Skip to content

Commit 2164ce4

Browse files
authored
feat: add weight_col to MLForecast.fit and MLForecast.cross_validation (#444)
1 parent 3edd9df commit 2164ce4

File tree

7 files changed

+1180
-118
lines changed

7 files changed

+1180
-118
lines changed

mlforecast/core.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def _fit(
242242
target_col: str,
243243
static_features: Optional[List[str]] = None,
244244
keep_last_n: Optional[int] = None,
245+
weight_col: Optional[str] = None,
245246
) -> "TimeSeries":
246247
"""Save the series values, ids and last dates."""
247248
validate_format(df, id_col, time_col, target_col)
@@ -251,6 +252,7 @@ def _fit(
251252
self.id_col = id_col
252253
self.target_col = target_col
253254
self.time_col = time_col
255+
self.weight_col = weight_col
254256
self.keep_last_n = keep_last_n
255257
self.static_features = static_features
256258
sorted_df = df[[id_col, time_col, target_col]]
@@ -298,9 +300,12 @@ def _fit(
298300
if static_features is None:
299301
static_features = [c for c in df.columns if c not in [time_col, target_col]]
300302
elif id_col not in static_features:
301-
static_features = [id_col] + static_features
303+
static_features = [id_col, *static_features]
302304
else: # static_features defined and contain id_col
303305
to_drop = [time_col, target_col]
306+
if weight_col is not None:
307+
to_drop.append(weight_col)
308+
static_features = [f for f in static_features if f != weight_col]
304309
self.ga = ga
305310
series_starts = ga.indptr[:-1]
306311
series_ends = ga.indptr[1:] - 1
@@ -478,7 +483,11 @@ def _transform(
478483

479484
# assemble return
480485
if return_X_y:
481-
X = df[self.features_order_]
486+
if self.weight_col is not None:
487+
x_cols = [self.weight_col, *self.features_order_]
488+
else:
489+
x_cols = self.features_order_
490+
X = df[x_cols]
482491
if as_numpy:
483492
X = ufp.to_numpy(X)
484493
return X, target
@@ -506,6 +515,7 @@ def fit_transform(
506515
max_horizon: Optional[int] = None,
507516
return_X_y: bool = False,
508517
as_numpy: bool = False,
518+
weight_col: Optional[str] = None,
509519
) -> Union[DFType, Tuple[DFType, np.ndarray]]:
510520
"""Add the features to `data` and save the required information for the predictions step.
511521
@@ -522,6 +532,7 @@ def fit_transform(
522532
target_col=target_col,
523533
static_features=static_features,
524534
keep_last_n=keep_last_n,
535+
weight_col=weight_col,
525536
)
526537
return self._transform(
527538
df=data,

mlforecast/forecast.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def preprocess(
214214
max_horizon: Optional[int] = None,
215215
return_X_y: bool = False,
216216
as_numpy: bool = False,
217+
weight_col: Optional[str] = None,
217218
) -> Union[DFType, Tuple[DFType, np.ndarray]]:
218219
"""Add the features to `data`.
219220
@@ -239,6 +240,8 @@ def preprocess(
239240
Return a tuple with the features and the target. If False will return a single dataframe.
240241
as_numpy : bool (default = False)
241242
Cast features to numpy array. Only works for `return_X_y=True`.
243+
weight_col : str, optional (default=None)
244+
Column that contains the sample weights.
242245
243246
Returns
244247
-------
@@ -256,6 +259,7 @@ def preprocess(
256259
max_horizon=max_horizon,
257260
return_X_y=return_X_y,
258261
as_numpy=as_numpy,
262+
weight_col=weight_col,
259263
)
260264

261265
def fit_models(
@@ -277,21 +281,31 @@ def fit_models(
277281
self : MLForecast
278282
Forecast object with trained models.
279283
"""
284+
285+
def fit_model(model, X, y, weight_col):
286+
fit_kwargs = {}
287+
if weight_col is not None:
288+
if isinstance(X, np.ndarray):
289+
fit_kwargs["sample_weight"] = X[:, 0]
290+
X = X[:, 1:]
291+
else:
292+
fit_kwargs["sample_weight"] = X[weight_col]
293+
X = ufp.drop_columns(X, weight_col)
294+
return clone(model).fit(X, y, **fit_kwargs)
295+
280296
self.models_: Dict[str, Union[BaseEstimator, List[BaseEstimator]]] = {}
281297
for name, model in self.models.items():
282298
if y.ndim == 2 and y.shape[1] > 1:
283299
self.models_[name] = []
284300
for col in range(y.shape[1]):
285301
keep = ~np.isnan(y[:, col])
286-
if isinstance(X, np.ndarray):
287-
# TODO: migrate to utils
288-
Xh = X[keep]
289-
else:
290-
Xh = ufp.filter_with_mask(X, keep)
302+
Xh = ufp.filter_with_mask(X, keep)
291303
yh = y[keep, col]
292-
self.models_[name].append(clone(model).fit(Xh, yh))
304+
self.models_[name].append(
305+
fit_model(model, Xh, yh, self.ts.weight_col)
306+
)
293307
else:
294-
self.models_[name] = clone(model).fit(X, y)
308+
self.models_[name] = fit_model(model, X, y, self.ts.weight_col)
295309
return self
296310

297311
def _conformity_scores(
@@ -380,8 +394,12 @@ def _extract_X_y(
380394
self,
381395
prep: DFType,
382396
target_col: str,
397+
weight_col: Optional[str],
383398
) -> Tuple[Union[DFType, np.ndarray], np.ndarray]:
384-
X = prep[self.ts.features_order_]
399+
x_cols = self.ts.features_order_
400+
if weight_col is not None:
401+
x_cols = [weight_col, *x_cols]
402+
X = prep[x_cols]
385403
targets = [c for c in prep.columns if re.match(rf"^{target_col}\d*$", c)]
386404
if len(targets) == 1:
387405
targets = targets[0]
@@ -397,7 +415,13 @@ def _compute_fitted_values(
397415
time_col: str,
398416
target_col: str,
399417
max_horizon: Optional[int],
418+
weight_col: Optional[str],
400419
) -> DFType:
420+
if weight_col is not None:
421+
if isinstance(X, np.ndarray):
422+
X = X[:, 1:]
423+
else:
424+
X = ufp.drop_columns(X, weight_col)
401425
base = ufp.copy_if_pandas(base, deep=False)
402426
sort_idxs = ufp.maybe_compute_sort_indices(base, id_col, time_col)
403427
if sort_idxs is not None:
@@ -456,6 +480,7 @@ def fit(
456480
prediction_intervals: Optional[PredictionIntervals] = None,
457481
fitted: bool = False,
458482
as_numpy: bool = False,
483+
weight_col: Optional[str] = None,
459484
) -> "MLForecast":
460485
"""Apply the feature engineering and train the models.
461486
@@ -484,6 +509,8 @@ def fit(
484509
Save in-sample predictions.
485510
as_numpy : bool (default = False)
486511
Cast features to numpy array.
512+
weight_col : str, optional (default=None)
513+
Column that contains the sample weights.
487514
488515
Returns
489516
-------
@@ -520,12 +547,13 @@ def fit(
520547
max_horizon=max_horizon,
521548
return_X_y=not fitted,
522549
as_numpy=as_numpy,
550+
weight_col=weight_col,
523551
)
524552
if isinstance(prep, tuple):
525553
X, y = prep
526554
else:
527555
base = prep[[id_col, time_col]]
528-
X, y = self._extract_X_y(prep, target_col)
556+
X, y = self._extract_X_y(prep, target_col, weight_col)
529557
if as_numpy:
530558
X = ufp.to_numpy(X)
531559
del prep
@@ -539,6 +567,7 @@ def fit(
539567
time_col=time_col,
540568
target_col=target_col,
541569
max_horizon=max_horizon,
570+
weight_col=self.ts.weight_col,
542571
)
543572
fitted_values = ufp.drop_index_if_pandas(fitted_values)
544573
self.fcst_fitted_values_ = fitted_values
@@ -784,6 +813,7 @@ def cross_validation(
784813
input_size: Optional[int] = None,
785814
fitted: bool = False,
786815
as_numpy: bool = False,
816+
weight_col: Optional[str] = None,
787817
) -> DFType:
788818
"""Perform time series cross validation.
789819
Creates `n_windows` splits where each window has `h` test periods,
@@ -835,6 +865,8 @@ def cross_validation(
835865
Store the in-sample predictions.
836866
as_numpy : bool (default = False)
837867
Cast features to numpy array.
868+
weight_col : str, optional (default=None)
869+
Column that contains the sample weights.
838870
839871
Returns
840872
-------
@@ -869,6 +901,7 @@ def cross_validation(
869901
prediction_intervals=prediction_intervals,
870902
fitted=fitted,
871903
as_numpy=as_numpy,
904+
weight_col=weight_col,
872905
)
873906
cv_models.append(self.models_)
874907
if fitted:
@@ -890,10 +923,11 @@ def cross_validation(
890923
keep_last_n=keep_last_n,
891924
max_horizon=max_horizon,
892925
return_X_y=False,
926+
weight_col=weight_col,
893927
)
894928
assert not isinstance(prep, tuple)
895929
base = prep[[id_col, time_col]]
896-
train_X, train_y = self._extract_X_y(prep, target_col)
930+
train_X, train_y = self._extract_X_y(prep, target_col, weight_col)
897931
if as_numpy:
898932
train_X = ufp.to_numpy(train_X)
899933
del prep
@@ -905,6 +939,7 @@ def cross_validation(
905939
time_col=time_col,
906940
target_col=target_col,
907941
max_horizon=max_horizon,
942+
weight_col=weight_col,
908943
)
909944
fitted_values = ufp.assign_columns(fitted_values, "fold", i_window)
910945
cv_fitted_values.append(fitted_values)

nbs/core.ipynb

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@
736736
" target_col: str,\n",
737737
" static_features: Optional[List[str]] = None,\n",
738738
" keep_last_n: Optional[int] = None,\n",
739+
" weight_col: Optional[str] = None,\n",
739740
" ) -> 'TimeSeries':\n",
740741
" \"\"\"Save the series values, ids and last dates.\"\"\"\n",
741742
" validate_format(df, id_col, time_col, target_col)\n",
@@ -745,6 +746,7 @@
745746
" self.id_col = id_col\n",
746747
" self.target_col = target_col\n",
747748
" self.time_col = time_col\n",
749+
" self.weight_col = weight_col\n",
748750
" self.keep_last_n = keep_last_n\n",
749751
" self.static_features = static_features\n",
750752
" sorted_df = df[[id_col, time_col, target_col]]\n",
@@ -790,9 +792,12 @@
790792
" if static_features is None:\n",
791793
" static_features = [c for c in df.columns if c not in [time_col, target_col]]\n",
792794
" elif id_col not in static_features:\n",
793-
" static_features = [id_col] + static_features\n",
795+
" static_features = [id_col, *static_features]\n",
794796
" else: # static_features defined and contain id_col\n",
795797
" to_drop = [time_col, target_col]\n",
798+
" if weight_col is not None:\n",
799+
" to_drop.append(weight_col)\n",
800+
" static_features = [f for f in static_features if f != weight_col]\n",
796801
" self.ga = ga\n",
797802
" series_starts = ga.indptr[:-1]\n",
798803
" series_ends = ga.indptr[1:] - 1\n",
@@ -967,7 +972,11 @@
967972
"\n",
968973
" # assemble return\n",
969974
" if return_X_y:\n",
970-
" X = df[self.features_order_]\n",
975+
" if self.weight_col is not None:\n",
976+
" x_cols = [self.weight_col, *self.features_order_]\n",
977+
" else:\n",
978+
" x_cols = self.features_order_\n",
979+
" X = df[x_cols]\n",
971980
" if as_numpy:\n",
972981
" X = ufp.to_numpy(X)\n",
973982
" return X, target\n",
@@ -996,6 +1005,7 @@
9961005
" max_horizon: Optional[int] = None,\n",
9971006
" return_X_y: bool = False,\n",
9981007
" as_numpy: bool = False,\n",
1008+
" weight_col: Optional[str] = None,\n",
9991009
" ) -> Union[DFType, Tuple[DFType, np.ndarray]]:\n",
10001010
" \"\"\"Add the features to `data` and save the required information for the predictions step.\n",
10011011
" \n",
@@ -1012,6 +1022,7 @@
10121022
" target_col=target_col,\n",
10131023
" static_features=static_features,\n",
10141024
" keep_last_n=keep_last_n,\n",
1025+
" weight_col=weight_col,\n",
10151026
" )\n",
10161027
" return self._transform(\n",
10171028
" df=data,\n",
@@ -1690,7 +1701,7 @@
16901701
"text/markdown": [
16911702
"---\n",
16921703
"\n",
1693-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L487){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
1704+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L496){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
16941705
"\n",
16951706
"## TimeSeries.fit_transform\n",
16961707
"\n",
@@ -1700,7 +1711,8 @@
17001711
"> dropna:bool=True,\n",
17011712
"> keep_last_n:Optional[int]=None,\n",
17021713
"> max_horizon:Optional[int]=None,\n",
1703-
"> return_X_y:bool=False, as_numpy:bool=False)\n",
1714+
"> return_X_y:bool=False, as_numpy:bool=False,\n",
1715+
"> weight_col:Optional[str]=None)\n",
17041716
"\n",
17051717
"*Add the features to `data` and save the required information for the predictions step.\n",
17061718
"\n",
@@ -1711,7 +1723,7 @@
17111723
"text/plain": [
17121724
"---\n",
17131725
"\n",
1714-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L487){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
1726+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L496){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
17151727
"\n",
17161728
"## TimeSeries.fit_transform\n",
17171729
"\n",
@@ -1721,7 +1733,8 @@
17211733
"> dropna:bool=True,\n",
17221734
"> keep_last_n:Optional[int]=None,\n",
17231735
"> max_horizon:Optional[int]=None,\n",
1724-
"> return_X_y:bool=False, as_numpy:bool=False)\n",
1736+
"> return_X_y:bool=False, as_numpy:bool=False,\n",
1737+
"> weight_col:Optional[str]=None)\n",
17251738
"\n",
17261739
"*Add the features to `data` and save the required information for the predictions step.\n",
17271740
"\n",
@@ -2003,7 +2016,7 @@
20032016
"text/markdown": [
20042017
"---\n",
20052018
"\n",
2006-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L732){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
2019+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L743){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
20072020
"\n",
20082021
"## TimeSeries.predict\n",
20092022
"\n",
@@ -2017,7 +2030,7 @@
20172030
"text/plain": [
20182031
"---\n",
20192032
"\n",
2020-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L732){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
2033+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L743){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
20212034
"\n",
20222035
"## TimeSeries.predict\n",
20232036
"\n",
@@ -2155,7 +2168,7 @@
21552168
"text/markdown": [
21562169
"---\n",
21572170
"\n",
2158-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L837){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
2171+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L848){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
21592172
"\n",
21602173
"## TimeSeries.update\n",
21612174
"\n",
@@ -2168,7 +2181,7 @@
21682181
"text/plain": [
21692182
"---\n",
21702183
"\n",
2171-
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L837){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
2184+
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L848){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
21722185
"\n",
21732186
"## TimeSeries.update\n",
21742187
"\n",

0 commit comments

Comments
 (0)