Skip to content

Commit 7aaffe5

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Install Bigframes sklearn dependencies automatically
PiperOrigin-RevId: 571490689
1 parent 9b75259 commit 7aaffe5

File tree

2 files changed

+11
-15
lines changed

2 files changed

+11
-15
lines changed

vertexai/preview/_workflow/serialization_engine/serializers.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -1138,14 +1138,20 @@ def serialize(
11381138
if not _is_valid_gcs_path(gcs_path):
11391139
raise ValueError(f"Invalid gcs path: {gcs_path}")
11401140

1141-
BigframeSerializer._metadata.dependencies = (
1142-
supported_frameworks._get_bigframe_deps()
1143-
)
1144-
11451141
# Record the framework in metadata for deserialization
11461142
detected_framework = kwargs.get("framework")
11471143
BigframeSerializer._metadata.framework = detected_framework
1148-
if detected_framework == "torch":
1144+
1145+
# Reset dependencies and custom_commands in case the framework is different
1146+
BigframeSerializer._metadata.dependencies = []
1147+
BigframeSerializer._metadata.custom_commands = []
1148+
1149+
# Add dependencies based on framework
1150+
if detected_framework == "sklearn":
1151+
sklearn_deps = supported_frameworks._get_pandas_deps()
1152+
sklearn_deps += supported_frameworks._get_pyarrow_deps()
1153+
BigframeSerializer._metadata.dependencies += sklearn_deps
1154+
elif detected_framework == "torch":
11491155
# Install using custom_commands to avoid numpy dependency conflict
11501156
BigframeSerializer._metadata.custom_commands.append("pip install torchdata")
11511157
BigframeSerializer._metadata.custom_commands.append("pip install torcharrow")

vertexai/preview/_workflow/shared/supported_frameworks.py

-10
Original file line numberDiff line numberDiff line change
@@ -276,16 +276,6 @@ def _get_deps_if_pandas_dataframe(possible_dataframe: Any) -> List[str]:
276276
return deps
277277

278278

279-
def _get_bigframe_deps() -> List[str]:
280-
deps = []
281-
# Note: bigframe serialization can only occur locally so bigframes
282-
# should not be installed remotely. Pandas and pyarrow are required
283-
# to deserialize for sklearn bigframes though.
284-
deps += _get_pandas_deps()
285-
deps += _get_pyarrow_deps()
286-
return deps
287-
288-
289279
def _get_pyarrow_deps() -> List[str]:
290280
deps = []
291281
try:

0 commit comments

Comments
 (0)