Skip to content

Commit 750e17b

Browse files
lingyinwcopybara-github
authored andcommitted
feat: add encryption_spec_key_name, enable_private_service_connect,project_allowlist to MatchingEngineIndexEndpoint create.
PiperOrigin-RevId: 581328160
1 parent fcf05cb commit 750e17b

File tree

4 files changed

+177
-9
lines changed

4 files changed

+177
-9
lines changed

google/cloud/aiplatform/compat/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
types.model_garden_service = types.model_garden_service_v1beta1
112112
types.model_monitoring = types.model_monitoring_v1beta1
113113
types.model_service = types.model_service_v1beta1
114+
types.service_networking = types.service_networking_v1beta1
114115
types.operation = types.operation_v1beta1
115116
types.pipeline_failure_policy = types.pipeline_failure_policy_v1beta1
116117
types.pipeline_job = types.pipeline_job_v1beta1
@@ -208,6 +209,7 @@
208209
types.model_deployment_monitoring_job = types.model_deployment_monitoring_job_v1
209210
types.model_monitoring = types.model_monitoring_v1
210211
types.model_service = types.model_service_v1
212+
types.service_networking = types.service_networking_v1
211213
types.operation = types.operation_v1
212214
types.pipeline_failure_policy = types.pipeline_failure_policy_v1
213215
types.pipeline_job = types.pipeline_job_v1

google/cloud/aiplatform/compat/types/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
pipeline_state as pipeline_state_v1beta1,
7676
prediction_service as prediction_service_v1beta1,
7777
publisher_model as publisher_model_v1beta1,
78+
service_networking as service_networking_v1beta1,
7879
schedule as schedule_v1beta1,
7980
schedule_service as schedule_service_v1beta1,
8081
specialist_pool as specialist_pool_v1beta1,
@@ -147,6 +148,7 @@
147148
publisher_model as publisher_model_v1,
148149
schedule as schedule_v1,
149150
schedule_service as schedule_service_v1,
151+
service_networking as service_networking_v1,
150152
specialist_pool as specialist_pool_v1,
151153
specialist_pool_service as specialist_pool_service_v1,
152154
study as study_v1,

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+80-6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
matching_engine_index_endpoint as gca_matching_engine_index_endpoint,
2929
match_service_v1beta1 as gca_match_service_v1beta1,
3030
index_v1beta1 as gca_index_v1beta1,
31+
service_networking as gca_service_networking,
32+
encryption_spec as gca_encryption_spec,
3133
)
3234
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
3335
from google.cloud.aiplatform.matching_engine._protos import (
@@ -145,6 +147,9 @@ def create(
145147
credentials: Optional[auth_credentials.Credentials] = None,
146148
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
147149
sync: bool = True,
150+
enable_private_service_connect: Optional[bool] = False,
151+
project_allowlist: Optional[Sequence[str]] = None,
152+
encryption_spec_key_name: Optional[str] = None,
148153
) -> "MatchingEngineIndexEndpoint":
149154
"""Creates a MatchingEngineIndexEndpoint resource.
150155
@@ -205,6 +210,23 @@ def create(
205210
Optional. Whether to execute this creation synchronously. If False, this method
206211
will be executed in concurrent Future and any downstream object will
207212
be immediately returned and synced when the Future has completed.
213+
enable_private_service_connect (bool):
214+
If true, expose the index endpoint via private service connect.
215+
project_allowlist (Sequence[str]):
216+
Optional. List of projects from which the forwarding rule will
217+
target the service attachment.
218+
encryption_spec_key_name (str):
219+
Optional. The Cloud KMS resource identifier of the customer
220+
managed encryption key used to protect the index endpoint.
221+
Has the form:
222+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
223+
The key needs to be in the same region as where the compute
224+
resource is created.
225+
226+
If set, this index endpoint and all sub-resources of this
227+
index endpoint will be secured by this key.
228+
The key needs to be in the same region as where the index
229+
endpoint is created.
208230
209231
Returns:
210232
MatchingEngineIndexEndpoint - IndexEndpoint resource object
@@ -214,14 +236,27 @@ def create(
214236
"""
215237
network = network or initializer.global_config.network
216238

217-
if not network and not public_endpoint_enabled:
239+
if not (network or public_endpoint_enabled or enable_private_service_connect):
218240
raise ValueError(
219-
"Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint"
241+
"Please provide `network` argument for Private Service Access endpoint,"
242+
"or provide `enable_private_service_connect` for Private Service"
243+
"Connect endpoint, or provide `public_endpoint_enabled` to"
244+
"deploy to a public endpoint"
220245
)
221246

222-
if network and public_endpoint_enabled:
247+
if (
248+
sum(
249+
bool(network_setting)
250+
for network_setting in [
251+
network,
252+
public_endpoint_enabled,
253+
enable_private_service_connect,
254+
]
255+
)
256+
> 1
257+
):
223258
raise ValueError(
224-
"`network` and `public_endpoint_enabled` argument should not be set at the same time"
259+
"One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set."
225260
)
226261

227262
return cls._create(
@@ -235,6 +270,9 @@ def create(
235270
credentials=credentials,
236271
request_metadata=request_metadata,
237272
sync=sync,
273+
enable_private_service_connect=enable_private_service_connect,
274+
project_allowlist=project_allowlist,
275+
encryption_spec_key_name=encryption_spec_key_name,
238276
)
239277

240278
@classmethod
@@ -251,6 +289,9 @@ def _create(
251289
credentials: Optional[auth_credentials.Credentials] = None,
252290
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
253291
sync: bool = True,
292+
enable_private_service_connect: Optional[bool] = False,
293+
project_allowlist: Optional[Sequence[str]] = None,
294+
encryption_spec_key_name: Optional[str] = None,
254295
) -> "MatchingEngineIndexEndpoint":
255296
"""Helper method to ensure network synchronization and to
256297
create a MatchingEngineIndexEndpoint resource.
@@ -304,20 +345,53 @@ def _create(
304345
Optional. Whether to execute this creation synchronously. If False, this method
305346
will be executed in concurrent Future and any downstream object will
306347
be immediately returned and synced when the Future has completed.
348+
encryption_spec_key_name (str):
349+
Immutable. Customer-managed encryption key
350+
spec for an IndexEndpoint. If set, this
351+
IndexEndpoint and all sub-resources of this
352+
IndexEndpoint will be secured by this key.
353+
enable_private_service_connect (bool):
354+
Required. If true, expose the IndexEndpoint
355+
via private service connect.
356+
project_allowlist (MutableSequence[str]):
357+
A list of Projects from which the forwarding
358+
rule will target the service attachment.
307359
308360
Returns:
309361
MatchingEngineIndexEndpoint - IndexEndpoint resource object
310362
"""
311-
363+
# Public
312364
if public_endpoint_enabled:
313365
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
314366
display_name=display_name,
315367
description=description,
316368
public_endpoint_enabled=public_endpoint_enabled,
369+
encryption_spec=gca_encryption_spec.EncryptionSpec(
370+
kms_key_name=encryption_spec_key_name
371+
),
372+
)
373+
# PSA
374+
elif network:
375+
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
376+
display_name=display_name,
377+
description=description,
378+
network=network,
379+
encryption_spec=gca_encryption_spec.EncryptionSpec(
380+
kms_key_name=encryption_spec_key_name
381+
),
317382
)
383+
# PSC
318384
else:
319385
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
320-
display_name=display_name, description=description, network=network
386+
display_name=display_name,
387+
description=description,
388+
private_service_connect_config=gca_service_networking.PrivateServiceConnectConfig(
389+
project_allowlist=project_allowlist,
390+
enable_private_service_connect=enable_private_service_connect,
391+
),
392+
encryption_spec=gca_encryption_spec.EncryptionSpec(
393+
kms_key_name=encryption_spec_key_name
394+
),
321395
)
322396

323397
if labels:

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

+93-3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
index as gca_index,
3535
match_service_v1beta1 as gca_match_service_v1beta1,
3636
index_v1beta1 as gca_index_v1beta1,
37+
service_networking as gca_service_networking,
38+
encryption_spec as gca_encryption_spec,
3739
)
3840
from google.cloud.aiplatform.compat.services import (
3941
index_endpoint_service_client,
@@ -236,6 +238,8 @@
236238
_TEST_APPROX_NUM_NEIGHBORS = 2
237239
_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE = 0.8
238240
_TEST_RETURN_FULL_DATAPOINT = True
241+
_TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name"
242+
_TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"]
239243

240244

241245
def uuid_mock():
@@ -619,6 +623,7 @@ def test_create_index_endpoint(self, create_index_endpoint_mock, sync):
619623
network=_TEST_INDEX_ENDPOINT_VPC_NETWORK,
620624
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
621625
labels=_TEST_LABELS,
626+
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
622627
)
623628

624629
if not sync:
@@ -629,6 +634,42 @@ def test_create_index_endpoint(self, create_index_endpoint_mock, sync):
629634
network=_TEST_INDEX_ENDPOINT_VPC_NETWORK,
630635
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
631636
labels=_TEST_LABELS,
637+
encryption_spec=gca_encryption_spec.EncryptionSpec(
638+
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
639+
),
640+
)
641+
create_index_endpoint_mock.assert_called_once_with(
642+
parent=_TEST_PARENT,
643+
index_endpoint=expected,
644+
metadata=_TEST_REQUEST_METADATA,
645+
)
646+
647+
@pytest.mark.usefixtures("get_index_endpoint_mock")
648+
def test_create_index_endpoint_with_private_service_connect(
649+
self, create_index_endpoint_mock
650+
):
651+
aiplatform.init(project=_TEST_PROJECT)
652+
653+
aiplatform.MatchingEngineIndexEndpoint.create(
654+
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
655+
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
656+
labels=_TEST_LABELS,
657+
enable_private_service_connect=True,
658+
project_allowlist=_TEST_PROJECT_ALLOWLIST,
659+
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
660+
)
661+
662+
expected = gca_index_endpoint.IndexEndpoint(
663+
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
664+
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
665+
labels=_TEST_LABELS,
666+
private_service_connect_config=gca_service_networking.PrivateServiceConnectConfig(
667+
project_allowlist=_TEST_PROJECT_ALLOWLIST,
668+
enable_private_service_connect=True,
669+
),
670+
encryption_spec=gca_encryption_spec.EncryptionSpec(
671+
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
672+
),
632673
)
633674
create_index_endpoint_mock.assert_called_once_with(
634675
parent=_TEST_PARENT,
@@ -644,6 +685,7 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc
644685
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
645686
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
646687
labels=_TEST_LABELS,
688+
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
647689
)
648690

649691
expected = gca_index_endpoint.IndexEndpoint(
@@ -652,6 +694,9 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc
652694
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
653695
labels=_TEST_LABELS,
654696
public_endpoint_enabled=False,
697+
encryption_spec=gca_encryption_spec.EncryptionSpec(
698+
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
699+
),
655700
)
656701

657702
create_index_endpoint_mock.assert_called_once_with(
@@ -671,6 +716,7 @@ def test_create_index_endpoint_with_public_endpoint_enabled(
671716
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
672717
public_endpoint_enabled=True,
673718
labels=_TEST_LABELS,
719+
encryption_spec_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME,
674720
)
675721

676722
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
@@ -682,6 +728,9 @@ def test_create_index_endpoint_with_public_endpoint_enabled(
682728
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
683729
public_endpoint_enabled=True,
684730
labels=_TEST_LABELS,
731+
encryption_spec=gca_encryption_spec.EncryptionSpec(
732+
kms_key_name=_TEST_ENCRYPTION_SPEC_KEY_NAME
733+
),
685734
)
686735

687736
create_index_endpoint_mock.assert_called_once_with(
@@ -700,7 +749,12 @@ def test_create_index_endpoint_missing_argument_throw_error(
700749
):
701750
aiplatform.init(project=_TEST_PROJECT)
702751

703-
expected_message = "Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint"
752+
expected_message = (
753+
"Please provide `network` argument for Private Service Access endpoint,"
754+
"or provide `enable_private_service_connect` for Private Service"
755+
"Connect endpoint, or provide `public_endpoint_enabled` to"
756+
"deploy to a public endpoint"
757+
)
704758

705759
with pytest.raises(ValueError) as exception:
706760
_ = aiplatform.MatchingEngineIndexEndpoint.create(
@@ -711,12 +765,12 @@ def test_create_index_endpoint_missing_argument_throw_error(
711765

712766
assert str(exception.value) == expected_message
713767

714-
def test_create_index_endpoint_set_both_throw_error(
768+
def test_create_index_endpoint_set_both_psa_and_public_throw_error(
715769
self, create_index_endpoint_mock
716770
):
717771
aiplatform.init(project=_TEST_PROJECT)
718772

719-
expected_message = "`network` and `public_endpoint_enabled` argument should not be set at the same time"
773+
expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set."
720774

721775
with pytest.raises(ValueError) as exception:
722776
_ = aiplatform.MatchingEngineIndexEndpoint.create(
@@ -729,6 +783,42 @@ def test_create_index_endpoint_set_both_throw_error(
729783

730784
assert str(exception.value) == expected_message
731785

786+
def test_create_index_endpoint_set_both_psa_and_psc_throw_error(
787+
self, create_index_endpoint_mock
788+
):
789+
aiplatform.init(project=_TEST_PROJECT)
790+
791+
expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set."
792+
793+
with pytest.raises(ValueError) as exception:
794+
_ = aiplatform.MatchingEngineIndexEndpoint.create(
795+
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
796+
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
797+
network=_TEST_INDEX_ENDPOINT_VPC_NETWORK,
798+
labels=_TEST_LABELS,
799+
enable_private_service_connect=True,
800+
)
801+
802+
assert str(exception.value) == expected_message
803+
804+
def test_create_index_endpoint_set_both_psc_and_public_throw_error(
805+
self, create_index_endpoint_mock
806+
):
807+
aiplatform.init(project=_TEST_PROJECT)
808+
809+
expected_message = "One and only one among network, public_endpoint_enabled and enable_private_service_connect should be set."
810+
811+
with pytest.raises(ValueError) as exception:
812+
_ = aiplatform.MatchingEngineIndexEndpoint.create(
813+
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
814+
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
815+
public_endpoint_enabled=True,
816+
labels=_TEST_LABELS,
817+
enable_private_service_connect=True,
818+
)
819+
820+
assert str(exception.value) == expected_message
821+
732822
@pytest.mark.usefixtures("get_index_endpoint_mock", "get_index_mock")
733823
def test_deploy_index(self, deploy_index_mock, undeploy_index_mock):
734824
aiplatform.init(project=_TEST_PROJECT)

0 commit comments

Comments
 (0)