Skip to content

Commit 3fad7bb

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: add TensorBoard log uploader
PiperOrigin-RevId: 521565504
1 parent 00b853b commit 3fad7bb

File tree

7 files changed

+526
-103
lines changed

7 files changed

+526
-103
lines changed

google/cloud/aiplatform/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
MatchingEngineIndexEndpoint,
4444
)
4545
from google.cloud.aiplatform import metadata
46+
from google.cloud.aiplatform.tensorboard import uploader_tracker
4647
from google.cloud.aiplatform.models import Endpoint
4748
from google.cloud.aiplatform.models import PrivateEndpoint
4849
from google.cloud.aiplatform.models import Model
@@ -100,6 +101,10 @@
100101
log_time_series_metrics = metadata.metadata._experiment_tracker.log_time_series_metrics
101102
end_run = metadata.metadata._experiment_tracker.end_run
102103

104+
upload_tb_log = uploader_tracker._tensorboard_tracker.upload_tb_log
105+
start_upload_tb_log = uploader_tracker._tensorboard_tracker.start_upload_tb_log
106+
end_upload_tb_log = uploader_tracker._tensorboard_tracker.end_upload_tb_log
107+
103108
save_model = metadata._models.save_model
104109
get_experiment_model = metadata.schema.google.artifact_schema.ExperimentModel.get
105110

google/cloud/aiplatform/tensorboard/uploader.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2021 Google LLC
3+
# Copyright 2023 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -15,33 +15,46 @@
1515
# limitations under the License.
1616
#
1717
"""Uploads a TensorBoard logdir to TensorBoard.gcp."""
18+
1819
import abc
1920
from collections import defaultdict
2021
import functools
2122
import logging
2223
import os
23-
import time
2424
import re
25-
from typing import (
26-
Dict,
27-
FrozenSet,
28-
Generator,
29-
Iterable,
30-
Optional,
31-
ContextManager,
32-
Tuple,
33-
)
25+
import time
26+
from typing import ContextManager, Dict, FrozenSet, Generator, Iterable, Optional, Tuple
3427
import uuid
3528

29+
from google.api_core import exceptions
30+
from google.cloud import storage
31+
from google.cloud.aiplatform import base
32+
from google.cloud.aiplatform.compat.services import (
33+
tensorboard_service_client,
34+
)
35+
from google.cloud.aiplatform.compat.types import tensorboard_data
36+
from google.cloud.aiplatform.compat.types import tensorboard_experiment
37+
from google.cloud.aiplatform.compat.types import tensorboard_service
38+
from google.cloud.aiplatform.compat.types import tensorboard_time_series
39+
from google.cloud.aiplatform.tensorboard import uploader_utils
40+
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import (
41+
profile_uploader,
42+
)
3643
import grpc
44+
import tensorflow as tf
45+
46+
from google.protobuf import timestamp_pb2 as timestamp
47+
from google.protobuf import message
3748
from tensorboard.backend import process_graph
3849
from tensorboard.backend.event_processing.plugin_event_accumulator import (
3950
directory_loader,
4051
)
4152
from tensorboard.backend.event_processing.plugin_event_accumulator import (
4253
event_file_loader,
4354
)
44-
from tensorboard.backend.event_processing.plugin_event_accumulator import io_wrapper
55+
from tensorboard.backend.event_processing.plugin_event_accumulator import (
56+
io_wrapper,
57+
)
4558
from tensorboard.compat.proto import graph_pb2
4659
from tensorboard.compat.proto import summary_pb2
4760
from tensorboard.compat.proto import types_pb2
@@ -52,19 +65,8 @@
5265
from tensorboard.uploader.proto import server_info_pb2
5366
from tensorboard.util import tb_logging
5467
from tensorboard.util import tensor_util
55-
import tensorflow as tf
5668

57-
from google.api_core import exceptions
58-
from google.cloud import storage
59-
from google.cloud.aiplatform.compat.services import tensorboard_service_client
60-
from google.cloud.aiplatform.compat.types import tensorboard_data
61-
from google.cloud.aiplatform.compat.types import tensorboard_experiment
62-
from google.cloud.aiplatform.compat.types import tensorboard_service
63-
from google.cloud.aiplatform.compat.types import tensorboard_time_series
64-
from google.cloud.aiplatform.tensorboard import uploader_utils
65-
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
66-
from google.protobuf import message
67-
from google.protobuf import timestamp_pb2 as timestamp
69+
_LOGGER = base.Logger(__name__)
6870

6971
TensorboardServiceClient = tensorboard_service_client.TensorboardServiceClient
7072

@@ -189,6 +191,7 @@ def __init__(
189191
self._allowed_plugins = frozenset(allowed_plugins)
190192
self._run_name_prefix = run_name_prefix
191193
self._is_brand_new_experiment = False
194+
self._continue_uploading = True
192195

193196
self._upload_limits = upload_limits
194197
if not self._upload_limits:
@@ -388,20 +391,22 @@ def start_uploading(self):
388391
"performance."
389392
)
390393

391-
while True:
394+
while self._continue_uploading:
392395
self._logdir_poll_rate_limiter.tick()
393396
self._upload_once()
394397
if self._one_shot:
395398
break
396399
if self._one_shot and not self._tracker.has_data():
397400
logger.warning(
398-
"One-shot mode was used on a logdir (%s) "
399-
"without any uploadable data" % self._logdir
401+
"One-shot mode was used on a logdir (%s) without any uploadable data"
402+
% self._logdir
400403
)
401404

405+
def _end_uploading(self):
406+
self._continue_uploading = False
407+
402408
def _pre_create_runs_and_time_series(self):
403-
"""
404-
Iterates though the log dir to collect TensorboardRuns and
409+
"""Iterates though the log dir to collect TensorboardRuns and
405410
TensorboardTimeSeries that need to be created, and creates them in batch
406411
to speed up uploading later on.
407412
"""
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Constants shared between TensorBoard command line uploader and SDK uploader"""
2+
3+
from tensorboard.plugins.distribution import (
4+
metadata as distribution_metadata,
5+
)
6+
from tensorboard.plugins.graph import metadata as graphs_metadata
7+
from tensorboard.plugins.histogram import (
8+
metadata as histogram_metadata,
9+
)
10+
from tensorboard.plugins.hparams import metadata as hparams_metadata
11+
from tensorboard.plugins.image import metadata as images_metadata
12+
from tensorboard.plugins.scalar import metadata as scalar_metadata
13+
from tensorboard.plugins.text import metadata as text_metadata
14+
15+
ALLOWED_PLUGINS = [
16+
scalar_metadata.PLUGIN_NAME,
17+
histogram_metadata.PLUGIN_NAME,
18+
distribution_metadata.PLUGIN_NAME,
19+
text_metadata.PLUGIN_NAME,
20+
hparams_metadata.PLUGIN_NAME,
21+
images_metadata.PLUGIN_NAME,
22+
graphs_metadata.PLUGIN_NAME,
23+
]

google/cloud/aiplatform/tensorboard/uploader_main.py

Lines changed: 12 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,16 @@
1919

2020
from absl import app
2121
from absl import flags
22-
import grpc
23-
from tensorboard.plugins.scalar import metadata as scalar_metadata
24-
from tensorboard.plugins.distribution import metadata as distribution_metadata
25-
from tensorboard.plugins.histogram import metadata as histogram_metadata
26-
from tensorboard.plugins.text import metadata as text_metadata
27-
from tensorboard.plugins.hparams import metadata as hparams_metadata
28-
from tensorboard.plugins.image import metadata as images_metadata
29-
from tensorboard.plugins.graph import metadata as graphs_metadata
30-
3122
from google.api_core import exceptions
32-
from google.cloud import storage
3323
from google.cloud import aiplatform
34-
from google.cloud.aiplatform.constants import base as constants
3524
from google.cloud.aiplatform import jobs
25+
from google.cloud.aiplatform.constants import base as constants
3626
from google.cloud.aiplatform.tensorboard import uploader
27+
from google.cloud.aiplatform.tensorboard import uploader_constants
28+
from google.cloud.aiplatform.tensorboard import uploader_utils
3729
from google.cloud.aiplatform.utils import TensorboardClientWithOverride
3830

31+
3932
FLAGS = flags.FLAGS
4033
flags.DEFINE_string("experiment_name", None, "The name of the Cloud AI Experiment.")
4134
flags.DEFINE_string(
@@ -73,15 +66,7 @@
7366

7467
flags.DEFINE_multi_string(
7568
"allowed_plugins",
76-
[
77-
scalar_metadata.PLUGIN_NAME,
78-
histogram_metadata.PLUGIN_NAME,
79-
distribution_metadata.PLUGIN_NAME,
80-
text_metadata.PLUGIN_NAME,
81-
hparams_metadata.PLUGIN_NAME,
82-
images_metadata.PLUGIN_NAME,
83-
graphs_metadata.PLUGIN_NAME,
84-
],
69+
uploader_constants.ALLOWED_PLUGINS,
8570
"Plugins allowed by the Uploader.",
8671
)
8772

@@ -103,29 +88,12 @@ def main(argv):
10388
location_override=region,
10489
)
10590

106-
try:
107-
tensorboard = api_client.get_tensorboard(name=FLAGS.tensorboard_resource_name)
108-
except grpc.RpcError as rpc_error:
109-
if rpc_error.code() == grpc.StatusCode.NOT_FOUND:
110-
raise app.UsageError(
111-
"Tensorboard resource %s not found" % FLAGS.tensorboard_resource_name,
112-
exitcode=0,
113-
) from rpc_error
114-
raise
115-
116-
if tensorboard.blob_storage_path_prefix:
117-
path_prefix = tensorboard.blob_storage_path_prefix + "/"
118-
first_slash_index = path_prefix.find("/")
119-
bucket_name = path_prefix[:first_slash_index]
120-
blob_storage_bucket = storage.Client(project=project_id).bucket(bucket_name)
121-
blob_storage_folder = path_prefix[first_slash_index + 1 :]
122-
else:
123-
raise app.UsageError(
124-
"Tensorboard resource {} is obsolete. Please create a new one.".format(
125-
FLAGS.tensorboard_resource_name
126-
),
127-
exitcode=0,
128-
)
91+
(
92+
blob_storage_bucket,
93+
blob_storage_folder,
94+
) = uploader_utils.get_blob_storage_bucket_and_folder(
95+
api_client, FLAGS.tensorboard_resource_name, project_id
96+
)
12997

13098
experiment_name = FLAGS.experiment_name
13199
experiment_display_name = get_experiment_display_name_with_override(
@@ -135,7 +103,7 @@ def main(argv):
135103
tb_uploader = uploader.TensorBoardUploader(
136104
experiment_name=experiment_name,
137105
experiment_display_name=experiment_display_name,
138-
tensorboard_resource_name=tensorboard.name,
106+
tensorboard_resource_name=FLAGS.tensorboard_resource_name,
139107
blob_storage_bucket=blob_storage_bucket,
140108
blob_storage_folder=blob_storage_folder,
141109
allowed_plugins=FLAGS.allowed_plugins,

0 commit comments

Comments
 (0)