Skip to content

Commit efe88f9

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
docs: Add probabilistic inference to TiDE and L2L model code samples.
PiperOrigin-RevId: 570445202
1 parent 51a3bfa commit efe88f9

5 files changed

+9
-0
lines changed

samples/model-builder/create_training_pipeline_forecasting_sample.py

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def create_training_pipeline_forecasting_sample(
4444
export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
4545
export_evaluated_data_items_override_destination: bool = False,
4646
quantiles: Optional[List[float]] = None,
47+
enable_probabilistic_inference: bool = False,
4748
validation_options: Optional[str] = None,
4849
predefined_split_column_name: Optional[str] = None,
4950
sync: bool = True,
@@ -81,6 +82,7 @@ def create_training_pipeline_forecasting_sample(
8182
export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
8283
export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
8384
quantiles=quantiles,
85+
enable_probabilistic_inference=enable_probabilistic_inference,
8486
validation_options=validation_options,
8587
budget_milli_node_hours=budget_milli_node_hours,
8688
model_display_name=model_display_name,

samples/model-builder/create_training_pipeline_forecasting_sample_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_create_training_pipeline_forecasting_sample(
4343
export_evaluated_data_items_bigquery_destination_uri=constants.EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI,
4444
export_evaluated_data_items_override_destination=constants.EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION,
4545
quantiles=constants.QUANTILES,
46+
enable_probabilistic_inference=constants.ENABLE_PROBABILISTIC_INFERENCE,
4647
validation_options=constants.VALIDATION_OPTIONS,
4748
predefined_split_column_name=constants.PREDEFINED_SPLIT_COLUMN_NAME,
4849
)
@@ -79,6 +80,7 @@ def test_create_training_pipeline_forecasting_sample(
7980
export_evaluated_data_items_bigquery_destination_uri=constants.EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI,
8081
export_evaluated_data_items_override_destination=constants.EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION,
8182
quantiles=constants.QUANTILES,
83+
enable_probabilistic_inference=constants.ENABLE_PROBABILISTIC_INFERENCE,
8284
validation_options=constants.VALIDATION_OPTIONS,
8385
predefined_split_column_name=constants.PREDEFINED_SPLIT_COLUMN_NAME,
8486
sync=True,

samples/model-builder/create_training_pipeline_forecasting_tide_sample.py

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def create_training_pipeline_forecasting_time_series_dense_encoder_sample(
4444
export_evaluated_data_items_bigquery_destination_uri: Optional[str] = None,
4545
export_evaluated_data_items_override_destination: bool = False,
4646
quantiles: Optional[List[float]] = None,
47+
enable_probabilistic_inference: bool = False,
4748
validation_options: Optional[str] = None,
4849
predefined_split_column_name: Optional[str] = None,
4950
sync: bool = True,
@@ -82,6 +83,7 @@ def create_training_pipeline_forecasting_time_series_dense_encoder_sample(
8283
export_evaluated_data_items_bigquery_destination_uri=export_evaluated_data_items_bigquery_destination_uri,
8384
export_evaluated_data_items_override_destination=export_evaluated_data_items_override_destination,
8485
quantiles=quantiles,
86+
enable_probabilistic_inference=enable_probabilistic_inference,
8587
validation_options=validation_options,
8688
budget_milli_node_hours=budget_milli_node_hours,
8789
model_display_name=model_display_name,

samples/model-builder/create_training_pipeline_forecasting_tide_sample_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def test_create_training_pipeline_forecasting_tide_sample(
4343
export_evaluated_data_items_bigquery_destination_uri=constants.EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI,
4444
export_evaluated_data_items_override_destination=constants.EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION,
4545
quantiles=constants.QUANTILES,
46+
enable_probabilistic_inference=constants.ENABLE_PROBABILISTIC_INFERENCE,
4647
validation_options=constants.VALIDATION_OPTIONS,
4748
predefined_split_column_name=constants.PREDEFINED_SPLIT_COLUMN_NAME,
4849
)
@@ -79,6 +80,7 @@ def test_create_training_pipeline_forecasting_tide_sample(
7980
export_evaluated_data_items_bigquery_destination_uri=constants.EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI,
8081
export_evaluated_data_items_override_destination=constants.EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION,
8182
quantiles=constants.QUANTILES,
83+
enable_probabilistic_inference=constants.ENABLE_PROBABILISTIC_INFERENCE,
8284
validation_options=constants.VALIDATION_OPTIONS,
8385
predefined_split_column_name=constants.PREDEFINED_SPLIT_COLUMN_NAME,
8486
sync=True,

samples/model-builder/test_constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@
269269
EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI = "bq://test:test:test"
270270
EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION = True
271271
QUANTILES = [0, 0.5, 1]
272+
ENABLE_PROBABILISTIC_INFERENCE = True
272273
VALIDATION_OPTIONS = "fail-pipeline"
273274
PREDEFINED_SPLIT_COLUMN_NAME = "predefined"
274275

0 commit comments

Comments
 (0)