Skip to content

Commit 8779df5

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Fix bug that broke profiler with '0-rc2' tensorflow versions.
PiperOrigin-RevId: 491683085
1 parent 3e95e8d commit 8779df5

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

google/cloud/aiplatform/training_utils/cloud_profiler/plugins/tensorflow/tf_profiler.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717

1818
"""A plugin to handle remote tensoflow profiler sessions for Vertex AI."""
1919

20-
from google.cloud.aiplatform.training_utils.cloud_profiler import cloud_profiler_utils
20+
from google.cloud.aiplatform.training_utils.cloud_profiler import (
21+
cloud_profiler_utils,
22+
)
2123

2224
try:
2325
import tensorflow as tf
24-
from tensorboard_plugin_profile.profile_plugin import ProfilePlugin
26+
from tensorboard_plugin_profile.profile_plugin import (
27+
ProfilePlugin,
28+
)
2529
except ImportError as err:
2630
raise ImportError(cloud_profiler_utils.import_error_msg) from err
2731

@@ -36,10 +40,14 @@
3640
import tensorboard.plugins.base_plugin as tensorboard_base_plugin
3741
from werkzeug import Response
3842

39-
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
43+
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import (
44+
profile_uploader,
45+
)
4046
from google.cloud.aiplatform.training_utils import environment_variables
4147
from google.cloud.aiplatform.training_utils.cloud_profiler import wsgi_types
42-
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin
48+
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import (
49+
base_plugin,
50+
)
4351
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import (
4452
tensorboard_api,
4553
)
@@ -68,8 +76,7 @@ def _get_tf_versioning() -> Optional[Version]:
6876
versioning = version.split(".")
6977
if len(versioning) != 3:
7078
return
71-
72-
return Version(int(versioning[0]), int(versioning[1]), int(versioning[2]))
79+
return Version(int(versioning[0]), int(versioning[1]), versioning[2])
7380

7481

7582
def _is_compatible_version(version: Version) -> bool:
@@ -228,7 +235,7 @@ def warn_tensorboard_env_var(var_name: str):
228235
Required. The name of the missing environment variable.
229236
"""
230237
logging.warning(
231-
f"Environment variable `{var_name}` must be set. " + _BASE_TB_ENV_WARNING
238+
"Environment variable `%s` must be set. %s", var_name, _BASE_TB_ENV_WARNING
232239
)
233240

234241

tests/unit/aiplatform/test_cloud_profiler.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,12 @@
3131
from google.api_core import exceptions
3232
from google.cloud import aiplatform
3333
from google.cloud.aiplatform import training_utils
34-
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import profile_uploader
35-
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import base_plugin
34+
from google.cloud.aiplatform.tensorboard.plugins.tf_profiler import (
35+
profile_uploader,
36+
)
37+
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins import (
38+
base_plugin,
39+
)
3640
from google.cloud.aiplatform.training_utils.cloud_profiler.plugins.tensorflow import (
3741
tf_profiler,
3842
)
@@ -175,15 +179,21 @@ def tf_import_mock(name, *args, **kwargs):
175179
def testCanInitializeTFVersion(self):
176180
import tensorflow
177181

178-
with mock.patch.object(tensorflow, "__version__", return_value="1.2.3.4"):
182+
with mock.patch.object(tensorflow, "__version__", "1.2.3.4"):
179183
assert not TFProfiler.can_initialize()
180184

181185
def testCanInitializeOldTFVersion(self):
182186
import tensorflow
183187

184-
with mock.patch.object(tensorflow, "__version__", return_value="2.3.0"):
188+
with mock.patch.object(tensorflow, "__version__", "2.3.0"):
185189
assert not TFProfiler.can_initialize()
186190

191+
def testCanInitializeRcTFVersion(self):
192+
import tensorflow as tf
193+
194+
with mock.patch.object(tf, "__version__", "2.4.0-rc2"):
195+
assert TFProfiler.can_initialize()
196+
187197
def testCanInitializeNoProfilePlugin(self):
188198
orig_find_spec = importlib.util.find_spec
189199

0 commit comments

Comments
 (0)