2
2
import numpy as np
3
3
from data import get_stallion_data , generate_ar_data
4
4
from pytorch_forecasting import TimeSeriesDataSet
5
- from pytorch_forecasting .data import NaNLabelEncoder , EncoderNormalizer
5
+ from pytorch_forecasting .data import GroupNormalizer , NaNLabelEncoder , EncoderNormalizer
6
6
7
7
8
8
@pytest .fixture
@@ -35,8 +35,42 @@ def data_with_covariates():
35
35
return data
36
36
37
37
38
- @pytest .fixture
39
- def dataloaders_with_coveratiates (data_with_covariates ):
38
+ @pytest .fixture (
39
+ params = [
40
+ dict (),
41
+ dict (
42
+ static_categoricals = ["agency" , "sku" ],
43
+ static_reals = ["avg_population_2017" , "avg_yearly_household_income_2017" ],
44
+ time_varying_known_categoricals = ["special_days" , "month" ],
45
+ variable_groups = dict (
46
+ special_days = [
47
+ "easter_day" ,
48
+ "good_friday" ,
49
+ "new_year" ,
50
+ "christmas" ,
51
+ "labor_day" ,
52
+ "independence_day" ,
53
+ "revolution_day_memorial" ,
54
+ "regional_games" ,
55
+ "fifa_u_17_world_cup" ,
56
+ "football_gold_cup" ,
57
+ "beer_capital" ,
58
+ "music_fest" ,
59
+ ]
60
+ ),
61
+ time_varying_known_reals = ["time_idx" , "price_regular" , "price_actual" , "discount" , "discount_in_percent" ],
62
+ time_varying_unknown_categoricals = [],
63
+ time_varying_unknown_reals = ["volume" , "log_volume" , "industry_volume" , "soda_volume" , "avg_max_temp" ],
64
+ constant_fill_strategy = {"volume" : 0 },
65
+ dropout_categoricals = ["sku" ],
66
+ ),
67
+ dict (static_categoricals = ["agency" , "sku" ]),
68
+ dict (target_normalizer = EncoderNormalizer (), min_encoder_length = 2 ),
69
+ dict (target_normalizer = GroupNormalizer (log_scale = True )),
70
+ dict (target_normalizer = GroupNormalizer (groups = ["agency" , "sku" ], coerce_positive = 1.0 )),
71
+ ]
72
+ )
73
+ def dataloaders_with_coveratiates (data_with_covariates , request ):
40
74
training_cutoff = "2016-09-01"
41
75
max_encoder_length = 36
42
76
max_prediction_length = 6
@@ -49,30 +83,7 @@ def dataloaders_with_coveratiates(data_with_covariates):
49
83
group_ids = ["agency" , "sku" ],
50
84
max_encoder_length = max_encoder_length ,
51
85
max_prediction_length = max_prediction_length ,
52
- static_categoricals = ["agency" , "sku" ],
53
- static_reals = ["avg_population_2017" , "avg_yearly_household_income_2017" ],
54
- time_varying_known_categoricals = ["special_days" , "month" ],
55
- variable_groups = dict (
56
- special_days = [
57
- "easter_day" ,
58
- "good_friday" ,
59
- "new_year" ,
60
- "christmas" ,
61
- "labor_day" ,
62
- "independence_day" ,
63
- "revolution_day_memorial" ,
64
- "regional_games" ,
65
- "fifa_u_17_world_cup" ,
66
- "football_gold_cup" ,
67
- "beer_capital" ,
68
- "music_fest" ,
69
- ]
70
- ),
71
- time_varying_known_reals = ["time_idx" , "price_regular" , "price_actual" , "discount" , "discount_in_percent" ],
72
- time_varying_unknown_categoricals = [],
73
- time_varying_unknown_reals = ["volume" , "log_volume" , "industry_volume" , "soda_volume" , "avg_max_temp" ],
74
- constant_fill_strategy = {"volume" : 0 },
75
- dropout_categoricals = ["sku" ],
86
+ ** request .param # fixture parametrization
76
87
)
77
88
78
89
validation = TimeSeriesDataSet .from_dataset (
@@ -85,7 +96,7 @@ def dataloaders_with_coveratiates(data_with_covariates):
85
96
return dict (train = train_dataloader , val = val_dataloader )
86
97
87
98
88
- @pytest .fixture
99
+ @pytest .fixture ()
89
100
def dataloaders_fixed_window_without_coveratiates (data_with_covariates ):
90
101
data = generate_ar_data (seasonality = 10.0 , timesteps = 400 , n_series = 10 )
91
102
data ["static" ] = "2"
0 commit comments