@@ -409,6 +409,73 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method
409
409
metadata = _TEST_REQUEST_METADATA ,
410
410
)
411
411
412
+ @pytest .mark .usefixtures ("get_index_mock" )
413
+ @pytest .mark .parametrize ("sync" , [True , False ])
414
+ @pytest .mark .parametrize (
415
+ "index_update_method" ,
416
+ [
417
+ _TEST_INDEX_STREAM_UPDATE_METHOD ,
418
+ _TEST_INDEX_BATCH_UPDATE_METHOD ,
419
+ _TEST_INDEX_EMPTY_UPDATE_METHOD ,
420
+ _TEST_INDEX_INVALID_UPDATE_METHOD ,
421
+ ],
422
+ )
423
+ def test_create_tree_ah_index_with_empty_index (
424
+ self , create_index_mock , sync , index_update_method
425
+ ):
426
+ aiplatform .init (project = _TEST_PROJECT )
427
+
428
+ my_index = aiplatform .MatchingEngineIndex .create_tree_ah_index (
429
+ display_name = _TEST_INDEX_DISPLAY_NAME ,
430
+ contents_delta_uri = None ,
431
+ dimensions = _TEST_INDEX_CONFIG_DIMENSIONS ,
432
+ approximate_neighbors_count = _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT ,
433
+ distance_measure_type = _TEST_INDEX_DISTANCE_MEASURE_TYPE ,
434
+ leaf_node_embedding_count = _TEST_LEAF_NODE_EMBEDDING_COUNT ,
435
+ leaf_nodes_to_search_percent = _TEST_LEAF_NODES_TO_SEARCH_PERCENT ,
436
+ description = _TEST_INDEX_DESCRIPTION ,
437
+ labels = _TEST_LABELS ,
438
+ sync = sync ,
439
+ index_update_method = index_update_method ,
440
+ encryption_spec_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME ,
441
+ )
442
+
443
+ if not sync :
444
+ my_index .wait ()
445
+
446
+ config = {
447
+ "treeAhConfig" : {
448
+ "leafNodeEmbeddingCount" : _TEST_LEAF_NODE_EMBEDDING_COUNT ,
449
+ "leafNodesToSearchPercent" : _TEST_LEAF_NODES_TO_SEARCH_PERCENT ,
450
+ }
451
+ }
452
+
453
+ expected = gca_index .Index (
454
+ display_name = _TEST_INDEX_DISPLAY_NAME ,
455
+ metadata = {
456
+ "config" : {
457
+ "algorithmConfig" : config ,
458
+ "dimensions" : _TEST_INDEX_CONFIG_DIMENSIONS ,
459
+ "approximateNeighborsCount" : _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT ,
460
+ "distanceMeasureType" : _TEST_INDEX_DISTANCE_MEASURE_TYPE ,
461
+ },
462
+ },
463
+ description = _TEST_INDEX_DESCRIPTION ,
464
+ labels = _TEST_LABELS ,
465
+ index_update_method = _TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP [
466
+ index_update_method
467
+ ],
468
+ encryption_spec = gca_encryption_spec .EncryptionSpec (
469
+ kms_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME
470
+ ),
471
+ )
472
+
473
+ create_index_mock .assert_called_once_with (
474
+ parent = _TEST_PARENT ,
475
+ index = expected ,
476
+ metadata = _TEST_REQUEST_METADATA ,
477
+ )
478
+
412
479
@pytest .mark .usefixtures ("get_index_mock" )
413
480
def test_create_tree_ah_index_backward_compatibility (self , create_index_mock ):
414
481
aiplatform .init (project = _TEST_PROJECT )
@@ -513,6 +580,64 @@ def test_create_brute_force_index(
513
580
metadata = _TEST_REQUEST_METADATA ,
514
581
)
515
582
583
+ @pytest .mark .usefixtures ("get_index_mock" )
584
+ @pytest .mark .parametrize ("sync" , [True , False ])
585
+ @pytest .mark .parametrize (
586
+ "index_update_method" ,
587
+ [
588
+ _TEST_INDEX_STREAM_UPDATE_METHOD ,
589
+ _TEST_INDEX_BATCH_UPDATE_METHOD ,
590
+ _TEST_INDEX_EMPTY_UPDATE_METHOD ,
591
+ _TEST_INDEX_INVALID_UPDATE_METHOD ,
592
+ ],
593
+ )
594
+ def test_create_brute_force_index_with_empty_index (
595
+ self , create_index_mock , sync , index_update_method
596
+ ):
597
+ aiplatform .init (project = _TEST_PROJECT )
598
+
599
+ my_index = aiplatform .MatchingEngineIndex .create_brute_force_index (
600
+ display_name = _TEST_INDEX_DISPLAY_NAME ,
601
+ dimensions = _TEST_INDEX_CONFIG_DIMENSIONS ,
602
+ distance_measure_type = _TEST_INDEX_DISTANCE_MEASURE_TYPE ,
603
+ description = _TEST_INDEX_DESCRIPTION ,
604
+ labels = _TEST_LABELS ,
605
+ sync = sync ,
606
+ index_update_method = index_update_method ,
607
+ encryption_spec_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME ,
608
+ )
609
+
610
+ if not sync :
611
+ my_index .wait ()
612
+
613
+ config = {"bruteForceConfig" : {}}
614
+
615
+ expected = gca_index .Index (
616
+ display_name = _TEST_INDEX_DISPLAY_NAME ,
617
+ metadata = {
618
+ "config" : {
619
+ "algorithmConfig" : config ,
620
+ "dimensions" : _TEST_INDEX_CONFIG_DIMENSIONS ,
621
+ "approximateNeighborsCount" : None ,
622
+ "distanceMeasureType" : _TEST_INDEX_DISTANCE_MEASURE_TYPE ,
623
+ },
624
+ },
625
+ description = _TEST_INDEX_DESCRIPTION ,
626
+ labels = _TEST_LABELS ,
627
+ index_update_method = _TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP [
628
+ index_update_method
629
+ ],
630
+ encryption_spec = gca_encryption_spec .EncryptionSpec (
631
+ kms_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME ,
632
+ ),
633
+ )
634
+
635
+ create_index_mock .assert_called_once_with (
636
+ parent = _TEST_PARENT ,
637
+ index = expected ,
638
+ metadata = _TEST_REQUEST_METADATA ,
639
+ )
640
+
516
641
@pytest .mark .usefixtures ("get_index_mock" )
517
642
def test_create_brute_force_index_backward_compatibility (self , create_index_mock ):
518
643
aiplatform .init (project = _TEST_PROJECT )
0 commit comments