Skip to content

Commit 5fe59a4

Browse files
feat: add additional_experiement flag in the tables and forecasting training job (#979)
* Update training_jobs.py * Update test_automl_forecasting_training_jobs.py * Update training_jobs.py * Update test_automl_tabular_training_jobs.py * Update test_automl_forecasting_training_jobs.py * Update test_automl_tabular_training_jobs.py * Update google/cloud/aiplatform/training_jobs.py Co-authored-by: sasha-gitg <[email protected]> * Update google/cloud/aiplatform/training_jobs.py Co-authored-by: sasha-gitg <[email protected]> * Update test_automl_forecasting_training_jobs.py * Update test_automl_tabular_training_jobs.py * Update training_jobs.py * Update training_jobs.py Co-authored-by: sasha-gitg <[email protected]>
1 parent 5ee6354 commit 5fe59a4

File tree

3 files changed

+21
-264
lines changed

3 files changed

+21
-264
lines changed

google/cloud/aiplatform/training_jobs.py

+12-196
Original file line numberDiff line numberDiff line change
@@ -3371,6 +3371,7 @@ def run(
33713371
export_evaluated_data_items: bool = False,
33723372
export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
33733373
export_evaluated_data_items_override_destination: bool = False,
3374+
additional_experiments: Optional[List[str]] = None,
33743375
sync: bool = True,
33753376
) -> models.Model:
33763377
"""Runs the training job and returns a model.
@@ -3497,6 +3498,8 @@ def run(
34973498
34983499
Applies only if [export_evaluated_data_items] is True and
34993500
[export_evaluated_data_items_bigquery_destination_uri] is specified.
3501+
additional_experiments (List[str]):
3502+
Optional. Additional experiment flags for the automl tables training.
35003503
sync (bool):
35013504
Whether to execute this method synchronously. If False, this method
35023505
will be executed in concurrent Future and any downstream object will
@@ -3519,6 +3522,9 @@ def run(
35193522
if self._has_run:
35203523
raise RuntimeError("AutoML Tabular Training has already run.")
35213524

3525+
if additional_experiments:
3526+
self._add_additional_experiments(additional_experiments)
3527+
35223528
return self._run(
35233529
dataset=dataset,
35243530
target_column=target_column,
@@ -3961,6 +3967,7 @@ def run(
39613967
budget_milli_node_hours: int = 1000,
39623968
model_display_name: Optional[str] = None,
39633969
model_labels: Optional[Dict[str, str]] = None,
3970+
additional_experiments: Optional[List[str]] = None,
39643971
sync: bool = True,
39653972
) -> models.Model:
39663973
"""Runs the training job and returns a model.
@@ -4107,6 +4114,8 @@ def run(
41074114
are allowed.
41084115
See https://goo.gl/xmQnxf for more information
41094116
and examples of labels.
4117+
additional_experiments (List[str]):
4118+
Optional. Additional experiment flags for the time series forcasting training.
41104119
sync (bool):
41114120
Whether to execute this method synchronously. If False, this method
41124121
will be executed in concurrent Future and any downstream object will
@@ -4132,6 +4141,9 @@ def run(
41324141
if self._has_run:
41334142
raise RuntimeError("AutoML Forecasting Training has already run.")
41344143

4144+
if additional_experiments:
4145+
self._add_additional_experiments(additional_experiments)
4146+
41354147
return self._run(
41364148
dataset=dataset,
41374149
target_column=target_column,
@@ -4160,202 +4172,6 @@ def run(
41604172
sync=sync,
41614173
)
41624174

4163-
def _run_with_experiments(
4164-
self,
4165-
dataset: datasets.TimeSeriesDataset,
4166-
target_column: str,
4167-
time_column: str,
4168-
time_series_identifier_column: str,
4169-
unavailable_at_forecast_columns: List[str],
4170-
available_at_forecast_columns: List[str],
4171-
forecast_horizon: int,
4172-
data_granularity_unit: str,
4173-
data_granularity_count: int,
4174-
predefined_split_column_name: Optional[str] = None,
4175-
weight_column: Optional[str] = None,
4176-
time_series_attribute_columns: Optional[List[str]] = None,
4177-
context_window: Optional[int] = None,
4178-
export_evaluated_data_items: bool = False,
4179-
export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
4180-
export_evaluated_data_items_override_destination: bool = False,
4181-
quantiles: Optional[List[float]] = None,
4182-
validation_options: Optional[str] = None,
4183-
budget_milli_node_hours: int = 1000,
4184-
model_display_name: Optional[str] = None,
4185-
model_labels: Optional[Dict[str, str]] = None,
4186-
sync: bool = True,
4187-
additional_experiments: Optional[List[str]] = None,
4188-
) -> models.Model:
4189-
"""Runs the training job with experiment flags and returns a model.
4190-
4191-
The training data splits are set by default: Roughly 80% will be used for training,
4192-
10% for validation, and 10% for test.
4193-
4194-
Args:
4195-
dataset (datasets.TimeSeriesDataset):
4196-
Required. The dataset within the same Project from which data will be used to train the Model. The
4197-
Dataset must use schema compatible with Model being trained,
4198-
and what is compatible should be described in the used
4199-
TrainingPipeline's [training_task_definition]
4200-
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
4201-
For time series Datasets, all their data is exported to
4202-
training, to pick and choose from.
4203-
target_column (str):
4204-
Required. Name of the column that the Model is to predict values for.
4205-
time_column (str):
4206-
Required. Name of the column that identifies time order in the time series.
4207-
time_series_identifier_column (str):
4208-
Required. Name of the column that identifies the time series.
4209-
unavailable_at_forecast_columns (List[str]):
4210-
Required. Column names of columns that are unavailable at forecast.
4211-
Each column contains information for the given entity (identified by the
4212-
[time_series_identifier_column]) that is unknown before the forecast
4213-
(e.g. population of a city in a given year, or weather on a given day).
4214-
available_at_forecast_columns (List[str]):
4215-
Required. Column names of columns that are available at forecast.
4216-
Each column contains information for the given entity (identified by the
4217-
[time_series_identifier_column]) that is known at forecast.
4218-
forecast_horizon: (int):
4219-
Required. The amount of time into the future for which forecasted values for the target are
4220-
returned. Expressed in number of units defined by the [data_granularity_unit] and
4221-
[data_granularity_count] field. Inclusive.
4222-
data_granularity_unit (str):
4223-
Required. The data granularity unit. Accepted values are ``minute``,
4224-
``hour``, ``day``, ``week``, ``month``, ``year``.
4225-
data_granularity_count (int):
4226-
Required. The number of data granularity units between data points in the training
4227-
data. If [data_granularity_unit] is `minute`, can be 1, 5, 10, 15, or 30. For all other
4228-
values of [data_granularity_unit], must be 1.
4229-
predefined_split_column_name (str):
4230-
Optional. The key is a name of one of the Dataset's data
4231-
columns. The value of the key (either the label's value or
4232-
value in the column) must be one of {``TRAIN``,
4233-
``VALIDATE``, ``TEST``}, and it defines to which set the
4234-
given piece of data is assigned. If for a piece of data the
4235-
key is not present or has an invalid value, that piece is
4236-
ignored by the pipeline.
4237-
4238-
Supported only for tabular and time series Datasets.
4239-
weight_column (str):
4240-
Optional. Name of the column that should be used as the weight column.
4241-
Higher values in this column give more importance to the row
4242-
during Model training. The column must have numeric values between 0 and
4243-
10000 inclusively, and 0 value means that the row is ignored.
4244-
If the weight column field is not set, then all rows are assumed to have
4245-
equal weight of 1.
4246-
time_series_attribute_columns (List[str]):
4247-
Optional. Column names that should be used as attribute columns.
4248-
Each column is constant within a time series.
4249-
context_window (int):
4250-
Optional. The amount of time into the past training and prediction data is used for
4251-
model training and prediction respectively. Expressed in number of units defined by the
4252-
[data_granularity_unit] and [data_granularity_count] fields. When not provided uses the
4253-
default value of 0 which means the model sets each series context window to be 0 (also
4254-
known as "cold start"). Inclusive.
4255-
export_evaluated_data_items (bool):
4256-
Whether to export the test set predictions to a BigQuery table.
4257-
If False, then the export is not performed.
4258-
export_evaluated_data_items_bigquery_destination_uri (string):
4259-
Optional. URI of desired destination BigQuery table for exported test set predictions.
4260-
4261-
Expected format:
4262-
``bq://<project_id>:<dataset_id>:<table>``
4263-
4264-
If not specified, then results are exported to the following auto-created BigQuery
4265-
table:
4266-
``<project_id>:export_evaluated_examples_<model_name>_<yyyy_MM_dd'T'HH_mm_ss_SSS'Z'>.evaluated_examples``
4267-
4268-
Applies only if [export_evaluated_data_items] is True.
4269-
export_evaluated_data_items_override_destination (bool):
4270-
Whether to override the contents of [export_evaluated_data_items_bigquery_destination_uri],
4271-
if the table exists, for exported test set predictions. If False, and the
4272-
table exists, then the training job will fail.
4273-
4274-
Applies only if [export_evaluated_data_items] is True and
4275-
[export_evaluated_data_items_bigquery_destination_uri] is specified.
4276-
quantiles (List[float]):
4277-
Quantiles to use for the `minizmize-quantile-loss`
4278-
[AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in
4279-
this case.
4280-
4281-
Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive.
4282-
Each quantile must be unique.
4283-
validation_options (str):
4284-
Validation options for the data validation component. The available options are:
4285-
"fail-pipeline" - (default), will validate against the validation and fail the pipeline
4286-
if it fails.
4287-
"ignore-validation" - ignore the results of the validation and continue the pipeline
4288-
budget_milli_node_hours (int):
4289-
Optional. The train budget of creating this Model, expressed in milli node
4290-
hours i.e. 1,000 value in this field means 1 node hour.
4291-
The training cost of the model will not exceed this budget. The final
4292-
cost will be attempted to be close to the budget, though may end up
4293-
being (even) noticeably smaller - at the backend's discretion. This
4294-
especially may happen when further model training ceases to provide
4295-
any improvements.
4296-
If the budget is set to a value known to be insufficient to train a
4297-
Model for the given training set, the training won't be attempted and
4298-
will error.
4299-
The minimum value is 1000 and the maximum is 72000.
4300-
model_display_name (str):
4301-
Optional. If the script produces a managed Vertex AI Model. The display name of
4302-
the Model. The name can be up to 128 characters long and can be consist
4303-
of any UTF-8 characters.
4304-
4305-
If not provided upon creation, the job's display_name is used.
4306-
model_labels (Dict[str, str]):
4307-
Optional. The labels with user-defined metadata to
4308-
organize your Models.
4309-
Label keys and values can be no longer than 64
4310-
characters (Unicode codepoints), can only
4311-
contain lowercase letters, numeric characters,
4312-
underscores and dashes. International characters
4313-
are allowed.
4314-
See https://goo.gl/xmQnxf for more information
4315-
and examples of labels.
4316-
sync (bool):
4317-
Whether to execute this method synchronously. If False, this method
4318-
will be executed in concurrent Future and any downstream object will
4319-
be immediately returned and synced when the Future has completed.
4320-
additional_experiments (List[str]):
4321-
Additional experiment flags for the time series forcasting training.
4322-
4323-
Returns:
4324-
model: The trained Vertex AI Model resource or None if training did not
4325-
produce a Vertex AI Model.
4326-
4327-
Raises:
4328-
RuntimeError: If Training job has already been run or is waiting to run.
4329-
"""
4330-
4331-
if additional_experiments:
4332-
self._add_additional_experiments(additional_experiments)
4333-
4334-
return self.run(
4335-
dataset=dataset,
4336-
target_column=target_column,
4337-
time_column=time_column,
4338-
time_series_identifier_column=time_series_identifier_column,
4339-
unavailable_at_forecast_columns=unavailable_at_forecast_columns,
4340-
available_at_forecast_columns=available_at_forecast_columns,
4341-
forecast_horizon=forecast_horizon,
4342-
data_granularity_unit=data_granularity_unit,
4343-
data_granularity_count=data_granularity_count,
4344-
predefined_split_column_name=predefined_split_column_name,
4345-
weight_column=weight_column,
4346-
time_series_attribute_columns=time_series_attribute_columns,
4347-
context_window=context_window,
4348-
budget_milli_node_hours=budget_milli_node_hours,
4349-
export_evaluated_data_items=export_evaluated_data_items,
4350-
export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
4351-
export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
4352-
quantiles=quantiles,
4353-
validation_options=validation_options,
4354-
model_display_name=model_display_name,
4355-
model_labels=model_labels,
4356-
sync=sync,
4357-
)
4358-
43594175
@base.optional_sync()
43604176
def _run(
43614177
self,

tests/unit/aiplatform/test_automl_forecasting_training_jobs.py

+7-67
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,7 @@
9191
"validationOptions": _TEST_TRAINING_VALIDATION_OPTIONS,
9292
"optimizationObjective": _TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
9393
}
94-
_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict(
95-
_TEST_TRAINING_TASK_INPUTS_DICT, struct_pb2.Value(),
96-
)
94+
9795
_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS = json_format.ParseDict(
9896
{
9997
**_TEST_TRAINING_TASK_INPUTS_DICT,
@@ -102,6 +100,10 @@
102100
struct_pb2.Value(),
103101
)
104102

103+
_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict(
104+
_TEST_TRAINING_TASK_INPUTS_DICT, struct_pb2.Value(),
105+
)
106+
105107
_TEST_DATASET_NAME = "test-dataset-name"
106108

107109
_TEST_MODEL_DISPLAY_NAME = "model-display-name"
@@ -269,6 +271,7 @@ def test_run_call_pipeline_service_create(
269271
export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION,
270272
quantiles=_TEST_TRAINING_QUANTILES,
271273
validation_options=_TEST_TRAINING_VALIDATION_OPTIONS,
274+
additional_experiments=_TEST_ADDITIONAL_EXPERIMENTS,
272275
sync=sync,
273276
)
274277

@@ -290,7 +293,7 @@ def test_run_call_pipeline_service_create(
290293
display_name=_TEST_DISPLAY_NAME,
291294
labels=_TEST_LABELS,
292295
training_task_definition=schema.training_job.definition.automl_forecasting,
293-
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
296+
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS,
294297
model_to_upload=true_managed_model,
295298
input_data_config=true_input_data_config,
296299
)
@@ -380,69 +383,6 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
380383
training_pipeline=true_training_pipeline,
381384
)
382385

383-
@pytest.mark.usefixtures("mock_pipeline_service_get")
384-
@pytest.mark.parametrize("sync", [True, False])
385-
def test_run_with_experiments(
386-
self,
387-
mock_pipeline_service_create,
388-
mock_dataset_time_series,
389-
mock_model_service_get,
390-
sync,
391-
):
392-
aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME)
393-
394-
job = AutoMLForecastingTrainingJob(
395-
display_name=_TEST_DISPLAY_NAME,
396-
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
397-
column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS,
398-
)
399-
400-
model_from_job = job._run_with_experiments(
401-
dataset=mock_dataset_time_series,
402-
target_column=_TEST_TRAINING_TARGET_COLUMN,
403-
time_column=_TEST_TRAINING_TIME_COLUMN,
404-
time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
405-
unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
406-
available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
407-
forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON,
408-
data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT,
409-
data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT,
410-
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
411-
time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS,
412-
context_window=_TEST_TRAINING_CONTEXT_WINDOW,
413-
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
414-
export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS,
415-
export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI,
416-
export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION,
417-
quantiles=_TEST_TRAINING_QUANTILES,
418-
validation_options=_TEST_TRAINING_VALIDATION_OPTIONS,
419-
sync=sync,
420-
additional_experiments=_TEST_ADDITIONAL_EXPERIMENTS,
421-
)
422-
423-
if not sync:
424-
model_from_job.wait()
425-
426-
# Test that if defaults to the job display name
427-
true_managed_model = gca_model.Model(display_name=_TEST_DISPLAY_NAME)
428-
429-
true_input_data_config = gca_training_pipeline.InputDataConfig(
430-
dataset_id=mock_dataset_time_series.name,
431-
)
432-
433-
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
434-
display_name=_TEST_DISPLAY_NAME,
435-
training_task_definition=schema.training_job.definition.automl_forecasting,
436-
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS,
437-
model_to_upload=true_managed_model,
438-
input_data_config=true_input_data_config,
439-
)
440-
441-
mock_pipeline_service_create.assert_called_once_with(
442-
parent=initializer.global_config.common_location_path(),
443-
training_pipeline=true_training_pipeline,
444-
)
445-
446386
@pytest.mark.usefixtures("mock_pipeline_service_get")
447387
@pytest.mark.parametrize("sync", [True, False])
448388
def test_run_call_pipeline_if_set_additional_experiments(

tests/unit/aiplatform/test_automl_tabular_training_jobs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def test_run_call_pipeline_service_create(
330330
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
331331
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
332332
disable_early_stopping=_TEST_TRAINING_DISABLE_EARLY_STOPPING,
333+
additional_experiments=_TEST_ADDITIONAL_EXPERIMENTS,
333334
sync=sync,
334335
)
335336

@@ -354,7 +355,7 @@ def test_run_call_pipeline_service_create(
354355
display_name=_TEST_DISPLAY_NAME,
355356
labels=_TEST_LABELS,
356357
training_task_definition=schema.training_job.definition.automl_tabular,
357-
training_task_inputs=_TEST_TRAINING_TASK_INPUTS,
358+
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS,
358359
model_to_upload=true_managed_model,
359360
input_data_config=true_input_data_config,
360361
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,

0 commit comments

Comments
 (0)