Skip to content

Commit e869465

Browse files
authored
enh(distributed): propagate null features in spark (#448)
1 parent ca67b98 commit e869465

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

mlforecast/distributed/forecast.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,9 @@ def _fit(
377377
]
378378
self.models_ = {}
379379
if SPARK_INSTALLED and isinstance(data, SparkDataFrame):
380-
featurizer = VectorAssembler(inputCols=features, outputCol="features")
380+
featurizer = VectorAssembler(
381+
inputCols=features, outputCol="features", handleInvalid="keep"
382+
)
381383
train_data = featurizer.transform(prep)[target_col, "features"]
382384
for name, model in self.models.items():
383385
trained_model = model._pre_fit(target_col).fit(train_data)

nbs/distributed.forecast.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,9 @@
431431
" features = [x for x in fa.get_column_names(prep) if x not in {id_col, time_col, target_col}]\n",
432432
" self.models_ = {}\n",
433433
" if SPARK_INSTALLED and isinstance(data, SparkDataFrame):\n",
434-
" featurizer = VectorAssembler(inputCols=features, outputCol=\"features\")\n",
434+
" featurizer = VectorAssembler(\n",
435+
" inputCols=features, outputCol=\"features\", handleInvalid=\"keep\"\n",
436+
" )\n",
435437
" train_data = featurizer.transform(prep)[target_col, \"features\"]\n",
436438
" for name, model in self.models.items():\n",
437439
" trained_model = model._pre_fit(target_col).fit(train_data)\n",

0 commit comments

Comments
 (0)