@@ -4037,6 +4037,13 @@ def run(
4037
4037
model_display_name : Optional [str ] = None ,
4038
4038
model_labels : Optional [Dict [str , str ]] = None ,
4039
4039
additional_experiments : Optional [List [str ]] = None ,
4040
+ hierarchy_group_columns : Optional [List [str ]] = None ,
4041
+ hierarchy_group_total_weight : Optional [float ] = None ,
4042
+ hierarchy_temporal_total_weight : Optional [float ] = None ,
4043
+ hierarchy_group_temporal_total_weight : Optional [float ] = None ,
4044
+ window_column : Optional [str ] = None ,
4045
+ window_stride_length : Optional [int ] = None ,
4046
+ window_max_count : Optional [int ] = None ,
4040
4047
sync : bool = True ,
4041
4048
create_request_timeout : Optional [float ] = None ,
4042
4049
) -> models .Model :
@@ -4157,7 +4164,7 @@ def run(
4157
4164
Applies only if [export_evaluated_data_items] is True and
4158
4165
[export_evaluated_data_items_bigquery_destination_uri] is specified.
4159
4166
quantiles (List[float]):
4160
- Quantiles to use for the `minimize-quantile-loss`
4167
+ Quantiles to use for the `` minimize-quantile-loss` `
4161
4168
[AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in
4162
4169
this case.
4163
4170
@@ -4200,6 +4207,37 @@ def run(
4200
4207
Optional. Additional experiment flags for the time series forcasting training.
4201
4208
create_request_timeout (float):
4202
4209
Optional. The timeout for the create request in seconds.
4210
+ hierarchy_group_columns (List[str]):
4211
+ Optional. A list of time series attribute column names that
4212
+ define the time series hierarchy. Only one level of hierarchy is
4213
+ supported, ex. ``region`` for a hierarchy of stores or
4214
+ ``department`` for a hierarchy of products. If multiple columns
4215
+ are specified, time series will be grouped by their combined
4216
+ values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
4217
+ to 5 columns are accepted. If no group columns are specified,
4218
+ all time series are considered to be part of the same group.
4219
+ hierarchy_group_total_weight (float):
4220
+ Optional. The weight of the loss for predictions aggregated over
4221
+ time series in the same hierarchy group.
4222
+ hierarchy_temporal_total_weight (float):
4223
+ Optional. The weight of the loss for predictions aggregated over
4224
+ the horizon for a single time series.
4225
+ hierarchy_group_temporal_total_weight (float):
4226
+ Optional. The weight of the loss for predictions aggregated over
4227
+ both the horizon and time series in the same hierarchy group.
4228
+ window_column (str):
4229
+ Optional. Name of the column that should be used to filter input
4230
+ rows. The column should contain either booleans or string
4231
+ booleans; if the value of the row is True, generate a sliding
4232
+ window from that row.
4233
+ window_stride_length (int):
4234
+ Optional. Step length used to generate input examples. Every
4235
+ ``window_stride_length`` rows will be used to generate a sliding
4236
+ window.
4237
+ window_max_count (int):
4238
+ Optional. Number of rows that should be used to generate input
4239
+ examples. If the total row count is larger than this number, the
4240
+ input data will be randomly sampled to hit the count.
4203
4241
sync (bool):
4204
4242
Whether to execute this method synchronously. If False, this method
4205
4243
will be executed in concurrent Future and any downstream object will
@@ -4254,6 +4292,13 @@ def run(
4254
4292
validation_options = validation_options ,
4255
4293
model_display_name = model_display_name ,
4256
4294
model_labels = model_labels ,
4295
+ hierarchy_group_columns = hierarchy_group_columns ,
4296
+ hierarchy_group_total_weight = hierarchy_group_total_weight ,
4297
+ hierarchy_temporal_total_weight = hierarchy_temporal_total_weight ,
4298
+ hierarchy_group_temporal_total_weight = hierarchy_group_temporal_total_weight ,
4299
+ window_column = window_column ,
4300
+ window_stride_length = window_stride_length ,
4301
+ window_max_count = window_max_count ,
4257
4302
sync = sync ,
4258
4303
create_request_timeout = create_request_timeout ,
4259
4304
)
@@ -4286,6 +4331,13 @@ def _run(
4286
4331
budget_milli_node_hours : int = 1000 ,
4287
4332
model_display_name : Optional [str ] = None ,
4288
4333
model_labels : Optional [Dict [str , str ]] = None ,
4334
+ hierarchy_group_columns : Optional [List [str ]] = None ,
4335
+ hierarchy_group_total_weight : Optional [float ] = None ,
4336
+ hierarchy_temporal_total_weight : Optional [float ] = None ,
4337
+ hierarchy_group_temporal_total_weight : Optional [float ] = None ,
4338
+ window_column : Optional [str ] = None ,
4339
+ window_stride_length : Optional [int ] = None ,
4340
+ window_max_count : Optional [int ] = None ,
4289
4341
sync : bool = True ,
4290
4342
create_request_timeout : Optional [float ] = None ,
4291
4343
) -> models .Model :
@@ -4453,6 +4505,37 @@ def _run(
4453
4505
are allowed.
4454
4506
See https://goo.gl/xmQnxf for more information
4455
4507
and examples of labels.
4508
+ hierarchy_group_columns (List[str]):
4509
+ Optional. A list of time series attribute column names that
4510
+ define the time series hierarchy. Only one level of hierarchy is
4511
+ supported, ex. ``region`` for a hierarchy of stores or
4512
+ ``department`` for a hierarchy of products. If multiple columns
4513
+ are specified, time series will be grouped by their combined
4514
+ values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
4515
+ to 5 columns are accepted. If no group columns are specified,
4516
+ all time series are considered to be part of the same group.
4517
+ hierarchy_group_total_weight (float):
4518
+ Optional. The weight of the loss for predictions aggregated over
4519
+ time series in the same hierarchy group.
4520
+ hierarchy_temporal_total_weight (float):
4521
+ Optional. The weight of the loss for predictions aggregated over
4522
+ the horizon for a single time series.
4523
+ hierarchy_group_temporal_total_weight (float):
4524
+ Optional. The weight of the loss for predictions aggregated over
4525
+ both the horizon and time series in the same hierarchy group.
4526
+ window_column (str):
4527
+ Optional. Name of the column that should be used to filter input
4528
+ rows. The column should contain either booleans or string
4529
+ booleans; if the value of the row is True, generate a sliding
4530
+ window from that row.
4531
+ window_stride_length (int):
4532
+ Optional. Step length used to generate input examples. Every
4533
+ ``window_stride_length`` rows will be used to generate a sliding
4534
+ window.
4535
+ window_max_count (int):
4536
+ Optional. Number of rows that should be used to generate input
4537
+ examples. If the total row count is larger than this number, the
4538
+ input data will be randomly sampled to hit the count.
4456
4539
sync (bool):
4457
4540
Whether to execute this method synchronously. If False, this method
4458
4541
will be executed in concurrent Future and any downstream object will
@@ -4482,6 +4565,12 @@ def _run(
4482
4565
% column_names
4483
4566
)
4484
4567
4568
+ window_config = self ._create_window_config (
4569
+ column = window_column ,
4570
+ stride_length = window_stride_length ,
4571
+ max_count = window_max_count ,
4572
+ )
4573
+
4485
4574
training_task_inputs_dict = {
4486
4575
# required inputs
4487
4576
"targetColumn" : target_column ,
@@ -4505,6 +4594,24 @@ def _run(
4505
4594
"optimizationObjective" : self ._optimization_objective ,
4506
4595
}
4507
4596
4597
+ # TODO(TheMichaelHu): Remove the ifs once the API supports these inputs.
4598
+ if any (
4599
+ [
4600
+ hierarchy_group_columns ,
4601
+ hierarchy_group_total_weight ,
4602
+ hierarchy_temporal_total_weight ,
4603
+ hierarchy_group_temporal_total_weight ,
4604
+ ]
4605
+ ):
4606
+ training_task_inputs_dict ["hierarchyConfig" ] = {
4607
+ "groupColumns" : hierarchy_group_columns ,
4608
+ "groupTotalWeight" : hierarchy_group_total_weight ,
4609
+ "temporalTotalWeight" : hierarchy_temporal_total_weight ,
4610
+ "groupTemporalTotalWeight" : hierarchy_group_temporal_total_weight ,
4611
+ }
4612
+ if window_config :
4613
+ training_task_inputs_dict ["windowConfig" ] = window_config
4614
+
4508
4615
final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri
4509
4616
if final_export_eval_bq_uri and not final_export_eval_bq_uri .startswith (
4510
4617
"bq://"
@@ -4582,6 +4689,29 @@ def _add_additional_experiments(self, additional_experiments: List[str]):
4582
4689
"""
4583
4690
self ._additional_experiments .extend (additional_experiments )
4584
4691
4692
+ @staticmethod
4693
+ def _create_window_config (
4694
+ column : Optional [str ] = None ,
4695
+ stride_length : Optional [int ] = None ,
4696
+ max_count : Optional [int ] = None ,
4697
+ ) -> Optional [Dict [str , Union [int , str ]]]:
4698
+ """Creates a window config from training job arguments."""
4699
+ configs = {
4700
+ "column" : column ,
4701
+ "strideLength" : stride_length ,
4702
+ "maxCount" : max_count ,
4703
+ }
4704
+ present_configs = {k : v for k , v in configs .items () if v is not None }
4705
+ if not present_configs :
4706
+ return None
4707
+ if len (present_configs ) > 1 :
4708
+ raise ValueError (
4709
+ "More than one windowing strategy provided. Make sure only one "
4710
+ "of window_column, window_stride_length, or window_max_count "
4711
+ "is specified."
4712
+ )
4713
+ return present_configs
4714
+
4585
4715
4586
4716
class AutoMLImageTrainingJob (_TrainingJob ):
4587
4717
_supported_training_schemas = (
0 commit comments