Skip to content

Commit 8a4a41a

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Enable Tensorboard profile plugin in all regions by default.
PiperOrigin-RevId: 638377255
1 parent cb2f4aa commit 8a4a41a

File tree

4 files changed

+32
-27
lines changed

4 files changed

+32
-27
lines changed

google/cloud/aiplatform/tensorboard/uploader_constants.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
from tensorboard.plugins.image import metadata as images_metadata
1414
from tensorboard.plugins.scalar import metadata as scalar_metadata
1515
from tensorboard.plugins.text import metadata as text_metadata
16-
17-
ALLOWED_PLUGINS = [
18-
scalar_metadata.PLUGIN_NAME,
19-
histogram_metadata.PLUGIN_NAME,
20-
distribution_metadata.PLUGIN_NAME,
21-
text_metadata.PLUGIN_NAME,
22-
hparams_metadata.PLUGIN_NAME,
23-
images_metadata.PLUGIN_NAME,
24-
graphs_metadata.PLUGIN_NAME,
25-
]
16+
from tensorboard_plugin_profile import profile_plugin
17+
18+
ALLOWED_PLUGINS = frozenset(
19+
[
20+
scalar_metadata.PLUGIN_NAME,
21+
histogram_metadata.PLUGIN_NAME,
22+
distribution_metadata.PLUGIN_NAME,
23+
text_metadata.PLUGIN_NAME,
24+
hparams_metadata.PLUGIN_NAME,
25+
images_metadata.PLUGIN_NAME,
26+
graphs_metadata.PLUGIN_NAME,
27+
profile_plugin.PLUGIN_NAME,
28+
]
29+
)
2630

2731
# Minimum length of a logdir polling cycle in seconds. Shorter cycles will
2832
# sleep to avoid spinning over the logdir, which isn't great for disks and can

google/cloud/aiplatform/tensorboard/uploader_main.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,11 @@ def main(argv):
103103
experiment_name, FLAGS.experiment_display_name, project_id, region
104104
)
105105

106-
plugins = uploader_constants.ALLOWED_PLUGINS
107-
if FLAGS.allowed_plugins:
108-
plugins += [
109-
plugin
110-
for plugin in FLAGS.allowed_plugins
111-
if plugin not in uploader_constants.ALLOWED_PLUGINS
112-
]
106+
plugins = (
107+
uploader_constants.ALLOWED_PLUGINS.union(FLAGS.allowed_plugins)
108+
if FLAGS.allowed_plugins
109+
else uploader_constants.ALLOWED_PLUGINS
110+
)
113111

114112
tb_uploader = uploader.TensorBoardUploader(
115113
experiment_name=experiment_name,

google/cloud/aiplatform/tensorboard/uploader_tracker.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,11 @@ def _create_uploader(
263263
api_client, tensorboard_resource_name, project
264264
)
265265

266-
plugins = uploader_constants.ALLOWED_PLUGINS
267-
if allowed_plugins:
268-
plugins += [
269-
plugin
270-
for plugin in allowed_plugins
271-
if plugin not in uploader_constants.ALLOWED_PLUGINS
272-
]
266+
plugins = (
267+
uploader_constants.ALLOWED_PLUGINS.union(allowed_plugins)
268+
if allowed_plugins
269+
else uploader_constants.ALLOWED_PLUGINS
270+
)
273271

274272
tensorboard_uploader = TensorBoardUploader(
275273
experiment_name=tensorboard_experiment_name,

tests/unit/aiplatform/test_uploader.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -255,12 +255,18 @@ def _create_uploader(
255255
max_blob_size=max_blob_size,
256256
)
257257

258+
plugins = (
259+
uploader_constants.ALLOWED_PLUGINS.union(allowed_plugins)
260+
if allowed_plugins
261+
else uploader_constants.ALLOWED_PLUGINS
262+
)
263+
258264
return uploader_lib.TensorBoardUploader(
259265
experiment_name=experiment_name,
260266
tensorboard_resource_name=tensorboard_resource_name,
261267
writer_client=writer_client,
262268
logdir=logdir,
263-
allowed_plugins=allowed_plugins,
269+
allowed_plugins=plugins,
264270
upload_limits=upload_limits,
265271
blob_storage_bucket=blob_storage_bucket,
266272
blob_storage_folder=blob_storage_folder,
@@ -1239,7 +1245,7 @@ def create_time_series(tensorboard_time_series, parent=None):
12391245
)
12401246
@patch.object(metadata, "_experiment_tracker", autospec=True)
12411247
@patch.object(experiment_resources, "Experiment", autospec=True)
1242-
def test_add_profile_plugin(
1248+
def test_profile_plugin_included_by_default(
12431249
self, experiment_resources_mock, experiment_tracker_mock, run_resource_mock
12441250
):
12451251
experiment_resources_mock.get.return_value = _TEST_EXPERIMENT_NAME
@@ -1259,7 +1265,6 @@ def test_add_profile_plugin(
12591265
_create_mock_client(),
12601266
logdir,
12611267
one_shot=True,
1262-
allowed_plugins=frozenset(("profile",)),
12631268
run_name_prefix=run_name,
12641269
)
12651270

0 commit comments

Comments
 (0)