Skip to content

Commit 77b89c0

Browse files
TheMichaelHucopybara-github
authored andcommitted
fix: Fix default AutoML Forecasting transformations list.
PiperOrigin-RevId: 526734524
1 parent 06f8508 commit 77b89c0

File tree

4 files changed

+94
-16
lines changed

4 files changed

+94
-16
lines changed

google/cloud/aiplatform/training_jobs.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2438,7 +2438,9 @@ def _run(
24382438
(
24392439
self._column_transformations,
24402440
column_names,
2441-
) = dataset._get_default_column_transformations(target_column)
2441+
) = column_transformations_utils.get_default_column_transformations(
2442+
dataset=dataset, target_column=target_column
2443+
)
24422444

24432445
_LOGGER.info(
24442446
"The column transformation of type 'auto' was set for the following columns: %s."

google/cloud/aiplatform/utils/column_transformations_utils.py

+6-15
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#
1717

1818
from typing import Dict, List, Optional, Tuple
19-
import warnings
2019

2120
from google.cloud.aiplatform import datasets
2221

@@ -51,9 +50,9 @@ def get_default_column_transformations(
5150

5251

5352
def validate_and_get_column_transformations(
54-
column_specs: Optional[Dict[str, str]],
55-
column_transformations: Optional[List[Dict[str, Dict[str, str]]]],
56-
) -> List[Dict[str, Dict[str, str]]]:
53+
column_specs: Optional[Dict[str, str]] = None,
54+
column_transformations: Optional[List[Dict[str, Dict[str, str]]]] = None,
55+
) -> Optional[List[Dict[str, Dict[str, str]]]]:
5756
"""Validates column specs and transformations, then returns processed transformations.
5857
5958
Args:
@@ -91,21 +90,13 @@ def validate_and_get_column_transformations(
9190
# user populated transformations
9291
if column_transformations is not None and column_specs is not None:
9392
raise ValueError(
94-
"Both column_transformations and column_specs were passed. Only one is allowed."
93+
"Both column_transformations and column_specs were passed. Only "
94+
"one is allowed."
9595
)
96-
if column_transformations is not None:
97-
warnings.simplefilter("always", DeprecationWarning)
98-
warnings.warn(
99-
"consider using column_specs instead. column_transformations will be deprecated in the future.",
100-
DeprecationWarning,
101-
stacklevel=2,
102-
)
103-
104-
return column_transformations
10596
elif column_specs is not None:
10697
return [
10798
{transformation: {"column_name": column_name}}
10899
for column_name, transformation in column_specs.items()
109100
]
110101
else:
111-
return None
102+
return column_transformations

tests/unit/aiplatform/test_automl_forecasting_training_jobs.py

+38
Original file line numberDiff line numberDiff line change
@@ -1294,3 +1294,41 @@ def test_run_call_pipeline_if_set_additional_experiments_probabilistic_inference
12941294
training_pipeline=true_training_pipeline,
12951295
timeout=None,
12961296
)
1297+
1298+
def test_automl_forecasting_with_no_transformations(
1299+
self,
1300+
mock_pipeline_service_create,
1301+
mock_pipeline_service_get,
1302+
mock_dataset_time_series,
1303+
mock_model_service_get,
1304+
):
1305+
aiplatform.init(project=_TEST_PROJECT)
1306+
job = training_jobs.AutoMLForecastingTrainingJob(
1307+
display_name=_TEST_DISPLAY_NAME,
1308+
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
1309+
)
1310+
mock_dataset_time_series.column_names = [
1311+
"a",
1312+
"b",
1313+
_TEST_TRAINING_TARGET_COLUMN,
1314+
]
1315+
job.run(
1316+
dataset=mock_dataset_time_series,
1317+
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
1318+
target_column=_TEST_TRAINING_TARGET_COLUMN,
1319+
time_column=_TEST_TRAINING_TIME_COLUMN,
1320+
time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
1321+
unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
1322+
available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
1323+
forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON,
1324+
data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT,
1325+
data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT,
1326+
model_display_name=_TEST_MODEL_DISPLAY_NAME,
1327+
time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS,
1328+
context_window=_TEST_TRAINING_CONTEXT_WINDOW,
1329+
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
1330+
)
1331+
assert job._column_transformations == [
1332+
{"auto": {"column_name": "a"}},
1333+
{"auto": {"column_name": "b"}},
1334+
]

tests/unit/aiplatform/test_utils.py

+47
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
from google.cloud import storage
3636
from google.cloud.aiplatform import compat, utils
3737
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
38+
from google.cloud.aiplatform import datasets
3839
from google.cloud.aiplatform.utils import (
40+
column_transformations_utils,
3941
gcs_utils,
4042
pipeline_utils,
4143
prediction_utils,
@@ -485,6 +487,51 @@ def test_timestamped_unique_name():
485487
assert re.match(r"\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-.{5}", name)
486488

487489

490+
class TestColumnTransformationsUtils:
491+
492+
column_transformations = [
493+
{"auto": {"column_name": "a"}},
494+
{"auto": {"column_name": "b"}},
495+
]
496+
column_specs = {"a": "auto", "b": "auto"}
497+
498+
def test_get_default_column_transformations(self):
499+
ds = mock.MagicMock(datasets.TimeSeriesDataset)
500+
ds.column_names = ["a", "b", "target"]
501+
(
502+
transforms,
503+
columns,
504+
) = column_transformations_utils.get_default_column_transformations(
505+
dataset=ds, target_column="target"
506+
)
507+
assert transforms == [
508+
{"auto": {"column_name": "a"}},
509+
{"auto": {"column_name": "b"}},
510+
]
511+
assert columns == ["a", "b"]
512+
513+
def test_validate_transformations_with_multiple_configs(self):
514+
with pytest.raises(ValueError):
515+
(
516+
column_transformations_utils.validate_and_get_column_transformations(
517+
column_transformations=self.column_transformations,
518+
column_specs=self.column_specs,
519+
)
520+
)
521+
522+
def test_validate_transformations_with_column_specs(self):
523+
actual = column_transformations_utils.validate_and_get_column_transformations(
524+
column_specs=self.column_specs
525+
)
526+
assert actual == self.column_transformations
527+
528+
def test_validate_transformations_with_column_transformations(self):
529+
actual = column_transformations_utils.validate_and_get_column_transformations(
530+
column_transformations=self.column_transformations
531+
)
532+
assert actual == self.column_transformations
533+
534+
488535
@pytest.mark.usefixtures("google_auth_mock")
489536
class TestGcsUtils:
490537
def test_upload_to_gcs(self, json_file, mock_storage_blob_upload_from_filename):

0 commit comments

Comments
 (0)