Skip to content

Commit d8e6744

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Adds the Time series Dense Encoder (TiDE) forecasting job.
PiperOrigin-RevId: 524068121
1 parent 29d4e45 commit d8e6744

File tree

5 files changed

+14
-4
lines changed

5 files changed

+14
-4
lines changed

google/cloud/aiplatform/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
AutoMLForecastingTrainingJob,
7171
SequenceToSequencePlusForecastingTrainingJob,
7272
TemporalFusionTransformerForecastingTrainingJob,
73+
TimeSeriesDenseEncoderForecastingTrainingJob,
7374
AutoMLImageTrainingJob,
7475
AutoMLTextTrainingJob,
7576
AutoMLVideoTrainingJob,
@@ -178,5 +179,6 @@
178179
"TextDataset",
179180
"TemporalFusionTransformerForecastingTrainingJob",
180181
"TimeSeriesDataset",
182+
"TimeSeriesDenseEncoderForecastingTrainingJob",
181183
"VideoDataset",
182184
)

google/cloud/aiplatform/schema.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class definition:
2525
automl_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_time_series_forecasting_1.0.0.yaml"
2626
seq2seq_plus_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml"
2727
tft_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/temporal_fusion_transformer_time_series_forecasting_1.0.0.yaml"
28+
tide_forecasting = "gs://google-cloud-aiplatform/schema/trainingjob/definition/time_series_dense_encoder_forecasting_1.0.0.yaml"
2829
automl_image_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml"
2930
automl_image_object_detection = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml"
3031
automl_text_classification = "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml"

google/cloud/aiplatform/training_jobs.py

+8
Original file line numberDiff line numberDiff line change
@@ -5257,6 +5257,14 @@ class TemporalFusionTransformerForecastingTrainingJob(_ForecastingTrainingJob):
52575257
_supported_training_schemas = (schema.training_job.definition.tft_forecasting,)
52585258

52595259

5260+
class TimeSeriesDenseEncoderForecastingTrainingJob(_ForecastingTrainingJob):
5261+
"""Class to train Time series Dense Encoder (TiDE) forecasting models."""
5262+
5263+
_model_type = "TiDE"
5264+
_training_task_definition = schema.training_job.definition.tide_forecasting
5265+
_supported_training_schemas = (schema.training_job.definition.tide_forecasting,)
5266+
5267+
52605268
class AutoMLImageTrainingJob(_TrainingJob):
52615269
_supported_training_schemas = (
52625270
schema.training_job.definition.automl_image_classification,

tests/system/aiplatform/test_e2e_forecasting.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,8 @@ class TestEndToEndForecasting(e2e_base.TestEndToEnd):
4242
[
4343
training_jobs.AutoMLForecastingTrainingJob,
4444
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
45-
pytest.param(
46-
training_jobs.TemporalFusionTransformerForecastingTrainingJob,
47-
marks=pytest.mark.skip(reason="TFT not yet released."),
48-
),
45+
training_jobs.TemporalFusionTransformerForecastingTrainingJob,
46+
training_jobs.TimeSeriesDenseEncoderForecastingTrainingJob,
4947
],
5048
)
5149
def test_end_to_end_forecasting(self, shared_state, training_job):

tests/unit/aiplatform/test_automl_forecasting_training_jobs.py

+1
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@
187187
training_jobs.AutoMLForecastingTrainingJob,
188188
training_jobs.SequenceToSequencePlusForecastingTrainingJob,
189189
training_jobs.TemporalFusionTransformerForecastingTrainingJob,
190+
training_jobs.TimeSeriesDenseEncoderForecastingTrainingJob,
190191
]
191192

192193

0 commit comments

Comments
 (0)