File tree 5 files changed +14
-4
lines changed
5 files changed +14
-4
lines changed Original file line number Diff line number Diff line change 70
70
AutoMLForecastingTrainingJob ,
71
71
SequenceToSequencePlusForecastingTrainingJob ,
72
72
TemporalFusionTransformerForecastingTrainingJob ,
73
+ TimeSeriesDenseEncoderForecastingTrainingJob ,
73
74
AutoMLImageTrainingJob ,
74
75
AutoMLTextTrainingJob ,
75
76
AutoMLVideoTrainingJob ,
178
179
"TextDataset" ,
179
180
"TemporalFusionTransformerForecastingTrainingJob" ,
180
181
"TimeSeriesDataset" ,
182
+ "TimeSeriesDenseEncoderForecastingTrainingJob" ,
181
183
"VideoDataset" ,
182
184
)
Original file line number Diff line number Diff line change @@ -25,6 +25,7 @@ class definition:
25
25
automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml"
26
26
seq2seq_plus_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml"
27
27
tft_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/temporal_fusion_transformer_time_series_forecasting_1.0.0.yaml"
28
+ tide_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/time_series_dense_encoder_forecasting_1.0.0.yaml"
28
29
automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml"
29
30
automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml"
30
31
automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml"
Original file line number Diff line number Diff line change @@ -5257,6 +5257,14 @@ class TemporalFusionTransformerForecastingTrainingJob(_ForecastingTrainingJob):
5257
5257
_supported_training_schemas = (schema .training_job .definition .tft_forecasting ,)
5258
5258
5259
5259
5260
+ class TimeSeriesDenseEncoderForecastingTrainingJob (_ForecastingTrainingJob ):
5261
+ """Class to train Time series Dense Encoder (TiDE) forecasting models."""
5262
+
5263
+ _model_type = "TiDE"
5264
+ _training_task_definition = schema .training_job .definition .tide_forecasting
5265
+ _supported_training_schemas = (schema .training_job .definition .tide_forecasting ,)
5266
+
5267
+
5260
5268
class AutoMLImageTrainingJob (_TrainingJob ):
5261
5269
_supported_training_schemas = (
5262
5270
schema .training_job .definition .automl_image_classification ,
Original file line number Diff line number Diff line change @@ -42,10 +42,8 @@ class TestEndToEndForecasting(e2e_base.TestEndToEnd):
42
42
[
43
43
training_jobs .AutoMLForecastingTrainingJob ,
44
44
training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
45
- pytest .param (
46
- training_jobs .TemporalFusionTransformerForecastingTrainingJob ,
47
- marks = pytest .mark .skip (reason = "TFT not yet released." ),
48
- ),
45
+ training_jobs .TemporalFusionTransformerForecastingTrainingJob ,
46
+ training_jobs .TimeSeriesDenseEncoderForecastingTrainingJob ,
49
47
],
50
48
)
51
49
def test_end_to_end_forecasting (self , shared_state , training_job ):
Original file line number Diff line number Diff line change 187
187
training_jobs .AutoMLForecastingTrainingJob ,
188
188
training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
189
189
training_jobs .TemporalFusionTransformerForecastingTrainingJob ,
190
+ training_jobs .TimeSeriesDenseEncoderForecastingTrainingJob ,
190
191
]
191
192
192
193
You can’t perform that action at this time.
0 commit comments