Skip to content

Commit b9758be

Browse files
authored
Merge pull request #5 from jdb78/test/models
Add more tests - testing multiple normalizers and settings for temporal fusion transformer
2 parents e909fc4 + 9164982 commit b9758be

File tree

4 files changed

+68
-40
lines changed

4 files changed

+68
-40
lines changed

.github/workflows/code_quality.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ jobs:
3030
# Enable linters
3131
black: true
3232
flake8: true
33-
mypy: true
33+
# mypy: true

pytorch_forecasting/data.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,22 @@ def to_dataloader(self, train: bool = True, batch_size: int = 64, **kwargs) -> D
11911191
Args:
11921192
train (bool, optional): if dataloader is used for training or prediction
11931193
Will shuffle and drop last batch if True. Defaults to True.
1194+
batch_size (int): batch size for training model. Defaults to 64.
1195+
**kwargs: additional arguments to ``DataLoader()``
1196+
1197+
1198+
Examples:
1199+
1200+
To samples for training:
1201+
1202+
.. code-block:: python
1203+
1204+
from torch.utils.data import WeightedRandomSampler
1205+
1206+
# length of probabilties for sampler have to be equal to the length of the index
1207+
probabilities = np.sqrt(1 + data.loc[dataset.index, "target"])
1208+
sampler = WeightedRandomSampler(probabilities, len(probabilities))
1209+
dataset.to_dataloader(train=True, sampler=sampler, shuffle=False)
11941210
11951211
Returns:
11961212
DataLoader: dataloader that returns Tuple.
@@ -1208,15 +1224,16 @@ def to_dataloader(self, train: bool = True, batch_size: int = 64, **kwargs) -> D
12081224
Second entry is target
12091225
)
12101226
"""
1211-
return DataLoader(
1212-
self,
1227+
default_kwargs = dict(
12131228
shuffle=train,
12141229
drop_last=train and len(self) > batch_size,
12151230
collate_fn=self._collate_fn,
12161231
batch_size=batch_size,
1217-
**kwargs,
12181232
)
12191233

1234+
default_kwargs.update(kwargs)
1235+
return DataLoader(self, **default_kwargs,)
1236+
12201237
def get_index(self) -> pd.DataFrame:
12211238
"""
12221239
Data index / order in which items are returned in train=False mode by dataloader.

tests/test_models/conftest.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
from data import get_stallion_data, generate_ar_data
44
from pytorch_forecasting import TimeSeriesDataSet
5-
from pytorch_forecasting.data import NaNLabelEncoder, EncoderNormalizer
5+
from pytorch_forecasting.data import GroupNormalizer, NaNLabelEncoder, EncoderNormalizer
66

77

88
@pytest.fixture
@@ -35,8 +35,42 @@ def data_with_covariates():
3535
return data
3636

3737

38-
@pytest.fixture
39-
def dataloaders_with_coveratiates(data_with_covariates):
38+
@pytest.fixture(
39+
params=[
40+
dict(),
41+
dict(
42+
static_categoricals=["agency", "sku"],
43+
static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
44+
time_varying_known_categoricals=["special_days", "month"],
45+
variable_groups=dict(
46+
special_days=[
47+
"easter_day",
48+
"good_friday",
49+
"new_year",
50+
"christmas",
51+
"labor_day",
52+
"independence_day",
53+
"revolution_day_memorial",
54+
"regional_games",
55+
"fifa_u_17_world_cup",
56+
"football_gold_cup",
57+
"beer_capital",
58+
"music_fest",
59+
]
60+
),
61+
time_varying_known_reals=["time_idx", "price_regular", "price_actual", "discount", "discount_in_percent"],
62+
time_varying_unknown_categoricals=[],
63+
time_varying_unknown_reals=["volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp"],
64+
constant_fill_strategy={"volume": 0},
65+
dropout_categoricals=["sku"],
66+
),
67+
dict(static_categoricals=["agency", "sku"]),
68+
dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2),
69+
dict(target_normalizer=GroupNormalizer(log_scale=True)),
70+
dict(target_normalizer=GroupNormalizer(groups=["agency", "sku"], coerce_positive=1.0)),
71+
]
72+
)
73+
def dataloaders_with_coveratiates(data_with_covariates, request):
4074
training_cutoff = "2016-09-01"
4175
max_encoder_length = 36
4276
max_prediction_length = 6
@@ -49,30 +83,7 @@ def dataloaders_with_coveratiates(data_with_covariates):
4983
group_ids=["agency", "sku"],
5084
max_encoder_length=max_encoder_length,
5185
max_prediction_length=max_prediction_length,
52-
static_categoricals=["agency", "sku"],
53-
static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
54-
time_varying_known_categoricals=["special_days", "month"],
55-
variable_groups=dict(
56-
special_days=[
57-
"easter_day",
58-
"good_friday",
59-
"new_year",
60-
"christmas",
61-
"labor_day",
62-
"independence_day",
63-
"revolution_day_memorial",
64-
"regional_games",
65-
"fifa_u_17_world_cup",
66-
"football_gold_cup",
67-
"beer_capital",
68-
"music_fest",
69-
]
70-
),
71-
time_varying_known_reals=["time_idx", "price_regular", "price_actual", "discount", "discount_in_percent"],
72-
time_varying_unknown_categoricals=[],
73-
time_varying_unknown_reals=["volume", "log_volume", "industry_volume", "soda_volume", "avg_max_temp"],
74-
constant_fill_strategy={"volume": 0},
75-
dropout_categoricals=["sku"],
86+
**request.param # fixture parametrization
7687
)
7788

7889
validation = TimeSeriesDataSet.from_dataset(
@@ -85,7 +96,7 @@ def dataloaders_with_coveratiates(data_with_covariates):
8596
return dict(train=train_dataloader, val=val_dataloader)
8697

8798

88-
@pytest.fixture
99+
@pytest.fixture()
89100
def dataloaders_fixed_window_without_coveratiates(data_with_covariates):
90101
data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=10)
91102
data["static"] = "2"

tests/test_models/test_temporal_fusion_transformer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from pytorch_forecasting.data import TimeSeriesDataSet
2+
import pytest
13
import shutil
24
import pytorch_lightning as pl
35
from pytorch_lightning.loggers import TensorBoardLogger
@@ -7,8 +9,6 @@
79

810

911
# todo: run with multiple normalizers
10-
# todo: run with muliple datasets and normalizers: ...
11-
# todo: monotonicity
1212
# todo: test different parameters
1313
def test_integration(dataloaders_with_coveratiates, tmp_path):
1414
train_dataloader = dataloaders_with_coveratiates["train"]
@@ -28,7 +28,11 @@ def test_integration(dataloaders_with_coveratiates, tmp_path):
2828
fast_dev_run=True,
2929
logger=logger,
3030
)
31-
31+
# test monotone constraints automatically
32+
if "discount_in_percent" in dataloaders_with_coveratiates["train"].dataset.reals:
33+
monotone_constaints = {"discount_in_percent": +1}
34+
else:
35+
monotone_constaints = {}
3236
net = TemporalFusionTransformer.from_dataset(
3337
train_dataloader.dataset,
3438
learning_rate=0.15,
@@ -40,7 +44,7 @@ def test_integration(dataloaders_with_coveratiates, tmp_path):
4044
log_interval=5,
4145
log_val_interval=1,
4246
log_gradient_flow=True,
43-
monotone_constaints={"discount_in_percent": +1},
47+
monotone_constaints=monotone_constaints,
4448
)
4549
net.size()
4650
try:
@@ -56,7 +60,3 @@ def test_integration(dataloaders_with_coveratiates, tmp_path):
5660
net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True)
5761
finally:
5862
shutil.rmtree(tmp_path, ignore_errors=True)
59-
60-
61-
def test_monotinicity():
62-
pass

0 commit comments

Comments
 (0)