Skip to content

Commit 3edd9df

Browse files
authored
feat: infer samples required for built-in lag transforms updates (#445)
1 parent 8c28a7f commit 3edd9df

File tree

6 files changed

+337
-97
lines changed

6 files changed

+337
-97
lines changed

mlforecast/_modidx.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,8 @@
262262
'mlforecast/lag_transforms.py'),
263263
'mlforecast.lag_transforms.Combine.update': ( 'lag_transforms.html#combine.update',
264264
'mlforecast/lag_transforms.py'),
265+
'mlforecast.lag_transforms.Combine.update_samples': ( 'lag_transforms.html#combine.update_samples',
266+
'mlforecast/lag_transforms.py'),
265267
'mlforecast.lag_transforms.ExpandingMax': ( 'lag_transforms.html#expandingmax',
266268
'mlforecast/lag_transforms.py'),
267269
'mlforecast.lag_transforms.ExpandingMean': ( 'lag_transforms.html#expandingmean',
@@ -272,12 +274,16 @@
272274
'mlforecast/lag_transforms.py'),
273275
'mlforecast.lag_transforms.ExpandingQuantile.__init__': ( 'lag_transforms.html#expandingquantile.__init__',
274276
'mlforecast/lag_transforms.py'),
277+
'mlforecast.lag_transforms.ExpandingQuantile.update_samples': ( 'lag_transforms.html#expandingquantile.update_samples',
278+
'mlforecast/lag_transforms.py'),
275279
'mlforecast.lag_transforms.ExpandingStd': ( 'lag_transforms.html#expandingstd',
276280
'mlforecast/lag_transforms.py'),
277281
'mlforecast.lag_transforms.ExponentiallyWeightedMean': ( 'lag_transforms.html#exponentiallyweightedmean',
278282
'mlforecast/lag_transforms.py'),
279283
'mlforecast.lag_transforms.ExponentiallyWeightedMean.__init__': ( 'lag_transforms.html#exponentiallyweightedmean.__init__',
280284
'mlforecast/lag_transforms.py'),
285+
'mlforecast.lag_transforms.ExponentiallyWeightedMean.update_samples': ( 'lag_transforms.html#exponentiallyweightedmean.update_samples',
286+
'mlforecast/lag_transforms.py'),
281287
'mlforecast.lag_transforms.Lag': ('lag_transforms.html#lag', 'mlforecast/lag_transforms.py'),
282288
'mlforecast.lag_transforms.Lag.__eq__': ( 'lag_transforms.html#lag.__eq__',
283289
'mlforecast/lag_transforms.py'),
@@ -287,6 +293,8 @@
287293
'mlforecast/lag_transforms.py'),
288294
'mlforecast.lag_transforms.Lag._set_core_tfm': ( 'lag_transforms.html#lag._set_core_tfm',
289295
'mlforecast/lag_transforms.py'),
296+
'mlforecast.lag_transforms.Lag.update_samples': ( 'lag_transforms.html#lag.update_samples',
297+
'mlforecast/lag_transforms.py'),
290298
'mlforecast.lag_transforms.Offset': ( 'lag_transforms.html#offset',
291299
'mlforecast/lag_transforms.py'),
292300
'mlforecast.lag_transforms.Offset.__init__': ( 'lag_transforms.html#offset.__init__',
@@ -295,6 +303,8 @@
295303
'mlforecast/lag_transforms.py'),
296304
'mlforecast.lag_transforms.Offset._set_core_tfm': ( 'lag_transforms.html#offset._set_core_tfm',
297305
'mlforecast/lag_transforms.py'),
306+
'mlforecast.lag_transforms.Offset.update_samples': ( 'lag_transforms.html#offset.update_samples',
307+
'mlforecast/lag_transforms.py'),
298308
'mlforecast.lag_transforms.RollingMax': ( 'lag_transforms.html#rollingmax',
299309
'mlforecast/lag_transforms.py'),
300310
'mlforecast.lag_transforms.RollingMean': ( 'lag_transforms.html#rollingmean',
@@ -327,6 +337,8 @@
327337
'mlforecast/lag_transforms.py'),
328338
'mlforecast.lag_transforms._BaseLagTransform._get_name': ( 'lag_transforms.html#_baselagtransform._get_name',
329339
'mlforecast/lag_transforms.py'),
340+
'mlforecast.lag_transforms._BaseLagTransform._lag': ( 'lag_transforms.html#_baselagtransform._lag',
341+
'mlforecast/lag_transforms.py'),
330342
'mlforecast.lag_transforms._BaseLagTransform._set_core_tfm': ( 'lag_transforms.html#_baselagtransform._set_core_tfm',
331343
'mlforecast/lag_transforms.py'),
332344
'mlforecast.lag_transforms._BaseLagTransform.stack': ( 'lag_transforms.html#_baselagtransform.stack',
@@ -337,18 +349,26 @@
337349
'mlforecast/lag_transforms.py'),
338350
'mlforecast.lag_transforms._BaseLagTransform.update': ( 'lag_transforms.html#_baselagtransform.update',
339351
'mlforecast/lag_transforms.py'),
352+
'mlforecast.lag_transforms._BaseLagTransform.update_samples': ( 'lag_transforms.html#_baselagtransform.update_samples',
353+
'mlforecast/lag_transforms.py'),
340354
'mlforecast.lag_transforms._ExpandingBase': ( 'lag_transforms.html#_expandingbase',
341355
'mlforecast/lag_transforms.py'),
342356
'mlforecast.lag_transforms._ExpandingBase.__init__': ( 'lag_transforms.html#_expandingbase.__init__',
343357
'mlforecast/lag_transforms.py'),
358+
'mlforecast.lag_transforms._ExpandingBase.update_samples': ( 'lag_transforms.html#_expandingbase.update_samples',
359+
'mlforecast/lag_transforms.py'),
344360
'mlforecast.lag_transforms._RollingBase': ( 'lag_transforms.html#_rollingbase',
345361
'mlforecast/lag_transforms.py'),
346362
'mlforecast.lag_transforms._RollingBase.__init__': ( 'lag_transforms.html#_rollingbase.__init__',
347363
'mlforecast/lag_transforms.py'),
364+
'mlforecast.lag_transforms._RollingBase.update_samples': ( 'lag_transforms.html#_rollingbase.update_samples',
365+
'mlforecast/lag_transforms.py'),
348366
'mlforecast.lag_transforms._Seasonal_RollingBase': ( 'lag_transforms.html#_seasonal_rollingbase',
349367
'mlforecast/lag_transforms.py'),
350368
'mlforecast.lag_transforms._Seasonal_RollingBase.__init__': ( 'lag_transforms.html#_seasonal_rollingbase.__init__',
351369
'mlforecast/lag_transforms.py'),
370+
'mlforecast.lag_transforms._Seasonal_RollingBase.update_samples': ( 'lag_transforms.html#_seasonal_rollingbase.update_samples',
371+
'mlforecast/lag_transforms.py'),
352372
'mlforecast.lag_transforms._pascal2camel': ( 'lag_transforms.html#_pascal2camel',
353373
'mlforecast/lag_transforms.py')},
354374
'mlforecast.lgb_cv': { 'mlforecast.lgb_cv.LightGBMCV': ('lgb_cv.html#lightgbmcv', 'mlforecast/lgb_cv.py'),

mlforecast/core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,16 @@ def _transform(
432432
self._dropped_series = None
433433

434434
# once we've computed the features and target we can slice the series
435+
update_samples = [
436+
getattr(tfm, "update_samples", -1) for tfm in self.transforms.values()
437+
]
438+
if (
439+
self.keep_last_n is None
440+
and update_samples
441+
and all(samples > 0 for samples in update_samples)
442+
):
443+
# user didn't set keep_last_n and we can infer it from the transforms
444+
self.keep_last_n = max(update_samples)
435445
if self.keep_last_n is not None:
436446
self.ga = self.ga.take_from_groups(slice(-self.keep_last_n, None))
437447
del self._restore_idxs, self._sort_idxs

mlforecast/lag_transforms.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,17 @@ def stack(transforms: Sequence["_BaseLagTransform"]) -> "_BaseLagTransform":
6868
)
6969
return out
7070

71+
@property
72+
def _lag(self):
73+
return self._core_tfm.lag - 1
74+
75+
@property
76+
def update_samples(self) -> int:
77+
return -1
78+
7179
# %% ../nbs/lag_transforms.ipynb 6
7280
class Lag(_BaseLagTransform):
81+
7382
def __init__(self, lag: int):
7483
self.lag = lag
7584
self._core_tfm = core_tfms.Lag(lag=lag)
@@ -83,6 +92,10 @@ def _get_name(self, lag: int) -> str:
8392
def __eq__(self, other):
8493
return isinstance(other, Lag) and self.lag == other.lag
8594

95+
@property
96+
def update_samples(self) -> int:
97+
return self.lag
98+
8699
# %% ../nbs/lag_transforms.ipynb 7
87100
class _RollingBase(_BaseLagTransform):
88101
"Rolling statistic"
@@ -100,6 +113,10 @@ def __init__(self, window_size: int, min_samples: Optional[int] = None):
100113
self.window_size = window_size
101114
self.min_samples = min_samples
102115

116+
@property
117+
def update_samples(self) -> int:
118+
return self._lag + self.window_size
119+
103120
# %% ../nbs/lag_transforms.ipynb 8
104121
class RollingMean(_RollingBase): ...
105122

@@ -149,6 +166,10 @@ def __init__(
149166
self.window_size = window_size
150167
self.min_samples = min_samples
151168

169+
@property
170+
def update_samples(self) -> int:
171+
return self._lag + self.season_length * self.window_size
172+
152173
# %% ../nbs/lag_transforms.ipynb 11
153174
class SeasonalRollingMean(_Seasonal_RollingBase): ...
154175

@@ -183,6 +204,10 @@ class _ExpandingBase(_BaseLagTransform):
183204

184205
def __init__(self): ...
185206

207+
@property
208+
def update_samples(self) -> int:
209+
return 1
210+
186211
# %% ../nbs/lag_transforms.ipynb 14
187212
class ExpandingMean(_ExpandingBase): ...
188213

@@ -200,6 +225,10 @@ class ExpandingQuantile(_ExpandingBase):
200225
def __init__(self, p: float):
201226
self.p = p
202227

228+
@property
229+
def update_samples(self) -> int:
230+
return -1
231+
203232
# %% ../nbs/lag_transforms.ipynb 16
204233
class ExponentiallyWeightedMean(_BaseLagTransform):
205234
"""Exponentially weighted average
@@ -212,6 +241,10 @@ class ExponentiallyWeightedMean(_BaseLagTransform):
212241
def __init__(self, alpha: float):
213242
self.alpha = alpha
214243

244+
@property
245+
def update_samples(self) -> int:
246+
return 1
247+
215248
# %% ../nbs/lag_transforms.ipynb 18
216249
class Offset(_BaseLagTransform):
217250
"""Shift series before computing transformation
@@ -231,9 +264,14 @@ def _get_name(self, lag: int) -> str:
231264
return self.tfm._get_name(lag + self.n)
232265

233266
def _set_core_tfm(self, lag: int) -> "Offset":
234-
self._core_tfm = clone(self.tfm)._set_core_tfm(lag + self.n)
267+
self.tfm = clone(self.tfm)._set_core_tfm(lag + self.n)
268+
self._core_tfm = self.tfm._core_tfm
235269
return self
236270

271+
@property
272+
def update_samples(self) -> int:
273+
return self.tfm.update_samples + self.n
274+
237275
# %% ../nbs/lag_transforms.ipynb 20
238276
class Combine(_BaseLagTransform):
239277
"""Combine two lag transformations using an operator
@@ -269,3 +307,7 @@ def transform(self, ga: CoreGroupedArray) -> np.ndarray:
269307

270308
def update(self, ga: CoreGroupedArray) -> np.ndarray:
271309
return self.operator(self.tfm1.update(ga), self.tfm2.update(ga))
310+
311+
@property
312+
def update_samples(self):
313+
return max(self.tfm1.update_samples, self.tfm2.update_samples)

0 commit comments

Comments
 (0)