Skip to content

Commit 7dc8771

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Support custom batch size for Bigframes Tensorflow
PiperOrigin-RevId: 589190954
1 parent 0cb1a7b commit 7dc8771

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

tests/system/vertexai/test_bigframes_tensorflow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
"prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources"
6363
)
6464
class TestRemoteExecutionBigframesTensorflow(e2e_base.TestEndToEnd):
65-
6665
_temp_prefix = "temp-vertexai-remote-execution"
6766

6867
def test_remote_execution_keras(self, shared_state):
@@ -97,6 +96,7 @@ def test_remote_execution_keras(self, shared_state):
9796
enable_cuda=True,
9897
display_name=self._make_display_name("bigframes-keras-training"),
9998
)
99+
model.fit.vertex.remote_config.serializer_args[train] = {"batch_size": 10}
100100

101101
# Train model on Vertex
102102
model.fit(train, epochs=10)

tests/unit/vertexai/test_any_serializer.py

+1
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,7 @@ def test_any_serializer_deserialize_bigframe_tensorflow(
11051105
mock_bigframe_deserialize_tensorflow.assert_called_once_with(
11061106
any_serializer_instance._instances[serializers.BigframeSerializer],
11071107
serialized_gcs_path=fake_gcs_path,
1108+
batch_size=None,
11081109
)
11091110

11101111
def test_any_serializer_deserialize_tf_dataset(

vertexai/preview/_workflow/serialization_engine/serializers.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"2.12": "0.32.0",
9191
"2.13": "0.34.0", # TODO(b/295580335): Support TF 2.13
9292
}
93+
DEFAULT_TENSORFLOW_BATCHSIZE = 32
9394

9495

9596
def get_uri_prefix(gcs_uri: str) -> str:
@@ -1174,7 +1175,9 @@ def serialize(
11741175
# Convert bigframes.dataframe.DataFrame to Parquet (GCS)
11751176
parquet_gcs_path = gcs_path + "/*" # path is required to contain '*'
11761177
to_serialize.to_parquet(parquet_gcs_path, index=True)
1177-
return parquet_gcs_path
1178+
1179+
# Return original gcs_path to retrieve the metadata for later
1180+
return gcs_path
11781181

11791182
def _get_tfio_verison(self):
11801183
major, minor, _ = version.Version(tf.__version__).release
@@ -1190,15 +1193,15 @@ def _get_tfio_verison(self):
11901193
def deserialize(
11911194
self, serialized_gcs_path: str, **kwargs
11921195
) -> Union["pandas.DataFrame", "bigframes.dataframe.DataFrame"]: # noqa: F821
1193-
del kwargs
1194-
11951196
detected_framework = BigframeSerializer._metadata.framework
11961197
if detected_framework == "sklearn":
11971198
return self._deserialize_sklearn(serialized_gcs_path)
11981199
elif detected_framework == "torch":
11991200
return self._deserialize_torch(serialized_gcs_path)
12001201
elif detected_framework == "tensorflow":
1201-
return self._deserialize_tensorflow(serialized_gcs_path)
1202+
return self._deserialize_tensorflow(
1203+
serialized_gcs_path, kwargs.get("batch_size")
1204+
)
12021205
else:
12031206
raise ValueError(f"Unsupported framework: {detected_framework}")
12041207

@@ -1269,11 +1272,16 @@ def reduce_tensors(a, b):
12691272

12701273
return functools.reduce(reduce_tensors, list(parquet_df_dp))
12711274

1272-
def _deserialize_tensorflow(self, serialized_gcs_path: str) -> TFDataset:
1275+
def _deserialize_tensorflow(
1276+
self, serialized_gcs_path: str, batch_size: Optional[int] = None
1277+
) -> TFDataset:
12731278
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
12741279
12751280
serialized_gcs_path is a folder containing one or more parquet files.
12761281
"""
1282+
# Set default batch_size
1283+
batch_size = batch_size or DEFAULT_TENSORFLOW_BATCHSIZE
1284+
12771285
# Deserialization at remote environment
12781286
try:
12791287
import tensorflow_io as tfio
@@ -1307,8 +1315,7 @@ def reduce_fn(a, b):
13071315

13081316
return functools.reduce(reduce_fn, row.values()), target
13091317

1310-
# TODO(b/295535730): Remove hardcoded batch_size of 32
1311-
return ds.map(map_fn).batch(32)
1318+
return ds.map(map_fn).batch(batch_size)
13121319

13131320

13141321
class CloudPickleSerializer(serializers_base.Serializer):

0 commit comments

Comments
 (0)