Skip to content

Commit 869f6ad

Browse files
zichuan-weicopybara-github
authored andcommitted
enable user choice for various quantization schemes
PiperOrigin-RevId: 756939691
1 parent f219b46 commit 869f6ad

File tree

2 files changed

+96
-11
lines changed

2 files changed

+96
-11
lines changed

ai_edge_torch/generative/tools/batch_convert.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,12 @@ def convert_models(conversion_configs: Sequence[ConversionConfig]) -> None:
282282
)
283283
converter.convert_to_tflite(
284284
pytorch_model,
285-
tflite_path=os.path.join(config.tflite_output_path, output_filename),
285+
output_path=config.tflite_output_path,
286+
output_name_prefix=output_filename,
286287
prefill_seq_len=config.prefill_seq_lens,
287-
quantize=True if precision == ExportPrecision.INT8 else False,
288+
quantize=converter.QuantizationName.DYNAMIC_INT8
289+
if precision == ExportPrecision.INT8
290+
else converter.QuantizationName.NONE,
288291
export_config=ExportConfig(),
289292
)
290293
logging.info("Successfully converted model: %s", output_filename)

ai_edge_torch/generative/utilities/converter.py

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Common utility functions for model conversion."""
1717

18+
import enum
1819
import os
1920
import pathlib
2021
from typing import Optional, Union
@@ -42,6 +43,27 @@ def forward(self, *export_args, **export_kwargs):
4243
return self.module(*export_args, **full_kwargs)
4344

4445

46+
class QuantizationName(str, enum.Enum):
47+
"""Strings for all supported quantization recipes.
48+
49+
none: No quantization.
50+
dynamic_int8: Dynamic range quantization with int8 weights.
51+
weight_only_int8: Weight only quantization with int8 weights.
52+
fp16: Float16 quantization.
53+
dynamic_int4_block32: Dynamic range quantization with int4 weights and block
54+
size of 32, better model quality but slower inference.
55+
dynamic_int4_block128: Dynamic range quantization with int4 weights and block
56+
size of 128, faster inference but worse model quality.
57+
"""
58+
59+
NONE = 'none'
60+
DYNAMIC_INT8 = 'dynamic_int8'
61+
WEIGHT_ONLY_INT8 = 'weight_only_int8'
62+
FP16 = 'fp16'
63+
DYNAMIC_INT4_BLOCK32 = 'dynamic_int4_block32'
64+
DYNAMIC_INT4_BLOCK128 = 'dynamic_int4_block128'
65+
66+
4567
def define_conversion_flags(
4668
model_name: str,
4769
default_mask_as_input: bool = False,
@@ -74,10 +96,10 @@ def define_conversion_flags(
7496
1280,
7597
'The maximum size of KV cache buffer, including both prefill and decode.',
7698
)
77-
flags.DEFINE_bool(
99+
flags.DEFINE_string(
78100
'quantize',
79-
True,
80-
'Whether the model should be quantized.',
101+
'dynamic_int8',
102+
'How the model should be quantized.',
81103
)
82104
flags.DEFINE_multi_integer(
83105
'lora_ranks',
@@ -99,6 +121,66 @@ def define_conversion_flags(
99121
return flags
100122

101123

124+
def get_quant_recipe_from_flag(
125+
quantize: str,
126+
) -> Optional[quant_recipes.QuantizationRecipe]:
127+
"""Processes the quantization flag and returns the corresponding recipe.
128+
129+
Args:
130+
quantize: The quantization type.
131+
132+
Returns:
133+
The quantization recipe, or None if no quantization is needed.
134+
135+
Raises:
136+
ValueError: If the quantization type is not supported.
137+
"""
138+
match quantize:
139+
case QuantizationName.NONE:
140+
return None
141+
case QuantizationName.DYNAMIC_INT8:
142+
return quant_recipes.full_int8_dynamic_recipe()
143+
case QuantizationName.WEIGHT_ONLY_INT8:
144+
return quant_recipes.full_int8_weight_only_recipe()
145+
case QuantizationName.FP16:
146+
return quant_recipes.full_fp16_recipe()
147+
case QuantizationName.DYNAMIC_INT4_BLOCK32:
148+
return quant_recipes.full_int4_dynamic_block_recipe(32)
149+
case QuantizationName.DYNAMIC_INT4_BLOCK128:
150+
return quant_recipes.full_int4_dynamic_block_recipe(128)
151+
case _:
152+
raise ValueError(f'Unsupported quantization flag: {quantize}')
153+
154+
155+
def create_quantize_suffix(quantize: str) -> str:
156+
"""Creates a suffix for the output file name based on the quantization type.
157+
158+
Args:
159+
quantize: The quantization type.
160+
161+
Returns:
162+
A string representing the quantization suffix.
163+
164+
Raises:
165+
ValueError: If the quantization type is not supported.
166+
"""
167+
match quantize:
168+
case QuantizationName.NONE:
169+
return 'f32'
170+
case QuantizationName.DYNAMIC_INT8:
171+
return 'q8'
172+
case QuantizationName.WEIGHT_ONLY_INT8:
173+
return 'q8_wo'
174+
case QuantizationName.FP16:
175+
return 'fp16'
176+
case QuantizationName.DYNAMIC_INT4_BLOCK32:
177+
return 'q4_block32'
178+
case QuantizationName.DYNAMIC_INT4_BLOCK128:
179+
return 'q4_block128'
180+
case _:
181+
raise ValueError(f'Unsupported quantization flag: {quantize}')
182+
183+
102184
def _build_mask(mask_len, kv_cache_max_len, causal_mask_value) -> torch.Tensor:
103185
if isinstance(mask_len, list):
104186
return [
@@ -118,7 +200,7 @@ def convert_to_tflite(
118200
prefill_seq_len: Union[int, list[int]],
119201
pixel_values_size: torch.Size = None,
120202
pixel_seq_len: int = 0,
121-
quantize: bool = True,
203+
quantize: str = 'dynamic_int8',
122204
config: cfg.ModelConfig = None,
123205
lora_ranks: Optional[list[int]] = None,
124206
export_config: ExportConfig = None,
@@ -164,8 +246,8 @@ def convert_to_tflite(
164246
embeddings generated by the image encoder with pixel values. The actual
165247
length of prefill_seq_len will be added by pixel_seq_len when pixel
166248
values are passed.
167-
quantize (bool, optional): Whether the model should be quanized. Defaults
168-
to True.
249+
quantize (str, optional): The quantization type. Defaults to
250+
'dynamic_int8'.
169251
config (cfg.ModelConfig, optional): The model config used to configure KV
170252
cache. If None, it uses the config of the pytorch_model.
171253
lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
@@ -186,7 +268,7 @@ def convert_to_tflite(
186268
lora = lora_utils.LoRA.zeros(rank, config)
187269
loras.append(lora)
188270

189-
quant_suffix = 'q8' if quantize else 'f32'
271+
quant_suffix = create_quantize_suffix(quantize)
190272
kv_size = config.kv_cache_max_len
191273
lora_suffix = (
192274
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
@@ -220,7 +302,7 @@ def _export_helper(
220302
prefill_seq_lens: list[int],
221303
pixel_values_size: torch.Size,
222304
pixel_seq_len: int,
223-
quantize: bool,
305+
quantize: str,
224306
config: cfg.ModelConfig,
225307
loras: list[None | lora_utils.LoRA],
226308
export_config: ExportConfig,
@@ -269,7 +351,7 @@ def _export_helper(
269351
kv_layout=export_config.kvcache_layout,
270352
)
271353

272-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
354+
quant_config = get_quant_recipe_from_flag(quantize)
273355
quant_config._model_config = config
274356

275357
# For export, we create a module that captures any non-exportable,

0 commit comments

Comments
 (0)