|
805 | 805 | " 'are dynamic please set `static_features=[]`.'\n",
|
806 | 806 | " )\n",
|
807 | 807 | " self.static_features_ = statics_on_ends\n",
|
808 |
| - " self.features_order_ = [c for c in df.columns if c not in to_drop] + self.features\n", |
| 808 | + " self.features_order_ = [\n", |
| 809 | + " c for c in df.columns if c not in to_drop\n", |
| 810 | + " ] + [f for f in self.features if f not in df.columns]\n", |
809 | 811 | " return self\n",
|
810 | 812 | "\n",
|
811 | 813 | " def _compute_transforms(\n",
|
|
858 | 860 | " \"\"\"Add the features to `df`.\n",
|
859 | 861 | " \n",
|
860 | 862 | " if `dropna=True` then all the null rows are dropped.\"\"\"\n",
|
861 |
| - " transforms = {k: v for k, v in self.transforms.items() if k not in df}\n", |
862 |
| - " features = self._compute_transforms(transforms=transforms, updates_only=False)\n", |
| 863 | + " # we need to compute all transformations in case they save state\n", |
| 864 | + " features = self._compute_transforms(\n", |
| 865 | + " transforms=self.transforms,\n", |
| 866 | + " updates_only=False\n", |
| 867 | + " )\n", |
| 868 | + " # filter out the features that already exist in df to avoid overwriting them\n", |
| 869 | + " features = {k: v for k, v in features.items() if k not in df}\n", |
863 | 870 | " if self._restore_idxs is not None:\n",
|
864 | 871 | " for k, v in features.items():\n",
|
865 | 872 | " features[k] = v[self._restore_idxs]\n",
|
|
914 | 921 | " del self._restore_idxs, self._sort_idxs\n",
|
915 | 922 | "\n",
|
916 | 923 | " # lag transforms\n",
|
917 |
| - " for feat in transforms.keys():\n", |
918 |
| - " df = ufp.assign_columns(df, feat, features[feat])\n", |
| 924 | + " for feat in self.transforms.keys():\n", |
| 925 | + " if feat in features:\n", |
| 926 | + " df = ufp.assign_columns(df, feat, features[feat])\n", |
919 | 927 | "\n",
|
920 | 928 | " # date features\n",
|
921 | 929 | " names = [f.__name__ if callable(f) else f for f in self.date_features]\n",
|
|
1663 | 1671 | "text/markdown": [
|
1664 | 1672 | "---\n",
|
1665 | 1673 | "\n",
|
1666 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L481){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 1674 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L486){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
1667 | 1675 | "\n",
|
1668 | 1676 | "## TimeSeries.fit_transform\n",
|
1669 | 1677 | "\n",
|
1670 |
| - "> TimeSeries.fit_transform (data:Union[pandas.core.frame.DataFrame,polars.d\n", |
1671 |
| - "> ataframe.frame.DataFrame], id_col:str,\n", |
1672 |
| - "> time_col:str, target_col:str,\n", |
| 1678 | + "> TimeSeries.fit_transform (data:~DFType, id_col:str, time_col:str,\n", |
| 1679 | + "> target_col:str,\n", |
1673 | 1680 | "> static_features:Optional[List[str]]=None,\n",
|
1674 | 1681 | "> dropna:bool=True,\n",
|
1675 | 1682 | "> keep_last_n:Optional[int]=None,\n",
|
|
1685 | 1692 | "text/plain": [
|
1686 | 1693 | "---\n",
|
1687 | 1694 | "\n",
|
1688 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L481){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 1695 | + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L486){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
1689 | 1696 | "\n",
|
1690 | 1697 | "## TimeSeries.fit_transform\n",
|
1691 | 1698 | "\n",
|
1692 |
| - "> TimeSeries.fit_transform (data:Union[pandas.core.frame.DataFrame,polars.d\n", |
1693 |
| - "> ataframe.frame.DataFrame], id_col:str,\n", |
1694 |
| - "> time_col:str, target_col:str,\n", |
| 1699 | + "> TimeSeries.fit_transform (data:~DFType, id_col:str, time_col:str,\n", |
| 1700 | + "> target_col:str,\n", |
1695 | 1701 | "> static_features:Optional[List[str]]=None,\n",
|
1696 | 1702 | "> dropna:bool=True,\n",
|
1697 | 1703 | "> keep_last_n:Optional[int]=None,\n",
|
|
1972 | 1978 | "cell_type": "code",
|
1973 | 1979 | "execution_count": null,
|
1974 | 1980 | "metadata": {},
|
1975 |
| - "outputs": [ |
1976 |
| - { |
1977 |
| - "data": { |
1978 |
| - "text/markdown": [ |
1979 |
| - "---\n", |
1980 |
| - "\n", |
1981 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L726){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
1982 |
| - "\n", |
1983 |
| - "## TimeSeries.predict\n", |
1984 |
| - "\n", |
1985 |
| - "> TimeSeries.predict (models:Dict[str,Union[sklearn.base.BaseEstimator,List\n", |
1986 |
| - "> [sklearn.base.BaseEstimator]]], horizon:int,\n", |
1987 |
| - "> before_predict_callback:Optional[Callable]=None,\n", |
1988 |
| - "> after_predict_callback:Optional[Callable]=None, X_df:\n", |
1989 |
| - "> Union[pandas.core.frame.DataFrame,polars.dataframe.fr\n", |
1990 |
| - "> ame.DataFrame,NoneType]=None,\n", |
1991 |
| - "> ids:Optional[List[str]]=None)" |
1992 |
| - ], |
1993 |
| - "text/plain": [ |
1994 |
| - "---\n", |
1995 |
| - "\n", |
1996 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L726){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
1997 |
| - "\n", |
1998 |
| - "## TimeSeries.predict\n", |
1999 |
| - "\n", |
2000 |
| - "> TimeSeries.predict (models:Dict[str,Union[sklearn.base.BaseEstimator,List\n", |
2001 |
| - "> [sklearn.base.BaseEstimator]]], horizon:int,\n", |
2002 |
| - "> before_predict_callback:Optional[Callable]=None,\n", |
2003 |
| - "> after_predict_callback:Optional[Callable]=None, X_df:\n", |
2004 |
| - "> Union[pandas.core.frame.DataFrame,polars.dataframe.fr\n", |
2005 |
| - "> ame.DataFrame,NoneType]=None,\n", |
2006 |
| - "> ids:Optional[List[str]]=None)" |
2007 |
| - ] |
2008 |
| - }, |
2009 |
| - "execution_count": null, |
2010 |
| - "metadata": {}, |
2011 |
| - "output_type": "execute_result" |
2012 |
| - } |
2013 |
| - ], |
| 1981 | + "outputs": [], |
2014 | 1982 | "source": [
|
2015 | 1983 | "show_doc(TimeSeries.predict, title_level=2)"
|
2016 | 1984 | ]
|
|
2126 | 2094 | "cell_type": "code",
|
2127 | 2095 | "execution_count": null,
|
2128 | 2096 | "metadata": {},
|
2129 |
| - "outputs": [ |
2130 |
| - { |
2131 |
| - "data": { |
2132 |
| - "text/markdown": [ |
2133 |
| - "---\n", |
2134 |
| - "\n", |
2135 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L831){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
2136 |
| - "\n", |
2137 |
| - "## TimeSeries.update\n", |
2138 |
| - "\n", |
2139 |
| - "> TimeSeries.update\n", |
2140 |
| - "> (df:Union[pandas.core.frame.DataFrame,polars.dataframe\n", |
2141 |
| - "> .frame.DataFrame])\n", |
2142 |
| - "\n", |
2143 |
| - "*Update the values of the stored series.*" |
2144 |
| - ], |
2145 |
| - "text/plain": [ |
2146 |
| - "---\n", |
2147 |
| - "\n", |
2148 |
| - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L831){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
2149 |
| - "\n", |
2150 |
| - "## TimeSeries.update\n", |
2151 |
| - "\n", |
2152 |
| - "> TimeSeries.update\n", |
2153 |
| - "> (df:Union[pandas.core.frame.DataFrame,polars.dataframe\n", |
2154 |
| - "> .frame.DataFrame])\n", |
2155 |
| - "\n", |
2156 |
| - "*Update the values of the stored series.*" |
2157 |
| - ] |
2158 |
| - }, |
2159 |
| - "execution_count": null, |
2160 |
| - "metadata": {}, |
2161 |
| - "output_type": "execute_result" |
2162 |
| - } |
2163 |
| - ], |
| 2097 | + "outputs": [], |
2164 | 2098 | "source": [
|
2165 | 2099 | "show_doc(TimeSeries.update, title_level=2)"
|
2166 | 2100 | ]
|
|
2246 | 2180 | "cell_type": "code",
|
2247 | 2181 | "execution_count": null,
|
2248 | 2182 | "metadata": {},
|
2249 |
| - "outputs": [ |
2250 |
| - { |
2251 |
| - "name": "stderr", |
2252 |
| - "output_type": "stream", |
2253 |
| - "text": [ |
2254 |
| - "sys:1: CategoricalRemappingWarning: Local categoricals have different encodings, expensive re-encoding is done to perform this merge operation. Consider using a StringCache or an Enum type if the categories are known in advance\n" |
2255 |
| - ] |
2256 |
| - } |
2257 |
| - ], |
| 2183 | + "outputs": [], |
2258 | 2184 | "source": [
|
2259 | 2185 | "#| hide\n",
|
2260 | 2186 | "#| polars\n",
|
|
0 commit comments