Skip to content

Commit bd529a8

Browse files
committed
fix(cli): fix broken tests due to changes in chroma client API.
1 parent fba200c commit bd529a8

File tree

3 files changed

+28
-37
lines changed

3 files changed

+28
-37
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description = "A tool to vectorise repositories for RAG."
55
authors = [{ name = "Davidyz", email = "[email protected]" }]
66
dependencies = [
77
"sentence-transformers",
8-
"chromadb>=1.0.0",
8+
"chromadb @ git+https://github.com/chroma-core/chroma.git@main#commit=05ff1f0a428d00877f97037e41a289c43346b254",
99
"pathspec",
1010
"tabulate",
1111
"shtab",

src/vectorcode/common.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
async def get_collections(
2020
client: AsyncClientAPI,
2121
) -> AsyncGenerator[AsyncCollection, None]:
22-
for collection_name in await client.list_collections():
23-
collection = await client.get_collection(collection_name, None)
22+
for collection in await client.list_collections():
2423
meta = collection.metadata
2524
if meta is None:
2625
continue
@@ -41,9 +40,9 @@ async def try_server(host: str, port: int, use_v2: bool = True):
4140
url = f"http://{host}:{port}/api/v{2 if use_v2 else 1}/heartbeat"
4241
try:
4342
async with httpx.AsyncClient() as client:
44-
response = await client.get(url=url)
43+
response = await client.get(url=url, timeout=1)
4544
return response.status_code == 200
46-
except (httpx.ConnectError, httpx.ConnectTimeout):
45+
except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout):
4746
return False
4847

4948

tests/test_common.py

+24-32
Original file line numberDiff line numberDiff line change
@@ -410,62 +410,54 @@ async def test_get_collections():
410410
# Mocking AsyncClientAPI and AsyncCollection
411411
mock_client = MagicMock(spec=AsyncClientAPI)
412412

413-
# Mock successful get_collection
414-
mock_collection1 = MagicMock(spec=AsyncCollection)
415-
mock_collection1.metadata = {
413+
# Create test collections with different metadata scenarios
414+
valid_collection = MagicMock(spec=AsyncCollection)
415+
valid_collection.metadata = {
416416
"created-by": "VectorCode",
417417
"username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")),
418418
"hostname": socket.gethostname(),
419419
}
420420

421-
# collection with meta == None
422-
mock_collection2 = MagicMock(spec=AsyncCollection)
423-
mock_collection2.metadata = None
421+
no_metadata_collection = MagicMock(spec=AsyncCollection)
422+
no_metadata_collection.metadata = None
424423

425-
# collection with wrong "created-by"
426-
mock_collection3 = MagicMock(spec=AsyncCollection)
427-
mock_collection3.metadata = {
428-
"created-by": "NotVectorCode",
424+
wrong_creator_collection = MagicMock(spec=AsyncCollection)
425+
wrong_creator_collection.metadata = {
426+
"created-by": "OtherTool",
429427
"username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")),
430428
"hostname": socket.gethostname(),
431429
}
432430

433-
# collection with wrong "username"
434-
mock_collection4 = MagicMock(spec=AsyncCollection)
435-
mock_collection4.metadata = {
431+
wrong_user_collection = MagicMock(spec=AsyncCollection)
432+
wrong_user_collection.metadata = {
436433
"created-by": "VectorCode",
437434
"username": "wrong_user",
438435
"hostname": socket.gethostname(),
439436
}
440437

441-
# collection with wrong "hostname"
442-
mock_collection5 = MagicMock(spec=AsyncCollection)
443-
mock_collection5.metadata = {
438+
wrong_host_collection = MagicMock(spec=AsyncCollection)
439+
wrong_host_collection.metadata = {
444440
"created-by": "VectorCode",
445441
"username": os.environ.get("USER", os.environ.get("USERNAME", "DEFAULT_USER")),
446442
"hostname": "wrong_host",
447443
}
448444

445+
# Mock list_collections to return the collections directly
449446
mock_client.list_collections.return_value = [
450-
"collection1",
451-
"collection2",
452-
"collection3",
453-
"collection4",
454-
"collection5",
455-
]
456-
mock_client.get_collection.side_effect = [
457-
mock_collection1,
458-
mock_collection2,
459-
mock_collection3,
460-
mock_collection4,
461-
mock_collection5,
447+
valid_collection,
448+
no_metadata_collection,
449+
wrong_creator_collection,
450+
wrong_user_collection,
451+
wrong_host_collection,
462452
]
463453

464-
collections = [
465-
collection async for collection in get_collections(mock_client)
466-
] # call get_collections
454+
# Collect the filtered collections
455+
collections = [collection async for collection in get_collections(mock_client)]
456+
457+
# Verify only the valid collection was returned
467458
assert len(collections) == 1
468-
assert collections[0] == mock_collection1
459+
assert collections[0] == valid_collection
460+
mock_client.list_collections.assert_called_once()
469461

470462

471463
def test_get_embedding_function_fallback():

0 commit comments

Comments
 (0)