-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathstorage.go
1225 lines (1148 loc) · 44.4 KB
/
storage.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
package state
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"golang.org/x/exp/slices"
"github.com/matrix-org/sliding-sync/sync2"
"github.com/getsentry/sentry-go"
"github.com/lib/pq"
"github.com/jmoiron/sqlx"
"github.com/matrix-org/sliding-sync/internal"
"github.com/matrix-org/sliding-sync/sqlutil"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
)
var logger = zerolog.New(os.Stdout).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "15:04:05",
})
// Max number of parameters in a single SQL command
const MaxPostgresParameters = 65535
// StartupSnapshot represents a snapshot of startup data for the sliding sync HTTP API instances
type StartupSnapshot struct {
GlobalMetadata map[string]internal.RoomMetadata // room_id -> metadata
AllJoinedMembers map[string][]string // room_id -> [user_id]
}
type LatestEvents struct {
Timeline []json.RawMessage
PrevBatch string
LatestNID int64
}
// DiscardIgnoredMessages modifies the struct in-place, replacing the Timeline with
// a copy that has all ignored events omitted. The order of timelines is preserved.
func (e *LatestEvents) DiscardIgnoredMessages(shouldIgnore func(sender string) bool) {
// A little bit sad to be effectively doing a copy here---most of the time there
// won't be any messages to ignore (and the timeline is likely short). But that copy
// is unlikely to be a bottleneck.
newTimeline := make([]json.RawMessage, 0, len(e.Timeline))
for _, ev := range e.Timeline {
parsed := gjson.ParseBytes(ev)
if parsed.Get("state_key").Exists() || !shouldIgnore(parsed.Get("sender").Str) {
newTimeline = append(newTimeline, ev)
}
}
e.Timeline = newTimeline
}
type Storage struct {
Accumulator *Accumulator
EventsTable *EventTable
ToDeviceTable *ToDeviceTable
UnreadTable *UnreadTable
AccountDataTable *AccountDataTable
InvitesTable *InvitesTable
TransactionsTable *TransactionsTable
DeviceDataTable *DeviceDataTable
ReceiptTable *ReceiptTable
DB *sqlx.DB
MaxTimelineLimit int
shutdownCh chan struct{}
shutdown bool
}
func NewStorage(postgresURI string) *Storage {
db, err := sqlx.Open("postgres", postgresURI)
if err != nil {
sentry.CaptureException(err)
// TODO: if we panic(), will sentry have a chance to flush the event?
logger.Panic().Err(err).Str("uri", postgresURI).Msg("failed to open SQL DB")
}
return NewStorageWithDB(db, false)
}
func NewStorageWithDB(db *sqlx.DB, addPrometheusMetrics bool) *Storage {
acc := &Accumulator{
db: db,
roomsTable: NewRoomsTable(db),
eventsTable: NewEventTable(db),
snapshotTable: NewSnapshotsTable(db),
spacesTable: NewSpacesTable(db),
invitesTable: NewInvitesTable(db),
entityName: "server",
}
return &Storage{
Accumulator: acc,
ToDeviceTable: NewToDeviceTable(db),
UnreadTable: NewUnreadTable(db),
EventsTable: acc.eventsTable,
AccountDataTable: NewAccountDataTable(db),
InvitesTable: acc.invitesTable,
TransactionsTable: NewTransactionsTable(db),
DeviceDataTable: NewDeviceDataTable(db),
ReceiptTable: NewReceiptTable(db),
DB: db,
MaxTimelineLimit: 50,
shutdownCh: make(chan struct{}),
}
}
func (s *Storage) LatestEventNID() (int64, error) {
return s.Accumulator.eventsTable.SelectHighestNID()
}
func (s *Storage) AccountData(userID, roomID string, eventTypes []string) (data []AccountData, err error) {
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
data, err = s.AccountDataTable.Select(txn, userID, eventTypes, roomID)
return err
})
return
}
func (s *Storage) RoomAccountDatasWithType(userID, eventType string) (data []AccountData, err error) {
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
data, err = s.AccountDataTable.SelectWithType(txn, userID, eventType)
return err
})
return
}
// Pull out all account data for this user. If roomIDs is empty, global account data is returned.
// If roomIDs is non-empty, all account data for these rooms are extracted.
func (s *Storage) AccountDatas(userID string, roomIDs ...string) (datas []AccountData, err error) {
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
datas, err = s.AccountDataTable.SelectMany(txn, userID, roomIDs...)
return err
})
return
}
func (s *Storage) InsertAccountData(userID, roomID string, events []json.RawMessage) (data []AccountData, err error) {
data = make([]AccountData, len(events))
for i := range events {
data[i] = AccountData{
UserID: userID,
RoomID: roomID,
Data: events[i],
Type: gjson.ParseBytes(events[i]).Get("type").Str,
}
}
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
data, err = s.AccountDataTable.Insert(txn, data)
return err
})
return data, err
}
// Prepare a snapshot of the database for calling snapshot functions.
func (s *Storage) PrepareSnapshot(txn *sqlx.Tx) (tableName string, err error) {
// create a temporary table with all the membership nids for the current snapshots for all rooms.
// A temporary table will be deleted when the postgres session ends (this process quits).
// We insert these into a temporary table to let the query planner make better decisions. In practice,
// if we instead nest this SELECT as a subselect, we see very poor query times for large tables as
// each event NID is queried using a btree index, rather than doing a seq scan as this query will pull
// out ~50% of the rows in syncv3_events.
tempTableName := "temp_snapshot"
_, err = txn.Exec(
`SELECT UNNEST(membership_events) AS membership_nid INTO TEMP ` + tempTableName + ` FROM syncv3_snapshots
JOIN syncv3_rooms ON syncv3_snapshots.snapshot_id = syncv3_rooms.current_snapshot_id`,
)
return tempTableName, err
}
// GlobalSnapshot snapshots the entire database for the purposes of initialising
// a sliding sync instance. It will atomically grab metadata for all rooms and all joined members
// in a single transaction.
func (s *Storage) GlobalSnapshot() (ss StartupSnapshot, err error) {
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
tempTableName, err := s.PrepareSnapshot(txn)
if err != nil {
err = fmt.Errorf("GlobalSnapshot: failed to call PrepareSnapshot: %w", err)
sentry.CaptureException(err)
return err
}
var metadata map[string]internal.RoomMetadata
ss.AllJoinedMembers, metadata, err = s.AllJoinedMembers(txn, tempTableName)
if err != nil {
err = fmt.Errorf("GlobalSnapshot: failed to call AllJoinedMembers: %w", err)
sentry.CaptureException(err)
return err
}
err = s.MetadataForAllRooms(txn, tempTableName, metadata)
if err != nil {
err = fmt.Errorf("GlobalSnapshot: failed to call MetadataForAllRooms: %w", err)
sentry.CaptureException(err)
return err
}
ss.GlobalMetadata = metadata
return err
})
return
}
// Extract hero info for all rooms. Requires a prepared snapshot in order to be called.
func (s *Storage) MetadataForAllRooms(txn *sqlx.Tx, tempTableName string, result map[string]internal.RoomMetadata) error {
loadMetadata := func(roomID string) internal.RoomMetadata {
metadata, ok := result[roomID]
if !ok {
metadata = *internal.NewRoomMetadata(roomID)
}
return metadata
}
// work out latest timestamps
events, err := s.Accumulator.eventsTable.selectLatestEventByTypeInAllRooms(txn)
if err != nil {
return err
}
for _, ev := range events {
metadata := loadMetadata(ev.RoomID)
// For a given room, we'll see many events (one for each event type in the
// room's state). We need to pick the largest of these events' timestamps here.
ts := gjson.ParseBytes(ev.JSON).Get("origin_server_ts").Uint()
if ts > metadata.LastMessageTimestamp {
metadata.LastMessageTimestamp = ts
}
parsed := gjson.ParseBytes(ev.JSON)
eventMetadata := internal.EventMetadata{
NID: ev.NID,
Timestamp: parsed.Get("origin_server_ts").Uint(),
}
metadata.LatestEventsByType[parsed.Get("type").Str] = eventMetadata
// it's possible the latest event is a brand new room not caught by the first SELECT for joined
// rooms e.g when you're invited to a room so we need to make sure to set the metadata again here
// TODO: is the comment above now that we explicitly call NewRoomMetadata above
// when handling invites?
metadata.RoomID = ev.RoomID
result[ev.RoomID] = metadata
}
// Select the name / canonical alias for all rooms
roomIDToStateEvents, err := s.currentNotMembershipStateEventsInAllRooms(txn, []string{
"m.room.name", "m.room.canonical_alias", "m.room.avatar",
})
if err != nil {
return fmt.Errorf("failed to load state events for all rooms: %s", err)
}
for roomID, stateEvents := range roomIDToStateEvents {
metadata := loadMetadata(roomID)
for _, ev := range stateEvents {
if ev.Type == "m.room.name" && ev.StateKey == "" {
metadata.NameEvent = gjson.ParseBytes(ev.JSON).Get("content.name").Str
} else if ev.Type == "m.room.canonical_alias" && ev.StateKey == "" {
metadata.CanonicalAlias = gjson.ParseBytes(ev.JSON).Get("content.alias").Str
} else if ev.Type == "m.room.avatar" && ev.StateKey == "" {
metadata.AvatarEvent = gjson.ParseBytes(ev.JSON).Get("content.url").Str
}
}
result[roomID] = metadata
}
roomInfos, err := s.Accumulator.roomsTable.SelectRoomInfos(txn)
if err != nil {
return fmt.Errorf("failed to select room infos: %s", err)
}
var spaceRoomIDs []string
for _, info := range roomInfos {
metadata := loadMetadata(info.ID)
metadata.Encrypted = info.IsEncrypted
metadata.UpgradedRoomID = info.UpgradedRoomID
metadata.PredecessorRoomID = info.PredecessorRoomID
metadata.RoomType = info.Type
result[info.ID] = metadata
if metadata.IsSpace() {
spaceRoomIDs = append(spaceRoomIDs, info.ID)
}
}
// select space children
spaceRoomToRelations, err := s.Accumulator.spacesTable.SelectChildren(txn, spaceRoomIDs)
if err != nil {
return fmt.Errorf("failed to select space children: %s", err)
}
for roomID, relations := range spaceRoomToRelations {
if _, exists := result[roomID]; !exists {
// this can happen when you join a space (so it populates the spaces table) then leave the space,
// so there are no joined members in the space so result doesn't include the room. In this case,
// we don't want to have a stub metadata with just the space children, so skip it.
continue
}
metadata := loadMetadata(roomID)
metadata.ChildSpaceRooms = make(map[string]struct{}, len(relations))
for _, r := range relations {
// For now we only honour child state events, but we store all the mappings just in case.
if r.Relation == RelationMSpaceChild {
metadata.ChildSpaceRooms[r.Child] = struct{}{}
}
}
result[roomID] = metadata
}
return nil
}
// ResetMetadataState updates the given metadata in-place to reflect the current state
// of the room. This is only safe to call from the subscriber goroutine; it is not safe
// to call from the connection goroutines.
// TODO: could have this create a new RoomMetadata and get the caller to assign it.
func (s *Storage) ResetMetadataState(metadata *internal.RoomMetadata) error {
var events []Event
err := s.DB.Select(&events, `
WITH snapshot(events, membership_events) AS (
SELECT events, membership_events
FROM syncv3_snapshots
JOIN syncv3_rooms ON snapshot_id = current_snapshot_id
WHERE syncv3_rooms.room_id = $1
)
SELECT event_id, event_type, state_key, event, membership
FROM syncv3_events JOIN snapshot ON (
event_nid = ANY (ARRAY_CAT(events, membership_events))
)
WHERE (event_type IN ('m.room.name', 'm.room.avatar', 'm.room.canonical_alias', 'm.room.encryption') AND state_key = '')
OR (event_type = 'm.room.member' AND membership IN ('join', '_join', 'invite', '_invite'))
ORDER BY event_nid ASC
;`, metadata.RoomID)
if err != nil {
return fmt.Errorf("ResetMetadataState[%s]: %w", metadata.RoomID, err)
}
heroMemberships := circularSlice[*Event]{max: 6}
metadata.JoinCount = 0
metadata.InviteCount = 0
metadata.ChildSpaceRooms = make(map[string]struct{})
for i, ev := range events {
switch ev.Type {
case "m.room.name":
metadata.NameEvent = gjson.GetBytes(ev.JSON, "content.name").Str
case "m.room.avatar":
metadata.AvatarEvent = gjson.GetBytes(ev.JSON, "content.url").Str
case "m.room.canonical_alias":
metadata.CanonicalAlias = gjson.GetBytes(ev.JSON, "content.alias").Str
case "m.room.encryption":
metadata.Encrypted = true
case "m.room.member":
heroMemberships.append(&events[i])
switch ev.Membership {
case "join":
fallthrough
case "_join":
metadata.JoinCount++
case "invite":
fallthrough
case "_invite":
metadata.InviteCount++
}
case "m.space.child":
metadata.ChildSpaceRooms[ev.StateKey] = struct{}{}
}
}
metadata.Heroes = make([]internal.Hero, 0, len(heroMemberships.vals))
for _, ev := range heroMemberships.vals {
parsed := gjson.ParseBytes(ev.JSON)
hero := internal.Hero{
ID: ev.StateKey,
Name: parsed.Get("content.displayname").Str,
Avatar: parsed.Get("content.avatar_url").Str,
}
metadata.Heroes = append(metadata.Heroes, hero)
}
// For now, don't bother reloading PredecessorID and UpgradedRoomID.
// These shouldn't be changing during a room's lifetime in normal operation.
// We haven't updated LatestEventsByType because that's not part of the timeline.
return nil
}
// FetchMemberships looks up the latest snapshot for the given room and determines the
// latest membership events in the room. Returns
// - the list of joined members,
// - the list of invited members, and then
// - the list of all other memberships. (This is called "leaves", but includes bans. It
// also includes knocks, but the proxy doesn't support those.)
//
// Each lists' members are arranged in no particular order.
//
// TODO: there is a very similar query in ResetMetadataState which also selects events
// events row for memberships. It is a shame to have to do this twice---can we query
// once and pass the data around?
func (s *Storage) FetchMemberships(roomID string) (joins, invites, leaves []string, err error) {
var events []Event
err = s.DB.Select(&events, `
WITH snapshot(membership_nids) AS (
SELECT membership_events
FROM syncv3_snapshots
JOIN syncv3_rooms ON snapshot_id = current_snapshot_id
WHERE syncv3_rooms.room_id = $1
)
SELECT state_key, membership
FROM syncv3_events JOIN snapshot ON (
event_nid = ANY( membership_nids )
)
`, roomID)
if err != nil {
return nil, nil, nil, err
}
joins = make([]string, 0, len(events))
invites = make([]string, 0, len(events))
leaves = make([]string, 0, len(events))
for _, e := range events {
switch e.Membership {
case "_join":
fallthrough
case "join":
joins = append(joins, e.StateKey)
case "_invite":
fallthrough
case "invite":
invites = append(invites, e.StateKey)
default:
leaves = append(leaves, e.StateKey)
}
}
return
}
// Returns all current NOT MEMBERSHIP state events matching the event types given in all rooms. Returns a map of
// room ID to events in that room.
func (s *Storage) currentNotMembershipStateEventsInAllRooms(txn *sqlx.Tx, eventTypes []string) (map[string][]Event, error) {
query, args, err := sqlx.In(
`SELECT syncv3_events.room_id, syncv3_events.event_type, syncv3_events.state_key, syncv3_events.event FROM syncv3_events
WHERE syncv3_events.event_type IN (?)
AND syncv3_events.event_nid IN (
SELECT UNNEST(events) FROM syncv3_snapshots WHERE syncv3_snapshots.snapshot_id IN (SELECT current_snapshot_id FROM syncv3_rooms)
)`,
eventTypes,
)
if err != nil {
return nil, err
}
rows, err := txn.Query(txn.Rebind(query), args...)
if err != nil {
return nil, err
}
defer rows.Close()
result := make(map[string][]Event)
for rows.Next() {
var ev Event
if err := rows.Scan(&ev.RoomID, &ev.Type, &ev.StateKey, &ev.JSON); err != nil {
return nil, err
}
result[ev.RoomID] = append(result[ev.RoomID], ev)
}
return result, nil
}
func (s *Storage) Accumulate(userID, roomID string, timeline sync2.TimelineResponse) (result AccumulateResult, err error) {
if len(timeline.Events) == 0 {
return AccumulateResult{}, nil
}
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
result, err = s.Accumulator.Accumulate(txn, userID, roomID, timeline)
return err
})
return result, err
}
func (s *Storage) Initialise(roomID string, state []json.RawMessage) (InitialiseResult, error) {
return s.Accumulator.Initialise(roomID, state)
}
// EventNIDs fetches the raw JSON form of events given a slice of eventNIDs. The events
// are returned in ascending NID order; the order of eventNIDs is ignored.
func (s *Storage) EventNIDs(eventNIDs []int64) ([]json.RawMessage, error) {
// TODO: this selects a bunch of rows from the DB, but we only use the raw JSON
// itself.
events, err := s.EventsTable.SelectByNIDs(nil, true, eventNIDs)
if err != nil {
return nil, err
}
e := make([]json.RawMessage, len(events))
for i := range events {
e[i] = events[i].JSON
}
return e, nil
}
func (s *Storage) StateSnapshot(snapID int64) (state []json.RawMessage, err error) {
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
snapshotRow, err := s.Accumulator.snapshotTable.Select(txn, snapID)
if err != nil {
return err
}
events, err := s.Accumulator.eventsTable.SelectByNIDs(txn, true, append(snapshotRow.MembershipEvents, snapshotRow.OtherEvents...))
if err != nil {
return fmt.Errorf("failed to select state snapshot %v: %s", snapID, err)
}
state = make([]json.RawMessage, len(events))
for i := range events {
state[i] = events[i].JSON
}
return nil
})
return
}
// Look up room state after the given event position and no further. eventTypesToStateKeys is a map of event type to a list of state keys for that event type.
// If the list of state keys is empty then all events matching that event type will be returned. If the map is empty entirely, then all room state
// will be returned.
func (s *Storage) RoomStateAfterEventPosition(ctx context.Context, roomIDs []string, pos int64, eventTypesToStateKeys map[string][]string) (roomToEvents map[string][]Event, err error) {
_, span := internal.StartSpan(ctx, "RoomStateAfterEventPosition")
defer span.End()
roomToEvents = make(map[string][]Event, len(roomIDs))
roomIndex := make(map[string]int, len(roomIDs))
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
// we have 2 ways to pull the latest events:
// - superfast rooms table (which races as it can be updated before the new state hits the dispatcher)
// - slower events table query
// we will try to fulfill as many rooms as possible with the rooms table, only using the slower events table
// query if we can prove we have races. We can prove this because the latest NIDs will be > pos, meaning the
// database state is ahead of the in-memory state (which is normal as we update the DB first). This should
// happen infrequently though, so we will warn about this behaviour.
roomToLatestNIDs, err := s.Accumulator.roomsTable.LatestNIDs(txn, roomIDs)
if err != nil {
return err
}
fastNIDs := make([]int64, 0, len(roomToLatestNIDs))
var slowRooms []string
for roomID, latestNID := range roomToLatestNIDs {
if latestNID > pos {
slowRooms = append(slowRooms, roomID)
} else {
fastNIDs = append(fastNIDs, latestNID)
}
}
latestEvents, err := s.Accumulator.eventsTable.SelectByNIDs(txn, true, fastNIDs)
if err != nil {
return fmt.Errorf("failed to select latest nids in rooms %v: %s", roomIDs, err)
}
if len(slowRooms) > 0 {
logger.Warn().Int("slow_rooms", len(slowRooms)).Msg("RoomStateAfterEventPosition: pos value provided is far behind the database copy, performance degraded")
latestSlowEvents, err := s.Accumulator.eventsTable.LatestEventInRooms(txn, slowRooms, pos)
if err != nil {
return err
}
latestEvents = append(latestEvents, latestSlowEvents...)
}
for i, ev := range latestEvents {
roomIndex[ev.RoomID] = i
if ev.BeforeStateSnapshotID == 0 {
// if there is no before snapshot then this last event NID is _part of_ the initial state,
// ergo the state after this == the current state and we can safely ignore the lastEventNID
ev.BeforeStateSnapshotID = 0
ev.BeforeStateSnapshotID, err = s.Accumulator.roomsTable.CurrentAfterSnapshotID(txn, ev.RoomID)
if err != nil {
return err
}
latestEvents[i] = ev
}
}
if len(eventTypesToStateKeys) == 0 {
for _, ev := range latestEvents {
snapshotRow, err := s.Accumulator.snapshotTable.Select(txn, ev.BeforeStateSnapshotID)
if err != nil {
return err
}
allStateEventNIDs := append(snapshotRow.MembershipEvents, snapshotRow.OtherEvents...)
// we need to roll forward if this event is state
if gjson.ParseBytes(ev.JSON).Get("state_key").Exists() {
if ev.ReplacesNID != 0 {
// we determined at insert time of this event that this event replaces a nid in the snapshot.
// find it and replace it
for j := range allStateEventNIDs {
if allStateEventNIDs[j] == ev.ReplacesNID {
allStateEventNIDs[j] = ev.NID
break
}
}
} else {
// the event is still state, but it doesn't replace anything, so just add it onto the snapshot,
// but only if we haven't already
alreadyExists := false
for _, nid := range allStateEventNIDs {
if nid == ev.NID {
alreadyExists = true
break
}
}
if !alreadyExists {
allStateEventNIDs = append(allStateEventNIDs, ev.NID)
}
}
}
events, err := s.Accumulator.eventsTable.SelectByNIDs(txn, true, allStateEventNIDs)
if err != nil {
return fmt.Errorf("failed to select state snapshot %v for room %v: %s", ev.BeforeStateSnapshotID, ev.RoomID, err)
}
roomToEvents[ev.RoomID] = events
}
} else {
// do an optimised query to pull out only the event types and state keys we care about.
var args []interface{} // event type, state key, event type, state key, ....
snapIDs := make([]int64, len(latestEvents))
for i := range latestEvents {
snapIDs[i] = latestEvents[i].BeforeStateSnapshotID
}
args = append(args, pq.Int64Array(snapIDs))
var wheres []string
hasMembershipFilter := false
hasOtherFilter := false
for evType, skeys := range eventTypesToStateKeys {
if evType == "m.room.member" {
hasMembershipFilter = true
} else {
hasOtherFilter = true
}
for _, skey := range skeys {
args = append(args, evType, skey)
wheres = append(wheres, "(syncv3_events.event_type = ? AND syncv3_events.state_key = ?)")
}
if len(skeys) == 0 {
args = append(args, evType)
wheres = append(wheres, "syncv3_events.event_type = ?")
}
}
// figure out which state events to look at - if there is no m.room.member filter we can be super fast
nidcols := "array_cat(events, membership_events)"
if hasMembershipFilter && !hasOtherFilter {
nidcols = "membership_events"
} else if !hasMembershipFilter && hasOtherFilter {
nidcols = "events"
}
// it's not possible for there to be no membership filter and no other filter, we wouldn't be executing this code
// it is possible to have both, so neither if will execute.
// Similar to CurrentStateEventsInAllRooms
// We're using a CTE here, since unnestting the nids is quite expensive. Using the array as is
// and using ANY() instead performs quite well (e.g. 86k membership events and 130ms execution time, vs
// the previous query with unnest took 2.5s)
query, args, err := sqlx.In(
`
WITH nids AS (
SELECT `+nidcols+` AS allNids FROM syncv3_snapshots WHERE syncv3_snapshots.snapshot_id = ANY(?)
)
SELECT syncv3_events.event_nid, syncv3_events.room_id, syncv3_events.event_type, syncv3_events.state_key, syncv3_events.event
FROM syncv3_events, nids
WHERE (`+strings.Join(wheres, " OR ")+`) AND syncv3_events.event_nid = ANY(nids.allNids)
ORDER BY syncv3_events.event_nid ASC`,
args...,
)
if err != nil {
return fmt.Errorf("failed to form sql query: %s", err)
}
rows, err := txn.Query(txn.Rebind(query), args...)
if err != nil {
return fmt.Errorf("failed to execute query: %s", err)
}
defer rows.Close()
for rows.Next() {
var ev Event
if err := rows.Scan(&ev.NID, &ev.RoomID, &ev.Type, &ev.StateKey, &ev.JSON); err != nil {
return err
}
i := roomIndex[ev.RoomID]
if latestEvents[i].ReplacesNID == ev.NID {
// this event is replaced by the last event
ev = latestEvents[i]
}
roomToEvents[ev.RoomID] = append(roomToEvents[ev.RoomID], ev)
}
// handle the most recent events which won't be in the snapshot but may need to be.
// we handle the replace case but don't handle brand new state events
for i := range latestEvents {
if latestEvents[i].ReplacesNID == 0 {
// check if we should include it
for evType, stateKeys := range eventTypesToStateKeys {
if evType != latestEvents[i].Type {
continue
}
if len(stateKeys) == 0 {
roomToEvents[latestEvents[i].RoomID] = append(roomToEvents[latestEvents[i].RoomID], latestEvents[i])
} else {
for _, skey := range stateKeys {
if skey == latestEvents[i].StateKey {
roomToEvents[latestEvents[i].RoomID] = append(roomToEvents[latestEvents[i].RoomID], latestEvents[i])
break
}
}
}
}
}
}
}
return nil
})
return
}
// LatestEventsInRooms returns the most recent events
// - in the given rooms
// - that the user has permission to see
// - with NIDs <= `to`.
// Up to `limit` events are chosen per room. This limit be itself be limited according to MaxTimelineLimit.
func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64, limit int) (map[string]*LatestEvents, error) {
roomIDToRange, err := s.visibleEventNIDsBetweenForRooms(userID, roomIDs, 0, to)
if err != nil {
return nil, err
}
if s.MaxTimelineLimit != 0 && limit > s.MaxTimelineLimit {
limit = s.MaxTimelineLimit
}
result := make(map[string]*LatestEvents, len(roomIDs))
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
for roomID, r := range roomIDToRange {
var earliestEventNID int64
var latestEventNID int64
var roomEvents []json.RawMessage
// the most recent event will be first
events, err := s.EventsTable.SelectLatestEventsBetween(txn, roomID, r[0]-1, r[1], limit)
if err != nil {
return fmt.Errorf("room %s failed to SelectEventsBetween: %s", roomID, err)
}
for _, ev := range events {
if latestEventNID == 0 { // set first time and never again
latestEventNID = ev.NID
}
roomEvents = append(roomEvents, ev.JSON)
earliestEventNID = ev.NID
if len(roomEvents) >= limit {
break
}
}
// we want the most recent event to be last, so reverse the slice now in-place.
slices.Reverse(roomEvents)
latestEvents := LatestEvents{
LatestNID: latestEventNID,
Timeline: roomEvents,
}
if earliestEventNID != 0 {
// the oldest event needs a prev batch token, so find one now
prevBatch, err := s.EventsTable.SelectClosestPrevBatch(txn, roomID, earliestEventNID)
if err != nil {
return fmt.Errorf("failed to select prev_batch for room %s : %s", roomID, err)
}
latestEvents.PrevBatch = prevBatch
}
result[roomID] = &latestEvents
}
return nil
})
return result, err
}
// Remove state snapshots which cannot be accessed by clients. The latest MaxTimelineEvents
// snapshots must be kept, +1 for the current state. This handles the worst case where all
// MaxTimelineEvents are state events and hence each event makes a new snapshot. We can safely
// delete all snapshots older than this, as it's not possible to reach this snapshot as the proxy
// does not handle historical state (deferring to the homeserver for that).
func (s *Storage) RemoveInaccessibleStateSnapshots() error {
numToKeep := s.MaxTimelineLimit + 1
// Create a CTE which ranks each snapshot so we can figure out which snapshots to delete
// then execute the delete using the CTE.
//
// A per-room version of this query:
// WITH ranked_snapshots AS (
// SELECT
// snapshot_id,
// room_id,
// ROW_NUMBER() OVER (PARTITION BY room_id ORDER BY snapshot_id DESC) AS row_num
// FROM syncv3_snapshots
// )
// DELETE FROM syncv3_snapshots WHERE snapshot_id IN(
// SELECT snapshot_id FROM ranked_snapshots WHERE row_num > 51 AND room_id='!....'
// );
awfulQuery := fmt.Sprintf(`WITH ranked_snapshots AS (
SELECT
snapshot_id,
room_id,
ROW_NUMBER() OVER (PARTITION BY room_id ORDER BY snapshot_id DESC) AS row_num
FROM
syncv3_snapshots
)
DELETE FROM syncv3_snapshots USING ranked_snapshots
WHERE syncv3_snapshots.snapshot_id = ranked_snapshots.snapshot_id
AND ranked_snapshots.row_num > %d;`, numToKeep)
result, err := s.DB.Exec(awfulQuery)
if err != nil {
return fmt.Errorf("failed to RemoveInaccessibleStateSnapshots: Exec %s", err)
}
rowsAffected, err := result.RowsAffected()
if err == nil {
logger.Info().Int64("rows_affected", rowsAffected).Msg("RemoveInaccessibleStateSnapshots: deleted rows")
}
return nil
}
func (s *Storage) GetClosestPrevBatch(roomID string, eventNID int64) (prevBatch string) {
var err error
sqlutil.WithTransaction(s.DB, func(txn *sqlx.Tx) error {
// discard the error, we don't care if we fail as it's best effort
prevBatch, err = s.EventsTable.SelectClosestPrevBatch(txn, roomID, eventNID)
return err
})
return
}
// visibleEventNIDsBetweenForRooms determines which events a given user has permission to see.
// It accepts a nid range [from, to]. For each given room, it calculates the NID range
// [A1, B1] within [from, to] in which the user has permission to see events.
func (s *Storage) visibleEventNIDsBetweenForRooms(userID string, roomIDs []string, from, to int64) (map[string][2]int64, error) {
// load *THESE* joined rooms for this user at from (inclusive)
var membershipEvents []Event
var err error
if from != 0 {
// if from==0 then this query will return nothing, so optimise it out
membershipEvents, err = s.Accumulator.eventsTable.SelectEventsWithTypeStateKeyInRooms(roomIDs, "m.room.member", userID, 0, from)
if err != nil {
return nil, fmt.Errorf("VisibleEventNIDsBetweenForRooms.SelectEventsWithTypeStateKeyInRooms: %s", err)
}
}
joinTimingsAtFromByRoomID, err := s.determineJoinedRoomsFromMemberships(membershipEvents)
if err != nil {
return nil, fmt.Errorf("failed to work out joined rooms for %s at pos %d: %s", userID, from, err)
}
// load membership deltas for *THESE* rooms for this user
membershipEvents, err = s.Accumulator.eventsTable.SelectEventsWithTypeStateKeyInRooms(roomIDs, "m.room.member", userID, from, to)
if err != nil {
return nil, fmt.Errorf("failed to load membership events: %s", err)
}
return s.visibleEventNIDsWithData(joinTimingsAtFromByRoomID, membershipEvents, userID, from, to)
}
// Work out the NID ranges to pull events from for this user. Given a from and to event nid stream position,
// this function returns a map of room ID to a 2-element from|to positions. These positions are
// all INCLUSIVE, and the client should be informed of these events at some point. For example:
//
// Stream Positions
// 1 2 3 4 5 6 7 8 9 10
// Room A Maj E E E
// Room B E Maj E
// Room C E Mal E (a already joined to this room at position 0)
//
// E=message event, M=membership event, followed by user letter, followed by 'i' or 'j' or 'l' for invite|join|leave
//
// - For Room A: from=1, to=10, returns { RoomA: [ 1,10 ]} (tests events in joined room)
// - For Room B: from=1, to=10, returns { RoomB: [ 5,10 ]} (tests joining a room starts events)
// - For Room C: from=1, to=10, returns { RoomC: [ 0,9 ]} (tests leaving a room stops events)
//
// In cases where a user joins/leaves a room multiple times in the nid range, only the last range is returned.
// This is critical to ensure we don't skip out timeline events due to history visibility (which the proxy defers
// to the upstream HS for). See https://github.com/matrix-org/sliding-sync/issues/365 for what happens if we returned
// all ranges.
//
// Stream Positions
// 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Room D Maj E Mal E Maj E Mal E
// Room E E Mai E E Maj E E
//
// - For Room D: from=1, to=15 returns { RoomD: [ 8,10 ] } (tests multi-join/leave)
// - For Room E: from=1, to=15 returns { RoomE: [ 13,15 ] } (tests invites)
func (s *Storage) VisibleEventNIDsBetween(userID string, from, to int64) (map[string][2]int64, error) {
// load *ALL* joined rooms for this user at from (inclusive)
joinTimingsAtFromByRoomID, err := s.JoinedRoomsAfterPosition(userID, from)
if err != nil {
return nil, fmt.Errorf("failed to work out joined rooms for %s at pos %d: %s", userID, from, err)
}
// load *ALL* membership deltas for all rooms for this user
membershipEvents, err := s.Accumulator.eventsTable.SelectEventsWithTypeStateKey("m.room.member", userID, from, to)
if err != nil {
return nil, fmt.Errorf("failed to load membership events: %s", err)
}
return s.visibleEventNIDsWithData(joinTimingsAtFromByRoomID, membershipEvents, userID, from, to)
}
func (s *Storage) visibleEventNIDsWithData(joinTimingsAtFromByRoomID map[string]internal.EventMetadata, membershipEvents []Event, userID string, from, to int64) (map[string][2]int64, error) {
// load membership events in order and bucket based on room ID
roomIDToLogs := make(map[string][]membershipEvent)
for _, ev := range membershipEvents {
evJSON := gjson.ParseBytes(ev.JSON)
roomIDToLogs[ev.RoomID] = append(roomIDToLogs[ev.RoomID], membershipEvent{
Event: ev,
StateKey: evJSON.Get("state_key").Str,
Membership: evJSON.Get("content.membership").Str,
})
}
// Performs the algorithm
calculateVisibleEventNIDs := func(isJoined bool, fromIncl, toIncl int64, logs []membershipEvent) [2]int64 {
// short circuit when there are no membership deltas
if len(logs) == 0 {
return [2]int64{
fromIncl, toIncl, // TODO: is this actually valid? Surely omitting it is the right answer?
}
}
var result [][2]int64
var startIndex int64 = -1
if isJoined {
startIndex = fromIncl
}
for _, memEvent := range logs {
// check for a valid transition (join->leave|ban or leave|invite->join) - we won't always get valid transitions
// e.g logs will be there for things like leave->ban which we don't care about
isValidTransition := false
if isJoined && (memEvent.Membership == "leave" || memEvent.Membership == "ban") {
isValidTransition = true
} else if !isJoined && memEvent.Membership == "join" {
isValidTransition = true
} else if !isJoined && memEvent.Membership == "invite" {
// short-circuit: invites are sent on their own and don't affect ranges
result = append(result, [2]int64{memEvent.NID, memEvent.NID})
continue
}
if !isValidTransition {
continue
}
if isJoined {
// transitioning to leave, we get all events up to and including the leave event
result = append(result, [2]int64{startIndex, memEvent.NID})
isJoined = false
} else {
// transitioning to joined, we will get the join and some more events in a bit
startIndex = memEvent.NID
isJoined = true
}
}
// if we are still joined to the room at this point, grab all events up to toIncl
if isJoined {
result = append(result, [2]int64{startIndex, toIncl})
}
if len(result) == 0 {
return [2]int64{}
}
// we only care about the LAST nid range, otherwise we can end up with gaps being returned in the
// timeline. See https://github.com/matrix-org/sliding-sync/issues/365
return result[len(result)-1]
}
// For each joined room, perform the algorithm and delete the logs afterwards
result := make(map[string][2]int64)
for joinedRoomID, _ := range joinTimingsAtFromByRoomID {
roomResult := calculateVisibleEventNIDs(true, from, to, roomIDToLogs[joinedRoomID])
result[joinedRoomID] = roomResult
delete(roomIDToLogs, joinedRoomID)
}
// Handle rooms which we are not joined to but have logs for
for roomID, logs := range roomIDToLogs {
roomResult := calculateVisibleEventNIDs(false, from, to, logs)
result[roomID] = roomResult
}
return result, nil
}
func (s *Storage) RoomMembershipDelta(roomID string, from, to int64, limit int) (eventJSON []json.RawMessage, upTo int64, err error) {
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
nids, err := s.Accumulator.eventsTable.SelectEventNIDsWithTypeInRoom(txn, "m.room.member", limit, roomID, from, to)
if err != nil {
return err
}
if len(nids) == 0 {
return nil
}
upTo = nids[len(nids)-1]
events, err := s.Accumulator.eventsTable.SelectByNIDs(txn, true, nids)
if err != nil {
return err
}
eventJSON = make([]json.RawMessage, len(events))
for i := range events {
eventJSON[i] = events[i].JSON
}
return nil
})
return
}
// Extract all rooms with joined members, and include the joined user list. Requires a prepared snapshot in order to be called.
// Populates the join/invite count and heroes for the returned metadata.
func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMembers map[string][]string, metadata map[string]internal.RoomMetadata, err error) {
// Select the most recent members for each room to serve as Heroes. The spec is ambiguous here:
// "This should be the first 5 members of the room, ordered by stream ordering, which are joined or invited."
// Unclear if this is the first 5 *most recent* (backwards) or forwards. For now we'll use the most recent
// ones, and select 6 of them so we can always use 5 no matter who is requesting the room name.