Skip to content

Commit d91d272

Browse files
committed
fix: 🐛 a bug in MTGNN runner (issue GestaltCogTeam#220)
1 parent 59a0f95 commit d91d272

File tree

1 file changed

+3
-29
lines changed

1 file changed

+3
-29
lines changed

baselines/MTGNN/runner/mtgnn_runner.py

+3-29
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,9 @@ def __init__(self, cfg: dict):
1717
self.num_split = cfg.TRAIN.CUSTOM.NUM_SPLIT
1818
self.perm = None
1919

20-
def select_input_features(self, data: torch.Tensor) -> torch.Tensor:
21-
"""Select input features.
22-
23-
Args:
24-
data (torch.Tensor): input history data, shape [B, L, N, C]
25-
26-
Returns:
27-
torch.Tensor: reshaped data
28-
"""
29-
30-
# select feature using self.forward_features
31-
if self.forward_features is not None:
32-
data = data[:, :, :, self.forward_features]
33-
return data
34-
35-
def select_target_features(self, data: torch.Tensor) -> torch.Tensor:
36-
"""Select target feature
37-
38-
Args:
39-
data (torch.Tensor): prediction of the model with arbitrary shape.
40-
41-
Returns:
42-
torch.Tensor: reshaped data with shape [B, L, N, C]
43-
"""
44-
45-
# select feature using self.target_features
46-
data = data[:, :, :, self.target_features]
47-
return data
48-
4920
def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: bool = True, **kwargs) -> tuple:
21+
data = self.preprocessing(data)
22+
5023
if train:
5124
future_data, history_data, idx = data['target'], data['inputs'], data['idx']
5225
else:
@@ -68,6 +41,7 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b
6841
model_return["target"] = self.select_target_features(future_data)
6942
assert list(model_return["prediction"].shape)[:3] == [batch_size, seq_len, num_nodes], \
7043
"error shape of the output, edit the forward function to reshape it to [B, L, N, C]"
44+
model_return = self.postprocessing(model_return)
7145
return model_return
7246

7347
def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor:

0 commit comments

Comments
 (0)