Skip to content

Commit 78a92a1

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Use default run_name in Tensorboard uploader for direct directory upload.
PiperOrigin-RevId: 646265028
1 parent e96fc91 commit 78a92a1

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

google/cloud/aiplatform/tensorboard/uploader.py

+7
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from tensorboard.util import tensor_util
6868

6969
_LOGGER = base.Logger(__name__)
70+
_DEFAULT_RUN_NAME = "default"
7071

7172
TensorboardServiceClient = tensorboard_service_client.TensorboardServiceClient
7273

@@ -381,6 +382,7 @@ def _pre_create_runs_and_time_series(self):
381382
run_names = []
382383
run_tag_name_to_time_series_proto = {}
383384
for (run_name, events) in run_to_events.items():
385+
run_name = run_name if (run_name and run_name != ".") else _DEFAULT_RUN_NAME
384386
run_names.append(run_name)
385387
for event in events:
386388
_filter_graph_defs(event)
@@ -427,6 +429,11 @@ def _upload_once(self):
427429
logger.info("Logdir sync took %.3f seconds", sync_duration_secs)
428430

429431
run_to_events = self._logdir_loader.get_run_events()
432+
run_to_events = {
433+
k if (k and k != ".") else _DEFAULT_RUN_NAME: v
434+
for k, v in run_to_events.items()
435+
if v
436+
}
430437
if self._run_name_prefix:
431438
run_to_events = {
432439
self._run_name_prefix + k: v for k, v in run_to_events.items()

google/cloud/aiplatform/tensorboard/uploader_utils.py

-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import re
2424
import time
2525
from typing import Callable, Dict, Generator, List, Optional, Tuple
26-
import uuid
2726

2827
from absl import app
2928
from google.api_core import exceptions
@@ -222,8 +221,6 @@ def _create_or_get_run_resource(
222221
location = m[2]
223222
tensorboard = m[3]
224223
experiment = m[4]
225-
if not run_name or run_name == ".":
226-
run_name = str(uuid.uuid4())
227224
experiment_run = experiment_run_resource.ExperimentRun.get(
228225
project=project, location=location, run_name=run_name
229226
)

tests/unit/aiplatform/test_uploader.py

+47
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
_TEST_ONE_PLATFORM_RUN_NAME, _TEST_TIME_SERIES_NAME
116116
)
117117
_TEST_BLOB_STORAGE_FOLDER = "test_folder"
118+
_DEFAULT_RUN_NAME = "default"
118119

119120

120121
def _create_example_graph_bytes(large_attr_size):
@@ -821,6 +822,52 @@ def test_upload_empty_logdir(
821822
mock_client.write_tensorboard_experiment_data.assert_not_called()
822823
experiment_tracker_mock.set_experiment.assert_called_once()
823824

825+
@parameterized.parameters(
826+
{"run_name_prefix": None},
827+
{"run_name_prefix": "run-prefix-"},
828+
)
829+
@patch.object(
830+
uploader_utils.OnePlatformResourceManager,
831+
"get_run_resource_name",
832+
autospec=True,
833+
)
834+
@patch.object(metadata, "_experiment_tracker", autospec=True)
835+
@patch.object(experiment_resources, "Experiment", autospec=True)
836+
def test_default_run_name(
837+
self,
838+
experiment_resources_mock,
839+
experiment_tracker_mock,
840+
run_resource_mock,
841+
run_name_prefix,
842+
):
843+
run_resource_mock.return_value = "."
844+
experiment_resources_mock.get.return_value = _TEST_EXPERIMENT_NAME
845+
experiment_tracker_mock.set_experiment.return_value = _TEST_EXPERIMENT_NAME
846+
experiment_tracker_mock.set_tensorboard.return_value = (
847+
_TEST_TENSORBOARD_RESOURCE_NAME
848+
)
849+
logdir = self.get_temp_dir()
850+
with FileWriter(logdir) as writer:
851+
writer.add_test_summary("foo")
852+
853+
uploader = _create_uploader(
854+
logdir=logdir,
855+
run_name_prefix=run_name_prefix,
856+
)
857+
uploader.create_experiment()
858+
mock_dispatcher = mock.create_autospec(uploader_lib._Dispatcher)
859+
uploader._dispatcher = mock_dispatcher
860+
mock_logdir_loader = mock.create_autospec(logdir_loader.LogdirLoader)
861+
mock.patch.object(uploader, "_logdir_loader", mock_logdir_loader)
862+
expected_run_name = _DEFAULT_RUN_NAME
863+
if run_name_prefix:
864+
expected_run_name = run_name_prefix + _DEFAULT_RUN_NAME
865+
866+
uploader._upload_once()
867+
868+
run_to_events = mock_dispatcher.dispatch_requests.call_args[0][0]
869+
self.assertIn(expected_run_name, run_to_events)
870+
824871
@patch.object(metadata, "_experiment_tracker", autospec=True)
825872
@patch.object(experiment_resources, "Experiment", autospec=True)
826873
def test_upload_polls_slowly_once_done(

0 commit comments

Comments
 (0)