Skip to content

Commit 41cd943

Browse files
jaycee-licopybara-github
authored andcommitted
feat: add get method for Experiment and ExperimentRun
PiperOrigin-RevId: 518042292
1 parent 9fa3c68 commit 41cd943

File tree

3 files changed

+181
-4
lines changed

3 files changed

+181
-4
lines changed

google/cloud/aiplatform/metadata/experiment_resources.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2022 Google LLC
3+
# Copyright 2023 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -18,8 +18,9 @@
1818
import abc
1919
from dataclasses import dataclass
2020
import logging
21-
from typing import Dict, List, NamedTuple, Optional, Union, Tuple, Type
21+
from typing import Dict, List, NamedTuple, Optional, Tuple, Type, Union
2222

23+
from google.api_core import exceptions
2324
from google.auth import credentials as auth_credentials
2425

2526
from google.cloud.aiplatform import base
@@ -211,6 +212,43 @@ def create(
211212

212213
return self
213214

215+
@classmethod
216+
def get(
217+
cls,
218+
experiment_name: str,
219+
*,
220+
project: Optional[str] = None,
221+
location: Optional[str] = None,
222+
credentials: Optional[auth_credentials.Credentials] = None,
223+
) -> Optional["Experiment"]:
224+
"""Gets experiment if one exists with this experiment_name in Vertex AI Experiments.
225+
226+
Args:
227+
experiment_name (str):
228+
Required. The name of this experiment.
229+
project (str):
230+
Optional. Project used to retrieve this resource.
231+
Overrides project set in aiplatform.init.
232+
location (str):
233+
Optional. Location used to retrieve this resource.
234+
Overrides location set in aiplatform.init.
235+
credentials (auth_credentials.Credentials):
236+
Optional. Custom credentials used to retrieve this resource.
237+
Overrides credentials set in aiplatform.init.
238+
239+
Returns:
240+
Vertex AI experiment or None if no resource was found.
241+
"""
242+
try:
243+
return cls(
244+
experiment_name=experiment_name,
245+
project=project,
246+
location=location,
247+
credentials=credentials,
248+
)
249+
except exceptions.NotFound:
250+
return None
251+
214252
@classmethod
215253
def get_or_create(
216254
cls,

google/cloud/aiplatform/metadata/experiment_run_resource.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2022 Google LLC
3+
# Copyright 2023 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -386,6 +386,56 @@ def _lookup_tensorboard_run_artifact(
386386
metadata=tensorboard_run_artifact,
387387
)
388388

389+
@classmethod
390+
def get(
391+
cls,
392+
run_name: str,
393+
*,
394+
experiment: Optional[Union[experiment_resources.Experiment, str]] = None,
395+
project: Optional[str] = None,
396+
location: Optional[str] = None,
397+
credentials: Optional[auth_credentials.Credentials] = None,
398+
) -> Optional["ExperimentRun"]:
399+
"""Gets experiment run if one exists with this run_name.
400+
401+
Args:
402+
run_name (str):
403+
Required. The name of this run.
404+
experiment (Union[experiment_resources.Experiment, str]):
405+
Optional. The name or instance of this experiment.
406+
If not set, use the default experiment in `aiplatform.init`
407+
project (str):
408+
Optional. Project where this experiment run is located.
409+
Overrides project set in aiplatform.init.
410+
location (str):
411+
Optional. Location where this experiment run is located.
412+
Overrides location set in aiplatform.init.
413+
credentials (auth_credentials.Credentials):
414+
Optional. Custom credentials used to retrieve this experiment run.
415+
Overrides credentials set in aiplatform.init.
416+
417+
Returns:
418+
Vertex AI experimentRun or None if no resource was found.
419+
"""
420+
experiment = experiment or metadata._experiment_tracker.experiment
421+
422+
if not experiment:
423+
raise ValueError(
424+
"experiment must be provided or "
425+
"experiment should be set using aiplatform.init"
426+
)
427+
428+
try:
429+
return cls(
430+
run_name=run_name,
431+
experiment=experiment,
432+
project=project,
433+
location=location,
434+
credentials=credentials,
435+
)
436+
except exceptions.NotFound:
437+
return None
438+
389439
@classmethod
390440
def list(
391441
cls,

tests/unit/aiplatform/test_metadata.py

+90-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2022 Google LLC
3+
# Copyright 2023 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -280,6 +280,17 @@ def get_execution_mock():
280280
yield get_execution_mock
281281

282282

283+
@pytest.fixture
284+
def get_execution_not_found_mock():
285+
with patch.object(
286+
MetadataServiceClient, "get_execution"
287+
) as get_execution_not_found_mock:
288+
get_execution_not_found_mock.side_effect = exceptions.NotFound(
289+
"test: not found"
290+
)
291+
yield get_execution_not_found_mock
292+
293+
283294
@pytest.fixture
284295
def get_execution_wrong_schema_mock():
285296
with patch.object(
@@ -681,6 +692,13 @@ def get_experiment_mock():
681692
yield get_context_mock
682693

683694

695+
@pytest.fixture
696+
def get_experiment_not_found_mock():
697+
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
698+
get_context_mock.side_effect = exceptions.NotFound("test: not found")
699+
yield get_context_mock
700+
701+
684702
@pytest.fixture
685703
def get_experiment_run_run_mock():
686704
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
@@ -704,6 +722,17 @@ def get_experiment_run_mock():
704722
yield get_context_mock
705723

706724

725+
@pytest.fixture
726+
def get_experiment_run_not_found_mock():
727+
with patch.object(MetadataServiceClient, "get_context") as get_context_mock:
728+
get_context_mock.side_effect = [
729+
_EXPERIMENT_MOCK,
730+
exceptions.NotFound("test: not found"),
731+
]
732+
733+
yield get_context_mock
734+
735+
707736
@pytest.fixture
708737
def create_experiment_context_mock():
709738
with patch.object(MetadataServiceClient, "create_context") as create_context_mock:
@@ -1125,6 +1154,66 @@ def test_init_experiment_wrong_schema(self):
11251154
experiment=_TEST_EXPERIMENT,
11261155
)
11271156

1157+
def test_get_experiment(self, get_experiment_mock):
1158+
aiplatform.init(
1159+
project=_TEST_PROJECT,
1160+
location=_TEST_LOCATION,
1161+
)
1162+
1163+
exp = aiplatform.Experiment.get(_TEST_EXPERIMENT)
1164+
1165+
assert exp.name == _TEST_EXPERIMENT
1166+
get_experiment_mock.assert_called_with(
1167+
name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY
1168+
)
1169+
1170+
def test_get_experiment_not_found(self, get_experiment_not_found_mock):
1171+
aiplatform.init(
1172+
project=_TEST_PROJECT,
1173+
location=_TEST_LOCATION,
1174+
)
1175+
1176+
exp = aiplatform.Experiment.get(_TEST_EXPERIMENT)
1177+
1178+
assert exp is None
1179+
get_experiment_not_found_mock.assert_called_with(
1180+
name=_TEST_CONTEXT_NAME, retry=base._DEFAULT_RETRY
1181+
)
1182+
1183+
@pytest.mark.usefixtures(
1184+
"get_metadata_store_mock", "get_tensorboard_run_artifact_not_found_mock"
1185+
)
1186+
def test_get_experiment_run(self, get_experiment_run_mock):
1187+
aiplatform.init(
1188+
project=_TEST_PROJECT,
1189+
location=_TEST_LOCATION,
1190+
)
1191+
1192+
run = aiplatform.ExperimentRun.get(_TEST_RUN, experiment=_TEST_EXPERIMENT)
1193+
1194+
assert run.name == _TEST_RUN
1195+
get_experiment_run_mock.assert_called_with(
1196+
name=f"{_TEST_CONTEXT_NAME}-{_TEST_RUN}", retry=base._DEFAULT_RETRY
1197+
)
1198+
1199+
@pytest.mark.usefixtures(
1200+
"get_metadata_store_mock",
1201+
"get_tensorboard_run_artifact_not_found_mock",
1202+
"get_execution_not_found_mock",
1203+
)
1204+
def test_get_experiment_run_not_found(self, get_experiment_run_not_found_mock):
1205+
aiplatform.init(
1206+
project=_TEST_PROJECT,
1207+
location=_TEST_LOCATION,
1208+
)
1209+
1210+
run = aiplatform.ExperimentRun.get(_TEST_RUN, experiment=_TEST_EXPERIMENT)
1211+
1212+
assert run is None
1213+
get_experiment_run_not_found_mock.assert_called_with(
1214+
name=f"{_TEST_CONTEXT_NAME}-{_TEST_RUN}", retry=base._DEFAULT_RETRY
1215+
)
1216+
11281217
@pytest.mark.usefixtures("get_metadata_store_mock")
11291218
def test_start_run(
11301219
self,

0 commit comments

Comments
 (0)