|
| 1 | +# Stateful tests for arbitrary Zarr stores. |
| 2 | + |
| 3 | + |
| 4 | +import hypothesis.strategies as st |
| 5 | +from hypothesis import assume, note |
| 6 | +from hypothesis.stateful import ( |
| 7 | + RuleBasedStateMachine, |
| 8 | + invariant, |
| 9 | + precondition, |
| 10 | + rule, |
| 11 | +) |
| 12 | + |
| 13 | +import zarr |
| 14 | +from zarr.abc.store import AccessMode, Store |
| 15 | +from zarr.core.buffer import Buffer, BufferPrototype, default_buffer_prototype |
| 16 | +from zarr.store import MemoryStore |
| 17 | +from zarr.testing.strategies import key_ranges, paths |
| 18 | + |
| 19 | + |
| 20 | +class SyncStoreWrapper(zarr.core.sync.SyncMixin): |
| 21 | + def __init__(self, store: Store): |
| 22 | + """Synchronous Store wrapper |
| 23 | +
|
| 24 | + This class holds synchronous methods that map to async methods of Store classes. |
| 25 | + The synchronous wrapper is needed because hypothesis' stateful testing infra does |
| 26 | + not support asyncio so we redefine sync versions of the Store API. |
| 27 | + https://github.com/HypothesisWorks/hypothesis/issues/3712#issuecomment-1668999041 |
| 28 | + """ |
| 29 | + self.store = store |
| 30 | + |
| 31 | + @property |
| 32 | + def mode(self) -> AccessMode: |
| 33 | + return self.store.mode |
| 34 | + |
| 35 | + def set(self, key: str, data_buffer: zarr.core.buffer.Buffer) -> None: |
| 36 | + return self._sync(self.store.set(key, data_buffer)) |
| 37 | + |
| 38 | + def list(self) -> list: |
| 39 | + return self._sync_iter(self.store.list()) |
| 40 | + |
| 41 | + def get(self, key: str, prototype: BufferPrototype) -> zarr.core.buffer.Buffer: |
| 42 | + obs = self._sync(self.store.get(key, prototype=prototype)) |
| 43 | + return obs |
| 44 | + |
| 45 | + def get_partial_values( |
| 46 | + self, key_ranges: list, prototype: BufferPrototype |
| 47 | + ) -> zarr.core.buffer.Buffer: |
| 48 | + obs_partial = self._sync( |
| 49 | + self.store.get_partial_values(prototype=prototype, key_ranges=key_ranges) |
| 50 | + ) |
| 51 | + return obs_partial |
| 52 | + |
| 53 | + def delete(self, path: str) -> None: |
| 54 | + return self._sync(self.store.delete(path)) |
| 55 | + |
| 56 | + def empty(self) -> bool: |
| 57 | + return self._sync(self.store.empty()) |
| 58 | + |
| 59 | + def clear(self) -> None: |
| 60 | + return self._sync(self.store.clear()) |
| 61 | + |
| 62 | + def exists(self, key) -> bool: |
| 63 | + return self._sync(self.store.exists(key)) |
| 64 | + |
| 65 | + def list_dir(self, prefix): |
| 66 | + raise NotImplementedError |
| 67 | + |
| 68 | + def list_prefix(self, prefix: str): |
| 69 | + raise NotImplementedError |
| 70 | + |
| 71 | + def set_partial_values(self, key_start_values): |
| 72 | + raise NotImplementedError |
| 73 | + |
| 74 | + @property |
| 75 | + def supports_listing(self) -> bool: |
| 76 | + return self.store.supports_listing |
| 77 | + |
| 78 | + @property |
| 79 | + def supports_partial_writes(self) -> bool: |
| 80 | + return self.supports_partial_writes |
| 81 | + |
| 82 | + @property |
| 83 | + def supports_writes(self) -> bool: |
| 84 | + return self.store.supports_writes |
| 85 | + |
| 86 | + |
| 87 | +class ZarrStoreStateMachine(RuleBasedStateMachine): |
| 88 | + """ " |
| 89 | + Zarr store state machine |
| 90 | +
|
| 91 | + This is a subclass of a Hypothesis RuleBasedStateMachine. |
| 92 | + It is testing a framework to ensure that the state of a Zarr store matches |
| 93 | + an expected state after a set of random operations. It contains a store |
| 94 | + (currently, a Zarr MemoryStore) and a model, a simplified version of a |
| 95 | + zarr store (in this case, a dict). It also contains rules which represent |
| 96 | + actions that can be applied to a zarr store. Rules apply an action to both |
| 97 | + the store and the model, and invariants assert that the state of the model |
| 98 | + is equal to the state of the store. Hypothesis then generates sequences of |
| 99 | + rules, running invariants after each rule. It raises an error if a sequence |
| 100 | + produces discontinuity between state of the model and state of the store |
| 101 | + (ie. an invariant is violated). |
| 102 | + https://hypothesis.readthedocs.io/en/latest/stateful.html |
| 103 | + """ |
| 104 | + |
| 105 | + def __init__(self): |
| 106 | + super().__init__() |
| 107 | + self.model: dict[str, bytes] = {} |
| 108 | + self.store = SyncStoreWrapper(MemoryStore(mode="w")) |
| 109 | + self.prototype = default_buffer_prototype() |
| 110 | + |
| 111 | + @rule(key=paths, data=st.binary(min_size=0, max_size=100)) |
| 112 | + def set(self, key: str, data: bytes) -> None: |
| 113 | + note(f"(set) Setting {key!r} with {data}") |
| 114 | + assert not self.store.mode.readonly |
| 115 | + data_buf = Buffer.from_bytes(data) |
| 116 | + self.store.set(key, data_buf) |
| 117 | + self.model[key] = data_buf |
| 118 | + |
| 119 | + @precondition(lambda self: len(self.model.keys()) > 0) |
| 120 | + @rule(key=paths, data=st.data()) |
| 121 | + def get(self, key: str, data: bytes) -> None: |
| 122 | + key = data.draw( |
| 123 | + st.sampled_from(sorted(self.model.keys())) |
| 124 | + ) # hypothesis wants to sample from sorted list |
| 125 | + note("(get)") |
| 126 | + store_value = self.store.get(key, self.prototype) |
| 127 | + # to bytes here necessary because data_buf set to model in set() |
| 128 | + assert self.model[key].to_bytes() == (store_value.to_bytes()) |
| 129 | + |
| 130 | + @rule(key=paths, data=st.data()) |
| 131 | + def get_invalid_keys(self, key: str, data: bytes) -> None: |
| 132 | + note("(get_invalid)") |
| 133 | + assume(key not in self.model.keys()) |
| 134 | + assert self.store.get(key, self.prototype) is None |
| 135 | + |
| 136 | + @precondition(lambda self: len(self.model.keys()) > 0) |
| 137 | + @rule(data=st.data()) |
| 138 | + def get_partial_values(self, data: bytes) -> None: |
| 139 | + key_range = data.draw(key_ranges(keys=st.sampled_from(sorted(self.model.keys())))) |
| 140 | + note(f"(get partial) {key_range=}") |
| 141 | + obs_maybe = self.store.get_partial_values(key_range, self.prototype) |
| 142 | + observed = [] |
| 143 | + |
| 144 | + for obs in obs_maybe: |
| 145 | + assert obs is not None |
| 146 | + observed.append(obs.to_bytes()) |
| 147 | + |
| 148 | + model_vals_ls = [] |
| 149 | + |
| 150 | + for key, byte_range in key_range: |
| 151 | + start = byte_range[0] or 0 |
| 152 | + step = byte_range[1] |
| 153 | + stop = start + step if step is not None else None |
| 154 | + model_vals_ls.append(self.model[key][start:stop]) |
| 155 | + |
| 156 | + assert all( |
| 157 | + obs == exp.to_bytes() for obs, exp in zip(observed, model_vals_ls, strict=True) |
| 158 | + ), ( |
| 159 | + observed, |
| 160 | + model_vals_ls, |
| 161 | + ) |
| 162 | + |
| 163 | + @precondition(lambda self: len(self.model.keys()) > 0) |
| 164 | + @rule(data=st.data()) |
| 165 | + def delete(self, data: bytes) -> None: |
| 166 | + key = data.draw(st.sampled_from(sorted(self.model.keys()))) |
| 167 | + note(f"(delete) Deleting {key=}") |
| 168 | + |
| 169 | + self.store.delete(key) |
| 170 | + del self.model[key] |
| 171 | + |
| 172 | + @rule() |
| 173 | + def clear(self): |
| 174 | + assert not self.store.mode.readonly |
| 175 | + note("(clear)") |
| 176 | + self.store.clear() |
| 177 | + self.model.clear() |
| 178 | + |
| 179 | + assert len(self.model.keys()) == len(list(self.store.list())) == 0 |
| 180 | + |
| 181 | + @rule() |
| 182 | + def empty(self) -> None: |
| 183 | + note("(empty)") |
| 184 | + |
| 185 | + # make sure they either both are or both aren't empty (same state) |
| 186 | + assert self.store.empty() == (not self.model) |
| 187 | + |
| 188 | + @rule(key=paths) |
| 189 | + def exists(self, key: str) -> None: |
| 190 | + note("(exists)") |
| 191 | + |
| 192 | + assert self.store.exists(key) == (key in self.model) |
| 193 | + |
| 194 | + @invariant() |
| 195 | + def check_paths_equal(self) -> None: |
| 196 | + note("Checking that paths are equal") |
| 197 | + paths = list(self.store.list()) |
| 198 | + |
| 199 | + assert list(self.model.keys()) == paths |
| 200 | + |
| 201 | + @invariant() |
| 202 | + def check_vals_equal(self) -> None: |
| 203 | + note("Checking values equal") |
| 204 | + for key, _val in self.model.items(): |
| 205 | + store_item = self.store.get(key, self.prototype).to_bytes() |
| 206 | + assert self.model[key].to_bytes() == store_item |
| 207 | + |
| 208 | + @invariant() |
| 209 | + def check_num_keys_equal(self) -> None: |
| 210 | + note("check num keys equal") |
| 211 | + |
| 212 | + assert len(self.model) == len(list(self.store.list())) |
| 213 | + |
| 214 | + @invariant() |
| 215 | + def check_keys(self) -> None: |
| 216 | + keys = list(self.store.list()) |
| 217 | + |
| 218 | + if len(keys) == 0: |
| 219 | + assert self.store.empty() is True |
| 220 | + |
| 221 | + elif len(keys) != 0: |
| 222 | + assert self.store.empty() is False |
| 223 | + |
| 224 | + for key in keys: |
| 225 | + assert self.store.exists(key) is True |
| 226 | + note("checking keys / exists / empty") |
| 227 | + |
| 228 | + |
| 229 | +StatefulStoreTest = ZarrStoreStateMachine.TestCase |
0 commit comments