Skip to content

Commit 9fb24d7

Browse files
jaycee-licopybara-github
authored andcommitted
fix: Endpoint.undeploy_all() doesn't undeploy all models
#1441 PiperOrigin-RevId: 500253890
1 parent f87fef0 commit 9fb24d7

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

google/cloud/aiplatform/models.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1709,11 +1709,21 @@ def undeploy_all(self, sync: bool = True) -> "Endpoint":
17091709
"""
17101710
self._sync_gca_resource()
17111711

1712-
models_to_undeploy = sorted( # Undeploy zero traffic models first
1712+
models_in_traffic_split = sorted( # Undeploy zero traffic models first
17131713
self._gca_resource.traffic_split.keys(),
17141714
key=lambda id: self._gca_resource.traffic_split[id],
17151715
)
17161716

1717+
# Some deployed models may not in the traffic_split dict.
1718+
# These models have 0% traffic and should be undeployed first.
1719+
models_not_in_traffic_split = [
1720+
deployed_model.id
1721+
for deployed_model in self._gca_resource.deployed_models
1722+
if deployed_model.id not in models_in_traffic_split
1723+
]
1724+
1725+
models_to_undeploy = models_not_in_traffic_split + models_in_traffic_split
1726+
17171727
for deployed_model in models_to_undeploy:
17181728
self._undeploy(deployed_model_id=deployed_model, sync=sync)
17191729

tests/unit/aiplatform/test_endpoints.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,12 @@
102102
"m2": 10,
103103
"m3": 30,
104104
"m4": 0,
105-
"m5": 5,
106-
"m6": 8,
107-
"m7": 7,
105+
"m5": 20,
108106
}
109-
_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS = ["m4", "m5", "m7", "m6", "m2", "m3", "m1"]
107+
_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS = ["m4", "m2", "m5", "m3", "m1"]
110108
_TEST_LONG_DEPLOYED_MODELS = [
111109
gca_endpoint.DeployedModel(id=id, display_name=f"{id}_display_name")
112-
for id in _TEST_LONG_TRAFFIC_SPLIT.keys()
110+
for id in ["m1", "m2", "m3", "m4", "m5", "m6", "m7"]
113111
]
114112

115113
_TEST_MACHINE_TYPE = "n1-standard-32"
@@ -1861,11 +1859,6 @@ def test_list_models(self, get_endpoint_with_models_mock):
18611859
@pytest.mark.parametrize("sync", [True, False])
18621860
def test_undeploy_all(self, sdk_private_undeploy_mock, sync):
18631861

1864-
# Ensure mock traffic split deployed model IDs are same as expected IDs
1865-
assert set(_TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS) == set(
1866-
_TEST_LONG_TRAFFIC_SPLIT.keys()
1867-
)
1868-
18691862
ept = aiplatform.Endpoint(_TEST_ID)
18701863
ept.undeploy_all(sync=sync)
18711864

@@ -1874,10 +1867,11 @@ def test_undeploy_all(self, sdk_private_undeploy_mock, sync):
18741867

18751868
# undeploy_all() results in an undeploy() call for each deployed_model
18761869
# Models are undeployed in ascending order of traffic percentage
1870+
expected_models_to_undeploy = ["m6", "m7"] + _TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS
18771871
sdk_private_undeploy_mock.assert_has_calls(
18781872
[
18791873
mock.call(deployed_model_id=deployed_model_id, sync=sync)
1880-
for deployed_model_id in _TEST_LONG_TRAFFIC_SPLIT_SORTED_IDS
1874+
for deployed_model_id in expected_models_to_undeploy
18811875
],
18821876
)
18831877

0 commit comments

Comments
 (0)