Skip to content

Commit 6843748

Browse files
authored
Merge pull request #30 from jdb78/feature/dependency
Add calculation of partial dependencies
2 parents 41c5f3b + aa1ba56 commit 6843748

File tree

9 files changed

+765
-133
lines changed

9 files changed

+765
-133
lines changed

docs/source/tutorials/stallion.ipynb

Lines changed: 468 additions & 75 deletions
Large diffs are not rendered by default.

pytorch_forecasting/data/timeseries.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -691,14 +691,18 @@ def __len__(self) -> int:
691691
"""
692692
return self.index.shape[0]
693693

694-
def set_overwrite_values(self, values: Union[float, torch.Tensor], variable: str, target: str = "decoder") -> None:
694+
def set_overwrite_values(
695+
self, values: Union[float, torch.Tensor], variable: str, target: Union[str, slice] = "decoder"
696+
) -> None:
695697
"""
696698
Convenience method to quickly overwrite values in decoder or encoder (or both) for a specific variable.
697699
698700
Args:
699701
values (Union[float, torch.Tensor]): values to use for overwrite.
700702
variable (str): variable whose values should be overwritten.
701-
target (str, optional): positions to overwrite. One of "decoder", "encoder" or "all". Defaults to "decoder".
703+
target (Union[str, slice], optional): positions to overwrite. One of "decoder", "encoder" or "all" or
704+
a slice object which is directly used to overwrite indices, e.g. ``slice(-5, None)`` will overwrite
705+
the last 5 values. Defaults to "decoder".
702706
"""
703707
values = torch.tensor(self.transform_values(variable, np.asarray(values).reshape(-1), inverse=False)).squeeze()
704708
assert target in [
@@ -707,6 +711,9 @@ def set_overwrite_values(self, values: Union[float, torch.Tensor], variable: str
707711
"encoder",
708712
], f"target has be one of 'all', 'decoder' or 'encoder' but target={target} instead"
709713

714+
if variable in self.static_categoricals or variable in self.static_categoricals:
715+
target = "all"
716+
710717
if variable == self.target:
711718
raise NotImplementedError("Target variable is not supported")
712719
if self.weight is not None and self.weight == variable:
@@ -856,7 +863,9 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
856863

857864
# overwrite values
858865
if self._overwrite_values is not None:
859-
if self._overwrite_values["target"] == "all":
866+
if isinstance(self._overwrite_values["target"], slice):
867+
positions = self._overwrite_values["target"]
868+
elif self._overwrite_values["target"] == "all":
860869
positions = slice(None)
861870
elif self._overwrite_values["target"] == "encoder":
862871
positions = slice(None, encoder_length)

pytorch_forecasting/metrics.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,19 @@ def forward(self, y_pred: Dict[str, torch.Tensor], target: Union[torch.Tensor, r
137137
mask = torch.arange(target.size(1), device=target.device).unsqueeze(0) >= lengths.unsqueeze(-1)
138138
if losses.ndim > 2:
139139
mask = mask.unsqueeze(-1)
140+
dim_normalizer = losses.size(-1)
141+
else:
142+
dim_normalizer = 1.0
140143
# reduce to one number
141144
if self.reduction == "none":
142145
loss = losses.masked_fill(mask, float("nan"))
143146
else:
144147
if self.reduction == "mean":
145148
losses = losses.masked_fill(mask, 0.0)
146-
loss = losses.sum() / lengths.sum()
149+
loss = losses.sum() / lengths.sum() / dim_normalizer
147150
elif self.reduction == "sqrt-mean":
148151
losses = losses.masked_fill(mask, 0.0)
149-
loss = losses.sum() / lengths.sum()
152+
loss = losses.sum() / lengths.sum() / dim_normalizer
150153
loss = loss.sqrt()
151154
assert not torch.isnan(loss), (
152155
"Loss should not be nan - i.e. something went wrong "

pytorch_forecasting/models/base_model.py

Lines changed: 173 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
from copy import deepcopy
55
import inspect
6+
from pytorch_forecasting.data.encoders import GroupNormalizer
67
from torch import unsqueeze
78
from torch import optim
89
import cloudpickle
@@ -11,7 +12,7 @@
1112
from tqdm.notebook import tqdm
1213

1314
from pytorch_forecasting.metrics import SMAPE
14-
from typing import Any, Callable, Dict, List, Tuple, Union
15+
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
1516
from pytorch_lightning import LightningModule
1617
from pytorch_lightning.metrics.metric import TensorMetric
1718
from pytorch_forecasting.optim import Ranger
@@ -50,13 +51,6 @@ def forward(self, x):
5051
encoding_target = x["encoder_target"]
5152
return dict(prediction=..., target_scale=x["target_scale"])
5253
53-
# implement lightning steps
54-
def training_step(self, batch, batch_idx):
55-
x, y = batch
56-
return {"loss": self.loss(self(x), y)}
57-
58-
# implement further steps
59-
6054
"""
6155

6256
def __init__(
@@ -516,7 +510,7 @@ def predict(
516510
batch_size: batch size for dataloader - only used if data is not a dataloader is passed
517511
num_workers: number of workers for dataloader - only used if data is not a dataloader is passed
518512
fast_dev_run: if to only return results of first batch
519-
show_progress_bar: if to show progress bar. Defaults to True
513+
show_progress_bar: if to show progress bar. Defaults to False.
520514
return_x: if to return network inputs
521515
522516
Returns:
@@ -608,6 +602,118 @@ def predict(
608602
output.append(torch.cat(decode_lenghts, dim=0))
609603
return output
610604

605+
def predict_dependency(
606+
self,
607+
data: Union[DataLoader, pd.DataFrame, TimeSeriesDataSet],
608+
variable: str,
609+
values: Iterable,
610+
mode: str = "dataframe",
611+
target="decoder",
612+
show_progress_bar: bool = False,
613+
**kwargs,
614+
) -> Union[np.ndarray, torch.Tensor, pd.Series, pd.DataFrame]:
615+
"""
616+
Predict partial dependency.
617+
618+
619+
Args:
620+
data (Union[DataLoader, pd.DataFrame, TimeSeriesDataSet]): data
621+
variable (str): variable which to modify
622+
values (Iterable): array of values to probe
623+
mode (str, optional): Output mode. Defaults to "dataframe". Either
624+
625+
* "series": values are average prediction and index are probed values
626+
* "dataframe": columns are as obtained by the `dataset.get_index()` method,
627+
prediction (which is the mean prediction over the time horizon),
628+
normalized_prediction (which are predictions devided by the prediction for the first probed value)
629+
the variable name for the probed values
630+
* "raw": outputs a tensor of shape len(values) x prediction_shape
631+
632+
target: Defines which values are overwritten for making a prediction.
633+
Same as in :py:meth:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet.set_overwrite_values`.
634+
Defaults to "decoder".
635+
show_progress_bar: if to show progress bar. Defaults to False.
636+
**kwargs: additional kwargs to :py:meth:`~predict` method
637+
638+
Returns:
639+
Union[np.ndarray, torch.Tensor, pd.Series, pd.DataFrame]: output
640+
"""
641+
values = np.asarray(values)
642+
if isinstance(data, pd.DataFrame): # convert to dataframe
643+
data = TimeSeriesDataSet.from_parameters(self.dataset_parameters, data, predict=True)
644+
elif isinstance(data, DataLoader):
645+
data = data.dataset
646+
647+
results = []
648+
progress_bar = tqdm(desc="Predict", unit=" batches", total=len(values), disable=not show_progress_bar)
649+
for value in values:
650+
# set values
651+
data.set_overwrite_values(variable=variable, values=value, target=target)
652+
# predict
653+
kwargs.setdefault("mode", "prediction")
654+
results.append(self.predict(data, **kwargs))
655+
# increment progress
656+
progress_bar.update()
657+
658+
data.reset_overwrite_values() # reset overwrite values to avoid side-effect
659+
660+
# results to one tensor
661+
results = torch.stack(results, dim=0)
662+
663+
# convert results to requested output format
664+
if mode == "series":
665+
results = results[:, ~torch.isnan(results[0])].mean(1) # average samples and prediction horizon
666+
results = pd.Series(results, index=values)
667+
668+
elif mode == "dataframe":
669+
# take mean over time
670+
is_nan = torch.isnan(results)
671+
results[is_nan] = 0
672+
results = results.sum(-1) / (~is_nan).float().sum(-1)
673+
674+
# create dataframe
675+
dependencies = data.get_index()
676+
dependencies = (
677+
dependencies.iloc[np.tile(np.arange(len(dependencies)), len(values))]
678+
.reset_index(drop=True)
679+
.assign(prediction=results.flatten())
680+
)
681+
dependencies[variable] = values.repeat(len(data))
682+
first_prediction = dependencies.groupby(data.group_ids, observed=True).prediction.transform("first")
683+
dependencies["normalized_prediction"] = dependencies["prediction"] / first_prediction
684+
dependencies["id"] = dependencies.groupby(data.group_ids, observed=True).ngroup()
685+
results = dependencies
686+
687+
elif mode == "raw":
688+
pass
689+
690+
else:
691+
raise ValueError(f"mode {mode} is unknown - see documentation for available modes")
692+
693+
return results
694+
695+
696+
class CovariatesMixin:
697+
"""
698+
Model mix-in for additional methods using covariates.
699+
700+
Assumes the following hyperparameters:
701+
702+
Args:
703+
x_reals: order of continuous variables in tensor passed to forward function
704+
x_categoricals: order of categorical variables in tensor passed to forward function
705+
embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and
706+
embedding size
707+
embedding_labels: dictionary mapping (string) indices to list of categorical labels
708+
"""
709+
710+
@property
711+
def categorical_groups_mapping(self) -> Dict[str, str]:
712+
groups = {}
713+
for group_name, sublist in self.hparams.categorical_groups.items():
714+
groups.update({name: group_name for name in sublist})
715+
return groups
716+
611717
def calculate_prediction_actual_by_variable(
612718
self,
613719
x: Dict[str, torch.Tensor],
@@ -621,13 +727,13 @@ def calculate_prediction_actual_by_variable(
621727
622728
Args:
623729
x: input as ``forward()``
624-
y_pred: predictions obtained by ``self.loss.to_prediction(self(x))``
730+
y_pred: predictions obtained by ``self.transform_output(self(x))``
625731
normalize: if to return normalized averages, i.e. mean or sum of ``y``
626732
bins: number of bins to calculate
627733
std: number of standard deviations for standard scaled continuous variables
628734
629735
Returns:
630-
dictionary that can be used to plot averages with ``plot_prediction_actual_by_variable()``
736+
dictionary that can be used to plot averages with :py:meth:`~plot_prediction_actual_by_variable`
631737
"""
632738
support = {} # histogram
633739
# averages
@@ -640,7 +746,10 @@ def calculate_prediction_actual_by_variable(
640746
# select valid y values
641747
y_flat = x["decoder_target"][mask]
642748
y_pred_flat = y_pred[mask]
643-
if self.loss.log_space:
749+
log_y = self.dataset_parameters["target_normalizer"] is not None and getattr(
750+
self.dataset_parameters["target_normalizer"], "log_scale", False
751+
)
752+
if log_y:
644753
y_flat = torch.log(y_flat + 1e-8)
645754
y_pred_flat = torch.log(y_pred_flat + 1e-8)
646755

@@ -675,28 +784,51 @@ def calculate_prediction_actual_by_variable(
675784
# categorical_variables
676785
cats = x["decoder_cat"]
677786
for idx, name in enumerate(self.hparams.x_categoricals): # todo: make it work for grouped categoricals
678-
averages_actual[name], support[name] = groupby_apply(
787+
reduction = "sum"
788+
name = self.categorical_groups_mapping.get(name, name)
789+
averages_actual_cat, support_cat = groupby_apply(
679790
cats[..., idx][mask],
680791
y_flat,
681-
bins=self.hparams.embedding_sizes[idx][0],
792+
bins=self.hparams.embedding_sizes[name][0],
682793
reduction=reduction,
683794
return_histogram=True,
684795
)
685-
averages_prediction[name], _ = groupby_apply(
796+
averages_prediction_cat, _ = groupby_apply(
686797
cats[..., idx][mask],
687798
y_pred_flat,
688-
bins=self.hparams.embedding_sizes[idx][0],
799+
bins=self.hparams.embedding_sizes[name][0],
689800
reduction=reduction,
690801
return_histogram=True,
691802
)
803+
804+
# add either to existing calculations or
805+
if name in averages_actual:
806+
averages_actual[name] += averages_actual_cat
807+
support[name] += support_cat
808+
averages_prediction[name] += averages_prediction_cat
809+
else:
810+
averages_actual[name] = averages_actual_cat
811+
support[name] = support_cat
812+
averages_prediction[name] = averages_prediction_cat
813+
814+
if normalize: # run reduction for categoricals
815+
for name in self.hparams.embedding_sizes.keys():
816+
averages_actual[name] /= support[name].clamp(min=1)
817+
averages_prediction[name] /= support[name].clamp(min=1)
818+
819+
if log_y: # reverse log scaling
820+
for name in support.keys():
821+
averages_actual[name] = torch.exp(averages_actual[name])
822+
averages_prediction[name] = torch.exp(averages_prediction[name])
823+
692824
return {
693825
"support": support,
694826
"average": {"actual": averages_actual, "prediction": averages_prediction},
695827
"std": std,
696828
}
697829

698830
def plot_prediction_actual_by_variable(
699-
self, data: Dict[str, Dict[str, torch.Tensor]], name: str = None
831+
self, data: Dict[str, Dict[str, torch.Tensor]], name: str = None, ax=None
700832
) -> Union[Dict[str, plt.Figure], plt.Figure]:
701833
"""
702834
Plot predicions and actual averages by variables
@@ -720,23 +852,29 @@ def plot_prediction_actual_by_variable(
720852
# create figure
721853
kwargs = {}
722854
# adjust figure size for figures with many labels
723-
if self.hparams.embedding_sizes[name][0] > 10:
855+
if self.hparams.embedding_sizes.get(name, [1e9])[0] > 10:
724856
kwargs = dict(figsize=(10, 5))
725-
fig, ax = plt.subplots(**kwargs)
857+
if ax is None:
858+
fig, ax = plt.subplots(**kwargs)
859+
else:
860+
fig = ax.get_figure()
726861
ax.set_title(f"{name} averages")
727862
ax.set_xlabel(name)
728-
if self.loss.log_space:
729-
ax.set_ylabel("Log prediction")
730-
else:
731-
ax.set_ylabel("Prediction")
863+
ax.set_ylabel("Prediction")
864+
732865
ax2 = ax.twinx() # second axis for histogram
733866
ax2.set_ylabel("Frequency")
734867

735868
# get values for average plot and histogram
736869
values_actual = data["average"]["actual"][name].cpu().numpy()
737870
values_prediction = data["average"]["prediction"][name].cpu().numpy()
738871
bins = values_actual.size
739-
support = data["average"][name].cpu().numpy()
872+
support = data["support"][name].cpu().numpy()
873+
874+
if self.dataset_parameters["target_normalizer"] is not None and getattr(
875+
self.dataset_parameters["target_normalizer"], "log_scale", False
876+
):
877+
ax.set_yscale("log")
740878

741879
# only display values where samples were observed
742880
support_non_zero = support > 0
@@ -746,8 +884,14 @@ def plot_prediction_actual_by_variable(
746884

747885
# plot averages
748886
if name in self.hparams.x_reals:
749-
mean, scale = self.dataset_parameters.scalers[name].mean, self.dataset_parameters.scalers[name].scale
750-
x = np.linspace(-data["std"], data["std"], bins) * scale + mean
887+
# create x
888+
scaler = self.dataset_parameters["scalers"][name]
889+
x = np.linspace(-data["std"], data["std"], bins)
890+
# reversing normalization for group normalizer is not possible without sample level information
891+
if not isinstance(scaler, GroupNormalizer):
892+
x = scaler.inverse_transform(x)
893+
ax.set_xlabel(f"Normalized {name}")
894+
751895
if len(x) > 0:
752896
x_step = x[1] - x[0]
753897
else:
@@ -759,7 +903,7 @@ def plot_prediction_actual_by_variable(
759903
elif name in self.hparams.embedding_labels:
760904
# sort values from lowest to highest
761905
sorting = values_actual.argsort()
762-
labels = np.asarray(self.hparams.embedding_labels[name])[support_non_zero][sorting]
906+
labels = np.asarray(list(self.hparams.embedding_labels[name].keys()))[support_non_zero][sorting]
763907
values_actual = values_actual[sorting]
764908
values_prediction = values_prediction[sorting]
765909
support = support[sorting]
@@ -783,6 +927,8 @@ def plot_prediction_actual_by_variable(
783927
else:
784928
raise ValueError(f"Unknown name {name}")
785929
# plot support histogram
930+
if len(support) > 1 and np.median(support) < support.max() / 10:
931+
ax2.set_yscale("log")
786932
ax2.bar(x, support, width=x_step, linewidth=0, alpha=0.2, color="k")
787933
# adjust layout and legend
788934
fig.tight_layout()

0 commit comments

Comments
 (0)