34
34
index as gca_index ,
35
35
match_service_v1beta1 as gca_match_service_v1beta1 ,
36
36
index_v1beta1 as gca_index_v1beta1 ,
37
+ service_networking as gca_service_networking ,
38
+ encryption_spec as gca_encryption_spec ,
37
39
)
38
40
from google .cloud .aiplatform .compat .services import (
39
41
index_endpoint_service_client ,
236
238
_TEST_APPROX_NUM_NEIGHBORS = 2
237
239
_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE = 0.8
238
240
_TEST_RETURN_FULL_DATAPOINT = True
241
+ _TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name"
242
+ _TEST_PROJECT_ALLOWLIST = ["project-1" , "project-2" ]
239
243
240
244
241
245
def uuid_mock ():
@@ -619,6 +623,7 @@ def test_create_index_endpoint(self, create_index_endpoint_mock, sync):
619
623
network = _TEST_INDEX_ENDPOINT_VPC_NETWORK ,
620
624
description = _TEST_INDEX_ENDPOINT_DESCRIPTION ,
621
625
labels = _TEST_LABELS ,
626
+ encryption_spec_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME ,
622
627
)
623
628
624
629
if not sync :
@@ -629,6 +634,42 @@ def test_create_index_endpoint(self, create_index_endpoint_mock, sync):
629
634
network = _TEST_INDEX_ENDPOINT_VPC_NETWORK ,
630
635
description = _TEST_INDEX_ENDPOINT_DESCRIPTION ,
631
636
labels = _TEST_LABELS ,
637
+ encryption_spec = gca_encryption_spec .EncryptionSpec (
638
+ kms_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME
639
+ ),
640
+ )
641
+ create_index_endpoint_mock .assert_called_once_with (
642
+ parent = _TEST_PARENT ,
643
+ index_endpoint = expected ,
644
+ metadata = _TEST_REQUEST_METADATA ,
645
+ )
646
+
647
+ @pytest .mark .usefixtures ("get_index_endpoint_mock" )
648
+ def test_create_index_endpoint_with_private_service_connect (
649
+ self , create_index_endpoint_mock
650
+ ):
651
+ aiplatform .init (project = _TEST_PROJECT )
652
+
653
+ aiplatform .MatchingEngineIndexEndpoint .create (
654
+ display_name = _TEST_INDEX_ENDPOINT_DISPLAY_NAME ,
655
+ description = _TEST_INDEX_ENDPOINT_DESCRIPTION ,
656
+ labels = _TEST_LABELS ,
657
+ enable_private_service_connect = True ,
658
+ project_allowlist = _TEST_PROJECT_ALLOWLIST ,
659
+ encryption_spec_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME ,
660
+ )
661
+
662
+ expected = gca_index_endpoint .IndexEndpoint (
663
+ display_name = _TEST_INDEX_ENDPOINT_DISPLAY_NAME ,
664
+ description = _TEST_INDEX_ENDPOINT_DESCRIPTION ,
665
+ labels = _TEST_LABELS ,
666
+ private_service_connect_config = gca_service_networking .PrivateServiceConnectConfig (
667
+ project_allowlist = _TEST_PROJECT_ALLOWLIST ,
668
+ enable_private_service_connect = True ,
669
+ ),
670
+ encryption_spec = gca_encryption_spec .EncryptionSpec (
671
+ kms_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME
672
+ ),
632
673
)
633
674
create_index_endpoint_mock .assert_called_once_with (
634
675
parent = _TEST_PARENT ,
@@ -644,6 +685,7 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc
644
685
display_name = _TEST_INDEX_ENDPOINT_DISPLAY_NAME ,
645
686
description = _TEST_INDEX_ENDPOINT_DESCRIPTION ,
646
687
labels = _TEST_LABELS ,
688
+ encryption_spec_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME ,
647
689
)
648
690
649
691
expected = gca_index_endpoint .IndexEndpoint (
@@ -652,6 +694,9 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc
652
694
description = _TEST_INDEX_ENDPOINT_DESCRIPTION ,
653
695
labels = _TEST_LABELS ,
654
696
public_endpoint_enabled = False ,
697
+ encryption_spec = gca_encryption_spec .EncryptionSpec (
698
+ kms_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME
699
+ ),
655
700
)
656
701
657
702
create_index_endpoint_mock .assert_called_once_with (
@@ -671,6 +716,7 @@ def test_create_index_endpoint_with_public_endpoint_enabled(
671
716
description = _TEST_INDEX_ENDPOINT_DESCRIPTION ,
672
717
public_endpoint_enabled = True ,
673
718
labels = _TEST_LABELS ,
719
+ encryption_spec_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME ,
674
720
)
675
721
676
722
my_index_endpoint = aiplatform .MatchingEngineIndexEndpoint (
@@ -682,6 +728,9 @@ def test_create_index_endpoint_with_public_endpoint_enabled(
682
728
description = _TEST_INDEX_ENDPOINT_DESCRIPTION ,
683
729
public_endpoint_enabled = True ,
684
730
labels = _TEST_LABELS ,
731
+ encryption_spec = gca_encryption_spec .EncryptionSpec (
732
+ kms_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME
733
+ ),
685
734
)
686
735
687
736
create_index_endpoint_mock .assert_called_once_with (
@@ -700,7 +749,12 @@ def test_create_index_endpoint_missing_argument_throw_error(
700
749
):
701
750
aiplatform .init (project = _TEST_PROJECT )
702
751
703
- expected_message = "Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint"
752
+ expected_message = (
753
+ "Please provide `network` argument for Private Service Access endpoint,"
754
+ "or provide `enable_private_service_connect` for Private Service"
755
+ "Connect endpoint, or provide `public_endpoint_enabled` to"
756
+ "deploy to a public endpoint"
757
+ )
704
758
705
759
with pytest .raises (ValueError ) as exception :
706
760
_ = aiplatform .MatchingEngineIndexEndpoint .create (
@@ -711,12 +765,12 @@ def test_create_index_endpoint_missing_argument_throw_error(
711
765
712
766
assert str (exception .value ) == expected_message
713
767
714
- def test_create_index_endpoint_set_both_throw_error (
768
+ def test_create_index_endpoint_set_both_psa_and_public_throw_error (
715
769
self , create_index_endpoint_mock
716
770
):
717
771
aiplatform .init (project = _TEST_PROJECT )
718
772
719
- expected_message = "`network` and `public_endpoint_enabled` argument should not be set at the same time "
773
+ expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set. "
720
774
721
775
with pytest .raises (ValueError ) as exception :
722
776
_ = aiplatform .MatchingEngineIndexEndpoint .create (
@@ -729,6 +783,42 @@ def test_create_index_endpoint_set_both_throw_error(
729
783
730
784
assert str (exception .value ) == expected_message
731
785
786
+ def test_create_index_endpoint_set_both_psa_and_psc_throw_error (
787
+ self , create_index_endpoint_mock
788
+ ):
789
+ aiplatform .init (project = _TEST_PROJECT )
790
+
791
+ expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set."
792
+
793
+ with pytest .raises (ValueError ) as exception :
794
+ _ = aiplatform .MatchingEngineIndexEndpoint .create (
795
+ display_name = _TEST_INDEX_ENDPOINT_DISPLAY_NAME ,
796
+ description = _TEST_INDEX_ENDPOINT_DESCRIPTION ,
797
+ network = _TEST_INDEX_ENDPOINT_VPC_NETWORK ,
798
+ labels = _TEST_LABELS ,
799
+ enable_private_service_connect = True ,
800
+ )
801
+
802
+ assert str (exception .value ) == expected_message
803
+
804
+ def test_create_index_endpoint_set_both_psc_and_public_throw_error (
805
+ self , create_index_endpoint_mock
806
+ ):
807
+ aiplatform .init (project = _TEST_PROJECT )
808
+
809
+ expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set."
810
+
811
+ with pytest .raises (ValueError ) as exception :
812
+ _ = aiplatform .MatchingEngineIndexEndpoint .create (
813
+ display_name = _TEST_INDEX_ENDPOINT_DISPLAY_NAME ,
814
+ description = _TEST_INDEX_ENDPOINT_DESCRIPTION ,
815
+ public_endpoint_enabled = True ,
816
+ labels = _TEST_LABELS ,
817
+ enable_private_service_connect = True ,
818
+ )
819
+
820
+ assert str (exception .value ) == expected_message
821
+
732
822
@pytest .mark .usefixtures ("get_index_endpoint_mock" , "get_index_mock" )
733
823
def test_deploy_index (self , deploy_index_mock , undeploy_index_mock ):
734
824
aiplatform .init (project = _TEST_PROJECT )
0 commit comments