|
1184 | 1184 | " result = df_constructor({self.id_col: uids, self.time_col: dates})\n",
|
1185 | 1185 | " for name, model in models.items():\n",
|
1186 | 1186 | " with self._backup():\n",
|
| 1187 | + " self._predict_setup()\n", |
1187 | 1188 | " new_x = self._get_features_for_next_step(X_df)\n",
|
1188 | 1189 | " if before_predict_callback is not None:\n",
|
1189 | 1190 | " new_x = before_predict_callback(new_x)\n",
|
|
1261 | 1262 | " raise ValueError(\n",
|
1262 | 1263 | " f\"The following features were provided through `X_df` but were considered as static during fit: {common}.\\n\"\n",
|
1263 | 1264 | " \"Please re-run the fit step using the `static_features` argument to indicate which features are static. \"\n",
|
1264 |
| - " \"If all your features are dynamic please pass an empty list (static_features=[]).\"\n", |
| 1265 | + " \"If all your features are dynamic please provide an empty list (static_features=[]).\"\n", |
1265 | 1266 | " )\n",
|
1266 | 1267 | " starts = ufp.offset_times(self.last_dates, self.freq, 1)\n",
|
1267 |
| - " ends = ufp.offset_times(self.last_dates, self.freq, horizon)\n", |
| 1268 | + " if getattr(self, 'max_horizon', None) is None:\n", |
| 1269 | + " ends = ufp.offset_times(self.last_dates, self.freq, horizon)\n", |
| 1270 | + " expected_rows_X = len(self.uids) * horizon\n", |
| 1271 | + " else:\n", |
| 1272 | + " # direct approach uses only the immediate next timestamp\n", |
| 1273 | + " ends = starts\n", |
| 1274 | + " expected_rows_X = len(self.uids)\n", |
1268 | 1275 | " dates_validation = type(X_df)(\n",
|
1269 | 1276 | " {\n",
|
1270 | 1277 | " self.id_col: self.uids,\n",
|
|
1275 | 1282 | " X_df = ufp.join(X_df, dates_validation, on=self.id_col)\n",
|
1276 | 1283 | " mask = ufp.between(X_df[self.time_col], X_df['_start'], X_df['_end'])\n",
|
1277 | 1284 | " X_df = ufp.filter_with_mask(X_df, mask)\n",
|
1278 |
| - " if X_df.shape[0] != len(self.uids) * horizon:\n", |
| 1285 | + " if X_df.shape[0] != expected_rows_X:\n", |
1279 | 1286 | " msg = (\n",
|
1280 | 1287 | " \"Found missing inputs in X_df. \"\n",
|
1281 | 1288 | " \"It should have one row per id and time for the complete forecasting horizon.\\n\"\n",
|
|
2015 | 2022 | "text/markdown": [
|
2016 | 2023 | "---\n",
|
2017 | 2024 | "\n",
|
2018 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L757){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 2025 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L758){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
2019 | 2026 | "\n",
|
2020 | 2027 | "## TimeSeries.predict\n",
|
2021 | 2028 | "\n",
|
|
2029 | 2036 | "text/plain": [
|
2030 | 2037 | "---\n",
|
2031 | 2038 | "\n",
|
2032 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L757){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 2039 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L758){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
2033 | 2040 | "\n",
|
2034 | 2041 | "## TimeSeries.predict\n",
|
2035 | 2042 | "\n",
|
|
2167 | 2174 | "text/markdown": [
|
2168 | 2175 | "---\n",
|
2169 | 2176 | "\n",
|
2170 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L862){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 2177 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L869){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
2171 | 2178 | "\n",
|
2172 | 2179 | "## TimeSeries.update\n",
|
2173 | 2180 | "\n",
|
|
2180 | 2187 | "text/plain": [
|
2181 | 2188 | "---\n",
|
2182 | 2189 | "\n",
|
2183 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L862){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 2190 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L869){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
2184 | 2191 | "\n",
|
2185 | 2192 | "## TimeSeries.update\n",
|
2186 | 2193 | "\n",
|
|
0 commit comments