Skip to content

Commit 5ec8614

Browse files
committed
Seasonality values can be negative while prediction stays positive
1 parent 3063ab3 commit 5ec8614

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

python/prophet/forecaster.py

+4
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,10 @@ def predict(self, df: pd.DataFrame = None, vectorized: bool = True) -> pd.DataFr
13261326
df2['trend'] * (1 + df2['multiplicative_terms'])
13271327
+ df2['additive_terms']
13281328
)
1329+
1330+
if not self.negative_prediction_values:
1331+
df2['yhat'] = df2['yhat'].clip(lower=0)
1332+
13291333
return df2
13301334

13311335
@staticmethod

python/prophet/tests/test_prophet.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -434,21 +434,15 @@ def test_override_n_changepoints(self, daily_univariate_ts, backend):
434434
cp = m.changepoints_t
435435
assert cp.shape[0] == 15
436436

437-
@pytest.mark.parametrize(
438-
"expected",
439-
[5.656087514685135],
440-
)
441-
def test_without_negative_predictions(self, subdaily_univariate_ts, backend, expected):
437+
def test_without_negative_predictions(self, subdaily_univariate_ts, backend):
442438
test_days = 280
443439
train, test = train_test_split(subdaily_univariate_ts, test_days)
444-
forecaster = Prophet(stan_backend=backend, negative_prediction_values=False)
440+
forecaster = Prophet(stan_backend=backend, negative_prediction_values=False, weekly_seasonality=True, yearly_seasonality=True)
445441
forecaster.fit(train, seed=1237861298)
446442
np.random.seed(876543987)
447443
future = forecaster.make_future_dataframe(test_days, include_history=False)
448444
future = forecaster.predict(future)
449-
res = rmse(future["yhat"], test["y"])
450-
tolerance = 1e-5
451-
assert res == pytest.approx(expected, rel=tolerance), "backend: {}".format(forecaster.stan_backend)
445+
assert (future['yhat'].values >= 0).all()
452446

453447

454448
class TestProphetSeasonalComponent:

0 commit comments

Comments
 (0)