Skip to content

Commit 51248ef

Browse files
ai-edge-botcopybara-github
authored andcommitted
Modified GemmaWrapper for Gemma2 to pass local mask cache to original model and added filename for model weights.
PiperOrigin-RevId: 766626383
1 parent 787e48c commit 51248ef

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

ai_edge_torch/generative/examples/gemma/verify_gemma2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,18 @@
4242
True,
4343
"Transpose the KV cache to reduce memory usage.",
4444
)
45+
_WEIGHT_FILENAME = flags.DEFINE_string(
46+
"weight_filename",
47+
"model.ckpt",
48+
"Name of the weight file in the checkpoint directory.",
49+
)
4550

4651
def main(_):
4752
checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
4853

4954
verify_util.verify_gemma2(
5055
checkpoint,
56+
_WEIGHT_FILENAME.value,
5157
_PROMPTS.value,
5258
_MAX_NEW_TOKENS.value,
5359
_MASK_AS_INPUT.value,

ai_edge_torch/generative/examples/gemma/verify_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
6262
actual_input_len = self._get_actual_input_len(tokens)
6363
input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
6464
mask_cache = attn_utils.build_causal_mask_cache(tokens.shape[1])
65+
local_mask_cache = attn_utils.build_sliding_window_mask_cache(
66+
tokens.shape[1], self.model.config.sliding_window_size)
6567
_, logits = self.model.forward(
6668
input_token_ids=tokens[0, :actual_input_len].unsqueeze(0),
6769
input_positions=input_pos,
@@ -72,6 +74,7 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
7274
temperatures=None,
7375
top_ps=torch.tensor([1.0], dtype=torch.float),
7476
top_ks=torch.tensor([1], dtype=torch.long),
77+
local_mask=local_mask_cache.index_select(2, input_pos)
7578
)
7679
return logits
7780

0 commit comments

Comments
 (0)