Skip to content

Commit 0769411

Browse files
casaroFlax Authors
authored andcommitted
Add QK Norm.
PiperOrigin-RevId: 733610890
1 parent 50b996d commit 0769411

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

examples/gemma/modules.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(
8585
rngs: nnx.Rngs,
8686
attn_logits_soft_cap: float | None = None,
8787
sliding_window_size: int | None = None,
88+
use_qk_norm: bool = False,
8889
sow_config: sow_lib.SowConfig = sow_lib.SowConfig()
8990
):
9091
if attn_type == AttentionType.LOCAL_SLIDING and sliding_window_size is None:
@@ -100,6 +101,7 @@ def __init__(
100101
shape=(num_heads, head_dim, features),
101102
rngs=rngs,
102103
)
104+
self.use_qk_norm = use_qk_norm
103105
self.sow_config = sow_config
104106

105107
if num_heads == num_kv_heads:
@@ -119,6 +121,9 @@ def __init__(
119121
shape=(2, num_kv_heads, features, head_dim),
120122
rngs=rngs,
121123
)
124+
if self.use_qk_norm:
125+
self._query_norm = layers.RMSNorm(head_dim, rngs=rngs)
126+
self._key_norm = layers.RMSNorm(head_dim, rngs=rngs)
122127

123128
def __call__(
124129
self,
@@ -135,6 +140,10 @@ def __call__(
135140
query_proj = self.q_einsum(x)
136141
key_proj, value_proj = self.kv_einsum(x)
137142

143+
if self.use_qk_norm:
144+
query_proj = self._query_norm(query_proj)
145+
key_proj = self._key_norm(key_proj)
146+
138147
query_proj = positional_embeddings.apply_rope(
139148
query_proj,
140149
segment_pos,
@@ -300,6 +309,7 @@ def __init__(
300309
rngs: nnx.Rngs,
301310
attn_logits_soft_cap: float | None = None,
302311
sliding_window_size: int | None = None,
312+
use_qk_norm: bool = False,
303313
sow_config: sow_lib.SowConfig = sow_lib.SowConfig()
304314
):
305315
self.pre_attention_norm = layers.RMSNorm(embed_dim, rngs=rngs)
@@ -312,6 +322,7 @@ def __init__(
312322
attn_logits_soft_cap=attn_logits_soft_cap,
313323
sliding_window_size=sliding_window_size,
314324
rngs=rngs,
325+
use_qk_norm=use_qk_norm,
315326
sow_config=sow_config,
316327
)
317328
if use_post_attn_norm:

examples/gemma/transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class TransformerConfig:
4848
use_post_ffw_norm: bool
4949
attention_types: Iterable[modules.AttentionType]
5050
attn_logits_soft_cap: float | None = None
51+
use_qk_norm: bool = False
5152
sliding_window_size: int | None = None
5253

5354
@classmethod
@@ -248,6 +249,7 @@ def __init__(
248249
attn_logits_soft_cap=config.attn_logits_soft_cap,
249250
attn_type=attn_type,
250251
rngs=rngs,
252+
use_qk_norm=config.use_qk_norm,
251253
sow_config=sow_config,
252254
)
253255
for _, attn_type in zip(

0 commit comments

Comments
 (0)