|
69 | 69 | from google.cloud import storage
|
70 | 70 | from google.protobuf import json_format
|
71 | 71 | from google.protobuf import struct_pb2
|
72 |
| - |
| 72 | +from google.protobuf import duration_pb2 # type: ignore |
73 | 73 |
|
74 | 74 | _TEST_BUCKET_NAME = "test-bucket"
|
75 | 75 | _TEST_GCS_PATH_WITHOUT_BUCKET = "path/to/folder"
|
@@ -211,9 +211,13 @@ def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"):
|
211 | 211 | custom_job_proto.name = name
|
212 | 212 | custom_job_proto.state = state
|
213 | 213 |
|
214 |
| - custom_job_proto.job_spec.enable_web_access = _TEST_ENABLE_WEB_ACCESS |
215 |
| - if state == gca_job_state.JobState.JOB_STATE_RUNNING: |
216 |
| - custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS |
| 214 | + custom_job_proto.job_spec.scheduling.timeout = duration_pb2.Duration( |
| 215 | + seconds=_TEST_TIMEOUT |
| 216 | + ) |
| 217 | + custom_job_proto.job_spec.scheduling.restart_job_on_worker_restart = ( |
| 218 | + _TEST_RESTART_JOB_ON_WORKER_RESTART |
| 219 | + ) |
| 220 | + |
217 | 221 | return custom_job_proto
|
218 | 222 |
|
219 | 223 |
|
@@ -321,40 +325,6 @@ def mock_get_backing_custom_job_with_enable_web_access():
|
321 | 325 | yield get_custom_job_mock
|
322 | 326 |
|
323 | 327 |
|
324 |
| -@pytest.fixture |
325 |
| -def mock_get_backing_custom_job_with_scheduling(): |
326 |
| - with patch.object( |
327 |
| - job_service_client.JobServiceClient, "get_custom_job" |
328 |
| - ) as get_custom_job_mock: |
329 |
| - get_custom_job_mock.side_effect = [ |
330 |
| - _get_custom_job_proto_with_scheduling( |
331 |
| - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, |
332 |
| - state=gca_job_state.JobState.JOB_STATE_PENDING, |
333 |
| - ), |
334 |
| - _get_custom_job_proto_with_scheduling( |
335 |
| - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, |
336 |
| - state=gca_job_state.JobState.JOB_STATE_RUNNING, |
337 |
| - ), |
338 |
| - _get_custom_job_proto_with_scheduling( |
339 |
| - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, |
340 |
| - state=gca_job_state.JobState.JOB_STATE_RUNNING, |
341 |
| - ), |
342 |
| - _get_custom_job_proto_with_scheduling( |
343 |
| - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, |
344 |
| - state=gca_job_state.JobState.JOB_STATE_RUNNING, |
345 |
| - ), |
346 |
| - _get_custom_job_proto_with_scheduling( |
347 |
| - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, |
348 |
| - state=gca_job_state.JobState.JOB_STATE_SUCCEEDED, |
349 |
| - ), |
350 |
| - _get_custom_job_proto_with_scheduling( |
351 |
| - name=_TEST_CUSTOM_JOB_RESOURCE_NAME, |
352 |
| - state=gca_job_state.JobState.JOB_STATE_SUCCEEDED, |
353 |
| - ), |
354 |
| - ] |
355 |
| - yield get_custom_job_mock |
356 |
| - |
357 |
| - |
358 | 328 | class TestTrainingScriptPythonPackagerHelpers:
|
359 | 329 | def setup_method(self):
|
360 | 330 | importlib.reload(initializer)
|
@@ -1505,14 +1475,11 @@ def test_run_call_pipeline_service_create_with_enable_web_access(
|
1505 | 1475 | @pytest.mark.usefixtures(
|
1506 | 1476 | "mock_pipeline_service_create_with_scheduling",
|
1507 | 1477 | "mock_pipeline_service_get_with_scheduling",
|
1508 |
| - "mock_get_backing_custom_job_with_scheduling", |
1509 | 1478 | "mock_python_package_to_gcs",
|
1510 | 1479 | )
|
1511 | 1480 | @pytest.mark.parametrize("sync", [True, False])
|
1512 | 1481 | def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
|
1513 | 1482 |
|
1514 |
| - caplog.set_level(logging.INFO) |
1515 |
| - |
1516 | 1483 | aiplatform.init(
|
1517 | 1484 | project=_TEST_PROJECT,
|
1518 | 1485 | staging_bucket=_TEST_BUCKET_NAME,
|
@@ -2952,13 +2919,10 @@ def test_run_call_pipeline_service_create_with_enable_web_access(
|
2952 | 2919 | @pytest.mark.usefixtures(
|
2953 | 2920 | "mock_pipeline_service_create_with_scheduling",
|
2954 | 2921 | "mock_pipeline_service_get_with_scheduling",
|
2955 |
| - "mock_get_backing_custom_job_with_scheduling", |
2956 | 2922 | )
|
2957 | 2923 | @pytest.mark.parametrize("sync", [True, False])
|
2958 | 2924 | def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
|
2959 | 2925 |
|
2960 |
| - caplog.set_level(logging.INFO) |
2961 |
| - |
2962 | 2926 | aiplatform.init(
|
2963 | 2927 | project=_TEST_PROJECT,
|
2964 | 2928 | staging_bucket=_TEST_BUCKET_NAME,
|
@@ -4670,13 +4634,10 @@ def test_run_call_pipeline_service_create_with_enable_web_access(
|
4670 | 4634 | @pytest.mark.usefixtures(
|
4671 | 4635 | "mock_pipeline_service_create_with_scheduling",
|
4672 | 4636 | "mock_pipeline_service_get_with_scheduling",
|
4673 |
| - "mock_get_backing_custom_job_with_scheduling", |
4674 | 4637 | )
|
4675 | 4638 | @pytest.mark.parametrize("sync", [True, False])
|
4676 | 4639 | def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
|
4677 | 4640 |
|
4678 |
| - caplog.set_level(logging.INFO) |
4679 |
| - |
4680 | 4641 | aiplatform.init(
|
4681 | 4642 | project=_TEST_PROJECT,
|
4682 | 4643 | staging_bucket=_TEST_BUCKET_NAME,
|
|
0 commit comments