38
38
serializers_base ,
39
39
)
40
40
41
+ from packaging import version
42
+
41
43
try :
42
44
import cloudpickle
43
45
except ImportError :
125
127
_LIGHTNING_ROOT_DIR = "/vertex_lightning_root_dir/"
126
128
SERIALIZATION_METADATA_FILENAME = "serialization_metadata"
127
129
130
+ # Map tf major.minor version to tfio version from https://pypi.org/project/tensorflow-io/
131
+ _TFIO_VERSION_DICT = {
132
+ "2.3" : "0.16.0" , # Align with testing_extra_require: tensorflow >= 2.3.0
133
+ "2.4" : "0.17.1" ,
134
+ "2.5" : "0.19.1" ,
135
+ "2.6" : "0.21.0" ,
136
+ "2.7" : "0.23.1" ,
137
+ "2.8" : "0.25.0" ,
138
+ "2.9" : "0.26.0" ,
139
+ "2.10" : "0.27.0" ,
140
+ "2.11" : "0.31.0" ,
141
+ "2.12" : "0.32.0" ,
142
+ "2.13" : "0.34.0" , # TODO(b/295580335): Support TF 2.13
143
+ }
144
+
128
145
129
146
def get_uri_prefix (gcs_uri : str ) -> str :
130
147
"""Gets the directory of the gcs_uri.
@@ -1117,20 +1134,24 @@ def serialize(
1117
1134
gcs_path : str ,
1118
1135
** kwargs ,
1119
1136
) -> str :
1120
- # All bigframe serializers will be identical (bigframes.dataframe.DataFrame --> parquet)
1121
- # Record the framework in metadata for deserialization
1122
- detected_framework = kwargs .get ("framework" )
1123
- BigframeSerializer ._metadata .framework = detected_framework
1124
- if detected_framework == "torch" :
1125
- self .register_custom_command ("pip install torchdata" )
1126
- self .register_custom_command ("pip install torcharrow" )
1137
+ # All bigframe serializers will convert bigframes.dataframe.DataFrame --> parquet
1127
1138
if not _is_valid_gcs_path (gcs_path ):
1128
1139
raise ValueError (f"Invalid gcs path: { gcs_path } " )
1129
1140
1130
1141
BigframeSerializer ._metadata .dependencies = (
1131
1142
supported_frameworks ._get_bigframe_deps ()
1132
1143
)
1133
1144
1145
+ # Record the framework in metadata for deserialization
1146
+ detected_framework = kwargs .get ("framework" )
1147
+ BigframeSerializer ._metadata .framework = detected_framework
1148
+ if detected_framework == "torch" :
1149
+ self .register_custom_command ("pip install torchdata" )
1150
+ self .register_custom_command ("pip install torcharrow" )
1151
+ elif detected_framework == "tensorflow" :
1152
+ tensorflow_io_dep = "tensorflow-io==" + self ._get_tfio_verison ()
1153
+ BigframeSerializer ._metadata .dependencies .append (tensorflow_io_dep )
1154
+
1134
1155
# Check if index.name is default and set index.name if not
1135
1156
if to_serialize .index .name and to_serialize .index .name != "index" :
1136
1157
raise ValueError ("Index name must be 'index'" )
@@ -1141,6 +1162,17 @@ def serialize(
1141
1162
parquet_gcs_path = gcs_path + "/*" # path is required to contain '*'
1142
1163
to_serialize .to_parquet (parquet_gcs_path , index = True )
1143
1164
1165
+ def _get_tfio_verison (self ):
1166
+ major , minor , _ = version .Version (tf .__version__ ).release
1167
+ tf_version = f"{ major } .{ minor } "
1168
+
1169
+ if tf_version not in _TFIO_VERSION_DICT :
1170
+ raise ValueError (
1171
+ f"Tensorflow version { tf_version } is not supported for Bigframes."
1172
+ + " Supported versions: tensorflow >= 2.3.0, <= 2.12.0."
1173
+ )
1174
+ return _TFIO_VERSION_DICT [tf_version ]
1175
+
1144
1176
def deserialize (
1145
1177
self , serialized_gcs_path : str , ** kwargs
1146
1178
) -> Union [PandasData , BigframesData ]:
0 commit comments