Skip to content

Commit 0318b3e

Browse files
authored
Remove 5D tensor reshape in attention layer implementation. (#57)
* Remove 5D tensor reshape in attention layer implementation. * formatting
1 parent 5b31b82 commit 0318b3e

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

ai_edge_torch/generative/layers/attention.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -189,20 +189,23 @@ def forward(
189189

190190
# Assemble into a number of query groups to support MHA, MQA and GQA.
191191
q_per_kv = self.config.num_heads // self.config.num_query_groups
192-
total_qkv = q_per_kv + 2 # Each group has >=1 queries, 1 key, and 1 value.
192+
# Each group has >=1 queries, 1 key, and 1 value.
193193
if self.config.qkv_transpose_before_split:
194-
qkv = qkv.view(
195-
B, T, total_qkv, self.config.num_query_groups, self.head_dim
196-
) # (B, T, total_qkv, num_query_groups, head_dim)
197-
qkv_axis = -3
194+
qkv = qkv.view(B, T, -1, self.head_dim)
195+
q, k, v = qkv.split(
196+
(
197+
q_per_kv * self.config.num_query_groups,
198+
self.config.num_query_groups,
199+
self.config.num_query_groups,
200+
),
201+
dim=-2,
202+
)
198203
else:
199-
qkv = qkv.view(
200-
B, T, self.config.num_query_groups, total_qkv, self.head_dim
201-
) # (B, T, num_query_groups, total_qkv, head_dim)
202-
qkv_axis = -2
204+
qkv = qkv.view(B, T, self.config.num_query_groups, -1)
205+
q, k, v = qkv.split(
206+
(q_per_kv * self.head_dim, self.head_dim, self.head_dim), dim=-1
207+
)
203208

204-
# Split batched computation into three.
205-
q, k, v = qkv.split((q_per_kv, 1, 1), dim=qkv_axis)
206209
q = q.reshape(B, T, -1, self.head_dim)
207210
k = k.reshape(B, T, -1, self.head_dim)
208211
v = v.reshape(B, T, -1, self.head_dim)

ai_edge_torch/generative/layers/unet/blocks_2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
134134
x = input_tensor.view(B, C, H * W)
135135
x = x.transpose(-1, -2)
136136
x = self.norm(x)
137+
x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite.
137138
x = self.attention(x)
138139
x = x.transpose(-1, -2)
139140
x = x.view(B, C, H, W)

0 commit comments

Comments
 (0)