@@ -119,9 +119,7 @@ def __init__(self, config: cfg.ModelConfig):
119
119
config .vocab_size , config .embedding_dim , padding_idx = 0
120
120
)
121
121
self .lm_head = nn .Linear (
122
- config .embedding_dim ,
123
- config .vocab_size ,
124
- bias = config .lm_head_use_bias ,
122
+ config .embedding_dim , config .vocab_size , bias = config .lm_head_use_bias
125
123
)
126
124
# Gemma3 re-uses the embedding as the head projection layer.
127
125
self .lm_head .weight .data = self .tok_embedding .weight .data
@@ -130,30 +128,13 @@ def __init__(self, config: cfg.ModelConfig):
130
128
for idx in range (config .num_layers )
131
129
)
132
130
self .final_norm = builder .build_norm (
133
- config .embedding_dim ,
134
- config .final_norm_config ,
131
+ config .embedding_dim , config .final_norm_config
135
132
)
136
133
self .mask_cache = attn_utils .build_causal_mask_cache (
137
134
size = config .kv_cache_max ,
138
135
)
139
- # Gemma3 has same hyper parameters for each layer except for attention
140
- # types. Use the first layer.
141
- attn_config = config .block_config (0 ).attn_config
142
- self .sliding_window_mask_cache = attn_utils .build_sliding_window_mask_cache (
143
- size = config .kv_cache_max ,
144
- window_size = attn_config .sliding_window_size ,
145
- )
146
136
self .config = config
147
137
148
- def get_attention_mask (
149
- self ,
150
- attn_type : cfg .AttentionType ,
151
- input_pos : torch .Tensor ,
152
- ) -> torch .Tensor :
153
- if attn_type == cfg .AttentionType .LOCAL_SLIDING :
154
- return self .sliding_window_mask_cache .index_select (2 , input_pos )
155
- return self .mask_cache .index_select (2 , input_pos )
156
-
157
138
def get_local_global_attention_mask (
158
139
self ,
159
140
attention_mask : torch .Tensor ,
@@ -200,9 +181,7 @@ def create_sliding_mask(
200
181
sliding_mask_bool ,
201
182
torch .zeros_like (sliding_mask_bool , dtype = torch .float ),
202
183
torch .full_like (
203
- sliding_mask_bool ,
204
- self .config .causal_mask_value ,
205
- dtype = torch .float ,
184
+ sliding_mask_bool , self .config .causal_mask_value , dtype = torch .float
206
185
),
207
186
)
208
187
@@ -272,12 +251,8 @@ def forward(
272
251
for i in range (self .config .num_layers )
273
252
]
274
253
if mask is None :
275
- mask = [
276
- self .get_attention_mask (
277
- self .config .block_config (i ).attn_config .attn_type , input_pos
278
- )
279
- for i in range (self .config .num_layers )
280
- ]
254
+ mask = self .mask_cache .index_select (2 , input_pos )
255
+ mask = mask [:, :, :, : self .config .kv_cache_max ]
281
256
282
257
return self ._forward_with_embeds (
283
258
input_embeds , rope , mask , input_pos , kv_cache , pixel_mask , export_config
@@ -329,6 +304,7 @@ def _forward_with_embeds(
329
304
if kv_entry :
330
305
updated_kv_entries .append (kv_entry )
331
306
updated_kv_cache = kv_utils .KVCache (tuple (updated_kv_entries ))
307
+
332
308
if export_config is not None :
333
309
if (
334
310
torch .numel (input_pos ) > 1
0 commit comments