Skip to content

Commit 63df87c

Browse files
committed
add tests
1 parent 4e8ef97 commit 63df87c

File tree

3 files changed

+519
-0
lines changed

3 files changed

+519
-0
lines changed

integrations/google_vertex/src/haystack_integrations/components/embedders/google_vertex/document_embedder.py

+4
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ def to_dict(self) -> Dict[str, Any]:
272272
truncate_dim=self.truncate_dim,
273273
meta_fields_to_embed=self.meta_fields_to_embed,
274274
embedding_separator=self.embedding_separator,
275+
max_tokens_total = self.max_tokens_total,
276+
task_type = self.task_type,
277+
time_sleep= self.time_sleep,
278+
retries=self.retries,
275279
)
276280

277281
@classmethod
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
import math
2+
import time
3+
from unittest.mock import MagicMock, call, patch
4+
5+
import pytest
6+
import vertexai
7+
from haystack import Document
8+
from haystack.dataclasses.document import Document
9+
from haystack.utils.auth import Secret
10+
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
11+
12+
# Assume the classes are in this path for the tests
13+
from haystack_integrations.components.embedders.google_vertex.document_embedder import (
14+
SUPPORTED_EMBEDDING_MODELS,
15+
VertexAIDocumentEmbedder,
16+
)
17+
18+
19+
# Mock the TextEmbeddingResponse structure expected by the embedder
20+
class MockTextEmbeddingResponse:
21+
def __init__(self, values):
22+
self.values = values
23+
24+
25+
# Mock the CountTokensResponse structure
26+
class MockCountTokensResponse:
27+
def __init__(self, total_tokens):
28+
self.total_tokens = total_tokens
29+
30+
31+
@pytest.fixture()
32+
def mock_vertex_init_and_model():
33+
"""
34+
Fixture to mock vertexai.init and TextEmbeddingModel.from_pretrained
35+
"""
36+
with patch("vertexai.init") as mock_init, patch(
37+
"vertexai.language_models.TextEmbeddingModel.from_pretrained"
38+
) as mock_from_pretrained:
39+
mock_embedder = MagicMock(spec=TextEmbeddingModel)
40+
mock_embedder.get_embeddings.return_value = [MockTextEmbeddingResponse([0.1] * 768)]
41+
mock_embedder.count_tokens.return_value = MockCountTokensResponse(total_tokens=10)
42+
mock_from_pretrained.return_value = mock_embedder
43+
yield mock_init, mock_from_pretrained, mock_embedder
44+
45+
46+
# Define valid parameters for initialization
47+
VALID_MODEL = "text-embedding-005"
48+
VALID_TASK_TYPE = "RETRIEVAL_DOCUMENT"
49+
50+
51+
def test_init_defaults():
52+
"""Test default initialization."""
53+
54+
embedder = VertexAIDocumentEmbedder(model=VALID_MODEL, task_type=VALID_TASK_TYPE)
55+
assert embedder.model == VALID_MODEL
56+
assert embedder.task_type == VALID_TASK_TYPE
57+
assert embedder.batch_size == 32
58+
assert embedder.max_tokens_total == 20000
59+
assert embedder.time_sleep == 30
60+
assert embedder.retries == 3
61+
assert embedder.progress_bar is True
62+
assert embedder.truncate_dim is None
63+
assert embedder.meta_fields_to_embed == []
64+
assert embedder.embedding_separator == "\n"
65+
assert isinstance(embedder.gcp_project_id, Secret)
66+
assert isinstance(embedder.gcp_region_name, Secret)
67+
68+
69+
def test_init_custom_params(mock_vertex_init_and_model):
70+
"""Test initialization with custom parameters."""
71+
mock_init, mock_from_pretrained, _ = mock_vertex_init_and_model
72+
project_id = Secret.from_token("test-project")
73+
region = Secret.from_token("us-west1")
74+
75+
embedder = VertexAIDocumentEmbedder(
76+
model="textembedding-gecko-multilingual@001",
77+
task_type="SEMANTIC_SIMILARITY",
78+
gcp_project_id=project_id,
79+
gcp_region_name=region,
80+
batch_size=64,
81+
max_tokens_total=10000,
82+
time_sleep=10,
83+
retries=5,
84+
progress_bar=False,
85+
truncate_dim=256,
86+
meta_fields_to_embed=["meta_key"],
87+
embedding_separator=" | ",
88+
)
89+
90+
assert embedder.model == "textembedding-gecko-multilingual@001"
91+
assert embedder.task_type == "SEMANTIC_SIMILARITY"
92+
assert embedder.batch_size == 64
93+
assert embedder.max_tokens_total == 10000
94+
assert embedder.time_sleep == 10
95+
assert embedder.retries == 5
96+
assert embedder.progress_bar is False
97+
assert embedder.truncate_dim == 256
98+
assert embedder.meta_fields_to_embed == ["meta_key"]
99+
assert embedder.embedding_separator == " | "
100+
assert embedder.gcp_project_id == project_id
101+
assert embedder.gcp_region_name == region
102+
103+
mock_init.assert_called_once_with(project="test-project", location="us-west1")
104+
mock_from_pretrained.assert_called_once_with("textembedding-gecko-multilingual@001")
105+
106+
107+
def test_init_invalid_model():
108+
"""Test initialization with an invalid model name."""
109+
with pytest.raises(ValueError, match="Please provide a valid model"):
110+
VertexAIDocumentEmbedder(model="invalid-model", task_type=VALID_TASK_TYPE)
111+
112+
113+
def test_prepare_texts_to_embed_no_meta():
114+
"""Test _prepare_texts_to_embed without meta fields."""
115+
embedder = VertexAIDocumentEmbedder(model=VALID_MODEL, task_type=VALID_TASK_TYPE)
116+
docs = [Document(content="doc1 text"), Document(content="doc2 text")]
117+
texts = embedder._prepare_texts_to_embed(docs)
118+
assert texts == ["doc1 text", "doc2 text"]
119+
120+
121+
def test_prepare_texts_to_embed_with_meta():
122+
"""Test _prepare_texts_to_embed with meta fields."""
123+
embedder = VertexAIDocumentEmbedder(
124+
model=VALID_MODEL, task_type=VALID_TASK_TYPE, meta_fields_to_embed=["meta_key1", "meta_key2"]
125+
)
126+
docs = [
127+
Document(content="doc1 text", meta={"meta_key1": "value1"}),
128+
Document(content="doc2 text", meta={"meta_key1": "value2", "meta_key2": "value3"}),
129+
Document(content="doc3 text", meta={"other_key": "value4"}), # meta_key1/2 missing
130+
Document(content=None, meta={"meta_key1": "value5"}), # None content
131+
]
132+
texts = embedder._prepare_texts_to_embed(docs)
133+
assert texts == [
134+
"value1\ndoc1 text",
135+
"value2\nvalue3\ndoc2 text",
136+
"doc3 text", # Only content if specified meta keys are missing
137+
"value5\n", # Separator is still added even if content is None
138+
]
139+
140+
141+
def test_prepare_texts_to_embed_custom_separator():
142+
"""Test _prepare_texts_to_embed with a custom separator."""
143+
embedder = VertexAIDocumentEmbedder(
144+
model=VALID_MODEL, task_type=VALID_TASK_TYPE, meta_fields_to_embed=["meta_key"], embedding_separator=" --- "
145+
)
146+
docs = [Document(content="doc text", meta={"meta_key": "value"})]
147+
texts = embedder._prepare_texts_to_embed(docs)
148+
assert texts == ["value --- doc text"]
149+
150+
151+
def test_get_text_embedding_input(mock_vertex_init_and_model):
152+
"""Test conversion of Documents to TextEmbeddingInput."""
153+
embedder = VertexAIDocumentEmbedder(model=VALID_MODEL, task_type="CLASSIFICATION")
154+
docs = [Document(content="text1"), Document(content="text2")]
155+
156+
with patch.object(embedder, "_prepare_texts_to_embed", return_value=["prep_text1", "prep_text2"]) as mock_prepare:
157+
inputs = embedder.get_text_embedding_input(docs)
158+
159+
mock_prepare.assert_called_once_with(documents=docs)
160+
assert len(inputs) == 2
161+
assert isinstance(inputs[0], TextEmbeddingInput)
162+
assert inputs[0].text == "prep_text1"
163+
assert inputs[0].task_type == "CLASSIFICATION"
164+
assert isinstance(inputs[1], TextEmbeddingInput)
165+
assert inputs[1].text == "prep_text2"
166+
assert inputs[1].task_type == "CLASSIFICATION"
167+
168+
169+
def test_embed_batch(mock_vertex_init_and_model):
170+
"""Test embedding a single batch successfully."""
171+
_, _, mock_embedder_instance = mock_vertex_init_and_model
172+
embedder = VertexAIDocumentEmbedder(model=VALID_MODEL, task_type=VALID_TASK_TYPE)
173+
docs = [Document(content="text1"), Document(content="text2")]
174+
prepared_texts = ["text1", "text2"]
175+
expected_embeddings = [[0.1] * 10, [0.2] * 10]
176+
177+
# Mock the response from the underlying API
178+
mock_embedder_instance.get_embeddings.return_value = [
179+
MockTextEmbeddingResponse(expected_embeddings[0]),
180+
MockTextEmbeddingResponse(expected_embeddings[1]),
181+
]
182+
183+
with patch.object(embedder, "_prepare_texts_to_embed", return_value=prepared_texts):
184+
embeddings = embedder.embed_batch(docs)
185+
186+
assert embeddings == expected_embeddings
187+
# Check that get_embeddings was called with the correct TextEmbeddingInput objects
188+
call_args, _ = mock_embedder_instance.get_embeddings.call_args
189+
inputs = call_args[0]
190+
assert len(inputs) == 2
191+
assert inputs[0].text == "text1"
192+
assert inputs[0].task_type == VALID_TASK_TYPE
193+
assert inputs[1].text == "text2"
194+
assert inputs[1].task_type == VALID_TASK_TYPE
195+
196+
def test_to_dict(mock_vertex_init_and_model):
197+
"""Test serialization to dictionary."""
198+
project_id = Secret.from_env_var("GCP_PROJECT_ID", strict=False)
199+
region = Secret.from_env_var("GCP_DEFAULT_REGION", strict=False)
200+
embedder = VertexAIDocumentEmbedder(
201+
model=VALID_MODEL,
202+
task_type=VALID_TASK_TYPE,
203+
gcp_project_id=project_id,
204+
gcp_region_name=region,
205+
batch_size=64,
206+
progress_bar=False,
207+
truncate_dim=128,
208+
meta_fields_to_embed=["meta1"],
209+
embedding_separator="||",
210+
)
211+
data = embedder.to_dict()
212+
213+
assert data == {
214+
"type": "haystack_integrations.components.embedders.google_vertex.document_embedder.VertexAIDocumentEmbedder",
215+
"init_parameters": {
216+
"model": VALID_MODEL,
217+
"task_type": VALID_TASK_TYPE,
218+
"gcp_project_id": project_id.to_dict(),
219+
"gcp_region_name": region.to_dict(),
220+
"batch_size": 64,
221+
"max_tokens_total": 20000, # Default value was not overridden
222+
"time_sleep": 30, # Default value was not overridden
223+
"retries": 3, # Default value was not overridden
224+
"progress_bar": False,
225+
"truncate_dim": 128,
226+
"meta_fields_to_embed": ["meta1"],
227+
"embedding_separator": "||",
228+
},
229+
}
230+
231+
232+
def test_from_dict(mock_vertex_init_and_model):
233+
"""Test deserialization from dictionary."""
234+
mock_init, mock_from_pretrained, _ = mock_vertex_init_and_model
235+
project_id_dict = Secret.from_env_var("GCP_PROJECT_ID", strict=False).to_dict()
236+
region_dict = Secret.from_env_var("GCP_DEFAULT_REGION", strict=False).to_dict()
237+
238+
data = {
239+
"type": "haystack_integrations.components.embedders.google_vertex.document_embedder.VertexAIDocumentEmbedder",
240+
"init_parameters": {
241+
"model": "text-multilingual-embedding-002",
242+
"task_type": "CLUSTERING",
243+
"gcp_project_id": project_id_dict,
244+
"gcp_region_name": region_dict,
245+
"batch_size": 16,
246+
"progress_bar": True,
247+
"truncate_dim": None,
248+
"meta_fields_to_embed": None,
249+
"embedding_separator": "\n",
250+
# Include defaults that might be missing if saved from older versions
251+
"max_tokens_total": 20000,
252+
"time_sleep": 30,
253+
"retries": 3,
254+
},
255+
}
256+
257+
embedder = VertexAIDocumentEmbedder.from_dict(data)
258+
259+
assert embedder.model == "text-multilingual-embedding-002"
260+
assert embedder.task_type == "CLUSTERING"
261+
assert isinstance(embedder.gcp_project_id, Secret)
262+
assert isinstance(embedder.gcp_region_name, Secret)
263+
assert embedder.batch_size == 16
264+
assert embedder.progress_bar is True
265+
assert embedder.truncate_dim is None
266+
assert embedder.meta_fields_to_embed is None
267+
assert embedder.embedding_separator == "\n"
268+
assert embedder.max_tokens_total == 20000
269+
assert embedder.time_sleep == 30
270+
assert embedder.retries == 3
271+
272+
# Check that vertexai.init and from_pretrained were called again
273+
mock_init.assert_called_once()
274+
mock_from_pretrained.assert_called_once_with("text-multilingual-embedding-002")
275+
276+
277+
def test_from_dict_no_secrets(mock_vertex_init_and_model):
278+
"""Test deserialization when secrets are not in the dictionary."""
279+
mock_init, mock_from_pretrained, _ = mock_vertex_init_and_model
280+
data = {
281+
"type": "haystack_integrations.components.embedders.google_vertex.document_embedder.VertexAIDocumentEmbedder",
282+
"init_parameters": {
283+
"model": VALID_MODEL,
284+
"task_type": VALID_TASK_TYPE,
285+
"gcp_project_id": None, # Explicitly None
286+
"gcp_region_name": None, # Explicitly None
287+
"batch_size": 32,
288+
"progress_bar": True,
289+
"truncate_dim": None,
290+
"meta_fields_to_embed": None,
291+
"embedding_separator": "\n",
292+
"max_tokens_total": 20000,
293+
"time_sleep": 30,
294+
"retries": 3,
295+
},
296+
}
297+
embedder = VertexAIDocumentEmbedder.from_dict(data)
298+
assert embedder.gcp_project_id is None
299+
assert embedder.gcp_region_name is None
300+
mock_init.assert_called_once_with(project=None, location=None)
301+
mock_from_pretrained.assert_called_once_with(VALID_MODEL)

0 commit comments

Comments
 (0)