Skip to content

Commit 90d7dc2

Browse files
protobird-gitcopybara-github
authored andcommitted
Fix verifier or verify_util of each model
mask_cache_size and kv_cache_max_len should have been passed. PiperOrigin-RevId: 769912691
1 parent 5f22c45 commit 90d7dc2

File tree

17 files changed

+65
-23
lines changed

17 files changed

+65
-23
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/verify_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def verify_amd_llama_135m(
6060
reauthored_model = amd_llama_135m.build_model(
6161
checkpoint_path=reauthored_checkpoint,
6262
custom_loader=custom_loader,
63+
mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN,
6364
)
6465

6566
logging.info("Loading the tokenizer from: %s", checkpoint_dir)

ai_edge_torch/generative/examples/deepseek/verify_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def verify_deepseek_r1_distill_1_5b(
6060
reauthored_model = deepseek.build_model(
6161
checkpoint_path=reauthored_checkpoint,
6262
custom_loader=custom_loader,
63+
mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN,
6364
)
6465

6566
logging.info("Loading the tokenizer from: %s", checkpoint_dir)

ai_edge_torch/generative/examples/gemma/verify_util.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,7 @@ def verify_reauthored_gemma_model(
143143
return verifier.verify_reauthored_model(
144144
original_model=GemmaWrapper(original_model),
145145
reauthored_model=verifier.ReauthoredModelWrapper(
146-
reauthored_model,
147-
mask_as_input=mask_as_input,
148-
kv_layout=kv_layout,
146+
reauthored_model, mask_as_input, kv_layout
149147
),
150148
tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
151149
generate_prompts=generate_prompts,
@@ -171,7 +169,11 @@ def verify_gemma2(
171169
"""
172170
checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
173171
logging.info("Building the reauthored model from: %s", checkpoint_path)
174-
reauthored_model = gemma2.build_2b_model(checkpoint_path, custom_loader)
172+
reauthored_model = gemma2.build_2b_model(
173+
checkpoint_path,
174+
custom_loader,
175+
mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN,
176+
)
175177

176178
return verify_reauthored_gemma_model(
177179
checkpoint=checkpoint_dir,
@@ -193,7 +195,11 @@ def verify_gemma1_with_custom_loader(checkpoint_dir: str) -> bool:
193195
weight_filename = "gemma-2b-it.ckpt"
194196
checkpoint_path = os.path.join(checkpoint_dir, weight_filename)
195197
custom_loader = loader.get_custom_loader(checkpoint_path)
196-
reauthored_model = gemma1.build_2b_model(checkpoint_path, custom_loader)
198+
reauthored_model = gemma1.build_2b_model(
199+
checkpoint_path,
200+
custom_loader,
201+
mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN,
202+
)
197203
return verify_reauthored_gemma_model(
198204
checkpoint=checkpoint_dir,
199205
variant="2b",

ai_edge_torch/generative/examples/gemma3/verify_util.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,21 +93,27 @@ def generate(
9393
class UnifiedGemma3Wrapper(verifier.ReauthoredModelWrapper):
9494
"""Unified Gemma3 model wrapper for verification."""
9595

96-
def __init__(self, model: torch.nn.Module):
97-
super().__init__(model, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED)
96+
def __init__(
97+
self,
98+
model: torch.nn.Module,
99+
kv_cache_max_len: int = verifier.DEFAULT_KV_CACHE_MAX_LEN,
100+
):
101+
super().__init__(
102+
model,
103+
kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED,
104+
kv_cache_max_len=kv_cache_max_len,
105+
)
98106

99107
def _init_kv_cache(self):
100108
return kv_utils.KVCache.from_model_config(
101-
self.model.model.config, kv_layout=self.kv_layout
109+
self.kv_cache_max_len, self.model.model.config, kv_layout=self.kv_layout
102110
)
103111

104112
def forward(
105113
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
106114
) -> torch.Tensor:
107115
"""Forwards the model."""
108-
mask = attn_utils.build_causal_mask_cache(
109-
self.model.model.config.kv_cache_max_len
110-
)
116+
mask = attn_utils.build_causal_mask_cache(self.kv_cache_max_len)
111117
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
112118
mask = mask.index_select(2, input_pos)
113119
output = self.model.model.forward(
@@ -127,9 +133,7 @@ def generate(
127133
tokens = torch.tensor([input_ids])
128134
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
129135
kv_cache = self._init_kv_cache()
130-
mask_cache = attn_utils.build_causal_mask_cache(
131-
self.model.model.config.kv_cache_max_len
132-
)
136+
mask_cache = attn_utils.build_causal_mask_cache(self.kv_cache_max_len)
133137
for _ in range(max_new_tokens):
134138
mask = mask_cache.index_select(2, input_pos)
135139
output = self.model.model.forward(
@@ -245,7 +249,11 @@ def verify_gemma3(
245249

246250
if variant == "1b":
247251
reauthored_model = UnifiedGemma3Wrapper(
248-
gemma3.build_model_1b(gemma3_model_path, custom_loader)
252+
gemma3.build_model_1b(
253+
gemma3_model_path,
254+
custom_loader,
255+
mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN,
256+
)
249257
)
250258
else:
251259
raise ValueError(f"Unsupported Gemma3 variant: {variant}")

ai_edge_torch/generative/examples/hammer/verify_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def verify_hammer(
6666
reauthored_model = _BUILDER[model_size](
6767
checkpoint_path=reauthored_checkpoint,
6868
custom_loader=custom_loader,
69+
mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN,
6970
)
7071

7172
logging.info("Loading the tokenizer from: %s", checkpoint_dir)

ai_edge_torch/generative/examples/llama/verify_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def verify_llama_3_2(
6565
reauthored_model = _BUILDER[model_size](
6666
checkpoint_path=reauthored_checkpoint,
6767
custom_loader=custom_loader,
68+
mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN,
6869
)
6970

7071
logging.info("Loading the tokenizer from: %s", checkpoint_dir)

ai_edge_torch/generative/examples/openelm/verify_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def verify_openelm(
6060
reauthored_model = openelm.build_model(
6161
checkpoint_path=reauthored_checkpoint,
6262
custom_loader=custom_loader,
63+
mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN,
6364
)
6465

6566
logging.info("Loading the tokenizer from: %s", checkpoint_dir)

ai_edge_torch/generative/examples/paligemma/verify.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
6666
"""Reauthored PaliGemma model wrapper."""
6767

6868
def _init_kv_cache(self):
69-
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
69+
return kv_cache.KVCache.from_model_config(
70+
self.kv_cache_max_len, self.model.config.decoder_config
71+
)
7072

7173

7274
def main(_):
@@ -88,7 +90,9 @@ def main(_):
8890

8991
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
9092
reauthored_model = paligemma.build_model(
91-
reauthored_checkpoint, version=int(_VERSION.value)
93+
reauthored_checkpoint,
94+
version=int(_VERSION.value),
95+
mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN,
9296
)
9397
wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
9498

ai_edge_torch/generative/examples/paligemma/verify_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def main(_):
5151
)
5252
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
5353
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54-
reauthored_model = decoder.build_decoder(reauthored_checkpoint)
54+
reauthored_model = decoder.build_decoder(
55+
reauthored_checkpoint, mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN
56+
)
5557

5658
logging.info("Loading the tokenizer from: %s", checkpoint)
5759
# It works only when GemmaTokenizerFast is available. In some environments,

ai_edge_torch/generative/examples/paligemma/verify_decoder2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def main(_):
4848
original_language_model = original_full_model.eval().language_model
4949

5050
logging.info("Building the reauthored model from: %s", checkpoint)
51-
reauthored_model = decoder2.build_decoder2(checkpoint)
51+
reauthored_model = decoder2.build_decoder2(
52+
checkpoint, mask_cache_size=verifier.DEFAULT_KV_CACHE_MAX_LEN
53+
)
5254

5355
logging.info("Loading the tokenizer from: %s", checkpoint)
5456
# It works only when GemmaTokenizerFast is available. In some environments,

0 commit comments

Comments
 (0)