|
272 | 272 | "text/markdown": [
|
273 | 273 | "---\n",
|
274 | 274 | "\n",
|
275 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L113){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 275 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L114){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
276 | 276 | "\n",
|
277 | 277 | "### AutoModel\n",
|
278 | 278 | "\n",
|
|
289 | 289 | "text/plain": [
|
290 | 290 | "---\n",
|
291 | 291 | "\n",
|
292 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L113){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 292 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L114){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
293 | 293 | "\n",
|
294 | 294 | "### AutoModel\n",
|
295 | 295 | "\n",
|
|
522 | 522 | " n_windows: int,\n",
|
523 | 523 | " h: int,\n",
|
524 | 524 | " num_samples: int,\n",
|
| 525 | + " step_size: Optional[int] = None,\n", |
525 | 526 | " refit: Union[bool, int] = False,\n",
|
526 | 527 | " loss: Optional[Callable[[DataFrame, DataFrame], float]] = None,\n",
|
527 | 528 | " id_col: str = 'unique_id',\n",
|
|
545 | 546 | " Forecast horizon.\n",
|
546 | 547 | " num_samples : int\n",
|
547 | 548 | " Number of trials to run\n",
|
| 549 | + " step_size : int, optional (default=None)\n", |
| 550 | + " Step size between each cross validation window. If None it will be equal to `h`.\n", |
548 | 551 | " refit : bool or int (default=False)\n",
|
549 | 552 | " Retrain model for each cross validation window.\n",
|
550 | 553 | " If False, the models are trained at the beginning and then used to predict each window.\n",
|
|
616 | 619 | " freq=self.freq,\n",
|
617 | 620 | " n_windows=n_windows,\n",
|
618 | 621 | " h=h,\n",
|
| 622 | + " step_size=step_size,\n", |
619 | 623 | " refit=refit,\n",
|
620 | 624 | " id_col=id_col,\n",
|
621 | 625 | " time_col=time_col,\n",
|
|
726 | 730 | "text/markdown": [
|
727 | 731 | "---\n",
|
728 | 732 | "\n",
|
729 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L240){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 733 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L241){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
730 | 734 | "\n",
|
731 | 735 | "### AutoMLForecast\n",
|
732 | 736 | "\n",
|
|
752 | 756 | "text/plain": [
|
753 | 757 | "---\n",
|
754 | 758 | "\n",
|
755 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L240){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 759 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L241){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
756 | 760 | "\n",
|
757 | 761 | "### AutoMLForecast\n",
|
758 | 762 | "\n",
|
|
796 | 800 | "text/markdown": [
|
797 | 801 | "---\n",
|
798 | 802 | "\n",
|
799 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L432){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 803 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L441){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
800 | 804 | "\n",
|
801 | 805 | "### AutoMLForecast.fit\n",
|
802 | 806 | "\n",
|
803 | 807 | "> AutoMLForecast.fit\n",
|
804 | 808 | "> (df:Union[pandas.core.frame.DataFrame,polars.datafram\n",
|
805 | 809 | "> e.frame.DataFrame], n_windows:int, h:int,\n",
|
806 |
| - "> num_samples:int, refit:Union[bool,int]=False, loss:Op\n", |
807 |
| - "> tional[Callable[[Union[pandas.core.frame.DataFrame,po\n", |
808 |
| - "> lars.dataframe.frame.DataFrame],Union[pandas.core.fra\n", |
809 |
| - "> me.DataFrame,polars.dataframe.frame.DataFrame]],float\n", |
810 |
| - "> ]]=None, id_col:str='unique_id', time_col:str='ds',\n", |
| 810 | + "> num_samples:int, step_size:Optional[int]=None,\n", |
| 811 | + "> refit:Union[bool,int]=False, loss:Optional[Callable[[\n", |
| 812 | + "> Union[pandas.core.frame.DataFrame,polars.dataframe.fr\n", |
| 813 | + "> ame.DataFrame],Union[pandas.core.frame.DataFrame,pola\n", |
| 814 | + "> rs.dataframe.frame.DataFrame]],float]]=None,\n", |
| 815 | + "> id_col:str='unique_id', time_col:str='ds',\n", |
811 | 816 | "> target_col:str='y',\n",
|
812 | 817 | "> study_kwargs:Optional[Dict[str,Any]]=None,\n",
|
813 | 818 | "> optimize_kwargs:Optional[Dict[str,Any]]=None,\n",
|
|
823 | 828 | "| n_windows | int | | Number of windows to evaluate. |\n",
|
824 | 829 | "| h | int | | Forecast horizon. |\n",
|
825 | 830 | "| num_samples | int | | Number of trials to run |\n",
|
| 831 | + "| step_size | Optional | None | Step size between each cross validation window. If None it will be equal to `h`. |\n", |
826 | 832 | "| refit | Union | False | Retrain model for each cross validation window.<br>If False, the models are trained at the beginning and then used to predict each window.<br>If positive int, the models are retrained every `refit` windows. |\n",
|
827 | 833 | "| loss | Optional | None | Function that takes the validation and train dataframes and produces a float.<br>If `None` will use the average SMAPE across series. |\n",
|
828 | 834 | "| id_col | str | unique_id | Column that identifies each serie. |\n",
|
|
837 | 843 | "text/plain": [
|
838 | 844 | "---\n",
|
839 | 845 | "\n",
|
840 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L432){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 846 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L441){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
841 | 847 | "\n",
|
842 | 848 | "### AutoMLForecast.fit\n",
|
843 | 849 | "\n",
|
844 | 850 | "> AutoMLForecast.fit\n",
|
845 | 851 | "> (df:Union[pandas.core.frame.DataFrame,polars.datafram\n",
|
846 | 852 | "> e.frame.DataFrame], n_windows:int, h:int,\n",
|
847 |
| - "> num_samples:int, refit:Union[bool,int]=False, loss:Op\n", |
848 |
| - "> tional[Callable[[Union[pandas.core.frame.DataFrame,po\n", |
849 |
| - "> lars.dataframe.frame.DataFrame],Union[pandas.core.fra\n", |
850 |
| - "> me.DataFrame,polars.dataframe.frame.DataFrame]],float\n", |
851 |
| - "> ]]=None, id_col:str='unique_id', time_col:str='ds',\n", |
| 853 | + "> num_samples:int, step_size:Optional[int]=None,\n", |
| 854 | + "> refit:Union[bool,int]=False, loss:Optional[Callable[[\n", |
| 855 | + "> Union[pandas.core.frame.DataFrame,polars.dataframe.fr\n", |
| 856 | + "> ame.DataFrame],Union[pandas.core.frame.DataFrame,pola\n", |
| 857 | + "> rs.dataframe.frame.DataFrame]],float]]=None,\n", |
| 858 | + "> id_col:str='unique_id', time_col:str='ds',\n", |
852 | 859 | "> target_col:str='y',\n",
|
853 | 860 | "> study_kwargs:Optional[Dict[str,Any]]=None,\n",
|
854 | 861 | "> optimize_kwargs:Optional[Dict[str,Any]]=None,\n",
|
|
864 | 871 | "| n_windows | int | | Number of windows to evaluate. |\n",
|
865 | 872 | "| h | int | | Forecast horizon. |\n",
|
866 | 873 | "| num_samples | int | | Number of trials to run |\n",
|
| 874 | + "| step_size | Optional | None | Step size between each cross validation window. If None it will be equal to `h`. |\n", |
867 | 875 | "| refit | Union | False | Retrain model for each cross validation window.<br>If False, the models are trained at the beginning and then used to predict each window.<br>If positive int, the models are retrained every `refit` windows. |\n",
|
868 | 876 | "| loss | Optional | None | Function that takes the validation and train dataframes and produces a float.<br>If `None` will use the average SMAPE across series. |\n",
|
869 | 877 | "| id_col | str | unique_id | Column that identifies each serie. |\n",
|
|
896 | 904 | "text/markdown": [
|
897 | 905 | "---\n",
|
898 | 906 | "\n",
|
899 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L561){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 907 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L570){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
900 | 908 | "\n",
|
901 | 909 | "### AutoMLForecast.predict\n",
|
902 | 910 | "\n",
|
|
916 | 924 | "text/plain": [
|
917 | 925 | "---\n",
|
918 | 926 | "\n",
|
919 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L561){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 927 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L570){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
920 | 928 | "\n",
|
921 | 929 | "### AutoMLForecast.predict\n",
|
922 | 930 | "\n",
|
|
954 | 962 | "text/markdown": [
|
955 | 963 | "---\n",
|
956 | 964 | "\n",
|
957 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L593){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 965 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L602){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
958 | 966 | "\n",
|
959 | 967 | "### AutoMLForecast.save\n",
|
960 | 968 | "\n",
|
|
970 | 978 | "text/plain": [
|
971 | 979 | "---\n",
|
972 | 980 | "\n",
|
973 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L593){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 981 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L602){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
974 | 982 | "\n",
|
975 | 983 | "### AutoMLForecast.save\n",
|
976 | 984 | "\n",
|
|
1004 | 1012 | "text/markdown": [
|
1005 | 1013 | "---\n",
|
1006 | 1014 | "\n",
|
1007 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L603){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 1015 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L612){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
1008 | 1016 | "\n",
|
1009 | 1017 | "### AutoMLForecast.forecast_fitted_values\n",
|
1010 | 1018 | "\n",
|
|
1022 | 1030 | "text/plain": [
|
1023 | 1031 | "---\n",
|
1024 | 1032 | "\n",
|
1025 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L603){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 1033 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L612){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
1026 | 1034 | "\n",
|
1027 | 1035 | "### AutoMLForecast.forecast_fitted_values\n",
|
1028 | 1036 | "\n",
|
|
1702 | 1710 | "#| polars\n",
|
1703 | 1711 | "auto_mlf.forecast_fitted_values(level=[95])"
|
1704 | 1712 | ]
|
| 1713 | + }, |
| 1714 | + { |
| 1715 | + "cell_type": "code", |
| 1716 | + "execution_count": null, |
| 1717 | + "id": "0dfe2c18-d3df-41f2-a3b8-1cf73ef9765c", |
| 1718 | + "metadata": {}, |
| 1719 | + "outputs": [], |
| 1720 | + "source": [ |
| 1721 | + "#| hide\n", |
| 1722 | + "#| polars\n", |
| 1723 | + "auto_mlf2 = AutoMLForecast(\n", |
| 1724 | + " freq=1,\n", |
| 1725 | + " season_length=season_length,\n", |
| 1726 | + " models={'ridge': AutoRidge()},\n", |
| 1727 | + " num_threads=2,\n", |
| 1728 | + ")\n", |
| 1729 | + "auto_mlf2.fit(\n", |
| 1730 | + " df=train_pl,\n", |
| 1731 | + " n_windows=2,\n", |
| 1732 | + " h=h,\n", |
| 1733 | + " step_size=1,\n", |
| 1734 | + " num_samples=2,\n", |
| 1735 | + " optimize_kwargs={'timeout': 60},\n", |
| 1736 | + " fitted=True,\n", |
| 1737 | + " prediction_intervals=PredictionIntervals(n_windows=2, h=h),\n", |
| 1738 | + ")\n", |
| 1739 | + "metric_step_h = auto_mlf.results_['ridge'].best_trial.value\n", |
| 1740 | + "metric_step_1 = auto_mlf2.results_['ridge'].best_trial.value\n", |
| 1741 | + "assert abs(metric_step_h / metric_step_1 - 1) > 0.02" |
| 1742 | + ] |
1705 | 1743 | }
|
1706 | 1744 | ],
|
1707 | 1745 | "metadata": {
|
|
0 commit comments