Skip to content

Commit 44e279b

Browse files
authored
fix: change endpoint update method to return resource (#1409)
* fix: change endpoint update method to return resource * fix: update unit tests for endpoint.update
1 parent 643d335 commit 44e279b

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

google/cloud/aiplatform/models.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,19 +1377,13 @@ def update(
13771377
self,
13781378
)
13791379

1380-
update_endpoint_lro = self.api_client.update_endpoint(
1380+
self._gca_resource = self.api_client.update_endpoint(
13811381
endpoint=copied_endpoint_proto,
13821382
update_mask=update_mask,
13831383
metadata=request_metadata,
13841384
timeout=update_request_timeout,
13851385
)
13861386

1387-
_LOGGER.log_action_started_against_resource_with_lro(
1388-
"Update", "endpoint", self.__class__, update_endpoint_lro
1389-
)
1390-
1391-
update_endpoint_lro.result()
1392-
13931387
_LOGGER.log_action_completed_against_resource("endpoint", "updated", self)
13941388

13951389
return self

tests/unit/aiplatform/test_endpoints.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,11 @@ def update_endpoint_mock():
279279
with mock.patch.object(
280280
endpoint_service_client.EndpointServiceClient, "update_endpoint"
281281
) as update_endpoint_mock:
282-
update_endpoint_lro_mock = mock.Mock(ga_operation.Operation)
283-
update_endpoint_mock.return_value = update_endpoint_lro_mock
282+
update_endpoint_mock.return_value = gca_endpoint.Endpoint(
283+
display_name=_TEST_DISPLAY_NAME,
284+
name=_TEST_ENDPOINT_NAME,
285+
encryption_spec=_TEST_ENCRYPTION_SPEC,
286+
)
284287
yield update_endpoint_mock
285288

286289

@@ -768,9 +771,18 @@ def test_update_endpoint(self, update_endpoint_mock):
768771
timeout=_TEST_TIMEOUT,
769772
)
770773

774+
update_endpoint_mock.return_value = gca_endpoint.Endpoint(
775+
name=_TEST_ENDPOINT_NAME,
776+
display_name=_TEST_DISPLAY_NAME,
777+
description=_TEST_DESCRIPTION,
778+
labels=_TEST_LABELS,
779+
encryption_spec=_TEST_ENCRYPTION_SPEC,
780+
)
781+
771782
@pytest.mark.usefixtures("get_endpoint_with_models_mock")
772783
def test_update_traffic_split(self, update_endpoint_mock):
773784
endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
785+
774786
endpoint.update(traffic_split={_TEST_ID: 10, _TEST_ID_2: 80, _TEST_ID_3: 10})
775787

776788
expected_endpoint = gca_endpoint.Endpoint(
@@ -788,6 +800,12 @@ def test_update_traffic_split(self, update_endpoint_mock):
788800
timeout=_TEST_TIMEOUT,
789801
)
790802

803+
update_endpoint_mock.return_value = gca_endpoint.Endpoint(
804+
display_name=_TEST_DISPLAY_NAME,
805+
name=_TEST_ENDPOINT_NAME,
806+
traffic_split={_TEST_ID: 10, _TEST_ID_2: 80, _TEST_ID_3: 10},
807+
)
808+
791809
@pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock")
792810
@pytest.mark.parametrize("sync", [True, False])
793811
def test_deploy(self, deploy_model_mock, sync):

0 commit comments

Comments
 (0)