@@ -47,8 +47,9 @@ def apply_rotary_pos_emb(x: jnp.ndarray, cos_emb: jnp.ndarray, sin_emb: jnp.ndar
47
47
return (x * cos_emb ) + (rotate_half (x ) * sin_emb )
48
48
49
49
50
- def precompute_rotary_embeddings (seq_len : int , head_dim : int ,
51
- base : float = 10000.0 ) -> tuple [jnp .ndarray , jnp .ndarray ]:
50
+ def precompute_rotary_embeddings (
51
+ seq_len : int , head_dim : int , base : float = 10000.0
52
+ ) -> tuple [jnp .ndarray , jnp .ndarray ]:
52
53
"""Precomputes the RoPE cosine and sine embeddings.
53
54
54
55
Args:
@@ -91,11 +92,11 @@ class RoPEMultiHeadAttention(nn.Module):
91
92
rope_base : float = 10000.0
92
93
dtype : jnp .dtype = jnp .float32
93
94
94
- def setup (self ) -> None : # Added -> None return type
95
+ def setup (self ) -> None : # Added -> None return type
95
96
"""Initializes the attention projections."""
96
97
# Check head_dim validity early during setup
97
98
if self .head_dim % 2 != 0 :
98
- raise ValueError (f"head_dim ({ self .head_dim } ) must be even for RoPE." )
99
+ raise ValueError (f"head_dim ({ self .head_dim } ) must be even for RoPE." )
99
100
100
101
# Define layers here - they will be initialized when the module is first called
101
102
total_head_dim = self .num_heads * self .head_dim
@@ -109,13 +110,12 @@ def setup(self) -> None: # Added -> None return type
109
110
features = total_head_dim , use_bias = False , dtype = self .dtype , name = "value_proj"
110
111
)
111
112
self .output_proj = nn .Dense (
112
- features = self .num_heads * self .head_dim , # Output should match embed_dim
113
+ features = self .num_heads * self .head_dim , # Output should match embed_dim
113
114
use_bias = False ,
114
115
dtype = self .dtype ,
115
- name = "output_proj"
116
+ name = "output_proj" ,
116
117
)
117
118
118
-
119
119
@nn .compact
120
120
# Also using Optional for the mask type hint for clarity with None default
121
121
def __call__ (self , x : jnp .ndarray , mask : jnp .ndarray | None = None ) -> jnp .ndarray :
@@ -136,8 +136,7 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
136
136
137
137
if embed_dim != total_head_dim :
138
138
raise ValueError (
139
- f"embed_dim ({ embed_dim } ) must equal num_heads*head_dim"
140
- f" ({ total_head_dim } )"
139
+ f"embed_dim ({ embed_dim } ) must equal num_heads*head_dim ({ total_head_dim } )"
141
140
)
142
141
# Note: head_dim even check moved to setup for earlier failure
143
142
@@ -159,7 +158,6 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
159
158
cos_emb = cos_emb .astype (self .dtype )
160
159
sin_emb = sin_emb .astype (self .dtype )
161
160
162
-
163
161
# 4. Apply RoPE to Query and Key
164
162
query = apply_rotary_pos_emb (query , cos_emb , sin_emb )
165
163
key = apply_rotary_pos_emb (key , cos_emb , sin_emb )
@@ -172,44 +170,46 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
172
170
# 6. Scaled Dot-Product Attention
173
171
# Attention scores: (batch, num_heads, seq_len, seq_len)
174
172
attn_scores = jnp .matmul (query , key .transpose ((0 , 1 , 3 , 2 ))) / jnp .sqrt (
175
- self .head_dim ).astype (self .dtype ) # Ensure sqrt is correct dtype
173
+ self .head_dim
174
+ ).astype (self .dtype ) # Ensure sqrt is correct dtype
176
175
177
176
# Apply mask (if provided)
178
177
if mask is not None :
179
178
# Standard Flax causal mask is boolean (True means mask)
180
179
# nn.make_causal_mask returns (1, seq_len, seq_len) or (batch, 1, seq_len, seq_len)
181
180
# Check if mask needs broadcasting or conversion
182
- if mask .ndim == 2 : # Likely (seq_len, seq_len)
183
- mask = mask [None , None , :, :] # -> (1, 1, seq_len, seq_len)
181
+ if mask .ndim == 2 : # Likely (seq_len, seq_len)
182
+ mask = mask [None , None , :, :] # -> (1, 1, seq_len, seq_len)
184
183
elif mask .ndim == 3 and mask .shape [1 ] != self .num_heads :
185
- # Likely (batch, seq_len, seq_len) or causal (1, sl, sl)
184
+ # Likely (batch, seq_len, seq_len) or causal (1, sl, sl)
186
185
mask = mask [:, None , :, :]
187
- # Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len)
186
+ # Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len)
188
187
189
188
# Ensure mask is broadcastable to attn_scores shape
190
189
mask_shape_expected = (batch_size , self .num_heads , seq_len , seq_len )
191
190
if mask .shape != mask_shape_expected :
192
- # Attempt broadcasting common causal mask shapes
193
- if mask .shape == (1 , 1 , seq_len , seq_len ) or mask .shape == (batch_size , 1 ,
194
- seq_len , seq_len ): # Causal mask for all batches/heads
195
- mask = jnp .broadcast_to (mask , mask_shape_expected )
196
- # Add other broadcasting cases if needed
197
- else :
198
- raise ValueError (f"Mask shape { mask .shape } != exp shape { mask_shape_expected } " )
199
-
191
+ # Attempt broadcasting common causal mask shapes
192
+ if mask .shape == (1 , 1 , seq_len , seq_len ) or mask .shape == (
193
+ batch_size ,
194
+ 1 ,
195
+ seq_len ,
196
+ seq_len ,
197
+ ): # Causal mask for all batches/heads
198
+ mask = jnp .broadcast_to (mask , mask_shape_expected )
199
+ # Add other broadcasting cases if needed
200
+ else :
201
+ raise ValueError (f"Mask shape { mask .shape } != exp shape { mask_shape_expected } " )
200
202
201
203
# Apply mask: Use large negative number where mask is True
202
204
# (or where mask value is 0 if using 0/-inf convention)
203
205
# Assuming boolean mask convention (True = mask) common in Flax examples
204
206
# If using 0/-inf mask, the logic would be: attn_scores = attn_scores + mask
205
207
attn_scores = jnp .where (mask , jnp .finfo (self .dtype ).min , attn_scores )
206
208
207
-
208
209
# Softmax to get attention weights
209
- attn_weights = jax .nn .softmax (
210
- attn_scores , axis = - 1
211
- ).astype (self .dtype ) # Shape: (batch, num_heads, seq_len, seq_len)
212
-
210
+ attn_weights = jax .nn .softmax (attn_scores , axis = - 1 ).astype (
211
+ self .dtype
212
+ ) # Shape: (batch, num_heads, seq_len, seq_len)
213
213
214
214
# Apply attention weights to Value
215
215
# Output per head: (batch, num_heads, seq_len, head_dim)
@@ -222,6 +222,6 @@ def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None = None) -> jnp.ndarr
222
222
attn_output = attn_output .reshape (batch_size , seq_len , total_head_dim )
223
223
224
224
# Final linear projection
225
- output = self .output_proj (attn_output ) # Use self.output_proj defined in setup
225
+ output = self .output_proj (attn_output ) # Use self.output_proj defined in setup
226
226
227
227
return output
0 commit comments