Skip to content

Commit f3338fc

Browse files
authored
feat: Add done method for pipeline, training, and batch prediction jobs (#1062)
Added a done method via a `DoneMixin` class to check the status of long running jobs (returns True or False based on job state): * Implemented by `PipelineJob`, `_Job`, and `_TrainingJob` * Added system tests in `aiplatform/tests/system/aiplatform/test_e2e_tabular.py` * Added pipeline job tests in `tests/unit/aiplatform/test_pipeline_jobs.py` * Still need to add unit tests in `test_jobs` and `test_training_jobs` Fixes b/215396514
1 parent 6002d5d commit f3338fc

File tree

8 files changed

+249
-3
lines changed

8 files changed

+249
-3
lines changed

google/cloud/aiplatform/base.py

+55
Original file line numberDiff line numberDiff line change
@@ -1220,3 +1220,58 @@ def get_annotation_class(annotation: type) -> type:
12201220
return annotation.__args__[0]
12211221
else:
12221222
return annotation
1223+
1224+
1225+
class DoneMixin(abc.ABC):
1226+
"""An abstract class for implementing a done method, indicating
1227+
whether a job has completed.
1228+
1229+
"""
1230+
1231+
@abc.abstractmethod
1232+
def done(self) -> bool:
1233+
"""Method indicating whether a job has completed."""
1234+
pass
1235+
1236+
1237+
class StatefulResource(DoneMixin):
1238+
"""Extends DoneMixin to check whether a job returning a stateful resource has compted."""
1239+
1240+
@property
1241+
@abc.abstractmethod
1242+
def state(self):
1243+
"""The current state of the job."""
1244+
pass
1245+
1246+
@property
1247+
@classmethod
1248+
@abc.abstractmethod
1249+
def _valid_done_states(cls):
1250+
"""A set() containing all job states associated with a completed job."""
1251+
pass
1252+
1253+
def done(self) -> bool:
1254+
"""Method indicating whether a job has completed.
1255+
1256+
Returns:
1257+
True if the job has completed.
1258+
"""
1259+
if self.state in self._valid_done_states:
1260+
return True
1261+
else:
1262+
return False
1263+
1264+
1265+
class VertexAiStatefulResource(VertexAiResourceNounWithFutureManager, StatefulResource):
1266+
"""Extends StatefulResource to include a check for self._gca_resource."""
1267+
1268+
def done(self) -> bool:
1269+
"""Method indicating whether a job has completed.
1270+
1271+
Returns:
1272+
True if the job has completed.
1273+
"""
1274+
if self._gca_resource and self._gca_resource.name:
1275+
return super().done()
1276+
else:
1277+
return False

google/cloud/aiplatform/jobs.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
)
6767

6868

69-
class _Job(base.VertexAiResourceNounWithFutureManager):
69+
class _Job(base.VertexAiStatefulResource):
7070
"""Class that represents a general Job resource in Vertex AI.
7171
Cannot be directly instantiated.
7272
@@ -83,6 +83,9 @@ class _Job(base.VertexAiResourceNounWithFutureManager):
8383

8484
client_class = utils.JobClientWithOverride
8585

86+
# Required by the done() method
87+
_valid_done_states = _JOB_COMPLETE_STATES
88+
8689
def __init__(
8790
self,
8891
job_name: str,

google/cloud/aiplatform/pipeline_jobs.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _set_enable_caching_value(
7777
task["cachingOptions"] = {"enableCache": enable_caching}
7878

7979

80-
class PipelineJob(base.VertexAiResourceNounWithFutureManager):
80+
class PipelineJob(base.VertexAiStatefulResource):
8181

8282
client_class = utils.PipelineJobClientWithOverride
8383
_resource_noun = "pipelineJobs"
@@ -87,6 +87,9 @@ class PipelineJob(base.VertexAiResourceNounWithFutureManager):
8787
_parse_resource_name_method = "parse_pipeline_job_path"
8888
_format_resource_name_method = "pipeline_job_path"
8989

90+
# Required by the done() method
91+
_valid_done_states = _PIPELINE_COMPLETE_STATES
92+
9093
def __init__(
9194
self,
9295
display_name: str,

google/cloud/aiplatform/training_jobs.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
)
6767

6868

69-
class _TrainingJob(base.VertexAiResourceNounWithFutureManager):
69+
class _TrainingJob(base.VertexAiStatefulResource):
7070

7171
client_class = utils.PipelineClientWithOverride
7272
_resource_noun = "trainingPipelines"
@@ -76,6 +76,9 @@ class _TrainingJob(base.VertexAiResourceNounWithFutureManager):
7676
_parse_resource_name_method = "parse_training_pipeline_path"
7777
_format_resource_name_method = "training_pipeline_path"
7878

79+
# Required by the done() method
80+
_valid_done_states = _PIPELINE_COMPLETE_STATES
81+
7982
def __init__(
8083
self,
8184
display_name: str,

tests/system/aiplatform/test_e2e_tabular.py

+8
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def test_end_to_end_tabular(self, shared_state):
142142

143143
shared_state["resources"].append(custom_batch_prediction_job)
144144

145+
in_progress_done_check = custom_job.done()
145146
custom_job.wait_for_resource_creation()
147+
146148
automl_job.wait_for_resource_creation()
147149
custom_batch_prediction_job.wait_for_resource_creation()
148150

@@ -174,6 +176,8 @@ def test_end_to_end_tabular(self, shared_state):
174176
# Test lazy loading of Endpoint, check getter was never called after predict()
175177
custom_endpoint = aiplatform.Endpoint(custom_endpoint.resource_name)
176178
custom_endpoint.predict([_INSTANCE])
179+
180+
completion_done_check = custom_job.done()
177181
assert custom_endpoint._skipped_getter_call()
178182

179183
assert (
@@ -201,3 +205,7 @@ def test_end_to_end_tabular(self, shared_state):
201205
assert 200000 > custom_result > 50000
202206
except KeyError as e:
203207
raise RuntimeError("Unexpected prediction response structure:", e)
208+
209+
# Check done() method works correctly
210+
assert in_progress_done_check is False
211+
assert completion_done_check is True

tests/unit/aiplatform/test_jobs.py

+29
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,14 @@ def test_batch_prediction_job_status(self, get_batch_prediction_job_mock):
401401
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=base._DEFAULT_RETRY
402402
)
403403

404+
def test_batch_prediction_job_done_get(self, get_batch_prediction_job_mock):
405+
bp = jobs.BatchPredictionJob(
406+
batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME
407+
)
408+
409+
assert bp.done() is False
410+
assert get_batch_prediction_job_mock.call_count == 2
411+
404412
@pytest.mark.usefixtures("get_batch_prediction_job_gcs_output_mock")
405413
def test_batch_prediction_iter_dirs_gcs(self, storage_list_blobs_mock):
406414
bp = jobs.BatchPredictionJob(
@@ -507,6 +515,27 @@ def test_batch_predict_gcs_source_and_dest(
507515
batch_prediction_job=expected_gapic_batch_prediction_job,
508516
)
509517

518+
@pytest.mark.usefixtures("get_batch_prediction_job_mock")
519+
def test_batch_predict_job_done_create(self, create_batch_prediction_job_mock):
520+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
521+
522+
# Make SDK batch_predict method call
523+
batch_prediction_job = jobs.BatchPredictionJob.create(
524+
model_name=_TEST_MODEL_NAME,
525+
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
526+
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
527+
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
528+
sync=False,
529+
)
530+
531+
batch_prediction_job.wait_for_resource_creation()
532+
533+
assert batch_prediction_job.done() is False
534+
535+
batch_prediction_job.wait()
536+
537+
assert batch_prediction_job.done() is True
538+
510539
@pytest.mark.parametrize("sync", [True, False])
511540
@pytest.mark.usefixtures("get_batch_prediction_job_mock")
512541
def test_batch_predict_gcs_source_bq_dest(

tests/unit/aiplatform/test_pipeline_jobs.py

+33
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,39 @@ def test_submit_call_pipeline_service_pipeline_job_create(
532532
gca_pipeline_state_v1.PipelineState.PIPELINE_STATE_SUCCEEDED
533533
)
534534

535+
@pytest.mark.parametrize(
536+
"job_spec_json", [_TEST_PIPELINE_SPEC, _TEST_PIPELINE_JOB],
537+
)
538+
def test_done_method_pipeline_service(
539+
self,
540+
mock_pipeline_service_create,
541+
mock_pipeline_service_get,
542+
job_spec_json,
543+
mock_load_json,
544+
):
545+
aiplatform.init(
546+
project=_TEST_PROJECT,
547+
staging_bucket=_TEST_GCS_BUCKET_NAME,
548+
location=_TEST_LOCATION,
549+
credentials=_TEST_CREDENTIALS,
550+
)
551+
552+
job = pipeline_jobs.PipelineJob(
553+
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
554+
template_path=_TEST_TEMPLATE_PATH,
555+
job_id=_TEST_PIPELINE_JOB_ID,
556+
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
557+
enable_caching=True,
558+
)
559+
560+
job.submit(service_account=_TEST_SERVICE_ACCOUNT, network=_TEST_NETWORK)
561+
562+
assert job.done() is False
563+
564+
job.wait()
565+
566+
assert job.done() is True
567+
535568
@pytest.mark.parametrize(
536569
"job_spec_json", [_TEST_PIPELINE_SPEC_LEGACY, _TEST_PIPELINE_JOB_LEGACY],
537570
)

tests/unit/aiplatform/test_training_jobs.py

+112
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,65 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
10071007

10081008
assert job._has_logged_custom_job
10091009

1010+
def test_custom_training_tabular_done(
1011+
self,
1012+
mock_pipeline_service_create,
1013+
mock_pipeline_service_get,
1014+
mock_python_package_to_gcs,
1015+
mock_tabular_dataset,
1016+
mock_model_service_get,
1017+
):
1018+
aiplatform.init(
1019+
project=_TEST_PROJECT,
1020+
staging_bucket=_TEST_BUCKET_NAME,
1021+
credentials=_TEST_CREDENTIALS,
1022+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
1023+
)
1024+
1025+
job = training_jobs.CustomTrainingJob(
1026+
display_name=_TEST_DISPLAY_NAME,
1027+
labels=_TEST_LABELS,
1028+
script_path=_TEST_LOCAL_SCRIPT_FILE_NAME,
1029+
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
1030+
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
1031+
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
1032+
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
1033+
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
1034+
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
1035+
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
1036+
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
1037+
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
1038+
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
1039+
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
1040+
model_description=_TEST_MODEL_DESCRIPTION,
1041+
)
1042+
1043+
job.run(
1044+
dataset=mock_tabular_dataset,
1045+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
1046+
service_account=_TEST_SERVICE_ACCOUNT,
1047+
network=_TEST_NETWORK,
1048+
args=_TEST_RUN_ARGS,
1049+
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
1050+
machine_type=_TEST_MACHINE_TYPE,
1051+
accelerator_type=_TEST_ACCELERATOR_TYPE,
1052+
accelerator_count=_TEST_ACCELERATOR_COUNT,
1053+
model_display_name=_TEST_MODEL_DISPLAY_NAME,
1054+
model_labels=_TEST_MODEL_LABELS,
1055+
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
1056+
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
1057+
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
1058+
timestamp_split_column_name=_TEST_TIMESTAMP_SPLIT_COLUMN_NAME,
1059+
tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME,
1060+
sync=False,
1061+
)
1062+
1063+
assert job.done() is False
1064+
1065+
job.wait()
1066+
1067+
assert job.done() is True
1068+
10101069
@pytest.mark.parametrize("sync", [True, False])
10111070
def test_run_call_pipeline_service_create_with_bigquery_destination(
10121071
self,
@@ -2323,6 +2382,59 @@ def setup_method(self):
23232382
def teardown_method(self):
23242383
initializer.global_pool.shutdown(wait=True)
23252384

2385+
def test_custom_container_training_tabular_done(
2386+
self,
2387+
mock_pipeline_service_create,
2388+
mock_pipeline_service_get,
2389+
mock_tabular_dataset,
2390+
mock_model_service_get,
2391+
):
2392+
aiplatform.init(
2393+
project=_TEST_PROJECT,
2394+
staging_bucket=_TEST_BUCKET_NAME,
2395+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
2396+
)
2397+
2398+
job = training_jobs.CustomContainerTrainingJob(
2399+
display_name=_TEST_DISPLAY_NAME,
2400+
labels=_TEST_LABELS,
2401+
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
2402+
command=_TEST_TRAINING_CONTAINER_CMD,
2403+
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
2404+
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
2405+
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
2406+
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
2407+
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
2408+
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
2409+
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
2410+
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
2411+
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
2412+
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
2413+
model_description=_TEST_MODEL_DESCRIPTION,
2414+
)
2415+
2416+
job.run(
2417+
dataset=mock_tabular_dataset,
2418+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
2419+
args=_TEST_RUN_ARGS,
2420+
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
2421+
machine_type=_TEST_MACHINE_TYPE,
2422+
accelerator_type=_TEST_ACCELERATOR_TYPE,
2423+
accelerator_count=_TEST_ACCELERATOR_COUNT,
2424+
model_display_name=_TEST_MODEL_DISPLAY_NAME,
2425+
model_labels=_TEST_MODEL_LABELS,
2426+
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
2427+
service_account=_TEST_SERVICE_ACCOUNT,
2428+
tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME,
2429+
sync=False,
2430+
)
2431+
2432+
assert job.done() is False
2433+
2434+
job.wait()
2435+
2436+
assert job.done() is True
2437+
23262438
@pytest.mark.parametrize("sync", [True, False])
23272439
def test_run_call_pipeline_service_create_with_tabular_dataset(
23282440
self,

0 commit comments

Comments
 (0)