@@ -4737,6 +4737,160 @@ def test_run_call_pipeline_service_create_with_tabular_dataset_with_timeout(
4737
4737
timeout = 180.0 ,
4738
4738
)
4739
4739
4740
+ @pytest .mark .parametrize ("sync" , [True , False ])
4741
+ def test_run_call_pipeline_service_create_with_tabular_dataset_with_timeout_not_explicitly_set (
4742
+ self ,
4743
+ mock_pipeline_service_create ,
4744
+ mock_pipeline_service_get ,
4745
+ mock_tabular_dataset ,
4746
+ mock_model_service_get ,
4747
+ sync ,
4748
+ ):
4749
+ aiplatform .init (
4750
+ project = _TEST_PROJECT ,
4751
+ staging_bucket = _TEST_BUCKET_NAME ,
4752
+ encryption_spec_key_name = _TEST_DEFAULT_ENCRYPTION_KEY_NAME ,
4753
+ )
4754
+
4755
+ job = training_jobs .CustomPythonPackageTrainingJob (
4756
+ display_name = _TEST_DISPLAY_NAME ,
4757
+ labels = _TEST_LABELS ,
4758
+ python_package_gcs_uri = _TEST_OUTPUT_PYTHON_PACKAGE_PATH ,
4759
+ python_module_name = _TEST_PYTHON_MODULE_NAME ,
4760
+ container_uri = _TEST_TRAINING_CONTAINER_IMAGE ,
4761
+ model_serving_container_image_uri = _TEST_SERVING_CONTAINER_IMAGE ,
4762
+ model_serving_container_predict_route = _TEST_SERVING_CONTAINER_PREDICTION_ROUTE ,
4763
+ model_serving_container_health_route = _TEST_SERVING_CONTAINER_HEALTH_ROUTE ,
4764
+ model_serving_container_command = _TEST_MODEL_SERVING_CONTAINER_COMMAND ,
4765
+ model_serving_container_args = _TEST_MODEL_SERVING_CONTAINER_ARGS ,
4766
+ model_serving_container_environment_variables = _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES ,
4767
+ model_serving_container_ports = _TEST_MODEL_SERVING_CONTAINER_PORTS ,
4768
+ model_description = _TEST_MODEL_DESCRIPTION ,
4769
+ model_instance_schema_uri = _TEST_MODEL_INSTANCE_SCHEMA_URI ,
4770
+ model_parameters_schema_uri = _TEST_MODEL_PARAMETERS_SCHEMA_URI ,
4771
+ model_prediction_schema_uri = _TEST_MODEL_PREDICTION_SCHEMA_URI ,
4772
+ )
4773
+
4774
+ model_from_job = job .run (
4775
+ dataset = mock_tabular_dataset ,
4776
+ model_display_name = _TEST_MODEL_DISPLAY_NAME ,
4777
+ model_labels = _TEST_MODEL_LABELS ,
4778
+ base_output_dir = _TEST_BASE_OUTPUT_DIR ,
4779
+ service_account = _TEST_SERVICE_ACCOUNT ,
4780
+ network = _TEST_NETWORK ,
4781
+ args = _TEST_RUN_ARGS ,
4782
+ environment_variables = _TEST_ENVIRONMENT_VARIABLES ,
4783
+ machine_type = _TEST_MACHINE_TYPE ,
4784
+ accelerator_type = _TEST_ACCELERATOR_TYPE ,
4785
+ accelerator_count = _TEST_ACCELERATOR_COUNT ,
4786
+ training_fraction_split = _TEST_TRAINING_FRACTION_SPLIT ,
4787
+ validation_fraction_split = _TEST_VALIDATION_FRACTION_SPLIT ,
4788
+ test_fraction_split = _TEST_TEST_FRACTION_SPLIT ,
4789
+ sync = sync ,
4790
+ )
4791
+
4792
+ if not sync :
4793
+ model_from_job .wait ()
4794
+
4795
+ true_args = _TEST_RUN_ARGS
4796
+ true_env = [
4797
+ {"name" : key , "value" : value }
4798
+ for key , value in _TEST_ENVIRONMENT_VARIABLES .items ()
4799
+ ]
4800
+
4801
+ true_worker_pool_spec = {
4802
+ "replica_count" : _TEST_REPLICA_COUNT ,
4803
+ "machine_spec" : {
4804
+ "machine_type" : _TEST_MACHINE_TYPE ,
4805
+ "accelerator_type" : _TEST_ACCELERATOR_TYPE ,
4806
+ "accelerator_count" : _TEST_ACCELERATOR_COUNT ,
4807
+ },
4808
+ "disk_spec" : {
4809
+ "boot_disk_type" : _TEST_BOOT_DISK_TYPE_DEFAULT ,
4810
+ "boot_disk_size_gb" : _TEST_BOOT_DISK_SIZE_GB_DEFAULT ,
4811
+ },
4812
+ "python_package_spec" : {
4813
+ "executor_image_uri" : _TEST_TRAINING_CONTAINER_IMAGE ,
4814
+ "python_module" : _TEST_PYTHON_MODULE_NAME ,
4815
+ "package_uris" : [_TEST_OUTPUT_PYTHON_PACKAGE_PATH ],
4816
+ "args" : true_args ,
4817
+ "env" : true_env ,
4818
+ },
4819
+ }
4820
+
4821
+ true_fraction_split = gca_training_pipeline .FractionSplit (
4822
+ training_fraction = _TEST_TRAINING_FRACTION_SPLIT ,
4823
+ validation_fraction = _TEST_VALIDATION_FRACTION_SPLIT ,
4824
+ test_fraction = _TEST_TEST_FRACTION_SPLIT ,
4825
+ )
4826
+
4827
+ env = [
4828
+ gca_env_var .EnvVar (name = str (key ), value = str (value ))
4829
+ for key , value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES .items ()
4830
+ ]
4831
+
4832
+ ports = [
4833
+ gca_model .Port (container_port = port )
4834
+ for port in _TEST_MODEL_SERVING_CONTAINER_PORTS
4835
+ ]
4836
+
4837
+ true_container_spec = gca_model .ModelContainerSpec (
4838
+ image_uri = _TEST_SERVING_CONTAINER_IMAGE ,
4839
+ predict_route = _TEST_SERVING_CONTAINER_PREDICTION_ROUTE ,
4840
+ health_route = _TEST_SERVING_CONTAINER_HEALTH_ROUTE ,
4841
+ command = _TEST_MODEL_SERVING_CONTAINER_COMMAND ,
4842
+ args = _TEST_MODEL_SERVING_CONTAINER_ARGS ,
4843
+ env = env ,
4844
+ ports = ports ,
4845
+ )
4846
+
4847
+ true_managed_model = gca_model .Model (
4848
+ display_name = _TEST_MODEL_DISPLAY_NAME ,
4849
+ labels = _TEST_MODEL_LABELS ,
4850
+ description = _TEST_MODEL_DESCRIPTION ,
4851
+ container_spec = true_container_spec ,
4852
+ predict_schemata = gca_model .PredictSchemata (
4853
+ instance_schema_uri = _TEST_MODEL_INSTANCE_SCHEMA_URI ,
4854
+ parameters_schema_uri = _TEST_MODEL_PARAMETERS_SCHEMA_URI ,
4855
+ prediction_schema_uri = _TEST_MODEL_PREDICTION_SCHEMA_URI ,
4856
+ ),
4857
+ encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
4858
+ )
4859
+
4860
+ true_input_data_config = gca_training_pipeline .InputDataConfig (
4861
+ fraction_split = true_fraction_split ,
4862
+ dataset_id = mock_tabular_dataset .name ,
4863
+ gcs_destination = gca_io .GcsDestination (
4864
+ output_uri_prefix = _TEST_BASE_OUTPUT_DIR
4865
+ ),
4866
+ )
4867
+
4868
+ true_training_pipeline = gca_training_pipeline .TrainingPipeline (
4869
+ display_name = _TEST_DISPLAY_NAME ,
4870
+ labels = _TEST_LABELS ,
4871
+ training_task_definition = schema .training_job .definition .custom_task ,
4872
+ training_task_inputs = json_format .ParseDict (
4873
+ {
4874
+ "worker_pool_specs" : [true_worker_pool_spec ],
4875
+ "base_output_directory" : {
4876
+ "output_uri_prefix" : _TEST_BASE_OUTPUT_DIR
4877
+ },
4878
+ "service_account" : _TEST_SERVICE_ACCOUNT ,
4879
+ "network" : _TEST_NETWORK ,
4880
+ },
4881
+ struct_pb2 .Value (),
4882
+ ),
4883
+ model_to_upload = true_managed_model ,
4884
+ input_data_config = true_input_data_config ,
4885
+ encryption_spec = _TEST_DEFAULT_ENCRYPTION_SPEC ,
4886
+ )
4887
+
4888
+ mock_pipeline_service_create .assert_called_once_with (
4889
+ parent = initializer .global_config .common_location_path (),
4890
+ training_pipeline = true_training_pipeline ,
4891
+ timeout = None ,
4892
+ )
4893
+
4740
4894
@pytest .mark .parametrize ("sync" , [True , False ])
4741
4895
def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_display_name_nor_model_labels (
4742
4896
self ,
0 commit comments