Skip to content

Commit f05924d

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Enable continuous upload for profile logs.
PiperOrigin-RevId: 624258810
1 parent 894c73f commit f05924d

File tree

3 files changed

+59
-12
lines changed

3 files changed

+59
-12
lines changed

google/cloud/aiplatform/tensorboard/plugins/tf_profiler/profile_uploader.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def _profile_dir(self, run_name: str) -> str:
160160
Returns:
161161
Full path for run name.
162162
"""
163+
if run_name is None:
164+
return os.path.join(self._logdir, self.PROFILE_PATH)
163165
return os.path.join(self._logdir, run_name, self.PROFILE_PATH)
164166

165167
def send_request(self, run_name: str):
@@ -171,7 +173,7 @@ def send_request(self, run_name: str):
171173
"""
172174

173175
if not self._is_valid_event(run_name):
174-
logger.warning("No such profile run for %s", run_name)
176+
logger.debug("No such profile run for %s", run_name)
175177
return
176178

177179
# Create a profiler loader if one is not created.

google/cloud/aiplatform/tensorboard/uploader.py

-4
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,6 @@ def create_experiment(self):
306306
def _should_profile(self) -> bool:
307307
"""Indicate if profile plugin should be enabled."""
308308
if "profile" in self._allowed_plugins:
309-
if not self._one_shot:
310-
raise ValueError(
311-
"Profile plugin currently only supported for one shot."
312-
)
313309
logger.info("Profile plugin is enabled.")
314310
return True
315311
return False

tests/unit/aiplatform/test_uploader.py

+56-7
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@
8282
)
8383
)
8484

85+
_SCALARS_HISTOGRAMS_AND_PROFILE = frozenset(
86+
(
87+
scalars_metadata.PLUGIN_NAME,
88+
"profile",
89+
)
90+
)
91+
92+
8593
# Sentinel for `_create_*` helpers, for arguments for which we want to
8694
# supply a default other than the `None` used by the code under test.
8795
_USE_DEFAULT = object()
@@ -1095,7 +1103,23 @@ def test_thread_continuously_uploads(self):
10951103

10961104
logdir = self.get_temp_dir()
10971105
mock_client = _create_mock_client()
1098-
uploader = _create_uploader(mock_client, logdir)
1106+
builder = _create_dispatcher(
1107+
experiment_resource_name=_TEST_ONE_PLATFORM_EXPERIMENT_NAME,
1108+
api=mock_client,
1109+
allowed_plugins=_SCALARS_HISTOGRAMS_AND_PROFILE,
1110+
logdir=logdir,
1111+
)
1112+
mock_rate_limiter = mock.create_autospec(util.RateLimiter)
1113+
mock_bucket = _create_mock_blob_storage()
1114+
1115+
uploader = _create_uploader(
1116+
mock_client,
1117+
logdir,
1118+
allowed_plugins=_SCALARS_HISTOGRAMS_AND_PROFILE,
1119+
rpc_rate_limiter=mock_rate_limiter,
1120+
blob_storage_bucket=mock_bucket,
1121+
)
1122+
uploader._dispatcher = builder
10991123
uploader.create_experiment()
11001124

11011125
# Convenience helpers for constructing expected requests.
@@ -1104,7 +1128,7 @@ def test_thread_continuously_uploads(self):
11041128
scalar = tensorboard_data.Scalar
11051129

11061130
# Directory with scalar data
1107-
writer = FileWriter(logdir)
1131+
writer = FileWriter(os.path.join(logdir, "a"))
11081132
metadata = summary_pb2.SummaryMetadata(
11091133
plugin_data=summary_pb2.SummaryMetadata.PluginData(
11101134
plugin_name="scalars", content=b"12345"
@@ -1121,18 +1145,43 @@ def test_thread_continuously_uploads(self):
11211145
value_metadata=metadata,
11221146
)
11231147
writer.flush()
1124-
writer_a = FileWriter(os.path.join(logdir, "a"))
1148+
writer_a = FileWriter(os.path.join(logdir, "b"))
11251149
writer_a.add_test_summary("qux", simple_value=9.0, step=2)
11261150
writer_a.flush()
1151+
1152+
# Directory with profile data
1153+
prof_run_name = "2024_04_04_04_24_24"
1154+
prof_path = os.path.join(
1155+
logdir, profile_uploader.ProfileRequestSender.PROFILE_PATH
1156+
)
1157+
os.makedirs(prof_path)
1158+
run_path = os.path.join(prof_path, prof_run_name)
1159+
os.makedirs(run_path)
1160+
tempfile.NamedTemporaryFile(
1161+
prefix="c", suffix=".xplane.pb", dir=run_path, delete=False
1162+
)
1163+
self.assertNotEmpty(os.listdir(run_path))
1164+
11271165
uploader_thread = threading.Thread(target=uploader.start_uploading)
11281166
uploader_thread.start()
11291167
time.sleep(5)
1130-
self.assertEqual(3, mock_client.create_tensorboard_time_series.call_count)
1168+
1169+
# Check create_time_series calls
1170+
self.assertEqual(4, mock_client.create_tensorboard_time_series.call_count)
11311171
call_args_list = mock_client.create_tensorboard_time_series.call_args_list
1132-
request = call_args_list[1][1]["tensorboard_time_series"]
1133-
self.assertEqual("scalars", request.plugin_name)
1134-
self.assertEqual(b"12345", request.plugin_data)
1172+
request1, request2, request3, request4 = (
1173+
call_args_list[0][1]["tensorboard_time_series"],
1174+
call_args_list[1][1]["tensorboard_time_series"],
1175+
call_args_list[2][1]["tensorboard_time_series"],
1176+
call_args_list[3][1]["tensorboard_time_series"],
1177+
)
1178+
self.assertEqual("scalars", request1.plugin_name)
1179+
self.assertEqual("scalars", request2.plugin_name)
1180+
self.assertEqual(b"12345", request2.plugin_data)
1181+
self.assertEqual("scalars", request3.plugin_name)
1182+
self.assertEqual("profile", request4.plugin_name)
11351183

1184+
# Check write_tensorboard_experiment_data calls
11361185
self.assertEqual(1, mock_client.write_tensorboard_experiment_data.call_count)
11371186
call_args_list = mock_client.write_tensorboard_experiment_data.call_args_list
11381187
request1, request2 = (

0 commit comments

Comments
 (0)