Skip to content

Commit 47791f7

Browse files
authored
fix: change default for create_request_timeout arg to None (#1175)
Change default value for `create_request_timeout` from `False` to `None` and add test for when `create_request_timeout` isn't explicitly set. Fixes b/229868042 🦕
1 parent 4c21993 commit 47791f7

File tree

2 files changed

+155
-1
lines changed

2 files changed

+155
-1
lines changed

google/cloud/aiplatform/training_jobs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4736,7 +4736,7 @@ def run(
47364736
model_labels: Optional[Dict[str, str]] = None,
47374737
disable_early_stopping: bool = False,
47384738
sync: bool = True,
4739-
create_request_timeout: Optional[float] = False,
4739+
create_request_timeout: Optional[float] = None,
47404740
) -> models.Model:
47414741
"""Runs the AutoML Image training job and returns a model.
47424742

tests/unit/aiplatform/test_training_jobs.py

+154
Original file line numberDiff line numberDiff line change
@@ -4737,6 +4737,160 @@ def test_run_call_pipeline_service_create_with_tabular_dataset_with_timeout(
47374737
timeout=180.0,
47384738
)
47394739

4740+
@pytest.mark.parametrize("sync", [True, False])
4741+
def test_run_call_pipeline_service_create_with_tabular_dataset_with_timeout_not_explicitly_set(
4742+
self,
4743+
mock_pipeline_service_create,
4744+
mock_pipeline_service_get,
4745+
mock_tabular_dataset,
4746+
mock_model_service_get,
4747+
sync,
4748+
):
4749+
aiplatform.init(
4750+
project=_TEST_PROJECT,
4751+
staging_bucket=_TEST_BUCKET_NAME,
4752+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
4753+
)
4754+
4755+
job = training_jobs.CustomPythonPackageTrainingJob(
4756+
display_name=_TEST_DISPLAY_NAME,
4757+
labels=_TEST_LABELS,
4758+
python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH,
4759+
python_module_name=_TEST_PYTHON_MODULE_NAME,
4760+
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
4761+
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
4762+
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
4763+
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
4764+
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
4765+
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
4766+
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
4767+
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
4768+
model_description=_TEST_MODEL_DESCRIPTION,
4769+
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
4770+
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
4771+
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
4772+
)
4773+
4774+
model_from_job = job.run(
4775+
dataset=mock_tabular_dataset,
4776+
model_display_name=_TEST_MODEL_DISPLAY_NAME,
4777+
model_labels=_TEST_MODEL_LABELS,
4778+
base_output_dir=_TEST_BASE_OUTPUT_DIR,
4779+
service_account=_TEST_SERVICE_ACCOUNT,
4780+
network=_TEST_NETWORK,
4781+
args=_TEST_RUN_ARGS,
4782+
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
4783+
machine_type=_TEST_MACHINE_TYPE,
4784+
accelerator_type=_TEST_ACCELERATOR_TYPE,
4785+
accelerator_count=_TEST_ACCELERATOR_COUNT,
4786+
training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT,
4787+
validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT,
4788+
test_fraction_split=_TEST_TEST_FRACTION_SPLIT,
4789+
sync=sync,
4790+
)
4791+
4792+
if not sync:
4793+
model_from_job.wait()
4794+
4795+
true_args = _TEST_RUN_ARGS
4796+
true_env = [
4797+
{"name": key, "value": value}
4798+
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
4799+
]
4800+
4801+
true_worker_pool_spec = {
4802+
"replica_count": _TEST_REPLICA_COUNT,
4803+
"machine_spec": {
4804+
"machine_type": _TEST_MACHINE_TYPE,
4805+
"accelerator_type": _TEST_ACCELERATOR_TYPE,
4806+
"accelerator_count": _TEST_ACCELERATOR_COUNT,
4807+
},
4808+
"disk_spec": {
4809+
"boot_disk_type": _TEST_BOOT_DISK_TYPE_DEFAULT,
4810+
"boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB_DEFAULT,
4811+
},
4812+
"python_package_spec": {
4813+
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
4814+
"python_module": _TEST_PYTHON_MODULE_NAME,
4815+
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
4816+
"args": true_args,
4817+
"env": true_env,
4818+
},
4819+
}
4820+
4821+
true_fraction_split = gca_training_pipeline.FractionSplit(
4822+
training_fraction=_TEST_TRAINING_FRACTION_SPLIT,
4823+
validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT,
4824+
test_fraction=_TEST_TEST_FRACTION_SPLIT,
4825+
)
4826+
4827+
env = [
4828+
gca_env_var.EnvVar(name=str(key), value=str(value))
4829+
for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items()
4830+
]
4831+
4832+
ports = [
4833+
gca_model.Port(container_port=port)
4834+
for port in _TEST_MODEL_SERVING_CONTAINER_PORTS
4835+
]
4836+
4837+
true_container_spec = gca_model.ModelContainerSpec(
4838+
image_uri=_TEST_SERVING_CONTAINER_IMAGE,
4839+
predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
4840+
health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
4841+
command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
4842+
args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
4843+
env=env,
4844+
ports=ports,
4845+
)
4846+
4847+
true_managed_model = gca_model.Model(
4848+
display_name=_TEST_MODEL_DISPLAY_NAME,
4849+
labels=_TEST_MODEL_LABELS,
4850+
description=_TEST_MODEL_DESCRIPTION,
4851+
container_spec=true_container_spec,
4852+
predict_schemata=gca_model.PredictSchemata(
4853+
instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
4854+
parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
4855+
prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
4856+
),
4857+
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
4858+
)
4859+
4860+
true_input_data_config = gca_training_pipeline.InputDataConfig(
4861+
fraction_split=true_fraction_split,
4862+
dataset_id=mock_tabular_dataset.name,
4863+
gcs_destination=gca_io.GcsDestination(
4864+
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
4865+
),
4866+
)
4867+
4868+
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
4869+
display_name=_TEST_DISPLAY_NAME,
4870+
labels=_TEST_LABELS,
4871+
training_task_definition=schema.training_job.definition.custom_task,
4872+
training_task_inputs=json_format.ParseDict(
4873+
{
4874+
"worker_pool_specs": [true_worker_pool_spec],
4875+
"base_output_directory": {
4876+
"output_uri_prefix": _TEST_BASE_OUTPUT_DIR
4877+
},
4878+
"service_account": _TEST_SERVICE_ACCOUNT,
4879+
"network": _TEST_NETWORK,
4880+
},
4881+
struct_pb2.Value(),
4882+
),
4883+
model_to_upload=true_managed_model,
4884+
input_data_config=true_input_data_config,
4885+
encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC,
4886+
)
4887+
4888+
mock_pipeline_service_create.assert_called_once_with(
4889+
parent=initializer.global_config.common_location_path(),
4890+
training_pipeline=true_training_pipeline,
4891+
timeout=None,
4892+
)
4893+
47404894
@pytest.mark.parametrize("sync", [True, False])
47414895
def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_display_name_nor_model_labels(
47424896
self,

0 commit comments

Comments
 (0)