Skip to content

Commit c116b07

Browse files
jaycee-licopybara-github
authored andcommitted
feat: add ExperimentRun.get_logged_custom_jobs method
PiperOrigin-RevId: 523824096
1 parent f837e0e commit c116b07

File tree

5 files changed

+79
-11
lines changed

5 files changed

+79
-11
lines changed

google/cloud/aiplatform/constants/base.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,9 @@
114114

115115
# Used in CustomJob.from_local_script for experiments integration in training
116116
AIPLATFORM_DEPENDENCY_PATH = (
117-
"google-cloud-aiplatform[metadata,tensorboard]"
118-
+ f"=={aiplatform_version.__version__}"
117+
f"google-cloud-aiplatform=={aiplatform_version.__version__}"
119118
)
120119

121120
AIPLATFORM_AUTOLOG_DEPENDENCY_PATH = (
122-
"google-cloud-aiplatform[metadata,tensorboard,autologging]"
123-
+ f"=={aiplatform_version.__version__}"
121+
f"google-cloud-aiplatform[autologging]=={aiplatform_version.__version__}"
124122
)

google/cloud/aiplatform/jobs.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1488,10 +1488,12 @@ def from_local_script(
14881488
).pool_specs
14891489
)
14901490

1491+
# if users enable autolog, automatically install SDK in their container image
1492+
# otherwise users need to manually install SDK
14911493
if enable_autolog:
14921494
experiment_requirements = [constants.AIPLATFORM_AUTOLOG_DEPENDENCY_PATH]
14931495
else:
1494-
experiment_requirements = [constants.AIPLATFORM_DEPENDENCY_PATH]
1496+
experiment_requirements = []
14951497

14961498
if requirements:
14971499
requirements.extend(experiment_requirements)

google/cloud/aiplatform/metadata/experiment_run_resource.py

+19
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from google.cloud.aiplatform import base
2626
from google.cloud.aiplatform import initializer
2727
from google.cloud.aiplatform import pipeline_jobs
28+
from google.cloud.aiplatform import jobs
2829
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
2930
from google.cloud.aiplatform.compat.types import execution as gca_execution
3031
from google.cloud.aiplatform.compat.types import (
@@ -1304,6 +1305,24 @@ def get_logged_pipeline_jobs(self) -> List[pipeline_jobs.PipelineJob]:
13041305
for c in pipeline_job_contexts
13051306
]
13061307

1308+
@_v1_not_supported
1309+
def get_logged_custom_jobs(self) -> List[jobs.CustomJob]:
1310+
"""Get all CustomJobs associated to this experiment run.
1311+
1312+
Returns:
1313+
List of CustomJobs associated this run.
1314+
"""
1315+
1316+
custom_jobs = self._metadata_node.metadata.get(constants._CUSTOM_JOB_KEY)
1317+
1318+
return [
1319+
jobs.CustomJob.get(
1320+
resource_name=custom_job.get(constants._CUSTOM_JOB_RESOURCE_NAME),
1321+
credentials=self.credentials,
1322+
)
1323+
for custom_job in custom_jobs
1324+
]
1325+
13071326
def __enter__(self):
13081327
return self
13091328

tests/system/aiplatform/test_custom_job.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@
4040
@mock.patch.object(
4141
constants,
4242
"AIPLATFORM_DEPENDENCY_PATH",
43-
"google-cloud-aiplatform[metadata,tensorboard] @ git+https://github.com/googleapis/"
43+
"google-cloud-aiplatform @ git+https://github.com/googleapis/"
4444
f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}#egg=google-cloud-aiplatform"
4545
if os.environ.get("KOKORO_GIT_COMMIT")
4646
else constants.AIPLATFORM_DEPENDENCY_PATH,
4747
)
4848
@mock.patch.object(
4949
constants,
5050
"AIPLATFORM_AUTOLOG_DEPENDENCY_PATH",
51-
"google-cloud-aiplatform[metadata,tensorboard,autologging] @ git+https://github.com/googleapis/"
51+
"google-cloud-aiplatform[autologging] @ git+https://github.com/googleapis/"
5252
f"python-aiplatform.git@{os.environ['KOKORO_GIT_COMMIT']}#egg=google-cloud-aiplatform"
5353
if os.environ.get("KOKORO_GIT_COMMIT")
5454
else constants.AIPLATFORM_AUTOLOG_DEPENDENCY_PATH,
@@ -88,7 +88,7 @@ def test_from_local_script_prebuilt_container(self, shared_state):
8888
display_name=display_name,
8989
script_path=_LOCAL_TRAINING_SCRIPT_PATH,
9090
container_uri=_PREBUILT_CONTAINER_IMAGE,
91-
requirements=["scikit-learn"],
91+
requirements=["scikit-learn", "pandas"],
9292
)
9393
custom_job.run()
9494

@@ -110,7 +110,7 @@ def test_from_local_script_custom_container(self, shared_state):
110110
display_name=display_name,
111111
script_path=_LOCAL_TRAINING_SCRIPT_PATH,
112112
container_uri=_CUSTOM_CONTAINER_IMAGE,
113-
requirements=["scikit-learn"],
113+
requirements=["scikit-learn", "pandas"],
114114
)
115115
custom_job.run()
116116

@@ -139,7 +139,7 @@ def test_from_local_script_enable_autolog_prebuilt_container(self, shared_state)
139139
display_name=display_name,
140140
script_path=_LOCAL_TRAINING_SCRIPT_PATH,
141141
container_uri=_PREBUILT_CONTAINER_IMAGE,
142-
requirements=["scikit-learn"],
142+
requirements=["scikit-learn", "pandas"],
143143
enable_autolog=True,
144144
)
145145

@@ -169,7 +169,7 @@ def test_from_local_script_enable_autolog_custom_container(self, shared_state):
169169
display_name=display_name,
170170
script_path=_LOCAL_TRAINING_SCRIPT_PATH,
171171
container_uri=_CUSTOM_CONTAINER_IMAGE,
172-
requirements=["scikit-learn"],
172+
requirements=["scikit-learn", "pandas"],
173173
enable_autolog=True,
174174
)
175175

tests/unit/aiplatform/test_metadata.py

+49
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
Artifact as GapicArtifact,
4343
Context as GapicContext,
4444
Execution as GapicExecution,
45+
JobServiceClient,
4546
MetadataServiceClient,
4647
AddExecutionEventsResponse,
4748
MetadataStore as GapicMetadataStore,
@@ -686,6 +687,21 @@ def test_get_pipeline_df_wrong_schema(self):
686687
_EXPERIMENT_RUN_MOCK_WITH_PARENT_EXPERIMENT = copy.deepcopy(_EXPERIMENT_RUN_MOCK)
687688
_EXPERIMENT_RUN_MOCK_WITH_PARENT_EXPERIMENT.parent_contexts = [_TEST_CONTEXT_NAME]
688689

690+
_TEST_CUSTOM_JOB_NAME = (
691+
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/customJobs/12345"
692+
)
693+
_TEST_CUSTOM_JOB_CONSOLE_URI = "test-custom-job-console-uri"
694+
695+
_EXPERIMENT_RUN_MOCK_WITH_CUSTOM_JOBS = copy.deepcopy(
696+
_EXPERIMENT_RUN_MOCK_WITH_PARENT_EXPERIMENT
697+
)
698+
_EXPERIMENT_RUN_MOCK_WITH_CUSTOM_JOBS.metadata[constants._CUSTOM_JOB_KEY] = [
699+
{
700+
constants._CUSTOM_JOB_RESOURCE_NAME: _TEST_CUSTOM_JOB_NAME,
701+
constants._CUSTOM_JOB_CONSOLE_URI: _TEST_CUSTOM_JOB_CONSOLE_URI,
702+
},
703+
]
704+
689705

690706
@pytest.fixture
691707
def get_experiment_mock():
@@ -724,6 +740,17 @@ def get_experiment_run_mock():
724740
yield get_context_mock
725741

726742

743+
@pytest.fixture
744+
def get_experiment_run_with_custom_jobs_mock():
745+
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
746+
get_context_mock.side_effect = [
747+
_EXPERIMENT_MOCK,
748+
_EXPERIMENT_RUN_MOCK_WITH_CUSTOM_JOBS,
749+
]
750+
751+
yield get_context_mock
752+
753+
727754
@pytest.fixture
728755
def get_experiment_run_not_found_mock():
729756
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
@@ -831,6 +858,12 @@ def add_context_children_mock():
831858
yield add_context_children_mock
832859

833860

861+
@pytest.fixture
862+
def get_custom_job_mock():
863+
with patch.object(JobServiceClient, "get_custom_job") as get_custom_job_mock:
864+
yield get_custom_job_mock
865+
866+
834867
_EXPERIMENT_RUN_MOCK_POPULATED_1 = copy.deepcopy(
835868
_EXPERIMENT_RUN_MOCK_WITH_PARENT_EXPERIMENT
836869
)
@@ -1869,3 +1902,19 @@ def test_get_experiment_df_wrong_schema(self):
18691902
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
18701903
with pytest.raises(ValueError):
18711904
aiplatform.get_experiment_df(_TEST_EXPERIMENT)
1905+
1906+
@pytest.mark.usefixtures(
1907+
"get_experiment_run_with_custom_jobs_mock",
1908+
"get_metadata_store_mock",
1909+
"get_tensorboard_run_artifact_not_found_mock",
1910+
)
1911+
def test_experiment_run_get_logged_custom_jobs(self, get_custom_job_mock):
1912+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
1913+
run = aiplatform.ExperimentRun(_TEST_RUN, experiment=_TEST_EXPERIMENT)
1914+
jobs = run.get_logged_custom_jobs()
1915+
1916+
assert len(jobs) == 1
1917+
get_custom_job_mock.assert_called_once_with(
1918+
name=_TEST_CUSTOM_JOB_NAME,
1919+
retry=base._DEFAULT_RETRY,
1920+
)

0 commit comments

Comments
 (0)