@@ -17,36 +17,9 @@ def __init__(self, cfg: dict):
17
17
self .num_split = cfg .TRAIN .CUSTOM .NUM_SPLIT
18
18
self .perm = None
19
19
20
- def select_input_features (self , data : torch .Tensor ) -> torch .Tensor :
21
- """Select input features.
22
-
23
- Args:
24
- data (torch.Tensor): input history data, shape [B, L, N, C]
25
-
26
- Returns:
27
- torch.Tensor: reshaped data
28
- """
29
-
30
- # select feature using self.forward_features
31
- if self .forward_features is not None :
32
- data = data [:, :, :, self .forward_features ]
33
- return data
34
-
35
- def select_target_features (self , data : torch .Tensor ) -> torch .Tensor :
36
- """Select target feature
37
-
38
- Args:
39
- data (torch.Tensor): prediction of the model with arbitrary shape.
40
-
41
- Returns:
42
- torch.Tensor: reshaped data with shape [B, L, N, C]
43
- """
44
-
45
- # select feature using self.target_features
46
- data = data [:, :, :, self .target_features ]
47
- return data
48
-
49
20
def forward (self , data : tuple , epoch : int = None , iter_num : int = None , train : bool = True , ** kwargs ) -> tuple :
21
+ data = self .preprocessing (data )
22
+
50
23
if train :
51
24
future_data , history_data , idx = data ['target' ], data ['inputs' ], data ['idx' ]
52
25
else :
@@ -68,6 +41,7 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b
68
41
model_return ["target" ] = self .select_target_features (future_data )
69
42
assert list (model_return ["prediction" ].shape )[:3 ] == [batch_size , seq_len , num_nodes ], \
70
43
"error shape of the output, edit the forward function to reshape it to [B, L, N, C]"
44
+ model_return = self .postprocessing (model_return )
71
45
return model_return
72
46
73
47
def train_iters (self , epoch : int , iter_index : int , data : Union [torch .Tensor , Tuple ]) -> torch .Tensor :
0 commit comments