Skip to content

Commit 750f58e

Browse files
protobird-gitcopybara-github
authored andcommitted
Remove redundant code in gemma3 decoder
- sliding window mask is calculated on the fly - no need to get local mask from cache any more when mask_as_input is false PiperOrigin-RevId: 756820429
1 parent 45017e4 commit 750f58e

File tree

2 files changed

+7
-32
lines changed

2 files changed

+7
-32
lines changed

ai_edge_torch/generative/examples/gemma3/decoder.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ def __init__(self, config: cfg.ModelConfig):
119119
config.vocab_size, config.embedding_dim, padding_idx=0
120120
)
121121
self.lm_head = nn.Linear(
122-
config.embedding_dim,
123-
config.vocab_size,
124-
bias=config.lm_head_use_bias,
122+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
125123
)
126124
# Gemma3 re-uses the embedding as the head projection layer.
127125
self.lm_head.weight.data = self.tok_embedding.weight.data
@@ -130,30 +128,13 @@ def __init__(self, config: cfg.ModelConfig):
130128
for idx in range(config.num_layers)
131129
)
132130
self.final_norm = builder.build_norm(
133-
config.embedding_dim,
134-
config.final_norm_config,
131+
config.embedding_dim, config.final_norm_config
135132
)
136133
self.mask_cache = attn_utils.build_causal_mask_cache(
137134
size=config.kv_cache_max,
138135
)
139-
# Gemma3 has same hyper parameters for each layer except for attention
140-
# types. Use the first layer.
141-
attn_config = config.block_config(0).attn_config
142-
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
143-
size=config.kv_cache_max,
144-
window_size=attn_config.sliding_window_size,
145-
)
146136
self.config = config
147137

148-
def get_attention_mask(
149-
self,
150-
attn_type: cfg.AttentionType,
151-
input_pos: torch.Tensor,
152-
) -> torch.Tensor:
153-
if attn_type == cfg.AttentionType.LOCAL_SLIDING:
154-
return self.sliding_window_mask_cache.index_select(2, input_pos)
155-
return self.mask_cache.index_select(2, input_pos)
156-
157138
def get_local_global_attention_mask(
158139
self,
159140
attention_mask: torch.Tensor,
@@ -200,9 +181,7 @@ def create_sliding_mask(
200181
sliding_mask_bool,
201182
torch.zeros_like(sliding_mask_bool, dtype=torch.float),
202183
torch.full_like(
203-
sliding_mask_bool,
204-
self.config.causal_mask_value,
205-
dtype=torch.float,
184+
sliding_mask_bool, self.config.causal_mask_value, dtype=torch.float
206185
),
207186
)
208187

@@ -272,12 +251,8 @@ def forward(
272251
for i in range(self.config.num_layers)
273252
]
274253
if mask is None:
275-
mask = [
276-
self.get_attention_mask(
277-
self.config.block_config(i).attn_config.attn_type, input_pos
278-
)
279-
for i in range(self.config.num_layers)
280-
]
254+
mask = self.mask_cache.index_select(2, input_pos)
255+
mask = mask[:, :, :, : self.config.kv_cache_max]
281256

282257
return self._forward_with_embeds(
283258
input_embeds, rope, mask, input_pos, kv_cache, pixel_mask, export_config
@@ -329,6 +304,7 @@ def _forward_with_embeds(
329304
if kv_entry:
330305
updated_kv_entries.append(kv_entry)
331306
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
307+
332308
if export_config is not None:
333309
if (
334310
torch.numel(input_pos) > 1

ai_edge_torch/generative/utilities/model_builder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ def __init__(self, config: cfg.ModelConfig):
7575
for idx in range(config.num_layers)
7676
)
7777
self.final_norm = builder.build_norm(
78-
config.embedding_dim,
79-
config.final_norm_config,
78+
config.embedding_dim, config.final_norm_config
8079
)
8180
self.mask_cache = attn_utils.build_causal_mask_cache(
8281
size=config.kv_cache_max,

0 commit comments

Comments
 (0)