Skip to content

Commit 7b7d7d2

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: add additional parameters to Model.upload().
PiperOrigin-RevId: 586452396
1 parent 6f2860a commit 7b7d7d2

File tree

4 files changed

+199
-1
lines changed

4 files changed

+199
-1
lines changed

google/cloud/aiplatform/models.py

+75
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from google.api_core import exceptions as api_exceptions
3737
from google.auth import credentials as auth_credentials
3838
from google.auth.transport import requests as google_auth_requests
39+
from google.protobuf import duration_pb2
3940
import proto
4041

4142
from google.cloud import aiplatform
@@ -2974,6 +2975,14 @@ def upload(
29742975
staging_bucket: Optional[str] = None,
29752976
sync=True,
29762977
upload_request_timeout: Optional[float] = None,
2978+
serving_container_deployment_timeout: Optional[int] = None,
2979+
serving_container_shared_memory_size_mb: Optional[int] = None,
2980+
serving_container_startup_probe_exec: Optional[Sequence[str]] = None,
2981+
serving_container_startup_probe_period_seconds: Optional[int] = None,
2982+
serving_container_startup_probe_timeout_seconds: Optional[int] = None,
2983+
serving_container_health_probe_exec: Optional[Sequence[str]] = None,
2984+
serving_container_health_probe_period_seconds: Optional[int] = None,
2985+
serving_container_health_probe_timeout_seconds: Optional[int] = None,
29772986
) -> "Model":
29782987
"""Uploads a model and returns a Model representing the uploaded Model
29792988
resource.
@@ -3153,6 +3162,31 @@ def upload(
31533162
staging_bucket set in aiplatform.init.
31543163
upload_request_timeout (float):
31553164
Optional. The timeout for the upload request in seconds.
3165+
serving_container_deployment_timeout (int):
3166+
Optional. Deployment timeout in seconds.
3167+
serving_container_shared_memory_size_mb (int):
3168+
Optional. The amount of the VM memory to reserve as the shared
3169+
memory for the model in megabytes.
3170+
serving_container_startup_probe_exec (Sequence[str]):
3171+
Optional. Exec specifies the action to take. Used by startup
3172+
probe. An example of this argument would be
3173+
["cat", "/tmp/healthy"]
3174+
serving_container_startup_probe_period_seconds (int):
3175+
Optional. How often (in seconds) to perform the startup probe.
3176+
Default to 10 seconds. Minimum value is 1.
3177+
serving_container_startup_probe_timeout_seconds (int):
3178+
Optional. Number of seconds after which the startup probe times
3179+
out. Defaults to 1 second. Minimum value is 1.
3180+
serving_container_health_probe_exec (Sequence[str]):
3181+
Optional. Exec specifies the action to take. Used by health
3182+
probe. An example of this argument would be
3183+
["cat", "/tmp/healthy"]
3184+
serving_container_health_probe_period_seconds (int):
3185+
Optional. How often (in seconds) to perform the health probe.
3186+
Default to 10 seconds. Minimum value is 1.
3187+
serving_container_health_probe_timeout_seconds (int):
3188+
Optional. Number of seconds after which the health probe times
3189+
out. Defaults to 1 second. Minimum value is 1.
31563190
31573191
Returns:
31583192
model (aiplatform.Model):
@@ -3187,6 +3221,13 @@ def upload(
31873221

31883222
env = None
31893223
ports = None
3224+
deployment_timeout = (
3225+
duration_pb2.Duration(seconds=serving_container_deployment_timeout)
3226+
if serving_container_deployment_timeout
3227+
else None
3228+
)
3229+
startup_probe = None
3230+
health_probe = None
31903231

31913232
if serving_container_environment_variables:
31923233
env = [
@@ -3198,6 +3239,36 @@ def upload(
31983239
gca_model_compat.Port(container_port=port)
31993240
for port in serving_container_ports
32003241
]
3242+
if (
3243+
serving_container_startup_probe_exec
3244+
or serving_container_startup_probe_period_seconds
3245+
or serving_container_startup_probe_timeout_seconds
3246+
):
3247+
startup_probe_exec = None
3248+
if serving_container_startup_probe_exec:
3249+
startup_probe_exec = gca_model_compat.Probe.ExecAction(
3250+
command=serving_container_startup_probe_exec
3251+
)
3252+
startup_probe = gca_model_compat.Probe(
3253+
exec=startup_probe_exec,
3254+
period_seconds=serving_container_startup_probe_period_seconds,
3255+
timeout_seconds=serving_container_startup_probe_timeout_seconds,
3256+
)
3257+
if (
3258+
serving_container_health_probe_exec
3259+
or serving_container_health_probe_period_seconds
3260+
or serving_container_health_probe_timeout_seconds
3261+
):
3262+
health_probe_exec = None
3263+
if serving_container_health_probe_exec:
3264+
health_probe_exec = gca_model_compat.Probe.ExecAction(
3265+
command=serving_container_health_probe_exec
3266+
)
3267+
health_probe = gca_model_compat.Probe(
3268+
exec=health_probe_exec,
3269+
period_seconds=serving_container_health_probe_period_seconds,
3270+
timeout_seconds=serving_container_health_probe_timeout_seconds,
3271+
)
32013272

32023273
container_spec = gca_model_compat.ModelContainerSpec(
32033274
image_uri=serving_container_image_uri,
@@ -3207,6 +3278,10 @@ def upload(
32073278
ports=ports,
32083279
predict_route=serving_container_predict_route,
32093280
health_route=serving_container_health_route,
3281+
deployment_timeout=deployment_timeout,
3282+
shared_memory_size_mb=serving_container_shared_memory_size_mb,
3283+
startup_probe=startup_probe,
3284+
health_probe=health_probe,
32103285
)
32113286

32123287
model_predict_schemata = None

google/cloud/aiplatform/prediction/local_model.py

+76
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from google.cloud.aiplatform.prediction.predictor import Predictor
3737
from google.cloud.aiplatform.utils import prediction_utils
3838

39+
from google.protobuf import duration_pb2
40+
3941
DEFAULT_PREDICT_ROUTE = "/predict"
4042
DEFAULT_HEALTH_ROUTE = "/health"
4143
DEFAULT_HTTP_PORT = 8080
@@ -58,6 +60,14 @@ def __init__(
5860
serving_container_args: Optional[Sequence[str]] = None,
5961
serving_container_environment_variables: Optional[Dict[str, str]] = None,
6062
serving_container_ports: Optional[Sequence[int]] = None,
63+
serving_container_deployment_timeout: Optional[int] = None,
64+
serving_container_shared_memory_size_mb: Optional[int] = None,
65+
serving_container_startup_probe_exec: Optional[Sequence[str]] = None,
66+
serving_container_startup_probe_period_seconds: Optional[int] = None,
67+
serving_container_startup_probe_timeout_seconds: Optional[int] = None,
68+
serving_container_health_probe_exec: Optional[Sequence[str]] = None,
69+
serving_container_health_probe_period_seconds: Optional[int] = None,
70+
serving_container_health_probe_timeout_seconds: Optional[int] = None,
6171
):
6272
"""Creates a local model instance.
6373
@@ -100,6 +110,31 @@ def __init__(
100110
no impact on whether the port is actually exposed, any port listening on
101111
the default "0.0.0.0" address inside a container will be accessible from
102112
the network.
113+
serving_container_deployment_timeout (int):
114+
Optional. Deployment timeout in seconds.
115+
serving_container_shared_memory_size_mb (int):
116+
Optional. The amount of the VM memory to reserve as the shared
117+
memory for the model in megabytes.
118+
serving_container_startup_probe_exec (Sequence[str]):
119+
Optional. Exec specifies the action to take. Used by startup
120+
probe. An example of this argument would be
121+
["cat", "/tmp/healthy"]
122+
serving_container_startup_probe_period_seconds (int):
123+
Optional. How often (in seconds) to perform the startup probe.
124+
Default to 10 seconds. Minimum value is 1.
125+
serving_container_startup_probe_timeout_seconds (int):
126+
Optional. Number of seconds after which the startup probe times
127+
out. Defaults to 1 second. Minimum value is 1.
128+
serving_container_health_probe_exec (Sequence[str]):
129+
Optional. Exec specifies the action to take. Used by health
130+
probe. An example of this argument would be
131+
["cat", "/tmp/healthy"]
132+
serving_container_health_probe_period_seconds (int):
133+
Optional. How often (in seconds) to perform the health probe.
134+
Default to 10 seconds. Minimum value is 1.
135+
serving_container_health_probe_timeout_seconds (int):
136+
Optional. Number of seconds after which the health probe times
137+
out. Defaults to 1 second. Minimum value is 1.
103138
104139
Raises:
105140
ValueError: If ``serving_container_spec`` is specified but ``serving_container_spec.image_uri``
@@ -121,6 +156,13 @@ def __init__(
121156

122157
env = None
123158
ports = None
159+
deployment_timeout = (
160+
duration_pb2.Duration(seconds=serving_container_deployment_timeout)
161+
if serving_container_deployment_timeout
162+
else None
163+
)
164+
startup_probe = None
165+
health_probe = None
124166

125167
if serving_container_environment_variables:
126168
env = [
@@ -132,6 +174,36 @@ def __init__(
132174
gca_model_compat.Port(container_port=port)
133175
for port in serving_container_ports
134176
]
177+
if (
178+
serving_container_startup_probe_exec
179+
or serving_container_startup_probe_period_seconds
180+
or serving_container_startup_probe_timeout_seconds
181+
):
182+
startup_probe_exec = None
183+
if serving_container_startup_probe_exec:
184+
startup_probe_exec = gca_model_compat.Probe.ExecAction(
185+
command=serving_container_startup_probe_exec
186+
)
187+
startup_probe = gca_model_compat.Probe(
188+
exec=startup_probe_exec,
189+
period_seconds=serving_container_startup_probe_period_seconds,
190+
timeout_seconds=serving_container_startup_probe_timeout_seconds,
191+
)
192+
if (
193+
serving_container_health_probe_exec
194+
or serving_container_health_probe_period_seconds
195+
or serving_container_health_probe_timeout_seconds
196+
):
197+
health_probe_exec = None
198+
if serving_container_health_probe_exec:
199+
health_probe_exec = gca_model_compat.Probe.ExecAction(
200+
command=serving_container_health_probe_exec
201+
)
202+
health_probe = gca_model_compat.Probe(
203+
exec=health_probe_exec,
204+
period_seconds=serving_container_health_probe_period_seconds,
205+
timeout_seconds=serving_container_health_probe_timeout_seconds,
206+
)
135207

136208
self.serving_container_spec = gca_model_compat.ModelContainerSpec(
137209
image_uri=serving_container_image_uri,
@@ -141,6 +213,10 @@ def __init__(
141213
ports=ports,
142214
predict_route=serving_container_predict_route,
143215
health_route=serving_container_health_route,
216+
deployment_timeout=deployment_timeout,
217+
shared_memory_size_mb=serving_container_shared_memory_size_mb,
218+
startup_probe=startup_probe,
219+
health_probe=health_probe,
144220
)
145221

146222
@classmethod

tests/system/aiplatform/test_prediction_cpr.py

+2
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def test_build_cpr_model_upload_and_deploy(self, shared_state, caplog):
9797
local_model=local_model,
9898
display_name=f"cpr_e2e_test_{_TIMESTAMP}",
9999
artifact_uri=_ARTIFACT_URI,
100+
serving_container_deployment_timeout=3600,
101+
serving_container_shared_memory_size_mb=20,
100102
)
101103
shared_state["resources"] = [model]
102104

tests/unit/aiplatform/test_models.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,12 @@
7878
from google.cloud.aiplatform_v1 import Execution as GapicExecution
7979
from google.cloud.aiplatform.model_evaluation import model_evaluation_job
8080

81-
from google.protobuf import field_mask_pb2, struct_pb2, timestamp_pb2
81+
from google.protobuf import (
82+
field_mask_pb2,
83+
struct_pb2,
84+
timestamp_pb2,
85+
duration_pb2,
86+
)
8287

8388
import constants as test_constants
8489

@@ -108,6 +113,14 @@
108113
"loss_fn": "mse",
109114
}
110115
_TEST_SERVING_CONTAINER_PORTS = [8888, 10000]
116+
_TEST_SERVING_CONTAINER_DEPLOYMENT_TIMEOUT = 100
117+
_TEST_SERVING_CONTAINER_SHARED_MEMORY_SIZE_MB = 1000
118+
_TEST_SERVING_CONTAINER_STARTUP_PROBE_EXEC = ["a", "b"]
119+
_TEST_SERVING_CONTAINER_STARTUP_PROBE_PERIOD_SECONDS = 5
120+
_TEST_SERVING_CONTAINER_STARTUP_PROBE_TIMEOUT_SECONDS = 100
121+
_TEST_SERVING_CONTAINER_HEALTH_PROBE_EXEC = ["c", "d"]
122+
_TEST_SERVING_CONTAINER_HEALTH_PROBE_PERIOD_SECONDS = 20
123+
_TEST_SERVING_CONTAINER_HEALTH_PROBE_TIMEOUT_SECONDS = 200
111124
_TEST_ID = "1028944691210842416"
112125
_TEST_LABEL = test_constants.ProjectConstants._TEST_LABELS
113126
_TEST_APPENDED_USER_AGENT = ["fake_user_agent", "another_fake_user_agent"]
@@ -1598,6 +1611,14 @@ def test_upload_uploads_and_gets_model_with_all_args(
15981611
labels=_TEST_LABEL,
15991612
sync=sync,
16001613
upload_request_timeout=None,
1614+
serving_container_deployment_timeout=_TEST_SERVING_CONTAINER_DEPLOYMENT_TIMEOUT,
1615+
serving_container_shared_memory_size_mb=_TEST_SERVING_CONTAINER_SHARED_MEMORY_SIZE_MB,
1616+
serving_container_startup_probe_exec=_TEST_SERVING_CONTAINER_STARTUP_PROBE_EXEC,
1617+
serving_container_startup_probe_period_seconds=_TEST_SERVING_CONTAINER_STARTUP_PROBE_PERIOD_SECONDS,
1618+
serving_container_startup_probe_timeout_seconds=_TEST_SERVING_CONTAINER_STARTUP_PROBE_TIMEOUT_SECONDS,
1619+
serving_container_health_probe_exec=_TEST_SERVING_CONTAINER_HEALTH_PROBE_EXEC,
1620+
serving_container_health_probe_period_seconds=_TEST_SERVING_CONTAINER_HEALTH_PROBE_PERIOD_SECONDS,
1621+
serving_container_health_probe_timeout_seconds=_TEST_SERVING_CONTAINER_HEALTH_PROBE_TIMEOUT_SECONDS,
16011622
)
16021623

16031624
if not sync:
@@ -1613,6 +1634,26 @@ def test_upload_uploads_and_gets_model_with_all_args(
16131634
for port in _TEST_SERVING_CONTAINER_PORTS
16141635
]
16151636

1637+
deployment_timeout = duration_pb2.Duration(
1638+
seconds=_TEST_SERVING_CONTAINER_DEPLOYMENT_TIMEOUT
1639+
)
1640+
1641+
startup_probe = gca_model.Probe(
1642+
exec=gca_model.Probe.ExecAction(
1643+
command=_TEST_SERVING_CONTAINER_STARTUP_PROBE_EXEC
1644+
),
1645+
period_seconds=_TEST_SERVING_CONTAINER_STARTUP_PROBE_PERIOD_SECONDS,
1646+
timeout_seconds=_TEST_SERVING_CONTAINER_STARTUP_PROBE_TIMEOUT_SECONDS,
1647+
)
1648+
1649+
health_probe = gca_model.Probe(
1650+
exec=gca_model.Probe.ExecAction(
1651+
command=_TEST_SERVING_CONTAINER_HEALTH_PROBE_EXEC
1652+
),
1653+
period_seconds=_TEST_SERVING_CONTAINER_HEALTH_PROBE_PERIOD_SECONDS,
1654+
timeout_seconds=_TEST_SERVING_CONTAINER_HEALTH_PROBE_TIMEOUT_SECONDS,
1655+
)
1656+
16161657
container_spec = gca_model.ModelContainerSpec(
16171658
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
16181659
predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
@@ -1621,6 +1662,10 @@ def test_upload_uploads_and_gets_model_with_all_args(
16211662
args=_TEST_SERVING_CONTAINER_ARGS,
16221663
env=env,
16231664
ports=ports,
1665+
deployment_timeout=deployment_timeout,
1666+
shared_memory_size_mb=_TEST_SERVING_CONTAINER_SHARED_MEMORY_SIZE_MB,
1667+
startup_probe=startup_probe,
1668+
health_probe=health_probe,
16241669
)
16251670

16261671
managed_model = gca_model.Model(

0 commit comments

Comments
 (0)