Skip to content

Commit e58689b

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Install Bigframes tensorflow dependencies automatically
PiperOrigin-RevId: 571107540
1 parent 5c993d2 commit e58689b

File tree

1 file changed

+39
-7
lines changed

1 file changed

+39
-7
lines changed

vertexai/preview/_workflow/serialization_engine/serializers.py

+39-7
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
serializers_base,
3939
)
4040

41+
from packaging import version
42+
4143
try:
4244
import cloudpickle
4345
except ImportError:
@@ -125,6 +127,21 @@
125127
_LIGHTNING_ROOT_DIR = "/vertex_lightning_root_dir/"
126128
SERIALIZATION_METADATA_FILENAME = "serialization_metadata"
127129

130+
# Map tf major.minor version to tfio version from https://pypi.org/project/tensorflow-io/
131+
_TFIO_VERSION_DICT = {
132+
"2.3": "0.16.0", # Align with testing_extra_require: tensorflow >= 2.3.0
133+
"2.4": "0.17.1",
134+
"2.5": "0.19.1",
135+
"2.6": "0.21.0",
136+
"2.7": "0.23.1",
137+
"2.8": "0.25.0",
138+
"2.9": "0.26.0",
139+
"2.10": "0.27.0",
140+
"2.11": "0.31.0",
141+
"2.12": "0.32.0",
142+
"2.13": "0.34.0", # TODO(b/295580335): Support TF 2.13
143+
}
144+
128145

129146
def get_uri_prefix(gcs_uri: str) -> str:
130147
"""Gets the directory of the gcs_uri.
@@ -1117,20 +1134,24 @@ def serialize(
11171134
gcs_path: str,
11181135
**kwargs,
11191136
) -> str:
1120-
# All bigframe serializers will be identical (bigframes.dataframe.DataFrame --> parquet)
1121-
# Record the framework in metadata for deserialization
1122-
detected_framework = kwargs.get("framework")
1123-
BigframeSerializer._metadata.framework = detected_framework
1124-
if detected_framework == "torch":
1125-
self.register_custom_command("pip install torchdata")
1126-
self.register_custom_command("pip install torcharrow")
1137+
# All bigframe serializers will convert bigframes.dataframe.DataFrame --> parquet
11271138
if not _is_valid_gcs_path(gcs_path):
11281139
raise ValueError(f"Invalid gcs path: {gcs_path}")
11291140

11301141
BigframeSerializer._metadata.dependencies = (
11311142
supported_frameworks._get_bigframe_deps()
11321143
)
11331144

1145+
# Record the framework in metadata for deserialization
1146+
detected_framework = kwargs.get("framework")
1147+
BigframeSerializer._metadata.framework = detected_framework
1148+
if detected_framework == "torch":
1149+
self.register_custom_command("pip install torchdata")
1150+
self.register_custom_command("pip install torcharrow")
1151+
elif detected_framework == "tensorflow":
1152+
tensorflow_io_dep = "tensorflow-io==" + self._get_tfio_verison()
1153+
BigframeSerializer._metadata.dependencies.append(tensorflow_io_dep)
1154+
11341155
# Check if index.name is default and set index.name if not
11351156
if to_serialize.index.name and to_serialize.index.name != "index":
11361157
raise ValueError("Index name must be 'index'")
@@ -1141,6 +1162,17 @@ def serialize(
11411162
parquet_gcs_path = gcs_path + "/*" # path is required to contain '*'
11421163
to_serialize.to_parquet(parquet_gcs_path, index=True)
11431164

1165+
def _get_tfio_verison(self):
1166+
major, minor, _ = version.Version(tf.__version__).release
1167+
tf_version = f"{major}.{minor}"
1168+
1169+
if tf_version not in _TFIO_VERSION_DICT:
1170+
raise ValueError(
1171+
f"Tensorflow version {tf_version} is not supported for Bigframes."
1172+
+ " Supported versions: tensorflow >= 2.3.0, <= 2.12.0."
1173+
)
1174+
return _TFIO_VERSION_DICT[tf_version]
1175+
11441176
def deserialize(
11451177
self, serialized_gcs_path: str, **kwargs
11461178
) -> Union[PandasData, BigframesData]:

0 commit comments

Comments
 (0)