@@ -1007,6 +1007,65 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
1007
1007
1008
1008
assert job ._has_logged_custom_job
1009
1009
1010
+ def test_custom_training_tabular_done (
1011
+ self ,
1012
+ mock_pipeline_service_create ,
1013
+ mock_pipeline_service_get ,
1014
+ mock_python_package_to_gcs ,
1015
+ mock_tabular_dataset ,
1016
+ mock_model_service_get ,
1017
+ ):
1018
+ aiplatform .init (
1019
+ project = _TEST_PROJECT ,
1020
+ staging_bucket = _TEST_BUCKET_NAME ,
1021
+ credentials = _TEST_CREDENTIALS ,
1022
+ encryption_spec_key_name = _TEST_DEFAULT_ENCRYPTION_KEY_NAME ,
1023
+ )
1024
+
1025
+ job = training_jobs .CustomTrainingJob (
1026
+ display_name = _TEST_DISPLAY_NAME ,
1027
+ labels = _TEST_LABELS ,
1028
+ script_path = _TEST_LOCAL_SCRIPT_FILE_NAME ,
1029
+ container_uri = _TEST_TRAINING_CONTAINER_IMAGE ,
1030
+ model_serving_container_image_uri = _TEST_SERVING_CONTAINER_IMAGE ,
1031
+ model_serving_container_predict_route = _TEST_SERVING_CONTAINER_PREDICTION_ROUTE ,
1032
+ model_serving_container_health_route = _TEST_SERVING_CONTAINER_HEALTH_ROUTE ,
1033
+ model_instance_schema_uri = _TEST_MODEL_INSTANCE_SCHEMA_URI ,
1034
+ model_parameters_schema_uri = _TEST_MODEL_PARAMETERS_SCHEMA_URI ,
1035
+ model_prediction_schema_uri = _TEST_MODEL_PREDICTION_SCHEMA_URI ,
1036
+ model_serving_container_command = _TEST_MODEL_SERVING_CONTAINER_COMMAND ,
1037
+ model_serving_container_args = _TEST_MODEL_SERVING_CONTAINER_ARGS ,
1038
+ model_serving_container_environment_variables = _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES ,
1039
+ model_serving_container_ports = _TEST_MODEL_SERVING_CONTAINER_PORTS ,
1040
+ model_description = _TEST_MODEL_DESCRIPTION ,
1041
+ )
1042
+
1043
+ job .run (
1044
+ dataset = mock_tabular_dataset ,
1045
+ base_output_dir = _TEST_BASE_OUTPUT_DIR ,
1046
+ service_account = _TEST_SERVICE_ACCOUNT ,
1047
+ network = _TEST_NETWORK ,
1048
+ args = _TEST_RUN_ARGS ,
1049
+ environment_variables = _TEST_ENVIRONMENT_VARIABLES ,
1050
+ machine_type = _TEST_MACHINE_TYPE ,
1051
+ accelerator_type = _TEST_ACCELERATOR_TYPE ,
1052
+ accelerator_count = _TEST_ACCELERATOR_COUNT ,
1053
+ model_display_name = _TEST_MODEL_DISPLAY_NAME ,
1054
+ model_labels = _TEST_MODEL_LABELS ,
1055
+ training_fraction_split = _TEST_TRAINING_FRACTION_SPLIT ,
1056
+ validation_fraction_split = _TEST_VALIDATION_FRACTION_SPLIT ,
1057
+ test_fraction_split = _TEST_TEST_FRACTION_SPLIT ,
1058
+ timestamp_split_column_name = _TEST_TIMESTAMP_SPLIT_COLUMN_NAME ,
1059
+ tensorboard = _TEST_TENSORBOARD_RESOURCE_NAME ,
1060
+ sync = False ,
1061
+ )
1062
+
1063
+ assert job .done () is False
1064
+
1065
+ job .wait ()
1066
+
1067
+ assert job .done () is True
1068
+
1010
1069
@pytest .mark .parametrize ("sync" , [True , False ])
1011
1070
def test_run_call_pipeline_service_create_with_bigquery_destination (
1012
1071
self ,
@@ -2323,6 +2382,59 @@ def setup_method(self):
2323
2382
def teardown_method (self ):
2324
2383
initializer .global_pool .shutdown (wait = True )
2325
2384
2385
+ def test_custom_container_training_tabular_done (
2386
+ self ,
2387
+ mock_pipeline_service_create ,
2388
+ mock_pipeline_service_get ,
2389
+ mock_tabular_dataset ,
2390
+ mock_model_service_get ,
2391
+ ):
2392
+ aiplatform .init (
2393
+ project = _TEST_PROJECT ,
2394
+ staging_bucket = _TEST_BUCKET_NAME ,
2395
+ encryption_spec_key_name = _TEST_DEFAULT_ENCRYPTION_KEY_NAME ,
2396
+ )
2397
+
2398
+ job = training_jobs .CustomContainerTrainingJob (
2399
+ display_name = _TEST_DISPLAY_NAME ,
2400
+ labels = _TEST_LABELS ,
2401
+ container_uri = _TEST_TRAINING_CONTAINER_IMAGE ,
2402
+ command = _TEST_TRAINING_CONTAINER_CMD ,
2403
+ model_serving_container_image_uri = _TEST_SERVING_CONTAINER_IMAGE ,
2404
+ model_serving_container_predict_route = _TEST_SERVING_CONTAINER_PREDICTION_ROUTE ,
2405
+ model_serving_container_health_route = _TEST_SERVING_CONTAINER_HEALTH_ROUTE ,
2406
+ model_instance_schema_uri = _TEST_MODEL_INSTANCE_SCHEMA_URI ,
2407
+ model_parameters_schema_uri = _TEST_MODEL_PARAMETERS_SCHEMA_URI ,
2408
+ model_prediction_schema_uri = _TEST_MODEL_PREDICTION_SCHEMA_URI ,
2409
+ model_serving_container_command = _TEST_MODEL_SERVING_CONTAINER_COMMAND ,
2410
+ model_serving_container_args = _TEST_MODEL_SERVING_CONTAINER_ARGS ,
2411
+ model_serving_container_environment_variables = _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES ,
2412
+ model_serving_container_ports = _TEST_MODEL_SERVING_CONTAINER_PORTS ,
2413
+ model_description = _TEST_MODEL_DESCRIPTION ,
2414
+ )
2415
+
2416
+ job .run (
2417
+ dataset = mock_tabular_dataset ,
2418
+ base_output_dir = _TEST_BASE_OUTPUT_DIR ,
2419
+ args = _TEST_RUN_ARGS ,
2420
+ environment_variables = _TEST_ENVIRONMENT_VARIABLES ,
2421
+ machine_type = _TEST_MACHINE_TYPE ,
2422
+ accelerator_type = _TEST_ACCELERATOR_TYPE ,
2423
+ accelerator_count = _TEST_ACCELERATOR_COUNT ,
2424
+ model_display_name = _TEST_MODEL_DISPLAY_NAME ,
2425
+ model_labels = _TEST_MODEL_LABELS ,
2426
+ predefined_split_column_name = _TEST_PREDEFINED_SPLIT_COLUMN_NAME ,
2427
+ service_account = _TEST_SERVICE_ACCOUNT ,
2428
+ tensorboard = _TEST_TENSORBOARD_RESOURCE_NAME ,
2429
+ sync = False ,
2430
+ )
2431
+
2432
+ assert job .done () is False
2433
+
2434
+ job .wait ()
2435
+
2436
+ assert job .done () is True
2437
+
2326
2438
@pytest .mark .parametrize ("sync" , [True , False ])
2327
2439
def test_run_call_pipeline_service_create_with_tabular_dataset (
2328
2440
self ,
0 commit comments