|
42 | 42 | Artifact as GapicArtifact,
|
43 | 43 | Context as GapicContext,
|
44 | 44 | Execution as GapicExecution,
|
| 45 | + JobServiceClient, |
45 | 46 | MetadataServiceClient,
|
46 | 47 | AddExecutionEventsResponse,
|
47 | 48 | MetadataStore as GapicMetadataStore,
|
@@ -686,6 +687,21 @@ def test_get_pipeline_df_wrong_schema(self):
|
686 | 687 | _EXPERIMENT_RUN_MOCK_WITH_PARENT_EXPERIMENT = copy.deepcopy(_EXPERIMENT_RUN_MOCK)
|
687 | 688 | _EXPERIMENT_RUN_MOCK_WITH_PARENT_EXPERIMENT.parent_contexts = [_TEST_CONTEXT_NAME]
|
688 | 689 |
|
| 690 | +_TEST_CUSTOM_JOB_NAME = ( |
| 691 | + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/customJobs/12345" |
| 692 | +) |
| 693 | +_TEST_CUSTOM_JOB_CONSOLE_URI = "test-custom-job-console-uri" |
| 694 | + |
| 695 | +_EXPERIMENT_RUN_MOCK_WITH_CUSTOM_JOBS = copy.deepcopy( |
| 696 | + _EXPERIMENT_RUN_MOCK_WITH_PARENT_EXPERIMENT |
| 697 | +) |
| 698 | +_EXPERIMENT_RUN_MOCK_WITH_CUSTOM_JOBS.metadata[constants._CUSTOM_JOB_KEY] = [ |
| 699 | + { |
| 700 | + constants._CUSTOM_JOB_RESOURCE_NAME: _TEST_CUSTOM_JOB_NAME, |
| 701 | + constants._CUSTOM_JOB_CONSOLE_URI: _TEST_CUSTOM_JOB_CONSOLE_URI, |
| 702 | + }, |
| 703 | +] |
| 704 | + |
689 | 705 |
|
690 | 706 | @pytest.fixture
|
691 | 707 | def get_experiment_mock():
|
@@ -724,6 +740,17 @@ def get_experiment_run_mock():
|
724 | 740 | yield get_context_mock
|
725 | 741 |
|
726 | 742 |
|
| 743 | +@pytest.fixture |
| 744 | +def get_experiment_run_with_custom_jobs_mock(): |
| 745 | + with patch.object(MetadataServiceClient, "get_context") as get_context_mock: |
| 746 | + get_context_mock.side_effect = [ |
| 747 | + _EXPERIMENT_MOCK, |
| 748 | + _EXPERIMENT_RUN_MOCK_WITH_CUSTOM_JOBS, |
| 749 | + ] |
| 750 | + |
| 751 | + yield get_context_mock |
| 752 | + |
| 753 | + |
727 | 754 | @pytest.fixture
|
728 | 755 | def get_experiment_run_not_found_mock():
|
729 | 756 | with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
|
@@ -831,6 +858,12 @@ def add_context_children_mock():
|
831 | 858 | yield add_context_children_mock
|
832 | 859 |
|
833 | 860 |
|
| 861 | +@pytest.fixture |
| 862 | +def get_custom_job_mock(): |
| 863 | + with patch.object(JobServiceClient, "get_custom_job") as get_custom_job_mock: |
| 864 | + yield get_custom_job_mock |
| 865 | + |
| 866 | + |
834 | 867 | _EXPERIMENT_RUN_MOCK_POPULATED_1 = copy.deepcopy(
|
835 | 868 | _EXPERIMENT_RUN_MOCK_WITH_PARENT_EXPERIMENT
|
836 | 869 | )
|
@@ -1869,3 +1902,19 @@ def test_get_experiment_df_wrong_schema(self):
|
1869 | 1902 | aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
|
1870 | 1903 | with pytest.raises(ValueError):
|
1871 | 1904 | aiplatform.get_experiment_df(_TEST_EXPERIMENT)
|
| 1905 | + |
| 1906 | + @pytest.mark.usefixtures( |
| 1907 | + "get_experiment_run_with_custom_jobs_mock", |
| 1908 | + "get_metadata_store_mock", |
| 1909 | + "get_tensorboard_run_artifact_not_found_mock", |
| 1910 | + ) |
| 1911 | + def test_experiment_run_get_logged_custom_jobs(self, get_custom_job_mock): |
| 1912 | + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) |
| 1913 | + run = aiplatform.ExperimentRun(_TEST_RUN, experiment=_TEST_EXPERIMENT) |
| 1914 | + jobs = run.get_logged_custom_jobs() |
| 1915 | + |
| 1916 | + assert len(jobs) == 1 |
| 1917 | + get_custom_job_mock.assert_called_once_with( |
| 1918 | + name=_TEST_CUSTOM_JOB_NAME, |
| 1919 | + retry=base._DEFAULT_RETRY, |
| 1920 | + ) |
0 commit comments