Skip to content

Commit e09106c

Browse files
authored
fix: X_df handling in direct approach (#468)
1 parent 1fba842 commit e09106c

File tree

3 files changed

+52
-10
lines changed

3 files changed

+52
-10
lines changed

mlforecast/core.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,7 @@ def _predict_multi(
703703
result = df_constructor({self.id_col: uids, self.time_col: dates})
704704
for name, model in models.items():
705705
with self._backup():
706+
self._predict_setup()
706707
new_x = self._get_features_for_next_step(X_df)
707708
if before_predict_callback is not None:
708709
new_x = before_predict_callback(new_x)
@@ -789,10 +790,16 @@ def predict(
789790
raise ValueError(
790791
f"The following features were provided through `X_df` but were considered as static during fit: {common}.\n"
791792
"Please re-run the fit step using the `static_features` argument to indicate which features are static. "
792-
"If all your features are dynamic please pass an empty list (static_features=[])."
793+
"If all your features are dynamic please provide an empty list (static_features=[])."
793794
)
794795
starts = ufp.offset_times(self.last_dates, self.freq, 1)
795-
ends = ufp.offset_times(self.last_dates, self.freq, horizon)
796+
if getattr(self, "max_horizon", None) is None:
797+
ends = ufp.offset_times(self.last_dates, self.freq, horizon)
798+
expected_rows_X = len(self.uids) * horizon
799+
else:
800+
# direct approach uses only the immediate next timestamp
801+
ends = starts
802+
expected_rows_X = len(self.uids)
796803
dates_validation = type(X_df)(
797804
{
798805
self.id_col: self.uids,
@@ -803,7 +810,7 @@ def predict(
803810
X_df = ufp.join(X_df, dates_validation, on=self.id_col)
804811
mask = ufp.between(X_df[self.time_col], X_df["_start"], X_df["_end"])
805812
X_df = ufp.filter_with_mask(X_df, mask)
806-
if X_df.shape[0] != len(self.uids) * horizon:
813+
if X_df.shape[0] != expected_rows_X:
807814
msg = (
808815
"Found missing inputs in X_df. "
809816
"It should have one row per id and time for the complete forecasting horizon.\n"

nbs/core.ipynb

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,7 @@
11841184
" result = df_constructor({self.id_col: uids, self.time_col: dates})\n",
11851185
" for name, model in models.items():\n",
11861186
" with self._backup():\n",
1187+
" self._predict_setup()\n",
11871188
" new_x = self._get_features_for_next_step(X_df)\n",
11881189
" if before_predict_callback is not None:\n",
11891190
" new_x = before_predict_callback(new_x)\n",
@@ -1261,10 +1262,16 @@
12611262
" raise ValueError(\n",
12621263
" f\"The following features were provided through `X_df` but were considered as static during fit: {common}.\\n\"\n",
12631264
" \"Please re-run the fit step using the `static_features` argument to indicate which features are static. \"\n",
1264-
" \"If all your features are dynamic please pass an empty list (static_features=[]).\"\n",
1265+
" \"If all your features are dynamic please provide an empty list (static_features=[]).\"\n",
12651266
" )\n",
12661267
" starts = ufp.offset_times(self.last_dates, self.freq, 1)\n",
1267-
" ends = ufp.offset_times(self.last_dates, self.freq, horizon)\n",
1268+
" if getattr(self, 'max_horizon', None) is None:\n",
1269+
" ends = ufp.offset_times(self.last_dates, self.freq, horizon)\n",
1270+
" expected_rows_X = len(self.uids) * horizon\n",
1271+
" else:\n",
1272+
" # direct approach uses only the immediate next timestamp\n",
1273+
" ends = starts\n",
1274+
" expected_rows_X = len(self.uids)\n",
12681275
" dates_validation = type(X_df)(\n",
12691276
" {\n",
12701277
" self.id_col: self.uids,\n",
@@ -1275,7 +1282,7 @@
12751282
" X_df = ufp.join(X_df, dates_validation, on=self.id_col)\n",
12761283
" mask = ufp.between(X_df[self.time_col], X_df['_start'], X_df['_end'])\n",
12771284
" X_df = ufp.filter_with_mask(X_df, mask)\n",
1278-
" if X_df.shape[0] != len(self.uids) * horizon:\n",
1285+
" if X_df.shape[0] != expected_rows_X:\n",
12791286
" msg = (\n",
12801287
" \"Found missing inputs in X_df. \"\n",
12811288
" \"It should have one row per id and time for the complete forecasting horizon.\\n\"\n",
@@ -2015,7 +2022,7 @@
20152022
"text/markdown": [
20162023
"---\n",
20172024
"\n",
2018-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L757){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
2025+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L758){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
20192026
"\n",
20202027
"## TimeSeries.predict\n",
20212028
"\n",
@@ -2029,7 +2036,7 @@
20292036
"text/plain": [
20302037
"---\n",
20312038
"\n",
2032-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L757){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
2039+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L758){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
20332040
"\n",
20342041
"## TimeSeries.predict\n",
20352042
"\n",
@@ -2167,7 +2174,7 @@
21672174
"text/markdown": [
21682175
"---\n",
21692176
"\n",
2170-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L862){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
2177+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L869){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
21712178
"\n",
21722179
"## TimeSeries.update\n",
21732180
"\n",
@@ -2180,7 +2187,7 @@
21802187
"text/plain": [
21812188
"---\n",
21822189
"\n",
2183-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L862){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
2190+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L869){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
21842191
"\n",
21852192
"## TimeSeries.update\n",
21862193
"\n",

nbs/forecast.ipynb

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,7 @@
11491149
"import numpy as np\n",
11501150
"import xgboost as xgb\n",
11511151
"from sklearn.linear_model import LinearRegression\n",
1152+
"from utilsforecast.feature_engineering import time_features\n",
11521153
"from utilsforecast.plotting import plot_series\n",
11531154
"\n",
11541155
"from mlforecast.lag_transforms import ExpandingMean, ExponentiallyWeightedMean, RollingMean\n",
@@ -5439,6 +5440,33 @@
54395440
"preds2 = fcst2.predict(10)\n",
54405441
"pd.testing.assert_frame_equal(preds, preds2)"
54415442
]
5443+
},
5444+
{
5445+
"cell_type": "code",
5446+
"execution_count": null,
5447+
"id": "4a39447b-3d6b-4960-83c9-4f927c472ebb",
5448+
"metadata": {},
5449+
"outputs": [],
5450+
"source": [
5451+
"#| hide\n",
5452+
"# direct approach requires only one timestamp and produces same results for two models\n",
5453+
"series = generate_daily_series(5)\n",
5454+
"h = 5\n",
5455+
"freq = 'D'\n",
5456+
"train, future = time_features(series, freq=freq, features=['day'], h=h)\n",
5457+
"models = [LinearRegression(), lgb.LGBMRegressor(n_estimators=5)]\n",
5458+
"\n",
5459+
"fcst1 = MLForecast(models=models, freq=freq, date_features=['dayofweek'])\n",
5460+
"fcst1.fit(train, max_horizon=h, static_features=[])\n",
5461+
"preds1 = fcst1.predict(h=h, X_df=future) # extra timestamps\n",
5462+
"\n",
5463+
"fcst2 = MLForecast(models=models[::-1], freq=freq, date_features=['dayofweek'])\n",
5464+
"fcst2.fit(train, max_horizon=h, static_features=[])\n",
5465+
"X_df_one = future.groupby('unique_id', observed=True).head(1)\n",
5466+
"preds2 = fcst2.predict(h=h, X_df=X_df_one) # only needed timestamp\n",
5467+
"\n",
5468+
"pd.testing.assert_frame_equal(preds1, preds2[preds1.columns])"
5469+
]
54425470
}
54435471
],
54445472
"metadata": {

0 commit comments

Comments
 (0)