|
53 | 53 | _TEST_PIPELINE_JOB_DISPLAY_NAME = "sample-pipeline-job-display-name"
|
54 | 54 | _TEST_PIPELINE_JOB_ID = "sample-test-pipeline-202111111"
|
55 | 55 | _TEST_GCS_BUCKET_NAME = "my-bucket"
|
| 56 | +_TEST_GCS_OUTPUT_DIRECTORY = f"gs://{_TEST_GCS_BUCKET_NAME}/output_artifacts/" |
56 | 57 | _TEST_CREDENTIALS = auth_credentials.AnonymousCredentials()
|
57 | 58 | _TEST_SERVICE_ACCOUNT = "[email protected]"
|
58 | 59 |
|
@@ -249,7 +250,7 @@ def mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
|
249 | 250 |
|
250 | 251 | with mock.patch(
|
251 | 252 | "google.cloud.aiplatform.utils.gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist",
|
252 |
| - new=mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist, |
| 253 | + wraps=mock_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist, |
253 | 254 | ) as mock_context:
|
254 | 255 | yield mock_context
|
255 | 256 |
|
@@ -1097,6 +1098,44 @@ def test_submit_call_pipeline_service_pipeline_job_create(
|
1097 | 1098 | gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
|
1098 | 1099 | )
|
1099 | 1100 |
|
| 1101 | + @pytest.mark.parametrize( |
| 1102 | + "job_spec", |
| 1103 | + [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB], |
| 1104 | + ) |
| 1105 | + def test_submit_call_gcs_utils_get_or_create_with_correct_arguments( |
| 1106 | + self, |
| 1107 | + mock_pipeline_service_create, |
| 1108 | + mock_pipeline_service_get, |
| 1109 | + mock_pipeline_bucket_exists, |
| 1110 | + job_spec, |
| 1111 | + mock_load_yaml_and_json, |
| 1112 | + ): |
| 1113 | + job = pipeline_jobs.PipelineJob( |
| 1114 | + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, |
| 1115 | + template_path=_TEST_TEMPLATE_PATH, |
| 1116 | + job_id=_TEST_PIPELINE_JOB_ID, |
| 1117 | + parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, |
| 1118 | + enable_caching=True, |
| 1119 | + project=_TEST_PROJECT, |
| 1120 | + pipeline_root=_TEST_GCS_OUTPUT_DIRECTORY, |
| 1121 | + location=_TEST_LOCATION, |
| 1122 | + credentials=_TEST_CREDENTIALS, |
| 1123 | + ) |
| 1124 | + |
| 1125 | + job.submit( |
| 1126 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 1127 | + network=_TEST_NETWORK, |
| 1128 | + create_request_timeout=None, |
| 1129 | + ) |
| 1130 | + |
| 1131 | + mock_pipeline_bucket_exists.assert_called_once_with( |
| 1132 | + output_artifacts_gcs_dir=_TEST_GCS_OUTPUT_DIRECTORY, |
| 1133 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 1134 | + project=_TEST_PROJECT, |
| 1135 | + location=_TEST_LOCATION, |
| 1136 | + credentials=_TEST_CREDENTIALS, |
| 1137 | + ) |
| 1138 | + |
1100 | 1139 | @pytest.mark.parametrize(
|
1101 | 1140 | "job_spec",
|
1102 | 1141 | [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
|
|
0 commit comments