Skip to content

Commit 691a464

Browse files
authored
Enable creation of custom performance metrics (#2599)
1 parent 2a57e9d commit 691a464

File tree

8 files changed

+624
-240
lines changed

8 files changed

+624
-240
lines changed

docs/_docs/diagnostics.md

+171-71
Original file line numberDiff line numberDiff line change
@@ -86,45 +86,45 @@ df_cv.head()
8686
<tr>
8787
<th>0</th>
8888
<td>2010-02-16</td>
89-
<td>8.959678</td>
90-
<td>8.470035</td>
91-
<td>9.451618</td>
89+
<td>8.954582</td>
90+
<td>8.462876</td>
91+
<td>9.452305</td>
9292
<td>8.242493</td>
9393
<td>2010-02-15</td>
9494
</tr>
9595
<tr>
9696
<th>1</th>
9797
<td>2010-02-17</td>
98-
<td>8.726195</td>
99-
<td>8.236734</td>
100-
<td>9.219616</td>
98+
<td>8.720932</td>
99+
<td>8.222682</td>
100+
<td>9.242788</td>
101101
<td>8.008033</td>
102102
<td>2010-02-15</td>
103103
</tr>
104104
<tr>
105105
<th>2</th>
106106
<td>2010-02-18</td>
107-
<td>8.610011</td>
108-
<td>8.104834</td>
109-
<td>9.125484</td>
107+
<td>8.604608</td>
108+
<td>8.066920</td>
109+
<td>9.144968</td>
110110
<td>8.045268</td>
111111
<td>2010-02-15</td>
112112
</tr>
113113
<tr>
114114
<th>3</th>
115115
<td>2010-02-19</td>
116-
<td>8.532004</td>
117-
<td>7.985031</td>
118-
<td>9.041575</td>
116+
<td>8.526379</td>
117+
<td>8.029189</td>
118+
<td>9.043045</td>
119119
<td>7.928766</td>
120120
<td>2010-02-15</td>
121121
</tr>
122122
<tr>
123123
<th>4</th>
124124
<td>2010-02-20</td>
125-
<td>8.274090</td>
126-
<td>7.779034</td>
127-
<td>8.745627</td>
125+
<td>8.268247</td>
126+
<td>7.749520</td>
127+
<td>8.741847</td>
128128
<td>7.745003</td>
129129
<td>2010-02-15</td>
130130
</tr>
@@ -154,6 +154,20 @@ df_cv2 = cross_validation(m, cutoffs=cutoffs, horizon='365 days')
154154
The `performance_metrics` utility can be used to compute some useful statistics of the prediction performance (`yhat`, `yhat_lower`, and `yhat_upper` compared to `y`), as a function of the distance from the cutoff (how far into the future the prediction was). The statistics computed are mean squared error (MSE), root mean squared error (RMSE), mean absolute error (MAE), mean absolute percent error (MAPE), median absolute percent error (MDAPE) and coverage of the `yhat_lower` and `yhat_upper` estimates. These are computed on a rolling window of the predictions in `df_cv` after sorting by horizon (`ds` minus `cutoff`). By default 10% of the predictions will be included in each window, but this can be changed with the `rolling_window` argument.
155155

156156

157+
158+
In Python, you can also create custom performance metric using the `register_performance_metric` decorator. Created metric should contain following arguments:
159+
160+
- df: Cross-validation results dataframe.
161+
162+
- w: Aggregation window size.
163+
164+
165+
166+
and return:
167+
168+
- Dataframe with columns horizon and metric.
169+
170+
157171
```R
158172
# R
159173
df.p <- performance_metrics(df.cv)
@@ -200,57 +214,143 @@ df_p.head()
200214
<tr>
201215
<th>0</th>
202216
<td>37 days</td>
203-
<td>0.493764</td>
204-
<td>0.702683</td>
205-
<td>0.504754</td>
206-
<td>0.058485</td>
207-
<td>0.049922</td>
208-
<td>0.058774</td>
209-
<td>0.674052</td>
217+
<td>0.493358</td>
218+
<td>0.702395</td>
219+
<td>0.503977</td>
220+
<td>0.058376</td>
221+
<td>0.049365</td>
222+
<td>0.058677</td>
223+
<td>0.676565</td>
210224
</tr>
211225
<tr>
212226
<th>1</th>
213227
<td>38 days</td>
214-
<td>0.499522</td>
215-
<td>0.706769</td>
216-
<td>0.509723</td>
217-
<td>0.059060</td>
218-
<td>0.049389</td>
219-
<td>0.059409</td>
220-
<td>0.672910</td>
228+
<td>0.499112</td>
229+
<td>0.706478</td>
230+
<td>0.508946</td>
231+
<td>0.058951</td>
232+
<td>0.049135</td>
233+
<td>0.059312</td>
234+
<td>0.675423</td>
221235
</tr>
222236
<tr>
223237
<th>2</th>
224238
<td>39 days</td>
225-
<td>0.521614</td>
226-
<td>0.722229</td>
227-
<td>0.515793</td>
228-
<td>0.059657</td>
229-
<td>0.049540</td>
230-
<td>0.060131</td>
231-
<td>0.670169</td>
239+
<td>0.521344</td>
240+
<td>0.722042</td>
241+
<td>0.515016</td>
242+
<td>0.059547</td>
243+
<td>0.049225</td>
244+
<td>0.060034</td>
245+
<td>0.672682</td>
232246
</tr>
233247
<tr>
234248
<th>3</th>
235249
<td>40 days</td>
236-
<td>0.528760</td>
237-
<td>0.727159</td>
238-
<td>0.518634</td>
239-
<td>0.059961</td>
240-
<td>0.049232</td>
241-
<td>0.060504</td>
242-
<td>0.671311</td>
250+
<td>0.528651</td>
251+
<td>0.727084</td>
252+
<td>0.517873</td>
253+
<td>0.059852</td>
254+
<td>0.049072</td>
255+
<td>0.060409</td>
256+
<td>0.676336</td>
243257
</tr>
244258
<tr>
245259
<th>4</th>
246260
<td>41 days</td>
247-
<td>0.536078</td>
248-
<td>0.732174</td>
249-
<td>0.519585</td>
250-
<td>0.060036</td>
251-
<td>0.049389</td>
252-
<td>0.060641</td>
253-
<td>0.678849</td>
261+
<td>0.536149</td>
262+
<td>0.732222</td>
263+
<td>0.518843</td>
264+
<td>0.059927</td>
265+
<td>0.049135</td>
266+
<td>0.060548</td>
267+
<td>0.681361</td>
268+
</tr>
269+
</tbody>
270+
</table>
271+
</div>
272+
273+
274+
275+
```python
276+
# Python
277+
from prophet.diagnostics import register_performance_metric, rolling_mean_by_h
278+
import numpy as np
279+
@register_performance_metric
280+
def mase(df, w):
281+
"""Mean absolute scale error
282+
283+
Parameters
284+
----------
285+
df: Cross-validation results dataframe.
286+
w: Aggregation window size.
287+
288+
Returns
289+
-------
290+
Dataframe with columns horizon and mase.
291+
"""
292+
e = (df['y'] - df['yhat'])
293+
d = np.abs(np.diff(df['y'])).sum()/(df['y'].shape[0]-1)
294+
se = np.abs(e/d)
295+
if w < 0:
296+
return pd.DataFrame({'horizon': df['horizon'], 'mase': se})
297+
return rolling_mean_by_h(
298+
x=se.values, h=df['horizon'].values, w=w, name='mase'
299+
)
300+
301+
df_mase = performance_metrics(df_cv, metrics=['mase'])
302+
df_mase.head()
303+
```
304+
305+
306+
307+
<div>
308+
<style scoped>
309+
.dataframe tbody tr th:only-of-type {
310+
vertical-align: middle;
311+
}
312+
313+
.dataframe tbody tr th {
314+
vertical-align: top;
315+
}
316+
317+
.dataframe thead th {
318+
text-align: right;
319+
}
320+
</style>
321+
<table border="1" class="dataframe">
322+
<thead>
323+
<tr style="text-align: right;">
324+
<th></th>
325+
<th>horizon</th>
326+
<th>mase</th>
327+
</tr>
328+
</thead>
329+
<tbody>
330+
<tr>
331+
<th>0</th>
332+
<td>37 days</td>
333+
<td>0.522946</td>
334+
</tr>
335+
<tr>
336+
<th>1</th>
337+
<td>38 days</td>
338+
<td>0.528102</td>
339+
</tr>
340+
<tr>
341+
<th>2</th>
342+
<td>39 days</td>
343+
<td>0.534401</td>
344+
</tr>
345+
<tr>
346+
<th>3</th>
347+
<td>40 days</td>
348+
<td>0.537365</td>
349+
</tr>
350+
<tr>
351+
<th>4</th>
352+
<td>41 days</td>
353+
<td>0.538372</td>
254354
</tr>
255355
</tbody>
256356
</table>
@@ -271,7 +371,7 @@ from prophet.plot import plot_cross_validation_metric
271371
fig = plot_cross_validation_metric(df_cv, metric='mape')
272372
```
273373

274-
![png](/prophet/static/diagnostics_files/diagnostics_17_0.png)
374+
![png](/prophet/static/diagnostics_files/diagnostics_18_0.png)
275375

276376

277377
The size of the rolling window in the figure can be changed with the optional argument `rolling_window`, which specifies the proportion of forecasts to use in each rolling window. The default is 0.1, corresponding to 10% of rows from `df_cv` included in each window; increasing this will lead to a smoother average curve in the figure. The `initial` period should be long enough to capture all of the components of the model, in particular seasonalities and extra regressors: at least a year for yearly seasonality, at least a week for weekly seasonality, etc.
@@ -355,33 +455,33 @@ for params in all_params:
355455
tuning_results = pd.DataFrame(all_params)
356456
tuning_results['rmse'] = rmses
357457
print(tuning_results)
358-
```
359-
changepoint_prior_scale seasonality_prior_scale rmse
360-
0 0.001 0.01 0.757694
361-
1 0.001 0.10 0.743399
362-
2 0.001 1.00 0.753387
363-
3 0.001 10.00 0.762890
364-
4 0.010 0.01 0.542315
365-
5 0.010 0.10 0.535546
366-
6 0.010 1.00 0.527008
367-
7 0.010 10.00 0.541544
368-
8 0.100 0.01 0.524835
369-
9 0.100 0.10 0.516061
370-
10 0.100 1.00 0.521406
371-
11 0.100 10.00 0.518580
372-
12 0.500 0.01 0.532140
373-
13 0.500 0.10 0.524668
374-
14 0.500 1.00 0.521130
375-
15 0.500 10.00 0.522980
376-
458+
```
459+
460+
changepoint_prior_scale seasonality_prior_scale rmse
461+
0 0.001 0.01 0.757694
462+
1 0.001 0.10 0.743399
463+
2 0.001 1.00 0.753387
464+
3 0.001 10.00 0.762890
465+
4 0.010 0.01 0.542315
466+
5 0.010 0.10 0.535546
467+
6 0.010 1.00 0.527008
468+
7 0.010 10.00 0.541544
469+
8 0.100 0.01 0.524835
470+
9 0.100 0.10 0.516061
471+
10 0.100 1.00 0.521406
472+
11 0.100 10.00 0.518580
473+
12 0.500 0.01 0.532140
474+
13 0.500 0.10 0.524668
475+
14 0.500 1.00 0.521130
476+
15 0.500 10.00 0.522980
377477

378478
```python
379479
# Python
380480
best_params = all_params[np.argmin(rmses)]
381481
print(best_params)
382482
```
383-
{'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.1}
384483

484+
{'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.1}
385485

386486
Alternatively, parallelization could be done across parameter combinations by parallelizing the loop above.
387487

Binary file not shown.
Loading
Loading
Loading

0 commit comments

Comments
 (0)