Skip to content

Commit 369a0cc

Browse files
sararobcopybara-github
authored andcommitted
feat: enable passing experiment_tensorboard to init without experiment
PiperOrigin-RevId: 501298160
1 parent 2e509d0 commit 369a0cc

File tree

5 files changed

+160
-5
lines changed

5 files changed

+160
-5
lines changed

google/cloud/aiplatform/initializer.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def init(
8383
8484
Example tensorboard resource name format:
8585
"projects/123/locations/us-central1/tensorboards/456"
86+
87+
If `experiment_tensorboard` is provided and `experiment` is not,
88+
the provided `experiment_tensorboard` will be set as the global Tensorboard.
89+
Any subsequent calls to aiplatform.init() with `experiment` and without
90+
`experiment_tensorboard` will automatically assign the global Tensorboard
91+
to the `experiment`.
8692
staging_bucket (str): The default staging bucket to use to stage artifacts
8793
when making API calls. In the form gs://...
8894
credentials (google.auth.credentials.Credentials): The default custom
@@ -106,17 +112,19 @@ def init(
106112
Raises:
107113
ValueError:
108114
If experiment_description is provided but experiment is not.
109-
If experiment_tensorboard is provided but experiment is not.
110115
"""
111116

112117
if experiment_description and experiment is None:
113118
raise ValueError(
114119
"Experiment needs to be set in `init` in order to add experiment descriptions."
115120
)
116121

117-
if experiment_tensorboard and experiment is None:
118-
raise ValueError(
119-
"Experiment needs to be set in `init` in order to add experiment_tensorboard."
122+
if experiment_tensorboard:
123+
metadata._experiment_tracker.set_tensorboard(
124+
tensorboard=experiment_tensorboard,
125+
project=project,
126+
location=location,
127+
credentials=credentials,
120128
)
121129

122130
# reset metadata_service config if project or location is updated.

google/cloud/aiplatform/metadata/experiment_resources.py

+7
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,13 @@ def resource_name(self) -> str:
326326
"""The Metadata context resource name of this experiment."""
327327
return self._metadata_context.resource_name
328328

329+
@property
330+
def backing_tensorboard_resource_name(self) -> Optional[str]:
331+
"""The Tensorboard resource associated with this Experiment if there is one."""
332+
return self._metadata_context.metadata.get(
333+
constants._BACKING_TENSORBOARD_RESOURCE_KEY
334+
)
335+
329336
def delete(self, *, delete_backing_tensorboard_runs: bool = False):
330337
"""Deletes this experiment all the experiment runs under this experiment
331338

google/cloud/aiplatform/metadata/metadata.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ class _ExperimentTracker:
186186
def __init__(self):
187187
self._experiment: Optional[experiment_resources.Experiment] = None
188188
self._experiment_run: Optional[experiment_run_resource.ExperimentRun] = None
189+
self._global_tensorboard: Optional[tensorboard_resource.Tensorboard] = None
189190

190191
def reset(self):
191192
"""Resets this experiment tracker, clearing the current experiment and run."""
@@ -235,11 +236,47 @@ def set_experiment(
235236
experiment_name=experiment, description=description
236237
)
237238

238-
if backing_tensorboard:
239+
backing_tb = backing_tensorboard or self._global_tensorboard
240+
241+
current_backing_tb = experiment.backing_tensorboard_resource_name
242+
243+
if not current_backing_tb and backing_tb:
239244
experiment.assign_backing_tensorboard(tensorboard=backing_tensorboard)
240245

241246
self._experiment = experiment
242247

248+
def set_tensorboard(
249+
self,
250+
tensorboard: Union[
251+
tensorboard_resource.Tensorboard,
252+
str,
253+
],
254+
project: Optional[str] = None,
255+
location: Optional[str] = None,
256+
credentials: Optional[auth_credentials.Credentials] = None,
257+
):
258+
"""Sets the global Tensorboard resource for this session.
259+
260+
Args:
261+
tensorboard (Union[str, aiplatform.Tensorboard]):
262+
Required. The Tensorboard resource to set as the global Tensorboard.
263+
project (str):
264+
Optional. Project associated with this Tensorboard resource.
265+
location (str):
266+
Optional. Location associated with this Tensorboard resource.
267+
credentials (auth_credentials.Credentials):
268+
Optional. Custom credentials used to set this Tensorboard resource.
269+
"""
270+
if isinstance(tensorboard, str):
271+
tensorboard = tensorboard_resource.Tensorboard(
272+
tensorboard,
273+
project=project,
274+
location=location,
275+
credentials=credentials,
276+
)
277+
278+
self._global_tensorboard = tensorboard
279+
243280
def start_run(
244281
self,
245282
run: str,

tests/system/aiplatform/test_experiments.py

+46
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,49 @@ def test_delete_experiment(self):
416416

417417
with pytest.raises(exceptions.NotFound):
418418
aiplatform.Experiment(experiment_name=self._experiment_name)
419+
420+
def test_init_associates_global_tensorboard_to_experiment(self, shared_state):
421+
422+
tensorboard = aiplatform.Tensorboard.create(
423+
project=e2e_base._PROJECT,
424+
location=e2e_base._LOCATION,
425+
display_name=self._make_display_name("")[:64],
426+
)
427+
428+
shared_state["resources"] = [tensorboard]
429+
430+
aiplatform.init(
431+
project=e2e_base._PROJECT,
432+
location=e2e_base._LOCATION,
433+
experiment_tensorboard=tensorboard,
434+
)
435+
436+
assert (
437+
aiplatform.metadata.metadata._experiment_tracker._global_tensorboard
438+
== tensorboard
439+
)
440+
441+
new_experiment_name = self._make_display_name("")[:64]
442+
new_experiment_resource = aiplatform.Experiment.create(
443+
experiment_name=new_experiment_name
444+
)
445+
446+
shared_state["resources"].append(new_experiment_resource)
447+
448+
aiplatform.init(
449+
project=e2e_base._PROJECT,
450+
location=e2e_base._LOCATION,
451+
experiment=new_experiment_name,
452+
)
453+
454+
assert (
455+
new_experiment_resource._lookup_backing_tensorboard().resource_name
456+
== tensorboard.resource_name
457+
)
458+
459+
assert (
460+
new_experiment_resource._metadata_context.metadata.get(
461+
aiplatform.metadata.constants._BACKING_TENSORBOARD_RESOURCE_KEY
462+
)
463+
== tensorboard.resource_name
464+
)

tests/unit/aiplatform/test_initializer.py

+57
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@
4444
_TEST_STAGING_BUCKET = "test-bucket"
4545
_TEST_NETWORK = "projects/12345/global/networks/myVPC"
4646

47+
# tensorboard
48+
_TEST_TENSORBOARD_ID = "1028944691210842416"
49+
_TEST_TENSORBOARD_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/tensorboards/{_TEST_TENSORBOARD_ID}"
50+
4751

4852
@pytest.mark.usefixtures("google_auth_mock")
4953
class TestInit:
@@ -115,6 +119,59 @@ def test_init_experiment_sets_experiment_with_description(
115119
backing_tensorboard=None,
116120
)
117121

122+
@patch.object(_experiment_tracker, "set_tensorboard")
123+
def test_init_with_experiment_tensorboard_id_sets_global_tensorboard(
124+
self, set_tensorboard_mock
125+
):
126+
creds = credentials.AnonymousCredentials()
127+
initializer.global_config.init(
128+
experiment_tensorboard=_TEST_TENSORBOARD_ID,
129+
project=_TEST_PROJECT,
130+
location=_TEST_LOCATION,
131+
credentials=creds,
132+
)
133+
134+
set_tensorboard_mock.assert_called_once_with(
135+
tensorboard=_TEST_TENSORBOARD_ID,
136+
project=_TEST_PROJECT,
137+
location=_TEST_LOCATION,
138+
credentials=creds,
139+
)
140+
141+
@patch.object(_experiment_tracker, "set_tensorboard")
142+
def test_init_with_experiment_tensorboard_resource_sets_global_tensorboard(
143+
self, set_tensorboard_mock
144+
):
145+
initializer.global_config.init(experiment_tensorboard=_TEST_TENSORBOARD_NAME)
146+
147+
set_tensorboard_mock.assert_called_once_with(
148+
tensorboard=_TEST_TENSORBOARD_NAME,
149+
project=None,
150+
location=None,
151+
credentials=None,
152+
)
153+
154+
@patch.object(_experiment_tracker, "set_tensorboard")
155+
@patch.object(_experiment_tracker, "set_experiment")
156+
def test_init_experiment_without_tensorboard_uses_global_tensorboard(
157+
self,
158+
set_tensorboard_mock,
159+
set_experiment_mock,
160+
):
161+
162+
initializer.global_config.init(experiment_tensorboard=_TEST_TENSORBOARD_NAME)
163+
164+
initializer.global_config.init(
165+
experiment=_TEST_EXPERIMENT,
166+
)
167+
168+
set_experiment_mock.assert_called_once_with(
169+
tensorboard=_TEST_TENSORBOARD_NAME,
170+
project=None,
171+
location=None,
172+
credentials=None,
173+
)
174+
118175
def test_init_experiment_description_fail_without_experiment(self):
119176
with pytest.raises(ValueError):
120177
initializer.global_config.init(experiment_description=_TEST_DESCRIPTION)

0 commit comments

Comments
 (0)