Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit fa5c919

Browse files
author
David Robertson
committed
Add tests for StreamIdGenerator
1 parent 618e4ab commit fa5c919

File tree

1 file changed

+143
-2
lines changed

1 file changed

+143
-2
lines changed

tests/storage/test_id_generators.py

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,156 @@
1616
from twisted.test.proto_helpers import MemoryReactor
1717

1818
from synapse.server import HomeServer
19-
from synapse.storage.database import DatabasePool, LoggingTransaction
19+
from synapse.storage.database import (
20+
DatabasePool,
21+
LoggingDatabaseConnection,
22+
LoggingTransaction,
23+
)
2024
from synapse.storage.engines import IncorrectDatabaseSetup
21-
from synapse.storage.util.id_generators import MultiWriterIdGenerator
25+
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
2226
from synapse.util import Clock
2327

2428
from tests.unittest import HomeserverTestCase
2529
from tests.utils import USE_POSTGRES_FOR_TESTS
2630

2731

32+
class StreamIdGeneratorTestCase(HomeserverTestCase):
33+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
34+
self.store = hs.get_datastores().main
35+
self.db_pool: DatabasePool = self.store.db_pool
36+
37+
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
38+
39+
def _setup_db(self, txn: LoggingTransaction) -> None:
40+
txn.execute(
41+
"""
42+
CREATE TABLE foobar (
43+
stream_id BIGINT NOT NULL,
44+
data TEXT
45+
);
46+
"""
47+
)
48+
txn.execute("INSERT INTO foobar VALUES (123, 'hello world');")
49+
50+
def _create_id_generator(self) -> StreamIdGenerator:
51+
def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
52+
return StreamIdGenerator(
53+
db_conn=conn,
54+
table="foobar",
55+
column="stream_id",
56+
)
57+
58+
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
59+
60+
def test_initial_value(self) -> None:
61+
"""Check that we read the current token from the DB."""
62+
id_gen = self._create_id_generator()
63+
self.assertEqual(id_gen.get_current_token(), 123)
64+
65+
def test_single_gen_next(self) -> None:
66+
"""Check that we correctly increment the current token from the DB."""
67+
id_gen = self._create_id_generator()
68+
69+
async def test_gen_next() -> None:
70+
async with id_gen.get_next() as next_id:
71+
# We haven't persisted `next_id` yet; current token is still 123
72+
self.assertEqual(id_gen.get_current_token(), 123)
73+
# But we did learn what the next value is
74+
self.assertEqual(next_id, 124)
75+
76+
# Once the context manager closes we assume that the `next_id` has been
77+
# written to the DB.
78+
self.assertEqual(id_gen.get_current_token(), 124)
79+
80+
self.get_success(test_gen_next())
81+
82+
def test_multiple_gen_nexts(self) -> None:
83+
"""Check that we handle overlapping calls to gen_next sensibly."""
84+
id_gen = self._create_id_generator()
85+
86+
async def test_gen_next() -> None:
87+
ctx1 = id_gen.get_next()
88+
ctx2 = id_gen.get_next()
89+
ctx3 = id_gen.get_next()
90+
91+
# Request three new stream IDs.
92+
self.assertEqual(await ctx1.__aenter__(), 124)
93+
self.assertEqual(await ctx2.__aenter__(), 125)
94+
self.assertEqual(await ctx3.__aenter__(), 126)
95+
96+
# None are persisted: current token unchanged.
97+
self.assertEqual(id_gen.get_current_token(), 123)
98+
99+
# Persist each in turn.
100+
await ctx1.__aexit__(None, None, None)
101+
self.assertEqual(id_gen.get_current_token(), 124)
102+
await ctx2.__aexit__(None, None, None)
103+
self.assertEqual(id_gen.get_current_token(), 125)
104+
await ctx3.__aexit__(None, None, None)
105+
self.assertEqual(id_gen.get_current_token(), 126)
106+
107+
self.get_success(test_gen_next())
108+
109+
def test_multiple_gen_nexts_closed_in_different_order(self) -> None:
110+
"""Check that we handle overlapping calls to gen_next, even when their IDs
111+
created and persisted in different orders."""
112+
id_gen = self._create_id_generator()
113+
114+
async def test_gen_next() -> None:
115+
ctx1 = id_gen.get_next()
116+
ctx2 = id_gen.get_next()
117+
ctx3 = id_gen.get_next()
118+
119+
# Request three new stream IDs.
120+
self.assertEqual(await ctx1.__aenter__(), 124)
121+
self.assertEqual(await ctx2.__aenter__(), 125)
122+
self.assertEqual(await ctx3.__aenter__(), 126)
123+
124+
# None are persisted: current token unchanged.
125+
self.assertEqual(id_gen.get_current_token(), 123)
126+
127+
# Persist them in a different order, starting with 126 from ctx3.
128+
await ctx3.__aexit__(None, None, None)
129+
# We haven't persisted 124 from ctx1 yet---current token is still 123.
130+
self.assertEqual(id_gen.get_current_token(), 123)
131+
132+
# Now persist 124 from ctx1.
133+
await ctx1.__aexit__(None, None, None)
134+
# Current token is then 124, waiting for 125 to be persisted.
135+
self.assertEqual(id_gen.get_current_token(), 124)
136+
137+
# Finally persist 125 from ctx2.
138+
await ctx2.__aexit__(None, None, None)
139+
# Current token is then 126 (skipping over 125).
140+
self.assertEqual(id_gen.get_current_token(), 126)
141+
142+
self.get_success(test_gen_next())
143+
144+
def test_gen_next_while_still_waiting_for_persistence(self) -> None:
145+
"""Check that we handle overlapping calls to gen_next."""
146+
id_gen = self._create_id_generator()
147+
148+
async def test_gen_next() -> None:
149+
ctx1 = id_gen.get_next()
150+
ctx2 = id_gen.get_next()
151+
ctx3 = id_gen.get_next()
152+
153+
# Request two new stream IDs.
154+
self.assertEqual(await ctx1.__aenter__(), 124)
155+
self.assertEqual(await ctx2.__aenter__(), 125)
156+
157+
# Persist ctx2 first.
158+
await ctx2.__aexit__(None, None, None)
159+
# Still waiting on ctx1's ID to be persisted.
160+
self.assertEqual(id_gen.get_current_token(), 123)
161+
162+
# Now request a third stream ID. It should be 126 (the smallest ID that
163+
# we've not yet handed out.)
164+
self.assertEqual(await ctx3.__aenter__(), 126)
165+
166+
self.get_success(test_gen_next())
167+
168+
28169
class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
29170
if not USE_POSTGRES_FOR_TESTS:
30171
skip = "Requires Postgres"

0 commit comments

Comments
 (0)