Skip to content

Commit 8cefabb

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Implement preview deployment with RolloutOptions.
PiperOrigin-RevId: 729610395
1 parent a6225a0 commit 8cefabb

File tree

3 files changed

+205
-2
lines changed

3 files changed

+205
-2
lines changed

google/cloud/aiplatform/preview/models.py

+111-2
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,80 @@ def list(
471471
)
472472

473473

474+
class RolloutOptions(object):
475+
"""RolloutOptions contains configurations for rolling deployments.
476+
477+
Attributes:
478+
previous_deployed_model:
479+
The ID of the previous deployed model.
480+
max_surge_percentage:
481+
Maximum additional replicas to create during the deployment,
482+
specified as a percentage of the current replica count.
483+
max_surge_replicas:
484+
Maximum number of additional replicas to create during the
485+
deployment.
486+
max_unavailable_percentage:
487+
Maximum amount of replicas that can be unavailable during the
488+
deployment, specified as a percentage of the current replica count.
489+
max_unavailable_replicas:
490+
Maximum number of replicas that can be unavailable during the
491+
deployment.
492+
"""
493+
494+
def __init__(
495+
self,
496+
previous_deployed_model: int,
497+
max_surge_percentage: Optional[int] = None,
498+
max_surge_replicas: Optional[int] = None,
499+
max_unavailable_percentage: Optional[int] = None,
500+
max_unavailable_replicas: Optional[int] = None,
501+
):
502+
self.previous_deployed_model = previous_deployed_model
503+
self.max_surge_percentage = max_surge_percentage
504+
self.max_surge_replicas = max_surge_replicas
505+
self.max_unavailable_percentage = max_unavailable_percentage
506+
self.max_unavailable_replicas = max_unavailable_replicas
507+
508+
@classmethod
509+
def from_gapic(cls, opts: gca_endpoint_compat.RolloutOptions) -> "RolloutOptions":
510+
return cls(
511+
previous_deployed_model=int(opts.previous_deployed_model),
512+
max_surge_percentage=opts.max_surge_percentage,
513+
max_surge_replicas=opts.max_surge_replicas,
514+
max_unavailable_percentage=opts.max_unavailable_percentage,
515+
max_unavailable_replicas=opts.max_unavailable_replicas,
516+
)
517+
518+
def to_gapic(self) -> gca_endpoint_compat.RolloutOptions:
519+
"""Converts RolloutOptions class to gapic RolloutOptions proto."""
520+
result = gca_endpoint_compat.RolloutOptions(
521+
previous_deployed_model=str(self.previous_deployed_model),
522+
)
523+
if self.max_surge_percentage:
524+
if self.max_surge_replicas:
525+
raise ValueError(
526+
"max_surge_percentage and max_surge_replicas cannot both be" " set."
527+
)
528+
result.max_surge_percentage = self.max_surge_percentage
529+
elif self.max_surge_replicas:
530+
result.max_surge_replicas = self.max_surge_replicas
531+
else:
532+
result.max_surge_replicas = 0
533+
if self.max_unavailable_percentage:
534+
if self.max_unavailable_replicas:
535+
raise ValueError(
536+
"max_unavailable_percentage and max_unavailable_replicas"
537+
" cannot both be set."
538+
)
539+
result.max_unavailable_percentage = self.max_unavailable_percentage
540+
elif self.max_unavailable_replicas:
541+
result.max_unavailable_replicas = self.max_unavailable_replicas
542+
else:
543+
result.max_unavailable_replicas = 0
544+
545+
return result
546+
547+
474548
class Endpoint(aiplatform.Endpoint):
475549
@staticmethod
476550
def _validate_deploy_args(
@@ -616,6 +690,7 @@ def deploy(
616690
fast_tryout_enabled: bool = False,
617691
system_labels: Optional[Dict[str, str]] = None,
618692
required_replica_count: Optional[int] = 0,
693+
rollout_options: Optional[RolloutOptions] = None,
619694
) -> None:
620695
"""Deploys a Model to the Endpoint.
621696
@@ -712,6 +787,8 @@ def deploy(
712787
set, the model deploy/mutate operation will succeed once
713788
available_replica_count reaches required_replica_count, and the
714789
rest of the replicas will be retried.
790+
rollout_options (RolloutOptions):
791+
Optional. Options to configure a rolling deployment.
715792
716793
"""
717794
self._sync_gca_resource_if_skipped()
@@ -754,6 +831,7 @@ def deploy(
754831
fast_tryout_enabled=fast_tryout_enabled,
755832
system_labels=system_labels,
756833
required_replica_count=required_replica_count,
834+
rollout_options=rollout_options,
757835
)
758836

759837
@base.optional_sync()
@@ -780,6 +858,7 @@ def _deploy(
780858
fast_tryout_enabled: bool = False,
781859
system_labels: Optional[Dict[str, str]] = None,
782860
required_replica_count: Optional[int] = 0,
861+
rollout_options: Optional[RolloutOptions] = None,
783862
) -> None:
784863
"""Deploys a Model to the Endpoint.
785864
@@ -870,7 +949,8 @@ def _deploy(
870949
set, the model deploy/mutate operation will succeed once
871950
available_replica_count reaches required_replica_count, and the
872951
rest of the replicas will be retried.
873-
952+
rollout_options (RolloutOptions): Optional.
953+
Options to configure a rolling deployment.
874954
"""
875955
_LOGGER.log_action_start_against_resource(
876956
f"Deploying Model {model.resource_name} to", "", self
@@ -901,6 +981,7 @@ def _deploy(
901981
fast_tryout_enabled=fast_tryout_enabled,
902982
system_labels=system_labels,
903983
required_replica_count=required_replica_count,
984+
rollout_options=rollout_options,
904985
)
905986

906987
_LOGGER.log_action_completed_against_resource("model", "deployed", self)
@@ -934,6 +1015,7 @@ def _deploy_call(
9341015
fast_tryout_enabled: bool = False,
9351016
system_labels: Optional[Dict[str, str]] = None,
9361017
required_replica_count: Optional[int] = 0,
1018+
rollout_options: Optional[RolloutOptions] = None,
9371019
) -> None:
9381020
"""Helper method to deploy model to endpoint.
9391021
@@ -1031,6 +1113,8 @@ def _deploy_call(
10311113
set, the model deploy/mutate operation will succeed once
10321114
available_replica_count reaches required_replica_count, and the
10331115
rest of the replicas will be retried.
1116+
rollout_options (RolloutOptions): Optional. Options to configure a
1117+
rolling deployment.
10341118
10351119
Raises:
10361120
ValueError: If only `accelerator_type` or `accelerator_count` is
@@ -1103,7 +1187,7 @@ def _deploy_call(
11031187
machine_type = _DEFAULT_MACHINE_TYPE
11041188
_LOGGER.info(f"Using default machine_type: {machine_type}")
11051189

1106-
if use_dedicated_resources:
1190+
if use_dedicated_resources and not rollout_options:
11071191
dedicated_resources = gca_machine_resources_compat.DedicatedResources(
11081192
min_replica_count=min_replica_count,
11091193
max_replica_count=max_replica_count,
@@ -1146,6 +1230,15 @@ def _deploy_call(
11461230
)
11471231
)
11481232
deployed_model.dedicated_resources = dedicated_resources
1233+
elif rollout_options:
1234+
deployed_model.rollout_options = rollout_options.to_gapic()
1235+
elif supports_automatic_resources:
1236+
deployed_model.automatic_resources = (
1237+
gca_machine_resources_compat.AutomaticResources(
1238+
min_replica_count=min_replica_count,
1239+
max_replica_count=max_replica_count,
1240+
)
1241+
)
11491242
else:
11501243
deployed_model = gca_endpoint_compat.DeployedModel(
11511244
model=model.versioned_resource_name,
@@ -1444,6 +1537,7 @@ def deploy(
14441537
fast_tryout_enabled: bool = False,
14451538
system_labels: Optional[Dict[str, str]] = None,
14461539
required_replica_count: Optional[int] = 0,
1540+
rollout_options: Optional[RolloutOptions] = None,
14471541
) -> Union[Endpoint, models.PrivateEndpoint]:
14481542
"""Deploys model to endpoint.
14491543
@@ -1561,6 +1655,8 @@ def deploy(
15611655
set, the model deploy/mutate operation will succeed once
15621656
available_replica_count reaches required_replica_count, and the
15631657
rest of the replicas will be retried.
1658+
rollout_options (RolloutOptions):
1659+
Optional. Options to configure a rolling deployment.
15641660
15651661
Returns:
15661662
endpoint (Union[Endpoint, models.PrivateEndpoint]):
@@ -1620,6 +1716,7 @@ def deploy(
16201716
fast_tryout_enabled=fast_tryout_enabled,
16211717
system_labels=system_labels,
16221718
required_replica_count=required_replica_count,
1719+
rollout_options=rollout_options,
16231720
)
16241721

16251722
def _should_enable_dedicated_endpoint(self, fast_tryout_enabled: bool) -> bool:
@@ -1655,6 +1752,7 @@ def _deploy(
16551752
fast_tryout_enabled: bool = False,
16561753
system_labels: Optional[Dict[str, str]] = None,
16571754
required_replica_count: Optional[int] = 0,
1755+
rollout_options: Optional[RolloutOptions] = None,
16581756
) -> Union[Endpoint, models.PrivateEndpoint]:
16591757
"""Deploys model to endpoint.
16601758
@@ -1763,6 +1861,8 @@ def _deploy(
17631861
set, the model deploy/mutate operation will succeed once
17641862
available_replica_count reaches required_replica_count, and the
17651863
rest of the replicas will be retried.
1864+
rollout_options (RolloutOptions):
1865+
Optional. Options to configure a rolling deployment.
17661866
17671867
Returns:
17681868
endpoint (Union[Endpoint, models.PrivateEndpoint]):
@@ -1771,6 +1871,10 @@ def _deploy(
17711871

17721872
if endpoint is None:
17731873
display_name = self.display_name[:118] + "_endpoint"
1874+
if rollout_options is not None:
1875+
raise ValueError(
1876+
"Rollout options may only be used when deploying to an existing endpoint."
1877+
)
17741878

17751879
if not network:
17761880
endpoint = Endpoint.create(
@@ -1792,6 +1896,10 @@ def _deploy(
17921896
credentials=self.credentials,
17931897
encryption_spec_key_name=encryption_spec_key_name,
17941898
)
1899+
if isinstance(endpoint, Endpoint):
1900+
preview_kwargs = {"rollout_options": rollout_options}
1901+
else:
1902+
preview_kwargs = {}
17951903

17961904
_LOGGER.log_action_start_against_resource("Deploying model to", "", endpoint)
17971905

@@ -1820,6 +1928,7 @@ def _deploy(
18201928
fast_tryout_enabled=fast_tryout_enabled,
18211929
system_labels=system_labels,
18221930
required_replica_count=required_replica_count,
1931+
**preview_kwargs,
18231932
)
18241933

18251934
_LOGGER.log_action_completed_against_resource("model", "deployed", endpoint)

tests/unit/aiplatform/test_endpoints.py

+46
Original file line numberDiff line numberDiff line change
@@ -2352,6 +2352,52 @@ def test_allocate_traffic(self, model1, model2, model3, percent):
23522352
assert new_split_sum == 100
23532353
assert new_split["0"] == percent
23542354

2355+
@pytest.mark.usefixtures(
2356+
"get_model_mock",
2357+
"preview_deploy_model_mock",
2358+
"create_endpoint_mock",
2359+
"get_endpoint_mock",
2360+
)
2361+
@pytest.mark.parametrize("sync", [True, False])
2362+
def test_preview_deploy_with_rollout_options(self, preview_deploy_model_mock, sync):
2363+
test_model = models.Model(_TEST_ID).preview
2364+
test_model._gca_resource.supported_deployment_resources_types.append(
2365+
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
2366+
)
2367+
test_endpoint = preview_models.Endpoint(_TEST_ENDPOINT_NAME)
2368+
test_rollout_options = preview_models.RolloutOptions(
2369+
previous_deployed_model="123",
2370+
max_surge_percentage=10,
2371+
max_unavailable_replicas=2,
2372+
)
2373+
test_endpoint.deploy(
2374+
model=test_model,
2375+
sync=sync,
2376+
deploy_request_timeout=None,
2377+
rollout_options=test_rollout_options,
2378+
disable_container_logging=False,
2379+
)
2380+
if not sync:
2381+
test_endpoint.wait()
2382+
expected_rollout_options = gca_endpoint_v1beta1.RolloutOptions(
2383+
previous_deployed_model="123",
2384+
max_surge_percentage=10,
2385+
max_unavailable_replicas=2,
2386+
)
2387+
expected_deployed_model = gca_endpoint_v1beta1.DeployedModel(
2388+
model=test_model.resource_name,
2389+
display_name=None,
2390+
rollout_options=expected_rollout_options,
2391+
enable_container_logging=True,
2392+
)
2393+
preview_deploy_model_mock.assert_called_once_with(
2394+
endpoint=test_endpoint.resource_name,
2395+
deployed_model=expected_deployed_model,
2396+
traffic_split={"0": 100},
2397+
metadata=(),
2398+
timeout=None,
2399+
)
2400+
23552401
@pytest.mark.parametrize(
23562402
"model1, model2, model3, deployed_model",
23572403
[

tests/unit/aiplatform/test_models.py

+48
Original file line numberDiff line numberDiff line change
@@ -2854,6 +2854,54 @@ def test_deploy_with_deployment_resource_pool(self, deploy_model_mock, sync):
28542854
timeout=None,
28552855
)
28562856

2857+
@pytest.mark.usefixtures(
2858+
"get_model_mock",
2859+
"preview_deploy_model_mock",
2860+
"create_endpoint_mock",
2861+
"get_endpoint_mock",
2862+
)
2863+
@pytest.mark.parametrize("sync", [True, False])
2864+
def test_preview_deploy_with_rollout_options(self, preview_deploy_model_mock, sync):
2865+
test_model = models.Model(_TEST_ID).preview
2866+
test_model._gca_resource.supported_deployment_resources_types.append(
2867+
aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES
2868+
)
2869+
test_endpoint = preview_models.Endpoint(
2870+
test_constants.EndpointConstants._TEST_ENDPOINT_NAME
2871+
)
2872+
test_rollout_options = preview_models.RolloutOptions(
2873+
previous_deployed_model="123",
2874+
max_surge_percentage=10,
2875+
max_unavailable_replicas=2,
2876+
)
2877+
test_model.deploy(
2878+
endpoint=test_endpoint,
2879+
sync=sync,
2880+
deploy_request_timeout=None,
2881+
rollout_options=test_rollout_options,
2882+
disable_container_logging=False,
2883+
)
2884+
if not sync:
2885+
test_endpoint.wait()
2886+
expected_rollout_options = gca_endpoint_v1beta1.RolloutOptions(
2887+
previous_deployed_model="123",
2888+
max_surge_percentage=10,
2889+
max_unavailable_replicas=2,
2890+
)
2891+
expected_deployed_model = gca_endpoint_v1beta1.DeployedModel(
2892+
model=test_model.resource_name,
2893+
display_name=None,
2894+
rollout_options=expected_rollout_options,
2895+
enable_container_logging=True,
2896+
)
2897+
preview_deploy_model_mock.assert_called_once_with(
2898+
endpoint=test_endpoint.resource_name,
2899+
deployed_model=expected_deployed_model,
2900+
traffic_split={"0": 100},
2901+
metadata=(),
2902+
timeout=None,
2903+
)
2904+
28572905
@pytest.mark.parametrize("sync", [True, False])
28582906
@pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock")
28592907
def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_and_dest(

0 commit comments

Comments
 (0)