Skip to content

Commit 856a868

Browse files
ai-edge-botcopybara-github
authored andcommitted
Remove kv_cache_max_len from ModelConfig.
- This is the first step to make kv_cache_max_len configurable when model is loaded for inference - Infer kv_cache_max_len from kv_cache or mask. Either of them must be not null - Pass kv_cache_max_len as parameter during export - Build mask_cache only when mask_as_input is false - Confirmed that conversion generates the same tflite files before and after for gemma3, llama, and deepseek PiperOrigin-RevId: 766348863
1 parent 93edc84 commit 856a868

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+545
-470
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,16 @@ class AmdLlama(model_builder.DecoderOnlyModel):
2929
pass
3030

3131

32-
def get_model_config() -> cfg.ModelConfig:
33-
"""Returns the model config for an AMD-Llama-135m model."""
32+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
33+
"""Returns the model config for an AMD-Llama-135m model.
34+
35+
Args:
36+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
37+
is 1024.
38+
39+
Returns:
40+
The model config for an AMD-Llama-135m model.
41+
"""
3442
attn_config = cfg.AttentionConfig(
3543
num_heads=12,
3644
head_dim=64,
@@ -55,15 +63,16 @@ def get_model_config() -> cfg.ModelConfig:
5563
num_layers=12,
5664
max_seq_len=2048,
5765
embedding_dim=768,
66+
kv_cache_max_len=kv_cache_max_len,
5867
block_configs=block_config,
5968
final_norm_config=norm_config,
6069
lm_head_share_weight_with_embedding=False,
6170
)
6271
return config
6372

6473

65-
def get_fake_model_config() -> cfg.ModelConfig:
66-
config = get_model_config()
74+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
75+
config = get_model_config(**kwargs)
6776
config.vocab_size = 128
6877
config.num_layers = 2
6978
config.block_config(0).ff_config.intermediate_size = 64
@@ -73,13 +82,12 @@ def get_fake_model_config() -> cfg.ModelConfig:
7382
def build_model(
7483
checkpoint_path: str,
7584
custom_loader: Callable[[str], Dict[str, torch.Tensor]] | None = None,
76-
mask_cache_size: int = 0,
85+
**kwargs
7786
) -> nn.Module:
7887
return model_builder.build_decoder_only_model(
7988
checkpoint_path=checkpoint_path,
80-
config=get_model_config(),
89+
config=get_model_config(**kwargs),
8190
tensor_names=TENSOR_NAMES,
8291
model_class=AmdLlama,
8392
custom_loader=custom_loader,
84-
mask_cache_size=mask_cache_size,
8593
)

ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,13 @@ def main(_):
3131
custom_loader=loader.maybe_get_custom_loader(
3232
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
3333
),
34-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
34+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
3535
)
3636
converter.convert_to_tflite(
3737
pytorch_model,
3838
output_path=flags.FLAGS.output_path,
3939
output_name_prefix=flags.FLAGS.output_name_prefix,
4040
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
4241
quantize=flags.FLAGS.quantize,
4342
lora_ranks=flags.FLAGS.lora_ranks,
4443
export_config=export_config.get_from_flags(),

ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,20 @@
2323

2424
flags = converter.define_conversion_flags('deepseek')
2525

26-
2726
def main(_):
2827
checkpoint_path = flags.FLAGS.checkpoint_path
2928
pytorch_model = deepseek.build_model(
3029
checkpoint_path,
3130
custom_loader=loader.maybe_get_custom_loader(
3231
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
3332
),
34-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
33+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
3534
)
3635
converter.convert_to_tflite(
3736
pytorch_model,
3837
output_path=flags.FLAGS.output_path,
3938
output_name_prefix=flags.FLAGS.output_name_prefix,
4039
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
4240
quantize=flags.FLAGS.quantize,
4341
lora_ranks=flags.FLAGS.lora_ranks,
4442
export_config=export_config.get_from_flags(),

ai_edge_torch/generative/examples/deepseek/deepseek.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,16 @@ class DeepSeekDistillQwen(model_builder.DecoderOnlyModel):
2929
pass
3030

3131

32-
def get_model_config() -> cfg.ModelConfig:
33-
"""Returns the model config for a Qwen 2.5 3B model."""
32+
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
33+
"""Returns the model config for a Qwen 2.5 3B model.
34+
35+
Args:
36+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
37+
is 1024.
38+
39+
Returns:
40+
The model config for a SmolLM model.
41+
"""
3442
attn_config = cfg.AttentionConfig(
3543
num_heads=12,
3644
head_dim=128,
@@ -58,15 +66,16 @@ def get_model_config() -> cfg.ModelConfig:
5866
num_layers=28,
5967
max_seq_len=4096,
6068
embedding_dim=1536,
69+
kv_cache_max_len=kv_cache_max_len,
6170
block_configs=block_config,
6271
final_norm_config=norm_config,
6372
lm_head_share_weight_with_embedding=False,
6473
)
6574
return config
6675

6776

68-
def get_fake_model_config() -> cfg.ModelConfig:
69-
config = get_model_config()
77+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
78+
config = get_model_config(**kwargs)
7079
config.vocab_size = 128
7180
config.num_layers = 2
7281
# DeepSeek-R1-Distill-Qwen has only one block config.
@@ -77,13 +86,12 @@ def get_fake_model_config() -> cfg.ModelConfig:
7786
def build_model(
7887
checkpoint_path: str,
7988
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
80-
mask_cache_size: int = 0,
89+
**kwargs
8190
) -> nn.Module:
8291
return model_builder.build_decoder_only_model(
8392
checkpoint_path=checkpoint_path,
84-
config=get_model_config(),
93+
config=get_model_config(**kwargs),
8594
tensor_names=TENSOR_NAMES,
8695
model_class=DeepSeekDistillQwen,
8796
custom_loader=custom_loader,
88-
mask_cache_size=mask_cache_size,
8997
)

ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,13 @@ def main(_):
3131
custom_loader=loader.maybe_get_custom_loader(
3232
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
3333
),
34-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
34+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
3535
)
3636
converter.convert_to_tflite(
3737
pytorch_model,
3838
output_path=flags.FLAGS.output_path,
3939
output_name_prefix=flags.FLAGS.output_name_prefix,
4040
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
41-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
4241
quantize=flags.FLAGS.quantize,
4342
lora_ranks=flags.FLAGS.lora_ranks,
4443
export_config=export_config.get_from_flags(),

ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,13 @@ def main(_):
3333
custom_loader=loader.maybe_get_custom_loader(
3434
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
3535
),
36-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
36+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
3737
)
3838
converter.convert_to_tflite(
3939
pytorch_model,
4040
output_path=flags.FLAGS.output_path,
4141
output_name_prefix=flags.FLAGS.output_name_prefix,
4242
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
43-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
4443
quantize=flags.FLAGS.quantize,
4544
lora_ranks=flags.FLAGS.lora_ranks,
4645
export_config=export_config.get_from_flags(),

ai_edge_torch/generative/examples/gemma/gemma1.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,16 @@ class Gemma1(model_builder.DecoderOnlyModel):
4242
pass
4343

4444

45-
def get_model_config_2b() -> cfg.ModelConfig:
46-
"""Returns the model config for a Gemma 2B model."""
45+
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
46+
"""Returns the model config for a Gemma 2B model.
47+
48+
Args:
49+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
50+
is 1024.
51+
52+
Returns:
53+
The model config for a Gemma 2B model.
54+
"""
4755
attn_config = cfg.AttentionConfig(
4856
num_heads=8,
4957
head_dim=256,
@@ -72,33 +80,33 @@ def get_model_config_2b() -> cfg.ModelConfig:
7280
max_seq_len=8192,
7381
embedding_dim=embedding_dim,
7482
embedding_scale=embedding_dim**0.5,
83+
kv_cache_max_len=kv_cache_max_len,
7584
block_configs=block_config,
7685
final_norm_config=norm_config,
7786
lm_head_use_bias=False,
7887
)
7988
return config
8089

8190

82-
def get_fake_model_config() -> cfg.ModelConfig:
83-
config = get_model_config_2b()
91+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
92+
config = get_model_config_2b(kv_cache_max_len)
8493
# Gemma has only one block config.
8594
config.block_config(0).ff_config.intermediate_size = 128
8695
config.vocab_size = 128
8796
config.num_layers = 2
88-
config.max_seq_len = 256
97+
config.max_seq_len = 2 * kv_cache_max_len
8998
return config
9099

91100

92101
def build_2b_model(
93102
checkpoint_path: str,
94103
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
95-
mask_cache_size: int = 0,
104+
**kwargs
96105
) -> nn.Module:
97106
return model_builder.build_decoder_only_model(
98107
checkpoint_path=checkpoint_path,
99-
config=get_model_config_2b(),
108+
config=get_model_config_2b(**kwargs),
100109
tensor_names=TENSOR_NAMES,
101110
model_class=Gemma1,
102111
custom_loader=custom_loader,
103-
mask_cache_size=mask_cache_size,
104112
)

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def forward(
104104
class Gemma2(nn.Module):
105105
"""A Gemma2 model built from the Edge Generative API layers."""
106106

107-
def __init__(self, config: cfg.ModelConfig, mask_cache_size: int = 0):
107+
def __init__(self, config: cfg.ModelConfig):
108108
super().__init__()
109109

110110
# Construct model layers.
@@ -126,24 +126,17 @@ def __init__(self, config: cfg.ModelConfig, mask_cache_size: int = 0):
126126
config.embedding_dim,
127127
config.final_norm_config,
128128
)
129-
self.config = config
130-
self.build_mask_cache(mask_cache_size)
131-
132-
def build_mask_cache(self, mask_cache_size: int):
133-
assert (
134-
mask_cache_size <= self.config.max_seq_len
135-
), "Mask cache size must be less than or equal to the max seq length."
136-
if mask_cache_size <= 0:
137-
self.mask_cache = None
138-
self.sliding_window_mask_cache = None
139-
return
140-
self.mask_cache = attn_utils.build_causal_mask_cache(mask_cache_size)
129+
self.mask_cache = attn_utils.build_causal_mask_cache(
130+
size=config.kv_cache_max,
131+
)
141132
# Gemma2 has same hyper parameters for each layer except for attention
142133
# types. Use the first layer.
134+
attn_config = config.block_config(0).attn_config
143135
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
144-
size=mask_cache_size,
145-
window_size=self.config.block_config(0).attn_config.sliding_window_size,
136+
size=config.kv_cache_max,
137+
window_size=attn_config.sliding_window_size,
146138
)
139+
self.config = config
147140

148141
def get_attention_mask(
149142
self, attn_type: cfg.AttentionType, input_pos: torch.Tensor
@@ -174,7 +167,6 @@ def forward(
174167
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
175168
rope = rotary_pos_emb.build_rope(input_pos, n_elem, attn_config.rotary_base)
176169
if mask is None:
177-
assert self.mask_cache is not None, "Mask cache must be built."
178170
mask = [
179171
self.get_attention_mask(
180172
self.config.block_config(i).attn_config.attn_type, input_pos
@@ -230,8 +222,16 @@ def _forward_with_embeds(
230222
return {"logits": res, "kv_cache": updated_kv_cache}
231223

232224

233-
def get_model_config_2b() -> cfg.ModelConfig:
234-
"""Returns the model config for a Gemma2 2B model."""
225+
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
226+
"""Returns the model config for a Gemma2 2B model.
227+
228+
Args:
229+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
230+
is 1024.
231+
232+
Returns:
233+
The model config for a Gemma 2B model.
234+
"""
235235
norm_config = cfg.NormalizationConfig(
236236
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True
237237
)
@@ -277,6 +277,7 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
277277
max_seq_len=8192,
278278
embedding_dim=embedding_dim,
279279
embedding_scale=embedding_dim**0.5,
280+
kv_cache_max_len=kv_cache_max_len,
280281
block_configs=[get_block_config(i) for i in range(num_layers)],
281282
final_norm_config=norm_config,
282283
lm_head_use_bias=False,
@@ -285,11 +286,11 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
285286
return config
286287

287288

288-
def get_fake_model_config() -> cfg.ModelConfig:
289-
config = get_model_config_2b()
289+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
290+
config = get_model_config_2b(kv_cache_max_len)
290291
config.vocab_size = 128
291292
config.num_layers = 2
292-
config.max_seq_len = 256
293+
config.max_seq_len = 2 * kv_cache_max_len
293294
config.embedding_dim = 128
294295
config.embedding_scale = config.embedding_dim**0.5
295296
config.block_configs = config.block_configs[: config.num_layers]
@@ -304,17 +305,16 @@ def get_fake_model_config() -> cfg.ModelConfig:
304305
def build_2b_model(
305306
checkpoint_path: str,
306307
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
307-
mask_cache_size: int = 0,
308+
**kwargs,
308309
) -> nn.Module:
309310
for tensor_names in TENSOR_NAMES_DICT.values():
310311
try:
311312
return model_builder.build_decoder_only_model(
312313
checkpoint_path=checkpoint_path,
313-
config=get_model_config_2b(),
314+
config=get_model_config_2b(**kwargs),
314315
tensor_names=tensor_names,
315316
model_class=Gemma2,
316317
custom_loader=custom_loader,
317-
mask_cache_size=mask_cache_size,
318318
)
319319
except KeyError as _:
320320
continue

ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def main(_):
4040
custom_loader=loader.maybe_get_custom_loader(
4141
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
4242
),
43-
mask_cache_size=converter.get_mask_cache_size_from_flags(),
43+
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
4444
)
4545
else:
4646
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
@@ -50,7 +50,6 @@ def main(_):
5050
output_path=flags.FLAGS.output_path,
5151
output_name_prefix=flags.FLAGS.output_name_prefix,
5252
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
53-
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
5453
quantize=flags.FLAGS.quantize,
5554
lora_ranks=flags.FLAGS.lora_ranks,
5655
export_config=export_config.get_from_flags(),

0 commit comments

Comments
 (0)