File tree 2 files changed +11
-15
lines changed
vertexai/preview/_workflow
2 files changed +11
-15
lines changed Original file line number Diff line number Diff line change @@ -1138,14 +1138,20 @@ def serialize(
1138
1138
if not _is_valid_gcs_path (gcs_path ):
1139
1139
raise ValueError (f"Invalid gcs path: { gcs_path } " )
1140
1140
1141
- BigframeSerializer ._metadata .dependencies = (
1142
- supported_frameworks ._get_bigframe_deps ()
1143
- )
1144
-
1145
1141
# Record the framework in metadata for deserialization
1146
1142
detected_framework = kwargs .get ("framework" )
1147
1143
BigframeSerializer ._metadata .framework = detected_framework
1148
- if detected_framework == "torch" :
1144
+
1145
+ # Reset dependencies and custom_commands in case the framework is different
1146
+ BigframeSerializer ._metadata .dependencies = []
1147
+ BigframeSerializer ._metadata .custom_commands = []
1148
+
1149
+ # Add dependencies based on framework
1150
+ if detected_framework == "sklearn" :
1151
+ sklearn_deps = supported_frameworks ._get_pandas_deps ()
1152
+ sklearn_deps += supported_frameworks ._get_pyarrow_deps ()
1153
+ BigframeSerializer ._metadata .dependencies += sklearn_deps
1154
+ elif detected_framework == "torch" :
1149
1155
# Install using custom_commands to avoid numpy dependency conflict
1150
1156
BigframeSerializer ._metadata .custom_commands .append ("pip install torchdata" )
1151
1157
BigframeSerializer ._metadata .custom_commands .append ("pip install torcharrow" )
Original file line number Diff line number Diff line change @@ -276,16 +276,6 @@ def _get_deps_if_pandas_dataframe(possible_dataframe: Any) -> List[str]:
276
276
return deps
277
277
278
278
279
- def _get_bigframe_deps () -> List [str ]:
280
- deps = []
281
- # Note: bigframe serialization can only occur locally so bigframes
282
- # should not be installed remotely. Pandas and pyarrow are required
283
- # to deserialize for sklearn bigframes though.
284
- deps += _get_pandas_deps ()
285
- deps += _get_pyarrow_deps ()
286
- return deps
287
-
288
-
289
279
def _get_pyarrow_deps () -> List [str ]:
290
280
deps = []
291
281
try :
You can’t perform that action at this time.
0 commit comments