Skip to content

Commit a99ae89

Browse files
committed
fix: 🐛 bugs in STAEformer and MTGNN (issues #219 and #220)
1 parent 87be63b commit a99ae89

File tree

2 files changed

+6
-32
lines changed

2 files changed

+6
-32
lines changed

baselines/MTGNN/runner/mtgnn_runner.py

+3-29
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,9 @@ def __init__(self, cfg: dict):
1717
self.num_split = cfg.TRAIN.CUSTOM.NUM_SPLIT
1818
self.perm = None
1919

20-
def select_input_features(self, data: torch.Tensor) -> torch.Tensor:
21-
"""Select input features.
22-
23-
Args:
24-
data (torch.Tensor): input history data, shape [B, L, N, C]
25-
26-
Returns:
27-
torch.Tensor: reshaped data
28-
"""
29-
30-
# select feature using self.forward_features
31-
if self.forward_features is not None:
32-
data = data[:, :, :, self.forward_features]
33-
return data
34-
35-
def select_target_features(self, data: torch.Tensor) -> torch.Tensor:
36-
"""Select target feature
37-
38-
Args:
39-
data (torch.Tensor): prediction of the model with arbitrary shape.
40-
41-
Returns:
42-
torch.Tensor: reshaped data with shape [B, L, N, C]
43-
"""
44-
45-
# select feature using self.target_features
46-
data = data[:, :, :, self.target_features]
47-
return data
48-
4920
def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: bool = True, **kwargs) -> tuple:
21+
data = self.preprocessing(data)
22+
5023
if train:
5124
future_data, history_data, idx = data['target'], data['inputs'], data['idx']
5225
else:
@@ -68,6 +41,7 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b
6841
model_return["target"] = self.select_target_features(future_data)
6942
assert list(model_return["prediction"].shape)[:3] == [batch_size, seq_len, num_nodes], \
7043
"error shape of the output, edit the forward function to reshape it to [B, L, N, C]"
44+
model_return = self.postprocessing(model_return)
7145
return model_return
7246

7347
def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor:

baselines/STAEformer/arch/staeformer_arch.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -201,16 +201,16 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_s
201201
batch_size = x.shape[0]
202202

203203
if self.tod_embedding_dim > 0:
204-
tod = x[..., 1]
204+
tod = x[..., 1] * self.steps_per_day
205205
if self.dow_embedding_dim > 0:
206-
dow = x[..., 2]
206+
dow = x[..., 2] * 7
207207
x = x[..., : self.input_dim]
208208

209209
x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim)
210210
features = [x]
211211
if self.tod_embedding_dim > 0:
212212
tod_emb = self.tod_embedding(
213-
(tod * self.steps_per_day).long()
213+
tod.long()
214214
) # (batch_size, in_steps, num_nodes, tod_embedding_dim)
215215
features.append(tod_emb)
216216
if self.dow_embedding_dim > 0:

0 commit comments

Comments
 (0)