|
16 | 16 | from twisted.test.proto_helpers import MemoryReactor
|
17 | 17 |
|
18 | 18 | from synapse.server import HomeServer
|
19 |
| -from synapse.storage.database import DatabasePool, LoggingTransaction |
| 19 | +from synapse.storage.database import ( |
| 20 | + DatabasePool, |
| 21 | + LoggingDatabaseConnection, |
| 22 | + LoggingTransaction, |
| 23 | +) |
20 | 24 | from synapse.storage.engines import IncorrectDatabaseSetup
|
21 |
| -from synapse.storage.util.id_generators import MultiWriterIdGenerator |
| 25 | +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator |
22 | 26 | from synapse.util import Clock
|
23 | 27 |
|
24 | 28 | from tests.unittest import HomeserverTestCase
|
25 | 29 | from tests.utils import USE_POSTGRES_FOR_TESTS
|
26 | 30 |
|
27 | 31 |
|
| 32 | +class StreamIdGeneratorTestCase(HomeserverTestCase): |
| 33 | + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: |
| 34 | + self.store = hs.get_datastores().main |
| 35 | + self.db_pool: DatabasePool = self.store.db_pool |
| 36 | + |
| 37 | + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) |
| 38 | + |
| 39 | + def _setup_db(self, txn: LoggingTransaction) -> None: |
| 40 | + txn.execute( |
| 41 | + """ |
| 42 | + CREATE TABLE foobar ( |
| 43 | + stream_id BIGINT NOT NULL, |
| 44 | + data TEXT |
| 45 | + ); |
| 46 | + """ |
| 47 | + ) |
| 48 | + txn.execute("INSERT INTO foobar VALUES (123, 'hello world');") |
| 49 | + |
| 50 | + def _create_id_generator(self) -> StreamIdGenerator: |
| 51 | + def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: |
| 52 | + return StreamIdGenerator( |
| 53 | + db_conn=conn, |
| 54 | + table="foobar", |
| 55 | + column="stream_id", |
| 56 | + ) |
| 57 | + |
| 58 | + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) |
| 59 | + |
| 60 | + def test_initial_value(self) -> None: |
| 61 | + """Check that we read the current token from the DB.""" |
| 62 | + id_gen = self._create_id_generator() |
| 63 | + self.assertEqual(id_gen.get_current_token(), 123) |
| 64 | + |
| 65 | + def test_single_gen_next(self) -> None: |
| 66 | + """Check that we correctly increment the current token from the DB.""" |
| 67 | + id_gen = self._create_id_generator() |
| 68 | + |
| 69 | + async def test_gen_next() -> None: |
| 70 | + async with id_gen.get_next() as next_id: |
| 71 | + # We haven't persisted `next_id` yet; current token is still 123 |
| 72 | + self.assertEqual(id_gen.get_current_token(), 123) |
| 73 | + # But we did learn what the next value is |
| 74 | + self.assertEqual(next_id, 124) |
| 75 | + |
| 76 | + # Once the context manager closes we assume that the `next_id` has been |
| 77 | + # written to the DB. |
| 78 | + self.assertEqual(id_gen.get_current_token(), 124) |
| 79 | + |
| 80 | + self.get_success(test_gen_next()) |
| 81 | + |
| 82 | + def test_multiple_gen_nexts(self) -> None: |
| 83 | + """Check that we handle overlapping calls to gen_next sensibly.""" |
| 84 | + id_gen = self._create_id_generator() |
| 85 | + |
| 86 | + async def test_gen_next() -> None: |
| 87 | + ctx1 = id_gen.get_next() |
| 88 | + ctx2 = id_gen.get_next() |
| 89 | + ctx3 = id_gen.get_next() |
| 90 | + |
| 91 | + # Request three new stream IDs. |
| 92 | + self.assertEqual(await ctx1.__aenter__(), 124) |
| 93 | + self.assertEqual(await ctx2.__aenter__(), 125) |
| 94 | + self.assertEqual(await ctx3.__aenter__(), 126) |
| 95 | + |
| 96 | + # None are persisted: current token unchanged. |
| 97 | + self.assertEqual(id_gen.get_current_token(), 123) |
| 98 | + |
| 99 | + # Persist each in turn. |
| 100 | + await ctx1.__aexit__(None, None, None) |
| 101 | + self.assertEqual(id_gen.get_current_token(), 124) |
| 102 | + await ctx2.__aexit__(None, None, None) |
| 103 | + self.assertEqual(id_gen.get_current_token(), 125) |
| 104 | + await ctx3.__aexit__(None, None, None) |
| 105 | + self.assertEqual(id_gen.get_current_token(), 126) |
| 106 | + |
| 107 | + self.get_success(test_gen_next()) |
| 108 | + |
| 109 | + def test_multiple_gen_nexts_closed_in_different_order(self) -> None: |
| 110 | + """Check that we handle overlapping calls to gen_next, even when their IDs |
| 111 | + created and persisted in different orders.""" |
| 112 | + id_gen = self._create_id_generator() |
| 113 | + |
| 114 | + async def test_gen_next() -> None: |
| 115 | + ctx1 = id_gen.get_next() |
| 116 | + ctx2 = id_gen.get_next() |
| 117 | + ctx3 = id_gen.get_next() |
| 118 | + |
| 119 | + # Request three new stream IDs. |
| 120 | + self.assertEqual(await ctx1.__aenter__(), 124) |
| 121 | + self.assertEqual(await ctx2.__aenter__(), 125) |
| 122 | + self.assertEqual(await ctx3.__aenter__(), 126) |
| 123 | + |
| 124 | + # None are persisted: current token unchanged. |
| 125 | + self.assertEqual(id_gen.get_current_token(), 123) |
| 126 | + |
| 127 | + # Persist them in a different order, starting with 126 from ctx3. |
| 128 | + await ctx3.__aexit__(None, None, None) |
| 129 | + # We haven't persisted 124 from ctx1 yet---current token is still 123. |
| 130 | + self.assertEqual(id_gen.get_current_token(), 123) |
| 131 | + |
| 132 | + # Now persist 124 from ctx1. |
| 133 | + await ctx1.__aexit__(None, None, None) |
| 134 | + # Current token is then 124, waiting for 125 to be persisted. |
| 135 | + self.assertEqual(id_gen.get_current_token(), 124) |
| 136 | + |
| 137 | + # Finally persist 125 from ctx2. |
| 138 | + await ctx2.__aexit__(None, None, None) |
| 139 | + # Current token is then 126 (skipping over 125). |
| 140 | + self.assertEqual(id_gen.get_current_token(), 126) |
| 141 | + |
| 142 | + self.get_success(test_gen_next()) |
| 143 | + |
| 144 | + def test_gen_next_while_still_waiting_for_persistence(self) -> None: |
| 145 | + """Check that we handle overlapping calls to gen_next.""" |
| 146 | + id_gen = self._create_id_generator() |
| 147 | + |
| 148 | + async def test_gen_next() -> None: |
| 149 | + ctx1 = id_gen.get_next() |
| 150 | + ctx2 = id_gen.get_next() |
| 151 | + ctx3 = id_gen.get_next() |
| 152 | + |
| 153 | + # Request two new stream IDs. |
| 154 | + self.assertEqual(await ctx1.__aenter__(), 124) |
| 155 | + self.assertEqual(await ctx2.__aenter__(), 125) |
| 156 | + |
| 157 | + # Persist ctx2 first. |
| 158 | + await ctx2.__aexit__(None, None, None) |
| 159 | + # Still waiting on ctx1's ID to be persisted. |
| 160 | + self.assertEqual(id_gen.get_current_token(), 123) |
| 161 | + |
| 162 | + # Now request a third stream ID. It should be 126 (the smallest ID that |
| 163 | + # we've not yet handed out.) |
| 164 | + self.assertEqual(await ctx3.__aenter__(), 126) |
| 165 | + |
| 166 | + self.get_success(test_gen_next()) |
| 167 | + |
| 168 | + |
28 | 169 | class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
29 | 170 | if not USE_POSTGRES_FOR_TESTS:
|
30 | 171 | skip = "Requires Postgres"
|
|
0 commit comments