Skip to content

Commit 0ecfe1e

Browse files
morganduSam Goodman
and
Sam Goodman
authored
feat: add update endpoint (#1162)
* feat: add update endpoint * add validate_traffic and validate_traffic_split * remove validation, add system tests * Text fixes * Nox blacken change Co-authored-by: Sam Goodman <[email protected]>
1 parent b4a0bee commit 0ecfe1e

File tree

3 files changed

+211
-45
lines changed

3 files changed

+211
-45
lines changed

google/cloud/aiplatform/models.py

+109-5
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from google.protobuf import field_mask_pb2, json_format
5252

5353
_DEFAULT_MACHINE_TYPE = "n1-standard-2"
54+
_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0"
5455

5556
_LOGGER = base.Logger(__name__)
5657

@@ -485,7 +486,7 @@ def _allocate_traffic(
485486
new_traffic_split[deployed_model] += 1
486487
unallocated_traffic -= 1
487488

488-
new_traffic_split["0"] = traffic_percentage
489+
new_traffic_split[_DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY] = traffic_percentage
489490

490491
return new_traffic_split
491492

@@ -611,7 +612,6 @@ def _validate_deploy_args(
611612
raise ValueError("Traffic percentage cannot be negative.")
612613

613614
elif traffic_split:
614-
# TODO(b/172678233) verify every referenced deployed model exists
615615
if sum(traffic_split.values()) != 100:
616616
raise ValueError(
617617
"Sum of all traffic within traffic split needs to be 100."
@@ -1290,6 +1290,110 @@ def _instantiate_prediction_client(
12901290
prediction_client=True,
12911291
)
12921292

1293+
def update(
1294+
self,
1295+
display_name: Optional[str] = None,
1296+
description: Optional[str] = None,
1297+
labels: Optional[Dict[str, str]] = None,
1298+
traffic_split: Optional[Dict[str, int]] = None,
1299+
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
1300+
update_request_timeout: Optional[float] = None,
1301+
) -> "Endpoint":
1302+
"""Updates an endpoint.
1303+
1304+
Example usage:
1305+
1306+
my_endpoint = my_endpoint.update(
1307+
display_name='my-updated-endpoint',
1308+
description='my updated description',
1309+
labels={'key': 'value'},
1310+
traffic_split={
1311+
'123456': 20,
1312+
'234567': 80,
1313+
},
1314+
)
1315+
1316+
Args:
1317+
display_name (str):
1318+
Optional. The display name of the Endpoint.
1319+
The name can be up to 128 characters long and can be consist of any UTF-8
1320+
characters.
1321+
description (str):
1322+
Optional. The description of the Endpoint.
1323+
labels (Dict[str, str]):
1324+
Optional. The labels with user-defined metadata to organize your Endpoints.
1325+
Label keys and values can be no longer than 64 characters
1326+
(Unicode codepoints), can only contain lowercase letters, numeric
1327+
characters, underscores and dashes. International characters are allowed.
1328+
See https://goo.gl/xmQnxf for more information and examples of labels.
1329+
traffic_split (Dict[str, int]):
1330+
Optional. A map from a DeployedModel's ID to the percentage of this Endpoint's
1331+
traffic that should be forwarded to that DeployedModel.
1332+
If a DeployedModel's ID is not listed in this map, then it receives no traffic.
1333+
The traffic percentage values must add up to 100, or map must be empty if
1334+
the Endpoint is to not accept any traffic at a moment.
1335+
request_metadata (Sequence[Tuple[str, str]]):
1336+
Optional. Strings which should be sent along with the request as metadata.
1337+
update_request_timeout (float):
1338+
Optional. The timeout for the update request in seconds.
1339+
1340+
Returns:
1341+
Endpoint - Updated endpoint resource.
1342+
1343+
Raises:
1344+
ValueError: If `labels` is not the correct format.
1345+
"""
1346+
1347+
self.wait()
1348+
1349+
current_endpoint_proto = self.gca_resource
1350+
copied_endpoint_proto = current_endpoint_proto.__class__(current_endpoint_proto)
1351+
1352+
update_mask: List[str] = []
1353+
1354+
if display_name:
1355+
utils.validate_display_name(display_name)
1356+
copied_endpoint_proto.display_name = display_name
1357+
update_mask.append("display_name")
1358+
1359+
if description:
1360+
copied_endpoint_proto.description = description
1361+
update_mask.append("description")
1362+
1363+
if labels:
1364+
utils.validate_labels(labels)
1365+
copied_endpoint_proto.labels = labels
1366+
update_mask.append("labels")
1367+
1368+
if traffic_split:
1369+
update_mask.append("traffic_split")
1370+
copied_endpoint_proto.traffic_split = traffic_split
1371+
1372+
update_mask = field_mask_pb2.FieldMask(paths=update_mask)
1373+
1374+
_LOGGER.log_action_start_against_resource(
1375+
"Updating",
1376+
"endpoint",
1377+
self,
1378+
)
1379+
1380+
update_endpoint_lro = self.api_client.update_endpoint(
1381+
endpoint=copied_endpoint_proto,
1382+
update_mask=update_mask,
1383+
metadata=request_metadata,
1384+
timeout=update_request_timeout,
1385+
)
1386+
1387+
_LOGGER.log_action_started_against_resource_with_lro(
1388+
"Update", "endpoint", self.__class__, update_endpoint_lro
1389+
)
1390+
1391+
update_endpoint_lro.result()
1392+
1393+
_LOGGER.log_action_completed_against_resource("endpoint", "updated", self)
1394+
1395+
return self
1396+
12931397
def predict(
12941398
self,
12951399
instances: List,
@@ -1445,15 +1549,15 @@ def list(
14451549
credentials=credentials,
14461550
)
14471551

1448-
def list_models(self) -> Sequence[gca_endpoint_compat.DeployedModel]:
1552+
def list_models(self) -> List[gca_endpoint_compat.DeployedModel]:
14491553
"""Returns a list of the models deployed to this Endpoint.
14501554
14511555
Returns:
1452-
deployed_models (Sequence[aiplatform.gapic.DeployedModel]):
1556+
deployed_models (List[aiplatform.gapic.DeployedModel]):
14531557
A list of the models deployed in this Endpoint.
14541558
"""
14551559
self._sync_gca_resource()
1456-
return self._gca_resource.deployed_models
1560+
return list(self._gca_resource.deployed_models)
14571561

14581562
def undeploy_all(self, sync: bool = True) -> "Endpoint":
14591563
"""Undeploys every model deployed to this Endpoint.

tests/system/aiplatform/test_model_upload.py

+9
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,12 @@ def test_upload_and_deploy_xgboost_model(self, shared_state):
7676
assert model.display_name == "new_name"
7777
assert model.description == "new_description"
7878
assert model.labels == {"my_label": "updated"}
79+
80+
assert len(endpoint.list_models) == 1
81+
endpoint.deploy(model, traffic_percentage=100)
82+
assert len(endpoint.list_models) == 2
83+
traffic_split = {
84+
deployed_model.id: 50 for deployed_model in endpoint.list_models()
85+
}
86+
endpoint.update(traffic_split=traffic_split)
87+
assert endpoint.traffic_split == traffic_split

tests/unit/aiplatform/test_endpoints.py

+93-40
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from google.api_core import operation as ga_operation
2626
from google.auth import credentials as auth_credentials
2727

28+
from google.protobuf import field_mask_pb2
29+
2830
from google.cloud import aiplatform
2931
from google.cloud.aiplatform import base
3032
from google.cloud.aiplatform import initializer
@@ -58,6 +60,8 @@
5860
_TEST_ID_2 = "4366591682456584192"
5961
_TEST_ID_3 = "5820582938582924817"
6062
_TEST_DESCRIPTION = "test-description"
63+
_TEST_REQUEST_METADATA = ()
64+
_TEST_TIMEOUT = None
6165

6266
_TEST_ENDPOINT_NAME = (
6367
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}"
@@ -270,6 +274,16 @@ def create_endpoint_mock():
270274
yield create_endpoint_mock
271275

272276

277+
@pytest.fixture
278+
def update_endpoint_mock():
279+
with mock.patch.object(
280+
endpoint_service_client.EndpointServiceClient, "update_endpoint"
281+
) as update_endpoint_mock:
282+
update_endpoint_lro_mock = mock.Mock(ga_operation.Operation)
283+
update_endpoint_mock.return_value = update_endpoint_lro_mock
284+
yield update_endpoint_mock
285+
286+
273287
@pytest.fixture
274288
def deploy_model_mock():
275289
with mock.patch.object(
@@ -726,6 +740,54 @@ def test_create_with_labels(self, create_endpoint_mock, sync):
726740
timeout=None,
727741
)
728742

743+
@pytest.mark.usefixtures("get_endpoint_mock")
744+
def test_update_endpoint(self, update_endpoint_mock):
745+
endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
746+
endpoint.update(
747+
display_name=_TEST_DISPLAY_NAME,
748+
description=_TEST_DESCRIPTION,
749+
labels=_TEST_LABELS,
750+
)
751+
752+
expected_endpoint = gca_endpoint.Endpoint(
753+
name=_TEST_ENDPOINT_NAME,
754+
display_name=_TEST_DISPLAY_NAME,
755+
description=_TEST_DESCRIPTION,
756+
labels=_TEST_LABELS,
757+
encryption_spec=_TEST_ENCRYPTION_SPEC,
758+
)
759+
760+
expected_update_mask = field_mask_pb2.FieldMask(
761+
paths=["display_name", "description", "labels"]
762+
)
763+
764+
update_endpoint_mock.assert_called_once_with(
765+
endpoint=expected_endpoint,
766+
update_mask=expected_update_mask,
767+
metadata=_TEST_REQUEST_METADATA,
768+
timeout=_TEST_TIMEOUT,
769+
)
770+
771+
@pytest.mark.usefixtures("get_endpoint_with_models_mock")
772+
def test_update_traffic_split(self, update_endpoint_mock):
773+
endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
774+
endpoint.update(traffic_split={_TEST_ID: 10, _TEST_ID_2: 80, _TEST_ID_3: 10})
775+
776+
expected_endpoint = gca_endpoint.Endpoint(
777+
name=_TEST_ENDPOINT_NAME,
778+
display_name=_TEST_DISPLAY_NAME,
779+
deployed_models=_TEST_DEPLOYED_MODELS,
780+
traffic_split={_TEST_ID: 10, _TEST_ID_2: 80, _TEST_ID_3: 10},
781+
)
782+
expected_update_mask = field_mask_pb2.FieldMask(paths=["traffic_split"])
783+
784+
update_endpoint_mock.assert_called_once_with(
785+
endpoint=expected_endpoint,
786+
update_mask=expected_update_mask,
787+
metadata=_TEST_REQUEST_METADATA,
788+
timeout=_TEST_TIMEOUT,
789+
)
790+
729791
@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
730792
@pytest.mark.parametrize("sync", [True, False])
731793
def test_deploy(self, deploy_model_mock, sync):
@@ -920,7 +982,7 @@ def test_deploy_raise_error_max_replica(self, sync):
920982
)
921983
test_endpoint.deploy(model=test_model, max_replica_count=-2, sync=sync)
922984

923-
@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
985+
@pytest.mark.usefixtures("get_endpoint_with_models_mock", "get_model_mock")
924986
@pytest.mark.parametrize("sync", [True, False])
925987
def test_deploy_raise_error_traffic_split(self, sync):
926988
with pytest.raises(ValueError):
@@ -973,48 +1035,39 @@ def test_deploy_with_traffic_percent(self, deploy_model_mock, sync):
9731035
timeout=None,
9741036
)
9751037

976-
@pytest.mark.usefixtures("get_model_mock")
1038+
@pytest.mark.usefixtures("get_endpoint_with_models_mock", "get_model_mock")
9771039
@pytest.mark.parametrize("sync", [True, False])
9781040
def test_deploy_with_traffic_split(self, deploy_model_mock, sync):
979-
with mock.patch.object(
980-
endpoint_service_client.EndpointServiceClient, "get_endpoint"
981-
) as get_endpoint_mock:
982-
get_endpoint_mock.return_value = gca_endpoint.Endpoint(
983-
display_name=_TEST_DISPLAY_NAME,
984-
name=_TEST_ENDPOINT_NAME,
985-
traffic_split={"model1": 100},
986-
)
987-
988-
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
989-
test_model = models.Model(_TEST_ID)
990-
test_model._gca_resource.supported_deployment_resources_types.append(
991-
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
992-
)
993-
test_endpoint.deploy(
994-
model=test_model,
995-
traffic_split={"model1": 30, "0": 70},
996-
sync=sync,
997-
deploy_request_timeout=None,
998-
)
1041+
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
1042+
test_model = models.Model(_TEST_ID)
1043+
test_model._gca_resource.supported_deployment_resources_types.append(
1044+
aiplatform.gapic.Model.DeploymentResourcesType.AUTOMATIC_RESOURCES
1045+
)
1046+
test_endpoint.deploy(
1047+
model=test_model,
1048+
traffic_split={_TEST_ID: 10, _TEST_ID_2: 40, _TEST_ID_3: 10, "0": 40},
1049+
sync=sync,
1050+
deploy_request_timeout=None,
1051+
)
9991052

1000-
if not sync:
1001-
test_endpoint.wait()
1002-
automatic_resources = gca_machine_resources.AutomaticResources(
1003-
min_replica_count=1,
1004-
max_replica_count=1,
1005-
)
1006-
deployed_model = gca_endpoint.DeployedModel(
1007-
automatic_resources=automatic_resources,
1008-
model=test_model.resource_name,
1009-
display_name=None,
1010-
)
1011-
deploy_model_mock.assert_called_once_with(
1012-
endpoint=test_endpoint.resource_name,
1013-
deployed_model=deployed_model,
1014-
traffic_split={"model1": 30, "0": 70},
1015-
metadata=(),
1016-
timeout=None,
1017-
)
1053+
if not sync:
1054+
test_endpoint.wait()
1055+
automatic_resources = gca_machine_resources.AutomaticResources(
1056+
min_replica_count=1,
1057+
max_replica_count=1,
1058+
)
1059+
deployed_model = gca_endpoint.DeployedModel(
1060+
automatic_resources=automatic_resources,
1061+
model=test_model.resource_name,
1062+
display_name=None,
1063+
)
1064+
deploy_model_mock.assert_called_once_with(
1065+
endpoint=test_endpoint.resource_name,
1066+
deployed_model=deployed_model,
1067+
traffic_split={_TEST_ID: 10, _TEST_ID_2: 40, _TEST_ID_3: 10, "0": 40},
1068+
metadata=(),
1069+
timeout=None,
1070+
)
10181071

10191072
@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
10201073
@pytest.mark.parametrize("sync", [True, False])

0 commit comments

Comments
 (0)