Skip to content

Commit b4ee8f1

Browse files
imad24Imad Rahmouni
and
Imad Rahmouni
authored
Include predictions for missing y (NaN) dates in the history (#2530)
Co-authored-by: Imad Rahmouni <[email protected]>
1 parent 7e83e2c commit b4ee8f1

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

python/prophet/forecaster.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1130,7 +1130,7 @@ def preprocess(self, df: pd.DataFrame, **kwargs) -> ModelInputData:
11301130
history = df[df['y'].notnull()].copy()
11311131
if history.shape[0] < 2:
11321132
raise ValueError('Dataframe has less than 2 non-NaN rows.')
1133-
self.history_dates = pd.to_datetime(pd.Series(history['ds'].unique(), name='ds')).sort_values()
1133+
self.history_dates = pd.to_datetime(pd.Series(df['ds'].unique(), name='ds')).sort_values()
11341134

11351135
self.history = self.setup_dataframe(history, initialize_scales=True)
11361136
self.set_auto_seasonalities()

python/prophet/tests/test_prophet.py

+10
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,16 @@ def test_make_future_dataframe(self, daily_univariate_ts, backend):
244244
assert len(future) == 3
245245
assert np.all(future["ds"].values == correct.values)
246246

247+
def test_make_future_dataframe_include_history(self, daily_univariate_ts, backend):
248+
train = daily_univariate_ts.head(468 // 2).copy()
249+
#cover history with NAs
250+
train.loc[train.sample(10).index, "y"] = np.nan
251+
252+
forecaster = Prophet(stan_backend=backend)
253+
forecaster.fit(train)
254+
future = forecaster.make_future_dataframe(periods=3, freq="D", include_history=True)
255+
256+
assert len(future) == train.shape[0] + 3
247257

248258
class TestProphetTrendComponent:
249259
def test_invalid_growth_input(self, backend):

0 commit comments

Comments
 (0)