Skip to content

Commit 469db6b

Browse files
authored
feat: Make matching engine API public (#1192)
* Reinstated matching engine * Reinstated VPC-dependent system tests * Debug * Commented out test * Fix resource bug when the key doesn't exist * Skip tests for now * Tweaked logs * Raise * Added VPC preparation and deletion * Fixed firewall deletion code * Reinstate system tests * Removed VPC network generation code * Updated system tests * Fixed matching engine system tests * Reverted skips * Ran linter * Moved protos to their own folder * Removed constant for VPC network * Added protos * Fixed import * Fixed network constsant * Fixed reversion issue * Removed unused imports
1 parent 5949674 commit 469db6b

11 files changed

+97
-20
lines changed

google/cloud/aiplatform/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
Feature,
3939
Featurestore,
4040
)
41+
from google.cloud.aiplatform.matching_engine import (
42+
MatchingEngineIndex,
43+
MatchingEngineIndexEndpoint,
44+
)
4145
from google.cloud.aiplatform.metadata import metadata
4246
from google.cloud.aiplatform.models import Endpoint
4347
from google.cloud.aiplatform.models import Model
@@ -105,6 +109,8 @@
105109
"EntityType",
106110
"Feature",
107111
"Featurestore",
112+
"MatchingEngineIndex",
113+
"MatchingEngineIndexEndpoint",
108114
"ImageDataset",
109115
"HyperparameterTuningJob",
110116
"Model",

google/cloud/aiplatform/_matching_engine/__init__.py google/cloud/aiplatform/matching_engine/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
# limitations under the License.
1616
#
1717

18-
from google.cloud.aiplatform._matching_engine.matching_engine_index import (
18+
from google.cloud.aiplatform.matching_engine.matching_engine_index import (
1919
MatchingEngineIndex,
2020
)
21-
from google.cloud.aiplatform._matching_engine.matching_engine_index_config import (
21+
from google.cloud.aiplatform.matching_engine.matching_engine_index_config import (
2222
BruteForceConfig as MatchingEngineBruteForceAlgorithmConfig,
2323
MatchingEngineIndexConfig as MatchingEngineIndexConfig,
2424
TreeAhConfig as MatchingEngineTreeAhAlgorithmConfig,
2525
)
26-
from google.cloud.aiplatform._matching_engine.matching_engine_index_endpoint import (
26+
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
2727
MatchingEngineIndexEndpoint,
2828
)
2929

google/cloud/aiplatform/_matching_engine/match_service_pb2_grpc.py google/cloud/aiplatform/matching_engine/_protos/match_service_pb2_grpc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
1818
"""Client and server classes corresponding to protobuf-defined services."""
19-
from google.cloud.aiplatform._matching_engine import match_service_pb2
19+
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
2020

2121
import grpc
2222

google/cloud/aiplatform/_matching_engine/matching_engine_index.py google/cloud/aiplatform/matching_engine/matching_engine_index.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
matching_engine_index as gca_matching_engine_index,
2626
)
2727
from google.cloud.aiplatform import initializer
28-
from google.cloud.aiplatform._matching_engine import matching_engine_index_config
28+
from google.cloud.aiplatform.matching_engine import matching_engine_index_config
2929
from google.cloud.aiplatform import utils
3030

3131
_LOGGER = base.Logger(__name__)

google/cloud/aiplatform/_matching_engine/matching_engine_index_endpoint.py google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
from google.auth import credentials as auth_credentials
2222
from google.cloud.aiplatform import base
2323
from google.cloud.aiplatform import initializer
24-
from google.cloud.aiplatform import _matching_engine
24+
from google.cloud.aiplatform import matching_engine
2525
from google.cloud.aiplatform import utils
2626
from google.cloud.aiplatform.compat.types import (
2727
machine_resources as gca_machine_resources_compat,
2828
matching_engine_index_endpoint as gca_matching_engine_index_endpoint,
2929
)
30-
from google.cloud.aiplatform._matching_engine import match_service_pb2
31-
from google.cloud.aiplatform._matching_engine import match_service_pb2_grpc
30+
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
31+
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2_grpc
3232
from google.protobuf import field_mask_pb2
3333

3434
import grpc
@@ -432,7 +432,7 @@ def _build_deployed_index(
432432

433433
def deploy_index(
434434
self,
435-
index: _matching_engine.MatchingEngineIndex,
435+
index: matching_engine.MatchingEngineIndex,
436436
deployed_index_id: str,
437437
display_name: Optional[str] = None,
438438
machine_type: Optional[str] = None,

tests/system/aiplatform/e2e_base.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import pytest
2222
import uuid
23+
2324
from typing import Any, Dict, Generator
2425

2526
from google.api_core import exceptions
@@ -29,8 +30,7 @@
2930
from google.cloud.aiplatform import initializer
3031

3132
_PROJECT = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
32-
_PROJECT_NUMBER = os.getenv("PROJECT_NUMBER")
33-
_VPC_NETWORK_NAME = os.getenv("private-net")
33+
_VPC_NETWORK_URI = os.getenv("_VPC_NETWORK_URI")
3434
_LOCATION = "us-central1"
3535

3636

@@ -136,7 +136,10 @@ def tear_down_resources(self, shared_state: Dict[str, Any]):
136136
# Bring all Endpoints to the front of the list
137137
# Ensures Models are undeployed first before we attempt deletion
138138
shared_state["resources"].sort(
139-
key=lambda r: 1 if isinstance(r, aiplatform.Endpoint) else 2
139+
key=lambda r: 1
140+
if isinstance(r, aiplatform.Endpoint)
141+
or isinstance(r, aiplatform.MatchingEngineIndexEndpoint)
142+
else 2
140143
)
141144

142145
for resource in shared_state["resources"]:
@@ -146,6 +149,7 @@ def tear_down_resources(self, shared_state: Dict[str, Any]):
146149
(
147150
aiplatform.Endpoint,
148151
aiplatform.Featurestore,
152+
aiplatform.MatchingEngineIndexEndpoint,
149153
),
150154
):
151155
# For endpoint, undeploy model then delete endpoint

tests/system/aiplatform/test_matching_engine_index.py

+75-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#
1717

1818
import uuid
19-
import pytest
2019

2120
from google.cloud import aiplatform
2221

@@ -52,10 +51,6 @@
5251
_TEST_INDEX_ENDPOINT_DISPLAY_NAME = "endpoint_name"
5352
_TEST_INDEX_ENDPOINT_DESCRIPTION = "my endpoint"
5453

55-
_TEST_INDEX_ENDPOINT_VPC_NETWORK = "projects/{}/global/networks/{}".format(
56-
e2e_base._PROJECT_NUMBER, e2e_base._VPC_NETWORK_NAME
57-
)
58-
5954
# DEPLOYED INDEX
6055
_TEST_DEPLOYED_INDEX_ID = f"deployed_index_id_{uuid.uuid4()}"
6156
_TEST_DEPLOYED_INDEX_DISPLAY_NAME = f"deployed_index_display_name_{uuid.uuid4()}"
@@ -167,7 +162,6 @@
167162
]
168163

169164

170-
@pytest.mark.skip(reason="TestMatchingEngine not available")
171165
class TestMatchingEngine(e2e_base.TestEndToEnd):
172166

173167
_temp_prefix = "temp_vertex_sdk_e2e_matching_engine_test"
@@ -226,9 +220,84 @@ def test_create_get_list_matching_engine_index(self, shared_state):
226220

227221
assert updated_index.name == get_index.name
228222

223+
# Create endpoint and check that it is listed
224+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
225+
display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME,
226+
description=_TEST_INDEX_ENDPOINT_DESCRIPTION,
227+
network=e2e_base._VPC_NETWORK_URI,
228+
labels=_TEST_LABELS,
229+
)
230+
assert my_index_endpoint.resource_name in [
231+
index_endpoint.resource_name
232+
for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list()
233+
]
234+
235+
assert my_index_endpoint.labels == _TEST_LABELS
236+
assert my_index_endpoint.display_name == _TEST_INDEX_ENDPOINT_DISPLAY_NAME
237+
assert my_index_endpoint.description == _TEST_INDEX_ENDPOINT_DESCRIPTION
238+
239+
shared_state["resources"].append(my_index_endpoint)
240+
241+
# Deploy endpoint
242+
my_index_endpoint = my_index_endpoint.deploy_index(
243+
index=index,
244+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
245+
display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME,
246+
)
247+
248+
# Update endpoint
249+
updated_index_endpoint = my_index_endpoint.update(
250+
display_name=_TEST_DISPLAY_NAME_UPDATE,
251+
description=_TEST_DESCRIPTION_UPDATE,
252+
labels=_TEST_LABELS_UPDATE,
253+
)
254+
255+
assert updated_index_endpoint.labels == _TEST_LABELS_UPDATE
256+
assert updated_index_endpoint.display_name == _TEST_DISPLAY_NAME_UPDATE
257+
assert updated_index_endpoint.description == _TEST_DESCRIPTION_UPDATE
258+
259+
# Mutate deployed index
260+
my_index_endpoint.mutate_deployed_index(
261+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
262+
min_replica_count=_TEST_MIN_REPLICA_COUNT_UPDATED,
263+
max_replica_count=_TEST_MAX_REPLICA_COUNT_UPDATED,
264+
)
265+
266+
deployed_index = my_index_endpoint.deployed_indexes[0]
267+
268+
assert deployed_index.id == _TEST_DEPLOYED_INDEX_ID
269+
assert deployed_index.index == index.resource_name
270+
assert (
271+
deployed_index.automatic_resources.min_replica_count
272+
== _TEST_MIN_REPLICA_COUNT_UPDATED
273+
)
274+
assert (
275+
deployed_index.automatic_resources.max_replica_count
276+
== _TEST_MAX_REPLICA_COUNT_UPDATED
277+
)
278+
279+
# TODO: Test `my_index_endpoint.match` request. This requires running this test in a VPC.
280+
# results = my_index_endpoint.match(
281+
# deployed_index_id=_TEST_DEPLOYED_INDEX_ID, queries=[_TEST_MATCH_QUERY]
282+
# )
283+
284+
# assert results[0][0].id == 870
285+
286+
# Undeploy index
287+
my_index_endpoint = my_index_endpoint.undeploy_index(
288+
deployed_index_id=deployed_index.id
289+
)
290+
229291
# Delete index and check that it is no longer listed
230292
index.delete()
231293
list_indexes = aiplatform.MatchingEngineIndex.list()
232294
assert get_index.resource_name not in [
233295
index.resource_name for index in list_indexes
234296
]
297+
298+
# Delete index endpoint and check that it is no longer listed
299+
my_index_endpoint.delete()
300+
assert my_index_endpoint.resource_name not in [
301+
index_endpoint.resource_name
302+
for index_endpoint in aiplatform.MatchingEngineIndexEndpoint.list()
303+
]

tests/unit/aiplatform/test_matching_engine_index.py

-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ def create_index_mock():
167167
yield create_index_mock
168168

169169

170-
@pytest.mark.skip(reason="MatchingEngineIndex not available")
171170
class TestMatchingEngineIndex:
172171
def setup_method(self):
173172
reload(initializer)

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

-1
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,6 @@ def create_index_endpoint_mock():
383383
yield create_index_endpoint_mock
384384

385385

386-
@pytest.mark.skip(reason="MatchingEngineIndexEndpoint not available")
387386
class TestMatchingEngineIndexEndpoint:
388387
def setup_method(self):
389388
reload(initializer)

0 commit comments

Comments
 (0)