@@ -1047,6 +1047,8 @@ def _deserialize_sklearn(self, serialized_gcs_path: str) -> PandasData:
1047
1047
By default, sklearn returns a numpy array which uses CloudPickleSerializer.
1048
1048
If a bigframes.dataframe.DataFrame is desired for the return type,
1049
1049
b/291147206 (cl/548228568) is required
1050
+
1051
+ serialized_gcs_path is a folder containing one or more parquet files.
1050
1052
"""
1051
1053
# Deserialization at remote environment
1052
1054
try :
@@ -1069,7 +1071,7 @@ def _deserialize_sklearn(self, serialized_gcs_path: str) -> PandasData:
1069
1071
def _deserialize_torch (self , serialized_gcs_path : str ) -> TorchTensor :
1070
1072
"""Torch deserializes parquet (GCS) --> torch.tensor
1071
1073
1072
- Assumes one parquet file is created .
1074
+ serialized_gcs_path is a folder containing one or more parquet files .
1073
1075
"""
1074
1076
# Deserialization at remote environment
1075
1077
try :
@@ -1107,7 +1109,7 @@ def reduce_tensors(a, b):
1107
1109
def _deserialize_tensorflow (self , serialized_gcs_path : str ) -> TFDataset :
1108
1110
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
1109
1111
1110
- Assumes one parquet file is created .
1112
+ serialized_gcs_path is a folder containing one or more parquet files .
1111
1113
"""
1112
1114
# Deserialization at remote environment
1113
1115
try :
@@ -1118,14 +1120,15 @@ def _deserialize_tensorflow(self, serialized_gcs_path: str) -> TFDataset:
1118
1120
) from e
1119
1121
1120
1122
# Deserialization always happens at remote, so gcs filesystem is mounted to /gcs/
1121
- # TODO(b/296475384): Handle multiple parquet shards
1122
- if len ( os . listdir ( serialized_gcs_path + "/" )) > 1 :
1123
- raise RuntimeError (
1124
- "Large datasets which are serialized into sharded parquet are not yet supported (b/296475384)"
1125
- )
1123
+ files = os . listdir ( serialized_gcs_path + "/" )
1124
+ files = list (
1125
+ map ( lambda file_name : serialized_gcs_path + "/" + file_name , files )
1126
+ )
1127
+ ds = tfio . IODataset . from_parquet ( files [ 0 ] )
1126
1128
1127
- single_parquet_gcs_path = serialized_gcs_path + "/" + "000000000000"
1128
- ds = tfio .IODataset .from_parquet (single_parquet_gcs_path )
1129
+ for file_name in files [1 :]:
1130
+ ds_shard = tfio .IODataset .from_parquet (file_name )
1131
+ ds = ds .concatenate (ds_shard )
1129
1132
1130
1133
# TODO(b/296474656) Parquet must have "target" column for y
1131
1134
def map_fn (row ):
0 commit comments