3
3
"""
4
4
from copy import deepcopy
5
5
import inspect
6
+ from pytorch_forecasting .data .encoders import GroupNormalizer
6
7
from torch import unsqueeze
7
8
from torch import optim
8
9
import cloudpickle
11
12
from tqdm .notebook import tqdm
12
13
13
14
from pytorch_forecasting .metrics import SMAPE
14
- from typing import Any , Callable , Dict , List , Tuple , Union
15
+ from typing import Any , Callable , Dict , Iterable , List , Tuple , Union
15
16
from pytorch_lightning import LightningModule
16
17
from pytorch_lightning .metrics .metric import TensorMetric
17
18
from pytorch_forecasting .optim import Ranger
@@ -50,13 +51,6 @@ def forward(self, x):
50
51
encoding_target = x["encoder_target"]
51
52
return dict(prediction=..., target_scale=x["target_scale"])
52
53
53
- # implement lightning steps
54
- def training_step(self, batch, batch_idx):
55
- x, y = batch
56
- return {"loss": self.loss(self(x), y)}
57
-
58
- # implement further steps
59
-
60
54
"""
61
55
62
56
def __init__ (
@@ -516,7 +510,7 @@ def predict(
516
510
batch_size: batch size for dataloader - only used if data is not a dataloader is passed
517
511
num_workers: number of workers for dataloader - only used if data is not a dataloader is passed
518
512
fast_dev_run: if to only return results of first batch
519
- show_progress_bar: if to show progress bar. Defaults to True
513
+ show_progress_bar: if to show progress bar. Defaults to False.
520
514
return_x: if to return network inputs
521
515
522
516
Returns:
@@ -608,6 +602,118 @@ def predict(
608
602
output .append (torch .cat (decode_lenghts , dim = 0 ))
609
603
return output
610
604
605
+ def predict_dependency (
606
+ self ,
607
+ data : Union [DataLoader , pd .DataFrame , TimeSeriesDataSet ],
608
+ variable : str ,
609
+ values : Iterable ,
610
+ mode : str = "dataframe" ,
611
+ target = "decoder" ,
612
+ show_progress_bar : bool = False ,
613
+ ** kwargs ,
614
+ ) -> Union [np .ndarray , torch .Tensor , pd .Series , pd .DataFrame ]:
615
+ """
616
+ Predict partial dependency.
617
+
618
+
619
+ Args:
620
+ data (Union[DataLoader, pd.DataFrame, TimeSeriesDataSet]): data
621
+ variable (str): variable which to modify
622
+ values (Iterable): array of values to probe
623
+ mode (str, optional): Output mode. Defaults to "dataframe". Either
624
+
625
+ * "series": values are average prediction and index are probed values
626
+ * "dataframe": columns are as obtained by the `dataset.get_index()` method,
627
+ prediction (which is the mean prediction over the time horizon),
628
+ normalized_prediction (which are predictions devided by the prediction for the first probed value)
629
+ the variable name for the probed values
630
+ * "raw": outputs a tensor of shape len(values) x prediction_shape
631
+
632
+ target: Defines which values are overwritten for making a prediction.
633
+ Same as in :py:meth:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet.set_overwrite_values`.
634
+ Defaults to "decoder".
635
+ show_progress_bar: if to show progress bar. Defaults to False.
636
+ **kwargs: additional kwargs to :py:meth:`~predict` method
637
+
638
+ Returns:
639
+ Union[np.ndarray, torch.Tensor, pd.Series, pd.DataFrame]: output
640
+ """
641
+ values = np .asarray (values )
642
+ if isinstance (data , pd .DataFrame ): # convert to dataframe
643
+ data = TimeSeriesDataSet .from_parameters (self .dataset_parameters , data , predict = True )
644
+ elif isinstance (data , DataLoader ):
645
+ data = data .dataset
646
+
647
+ results = []
648
+ progress_bar = tqdm (desc = "Predict" , unit = " batches" , total = len (values ), disable = not show_progress_bar )
649
+ for value in values :
650
+ # set values
651
+ data .set_overwrite_values (variable = variable , values = value , target = target )
652
+ # predict
653
+ kwargs .setdefault ("mode" , "prediction" )
654
+ results .append (self .predict (data , ** kwargs ))
655
+ # increment progress
656
+ progress_bar .update ()
657
+
658
+ data .reset_overwrite_values () # reset overwrite values to avoid side-effect
659
+
660
+ # results to one tensor
661
+ results = torch .stack (results , dim = 0 )
662
+
663
+ # convert results to requested output format
664
+ if mode == "series" :
665
+ results = results [:, ~ torch .isnan (results [0 ])].mean (1 ) # average samples and prediction horizon
666
+ results = pd .Series (results , index = values )
667
+
668
+ elif mode == "dataframe" :
669
+ # take mean over time
670
+ is_nan = torch .isnan (results )
671
+ results [is_nan ] = 0
672
+ results = results .sum (- 1 ) / (~ is_nan ).float ().sum (- 1 )
673
+
674
+ # create dataframe
675
+ dependencies = data .get_index ()
676
+ dependencies = (
677
+ dependencies .iloc [np .tile (np .arange (len (dependencies )), len (values ))]
678
+ .reset_index (drop = True )
679
+ .assign (prediction = results .flatten ())
680
+ )
681
+ dependencies [variable ] = values .repeat (len (data ))
682
+ first_prediction = dependencies .groupby (data .group_ids , observed = True ).prediction .transform ("first" )
683
+ dependencies ["normalized_prediction" ] = dependencies ["prediction" ] / first_prediction
684
+ dependencies ["id" ] = dependencies .groupby (data .group_ids , observed = True ).ngroup ()
685
+ results = dependencies
686
+
687
+ elif mode == "raw" :
688
+ pass
689
+
690
+ else :
691
+ raise ValueError (f"mode { mode } is unknown - see documentation for available modes" )
692
+
693
+ return results
694
+
695
+
696
+ class CovariatesMixin :
697
+ """
698
+ Model mix-in for additional methods using covariates.
699
+
700
+ Assumes the following hyperparameters:
701
+
702
+ Args:
703
+ x_reals: order of continuous variables in tensor passed to forward function
704
+ x_categoricals: order of categorical variables in tensor passed to forward function
705
+ embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and
706
+ embedding size
707
+ embedding_labels: dictionary mapping (string) indices to list of categorical labels
708
+ """
709
+
710
+ @property
711
+ def categorical_groups_mapping (self ) -> Dict [str , str ]:
712
+ groups = {}
713
+ for group_name , sublist in self .hparams .categorical_groups .items ():
714
+ groups .update ({name : group_name for name in sublist })
715
+ return groups
716
+
611
717
def calculate_prediction_actual_by_variable (
612
718
self ,
613
719
x : Dict [str , torch .Tensor ],
@@ -621,13 +727,13 @@ def calculate_prediction_actual_by_variable(
621
727
622
728
Args:
623
729
x: input as ``forward()``
624
- y_pred: predictions obtained by ``self.loss.to_prediction (self(x))``
730
+ y_pred: predictions obtained by ``self.transform_output (self(x))``
625
731
normalize: if to return normalized averages, i.e. mean or sum of ``y``
626
732
bins: number of bins to calculate
627
733
std: number of standard deviations for standard scaled continuous variables
628
734
629
735
Returns:
630
- dictionary that can be used to plot averages with `` plot_prediction_actual_by_variable()` `
736
+ dictionary that can be used to plot averages with :py:meth:`~ plot_prediction_actual_by_variable`
631
737
"""
632
738
support = {} # histogram
633
739
# averages
@@ -640,7 +746,10 @@ def calculate_prediction_actual_by_variable(
640
746
# select valid y values
641
747
y_flat = x ["decoder_target" ][mask ]
642
748
y_pred_flat = y_pred [mask ]
643
- if self .loss .log_space :
749
+ log_y = self .dataset_parameters ["target_normalizer" ] is not None and getattr (
750
+ self .dataset_parameters ["target_normalizer" ], "log_scale" , False
751
+ )
752
+ if log_y :
644
753
y_flat = torch .log (y_flat + 1e-8 )
645
754
y_pred_flat = torch .log (y_pred_flat + 1e-8 )
646
755
@@ -675,28 +784,51 @@ def calculate_prediction_actual_by_variable(
675
784
# categorical_variables
676
785
cats = x ["decoder_cat" ]
677
786
for idx , name in enumerate (self .hparams .x_categoricals ): # todo: make it work for grouped categoricals
678
- averages_actual [name ], support [name ] = groupby_apply (
787
+ reduction = "sum"
788
+ name = self .categorical_groups_mapping .get (name , name )
789
+ averages_actual_cat , support_cat = groupby_apply (
679
790
cats [..., idx ][mask ],
680
791
y_flat ,
681
- bins = self .hparams .embedding_sizes [idx ][0 ],
792
+ bins = self .hparams .embedding_sizes [name ][0 ],
682
793
reduction = reduction ,
683
794
return_histogram = True ,
684
795
)
685
- averages_prediction [ name ] , _ = groupby_apply (
796
+ averages_prediction_cat , _ = groupby_apply (
686
797
cats [..., idx ][mask ],
687
798
y_pred_flat ,
688
- bins = self .hparams .embedding_sizes [idx ][0 ],
799
+ bins = self .hparams .embedding_sizes [name ][0 ],
689
800
reduction = reduction ,
690
801
return_histogram = True ,
691
802
)
803
+
804
+ # add either to existing calculations or
805
+ if name in averages_actual :
806
+ averages_actual [name ] += averages_actual_cat
807
+ support [name ] += support_cat
808
+ averages_prediction [name ] += averages_prediction_cat
809
+ else :
810
+ averages_actual [name ] = averages_actual_cat
811
+ support [name ] = support_cat
812
+ averages_prediction [name ] = averages_prediction_cat
813
+
814
+ if normalize : # run reduction for categoricals
815
+ for name in self .hparams .embedding_sizes .keys ():
816
+ averages_actual [name ] /= support [name ].clamp (min = 1 )
817
+ averages_prediction [name ] /= support [name ].clamp (min = 1 )
818
+
819
+ if log_y : # reverse log scaling
820
+ for name in support .keys ():
821
+ averages_actual [name ] = torch .exp (averages_actual [name ])
822
+ averages_prediction [name ] = torch .exp (averages_prediction [name ])
823
+
692
824
return {
693
825
"support" : support ,
694
826
"average" : {"actual" : averages_actual , "prediction" : averages_prediction },
695
827
"std" : std ,
696
828
}
697
829
698
830
def plot_prediction_actual_by_variable (
699
- self , data : Dict [str , Dict [str , torch .Tensor ]], name : str = None
831
+ self , data : Dict [str , Dict [str , torch .Tensor ]], name : str = None , ax = None
700
832
) -> Union [Dict [str , plt .Figure ], plt .Figure ]:
701
833
"""
702
834
Plot predicions and actual averages by variables
@@ -720,23 +852,29 @@ def plot_prediction_actual_by_variable(
720
852
# create figure
721
853
kwargs = {}
722
854
# adjust figure size for figures with many labels
723
- if self .hparams .embedding_sizes [ name ] [0 ] > 10 :
855
+ if self .hparams .embedding_sizes . get ( name , [ 1e9 ]) [0 ] > 10 :
724
856
kwargs = dict (figsize = (10 , 5 ))
725
- fig , ax = plt .subplots (** kwargs )
857
+ if ax is None :
858
+ fig , ax = plt .subplots (** kwargs )
859
+ else :
860
+ fig = ax .get_figure ()
726
861
ax .set_title (f"{ name } averages" )
727
862
ax .set_xlabel (name )
728
- if self .loss .log_space :
729
- ax .set_ylabel ("Log prediction" )
730
- else :
731
- ax .set_ylabel ("Prediction" )
863
+ ax .set_ylabel ("Prediction" )
864
+
732
865
ax2 = ax .twinx () # second axis for histogram
733
866
ax2 .set_ylabel ("Frequency" )
734
867
735
868
# get values for average plot and histogram
736
869
values_actual = data ["average" ]["actual" ][name ].cpu ().numpy ()
737
870
values_prediction = data ["average" ]["prediction" ][name ].cpu ().numpy ()
738
871
bins = values_actual .size
739
- support = data ["average" ][name ].cpu ().numpy ()
872
+ support = data ["support" ][name ].cpu ().numpy ()
873
+
874
+ if self .dataset_parameters ["target_normalizer" ] is not None and getattr (
875
+ self .dataset_parameters ["target_normalizer" ], "log_scale" , False
876
+ ):
877
+ ax .set_yscale ("log" )
740
878
741
879
# only display values where samples were observed
742
880
support_non_zero = support > 0
@@ -746,8 +884,14 @@ def plot_prediction_actual_by_variable(
746
884
747
885
# plot averages
748
886
if name in self .hparams .x_reals :
749
- mean , scale = self .dataset_parameters .scalers [name ].mean , self .dataset_parameters .scalers [name ].scale
750
- x = np .linspace (- data ["std" ], data ["std" ], bins ) * scale + mean
887
+ # create x
888
+ scaler = self .dataset_parameters ["scalers" ][name ]
889
+ x = np .linspace (- data ["std" ], data ["std" ], bins )
890
+ # reversing normalization for group normalizer is not possible without sample level information
891
+ if not isinstance (scaler , GroupNormalizer ):
892
+ x = scaler .inverse_transform (x )
893
+ ax .set_xlabel (f"Normalized { name } " )
894
+
751
895
if len (x ) > 0 :
752
896
x_step = x [1 ] - x [0 ]
753
897
else :
@@ -759,7 +903,7 @@ def plot_prediction_actual_by_variable(
759
903
elif name in self .hparams .embedding_labels :
760
904
# sort values from lowest to highest
761
905
sorting = values_actual .argsort ()
762
- labels = np .asarray (self .hparams .embedding_labels [name ])[support_non_zero ][sorting ]
906
+ labels = np .asarray (list ( self .hparams .embedding_labels [name ]. keys ()) )[support_non_zero ][sorting ]
763
907
values_actual = values_actual [sorting ]
764
908
values_prediction = values_prediction [sorting ]
765
909
support = support [sorting ]
@@ -783,6 +927,8 @@ def plot_prediction_actual_by_variable(
783
927
else :
784
928
raise ValueError (f"Unknown name { name } " )
785
929
# plot support histogram
930
+ if len (support ) > 1 and np .median (support ) < support .max () / 10 :
931
+ ax2 .set_yscale ("log" )
786
932
ax2 .bar (x , support , width = x_step , linewidth = 0 , alpha = 0.2 , color = "k" )
787
933
# adjust layout and legend
788
934
fig .tight_layout ()
0 commit comments