Skip to content

Commit c833952

Browse files
authored
feat(distributed): support ids in predict (#454)
1 parent 2342dea commit c833952

File tree

3 files changed

+111
-60
lines changed

3 files changed

+111
-60
lines changed

mlforecast/distributed/forecast.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
except ModuleNotFoundError:
3939
RAY_INSTALLED = False
4040
from sklearn.base import clone
41+
from triad import Schema
4142

4243
from mlforecast.core import (
4344
DateFeature,
@@ -455,31 +456,43 @@ def _predict(
455456
before_predict_callback=None,
456457
after_predict_callback=None,
457458
X_df=None,
459+
ids=None,
460+
schema=None,
458461
) -> Iterable[pd.DataFrame]:
459462
for serialized_ts, _, serialized_valid in items:
460463
valid = cloudpickle.loads(serialized_valid)
461464
if valid is not None:
462465
X_df = valid
463466
ts = cloudpickle.loads(serialized_ts)
467+
if ids is not None:
468+
ids = ts.uids.intersection(ids).tolist()
469+
if not ids:
470+
yield pd.DataFrame(
471+
{
472+
field.name: pd.Series(dtype=field.type.to_pandas_dtype())
473+
for field in schema.values()
474+
}
475+
)
476+
return
464477
res = ts.predict(
465478
models=models,
466479
horizon=horizon,
467480
before_predict_callback=before_predict_callback,
468481
after_predict_callback=after_predict_callback,
469482
X_df=X_df,
483+
ids=ids,
470484
)
471485
if valid is not None:
472486
res = res.merge(valid, how="left")
473487
yield res
474488

475-
def _get_predict_schema(self) -> str:
476-
model_names = self.models.keys()
477-
models_schema = ",".join(f"{model_name}:double" for model_name in model_names)
478-
schema = (
479-
f"{self._base_ts.id_col}:string,{self._base_ts.time_col}:datetime,"
480-
+ models_schema
481-
)
482-
return schema
489+
def _get_predict_schema(self) -> Schema:
490+
ids_schema = [
491+
(self._base_ts.id_col, "string"),
492+
(self._base_ts.time_col, "datetime"),
493+
]
494+
models_schema = [(model, "double") for model in self.models.keys()]
495+
return Schema(ids_schema + models_schema)
483496

484497
def predict(
485498
self,
@@ -488,6 +501,7 @@ def predict(
488501
after_predict_callback: Optional[Callable] = None,
489502
X_df: Optional[pd.DataFrame] = None,
490503
new_df: Optional[fugue.AnyDataFrame] = None,
504+
ids: Optional[List[str]] = None,
491505
) -> fugue.AnyDataFrame:
492506
"""Compute the predictions for the next `horizon` steps.
493507
@@ -509,6 +523,8 @@ def predict(
509523
Series data of new observations for which forecasts are to be generated.
510524
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
511525
If `new_df` is not None, the method will generate forecasts for the new observations.
526+
ids : list of str, optional (default=None)
527+
List with subset of ids seen during training for which the forecasts should be computed.
512528
513529
Returns
514530
-------
@@ -540,6 +556,8 @@ def predict(
540556
"before_predict_callback": before_predict_callback,
541557
"after_predict_callback": after_predict_callback,
542558
"X_df": X_df,
559+
"ids": ids,
560+
"schema": schema,
543561
},
544562
schema=schema,
545563
engine=self.engine,
@@ -636,9 +654,8 @@ def cross_validation(
636654
keep_last_n=keep_last_n,
637655
window_info=window_info,
638656
)
639-
schema = (
640-
self._get_predict_schema()
641-
+ f",cutoff:datetime,{self._base_ts.target_col}:double"
657+
schema = self._get_predict_schema() + Schema(
658+
("cutoff", "datetime"), (self._base_ts.target_col, "double")
642659
)
643660
preds = fa.transform(
644661
partition_results,

nbs/distributed.forecast.ipynb

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
"except ModuleNotFoundError:\n",
9999
" RAY_INSTALLED = False\n",
100100
"from sklearn.base import clone\n",
101+
"from triad import Schema\n",
101102
"\n",
102103
"from mlforecast.core import (\n",
103104
" DateFeature,\n",
@@ -506,29 +507,41 @@
506507
" horizon,\n",
507508
" before_predict_callback=None,\n",
508509
" after_predict_callback=None,\n",
509-
" X_df=None, \n",
510+
" X_df=None,\n",
511+
" ids=None,\n",
512+
" schema=None,\n",
510513
" ) -> Iterable[pd.DataFrame]:\n",
511514
" for serialized_ts, _, serialized_valid in items:\n",
512515
" valid = cloudpickle.loads(serialized_valid)\n",
513516
" if valid is not None:\n",
514517
" X_df = valid\n",
515518
" ts = cloudpickle.loads(serialized_ts)\n",
519+
" if ids is not None:\n",
520+
" ids = ts.uids.intersection(ids).tolist()\n",
521+
" if not ids:\n",
522+
" yield pd.DataFrame(\n",
523+
" {\n",
524+
" field.name: pd.Series(dtype=field.type.to_pandas_dtype())\n",
525+
" for field in schema.values()\n",
526+
" }\n",
527+
" )\n",
528+
" return\n",
516529
" res = ts.predict(\n",
517530
" models=models,\n",
518531
" horizon=horizon,\n",
519532
" before_predict_callback=before_predict_callback,\n",
520533
" after_predict_callback=after_predict_callback,\n",
521534
" X_df=X_df,\n",
535+
" ids=ids,\n",
522536
" )\n",
523537
" if valid is not None:\n",
524538
" res = res.merge(valid, how='left')\n",
525539
" yield res\n",
526540
" \n",
527-
" def _get_predict_schema(self) -> str:\n",
528-
" model_names = self.models.keys()\n",
529-
" models_schema = ','.join(f'{model_name}:double' for model_name in model_names)\n",
530-
" schema = f'{self._base_ts.id_col}:string,{self._base_ts.time_col}:datetime,' + models_schema\n",
531-
" return schema\n",
541+
" def _get_predict_schema(self) -> Schema:\n",
542+
" ids_schema = [(self._base_ts.id_col, 'string'), (self._base_ts.time_col, 'datetime')]\n",
543+
" models_schema = [(model, 'double') for model in self.models.keys()]\n",
544+
" return Schema(ids_schema + models_schema)\n",
532545
"\n",
533546
" def predict(\n",
534547
" self,\n",
@@ -537,6 +550,7 @@
537550
" after_predict_callback: Optional[Callable] = None,\n",
538551
" X_df: Optional[pd.DataFrame] = None,\n",
539552
" new_df: Optional[fugue.AnyDataFrame] = None,\n",
553+
" ids: Optional[List[str]] = None,\n",
540554
" ) -> fugue.AnyDataFrame:\n",
541555
" \"\"\"Compute the predictions for the next `horizon` steps.\n",
542556
"\n",
@@ -557,7 +571,9 @@
557571
" new_df : dask or spark DataFrame, optional (default=None)\n",
558572
" Series data of new observations for which forecasts are to be generated.\n",
559573
" This dataframe should have the same structure as the one used to fit the model, including any features and time series data.\n",
560-
" If `new_df` is not None, the method will generate forecasts for the new observations. \n",
574+
" If `new_df` is not None, the method will generate forecasts for the new observations.\n",
575+
" ids : list of str, optional (default=None)\n",
576+
" List with subset of ids seen during training for which the forecasts should be computed. \n",
561577
"\n",
562578
" Returns\n",
563579
" -------\n",
@@ -589,6 +605,8 @@
589605
" 'before_predict_callback': before_predict_callback,\n",
590606
" 'after_predict_callback': after_predict_callback,\n",
591607
" 'X_df': X_df,\n",
608+
" 'ids': ids,\n",
609+
" 'schema': schema,\n",
592610
" },\n",
593611
" schema=schema,\n",
594612
" engine=self.engine,\n",
@@ -685,7 +703,10 @@
685703
" keep_last_n=keep_last_n,\n",
686704
" window_info=window_info,\n",
687705
" )\n",
688-
" schema = self._get_predict_schema() + f',cutoff:datetime,{self._base_ts.target_col}:double'\n",
706+
" schema = (\n",
707+
" self._get_predict_schema() + Schema(\n",
708+
" ('cutoff', 'datetime'), (self._base_ts.target_col, 'double'))\n",
709+
" )\n",
689710
" preds = fa.transform(\n",
690711
" partition_results,\n",
691712
" DistributedMLForecast._predict,\n",

nbs/docs/getting-started/quick_start_distributed.ipynb

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -366,32 +366,31 @@
366366
"source": [
367367
"#| hide\n",
368368
"# test num_partitions works properly\n",
369-
"if sys.version_info >= (3, 9):\n",
370-
" num_partitions_test = 4\n",
371-
" test_dd = dd.from_pandas(series, npartitions=num_partitions_test) # In this case we dont have to specify the column\n",
372-
" test_dd['unique_id'] = test_dd['unique_id'].astype(str)\n",
373-
" fcst_np = DistributedMLForecast(\n",
374-
" models=models,\n",
375-
" freq='D',\n",
376-
" target_transforms=[Differences([7])], \n",
377-
" lags=[7],\n",
378-
" lag_transforms={\n",
379-
" 1: [ExpandingMean()],\n",
380-
" 7: [RollingMean(window_size=14)]\n",
381-
" },\n",
382-
" date_features=['dayofweek', 'month'],\n",
383-
" num_threads=1,\n",
384-
" engine=client,\n",
385-
" num_partitions=num_partitions_test\n",
386-
" )\n",
387-
" fcst_np.fit(test_dd)\n",
388-
" test_partition_results_size(fcst_np, num_partitions_test)\n",
389-
" preds_np = fcst_np.predict(7).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
390-
" preds = fcst.predict(7, X_df=future).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
391-
" pd.testing.assert_frame_equal(\n",
392-
" preds[['unique_id', 'ds']], \n",
393-
" preds_np[['unique_id', 'ds']], \n",
394-
" )"
369+
"num_partitions_test = 4\n",
370+
"test_dd = dd.from_pandas(series, npartitions=num_partitions_test) # In this case we dont have to specify the column\n",
371+
"test_dd['unique_id'] = test_dd['unique_id'].astype(str)\n",
372+
"fcst_np = DistributedMLForecast(\n",
373+
" models=models,\n",
374+
" freq='D',\n",
375+
" target_transforms=[Differences([7])], \n",
376+
" lags=[7],\n",
377+
" lag_transforms={\n",
378+
" 1: [ExpandingMean()],\n",
379+
" 7: [RollingMean(window_size=14)]\n",
380+
" },\n",
381+
" date_features=['dayofweek', 'month'],\n",
382+
" num_threads=1,\n",
383+
" engine=client,\n",
384+
" num_partitions=num_partitions_test\n",
385+
")\n",
386+
"fcst_np.fit(test_dd)\n",
387+
"test_partition_results_size(fcst_np, num_partitions_test)\n",
388+
"preds_np = fcst_np.predict(7).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
389+
"preds = fcst.predict(7, X_df=future).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
390+
"pd.testing.assert_frame_equal(\n",
391+
" preds[['unique_id', 'ds']], \n",
392+
" preds_np[['unique_id', 'ds']], \n",
393+
")"
395394
]
396395
},
397396
{
@@ -448,48 +447,48 @@
448447
" <th>0</th>\n",
449448
" <td>id_00</td>\n",
450449
" <td>2002-09-27 00:00:00</td>\n",
451-
" <td>22.267619</td>\n",
452-
" <td>21.835798</td>\n",
450+
" <td>21.722841</td>\n",
451+
" <td>21.725511</td>\n",
453452
" </tr>\n",
454453
" <tr>\n",
455454
" <th>1</th>\n",
456455
" <td>id_00</td>\n",
457456
" <td>2002-09-28 00:00:00</td>\n",
458-
" <td>85.230055</td>\n",
459-
" <td>83.996424</td>\n",
457+
" <td>84.918194</td>\n",
458+
" <td>84.606362</td>\n",
460459
" </tr>\n",
461460
" <tr>\n",
462461
" <th>2</th>\n",
463462
" <td>id_00</td>\n",
464463
" <td>2002-09-29 00:00:00</td>\n",
465-
" <td>168.256154</td>\n",
466-
" <td>163.076652</td>\n",
464+
" <td>162.067624</td>\n",
465+
" <td>163.36802</td>\n",
467466
" </tr>\n",
468467
" <tr>\n",
469468
" <th>3</th>\n",
470469
" <td>id_00</td>\n",
471470
" <td>2002-09-30 00:00:00</td>\n",
472-
" <td>246.712244</td>\n",
473-
" <td>245.827467</td>\n",
471+
" <td>249.001477</td>\n",
472+
" <td>246.422894</td>\n",
474473
" </tr>\n",
475474
" <tr>\n",
476475
" <th>4</th>\n",
477476
" <td>id_00</td>\n",
478477
" <td>2002-10-01 00:00:00</td>\n",
479-
" <td>314.184225</td>\n",
480-
" <td>315.257849</td>\n",
478+
" <td>317.149512</td>\n",
479+
" <td>315.538403</td>\n",
481480
" </tr>\n",
482481
" </tbody>\n",
483482
"</table>\n",
484483
"</div>"
485484
],
486485
"text/plain": [
487486
" unique_id ds DaskXGBForecast DaskLGBMForecast\n",
488-
"0 id_00 2002-09-27 00:00:00 22.267619 21.835798\n",
489-
"1 id_00 2002-09-28 00:00:00 85.230055 83.996424\n",
490-
"2 id_00 2002-09-29 00:00:00 168.256154 163.076652\n",
491-
"3 id_00 2002-09-30 00:00:00 246.712244 245.827467\n",
492-
"4 id_00 2002-10-01 00:00:00 314.184225 315.257849"
487+
"0 id_00 2002-09-27 00:00:00 21.722841 21.725511\n",
488+
"1 id_00 2002-09-28 00:00:00 84.918194 84.606362\n",
489+
"2 id_00 2002-09-29 00:00:00 162.067624 163.36802\n",
490+
"3 id_00 2002-09-30 00:00:00 249.001477 246.422894\n",
491+
"4 id_00 2002-10-01 00:00:00 317.149512 315.538403"
493492
]
494493
},
495494
"execution_count": null,
@@ -502,6 +501,20 @@
502501
"preds.head()"
503502
]
504503
},
504+
{
505+
"cell_type": "code",
506+
"execution_count": null,
507+
"id": "0150de6d-88b5-4513-bd82-c835ba945e79",
508+
"metadata": {},
509+
"outputs": [],
510+
"source": [
511+
"#| hide\n",
512+
"# predict with ids\n",
513+
"ids = np.random.choice(series['unique_id'].unique(), size=10, replace=False)\n",
514+
"preds_ids = fcst.predict(7, X_df=future[future['unique_id'].isin(ids)], ids=ids).compute()\n",
515+
"assert set(preds_ids['unique_id']) == set(ids)"
516+
]
517+
},
505518
{
506519
"cell_type": "code",
507520
"execution_count": null,

0 commit comments

Comments
 (0)