Skip to content

Commit 8b0add1

Browse files
sararobcopybara-github
authored andcommitted
feat: add Custom Job support to from_pretrained
PiperOrigin-RevId: 565175389
1 parent 220cbe8 commit 8b0add1

File tree

12 files changed

+457
-58
lines changed

12 files changed

+457
-58
lines changed

google/cloud/aiplatform/jobs.py

+13
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@
8787
gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLED,
8888
)
8989

90+
_JOB_PENDING_STATES = (
91+
gca_job_state.JobState.JOB_STATE_QUEUED,
92+
gca_job_state.JobState.JOB_STATE_PENDING,
93+
gca_job_state.JobState.JOB_STATE_RUNNING,
94+
gca_job_state.JobState.JOB_STATE_CANCELLING,
95+
gca_job_state.JobState.JOB_STATE_UPDATING,
96+
gca_job_state_v1beta1.JobState.JOB_STATE_QUEUED,
97+
gca_job_state_v1beta1.JobState.JOB_STATE_PENDING,
98+
gca_job_state_v1beta1.JobState.JOB_STATE_RUNNING,
99+
gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLING,
100+
gca_job_state_v1beta1.JobState.JOB_STATE_UPDATING,
101+
)
102+
90103
# _block_until_complete wait times
91104
_JOB_WAIT_TIME = 5 # start at five seconds
92105
_LOG_WAIT_TIME = 5

tests/unit/vertexai/conftest.py

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
8181
),
8282
},
83+
labels={"trained_by_vertex_ai": "true"},
8384
)
8485

8586

tests/unit/vertexai/test_model_utils.py

+202
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,19 @@
2222
import vertexai
2323
from vertexai.preview._workflow.serialization_engine import (
2424
any_serializer,
25+
serializers_base,
26+
)
27+
from google.cloud.aiplatform.compat.services import job_service_client
28+
from google.cloud.aiplatform.compat.types import (
29+
job_state as gca_job_state,
30+
custom_job as gca_custom_job,
31+
io as gca_io,
2532
)
2633
import pytest
2734

35+
import cloudpickle
36+
import numpy as np
37+
import sklearn
2838
from sklearn.linear_model import _logistic
2939
import tensorflow
3040
import torch
@@ -45,6 +55,9 @@
4555
_MODEL_RESOURCE_NAME = "projects/123/locations/us-central1/models/456"
4656
_REWRAPPER = "rewrapper"
4757

58+
# customJob constants
59+
_TEST_CUSTOM_JOB_RESOURCE_NAME = "projects/123/locations/us-central1/customJobs/456"
60+
4861

4962
@pytest.fixture
5063
def mock_serialize_model():
@@ -123,6 +136,126 @@ def mock_deserialize_model_exception():
123136
yield mock_deserialize_model_exception
124137

125138

139+
@pytest.fixture
140+
def mock_any_serializer_serialize_sklearn():
141+
with mock.patch.object(
142+
any_serializer.AnySerializer,
143+
"serialize",
144+
side_effect=[
145+
{
146+
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
147+
f"scikit-learn=={sklearn.__version__}"
148+
]
149+
},
150+
{
151+
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
152+
f"numpy=={np.__version__}",
153+
f"cloudpickle=={cloudpickle.__version__}",
154+
]
155+
},
156+
{
157+
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
158+
f"numpy=={np.__version__}",
159+
f"cloudpickle=={cloudpickle.__version__}",
160+
]
161+
},
162+
{
163+
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
164+
f"numpy=={np.__version__}",
165+
f"cloudpickle=={cloudpickle.__version__}",
166+
]
167+
},
168+
],
169+
) as mock_any_serializer_serialize:
170+
yield mock_any_serializer_serialize
171+
172+
173+
_TEST_PROJECT = "test-project"
174+
_TEST_LOCATION = "us-central1"
175+
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
176+
_TEST_DISPLAY_NAME = f"{_TEST_PARENT}/customJobs/12345"
177+
_TEST_BUCKET_NAME = "gs://test_bucket"
178+
_TEST_BASE_OUTPUT_DIR = f"{_TEST_BUCKET_NAME}/test_base_output_dir"
179+
180+
_TEST_INPUTS = [
181+
"--arg_0=string_val_0",
182+
"--arg_1=string_val_1",
183+
"--arg_2=int_val_0",
184+
"--arg_3=int_val_1",
185+
]
186+
_TEST_IMAGE_URI = "test_image_uri"
187+
_TEST_MACHINE_TYPE = "test_machine_type"
188+
_TEST_WORKER_POOL_SPEC = [
189+
{
190+
"machine_spec": {
191+
"machine_type": _TEST_MACHINE_TYPE,
192+
},
193+
"replica_count": 1,
194+
"container_spec": {
195+
"image_uri": _TEST_IMAGE_URI,
196+
"args": _TEST_INPUTS,
197+
},
198+
}
199+
]
200+
_TEST_CUSTOM_JOB_PROTO = gca_custom_job.CustomJob(
201+
display_name=_TEST_DISPLAY_NAME,
202+
job_spec={
203+
"worker_pool_specs": _TEST_WORKER_POOL_SPEC,
204+
"base_output_directory": gca_io.GcsDestination(
205+
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
206+
),
207+
},
208+
labels={"trained_by_vertex_ai": "true"},
209+
)
210+
211+
212+
@pytest.fixture
213+
def mock_get_custom_job_pending():
214+
with mock.patch.object(
215+
job_service_client.JobServiceClient, "get_custom_job"
216+
) as mock_get_custom_job:
217+
218+
mock_get_custom_job.side_effect = [
219+
gca_custom_job.CustomJob(
220+
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
221+
state=gca_job_state.JobState.JOB_STATE_RUNNING,
222+
display_name=_TEST_DISPLAY_NAME,
223+
job_spec={
224+
"worker_pool_specs": _TEST_WORKER_POOL_SPEC,
225+
"base_output_directory": gca_io.GcsDestination(
226+
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
227+
),
228+
},
229+
labels={"trained_by_vertex_ai": "true"},
230+
),
231+
gca_custom_job.CustomJob(
232+
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
233+
state=gca_job_state.JobState.JOB_STATE_SUCCEEDED,
234+
display_name=_TEST_DISPLAY_NAME,
235+
job_spec={
236+
"worker_pool_specs": _TEST_WORKER_POOL_SPEC,
237+
"base_output_directory": gca_io.GcsDestination(
238+
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
239+
),
240+
},
241+
labels={"trained_by_vertex_ai": "true"},
242+
),
243+
]
244+
yield mock_get_custom_job
245+
246+
247+
@pytest.fixture
248+
def mock_get_custom_job_failed():
249+
with mock.patch.object(
250+
job_service_client.JobServiceClient, "get_custom_job"
251+
) as mock_get_custom_job:
252+
custom_job_proto = _TEST_CUSTOM_JOB_PROTO
253+
custom_job_proto.name = _TEST_CUSTOM_JOB_RESOURCE_NAME
254+
custom_job_proto.state = gca_job_state.JobState.JOB_STATE_FAILED
255+
mock_get_custom_job.return_value = custom_job_proto
256+
yield mock_get_custom_job
257+
258+
126259
@pytest.mark.usefixtures("google_auth_mock")
127260
class TestModelUtils:
128261
def setup_method(self):
@@ -289,3 +422,72 @@ def test_local_model_from_pretrained_fail(self):
289422

290423
with pytest.raises(ValueError):
291424
vertexai.preview.from_pretrained(model_name=_MODEL_RESOURCE_NAME)
425+
426+
@pytest.mark.usefixtures(
427+
"mock_get_vertex_model",
428+
"mock_get_custom_job_succeeded",
429+
)
430+
def test_custom_job_from_pretrained_succeed(self, mock_deserialize_model):
431+
vertexai.init(
432+
project=_TEST_PROJECT,
433+
location=_TEST_LOCATION,
434+
staging_bucket=_TEST_BUCKET,
435+
)
436+
437+
local_model = vertexai.preview.from_pretrained(
438+
custom_job_name=_TEST_CUSTOM_JOB_RESOURCE_NAME
439+
)
440+
assert local_model == _SKLEARN_MODEL
441+
assert 2 == mock_deserialize_model.call_count
442+
443+
mock_deserialize_model.assert_has_calls(
444+
calls=[
445+
mock.call(
446+
f"{_TEST_BASE_OUTPUT_DIR}/output/output_estimator",
447+
),
448+
],
449+
any_order=True,
450+
)
451+
452+
@pytest.mark.usefixtures(
453+
"mock_get_vertex_model",
454+
"mock_get_custom_job_pending",
455+
"mock_cloud_logging_list_entries",
456+
)
457+
def test_custom_job_from_pretrained_logs_and_blocks_until_complete_on_pending_job(
458+
self, mock_deserialize_model
459+
):
460+
vertexai.init(
461+
project=_TEST_PROJECT,
462+
location=_TEST_LOCATION,
463+
staging_bucket=_TEST_BUCKET,
464+
)
465+
466+
local_model = vertexai.preview.from_pretrained(
467+
custom_job_name=_TEST_CUSTOM_JOB_RESOURCE_NAME
468+
)
469+
assert local_model == _SKLEARN_MODEL
470+
assert 2 == mock_deserialize_model.call_count
471+
472+
mock_deserialize_model.assert_has_calls(
473+
calls=[
474+
mock.call(
475+
f"{_TEST_BASE_OUTPUT_DIR}/output/output_estimator",
476+
),
477+
],
478+
any_order=True,
479+
)
480+
481+
@pytest.mark.usefixtures("mock_get_vertex_model", "mock_get_custom_job_failed")
482+
def test_custom_job_from_pretrained_fails_on_errored_job(self):
483+
vertexai.init(
484+
project=_TEST_PROJECT,
485+
location=_TEST_LOCATION,
486+
staging_bucket=_TEST_BUCKET,
487+
)
488+
489+
with pytest.raises(ValueError) as err_msg:
490+
vertexai.preview.from_pretrained(
491+
custom_job_name=_TEST_CUSTOM_JOB_RESOURCE_NAME
492+
)
493+
assert "did not complete" in err_msg

tests/unit/vertexai/test_remote_training.py

+13
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def _get_custom_job_proto(
388388
env.append(
389389
{"name": metadata_constants.ENV_EXPERIMENT_RUN_KEY, "value": experiment_run}
390390
)
391+
job.labels = ({"trained_by_vertex_ai": "true"},)
391392
return job
392393

393394

@@ -480,6 +481,12 @@ def mock_any_serializer_serialize_sklearn():
480481
f"cloudpickle=={cloudpickle.__version__}",
481482
]
482483
},
484+
{
485+
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
486+
f"numpy=={np.__version__}",
487+
f"cloudpickle=={cloudpickle.__version__}",
488+
]
489+
},
483490
],
484491
) as mock_any_serializer_serialize:
485492
yield mock_any_serializer_serialize
@@ -557,6 +564,12 @@ def mock_any_serializer_serialize_keras():
557564
f"cloudpickle=={cloudpickle.__version__}",
558565
]
559566
},
567+
{
568+
serializers_base.SERIALIZATION_METADATA_DEPENDENCIES_KEY: [
569+
f"numpy=={np.__version__}",
570+
f"cloudpickle=={cloudpickle.__version__}",
571+
]
572+
},
560573
],
561574
) as mock_any_serializer_serialize:
562575
yield mock_any_serializer_serialize

vertexai/preview/_workflow/driver/__init__.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def invoke(self, invokable: shared._Invokable) -> Any:
241241
):
242242
rewrapper = _unwrapper(invokable.instance)
243243

244-
result = self._launch(invokable)
244+
result = self._launch(invokable, rewrapper)
245245

246246
# rewrap the original instance
247247
if rewrapper and invokable.instance is not None:
@@ -255,12 +255,14 @@ def invoke(self, invokable: shared._Invokable) -> Any:
255255

256256
return result
257257

258-
def _launch(self, invokable: shared._Invokable) -> Any:
258+
def _launch(self, invokable: shared._Invokable, rewrapper: Any) -> Any:
259259
"""
260260
Launches an invokable.
261261
"""
262262
return self._launcher.launch(
263-
invokable=invokable, global_remote=vertexai.preview.global_config.remote
263+
invokable=invokable,
264+
global_remote=vertexai.preview.global_config.remote,
265+
rewrapper=rewrapper,
264266
)
265267

266268

vertexai/preview/_workflow/executor/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,18 @@ def local_execute(self, invokable: shared._Invokable) -> Any:
3737
*invokable.bound_arguments.args, **invokable.bound_arguments.kwargs
3838
)
3939

40-
def remote_execute(self, invokable: shared._Invokable) -> Any:
40+
def remote_execute(self, invokable: shared._Invokable, rewrapper: Any) -> Any:
4141
if invokable.remote_executor not in (
4242
remote_container_training.train,
4343
training.remote_training,
4444
prediction.remote_prediction,
4545
):
4646
raise ValueError(f"{invokable.remote_executor} is not supported.")
4747

48-
return invokable.remote_executor(invokable)
48+
if invokable.remote_executor == remote_container_training.train:
49+
invokable.remote_executor(invokable)
50+
else:
51+
return invokable.remote_executor(invokable, rewrapper=rewrapper)
4952

5053

5154
_workflow_executor = _WorkflowExecutor()

vertexai/preview/_workflow/executor/prediction.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
#
15+
from typing import Any
16+
1517
from vertexai.preview._workflow import (
1618
shared,
1719
)
@@ -20,9 +22,9 @@
2022
)
2123

2224

23-
def remote_prediction(invokable: shared._Invokable):
25+
def remote_prediction(invokable: shared._Invokable, rewrapper: Any):
2426
"""Wrapper function that makes a method executable by Vertex CustomJob."""
25-
predictions = training.remote_training(invokable=invokable)
27+
predictions = training.remote_training(invokable=invokable, rewrapper=rewrapper)
2628
return predictions
2729

2830

0 commit comments

Comments
 (0)