Skip to content

Commit 7e6022b

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: add support for create public index endpoint in matching engine
PiperOrigin-RevId: 524917003
1 parent 4d032d5 commit 7e6022b

File tree

3 files changed

+140
-7
lines changed

3 files changed

+140
-7
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
matching_engine_index_endpoint as gca_matching_engine_index_endpoint,
2929
)
3030
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
31-
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2_grpc
31+
from google.cloud.aiplatform.matching_engine._protos import (
32+
match_service_pb2_grpc,
33+
)
3234
from google.protobuf import field_mask_pb2
3335

3436
import grpc
@@ -130,6 +132,7 @@ def create(
130132
cls,
131133
display_name: str,
132134
network: Optional[str] = None,
135+
public_endpoint_enabled: Optional[bool] = False,
133136
description: Optional[str] = None,
134137
labels: Optional[Dict[str, str]] = None,
135138
project: Optional[str] = None,
@@ -163,6 +166,9 @@ def create(
163166
projects/{project}/global/networks/{network}. Where
164167
{project} is a project number, as in '12345', and {network}
165168
is network name.
169+
public_endpoint_enabled (bool):
170+
Optional. If true, the deployed index will be
171+
accessible through public endpoint.
166172
description (str):
167173
Optional. The description of the IndexEndpoint.
168174
labels (Dict[str, str]):
@@ -203,15 +209,20 @@ def create(
203209
"""
204210
network = network or initializer.global_config.network
205211

206-
if not network:
212+
if not network and not public_endpoint_enabled:
207213
raise ValueError(
208-
"Please provide `network` argument or set network"
209-
"using aiplatform.init(network=...)"
214+
"Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint"
215+
)
216+
217+
if network and public_endpoint_enabled:
218+
raise ValueError(
219+
"`network` and `public_endpoint_enabled` argument should not be set at the same time"
210220
)
211221

212222
return cls._create(
213223
display_name=display_name,
214224
network=network,
225+
public_endpoint_enabled=public_endpoint_enabled,
215226
description=description,
216227
labels=labels,
217228
project=project,
@@ -227,6 +238,7 @@ def _create(
227238
cls,
228239
display_name: str,
229240
network: Optional[str] = None,
241+
public_endpoint_enabled: Optional[bool] = False,
230242
description: Optional[str] = None,
231243
labels: Optional[Dict[str, str]] = None,
232244
project: Optional[str] = None,
@@ -253,6 +265,9 @@ def _create(
253265
projects/{project}/global/networks/{network}. Where
254266
{project} is a project number, as in '12345', and {network}
255267
is network name.
268+
public_endpoint_enabled (bool):
269+
Optional. If true, the deployed index will be
270+
accessible through public endpoint.
256271
description (str):
257272
Optional. The description of the IndexEndpoint.
258273
labels (Dict[str, str]):
@@ -288,9 +303,17 @@ def _create(
288303
Returns:
289304
MatchingEngineIndexEndpoint - IndexEndpoint resource object
290305
"""
291-
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
292-
display_name=display_name, description=description, network=network
293-
)
306+
307+
if public_endpoint_enabled:
308+
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
309+
display_name=display_name,
310+
description=description,
311+
public_endpoint_enabled=public_endpoint_enabled,
312+
)
313+
else:
314+
gapic_index_endpoint = gca_matching_engine_index_endpoint.IndexEndpoint(
315+
display_name=display_name, description=description, network=network
316+
)
294317

295318
if labels:
296319
utils.validate_labels(labels)

tests/system/aiplatform/test_matching_engine_index.py

+49
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,15 @@
5252

5353
# ENDPOINT
5454
_TEST_INDEX_ENDPOINT_DISPLAY_NAME = "endpoint_name"
55+
_TEST_PUBLIC_INDEX_ENDPOINT_DISPLAY_NAME = "public_endpoint_name"
5556
_TEST_INDEX_ENDPOINT_DESCRIPTION = "my endpoint"
57+
_TEST_PUBLIC_INDEX_ENDPOINT_DESCRIPTION = "my public endpoint"
5658

5759
# DEPLOYED INDEX
5860
_TEST_DEPLOYED_INDEX_ID = f"deployed_index_id_{uuid.uuid4()}".replace("-", "_")
5961
_TEST_DEPLOYED_INDEX_DISPLAY_NAME = f"deployed_index_display_name_{uuid.uuid4()}"
62+
_TEST_DEPLOYED_INDEX_ID_PUBLIC = f"deployed_index_id_{uuid.uuid4()}".replace("-", "_")
63+
_TEST_DEPLOYED_INDEX_DISPLAY_NAME_PUBLIC = f"deployed_index_display_name_{uuid.uuid4()}"
6064
_TEST_MIN_REPLICA_COUNT_UPDATED = 4
6165
_TEST_MAX_REPLICA_COUNT_UPDATED = 4
6266

@@ -241,6 +245,27 @@ def test_create_get_list_matching_engine_index(self, shared_state):
241245
assert my_index_endpoint.display_name == _TEST_INDEX_ENDPOINT_DISPLAY_NAME
242246
assert my_index_endpoint.description == _TEST_INDEX_ENDPOINT_DESCRIPTION
243247

248+
# Create endpoint and check that it is listed
249+
public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
250+
display_name=_TEST_PUBLIC_INDEX_ENDPOINT_DISPLAY_NAME,
251+
description=_TEST_PUBLIC_INDEX_ENDPOINT_DESCRIPTION,
252+
public_endpoint_enabled=True,
253+
labels=_TEST_LABELS,
254+
)
255+
assert public_index_endpoint.resource_name in [
256+
index_endpoint.resource_name
257+
for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list()
258+
]
259+
260+
assert public_index_endpoint.labels == _TEST_LABELS
261+
assert (
262+
public_index_endpoint.display_name
263+
== _TEST_PUBLIC_INDEX_ENDPOINT_DISPLAY_NAME
264+
)
265+
assert (
266+
public_index_endpoint.description == _TEST_PUBLIC_INDEX_ENDPOINT_DESCRIPTION
267+
)
268+
244269
shared_state["resources"].append(my_index_endpoint)
245270

246271
# Deploy endpoint
@@ -250,6 +275,15 @@ def test_create_get_list_matching_engine_index(self, shared_state):
250275
display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME,
251276
)
252277

278+
# Deploy public endpoint
279+
public_index_endpoint = public_index_endpoint.deploy_index(
280+
index=index,
281+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID_PUBLIC,
282+
display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME_PUBLIC,
283+
min_replica_count=_TEST_MIN_REPLICA_COUNT_UPDATED,
284+
max_replica_count=_TEST_MAX_REPLICA_COUNT_UPDATED,
285+
)
286+
253287
# Update endpoint
254288
updated_index_endpoint = my_index_endpoint.update(
255289
display_name=_TEST_DISPLAY_NAME_UPDATE,
@@ -268,6 +302,7 @@ def test_create_get_list_matching_engine_index(self, shared_state):
268302
max_replica_count=_TEST_MAX_REPLICA_COUNT_UPDATED,
269303
)
270304

305+
# deployed index on private endpoint.
271306
deployed_index = my_index_endpoint.deployed_indexes[0]
272307

273308
assert deployed_index.id == _TEST_DEPLOYED_INDEX_ID
@@ -281,6 +316,20 @@ def test_create_get_list_matching_engine_index(self, shared_state):
281316
== _TEST_MAX_REPLICA_COUNT_UPDATED
282317
)
283318

319+
# deployed index on public endpoint.
320+
deployed_index_public = public_index_endpoint.deployed_indexes[0]
321+
322+
assert deployed_index_public.id == _TEST_DEPLOYED_INDEX_ID_PUBLIC
323+
assert deployed_index_public.index == index.resource_name
324+
assert (
325+
deployed_index_public.automatic_resources.min_replica_count
326+
== _TEST_MIN_REPLICA_COUNT_UPDATED
327+
)
328+
assert (
329+
deployed_index_public.automatic_resources.max_replica_count
330+
== _TEST_MAX_REPLICA_COUNT_UPDATED
331+
)
332+
284333
# TODO: Test `my_index_endpoint.match` request. This requires running this test in a VPC.
285334
# results = my_index_endpoint.match(
286335
# deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=[_TEST_MATCH_QUERY]

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

+61
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc
547547
network=_TEST_INDEX_ENDPOINT_VPC_NETWORK,
548548
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
549549
labels=_TEST_LABELS,
550+
public_endpoint_enabled=False,
550551
)
551552

552553
create_index_endpoint_mock.assert_called_once_with(
@@ -555,6 +556,66 @@ def test_create_index_endpoint_with_network_init(self, create_index_endpoint_moc
555556
metadata=_TEST_REQUEST_METADATA,
556557
)
557558

559+
@pytest.mark.usefixtures("get_index_endpoint_mock")
560+
def test_create_index_endpoint_with_public_endpoint_enabled(
561+
self, create_index_endpoint_mock
562+
):
563+
aiplatform.init(project=_TEST_PROJECT)
564+
565+
aiplatform.MatchingEngineIndexEndpoint.create(
566+
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
567+
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
568+
public_endpoint_enabled=True,
569+
labels=_TEST_LABELS,
570+
)
571+
572+
expected = gca_index_endpoint.IndexEndpoint(
573+
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
574+
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
575+
public_endpoint_enabled=True,
576+
labels=_TEST_LABELS,
577+
)
578+
579+
create_index_endpoint_mock.assert_called_once_with(
580+
parent=_TEST_PARENT,
581+
index_endpoint=expected,
582+
metadata=_TEST_REQUEST_METADATA,
583+
)
584+
585+
def test_create_index_endpoint_missing_argument_throw_error(
586+
self, create_index_endpoint_mock
587+
):
588+
aiplatform.init(project=_TEST_PROJECT)
589+
590+
expected_message = "Please provide `network` argument for private endpoint or provide `public_endpoint_enabled` to deploy this index to a public endpoint"
591+
592+
with pytest.raises(ValueError) as exception:
593+
_ = aiplatform.MatchingEngineIndexEndpoint.create(
594+
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
595+
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
596+
labels=_TEST_LABELS,
597+
)
598+
599+
assert str(exception.value) == expected_message
600+
601+
def test_create_index_endpoint_set_both_throw_error(
602+
self, create_index_endpoint_mock
603+
):
604+
aiplatform.init(project=_TEST_PROJECT)
605+
606+
expected_message = "`network` and `public_endpoint_enabled` argument should not be set at the same time"
607+
608+
with pytest.raises(ValueError) as exception:
609+
_ = aiplatform.MatchingEngineIndexEndpoint.create(
610+
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
611+
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
612+
public_endpoint_enabled=True,
613+
network=_TEST_INDEX_ENDPOINT_VPC_NETWORK,
614+
labels=_TEST_LABELS,
615+
)
616+
617+
assert str(exception.value) == expected_message
618+
558619
@pytest.mark.usefixtures("get_index_endpoint_mock", "get_index_mock")
559620
def test_deploy_index(self, deploy_index_mock, undeploy_index_mock):
560621
aiplatform.init(project=_TEST_PROJECT)

0 commit comments

Comments
 (0)