|
| 1 | +import math |
| 2 | +import time |
| 3 | +from unittest.mock import MagicMock, call, patch |
| 4 | + |
| 5 | +import pytest |
| 6 | +import vertexai |
| 7 | +from haystack import Document |
| 8 | +from haystack.dataclasses.document import Document |
| 9 | +from haystack.utils.auth import Secret |
| 10 | +from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel |
| 11 | + |
| 12 | +# Assume the classes are in this path for the tests |
| 13 | +from haystack_integrations.components.embedders.google_vertex.document_embedder import ( |
| 14 | + SUPPORTED_EMBEDDING_MODELS, |
| 15 | + VertexAIDocumentEmbedder, |
| 16 | +) |
| 17 | + |
| 18 | + |
| 19 | +# Mock the TextEmbeddingResponse structure expected by the embedder |
| 20 | +class MockTextEmbeddingResponse: |
| 21 | + def __init__(self, values): |
| 22 | + self.values = values |
| 23 | + |
| 24 | + |
| 25 | +# Mock the CountTokensResponse structure |
| 26 | +class MockCountTokensResponse: |
| 27 | + def __init__(self, total_tokens): |
| 28 | + self.total_tokens = total_tokens |
| 29 | + |
| 30 | + |
| 31 | +@pytest.fixture() |
| 32 | +def mock_vertex_init_and_model(): |
| 33 | + """ |
| 34 | + Fixture to mock vertexai.init and TextEmbeddingModel.from_pretrained |
| 35 | + """ |
| 36 | + with patch("vertexai.init") as mock_init, patch( |
| 37 | + "vertexai.language_models.TextEmbeddingModel.from_pretrained" |
| 38 | + ) as mock_from_pretrained: |
| 39 | + mock_embedder = MagicMock(spec=TextEmbeddingModel) |
| 40 | + mock_embedder.get_embeddings.return_value = [MockTextEmbeddingResponse([0.1] * 768)] |
| 41 | + mock_embedder.count_tokens.return_value = MockCountTokensResponse(total_tokens=10) |
| 42 | + mock_from_pretrained.return_value = mock_embedder |
| 43 | + yield mock_init, mock_from_pretrained, mock_embedder |
| 44 | + |
| 45 | + |
| 46 | +# Define valid parameters for initialization |
| 47 | +VALID_MODEL = "text-embedding-005" |
| 48 | +VALID_TASK_TYPE = "RETRIEVAL_DOCUMENT" |
| 49 | + |
| 50 | + |
| 51 | +def test_init_defaults(): |
| 52 | + """Test default initialization.""" |
| 53 | + |
| 54 | + embedder = VertexAIDocumentEmbedder(model=VALID_MODEL, task_type=VALID_TASK_TYPE) |
| 55 | + assert embedder.model == VALID_MODEL |
| 56 | + assert embedder.task_type == VALID_TASK_TYPE |
| 57 | + assert embedder.batch_size == 32 |
| 58 | + assert embedder.max_tokens_total == 20000 |
| 59 | + assert embedder.time_sleep == 30 |
| 60 | + assert embedder.retries == 3 |
| 61 | + assert embedder.progress_bar is True |
| 62 | + assert embedder.truncate_dim is None |
| 63 | + assert embedder.meta_fields_to_embed == [] |
| 64 | + assert embedder.embedding_separator == "\n" |
| 65 | + assert isinstance(embedder.gcp_project_id, Secret) |
| 66 | + assert isinstance(embedder.gcp_region_name, Secret) |
| 67 | + |
| 68 | + |
| 69 | +def test_init_custom_params(mock_vertex_init_and_model): |
| 70 | + """Test initialization with custom parameters.""" |
| 71 | + mock_init, mock_from_pretrained, _ = mock_vertex_init_and_model |
| 72 | + project_id = Secret.from_token("test-project") |
| 73 | + region = Secret.from_token("us-west1") |
| 74 | + |
| 75 | + embedder = VertexAIDocumentEmbedder( |
| 76 | + model="textembedding-gecko-multilingual@001", |
| 77 | + task_type="SEMANTIC_SIMILARITY", |
| 78 | + gcp_project_id=project_id, |
| 79 | + gcp_region_name=region, |
| 80 | + batch_size=64, |
| 81 | + max_tokens_total=10000, |
| 82 | + time_sleep=10, |
| 83 | + retries=5, |
| 84 | + progress_bar=False, |
| 85 | + truncate_dim=256, |
| 86 | + meta_fields_to_embed=["meta_key"], |
| 87 | + embedding_separator=" | ", |
| 88 | + ) |
| 89 | + |
| 90 | + assert embedder.model == "textembedding-gecko-multilingual@001" |
| 91 | + assert embedder.task_type == "SEMANTIC_SIMILARITY" |
| 92 | + assert embedder.batch_size == 64 |
| 93 | + assert embedder.max_tokens_total == 10000 |
| 94 | + assert embedder.time_sleep == 10 |
| 95 | + assert embedder.retries == 5 |
| 96 | + assert embedder.progress_bar is False |
| 97 | + assert embedder.truncate_dim == 256 |
| 98 | + assert embedder.meta_fields_to_embed == ["meta_key"] |
| 99 | + assert embedder.embedding_separator == " | " |
| 100 | + assert embedder.gcp_project_id == project_id |
| 101 | + assert embedder.gcp_region_name == region |
| 102 | + |
| 103 | + mock_init.assert_called_once_with(project="test-project", location="us-west1") |
| 104 | + mock_from_pretrained.assert_called_once_with("textembedding-gecko-multilingual@001") |
| 105 | + |
| 106 | + |
| 107 | +def test_init_invalid_model(): |
| 108 | + """Test initialization with an invalid model name.""" |
| 109 | + with pytest.raises(ValueError, match="Please provide a valid model"): |
| 110 | + VertexAIDocumentEmbedder(model="invalid-model", task_type=VALID_TASK_TYPE) |
| 111 | + |
| 112 | + |
| 113 | +def test_prepare_texts_to_embed_no_meta(): |
| 114 | + """Test _prepare_texts_to_embed without meta fields.""" |
| 115 | + embedder = VertexAIDocumentEmbedder(model=VALID_MODEL, task_type=VALID_TASK_TYPE) |
| 116 | + docs = [Document(content="doc1 text"), Document(content="doc2 text")] |
| 117 | + texts = embedder._prepare_texts_to_embed(docs) |
| 118 | + assert texts == ["doc1 text", "doc2 text"] |
| 119 | + |
| 120 | + |
| 121 | +def test_prepare_texts_to_embed_with_meta(): |
| 122 | + """Test _prepare_texts_to_embed with meta fields.""" |
| 123 | + embedder = VertexAIDocumentEmbedder( |
| 124 | + model=VALID_MODEL, task_type=VALID_TASK_TYPE, meta_fields_to_embed=["meta_key1", "meta_key2"] |
| 125 | + ) |
| 126 | + docs = [ |
| 127 | + Document(content="doc1 text", meta={"meta_key1": "value1"}), |
| 128 | + Document(content="doc2 text", meta={"meta_key1": "value2", "meta_key2": "value3"}), |
| 129 | + Document(content="doc3 text", meta={"other_key": "value4"}), # meta_key1/2 missing |
| 130 | + Document(content=None, meta={"meta_key1": "value5"}), # None content |
| 131 | + ] |
| 132 | + texts = embedder._prepare_texts_to_embed(docs) |
| 133 | + assert texts == [ |
| 134 | + "value1\ndoc1 text", |
| 135 | + "value2\nvalue3\ndoc2 text", |
| 136 | + "doc3 text", # Only content if specified meta keys are missing |
| 137 | + "value5\n", # Separator is still added even if content is None |
| 138 | + ] |
| 139 | + |
| 140 | + |
| 141 | +def test_prepare_texts_to_embed_custom_separator(): |
| 142 | + """Test _prepare_texts_to_embed with a custom separator.""" |
| 143 | + embedder = VertexAIDocumentEmbedder( |
| 144 | + model=VALID_MODEL, task_type=VALID_TASK_TYPE, meta_fields_to_embed=["meta_key"], embedding_separator=" --- " |
| 145 | + ) |
| 146 | + docs = [Document(content="doc text", meta={"meta_key": "value"})] |
| 147 | + texts = embedder._prepare_texts_to_embed(docs) |
| 148 | + assert texts == ["value --- doc text"] |
| 149 | + |
| 150 | + |
| 151 | +def test_get_text_embedding_input(mock_vertex_init_and_model): |
| 152 | + """Test conversion of Documents to TextEmbeddingInput.""" |
| 153 | + embedder = VertexAIDocumentEmbedder(model=VALID_MODEL, task_type="CLASSIFICATION") |
| 154 | + docs = [Document(content="text1"), Document(content="text2")] |
| 155 | + |
| 156 | + with patch.object(embedder, "_prepare_texts_to_embed", return_value=["prep_text1", "prep_text2"]) as mock_prepare: |
| 157 | + inputs = embedder.get_text_embedding_input(docs) |
| 158 | + |
| 159 | + mock_prepare.assert_called_once_with(documents=docs) |
| 160 | + assert len(inputs) == 2 |
| 161 | + assert isinstance(inputs[0], TextEmbeddingInput) |
| 162 | + assert inputs[0].text == "prep_text1" |
| 163 | + assert inputs[0].task_type == "CLASSIFICATION" |
| 164 | + assert isinstance(inputs[1], TextEmbeddingInput) |
| 165 | + assert inputs[1].text == "prep_text2" |
| 166 | + assert inputs[1].task_type == "CLASSIFICATION" |
| 167 | + |
| 168 | + |
| 169 | +def test_embed_batch(mock_vertex_init_and_model): |
| 170 | + """Test embedding a single batch successfully.""" |
| 171 | + _, _, mock_embedder_instance = mock_vertex_init_and_model |
| 172 | + embedder = VertexAIDocumentEmbedder(model=VALID_MODEL, task_type=VALID_TASK_TYPE) |
| 173 | + docs = [Document(content="text1"), Document(content="text2")] |
| 174 | + prepared_texts = ["text1", "text2"] |
| 175 | + expected_embeddings = [[0.1] * 10, [0.2] * 10] |
| 176 | + |
| 177 | + # Mock the response from the underlying API |
| 178 | + mock_embedder_instance.get_embeddings.return_value = [ |
| 179 | + MockTextEmbeddingResponse(expected_embeddings[0]), |
| 180 | + MockTextEmbeddingResponse(expected_embeddings[1]), |
| 181 | + ] |
| 182 | + |
| 183 | + with patch.object(embedder, "_prepare_texts_to_embed", return_value=prepared_texts): |
| 184 | + embeddings = embedder.embed_batch(docs) |
| 185 | + |
| 186 | + assert embeddings == expected_embeddings |
| 187 | + # Check that get_embeddings was called with the correct TextEmbeddingInput objects |
| 188 | + call_args, _ = mock_embedder_instance.get_embeddings.call_args |
| 189 | + inputs = call_args[0] |
| 190 | + assert len(inputs) == 2 |
| 191 | + assert inputs[0].text == "text1" |
| 192 | + assert inputs[0].task_type == VALID_TASK_TYPE |
| 193 | + assert inputs[1].text == "text2" |
| 194 | + assert inputs[1].task_type == VALID_TASK_TYPE |
| 195 | + |
| 196 | +def test_to_dict(mock_vertex_init_and_model): |
| 197 | + """Test serialization to dictionary.""" |
| 198 | + project_id = Secret.from_env_var("GCP_PROJECT_ID", strict=False) |
| 199 | + region = Secret.from_env_var("GCP_DEFAULT_REGION", strict=False) |
| 200 | + embedder = VertexAIDocumentEmbedder( |
| 201 | + model=VALID_MODEL, |
| 202 | + task_type=VALID_TASK_TYPE, |
| 203 | + gcp_project_id=project_id, |
| 204 | + gcp_region_name=region, |
| 205 | + batch_size=64, |
| 206 | + progress_bar=False, |
| 207 | + truncate_dim=128, |
| 208 | + meta_fields_to_embed=["meta1"], |
| 209 | + embedding_separator="||", |
| 210 | + ) |
| 211 | + data = embedder.to_dict() |
| 212 | + |
| 213 | + assert data == { |
| 214 | + "type": "haystack_integrations.components.embedders.google_vertex.document_embedder.VertexAIDocumentEmbedder", |
| 215 | + "init_parameters": { |
| 216 | + "model": VALID_MODEL, |
| 217 | + "task_type": VALID_TASK_TYPE, |
| 218 | + "gcp_project_id": project_id.to_dict(), |
| 219 | + "gcp_region_name": region.to_dict(), |
| 220 | + "batch_size": 64, |
| 221 | + "max_tokens_total": 20000, # Default value was not overridden |
| 222 | + "time_sleep": 30, # Default value was not overridden |
| 223 | + "retries": 3, # Default value was not overridden |
| 224 | + "progress_bar": False, |
| 225 | + "truncate_dim": 128, |
| 226 | + "meta_fields_to_embed": ["meta1"], |
| 227 | + "embedding_separator": "||", |
| 228 | + }, |
| 229 | + } |
| 230 | + |
| 231 | + |
| 232 | +def test_from_dict(mock_vertex_init_and_model): |
| 233 | + """Test deserialization from dictionary.""" |
| 234 | + mock_init, mock_from_pretrained, _ = mock_vertex_init_and_model |
| 235 | + project_id_dict = Secret.from_env_var("GCP_PROJECT_ID", strict=False).to_dict() |
| 236 | + region_dict = Secret.from_env_var("GCP_DEFAULT_REGION", strict=False).to_dict() |
| 237 | + |
| 238 | + data = { |
| 239 | + "type": "haystack_integrations.components.embedders.google_vertex.document_embedder.VertexAIDocumentEmbedder", |
| 240 | + "init_parameters": { |
| 241 | + "model": "text-multilingual-embedding-002", |
| 242 | + "task_type": "CLUSTERING", |
| 243 | + "gcp_project_id": project_id_dict, |
| 244 | + "gcp_region_name": region_dict, |
| 245 | + "batch_size": 16, |
| 246 | + "progress_bar": True, |
| 247 | + "truncate_dim": None, |
| 248 | + "meta_fields_to_embed": None, |
| 249 | + "embedding_separator": "\n", |
| 250 | + # Include defaults that might be missing if saved from older versions |
| 251 | + "max_tokens_total": 20000, |
| 252 | + "time_sleep": 30, |
| 253 | + "retries": 3, |
| 254 | + }, |
| 255 | + } |
| 256 | + |
| 257 | + embedder = VertexAIDocumentEmbedder.from_dict(data) |
| 258 | + |
| 259 | + assert embedder.model == "text-multilingual-embedding-002" |
| 260 | + assert embedder.task_type == "CLUSTERING" |
| 261 | + assert isinstance(embedder.gcp_project_id, Secret) |
| 262 | + assert isinstance(embedder.gcp_region_name, Secret) |
| 263 | + assert embedder.batch_size == 16 |
| 264 | + assert embedder.progress_bar is True |
| 265 | + assert embedder.truncate_dim is None |
| 266 | + assert embedder.meta_fields_to_embed is None |
| 267 | + assert embedder.embedding_separator == "\n" |
| 268 | + assert embedder.max_tokens_total == 20000 |
| 269 | + assert embedder.time_sleep == 30 |
| 270 | + assert embedder.retries == 3 |
| 271 | + |
| 272 | + # Check that vertexai.init and from_pretrained were called again |
| 273 | + mock_init.assert_called_once() |
| 274 | + mock_from_pretrained.assert_called_once_with("text-multilingual-embedding-002") |
| 275 | + |
| 276 | + |
| 277 | +def test_from_dict_no_secrets(mock_vertex_init_and_model): |
| 278 | + """Test deserialization when secrets are not in the dictionary.""" |
| 279 | + mock_init, mock_from_pretrained, _ = mock_vertex_init_and_model |
| 280 | + data = { |
| 281 | + "type": "haystack_integrations.components.embedders.google_vertex.document_embedder.VertexAIDocumentEmbedder", |
| 282 | + "init_parameters": { |
| 283 | + "model": VALID_MODEL, |
| 284 | + "task_type": VALID_TASK_TYPE, |
| 285 | + "gcp_project_id": None, # Explicitly None |
| 286 | + "gcp_region_name": None, # Explicitly None |
| 287 | + "batch_size": 32, |
| 288 | + "progress_bar": True, |
| 289 | + "truncate_dim": None, |
| 290 | + "meta_fields_to_embed": None, |
| 291 | + "embedding_separator": "\n", |
| 292 | + "max_tokens_total": 20000, |
| 293 | + "time_sleep": 30, |
| 294 | + "retries": 3, |
| 295 | + }, |
| 296 | + } |
| 297 | + embedder = VertexAIDocumentEmbedder.from_dict(data) |
| 298 | + assert embedder.gcp_project_id is None |
| 299 | + assert embedder.gcp_region_name is None |
| 300 | + mock_init.assert_called_once_with(project=None, location=None) |
| 301 | + mock_from_pretrained.assert_called_once_with(VALID_MODEL) |
0 commit comments