Skip to content

Commit f219b46

Browse files
zichuan-weicopybara-github
authored andcommitted
fix: set quantization recipe for the softmax with embedding's value
PiperOrigin-RevId: 756931600
1 parent 750f58e commit f219b46

File tree

6 files changed

+26
-6
lines changed

6 files changed

+26
-6
lines changed

ai_edge_torch/generative/quantize/quant_recipe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
from dataclasses import dataclass
1717
from typing import Optional, Union
1818

19+
from ai_edge_torch.generative.layers import model_config
1920
from ai_edge_torch.generative.quantize import quant_attrs
2021
from ai_edge_torch.generative.quantize import supported_schemes
2122

23+
ModelConfig = model_config.ModelConfig
24+
2225

2326
@dataclass
2427
class LayerQuantRecipe:
@@ -52,7 +55,7 @@ def __str__(self):
5255
f'w:{self.weight_dtype.name}, '
5356
f'{self.mode.name}, '
5457
f'{self.algorithm.name}, '
55-
f'{self.granularity.name}'
58+
f'{self.granularity.name}, '
5659
f'{self.block_size}'
5760
)
5861
return f'{base_str})'
@@ -133,6 +136,7 @@ class GenerativeQuantRecipe:
133136
feedforward: Union[
134137
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
135138
] = None
139+
_model_config: Optional[ModelConfig] = None
136140

137141
def __str__(self):
138142
return f"""GenerativeQuantRecipe(

ai_edge_torch/generative/quantize/quant_recipes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def all_supported_int4_dynamic_block_recipe(
6363
generative_recipe=quant_recipe.GenerativeQuantRecipe(
6464
default=quant_recipe_utils.create_layer_quant_int4_dynamic_block(
6565
block_size
66-
)
66+
),
67+
embedding=quant_recipe_utils.create_layer_quant_int8_dynamic(),
6768
)
6869
)

ai_edge_torch/generative/test/test_quantize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# ==============================================================================
1515

1616
import ai_edge_torch
17-
from ai_edge_torch import config
1817
from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
1918
from ai_edge_torch.generative.quantize import quant_recipe
2019
from ai_edge_torch.generative.quantize import quant_recipe_utils

ai_edge_torch/generative/utilities/converter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def _export_helper(
270270
)
271271

272272
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
273+
quant_config._model_config = config
273274

274275
# For export, we create a module that captures any non-exportable,
275276
# arugments, e.g. the generation config object.
@@ -334,5 +335,7 @@ def _export_helper(
334335
sample_kwargs=sample_kwargs,
335336
)
336337

337-
edge_model = converter.convert(quant_config=quant_config)
338+
edge_model = converter.convert(
339+
quant_config=quant_config,
340+
)
338341
edge_model.export(output_file)

ai_edge_torch/lowertools/_shim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def exported_programs_to_tflite(
5050
*,
5151
quant_config: Optional[qcfg.QuantConfig] = None,
5252
_tfl_converter_flags: Optional[dict[str, Any]] = None,
53-
_saved_model_dir: Optional[str] = None
53+
_saved_model_dir: Optional[str] = None,
5454
):
5555
"""Converts a list of ExportedProgram to a TFLite model.
5656

ai_edge_torch/lowertools/translate_recipe.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
_ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention'
3030
_FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward'
3131
_EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
32+
# TODO: b/415833584 - Improve the regex for pre-softmax layer.
33+
_DECODE_LOGITS_REGEX_STR = 'StatefulPartitionedCall'
3234
_ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
3335

3436

@@ -95,10 +97,11 @@ def _set_quant_config(
9597
rm: quantizer.recipe_manager.RecipeManager,
9698
layer_recipe: quant_recipe.LayerQuantRecipe,
9799
regex: str,
100+
operation_name: _OpName = _OpName.ALL_SUPPORTED,
98101
):
99102
rm.add_quantization_config(
100103
regex=regex,
101-
operation_name=_OpName.ALL_SUPPORTED,
104+
operation_name=operation_name,
102105
op_config=_OpQuantConfig(
103106
weight_tensor_config=_TensorQuantConfig(
104107
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
@@ -126,6 +129,16 @@ def translate_to_ai_edge_recipe(
126129

127130
if recipe.embedding is not None:
128131
_set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)
132+
if (
133+
recipe._model_config is not None
134+
and recipe._model_config.lm_head_share_weight_with_embedding
135+
):
136+
_set_quant_config(
137+
rm,
138+
recipe.embedding,
139+
_DECODE_LOGITS_REGEX_STR,
140+
_OpName.FULLY_CONNECTED,
141+
)
129142

130143
if recipe.attention is not None:
131144
if isinstance(recipe.attention, dict):

0 commit comments

Comments
 (0)