Skip to content

Commit 595b580

Browse files
lingyinwcopybara-github
authored andcommitted
fix: fix server error due to no encryption_spec_key_name in MatchingEngineIndex create_tree_ah_index and create_brute_force_index
PiperOrigin-RevId: 582880161
1 parent dd4b852 commit 595b580

File tree

2 files changed

+87
-4
lines changed

2 files changed

+87
-4
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,14 @@ def _create(
196196
"contentsDeltaUri": contents_delta_uri,
197197
},
198198
index_update_method=index_update_method_enum,
199-
encryption_spec=gca_encryption_spec.EncryptionSpec(
200-
kms_key_name=encryption_spec_key_name
201-
),
202199
)
203200

201+
if encryption_spec_key_name:
202+
encryption_spec = gca_encryption_spec.EncryptionSpec(
203+
kms_key_name=encryption_spec_key_name
204+
)
205+
gapic_index.encryption_spec = encryption_spec
206+
204207
if labels:
205208
utils.validate_labels(labels)
206209
gapic_index.labels = labels

tests/unit/aiplatform/test_matching_engine_index.py

+81-1
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,50 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method
369369
metadata=_TEST_REQUEST_METADATA,
370370
)
371371

372+
@pytest.mark.usefixtures("get_index_mock")
373+
def test_create_tree_ah_index_backward_compatibility(self, create_index_mock):
374+
aiplatform.init(project=_TEST_PROJECT)
375+
376+
aiplatform.MatchingEngineIndex.create_tree_ah_index(
377+
display_name=_TEST_INDEX_DISPLAY_NAME,
378+
contents_delta_uri=_TEST_CONTENTS_DELTA_URI,
379+
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
380+
approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
381+
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
382+
leaf_node_embedding_count=_TEST_LEAF_NODE_EMBEDDING_COUNT,
383+
leaf_nodes_to_search_percent=_TEST_LEAF_NODES_TO_SEARCH_PERCENT,
384+
description=_TEST_INDEX_DESCRIPTION,
385+
labels=_TEST_LABELS,
386+
)
387+
388+
config = {
389+
"treeAhConfig": {
390+
"leafNodeEmbeddingCount": _TEST_LEAF_NODE_EMBEDDING_COUNT,
391+
"leafNodesToSearchPercent": _TEST_LEAF_NODES_TO_SEARCH_PERCENT,
392+
}
393+
}
394+
395+
expected = gca_index.Index(
396+
display_name=_TEST_INDEX_DISPLAY_NAME,
397+
metadata={
398+
"config": {
399+
"algorithmConfig": config,
400+
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
401+
"approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
402+
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
403+
},
404+
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
405+
},
406+
description=_TEST_INDEX_DESCRIPTION,
407+
labels=_TEST_LABELS,
408+
)
409+
410+
create_index_mock.assert_called_once_with(
411+
parent=_TEST_PARENT,
412+
index=expected,
413+
metadata=_TEST_REQUEST_METADATA,
414+
)
415+
372416
@pytest.mark.usefixtures("get_index_mock")
373417
@pytest.mark.parametrize("sync", [True, False])
374418
@pytest.mark.parametrize(
@@ -419,7 +463,7 @@ def test_create_brute_force_index(
419463
index_update_method
420464
],
421465
encryption_spec=gca_encryption_spec.EncryptionSpec(
422-
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
466+
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
423467
),
424468
)
425469

@@ -429,6 +473,42 @@ def test_create_brute_force_index(
429473
metadata=_TEST_REQUEST_METADATA,
430474
)
431475

476+
@pytest.mark.usefixtures("get_index_mock")
477+
def test_create_brute_force_index_backward_compatibility(self, create_index_mock):
478+
aiplatform.init(project=_TEST_PROJECT)
479+
480+
aiplatform.MatchingEngineIndex.create_brute_force_index(
481+
display_name=_TEST_INDEX_DISPLAY_NAME,
482+
contents_delta_uri=_TEST_CONTENTS_DELTA_URI,
483+
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
484+
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
485+
description=_TEST_INDEX_DESCRIPTION,
486+
labels=_TEST_LABELS,
487+
)
488+
489+
config = {"bruteForceConfig": {}}
490+
491+
expected = gca_index.Index(
492+
display_name=_TEST_INDEX_DISPLAY_NAME,
493+
metadata={
494+
"config": {
495+
"algorithmConfig": config,
496+
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
497+
"approximateNeighborsCount": None,
498+
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
499+
},
500+
"contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
501+
},
502+
description=_TEST_INDEX_DESCRIPTION,
503+
labels=_TEST_LABELS,
504+
)
505+
506+
create_index_mock.assert_called_once_with(
507+
parent=_TEST_PARENT,
508+
index=expected,
509+
metadata=_TEST_REQUEST_METADATA,
510+
)
511+
432512
@pytest.mark.usefixtures("get_index_mock")
433513
def test_remove_datapoints(self, remove_datapoints_mock):
434514
aiplatform.init(project=_TEST_PROJECT)

0 commit comments

Comments
 (0)