Skip to content

Commit a00db07

Browse files
lingyinwcopybara-github
authored andcommitted
feat: Support empty index for MatchingEngineIndex create index.
PiperOrigin-RevId: 599760961
1 parent b0b604e commit a00db07

File tree

2 files changed

+143
-14
lines changed

2 files changed

+143
-14
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def description(self) -> str:
101101
def _create(
102102
cls,
103103
display_name: str,
104-
contents_delta_uri: str,
105-
config: matching_engine_index_config.MatchingEngineIndexConfig,
104+
contents_delta_uri: Optional[str] = None,
105+
config: matching_engine_index_config.MatchingEngineIndexConfig = None,
106106
description: Optional[str] = None,
107107
labels: Optional[Dict[str, str]] = None,
108108
project: Optional[str] = None,
@@ -121,7 +121,7 @@ def _create(
121121
The name can be up to 128 characters long and
122122
can be consist of any UTF-8 characters.
123123
contents_delta_uri (str):
124-
Required. Allows inserting, updating or deleting the contents of the Matching Engine Index.
124+
Optional. Allows inserting, updating or deleting the contents of the Matching Engine Index.
125125
The string must be a valid Google Cloud Storage directory path. If this
126126
field is set when calling IndexService.UpdateIndex, then no other
127127
Index field can be also updated as part of the same call.
@@ -188,13 +188,17 @@ def _create(
188188
index_update_method
189189
]
190190

191+
metadata = {"config": config.as_dict()}
192+
if contents_delta_uri:
193+
metadata = {
194+
"config": config.as_dict(),
195+
"contentsDeltaUri": contents_delta_uri,
196+
}
197+
191198
gapic_index = gca_matching_engine_index.Index(
192199
display_name=display_name,
193200
description=description,
194-
metadata={
195-
"config": config.as_dict(),
196-
"contentsDeltaUri": contents_delta_uri,
197-
},
201+
metadata=metadata,
198202
index_update_method=index_update_method_enum,
199203
)
200204

@@ -399,9 +403,9 @@ def deployed_indexes(
399403
def create_tree_ah_index(
400404
cls,
401405
display_name: str,
402-
contents_delta_uri: str,
403-
dimensions: int,
404-
approximate_neighbors_count: int,
406+
contents_delta_uri: Optional[str] = None,
407+
dimensions: int = None,
408+
approximate_neighbors_count: int = None,
405409
leaf_node_embedding_count: Optional[int] = None,
406410
leaf_nodes_to_search_percent: Optional[float] = None,
407411
distance_measure_type: Optional[
@@ -439,7 +443,7 @@ def create_tree_ah_index(
439443
The name can be up to 128 characters long and
440444
can be consist of any UTF-8 characters.
441445
contents_delta_uri (str):
442-
Required. Allows inserting, updating or deleting the contents of the Matching Engine Index.
446+
Optional. Allows inserting, updating or deleting the contents of the Matching Engine Index.
443447
The string must be a valid Google Cloud Storage directory path. If this
444448
field is set when calling IndexService.UpdateIndex, then no other
445449
Index field can be also updated as part of the same call.
@@ -543,8 +547,8 @@ def create_tree_ah_index(
543547
def create_brute_force_index(
544548
cls,
545549
display_name: str,
546-
contents_delta_uri: str,
547-
dimensions: int,
550+
contents_delta_uri: Optional[str] = None,
551+
dimensions: int = None,
548552
distance_measure_type: Optional[
549553
matching_engine_index_config.DistanceMeasureType
550554
] = None,
@@ -578,7 +582,7 @@ def create_brute_force_index(
578582
The name can be up to 128 characters long and
579583
can be consist of any UTF-8 characters.
580584
contents_delta_uri (str):
581-
Required. Allows inserting, updating or deleting the contents of the Matching Engine Index.
585+
Optional. Allows inserting, updating or deleting the contents of the Matching Engine Index.
582586
The string must be a valid Google Cloud Storage directory path. If this
583587
field is set when calling IndexService.UpdateIndex, then no other
584588
Index field can be also updated as part of the same call.

tests/unit/aiplatform/test_matching_engine_index.py

+125
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,73 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method
409409
metadata=_TEST_REQUEST_METADATA,
410410
)
411411

412+
@pytest.mark.usefixtures("get_index_mock")
413+
@pytest.mark.parametrize("sync", [True, False])
414+
@pytest.mark.parametrize(
415+
"index_update_method",
416+
[
417+
_TEST_INDEX_STREAM_UPDATE_METHOD,
418+
_TEST_INDEX_BATCH_UPDATE_METHOD,
419+
_TEST_INDEX_EMPTY_UPDATE_METHOD,
420+
_TEST_INDEX_INVALID_UPDATE_METHOD,
421+
],
422+
)
423+
def test_create_tree_ah_index_with_empty_index(
424+
self, create_index_mock, sync, index_update_method
425+
):
426+
aiplatform.init(project=_TEST_PROJECT)
427+
428+
my_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
429+
display_name=_TEST_INDEX_DISPLAY_NAME,
430+
contents_delta_uri=None,
431+
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
432+
approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
433+
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
434+
leaf_node_embedding_count=_TEST_LEAF_NODE_EMBEDDING_COUNT,
435+
leaf_nodes_to_search_percent=_TEST_LEAF_NODES_TO_SEARCH_PERCENT,
436+
description=_TEST_INDEX_DESCRIPTION,
437+
labels=_TEST_LABELS,
438+
sync=sync,
439+
index_update_method=index_update_method,
440+
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
441+
)
442+
443+
if not sync:
444+
my_index.wait()
445+
446+
config = {
447+
"treeAhConfig": {
448+
"leafNodeEmbeddingCount": _TEST_LEAF_NODE_EMBEDDING_COUNT,
449+
"leafNodesToSearchPercent": _TEST_LEAF_NODES_TO_SEARCH_PERCENT,
450+
}
451+
}
452+
453+
expected = gca_index.Index(
454+
display_name=_TEST_INDEX_DISPLAY_NAME,
455+
metadata={
456+
"config": {
457+
"algorithmConfig": config,
458+
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
459+
"approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
460+
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
461+
},
462+
},
463+
description=_TEST_INDEX_DESCRIPTION,
464+
labels=_TEST_LABELS,
465+
index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[
466+
index_update_method
467+
],
468+
encryption_spec=gca_encryption_spec.EncryptionSpec(
469+
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
470+
),
471+
)
472+
473+
create_index_mock.assert_called_once_with(
474+
parent=_TEST_PARENT,
475+
index=expected,
476+
metadata=_TEST_REQUEST_METADATA,
477+
)
478+
412479
@pytest.mark.usefixtures("get_index_mock")
413480
def test_create_tree_ah_index_backward_compatibility(self, create_index_mock):
414481
aiplatform.init(project=_TEST_PROJECT)
@@ -513,6 +580,64 @@ def test_create_brute_force_index(
513580
metadata=_TEST_REQUEST_METADATA,
514581
)
515582

583+
@pytest.mark.usefixtures("get_index_mock")
584+
@pytest.mark.parametrize("sync", [True, False])
585+
@pytest.mark.parametrize(
586+
"index_update_method",
587+
[
588+
_TEST_INDEX_STREAM_UPDATE_METHOD,
589+
_TEST_INDEX_BATCH_UPDATE_METHOD,
590+
_TEST_INDEX_EMPTY_UPDATE_METHOD,
591+
_TEST_INDEX_INVALID_UPDATE_METHOD,
592+
],
593+
)
594+
def test_create_brute_force_index_with_empty_index(
595+
self, create_index_mock, sync, index_update_method
596+
):
597+
aiplatform.init(project=_TEST_PROJECT)
598+
599+
my_index = aiplatform.MatchingEngineIndex.create_brute_force_index(
600+
display_name=_TEST_INDEX_DISPLAY_NAME,
601+
dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
602+
distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
603+
description=_TEST_INDEX_DESCRIPTION,
604+
labels=_TEST_LABELS,
605+
sync=sync,
606+
index_update_method=index_update_method,
607+
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
608+
)
609+
610+
if not sync:
611+
my_index.wait()
612+
613+
config = {"bruteForceConfig": {}}
614+
615+
expected = gca_index.Index(
616+
display_name=_TEST_INDEX_DISPLAY_NAME,
617+
metadata={
618+
"config": {
619+
"algorithmConfig": config,
620+
"dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
621+
"approximateNeighborsCount": None,
622+
"distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
623+
},
624+
},
625+
description=_TEST_INDEX_DESCRIPTION,
626+
labels=_TEST_LABELS,
627+
index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[
628+
index_update_method
629+
],
630+
encryption_spec=gca_encryption_spec.EncryptionSpec(
631+
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
632+
),
633+
)
634+
635+
create_index_mock.assert_called_once_with(
636+
parent=_TEST_PARENT,
637+
index=expected,
638+
metadata=_TEST_REQUEST_METADATA,
639+
)
640+
516641
@pytest.mark.usefixtures("get_index_mock")
517642
def test_create_brute_force_index_backward_compatibility(self, create_index_mock):
518643
aiplatform.init(project=_TEST_PROJECT)

0 commit comments

Comments
 (0)