Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit df36945

Browse files
authored
Support pagination tokens from /sync and /messages in the relations API. (#11952)
1 parent 337f38c commit df36945

File tree

5 files changed

+217
-53
lines changed

5 files changed

+217
-53
lines changed

changelog.d/11952.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a long-standing bug where pagination tokens from `/sync` and `/messages` could not be provided to the `/relations` API.

synapse/rest/client/relations.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,45 @@
3232
PaginationChunk,
3333
RelationPaginationToken,
3434
)
35-
from synapse.types import JsonDict
35+
from synapse.types import JsonDict, RoomStreamToken, StreamToken
3636

3737
if TYPE_CHECKING:
3838
from synapse.server import HomeServer
39+
from synapse.storage.databases.main import DataStore
3940

4041
logger = logging.getLogger(__name__)
4142

4243

44+
async def _parse_token(
45+
store: "DataStore", token: Optional[str]
46+
) -> Optional[StreamToken]:
47+
"""
48+
For backwards compatibility support RelationPaginationToken, but new pagination
49+
tokens are generated as full StreamTokens, to be compatible with /sync and /messages.
50+
"""
51+
if not token:
52+
return None
53+
# Luckily the format for StreamToken and RelationPaginationToken differ enough
54+
# that they can easily be separated. An "_" appears in the serialization of
55+
# RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses
56+
# "-" only for separators.
57+
if "_" in token:
58+
return await StreamToken.from_string(store, token)
59+
else:
60+
relation_token = RelationPaginationToken.from_string(token)
61+
return StreamToken(
62+
room_key=RoomStreamToken(relation_token.topological, relation_token.stream),
63+
presence_key=0,
64+
typing_key=0,
65+
receipt_key=0,
66+
account_data_key=0,
67+
push_rules_key=0,
68+
to_device_key=0,
69+
device_list_key=0,
70+
groups_key=0,
71+
)
72+
73+
4374
class RelationPaginationServlet(RestServlet):
4475
"""API to paginate relations on an event by topological ordering, optionally
4576
filtered by relation type and event type.
@@ -88,13 +119,8 @@ async def on_GET(
88119
pagination_chunk = PaginationChunk(chunk=[])
89120
else:
90121
# Return the relations
91-
from_token = None
92-
if from_token_str:
93-
from_token = RelationPaginationToken.from_string(from_token_str)
94-
95-
to_token = None
96-
if to_token_str:
97-
to_token = RelationPaginationToken.from_string(to_token_str)
122+
from_token = await _parse_token(self.store, from_token_str)
123+
to_token = await _parse_token(self.store, to_token_str)
98124

99125
pagination_chunk = await self.store.get_relations_for_event(
100126
event_id=parent_id,
@@ -125,7 +151,7 @@ async def on_GET(
125151
events, now, bundle_aggregations=aggregations
126152
)
127153

128-
return_value = pagination_chunk.to_dict()
154+
return_value = await pagination_chunk.to_dict(self.store)
129155
return_value["chunk"] = serialized_events
130156
return_value["original_event"] = original_event
131157

@@ -216,7 +242,7 @@ async def on_GET(
216242
to_token=to_token,
217243
)
218244

219-
return 200, pagination_chunk.to_dict()
245+
return 200, await pagination_chunk.to_dict(self.store)
220246

221247

222248
class RelationAggregationGroupPaginationServlet(RestServlet):
@@ -287,13 +313,8 @@ async def on_GET(
287313
from_token_str = parse_string(request, "from")
288314
to_token_str = parse_string(request, "to")
289315

290-
from_token = None
291-
if from_token_str:
292-
from_token = RelationPaginationToken.from_string(from_token_str)
293-
294-
to_token = None
295-
if to_token_str:
296-
to_token = RelationPaginationToken.from_string(to_token_str)
316+
from_token = await _parse_token(self.store, from_token_str)
317+
to_token = await _parse_token(self.store, to_token_str)
297318

298319
result = await self.store.get_relations_for_event(
299320
event_id=parent_id,
@@ -313,7 +334,7 @@ async def on_GET(
313334
now = self.clock.time_msec()
314335
serialized_events = self._event_serializer.serialize_events(events, now)
315336

316-
return_value = result.to_dict()
337+
return_value = await result.to_dict(self.store)
317338
return_value["chunk"] = serialized_events
318339

319340
return 200, return_value

synapse/storage/databases/main/relations.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,13 @@
3939
)
4040
from synapse.storage.databases.main.stream import generate_pagination_where_clause
4141
from synapse.storage.engines import PostgresEngine
42-
from synapse.storage.relations import (
43-
AggregationPaginationToken,
44-
PaginationChunk,
45-
RelationPaginationToken,
46-
)
47-
from synapse.types import JsonDict
42+
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
43+
from synapse.types import JsonDict, RoomStreamToken, StreamToken
4844
from synapse.util.caches.descriptors import cached, cachedList
4945

5046
if TYPE_CHECKING:
5147
from synapse.server import HomeServer
48+
from synapse.storage.databases.main import DataStore
5249

5350
logger = logging.getLogger(__name__)
5451

@@ -98,8 +95,8 @@ async def get_relations_for_event(
9895
aggregation_key: Optional[str] = None,
9996
limit: int = 5,
10097
direction: str = "b",
101-
from_token: Optional[RelationPaginationToken] = None,
102-
to_token: Optional[RelationPaginationToken] = None,
98+
from_token: Optional[StreamToken] = None,
99+
to_token: Optional[StreamToken] = None,
103100
) -> PaginationChunk:
104101
"""Get a list of relations for an event, ordered by topological ordering.
105102
@@ -138,8 +135,10 @@ async def get_relations_for_event(
138135
pagination_clause = generate_pagination_where_clause(
139136
direction=direction,
140137
column_names=("topological_ordering", "stream_ordering"),
141-
from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
142-
to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
138+
from_token=from_token.room_key.as_historical_tuple()
139+
if from_token
140+
else None,
141+
to_token=to_token.room_key.as_historical_tuple() if to_token else None,
143142
engine=self.database_engine,
144143
)
145144

@@ -177,12 +176,27 @@ def _get_recent_references_for_event_txn(
177176
last_topo_id = row[1]
178177
last_stream_id = row[2]
179178

180-
next_batch = None
179+
# If there are more events, generate the next pagination key.
180+
next_token = None
181181
if len(events) > limit and last_topo_id and last_stream_id:
182-
next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
182+
next_key = RoomStreamToken(last_topo_id, last_stream_id)
183+
if from_token:
184+
next_token = from_token.copy_and_replace("room_key", next_key)
185+
else:
186+
next_token = StreamToken(
187+
room_key=next_key,
188+
presence_key=0,
189+
typing_key=0,
190+
receipt_key=0,
191+
account_data_key=0,
192+
push_rules_key=0,
193+
to_device_key=0,
194+
device_list_key=0,
195+
groups_key=0,
196+
)
183197

184198
return PaginationChunk(
185-
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
199+
chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
186200
)
187201

188202
return await self.db_pool.runInteraction(
@@ -676,13 +690,15 @@ async def _get_bundled_aggregation_for_event(
676690

677691
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
678692
if annotations.chunk:
679-
aggregations.annotations = annotations.to_dict()
693+
aggregations.annotations = await annotations.to_dict(
694+
cast("DataStore", self)
695+
)
680696

681697
references = await self.get_relations_for_event(
682698
event_id, room_id, RelationTypes.REFERENCE, direction="f"
683699
)
684700
if references.chunk:
685-
aggregations.references = references.to_dict()
701+
aggregations.references = await references.to_dict(cast("DataStore", self))
686702

687703
# If this event is the start of a thread, include a summary of the replies.
688704
if self._msc3440_enabled:

synapse/storage/relations.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Any, Dict, List, Optional, Tuple
16+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
1717

1818
import attr
1919

2020
from synapse.api.errors import SynapseError
2121
from synapse.types import JsonDict
2222

23+
if TYPE_CHECKING:
24+
from synapse.storage.databases.main import DataStore
25+
2326
logger = logging.getLogger(__name__)
2427

2528

@@ -39,14 +42,14 @@ class PaginationChunk:
3942
next_batch: Optional[Any] = None
4043
prev_batch: Optional[Any] = None
4144

42-
def to_dict(self) -> Dict[str, Any]:
45+
async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
4346
d = {"chunk": self.chunk}
4447

4548
if self.next_batch:
46-
d["next_batch"] = self.next_batch.to_string()
49+
d["next_batch"] = await self.next_batch.to_string(store)
4750

4851
if self.prev_batch:
49-
d["prev_batch"] = self.prev_batch.to_string()
52+
d["prev_batch"] = await self.prev_batch.to_string(store)
5053

5154
return d
5255

@@ -75,7 +78,7 @@ def from_string(string: str) -> "RelationPaginationToken":
7578
except ValueError:
7679
raise SynapseError(400, "Invalid relation pagination token")
7780

78-
def to_string(self) -> str:
81+
async def to_string(self, store: "DataStore") -> str:
7982
return "%d-%d" % (self.topological, self.stream)
8083

8184
def as_tuple(self) -> Tuple[Any, ...]:
@@ -105,7 +108,7 @@ def from_string(string: str) -> "AggregationPaginationToken":
105108
except ValueError:
106109
raise SynapseError(400, "Invalid aggregation pagination token")
107110

108-
def to_string(self) -> str:
111+
async def to_string(self, store: "DataStore") -> str:
109112
return "%d-%d" % (self.count, self.stream)
110113

111114
def as_tuple(self) -> Tuple[Any, ...]:

0 commit comments

Comments
 (0)