Skip to content

Commit 73490b2

Browse files
jsondaicopybara-github
authored andcommitted
feat: Allow customizing pipeline caching options for model evaluation jobs.
PiperOrigin-RevId: 673540795
1 parent ef80003 commit 73490b2

File tree

4 files changed

+163
-0
lines changed

4 files changed

+163
-0
lines changed

google/cloud/aiplatform/_pipeline_based_service/pipeline_based_service.py

+12
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def _create_and_submit_pipeline_job(
269269
location: Optional[str] = None,
270270
credentials: Optional[auth_credentials.Credentials] = None,
271271
experiment: Optional[Union[str, "aiplatform.Experiment"]] = None,
272+
enable_caching: Optional[bool] = None,
272273
) -> "_VertexAiPipelineBasedService":
273274
"""Create a new PipelineJob using the provided template and parameters.
274275
@@ -310,6 +311,16 @@ def _create_and_submit_pipeline_job(
310311
experiment (Union[str, experiments_resource.Experiment]):
311312
Optional. The Vertex AI experiment name or instance to associate
312313
to the PipelineJob executing this model evaluation job.
314+
enable_caching (bool):
315+
Optional. Whether to turn on caching for the run.
316+
317+
If this is not set, defaults to the compile time settings, which
318+
are True for all tasks by default, while users may specify
319+
different caching options for individual tasks.
320+
321+
If this is set, the setting applies to all tasks in the pipeline.
322+
323+
Overrides the compile time settings.
313324
Returns:
314325
(VertexAiPipelineBasedService):
315326
Instantiated representation of a Vertex AI Pipeline based service.
@@ -334,6 +345,7 @@ def _create_and_submit_pipeline_job(
334345
project=project,
335346
location=location,
336347
credentials=credentials,
348+
enable_caching=enable_caching,
337349
)
338350

339351
# Suppresses logs from PipelineJob

google/cloud/aiplatform/model_evaluation/model_evaluation_job.py

+12
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def submit(
174174
location: Optional[str] = None,
175175
credentials: Optional[auth_credentials.Credentials] = None,
176176
experiment: Optional[Union[str, "aiplatform.Experiment"]] = None,
177+
enable_caching: Optional[bool] = None,
177178
) -> "_ModelEvaluationJob":
178179
"""Submits a Model Evaluation Job using aiplatform.PipelineJob and returns
179180
the ModelEvaluationJob resource.
@@ -277,6 +278,16 @@ def submit(
277278
experiment (Union[str, experiments_resource.Experiment]):
278279
Optional. The Vertex AI experiment name or instance to associate to the PipelineJob executing
279280
this model evaluation job.
281+
enable_caching (bool):
282+
Optional. Whether to turn on caching for the run.
283+
284+
If this is not set, defaults to the compile time settings, which
285+
are True for all tasks by default, while users may specify
286+
different caching options for individual tasks.
287+
288+
If this is set, the setting applies to all tasks in the pipeline.
289+
290+
Overrides the compile time settings.
280291
Returns:
281292
(ModelEvaluationJob): Instantiated represnetation of the model evaluation job.
282293
"""
@@ -351,6 +362,7 @@ def submit(
351362
location=location,
352363
credentials=credentials,
353364
experiment=experiment,
365+
enable_caching=enable_caching,
354366
)
355367

356368
_LOGGER.info(

google/cloud/aiplatform/models.py

+12
Original file line numberDiff line numberDiff line change
@@ -6883,6 +6883,7 @@ def evaluate(
68836883
network: Optional[str] = None,
68846884
encryption_spec_key_name: Optional[str] = None,
68856885
experiment: Optional[Union[str, "aiplatform.Experiment"]] = None,
6886+
enable_caching: Optional[bool] = None,
68866887
) -> "model_evaluation._ModelEvaluationJob":
68876888
"""Creates a model evaluation job running on Vertex Pipelines and returns the resulting
68886889
ModelEvaluationJob resource.
@@ -6968,6 +6969,16 @@ def evaluate(
69686969
this model evaluation job. Metrics produced by the PipelineJob as system.Metric Artifacts
69696970
will be associated as metrics to the provided experiment, and parameters from this PipelineJob
69706971
will be associated as parameters to the provided experiment.
6972+
enable_caching (bool):
6973+
Optional. Whether to turn on caching for the run.
6974+
6975+
If this is not set, defaults to the compile time settings, which
6976+
are True for all tasks by default, while users may specify
6977+
different caching options for individual tasks.
6978+
6979+
If this is set, the setting applies to all tasks in the pipeline.
6980+
6981+
Overrides the compile time settings.
69716982
Returns:
69726983
model_evaluation.ModelEvaluationJob: Instantiated representation of the
69736984
_ModelEvaluationJob.
@@ -7088,6 +7099,7 @@ def evaluate(
70887099
encryption_spec_key_name=encryption_spec_key_name,
70897100
credentials=self.credentials,
70907101
experiment=experiment,
7102+
enable_caching=enable_caching,
70917103
)
70927104

70937105

tests/unit/aiplatform/test_model_evaluation.py

+127
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,39 @@
198198
}
199199
)
200200

201+
_TEST_MODEL_EVAL_PIPELINE_SPEC_WITH_CACHING_OPTIONS_JSON = json.dumps(
202+
{
203+
"pipelineInfo": {"name": "evaluation-default-pipeline"},
204+
"root": {
205+
"dag": {
206+
"tasks": {
207+
"model-evaluation-text-generation": {
208+
"taskInfo": {"name": "model-evaluation-text-generation"},
209+
"cachingOptions": {"enableCache": False},
210+
}
211+
}
212+
},
213+
"inputDefinitions": {
214+
"parameters": {
215+
"batch_predict_gcs_source_uris": {"type": "STRING"},
216+
"dataflow_service_account": {"type": "STRING"},
217+
"batch_predict_instances_format": {"type": "STRING"},
218+
"batch_predict_machine_type": {"type": "STRING"},
219+
"evaluation_class_labels": {"type": "STRING"},
220+
"location": {"type": "STRING"},
221+
"model_name": {"type": "STRING"},
222+
"project": {"type": "STRING"},
223+
"batch_predict_gcs_destination_output_uri": {"type": "STRING"},
224+
"target_field_name": {"type": "STRING"},
225+
}
226+
},
227+
},
228+
"schemaVersion": "2.0.0",
229+
"sdkVersion": "kfp-1.8.12",
230+
"components": {},
231+
}
232+
)
233+
201234
_TEST_INVALID_MODEL_EVAL_PIPELINE_SPEC = json.dumps(
202235
{
203236
"pipelineInfo": {"name": "my-pipeline"},
@@ -1083,6 +1116,100 @@ def test_model_evaluation_job_submit(
10831116

10841117
assert mock_model_eval_job_get.called_once
10851118

1119+
@pytest.mark.parametrize(
1120+
"job_spec",
1121+
[_TEST_MODEL_EVAL_PIPELINE_SPEC_WITH_CACHING_OPTIONS_JSON],
1122+
)
1123+
@pytest.mark.usefixtures("mock_pipeline_service_create")
1124+
def test_model_evaluation_job_submit_with_caching_disabled(
1125+
self,
1126+
job_spec,
1127+
mock_load_yaml_and_json,
1128+
mock_model,
1129+
get_model_mock,
1130+
mock_model_eval_get,
1131+
mock_model_eval_job_get,
1132+
mock_pipeline_service_get,
1133+
mock_model_eval_job_create,
1134+
mock_pipeline_bucket_exists,
1135+
mock_request_urlopen,
1136+
):
1137+
test_model_eval_job = model_evaluation_job._ModelEvaluationJob.submit(
1138+
model_name=_TEST_MODEL_RESOURCE_NAME,
1139+
prediction_type=_TEST_MODEL_EVAL_PREDICTION_TYPE,
1140+
instances_format=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[
1141+
"batch_predict_instances_format"
1142+
],
1143+
model_type="automl_tabular",
1144+
pipeline_root=_TEST_GCS_BUCKET_NAME,
1145+
target_field_name=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[
1146+
"target_field_name"
1147+
],
1148+
evaluation_pipeline_display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME,
1149+
gcs_source_uris=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[
1150+
"batch_predict_gcs_source_uris"
1151+
],
1152+
job_id=_TEST_PIPELINE_JOB_ID,
1153+
service_account=_TEST_SERVICE_ACCOUNT,
1154+
network=_TEST_NETWORK,
1155+
enable_caching=False,
1156+
)
1157+
1158+
test_model_eval_job.wait()
1159+
1160+
expected_runtime_config_dict = {
1161+
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
1162+
"parameters": {
1163+
"batch_predict_gcs_source_uris": {
1164+
"stringValue": '["gs://my-bucket/my-prediction-data.csv"]'
1165+
},
1166+
"dataflow_service_account": {"stringValue": _TEST_SERVICE_ACCOUNT},
1167+
"batch_predict_instances_format": {"stringValue": "csv"},
1168+
"model_name": {"stringValue": _TEST_MODEL_RESOURCE_NAME},
1169+
"project": {"stringValue": _TEST_PROJECT},
1170+
"location": {"stringValue": _TEST_LOCATION},
1171+
"batch_predict_gcs_destination_output_uri": {
1172+
"stringValue": _TEST_GCS_BUCKET_NAME
1173+
},
1174+
"target_field_name": {"stringValue": "predict_class"},
1175+
},
1176+
}
1177+
1178+
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
1179+
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
1180+
1181+
job_spec = yaml.safe_load(job_spec)
1182+
pipeline_spec = job_spec.get("pipelineSpec") or job_spec
1183+
1184+
# Construct expected request
1185+
expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob(
1186+
display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME,
1187+
pipeline_spec={
1188+
"components": {},
1189+
"pipelineInfo": pipeline_spec["pipelineInfo"],
1190+
"root": pipeline_spec["root"],
1191+
"schemaVersion": "2.0.0",
1192+
"sdkVersion": "kfp-1.8.12",
1193+
},
1194+
runtime_config=runtime_config,
1195+
service_account=_TEST_SERVICE_ACCOUNT,
1196+
network=_TEST_NETWORK,
1197+
template_uri=_TEST_KFP_TEMPLATE_URI,
1198+
)
1199+
1200+
mock_model_eval_job_create.assert_called_with(
1201+
parent=_TEST_PARENT,
1202+
pipeline_job=expected_gapic_pipeline_job,
1203+
pipeline_job_id=_TEST_PIPELINE_JOB_ID,
1204+
timeout=None,
1205+
)
1206+
1207+
assert mock_model_eval_job_get.called_once
1208+
1209+
assert mock_pipeline_service_get.called_once
1210+
1211+
assert mock_model_eval_job_get.called_once
1212+
10861213
@pytest.mark.parametrize(
10871214
"job_spec",
10881215
[_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON],

0 commit comments

Comments
 (0)