Skip to content

Commit 389f6f6

Browse files
committed
fix(cache): json query distinct & list comparisons
1 parent 08f4872 commit 389f6f6

File tree

2 files changed

+72
-9
lines changed

2 files changed

+72
-9
lines changed

tests/test_meta/test_cache.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import random
3+
import sys
34
import time
45
from copy import deepcopy
56
from dataclasses import asdict
@@ -12,7 +13,7 @@
1213
from toggl_api.meta import CustomDecoder, CustomEncoder, JSONCache, RequestMethod
1314
from toggl_api.meta.cache.base_cache import Comparison, TogglQuery
1415
from toggl_api.meta.cache.json_cache import JSONSession
15-
from toggl_api.models.models import TogglTracker
16+
from toggl_api.models.models import TogglTag, TogglTracker
1617
from toggl_api.user import UserEndpoint
1718

1819

@@ -176,6 +177,44 @@ def test_query(model_data, tracker_object, faker):
176177
assert len(tracker_object.query(TogglQuery("name", names[:5]))) == 5 # noqa: PLR2004
177178

178179

180+
@pytest.mark.unit
181+
def test_query_distinct(model_data, tracker_object, faker):
182+
t = model_data.pop("tracker")
183+
t.id = 1
184+
185+
d = asdict(t)
186+
for i in range(1, 13):
187+
d["id"] += i
188+
d["timestamp"] = datetime.now(timezone.utc)
189+
tracker_object.save_cache(TogglTracker.from_kwargs(**d), RequestMethod.GET)
190+
191+
tracker_object.cache.commit()
192+
assert len(tracker_object.load_cache()) == 12 # noqa: PLR2004
193+
assert len(tracker_object.query(TogglQuery("name", t["name"]), distinct=True)) == 1
194+
195+
196+
@pytest.mark.unit
197+
def test_query_tag(model_data, tracker_object, faker, number):
198+
names = [faker.name() for _ in range(12)]
199+
t = model_data.pop("tracker")
200+
t.id = 1
201+
202+
d = asdict(t)
203+
tracker_object.save_cache(TogglTracker.from_kwargs(**d), RequestMethod.GET)
204+
tag = TogglTag(number.randint(50, sys.maxsize), faker.name())
205+
206+
for i in range(1, 3):
207+
d["id"] += i
208+
d["name"] = names[i - 1]
209+
d["timestamp"] = datetime.now(timezone.utc)
210+
d["tags"] = [tag]
211+
tracker_object.save_cache(TogglTracker.from_kwargs(**d), RequestMethod.GET)
212+
213+
tracker_object.cache.commit()
214+
assert len(tracker_object.load_cache()) == 3 # noqa: PLR2004
215+
assert len(tracker_object.query(TogglQuery("tags", [tag]))) == 2 # noqa: PLR2004
216+
217+
179218
@pytest.mark.unit
180219
def test_query_parent(tmp_path):
181220
cache = JSONCache(Path(tmp_path))

toggl_api/meta/cache/json_cache.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,34 +272,58 @@ def query(self, *query: TogglQuery, distinct: bool = False) -> list[TogglClass]:
272272
search = self.session.data
273273
existing: defaultdict[str, set[Any]] = defaultdict(set)
274274

275-
return [model for model in search if self._query_helper(model, query, existing, min_ts)]
275+
return [
276+
model
277+
for model in search
278+
if self._query_helper(
279+
model,
280+
query,
281+
existing,
282+
min_ts,
283+
distinct=distinct,
284+
)
285+
]
276286

277287
def _query_helper(
278288
self,
279289
model: TogglClass,
280290
queries: tuple[TogglQuery, ...],
281291
existing: dict[str, set[Any]],
282292
min_ts: Optional[datetime],
293+
*,
294+
distinct: bool,
283295
) -> bool:
284296
if self.expire_after and min_ts and model.timestamp and min_ts >= model.timestamp:
285297
return False
286298

287299
for query in queries:
288-
if model[query.key] in existing[query.key] or not self._match_query(model, query):
300+
if (
301+
distinct and not isinstance(query.value, list) and model[query.key] in existing[query.key]
302+
) or not self._match_query(model, query):
289303
return False
290304

291-
for query in queries:
292-
existing[query.key].add(model[query.key])
305+
if distinct:
306+
for query in queries:
307+
existing[query.key].add(model[query.key])
293308

294309
return True
295310

311+
@staticmethod
312+
def _match_equal(model: TogglClass, query: TogglQuery):
313+
if isinstance(query.value, Sequence) and not isinstance(query.value, str):
314+
value = model[query.key]
315+
316+
if isinstance(value, list):
317+
return any(v == comp for comp in query.value for v in value)
318+
319+
return any(value == comp for comp in query.value)
320+
321+
return model[query.key] == query.value
322+
296323
@staticmethod
297324
def _match_query(model: TogglClass, query: TogglQuery) -> bool:
298325
if query.comparison == Comparison.EQUAL:
299-
if isinstance(query.value, Sequence) and not isinstance(query.value, str):
300-
value = model[query.key]
301-
return any(value == comp for comp in query.value)
302-
return model[query.key] == query.value
326+
return JSONCache._match_equal(model, query)
303327
if query.comparison == Comparison.LESS_THEN:
304328
return model[query.key] < query.value
305329
if query.comparison == Comparison.LESS_THEN_OR_EQUAL:

0 commit comments

Comments
 (0)