Skip to content

Commit 547c375

Browse files
authored
fix: pass namespace in the docstore init (#683)
* pass namespace in the docstore init * manage serialization
1 parent 1adbc2c commit 547c375

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py

+5
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
embedding_dimension: int = 768,
5656
duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE,
5757
similarity: str = "cosine",
58+
namespace: Optional[str] = None,
5859
):
5960
"""
6061
The connection to Astra DB is established and managed through the JSON API.
@@ -99,13 +100,15 @@ def __init__(
99100
self.embedding_dimension = embedding_dimension
100101
self.duplicates_policy = duplicates_policy
101102
self.similarity = similarity
103+
self.namespace = namespace
102104

103105
self.index = AstraClient(
104106
resolved_api_endpoint,
105107
resolved_token,
106108
self.collection_name,
107109
self.embedding_dimension,
108110
self.similarity,
111+
namespace,
109112
)
110113

111114
@classmethod
@@ -128,6 +131,7 @@ def to_dict(self) -> Dict[str, Any]:
128131
:returns:
129132
Dictionary with serialized data.
130133
"""
134+
131135
return default_to_dict(
132136
self,
133137
api_endpoint=self.api_endpoint.to_dict(),
@@ -136,6 +140,7 @@ def to_dict(self) -> Dict[str, Any]:
136140
embedding_dimension=self.embedding_dimension,
137141
duplicates_policy=self.duplicates_policy.name,
138142
similarity=self.similarity,
143+
namespace=self.namespace,
139144
)
140145

141146
def write_documents(

integrations/astra/tests/test_document_store.py

+28
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44
import os
55
from typing import List
6+
from unittest import mock
67

78
import pytest
89
from haystack import Document
@@ -13,6 +14,33 @@
1314
from haystack_integrations.document_stores.astra import AstraDocumentStore
1415

1516

17+
def test_namespace_init():
18+
with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") as client:
19+
AstraDocumentStore()
20+
assert "namespace" in client.call_args.kwargs
21+
assert client.call_args.kwargs["namespace"] is None
22+
23+
AstraDocumentStore(namespace="foo")
24+
assert "namespace" in client.call_args.kwargs
25+
assert client.call_args.kwargs["namespace"] == "foo"
26+
27+
28+
def test_to_dict():
29+
with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB"):
30+
ds = AstraDocumentStore()
31+
result = ds.to_dict()
32+
assert result["type"] == "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore"
33+
assert set(result["init_parameters"]) == {
34+
"api_endpoint",
35+
"token",
36+
"collection_name",
37+
"embedding_dimension",
38+
"duplicates_policy",
39+
"similarity",
40+
"namespace",
41+
}
42+
43+
1644
@pytest.mark.integration
1745
@pytest.mark.skipif(
1846
os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set"

integrations/astra/tests/test_retriever.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_retriever_to_json(*_):
3030
"embedding_dimension": 768,
3131
"duplicates_policy": "NONE",
3232
"similarity": "cosine",
33+
"namespace": None,
3334
},
3435
},
3536
},
@@ -42,7 +43,6 @@ def test_retriever_to_json(*_):
4243
)
4344
@patch("haystack_integrations.document_stores.astra.document_store.AstraClient")
4445
def test_retriever_from_json(*_):
45-
4646
data = {
4747
"type": "haystack_integrations.components.retrievers.astra.retriever.AstraEmbeddingRetriever",
4848
"init_parameters": {

0 commit comments

Comments
 (0)