Skip to content

Commit 28a091a

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Do not reset aiplatform.Experiment or aiplatform.ExperimentRun unnecessarily when running tensorboard uploader.
PiperOrigin-RevId: 644168539
1 parent d548c11 commit 28a091a

File tree

3 files changed

+78
-16
lines changed

3 files changed

+78
-16
lines changed

google/cloud/aiplatform/tensorboard/uploader.py

+35-12
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@
3030
from google.cloud.aiplatform.compat.services import (
3131
tensorboard_service_client,
3232
)
33+
from google.cloud.aiplatform.compat.types import execution as gca_execution
3334
from google.cloud.aiplatform.compat.types import tensorboard_data
3435
from google.cloud.aiplatform.compat.types import tensorboard_service
3536
from google.cloud.aiplatform.compat.types import tensorboard_time_series
3637
from google.cloud.aiplatform.metadata import experiment_resources
38+
from google.cloud.aiplatform.metadata import experiment_run_resource
3739
from google.cloud.aiplatform.metadata import metadata
3840
from google.cloud.aiplatform.tensorboard import logdir_loader
3941
from google.cloud.aiplatform.tensorboard import upload_tracker
@@ -164,6 +166,9 @@ def __init__(
164166
self._one_shot = one_shot
165167
self._dispatcher = None
166168
self._additional_senders: Dict[str, uploader_utils.RequestSender] = {}
169+
self._experiment_runs = []
170+
self._project = None
171+
self._location = None
167172
if logdir_poll_rate_limiter is None:
168173
self._logdir_poll_rate_limiter = uploader_utils.RateLimiter(
169174
uploader_constants.MIN_LOGDIR_POLL_INTERVAL_SECS
@@ -221,27 +226,31 @@ def create_experiment(self):
221226
Vertex Experiment and associate it with a Tensorboard Experiment.
222227
"""
223228
m = self._api.parse_tensorboard_path(self._tensorboard_resource_name)
229+
self._project = m["project"]
230+
self._location = m["location"]
224231

225232
existing_experiment = experiment_resources.Experiment.get(
226233
experiment_name=self._experiment_name,
227-
project=m["project"],
228-
location=m["location"],
234+
project=self._project,
235+
location=self._location,
229236
)
230237
if not existing_experiment:
231238
self._is_brand_new_experiment = True
232239

233-
metadata._experiment_tracker.reset()
240+
if metadata._experiment_tracker.experiment_name != self._experiment_name:
241+
logging.info(f"Setting experiment to {self._experiment_name}")
242+
metadata._experiment_tracker.reset()
243+
metadata._experiment_tracker.set_experiment(
244+
project=self._project,
245+
location=self._location,
246+
experiment=self._experiment_name,
247+
description=self._description,
248+
backing_tensorboard=self._tensorboard_resource_name,
249+
)
234250
metadata._experiment_tracker.set_tensorboard(
235251
tensorboard=self._tensorboard_resource_name,
236-
project=m["project"],
237-
location=m["location"],
238-
)
239-
metadata._experiment_tracker.set_experiment(
240-
project=m["project"],
241-
location=m["location"],
242-
experiment=self._experiment_name,
243-
description=self._description,
244-
backing_tensorboard=self._tensorboard_resource_name,
252+
project=self._project,
253+
location=self._location,
245254
)
246255

247256
self._tensorboard_experiment_resource_name = (
@@ -309,6 +318,17 @@ def _create_additional_senders(self) -> Dict[str, uploader_utils.RequestSender]:
309318
def get_experiment_resource_name(self):
310319
return self._tensorboard_experiment_resource_name
311320

321+
def _end_experiment_runs(self):
322+
# End all runs created by uploader
323+
for run_name in self._experiment_runs:
324+
if run_name:
325+
logging.info("Ending run %s", run_name)
326+
run = experiment_run_resource.ExperimentRun.get(
327+
project=self._project, location=self._location, run_name=run_name
328+
)
329+
if run:
330+
run.update_state(state=gca_execution.Execution.State.COMPLETE)
331+
312332
def start_uploading(self):
313333
"""Blocks forever to continuously upload data from the logdir.
314334
@@ -334,6 +354,7 @@ def start_uploading(self):
334354
self._logdir_poll_rate_limiter.tick()
335355
self._upload_once()
336356
if self._one_shot:
357+
self._end_experiment_runs()
337358
break
338359
if self._one_shot and not self._tracker.has_data():
339360
logger.warning(
@@ -343,6 +364,7 @@ def start_uploading(self):
343364

344365
def _end_uploading(self):
345366
self._continue_uploading = False
367+
self._end_experiment_runs()
346368

347369
def _pre_create_runs_and_time_series(self):
348370
"""Iterates though the log dir to collect TensorboardRuns and
@@ -409,6 +431,7 @@ def _upload_once(self):
409431
run_to_events = {
410432
self._run_name_prefix + k: v for k, v in run_to_events.items()
411433
}
434+
self._experiment_runs = run_to_events.keys()
412435

413436
# Add a profile event to trigger send_request in _additional_senders
414437
if self._should_profile():

google/cloud/aiplatform/tensorboard/uploader_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def _create_or_get_run_resource(
234234
run_name=run_name,
235235
experiment=experiment,
236236
tensorboard=tensorboard,
237-
state=gca_execution.Execution.State.COMPLETE,
237+
state=gca_execution.Execution.State.RUNNING,
238238
)
239239
tb_run_artifact = experiment_run._backing_tensorboard_run
240240
tb_run = tb_run_artifact.resource

tests/unit/aiplatform/test_uploader.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,11 @@ def batch_create_time_series(parent, requests):
785785
with mock.patch.object(
786786
uploader, "_logdir_loader_pre_create", mock_logdir_loader_pre_create
787787
):
788-
uploader.start_uploading()
788+
with mock.patch.object(
789+
uploader, "_end_experiment_runs", return_value=None
790+
):
791+
uploader.start_uploading()
792+
uploader._end_experiment_runs.assert_called_once()
789793

790794
self.assertEqual(existing_experiment is None, uploader._is_brand_new_experiment)
791795
self.assertEqual(2, mock_client.write_tensorboard_experiment_data.call_count)
@@ -797,6 +801,7 @@ def batch_create_time_series(parent, requests):
797801
self.assertLen(mock_tracker.scalars_tracker.call_args[0], 1)
798802
self.assertEqual(mock_tracker.tensors_tracker.call_count, 0)
799803
self.assertEqual(mock_tracker.blob_tracker.call_count, 0)
804+
experiment_tracker_mock.set_experiment.assert_called_once()
800805

801806
@patch.object(metadata, "_experiment_tracker", autospec=True)
802807
@patch.object(experiment_resources, "Experiment", autospec=True)
@@ -814,6 +819,7 @@ def test_upload_empty_logdir(
814819
uploader.create_experiment()
815820
uploader._upload_once()
816821
mock_client.write_tensorboard_experiment_data.assert_not_called()
822+
experiment_tracker_mock.set_experiment.assert_called_once()
817823

818824
@patch.object(metadata, "_experiment_tracker", autospec=True)
819825
@patch.object(experiment_resources, "Experiment", autospec=True)
@@ -847,6 +853,7 @@ def mock_upload_once():
847853
uploader.create_experiment()
848854
with self.assertRaises(SuccessError):
849855
uploader.start_uploading()
856+
experiment_tracker_mock.set_experiment.assert_called_once()
850857

851858
@patch.object(
852859
uploader_utils.OnePlatformResourceManager,
@@ -874,6 +881,7 @@ def test_upload_swallows_rpc_failure(
874881
mock_client.write_tensorboard_experiment_data.side_effect = error
875882
uploader._upload_once()
876883
mock_client.write_tensorboard_experiment_data.assert_called_once()
884+
experiment_tracker_mock.set_experiment.assert_called_once()
877885

878886
@patch.object(
879887
uploader_utils.OnePlatformResourceManager,
@@ -1006,10 +1014,12 @@ def test_upload_full_logdir(
10061014
self.assertProtoEquals(expected_request3[1], request3[1])
10071015
self.assertProtoEquals(expected_request4[0], request4[0])
10081016
mock_client.write_tensorboard_experiment_data.reset_mock()
1017+
experiment_tracker_mock.set_experiment.assert_called_once()
10091018

10101019
# Empty third round
10111020
uploader._upload_once()
10121021
mock_client.write_tensorboard_experiment_data.assert_not_called()
1022+
experiment_tracker_mock.set_experiment.assert_called_once()
10131023

10141024
@patch.object(
10151025
uploader_utils.OnePlatformResourceManager,
@@ -1057,6 +1067,7 @@ def test_verbosity_zero_creates_upload_tracker_with_verbosity_zero(
10571067
self.assertEqual(mock_constructor.call_count, 1)
10581068
self.assertEqual(mock_constructor.call_args[1], {"verbosity": 0})
10591069
self.assertEqual(mock_tracker.scalars_tracker.call_count, 1)
1070+
experiment_tracker_mock.set_experiment.assert_called_once()
10601071

10611072
@patch.object(
10621073
uploader_utils.OnePlatformResourceManager,
@@ -1160,6 +1171,7 @@ def create_time_series(tensorboard_time_series, parent=None):
11601171
self.assertEqual(mock_tracker.scalars_tracker.call_count, 0)
11611172
self.assertEqual(mock_tracker.tensors_tracker.call_count, 0)
11621173
self.assertEqual(mock_tracker.blob_tracker.call_count, 12)
1174+
experiment_tracker_mock.set_experiment.assert_called_once()
11631175

11641176
@patch.object(
11651177
uploader_utils.OnePlatformResourceManager,
@@ -1282,6 +1294,27 @@ def test_profile_plugin_included_by_default(
12821294
profile_sender = senders["profile"]
12831295
self.assertIn(run_name, profile_sender._run_to_profile_loaders)
12841296
self.assertIn(run_name, profile_sender._run_to_file_request_sender)
1297+
experiment_tracker_mock.set_experiment.assert_called_once()
1298+
1299+
@patch.object(metadata, "_experiment_tracker", autospec=True)
1300+
@patch.object(experiment_resources, "Experiment", autospec=True)
1301+
def test_active_experiment_set_experiment_not_called(
1302+
self, experiment_resources_mock, experiment_tracker_mock
1303+
):
1304+
experiment_resources_mock.get.return_value = _TEST_EXPERIMENT_NAME
1305+
experiment_tracker_mock.set_experiment.return_value = _TEST_EXPERIMENT_NAME
1306+
experiment_tracker_mock.experiment_name = _TEST_EXPERIMENT_NAME
1307+
experiment_tracker_mock.set_tensorboard.return_value = (
1308+
_TEST_TENSORBOARD_RESOURCE_NAME
1309+
)
1310+
logdir = self.get_temp_dir()
1311+
mock_client = _create_mock_client()
1312+
1313+
uploader = _create_uploader(mock_client, logdir)
1314+
uploader.create_experiment()
1315+
uploader._upload_once()
1316+
1317+
experiment_tracker_mock.set_experiment.assert_not_called()
12851318

12861319

12871320
# TODO(b/276368161)
@@ -1387,6 +1420,7 @@ def test_thread_continuously_uploads(
13871420
self.assertEqual(b"12345", request2.plugin_data)
13881421
self.assertEqual("scalars", request3.plugin_name)
13891422
self.assertEqual("profile", request4.plugin_name)
1423+
experiment_tracker_mock.set_experiment.assert_called_once()
13901424

13911425
# Check write_tensorboard_experiment_data calls
13921426
self.assertEqual(1, mock_client.write_tensorboard_experiment_data.call_count)
@@ -1425,17 +1459,22 @@ def test_thread_continuously_uploads(
14251459
self.assertProtoEquals(expected_request1[1], request1[1])
14261460
self.assertProtoEquals(expected_request2[0], request2[0])
14271461

1428-
uploader._end_uploading()
1462+
with mock.patch.object(uploader, "_end_experiment_runs", return_value=None):
1463+
uploader._end_uploading()
1464+
uploader._end_experiment_runs.assert_called_once()
14291465
time.sleep(1)
14301466
self.assertFalse(uploader_thread.is_alive())
14311467
mock_client.write_tensorboard_experiment_data.reset_mock()
14321468

14331469
# Empty directory
14341470
uploader._upload_once()
14351471
mock_client.write_tensorboard_experiment_data.assert_not_called()
1436-
uploader._end_uploading()
1472+
with mock.patch.object(uploader, "_end_experiment_runs", return_value=None):
1473+
uploader._end_uploading()
1474+
uploader._end_experiment_runs.assert_called_once()
14371475
time.sleep(1)
14381476
self.assertFalse(uploader_thread.is_alive())
1477+
experiment_tracker_mock.set_experiment.assert_called_once()
14391478

14401479

14411480
@pytest.mark.usefixtures("google_auth_mock")

0 commit comments

Comments
 (0)