Skip to content

Commit 8560fa8

Browse files
authored
feat: Add hierarchy and window configs to Vertex Forecasting training job (#1255)
Adds support for hierarchical forecasting and window filtering. --- Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) Fixes b/229907889 b/228499154 🦕
1 parent e82c179 commit 8560fa8

File tree

2 files changed

+224
-1
lines changed

2 files changed

+224
-1
lines changed

google/cloud/aiplatform/training_jobs.py

+131-1
Original file line numberDiff line numberDiff line change
@@ -4037,6 +4037,13 @@ def run(
40374037
model_display_name: Optional[str] = None,
40384038
model_labels: Optional[Dict[str, str]] = None,
40394039
additional_experiments: Optional[List[str]] = None,
4040+
hierarchy_group_columns: Optional[List[str]] = None,
4041+
hierarchy_group_total_weight: Optional[float] = None,
4042+
hierarchy_temporal_total_weight: Optional[float] = None,
4043+
hierarchy_group_temporal_total_weight: Optional[float] = None,
4044+
window_column: Optional[str] = None,
4045+
window_stride_length: Optional[int] = None,
4046+
window_max_count: Optional[int] = None,
40404047
sync: bool = True,
40414048
create_request_timeout: Optional[float] = None,
40424049
) -> models.Model:
@@ -4157,7 +4164,7 @@ def run(
41574164
Applies only if [export_evaluated_data_items] is True and
41584165
[export_evaluated_data_items_bigquery_destination_uri] is specified.
41594166
quantiles (List[float]):
4160-
Quantiles to use for the `minimize-quantile-loss`
4167+
Quantiles to use for the ``minimize-quantile-loss``
41614168
[AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in
41624169
this case.
41634170
@@ -4200,6 +4207,37 @@ def run(
42004207
Optional. Additional experiment flags for the time series forcasting training.
42014208
create_request_timeout (float):
42024209
Optional. The timeout for the create request in seconds.
4210+
hierarchy_group_columns (List[str]):
4211+
Optional. A list of time series attribute column names that
4212+
define the time series hierarchy. Only one level of hierarchy is
4213+
supported, ex. ``region`` for a hierarchy of stores or
4214+
``department`` for a hierarchy of products. If multiple columns
4215+
are specified, time series will be grouped by their combined
4216+
values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
4217+
to 5 columns are accepted. If no group columns are specified,
4218+
all time series are considered to be part of the same group.
4219+
hierarchy_group_total_weight (float):
4220+
Optional. The weight of the loss for predictions aggregated over
4221+
time series in the same hierarchy group.
4222+
hierarchy_temporal_total_weight (float):
4223+
Optional. The weight of the loss for predictions aggregated over
4224+
the horizon for a single time series.
4225+
hierarchy_group_temporal_total_weight (float):
4226+
Optional. The weight of the loss for predictions aggregated over
4227+
both the horizon and time series in the same hierarchy group.
4228+
window_column (str):
4229+
Optional. Name of the column that should be used to filter input
4230+
rows. The column should contain either booleans or string
4231+
booleans; if the value of the row is True, generate a sliding
4232+
window from that row.
4233+
window_stride_length (int):
4234+
Optional. Step length used to generate input examples. Every
4235+
``window_stride_length`` rows will be used to generate a sliding
4236+
window.
4237+
window_max_count (int):
4238+
Optional. Number of rows that should be used to generate input
4239+
examples. If the total row count is larger than this number, the
4240+
input data will be randomly sampled to hit the count.
42034241
sync (bool):
42044242
Whether to execute this method synchronously. If False, this method
42054243
will be executed in concurrent Future and any downstream object will
@@ -4254,6 +4292,13 @@ def run(
42544292
validation_options=validation_options,
42554293
model_display_name=model_display_name,
42564294
model_labels=model_labels,
4295+
hierarchy_group_columns=hierarchy_group_columns,
4296+
hierarchy_group_total_weight=hierarchy_group_total_weight,
4297+
hierarchy_temporal_total_weight=hierarchy_temporal_total_weight,
4298+
hierarchy_group_temporal_total_weight=hierarchy_group_temporal_total_weight,
4299+
window_column=window_column,
4300+
window_stride_length=window_stride_length,
4301+
window_max_count=window_max_count,
42574302
sync=sync,
42584303
create_request_timeout=create_request_timeout,
42594304
)
@@ -4286,6 +4331,13 @@ def _run(
42864331
budget_milli_node_hours: int = 1000,
42874332
model_display_name: Optional[str] = None,
42884333
model_labels: Optional[Dict[str, str]] = None,
4334+
hierarchy_group_columns: Optional[List[str]] = None,
4335+
hierarchy_group_total_weight: Optional[float] = None,
4336+
hierarchy_temporal_total_weight: Optional[float] = None,
4337+
hierarchy_group_temporal_total_weight: Optional[float] = None,
4338+
window_column: Optional[str] = None,
4339+
window_stride_length: Optional[int] = None,
4340+
window_max_count: Optional[int] = None,
42894341
sync: bool = True,
42904342
create_request_timeout: Optional[float] = None,
42914343
) -> models.Model:
@@ -4453,6 +4505,37 @@ def _run(
44534505
are allowed.
44544506
See https://goo.gl/xmQnxf for more information
44554507
and examples of labels.
4508+
hierarchy_group_columns (List[str]):
4509+
Optional. A list of time series attribute column names that
4510+
define the time series hierarchy. Only one level of hierarchy is
4511+
supported, ex. ``region`` for a hierarchy of stores or
4512+
``department`` for a hierarchy of products. If multiple columns
4513+
are specified, time series will be grouped by their combined
4514+
values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
4515+
to 5 columns are accepted. If no group columns are specified,
4516+
all time series are considered to be part of the same group.
4517+
hierarchy_group_total_weight (float):
4518+
Optional. The weight of the loss for predictions aggregated over
4519+
time series in the same hierarchy group.
4520+
hierarchy_temporal_total_weight (float):
4521+
Optional. The weight of the loss for predictions aggregated over
4522+
the horizon for a single time series.
4523+
hierarchy_group_temporal_total_weight (float):
4524+
Optional. The weight of the loss for predictions aggregated over
4525+
both the horizon and time series in the same hierarchy group.
4526+
window_column (str):
4527+
Optional. Name of the column that should be used to filter input
4528+
rows. The column should contain either booleans or string
4529+
booleans; if the value of the row is True, generate a sliding
4530+
window from that row.
4531+
window_stride_length (int):
4532+
Optional. Step length used to generate input examples. Every
4533+
``window_stride_length`` rows will be used to generate a sliding
4534+
window.
4535+
window_max_count (int):
4536+
Optional. Number of rows that should be used to generate input
4537+
examples. If the total row count is larger than this number, the
4538+
input data will be randomly sampled to hit the count.
44564539
sync (bool):
44574540
Whether to execute this method synchronously. If False, this method
44584541
will be executed in concurrent Future and any downstream object will
@@ -4482,6 +4565,12 @@ def _run(
44824565
% column_names
44834566
)
44844567

4568+
window_config = self._create_window_config(
4569+
column=window_column,
4570+
stride_length=window_stride_length,
4571+
max_count=window_max_count,
4572+
)
4573+
44854574
training_task_inputs_dict = {
44864575
# required inputs
44874576
"targetColumn": target_column,
@@ -4505,6 +4594,24 @@ def _run(
45054594
"optimizationObjective": self._optimization_objective,
45064595
}
45074596

4597+
# TODO(TheMichaelHu): Remove the ifs once the API supports these inputs.
4598+
if any(
4599+
[
4600+
hierarchy_group_columns,
4601+
hierarchy_group_total_weight,
4602+
hierarchy_temporal_total_weight,
4603+
hierarchy_group_temporal_total_weight,
4604+
]
4605+
):
4606+
training_task_inputs_dict["hierarchyConfig"] = {
4607+
"groupColumns": hierarchy_group_columns,
4608+
"groupTotalWeight": hierarchy_group_total_weight,
4609+
"temporalTotalWeight": hierarchy_temporal_total_weight,
4610+
"groupTemporalTotalWeight": hierarchy_group_temporal_total_weight,
4611+
}
4612+
if window_config:
4613+
training_task_inputs_dict["windowConfig"] = window_config
4614+
45084615
final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri
45094616
if final_export_eval_bq_uri and not final_export_eval_bq_uri.startswith(
45104617
"bq://"
@@ -4582,6 +4689,29 @@ def _add_additional_experiments(self, additional_experiments: List[str]):
45824689
"""
45834690
self._additional_experiments.extend(additional_experiments)
45844691

4692+
@staticmethod
4693+
def _create_window_config(
4694+
column: Optional[str] = None,
4695+
stride_length: Optional[int] = None,
4696+
max_count: Optional[int] = None,
4697+
) -> Optional[Dict[str, Union[int, str]]]:
4698+
"""Creates a window config from training job arguments."""
4699+
configs = {
4700+
"column": column,
4701+
"strideLength": stride_length,
4702+
"maxCount": max_count,
4703+
}
4704+
present_configs = {k: v for k, v in configs.items() if v is not None}
4705+
if not present_configs:
4706+
return None
4707+
if len(present_configs) > 1:
4708+
raise ValueError(
4709+
"More than one windowing strategy provided. Make sure only one "
4710+
"of window_column, window_stride_length, or window_max_count "
4711+
"is specified."
4712+
)
4713+
return present_configs
4714+
45854715

45864716
class AutoMLImageTrainingJob(_TrainingJob):
45874717
_supported_training_schemas = (

0 commit comments

Comments
 (0)