Skip to content

Commit 1634940

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Support custom target y column name for Bigframes Tensorflow
PiperOrigin-RevId: 592910297
1 parent 6e6d005 commit 1634940

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

tests/system/vertexai/test_bigframes_tensorflow.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def test_remote_execution_keras(self, shared_state):
8080
"virginica": 1,
8181
"setosa": 2,
8282
}
83-
df["target"] = df["species"].map(species_categories)
84-
df = df.drop(columns=["species"])
83+
df["species"] = df["species"].map(species_categories)
8584

8685
train, _ = bf_train_test_split(df, test_size=0.2)
8786

@@ -96,7 +95,10 @@ def test_remote_execution_keras(self, shared_state):
9695
enable_cuda=True,
9796
display_name=self._make_display_name("bigframes-keras-training"),
9897
)
99-
model.fit.vertex.remote_config.serializer_args[train] = {"batch_size": 10}
98+
model.fit.vertex.remote_config.serializer_args[train] = {
99+
"batch_size": 10,
100+
"target_col": "species",
101+
}
100102

101103
# Train model on Vertex
102104
model.fit(train, epochs=10)

tests/unit/vertexai/test_any_serializer.py

+1
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,7 @@ def test_any_serializer_deserialize_bigframe_tensorflow(
11061106
any_serializer_instance._instances[serializers.BigframeSerializer],
11071107
serialized_gcs_path=fake_gcs_path,
11081108
batch_size=None,
1109+
target_col=None,
11091110
)
11101111

11111112
def test_any_serializer_deserialize_tf_dataset(

vertexai/preview/_workflow/serialization_engine/serializers.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1200,7 +1200,7 @@ def deserialize(
12001200
return self._deserialize_torch(serialized_gcs_path)
12011201
elif detected_framework == "tensorflow":
12021202
return self._deserialize_tensorflow(
1203-
serialized_gcs_path, kwargs.get("batch_size")
1203+
serialized_gcs_path, kwargs.get("batch_size"), kwargs.get("target_col")
12041204
)
12051205
else:
12061206
raise ValueError(f"Unsupported framework: {detected_framework}")
@@ -1273,14 +1273,18 @@ def reduce_tensors(a, b):
12731273
return functools.reduce(reduce_tensors, list(parquet_df_dp))
12741274

12751275
def _deserialize_tensorflow(
1276-
self, serialized_gcs_path: str, batch_size: Optional[int] = None
1276+
self,
1277+
serialized_gcs_path: str,
1278+
batch_size: Optional[int] = None,
1279+
target_col: Optional[str] = None,
12771280
) -> TFDataset:
12781281
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
12791282
12801283
serialized_gcs_path is a folder containing one or more parquet files.
12811284
"""
1282-
# Set default batch_size
1285+
# Set default kwarg values
12831286
batch_size = batch_size or DEFAULT_TENSORFLOW_BATCHSIZE
1287+
target_col = target_col.encode("ASCII") or b"target"
12841288

12851289
# Deserialization at remote environment
12861290
try:
@@ -1301,13 +1305,12 @@ def _deserialize_tensorflow(
13011305
ds_shard = tfio.IODataset.from_parquet(file_name)
13021306
ds = ds.concatenate(ds_shard)
13031307

1304-
# TODO(b/296474656) Parquet must have "target" column for y
13051308
def map_fn(row):
1306-
target = row[b"target"]
1309+
target = row[target_col]
13071310
row = {
13081311
k: tf.expand_dims(v, -1)
13091312
for k, v in row.items()
1310-
if k != b"target" and k != b"index"
1313+
if k != target_col and k != b"index"
13111314
}
13121315

13131316
def reduce_fn(a, b):

0 commit comments

Comments
 (0)