Skip to content

Commit 95b91c4

Browse files
authored
enh: add step_size to AutoMLForecast (#426)
1 parent 282d91f commit 95b91c4

File tree

2 files changed

+64
-22
lines changed

2 files changed

+64
-22
lines changed

mlforecast/auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ def fit(
444444
n_windows: int,
445445
h: int,
446446
num_samples: int,
447+
step_size: Optional[int] = None,
447448
refit: Union[bool, int] = False,
448449
loss: Optional[Callable[[DataFrame, DataFrame], float]] = None,
449450
id_col: str = "unique_id",
@@ -467,6 +468,8 @@ def fit(
467468
Forecast horizon.
468469
num_samples : int
469470
Number of trials to run
471+
step_size : int, optional (default=None)
472+
Step size between each cross validation window. If None it will be equal to `h`.
470473
refit : bool or int (default=False)
471474
Retrain model for each cross validation window.
472475
If False, the models are trained at the beginning and then used to predict each window.
@@ -541,6 +544,7 @@ def config_fn(trial: optuna.Trial) -> Dict[str, Any]:
541544
freq=self.freq,
542545
n_windows=n_windows,
543546
h=h,
547+
step_size=step_size,
544548
refit=refit,
545549
id_col=id_col,
546550
time_col=time_col,

nbs/auto.ipynb

Lines changed: 60 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@
272272
"text/markdown": [
273273
"---\n",
274274
"\n",
275-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L113){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
275+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L114){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
276276
"\n",
277277
"### AutoModel\n",
278278
"\n",
@@ -289,7 +289,7 @@
289289
"text/plain": [
290290
"---\n",
291291
"\n",
292-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L113){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
292+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L114){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
293293
"\n",
294294
"### AutoModel\n",
295295
"\n",
@@ -522,6 +522,7 @@
522522
" n_windows: int,\n",
523523
" h: int,\n",
524524
" num_samples: int,\n",
525+
" step_size: Optional[int] = None,\n",
525526
" refit: Union[bool, int] = False,\n",
526527
" loss: Optional[Callable[[DataFrame, DataFrame], float]] = None,\n",
527528
" id_col: str = 'unique_id',\n",
@@ -545,6 +546,8 @@
545546
" Forecast horizon.\n",
546547
" num_samples : int\n",
547548
" Number of trials to run\n",
549+
" step_size : int, optional (default=None)\n",
550+
" Step size between each cross validation window. If None it will be equal to `h`.\n",
548551
" refit : bool or int (default=False)\n",
549552
" Retrain model for each cross validation window.\n",
550553
" If False, the models are trained at the beginning and then used to predict each window.\n",
@@ -616,6 +619,7 @@
616619
" freq=self.freq,\n",
617620
" n_windows=n_windows,\n",
618621
" h=h,\n",
622+
" step_size=step_size,\n",
619623
" refit=refit,\n",
620624
" id_col=id_col,\n",
621625
" time_col=time_col,\n",
@@ -726,7 +730,7 @@
726730
"text/markdown": [
727731
"---\n",
728732
"\n",
729-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L240){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
733+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L241){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
730734
"\n",
731735
"### AutoMLForecast\n",
732736
"\n",
@@ -752,7 +756,7 @@
752756
"text/plain": [
753757
"---\n",
754758
"\n",
755-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L240){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
759+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L241){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
756760
"\n",
757761
"### AutoMLForecast\n",
758762
"\n",
@@ -796,18 +800,19 @@
796800
"text/markdown": [
797801
"---\n",
798802
"\n",
799-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L432){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
803+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L441){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
800804
"\n",
801805
"### AutoMLForecast.fit\n",
802806
"\n",
803807
"> AutoMLForecast.fit\n",
804808
"> (df:Union[pandas.core.frame.DataFrame,polars.datafram\n",
805809
"> e.frame.DataFrame], n_windows:int, h:int,\n",
806-
"> num_samples:int, refit:Union[bool,int]=False, loss:Op\n",
807-
"> tional[Callable[[Union[pandas.core.frame.DataFrame,po\n",
808-
"> lars.dataframe.frame.DataFrame],Union[pandas.core.fra\n",
809-
"> me.DataFrame,polars.dataframe.frame.DataFrame]],float\n",
810-
"> ]]=None, id_col:str='unique_id', time_col:str='ds',\n",
810+
"> num_samples:int, step_size:Optional[int]=None,\n",
811+
"> refit:Union[bool,int]=False, loss:Optional[Callable[[\n",
812+
"> Union[pandas.core.frame.DataFrame,polars.dataframe.fr\n",
813+
"> ame.DataFrame],Union[pandas.core.frame.DataFrame,pola\n",
814+
"> rs.dataframe.frame.DataFrame]],float]]=None,\n",
815+
"> id_col:str='unique_id', time_col:str='ds',\n",
811816
"> target_col:str='y',\n",
812817
"> study_kwargs:Optional[Dict[str,Any]]=None,\n",
813818
"> optimize_kwargs:Optional[Dict[str,Any]]=None,\n",
@@ -823,6 +828,7 @@
823828
"| n_windows | int | | Number of windows to evaluate. |\n",
824829
"| h | int | | Forecast horizon. |\n",
825830
"| num_samples | int | | Number of trials to run |\n",
831+
"| step_size | Optional | None | Step size between each cross validation window. If None it will be equal to `h`. |\n",
826832
"| refit | Union | False | Retrain model for each cross validation window.<br>If False, the models are trained at the beginning and then used to predict each window.<br>If positive int, the models are retrained every `refit` windows. |\n",
827833
"| loss | Optional | None | Function that takes the validation and train dataframes and produces a float.<br>If `None` will use the average SMAPE across series. |\n",
828834
"| id_col | str | unique_id | Column that identifies each serie. |\n",
@@ -837,18 +843,19 @@
837843
"text/plain": [
838844
"---\n",
839845
"\n",
840-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L432){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
846+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L441){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
841847
"\n",
842848
"### AutoMLForecast.fit\n",
843849
"\n",
844850
"> AutoMLForecast.fit\n",
845851
"> (df:Union[pandas.core.frame.DataFrame,polars.datafram\n",
846852
"> e.frame.DataFrame], n_windows:int, h:int,\n",
847-
"> num_samples:int, refit:Union[bool,int]=False, loss:Op\n",
848-
"> tional[Callable[[Union[pandas.core.frame.DataFrame,po\n",
849-
"> lars.dataframe.frame.DataFrame],Union[pandas.core.fra\n",
850-
"> me.DataFrame,polars.dataframe.frame.DataFrame]],float\n",
851-
"> ]]=None, id_col:str='unique_id', time_col:str='ds',\n",
853+
"> num_samples:int, step_size:Optional[int]=None,\n",
854+
"> refit:Union[bool,int]=False, loss:Optional[Callable[[\n",
855+
"> Union[pandas.core.frame.DataFrame,polars.dataframe.fr\n",
856+
"> ame.DataFrame],Union[pandas.core.frame.DataFrame,pola\n",
857+
"> rs.dataframe.frame.DataFrame]],float]]=None,\n",
858+
"> id_col:str='unique_id', time_col:str='ds',\n",
852859
"> target_col:str='y',\n",
853860
"> study_kwargs:Optional[Dict[str,Any]]=None,\n",
854861
"> optimize_kwargs:Optional[Dict[str,Any]]=None,\n",
@@ -864,6 +871,7 @@
864871
"| n_windows | int | | Number of windows to evaluate. |\n",
865872
"| h | int | | Forecast horizon. |\n",
866873
"| num_samples | int | | Number of trials to run |\n",
874+
"| step_size | Optional | None | Step size between each cross validation window. If None it will be equal to `h`. |\n",
867875
"| refit | Union | False | Retrain model for each cross validation window.<br>If False, the models are trained at the beginning and then used to predict each window.<br>If positive int, the models are retrained every `refit` windows. |\n",
868876
"| loss | Optional | None | Function that takes the validation and train dataframes and produces a float.<br>If `None` will use the average SMAPE across series. |\n",
869877
"| id_col | str | unique_id | Column that identifies each serie. |\n",
@@ -896,7 +904,7 @@
896904
"text/markdown": [
897905
"---\n",
898906
"\n",
899-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L561){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
907+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L570){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
900908
"\n",
901909
"### AutoMLForecast.predict\n",
902910
"\n",
@@ -916,7 +924,7 @@
916924
"text/plain": [
917925
"---\n",
918926
"\n",
919-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L561){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
927+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L570){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
920928
"\n",
921929
"### AutoMLForecast.predict\n",
922930
"\n",
@@ -954,7 +962,7 @@
954962
"text/markdown": [
955963
"---\n",
956964
"\n",
957-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L593){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
965+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L602){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
958966
"\n",
959967
"### AutoMLForecast.save\n",
960968
"\n",
@@ -970,7 +978,7 @@
970978
"text/plain": [
971979
"---\n",
972980
"\n",
973-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L593){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
981+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L602){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
974982
"\n",
975983
"### AutoMLForecast.save\n",
976984
"\n",
@@ -1004,7 +1012,7 @@
10041012
"text/markdown": [
10051013
"---\n",
10061014
"\n",
1007-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L603){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
1015+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L612){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
10081016
"\n",
10091017
"### AutoMLForecast.forecast_fitted_values\n",
10101018
"\n",
@@ -1022,7 +1030,7 @@
10221030
"text/plain": [
10231031
"---\n",
10241032
"\n",
1025-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L603){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
1033+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L612){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
10261034
"\n",
10271035
"### AutoMLForecast.forecast_fitted_values\n",
10281036
"\n",
@@ -1702,6 +1710,36 @@
17021710
"#| polars\n",
17031711
"auto_mlf.forecast_fitted_values(level=[95])"
17041712
]
1713+
},
1714+
{
1715+
"cell_type": "code",
1716+
"execution_count": null,
1717+
"id": "0dfe2c18-d3df-41f2-a3b8-1cf73ef9765c",
1718+
"metadata": {},
1719+
"outputs": [],
1720+
"source": [
1721+
"#| hide\n",
1722+
"#| polars\n",
1723+
"auto_mlf2 = AutoMLForecast(\n",
1724+
" freq=1,\n",
1725+
" season_length=season_length,\n",
1726+
" models={'ridge': AutoRidge()},\n",
1727+
" num_threads=2,\n",
1728+
")\n",
1729+
"auto_mlf2.fit(\n",
1730+
" df=train_pl,\n",
1731+
" n_windows=2,\n",
1732+
" h=h,\n",
1733+
" step_size=1,\n",
1734+
" num_samples=2,\n",
1735+
" optimize_kwargs={'timeout': 60},\n",
1736+
" fitted=True,\n",
1737+
" prediction_intervals=PredictionIntervals(n_windows=2, h=h),\n",
1738+
")\n",
1739+
"metric_step_h = auto_mlf.results_['ridge'].best_trial.value\n",
1740+
"metric_step_1 = auto_mlf2.results_['ridge'].best_trial.value\n",
1741+
"assert abs(metric_step_h / metric_step_1 - 1) > 0.02"
1742+
]
17051743
}
17061744
],
17071745
"metadata": {

0 commit comments

Comments
 (0)