Skip to content

Commit 001b36d

Browse files
authored
Fix natten x y dims (#128)
1 parent 41a96b6 commit 001b36d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

attn_gym/masks/natten.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ def get_x_y_tiled(idx: IntTensor) -> Tuple[IntTensor, IntTensor]:
6666
Map 1-D index to 2-D coordinates for static tiles of T_H x T_W.
6767
"""
6868
t_id = idx // (T_H * T_W)
69-
t_x, t_y = t_id // (W // T_W), t_id % (W // T_W)
69+
t_y, t_x = t_id // (W // T_W), t_id % (W // T_W)
7070
t_offset = idx % (T_H * T_W)
71-
i_x, i_y = t_offset // T_W, t_offset % T_W
71+
i_y, i_x = t_offset // T_W, t_offset % T_W
7272
return t_x * T_W + i_x, t_y * T_H + i_y
7373

7474
def tiled_natten_mask(

0 commit comments

Comments
 (0)