Skip to content

Commit 7b9db49

Browse files
committed
Simplfied tests
1 parent 316c7b4 commit 7b9db49

File tree

1 file changed

+8
-47
lines changed

1 file changed

+8
-47
lines changed

tests/unit/aiplatform/test_training_jobs.py

+8-47
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from google.cloud import storage
7070
from google.protobuf import json_format
7171
from google.protobuf import struct_pb2
72-
72+
from google.protobuf import duration_pb2 # type: ignore
7373

7474
_TEST_BUCKET_NAME = "test-bucket"
7575
_TEST_GCS_PATH_WITHOUT_BUCKET = "path/to/folder"
@@ -211,9 +211,13 @@ def _get_custom_job_proto_with_scheduling(state=None, name=None, version="v1"):
211211
custom_job_proto.name = name
212212
custom_job_proto.state = state
213213

214-
custom_job_proto.job_spec.enable_web_access = _TEST_ENABLE_WEB_ACCESS
215-
if state == gca_job_state.JobState.JOB_STATE_RUNNING:
216-
custom_job_proto.web_access_uris = _TEST_WEB_ACCESS_URIS
214+
custom_job_proto.job_spec.scheduling.timeout = duration_pb2.Duration(
215+
seconds=_TEST_TIMEOUT
216+
)
217+
custom_job_proto.job_spec.scheduling.restart_job_on_worker_restart = (
218+
_TEST_RESTART_JOB_ON_WORKER_RESTART
219+
)
220+
217221
return custom_job_proto
218222

219223

@@ -321,40 +325,6 @@ def mock_get_backing_custom_job_with_enable_web_access():
321325
yield get_custom_job_mock
322326

323327

324-
@pytest.fixture
325-
def mock_get_backing_custom_job_with_scheduling():
326-
with patch.object(
327-
job_service_client.JobServiceClient, "get_custom_job"
328-
) as get_custom_job_mock:
329-
get_custom_job_mock.side_effect = [
330-
_get_custom_job_proto_with_scheduling(
331-
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
332-
state=gca_job_state.JobState.JOB_STATE_PENDING,
333-
),
334-
_get_custom_job_proto_with_scheduling(
335-
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
336-
state=gca_job_state.JobState.JOB_STATE_RUNNING,
337-
),
338-
_get_custom_job_proto_with_scheduling(
339-
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
340-
state=gca_job_state.JobState.JOB_STATE_RUNNING,
341-
),
342-
_get_custom_job_proto_with_scheduling(
343-
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
344-
state=gca_job_state.JobState.JOB_STATE_RUNNING,
345-
),
346-
_get_custom_job_proto_with_scheduling(
347-
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
348-
state=gca_job_state.JobState.JOB_STATE_SUCCEEDED,
349-
),
350-
_get_custom_job_proto_with_scheduling(
351-
name=_TEST_CUSTOM_JOB_RESOURCE_NAME,
352-
state=gca_job_state.JobState.JOB_STATE_SUCCEEDED,
353-
),
354-
]
355-
yield get_custom_job_mock
356-
357-
358328
class TestTrainingScriptPythonPackagerHelpers:
359329
def setup_method(self):
360330
importlib.reload(initializer)
@@ -1505,14 +1475,11 @@ def test_run_call_pipeline_service_create_with_enable_web_access(
15051475
@pytest.mark.usefixtures(
15061476
"mock_pipeline_service_create_with_scheduling",
15071477
"mock_pipeline_service_get_with_scheduling",
1508-
"mock_get_backing_custom_job_with_scheduling",
15091478
"mock_python_package_to_gcs",
15101479
)
15111480
@pytest.mark.parametrize("sync", [True, False])
15121481
def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
15131482

1514-
caplog.set_level(logging.INFO)
1515-
15161483
aiplatform.init(
15171484
project=_TEST_PROJECT,
15181485
staging_bucket=_TEST_BUCKET_NAME,
@@ -2952,13 +2919,10 @@ def test_run_call_pipeline_service_create_with_enable_web_access(
29522919
@pytest.mark.usefixtures(
29532920
"mock_pipeline_service_create_with_scheduling",
29542921
"mock_pipeline_service_get_with_scheduling",
2955-
"mock_get_backing_custom_job_with_scheduling",
29562922
)
29572923
@pytest.mark.parametrize("sync", [True, False])
29582924
def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
29592925

2960-
caplog.set_level(logging.INFO)
2961-
29622926
aiplatform.init(
29632927
project=_TEST_PROJECT,
29642928
staging_bucket=_TEST_BUCKET_NAME,
@@ -4670,13 +4634,10 @@ def test_run_call_pipeline_service_create_with_enable_web_access(
46704634
@pytest.mark.usefixtures(
46714635
"mock_pipeline_service_create_with_scheduling",
46724636
"mock_pipeline_service_get_with_scheduling",
4673-
"mock_get_backing_custom_job_with_scheduling",
46744637
)
46754638
@pytest.mark.parametrize("sync", [True, False])
46764639
def test_run_call_pipeline_service_create_with_scheduling(self, sync, caplog):
46774640

4678-
caplog.set_level(logging.INFO)
4679-
46804641
aiplatform.init(
46814642
project=_TEST_PROJECT,
46824643
staging_bucket=_TEST_BUCKET_NAME,

0 commit comments

Comments
 (0)