Skip to content

Commit 4b5f0ad

Browse files
committed
fix: 🐛 a bug in STAEformer dow embeddings (isse GestaltCogTeam#219)
1 parent d91d272 commit 4b5f0ad

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

baselines/STAEformer/arch/staeformer_arch.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -201,16 +201,16 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_s
201201
batch_size = x.shape[0]
202202

203203
if self.tod_embedding_dim > 0:
204-
tod = x[..., 1]
204+
tod = x[..., 1] * self.steps_per_day
205205
if self.dow_embedding_dim > 0:
206-
dow = x[..., 2]
206+
dow = x[..., 2] * 7
207207
x = x[..., : self.input_dim]
208208

209209
x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim)
210210
features = [x]
211211
if self.tod_embedding_dim > 0:
212212
tod_emb = self.tod_embedding(
213-
(tod * self.steps_per_day).long()
213+
tod.long()
214214
) # (batch_size, in_steps, num_nodes, tod_embedding_dim)
215215
features.append(tod_emb)
216216
if self.dow_embedding_dim > 0:

0 commit comments

Comments
 (0)