@@ -85,6 +85,7 @@ def __init__(
85
85
rngs : nnx .Rngs ,
86
86
attn_logits_soft_cap : float | None = None ,
87
87
sliding_window_size : int | None = None ,
88
+ use_qk_norm : bool = False ,
88
89
sow_config : sow_lib .SowConfig = sow_lib .SowConfig ()
89
90
):
90
91
if attn_type == AttentionType .LOCAL_SLIDING and sliding_window_size is None :
@@ -100,6 +101,7 @@ def __init__(
100
101
shape = (num_heads , head_dim , features ),
101
102
rngs = rngs ,
102
103
)
104
+ self .use_qk_norm = use_qk_norm
103
105
self .sow_config = sow_config
104
106
105
107
if num_heads == num_kv_heads :
@@ -119,6 +121,9 @@ def __init__(
119
121
shape = (2 , num_kv_heads , features , head_dim ),
120
122
rngs = rngs ,
121
123
)
124
+ if self .use_qk_norm :
125
+ self ._query_norm = layers .RMSNorm (head_dim , rngs = rngs )
126
+ self ._key_norm = layers .RMSNorm (head_dim , rngs = rngs )
122
127
123
128
def __call__ (
124
129
self ,
@@ -135,6 +140,10 @@ def __call__(
135
140
query_proj = self .q_einsum (x )
136
141
key_proj , value_proj = self .kv_einsum (x )
137
142
143
+ if self .use_qk_norm :
144
+ query_proj = self ._query_norm (query_proj )
145
+ key_proj = self ._key_norm (key_proj )
146
+
138
147
query_proj = positional_embeddings .apply_rope (
139
148
query_proj ,
140
149
segment_pos ,
@@ -300,6 +309,7 @@ def __init__(
300
309
rngs : nnx .Rngs ,
301
310
attn_logits_soft_cap : float | None = None ,
302
311
sliding_window_size : int | None = None ,
312
+ use_qk_norm : bool = False ,
303
313
sow_config : sow_lib .SowConfig = sow_lib .SowConfig ()
304
314
):
305
315
self .pre_attention_norm = layers .RMSNorm (embed_dim , rngs = rngs )
@@ -312,6 +322,7 @@ def __init__(
312
322
attn_logits_soft_cap = attn_logits_soft_cap ,
313
323
sliding_window_size = sliding_window_size ,
314
324
rngs = rngs ,
325
+ use_qk_norm = use_qk_norm ,
315
326
sow_config = sow_config ,
316
327
)
317
328
if use_post_attn_norm :
0 commit comments