90
90
"2.12" : "0.32.0" ,
91
91
"2.13" : "0.34.0" , # TODO(b/295580335): Support TF 2.13
92
92
}
93
+ DEFAULT_TENSORFLOW_BATCHSIZE = 32
93
94
94
95
95
96
def get_uri_prefix (gcs_uri : str ) -> str :
@@ -1174,7 +1175,9 @@ def serialize(
1174
1175
# Convert bigframes.dataframe.DataFrame to Parquet (GCS)
1175
1176
parquet_gcs_path = gcs_path + "/*" # path is required to contain '*'
1176
1177
to_serialize .to_parquet (parquet_gcs_path , index = True )
1177
- return parquet_gcs_path
1178
+
1179
+ # Return original gcs_path to retrieve the metadata for later
1180
+ return gcs_path
1178
1181
1179
1182
def _get_tfio_verison (self ):
1180
1183
major , minor , _ = version .Version (tf .__version__ ).release
@@ -1190,15 +1193,15 @@ def _get_tfio_verison(self):
1190
1193
def deserialize (
1191
1194
self , serialized_gcs_path : str , ** kwargs
1192
1195
) -> Union ["pandas.DataFrame" , "bigframes.dataframe.DataFrame" ]: # noqa: F821
1193
- del kwargs
1194
-
1195
1196
detected_framework = BigframeSerializer ._metadata .framework
1196
1197
if detected_framework == "sklearn" :
1197
1198
return self ._deserialize_sklearn (serialized_gcs_path )
1198
1199
elif detected_framework == "torch" :
1199
1200
return self ._deserialize_torch (serialized_gcs_path )
1200
1201
elif detected_framework == "tensorflow" :
1201
- return self ._deserialize_tensorflow (serialized_gcs_path )
1202
+ return self ._deserialize_tensorflow (
1203
+ serialized_gcs_path , kwargs .get ("batch_size" )
1204
+ )
1202
1205
else :
1203
1206
raise ValueError (f"Unsupported framework: { detected_framework } " )
1204
1207
@@ -1269,11 +1272,16 @@ def reduce_tensors(a, b):
1269
1272
1270
1273
return functools .reduce (reduce_tensors , list (parquet_df_dp ))
1271
1274
1272
- def _deserialize_tensorflow (self , serialized_gcs_path : str ) -> TFDataset :
1275
+ def _deserialize_tensorflow (
1276
+ self , serialized_gcs_path : str , batch_size : Optional [int ] = None
1277
+ ) -> TFDataset :
1273
1278
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
1274
1279
1275
1280
serialized_gcs_path is a folder containing one or more parquet files.
1276
1281
"""
1282
+ # Set default batch_size
1283
+ batch_size = batch_size or DEFAULT_TENSORFLOW_BATCHSIZE
1284
+
1277
1285
# Deserialization at remote environment
1278
1286
try :
1279
1287
import tensorflow_io as tfio
@@ -1307,8 +1315,7 @@ def reduce_fn(a, b):
1307
1315
1308
1316
return functools .reduce (reduce_fn , row .values ()), target
1309
1317
1310
- # TODO(b/295535730): Remove hardcoded batch_size of 32
1311
- return ds .map (map_fn ).batch (32 )
1318
+ return ds .map (map_fn ).batch (batch_size )
1312
1319
1313
1320
1314
1321
class CloudPickleSerializer (serializers_base .Serializer ):
0 commit comments