15
15
16
16
"""Common utility functions for model conversion."""
17
17
18
+ import enum
18
19
import os
19
20
import pathlib
20
21
from typing import Optional , Union
@@ -42,6 +43,27 @@ def forward(self, *export_args, **export_kwargs):
42
43
return self .module (* export_args , ** full_kwargs )
43
44
44
45
46
+ class QuantizationName (str , enum .Enum ):
47
+ """Strings for all supported quantization recipes.
48
+
49
+ none: No quantization.
50
+ dynamic_int8: Dynamic range quantization with int8 weights.
51
+ weight_only_int8: Weight only quantization with int8 weights.
52
+ fp16: Float16 quantization.
53
+ dynamic_int4_block32: Dynamic range quantization with int4 weights and block
54
+ size of 32, better model quality but slower inference.
55
+ dynamic_int4_block128: Dynamic range quantization with int4 weights and block
56
+ size of 128, faster inference but worse model quality.
57
+ """
58
+
59
+ NONE = 'none'
60
+ DYNAMIC_INT8 = 'dynamic_int8'
61
+ WEIGHT_ONLY_INT8 = 'weight_only_int8'
62
+ FP16 = 'fp16'
63
+ DYNAMIC_INT4_BLOCK32 = 'dynamic_int4_block32'
64
+ DYNAMIC_INT4_BLOCK128 = 'dynamic_int4_block128'
65
+
66
+
45
67
def define_conversion_flags (
46
68
model_name : str ,
47
69
default_mask_as_input : bool = False ,
@@ -74,10 +96,10 @@ def define_conversion_flags(
74
96
1280 ,
75
97
'The maximum size of KV cache buffer, including both prefill and decode.' ,
76
98
)
77
- flags .DEFINE_bool (
99
+ flags .DEFINE_string (
78
100
'quantize' ,
79
- True ,
80
- 'Whether the model should be quantized.' ,
101
+ 'dynamic_int8' ,
102
+ 'How the model should be quantized.' ,
81
103
)
82
104
flags .DEFINE_multi_integer (
83
105
'lora_ranks' ,
@@ -99,6 +121,66 @@ def define_conversion_flags(
99
121
return flags
100
122
101
123
124
+ def get_quant_recipe_from_flag (
125
+ quantize : str ,
126
+ ) -> Optional [quant_recipes .QuantizationRecipe ]:
127
+ """Processes the quantization flag and returns the corresponding recipe.
128
+
129
+ Args:
130
+ quantize: The quantization type.
131
+
132
+ Returns:
133
+ The quantization recipe, or None if no quantization is needed.
134
+
135
+ Raises:
136
+ ValueError: If the quantization type is not supported.
137
+ """
138
+ match quantize :
139
+ case QuantizationName .NONE :
140
+ return None
141
+ case QuantizationName .DYNAMIC_INT8 :
142
+ return quant_recipes .full_int8_dynamic_recipe ()
143
+ case QuantizationName .WEIGHT_ONLY_INT8 :
144
+ return quant_recipes .full_int8_weight_only_recipe ()
145
+ case QuantizationName .FP16 :
146
+ return quant_recipes .full_fp16_recipe ()
147
+ case QuantizationName .DYNAMIC_INT4_BLOCK32 :
148
+ return quant_recipes .full_int4_dynamic_block_recipe (32 )
149
+ case QuantizationName .DYNAMIC_INT4_BLOCK128 :
150
+ return quant_recipes .full_int4_dynamic_block_recipe (128 )
151
+ case _:
152
+ raise ValueError (f'Unsupported quantization flag: { quantize } ' )
153
+
154
+
155
+ def create_quantize_suffix (quantize : str ) -> str :
156
+ """Creates a suffix for the output file name based on the quantization type.
157
+
158
+ Args:
159
+ quantize: The quantization type.
160
+
161
+ Returns:
162
+ A string representing the quantization suffix.
163
+
164
+ Raises:
165
+ ValueError: If the quantization type is not supported.
166
+ """
167
+ match quantize :
168
+ case QuantizationName .NONE :
169
+ return 'f32'
170
+ case QuantizationName .DYNAMIC_INT8 :
171
+ return 'q8'
172
+ case QuantizationName .WEIGHT_ONLY_INT8 :
173
+ return 'q8_wo'
174
+ case QuantizationName .FP16 :
175
+ return 'fp16'
176
+ case QuantizationName .DYNAMIC_INT4_BLOCK32 :
177
+ return 'q4_block32'
178
+ case QuantizationName .DYNAMIC_INT4_BLOCK128 :
179
+ return 'q4_block128'
180
+ case _:
181
+ raise ValueError (f'Unsupported quantization flag: { quantize } ' )
182
+
183
+
102
184
def _build_mask (mask_len , kv_cache_max_len , causal_mask_value ) -> torch .Tensor :
103
185
if isinstance (mask_len , list ):
104
186
return [
@@ -118,7 +200,7 @@ def convert_to_tflite(
118
200
prefill_seq_len : Union [int , list [int ]],
119
201
pixel_values_size : torch .Size = None ,
120
202
pixel_seq_len : int = 0 ,
121
- quantize : bool = True ,
203
+ quantize : str = 'dynamic_int8' ,
122
204
config : cfg .ModelConfig = None ,
123
205
lora_ranks : Optional [list [int ]] = None ,
124
206
export_config : ExportConfig = None ,
@@ -164,8 +246,8 @@ def convert_to_tflite(
164
246
embeddings generated by the image encoder with pixel values. The actual
165
247
length of prefill_seq_len will be added by pixel_seq_len when pixel
166
248
values are passed.
167
- quantize (bool , optional): Whether the model should be quanized . Defaults
168
- to True .
249
+ quantize (str , optional): The quantization type . Defaults to
250
+ 'dynamic_int8' .
169
251
config (cfg.ModelConfig, optional): The model config used to configure KV
170
252
cache. If None, it uses the config of the pytorch_model.
171
253
lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
@@ -186,7 +268,7 @@ def convert_to_tflite(
186
268
lora = lora_utils .LoRA .zeros (rank , config )
187
269
loras .append (lora )
188
270
189
- quant_suffix = 'q8' if quantize else 'f32'
271
+ quant_suffix = create_quantize_suffix ( quantize )
190
272
kv_size = config .kv_cache_max_len
191
273
lora_suffix = (
192
274
'' if not lora_ranks else f'_lora{ "," .join (map (str , lora_ranks ))} '
@@ -220,7 +302,7 @@ def _export_helper(
220
302
prefill_seq_lens : list [int ],
221
303
pixel_values_size : torch .Size ,
222
304
pixel_seq_len : int ,
223
- quantize : bool ,
305
+ quantize : str ,
224
306
config : cfg .ModelConfig ,
225
307
loras : list [None | lora_utils .LoRA ],
226
308
export_config : ExportConfig ,
@@ -269,7 +351,7 @@ def _export_helper(
269
351
kv_layout = export_config .kvcache_layout ,
270
352
)
271
353
272
- quant_config = quant_recipes . full_int8_dynamic_recipe () if quantize else None
354
+ quant_config = get_quant_recipe_from_flag ( quantize )
273
355
quant_config ._model_config = config
274
356
275
357
# For export, we create a module that captures any non-exportable,
0 commit comments