Skip to content

Commit 3cbaa16

Browse files
authored
fix(distributed): support pre-computed features (#436)
1 parent feba066 commit 3cbaa16

File tree

4 files changed

+127
-131
lines changed

4 files changed

+127
-131
lines changed

mlforecast/core.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,9 @@ def _fit(
322322
"are dynamic please set `static_features=[]`."
323323
)
324324
self.static_features_ = statics_on_ends
325-
self.features_order_ = [
326-
c for c in df.columns if c not in to_drop
327-
] + self.features
325+
self.features_order_ = [c for c in df.columns if c not in to_drop] + [
326+
f for f in self.features if f not in df.columns
327+
]
328328
return self
329329

330330
def _compute_transforms(
@@ -377,8 +377,12 @@ def _transform(
377377
"""Add the features to `df`.
378378
379379
if `dropna=True` then all the null rows are dropped."""
380-
transforms = {k: v for k, v in self.transforms.items() if k not in df}
381-
features = self._compute_transforms(transforms=transforms, updates_only=False)
380+
# we need to compute all transformations in case they save state
381+
features = self._compute_transforms(
382+
transforms=self.transforms, updates_only=False
383+
)
384+
# filter out the features that already exist in df to avoid overwriting them
385+
features = {k: v for k, v in features.items() if k not in df}
382386
if self._restore_idxs is not None:
383387
for k, v in features.items():
384388
features[k] = v[self._restore_idxs]
@@ -433,8 +437,9 @@ def _transform(
433437
del self._restore_idxs, self._sort_idxs
434438

435439
# lag transforms
436-
for feat in transforms.keys():
437-
df = ufp.assign_columns(df, feat, features[feat])
440+
for feat in self.transforms.keys():
441+
if feat in features:
442+
df = ufp.assign_columns(df, feat, features[feat])
438443

439444
# date features
440445
names = [f.__name__ if callable(f) else f for f in self.date_features]

mlforecast/distributed/forecast.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,12 +293,14 @@ def _preprocess(
293293
keep_last_n=keep_last_n,
294294
window_info=window_info,
295295
)
296-
base_schema = str(fa.get_schema(data))
297-
features_schema = ",".join(f"{feat}:double" for feat in self._base_ts.features)
296+
base_schema = fa.get_schema(data)
297+
features_schema = {
298+
f: "double" for f in self._base_ts.features if f not in base_schema
299+
}
298300
res = fa.transform(
299301
self._partition_results,
300302
DistributedMLForecast._retrieve_df,
301-
schema=f"{base_schema},{features_schema}",
303+
schema=base_schema + features_schema,
302304
engine=self.engine,
303305
)
304306
return fa.get_native_as_df(res)

nbs/core.ipynb

Lines changed: 22 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,9 @@
805805
" 'are dynamic please set `static_features=[]`.'\n",
806806
" )\n",
807807
" self.static_features_ = statics_on_ends\n",
808-
" self.features_order_ = [c for c in df.columns if c not in to_drop] + self.features\n",
808+
" self.features_order_ = [\n",
809+
" c for c in df.columns if c not in to_drop\n",
810+
" ] + [f for f in self.features if f not in df.columns]\n",
809811
" return self\n",
810812
"\n",
811813
" def _compute_transforms(\n",
@@ -858,8 +860,13 @@
858860
" \"\"\"Add the features to `df`.\n",
859861
" \n",
860862
" if `dropna=True` then all the null rows are dropped.\"\"\"\n",
861-
" transforms = {k: v for k, v in self.transforms.items() if k not in df}\n",
862-
" features = self._compute_transforms(transforms=transforms, updates_only=False)\n",
863+
" # we need to compute all transformations in case they save state\n",
864+
" features = self._compute_transforms(\n",
865+
" transforms=self.transforms,\n",
866+
" updates_only=False\n",
867+
" )\n",
868+
" # filter out the features that already exist in df to avoid overwriting them\n",
869+
" features = {k: v for k, v in features.items() if k not in df}\n",
863870
" if self._restore_idxs is not None:\n",
864871
" for k, v in features.items():\n",
865872
" features[k] = v[self._restore_idxs]\n",
@@ -914,8 +921,9 @@
914921
" del self._restore_idxs, self._sort_idxs\n",
915922
"\n",
916923
" # lag transforms\n",
917-
" for feat in transforms.keys():\n",
918-
" df = ufp.assign_columns(df, feat, features[feat])\n",
924+
" for feat in self.transforms.keys():\n",
925+
" if feat in features:\n",
926+
" df = ufp.assign_columns(df, feat, features[feat])\n",
919927
"\n",
920928
" # date features\n",
921929
" names = [f.__name__ if callable(f) else f for f in self.date_features]\n",
@@ -1663,13 +1671,12 @@
16631671
"text/markdown": [
16641672
"---\n",
16651673
"\n",
1666-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L481){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
1674+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L486){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
16671675
"\n",
16681676
"## TimeSeries.fit_transform\n",
16691677
"\n",
1670-
"> TimeSeries.fit_transform (data:Union[pandas.core.frame.DataFrame,polars.d\n",
1671-
"> ataframe.frame.DataFrame], id_col:str,\n",
1672-
"> time_col:str, target_col:str,\n",
1678+
"> TimeSeries.fit_transform (data:~DFType, id_col:str, time_col:str,\n",
1679+
"> target_col:str,\n",
16731680
"> static_features:Optional[List[str]]=None,\n",
16741681
"> dropna:bool=True,\n",
16751682
"> keep_last_n:Optional[int]=None,\n",
@@ -1685,13 +1692,12 @@
16851692
"text/plain": [
16861693
"---\n",
16871694
"\n",
1688-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L481){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
1695+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L486){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
16891696
"\n",
16901697
"## TimeSeries.fit_transform\n",
16911698
"\n",
1692-
"> TimeSeries.fit_transform (data:Union[pandas.core.frame.DataFrame,polars.d\n",
1693-
"> ataframe.frame.DataFrame], id_col:str,\n",
1694-
"> time_col:str, target_col:str,\n",
1699+
"> TimeSeries.fit_transform (data:~DFType, id_col:str, time_col:str,\n",
1700+
"> target_col:str,\n",
16951701
"> static_features:Optional[List[str]]=None,\n",
16961702
"> dropna:bool=True,\n",
16971703
"> keep_last_n:Optional[int]=None,\n",
@@ -1972,45 +1978,7 @@
19721978
"cell_type": "code",
19731979
"execution_count": null,
19741980
"metadata": {},
1975-
"outputs": [
1976-
{
1977-
"data": {
1978-
"text/markdown": [
1979-
"---\n",
1980-
"\n",
1981-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L726){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
1982-
"\n",
1983-
"## TimeSeries.predict\n",
1984-
"\n",
1985-
"> TimeSeries.predict (models:Dict[str,Union[sklearn.base.BaseEstimator,List\n",
1986-
"> [sklearn.base.BaseEstimator]]], horizon:int,\n",
1987-
"> before_predict_callback:Optional[Callable]=None,\n",
1988-
"> after_predict_callback:Optional[Callable]=None, X_df:\n",
1989-
"> Union[pandas.core.frame.DataFrame,polars.dataframe.fr\n",
1990-
"> ame.DataFrame,NoneType]=None,\n",
1991-
"> ids:Optional[List[str]]=None)"
1992-
],
1993-
"text/plain": [
1994-
"---\n",
1995-
"\n",
1996-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L726){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
1997-
"\n",
1998-
"## TimeSeries.predict\n",
1999-
"\n",
2000-
"> TimeSeries.predict (models:Dict[str,Union[sklearn.base.BaseEstimator,List\n",
2001-
"> [sklearn.base.BaseEstimator]]], horizon:int,\n",
2002-
"> before_predict_callback:Optional[Callable]=None,\n",
2003-
"> after_predict_callback:Optional[Callable]=None, X_df:\n",
2004-
"> Union[pandas.core.frame.DataFrame,polars.dataframe.fr\n",
2005-
"> ame.DataFrame,NoneType]=None,\n",
2006-
"> ids:Optional[List[str]]=None)"
2007-
]
2008-
},
2009-
"execution_count": null,
2010-
"metadata": {},
2011-
"output_type": "execute_result"
2012-
}
2013-
],
1981+
"outputs": [],
20141982
"source": [
20151983
"show_doc(TimeSeries.predict, title_level=2)"
20161984
]
@@ -2126,41 +2094,7 @@
21262094
"cell_type": "code",
21272095
"execution_count": null,
21282096
"metadata": {},
2129-
"outputs": [
2130-
{
2131-
"data": {
2132-
"text/markdown": [
2133-
"---\n",
2134-
"\n",
2135-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L831){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
2136-
"\n",
2137-
"## TimeSeries.update\n",
2138-
"\n",
2139-
"> TimeSeries.update\n",
2140-
"> (df:Union[pandas.core.frame.DataFrame,polars.dataframe\n",
2141-
"> .frame.DataFrame])\n",
2142-
"\n",
2143-
"*Update the values of the stored series.*"
2144-
],
2145-
"text/plain": [
2146-
"---\n",
2147-
"\n",
2148-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L831){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
2149-
"\n",
2150-
"## TimeSeries.update\n",
2151-
"\n",
2152-
"> TimeSeries.update\n",
2153-
"> (df:Union[pandas.core.frame.DataFrame,polars.dataframe\n",
2154-
"> .frame.DataFrame])\n",
2155-
"\n",
2156-
"*Update the values of the stored series.*"
2157-
]
2158-
},
2159-
"execution_count": null,
2160-
"metadata": {},
2161-
"output_type": "execute_result"
2162-
}
2163-
],
2097+
"outputs": [],
21642098
"source": [
21652099
"show_doc(TimeSeries.update, title_level=2)"
21662100
]
@@ -2246,15 +2180,7 @@
22462180
"cell_type": "code",
22472181
"execution_count": null,
22482182
"metadata": {},
2249-
"outputs": [
2250-
{
2251-
"name": "stderr",
2252-
"output_type": "stream",
2253-
"text": [
2254-
"sys:1: CategoricalRemappingWarning: Local categoricals have different encodings, expensive re-encoding is done to perform this merge operation. Consider using a StringCache or an Enum type if the categories are known in advance\n"
2255-
]
2256-
}
2257-
],
2183+
"outputs": [],
22582184
"source": [
22592185
"#| hide\n",
22602186
"#| polars\n",

0 commit comments

Comments
 (0)