Skip to content

Commit 3626418

Browse files
authored
fix(distributed): exogenous handling in distributed cross validation (#443)
1 parent eab21b1 commit 3626418

File tree

3 files changed

+310
-229
lines changed

3 files changed

+310
-229
lines changed

mlforecast/distributed/forecast.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _preprocess_partition(
195195
train = part[train_mask]
196196
valid_keep_cols = part.columns
197197
if static_features is not None:
198-
valid_keep_cols.drop(static_features)
198+
valid_keep_cols = valid_keep_cols.drop(static_features)
199199
valid = part.loc[valid_mask, valid_keep_cols].merge(cutoffs, on=id_col)
200200
transformed = ts.fit_transform(
201201
train,
@@ -456,6 +456,8 @@ def _predict(
456456
) -> Iterable[pd.DataFrame]:
457457
for serialized_ts, _, serialized_valid in items:
458458
valid = cloudpickle.loads(serialized_valid)
459+
if valid is not None:
460+
X_df = valid
459461
ts = cloudpickle.loads(serialized_ts)
460462
res = ts.predict(
461463
models=models,
@@ -649,7 +651,11 @@ def cross_validation(
649651
engine=self.engine,
650652
)
651653
results.append(fa.get_native_as_df(preds))
652-
return fa.union(*results)
654+
if len(results) == 1:
655+
res = results[0]
656+
else:
657+
res = fa.union(*results)
658+
return res
653659

654660
@staticmethod
655661
def _save_ts(items: List[List[Any]], path: str) -> Iterable[pd.DataFrame]:

nbs/distributed.forecast.ipynb

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@
261261
" train = part[train_mask]\n",
262262
" valid_keep_cols = part.columns\n",
263263
" if static_features is not None:\n",
264-
" valid_keep_cols.drop(static_features)\n",
264+
" valid_keep_cols = valid_keep_cols.drop(static_features)\n",
265265
" valid = part.loc[valid_mask, valid_keep_cols].merge(cutoffs, on=id_col)\n",
266266
" transformed = ts.fit_transform(\n",
267267
" train,\n",
@@ -508,6 +508,8 @@
508508
" ) -> Iterable[pd.DataFrame]:\n",
509509
" for serialized_ts, _, serialized_valid in items:\n",
510510
" valid = cloudpickle.loads(serialized_valid)\n",
511+
" if valid is not None:\n",
512+
" X_df = valid\n",
511513
" ts = cloudpickle.loads(serialized_ts)\n",
512514
" res = ts.predict(\n",
513515
" models=models,\n",
@@ -695,7 +697,11 @@
695697
" engine=self.engine,\n",
696698
" )\n",
697699
" results.append(fa.get_native_as_df(preds))\n",
698-
" return fa.union(*results)\n",
700+
" if len(results) == 1:\n",
701+
" res = results[0]\n",
702+
" else:\n",
703+
" res = fa.union(*results)\n",
704+
" return res\n",
699705
"\n",
700706
" @staticmethod\n",
701707
" def _save_ts(items: List[List[Any]], path: str) -> Iterable[pd.DataFrame]:\n",

0 commit comments

Comments
 (0)