Skip to content

Commit a8f85ec

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Support bigframes sharded parquet ingestion at remote deserialization (Tensorflow)
PiperOrigin-RevId: 562030438
1 parent 468e6e7 commit a8f85ec

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

vertexai/preview/_workflow/serialization_engine/serializers.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,8 @@ def _deserialize_sklearn(self, serialized_gcs_path: str) -> PandasData:
10471047
By default, sklearn returns a numpy array which uses CloudPickleSerializer.
10481048
If a bigframes.dataframe.DataFrame is desired for the return type,
10491049
b/291147206 (cl/548228568) is required
1050+
1051+
serialized_gcs_path is a folder containing one or more parquet files.
10501052
"""
10511053
# Deserialization at remote environment
10521054
try:
@@ -1069,7 +1071,7 @@ def _deserialize_sklearn(self, serialized_gcs_path: str) -> PandasData:
10691071
def _deserialize_torch(self, serialized_gcs_path: str) -> TorchTensor:
10701072
"""Torch deserializes parquet (GCS) --> torch.tensor
10711073
1072-
Assumes one parquet file is created.
1074+
serialized_gcs_path is a folder containing one or more parquet files.
10731075
"""
10741076
# Deserialization at remote environment
10751077
try:
@@ -1107,7 +1109,7 @@ def reduce_tensors(a, b):
11071109
def _deserialize_tensorflow(self, serialized_gcs_path: str) -> TFDataset:
11081110
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
11091111
1110-
Assumes one parquet file is created.
1112+
serialized_gcs_path is a folder containing one or more parquet files.
11111113
"""
11121114
# Deserialization at remote environment
11131115
try:
@@ -1118,14 +1120,15 @@ def _deserialize_tensorflow(self, serialized_gcs_path: str) -> TFDataset:
11181120
) from e
11191121

11201122
# Deserialization always happens at remote, so gcs filesystem is mounted to /gcs/
1121-
# TODO(b/296475384): Handle multiple parquet shards
1122-
if len(os.listdir(serialized_gcs_path + "/")) > 1:
1123-
raise RuntimeError(
1124-
"Large datasets which are serialized into sharded parquet are not yet supported (b/296475384)"
1125-
)
1123+
files = os.listdir(serialized_gcs_path + "/")
1124+
files = list(
1125+
map(lambda file_name: serialized_gcs_path + "/" + file_name, files)
1126+
)
1127+
ds = tfio.IODataset.from_parquet(files[0])
11261128

1127-
single_parquet_gcs_path = serialized_gcs_path + "/" + "000000000000"
1128-
ds = tfio.IODataset.from_parquet(single_parquet_gcs_path)
1129+
for file_name in files[1:]:
1130+
ds_shard = tfio.IODataset.from_parquet(file_name)
1131+
ds = ds.concatenate(ds_shard)
11291132

11301133
# TODO(b/296474656) Parquet must have "target" column for y
11311134
def map_fn(row):

0 commit comments

Comments
 (0)