Skip to content

Commit 755e63d

Browse files
[OV] Add quantization for text2text-generation models (#1359)
* Initial commit * Add tests * Fix tests * Extend OVModelForSeq2SeqLM for integration with lm-evaluation-harness * Cleanup * Set default configs for T5 models * Update docs * Move table row higher * Add reference to the original code block
1 parent 1a10775 commit 755e63d

File tree

8 files changed

+235
-46
lines changed

8 files changed

+235
-46
lines changed

docs/source/openvino/optimization.mdx

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,59 @@ Click on a ✅ to copy the command/code for the corresponding optimization case.
623623
</button>
624624
</td>
625625
</tr>
626+
<tr>
627+
<td style="text-align: center; vertical-align: middle;">text2text-generation<br>(OVModelForSeq2SeqLM)</td>
628+
<td style="text-align: center; vertical-align: middle;">
629+
<button
630+
onclick="navigator.clipboard.writeText('optimum-cli export openvino -m google-t5/t5-small --weight-format int8 ./save_dir')">
631+
632+
</button>
633+
</td>
634+
<td style="text-align: center; vertical-align: middle;">
635+
<button
636+
onclick="navigator.clipboard.writeText('OVModelForSeq2SeqLM.from_pretrained(\'google-t5/t5-small\', quantization_config=OVWeightQuantizationConfig(bits=8)).save_pretrained(\'save_dir\')')">
637+
638+
</button>
639+
</td>
640+
<td style="text-align: center; vertical-align: middle;">
641+
<button
642+
onclick="navigator.clipboard.writeText('optimum-cli export openvino -m google-t5/t5-small --weight-format int4 --dataset wikitext2 ./save_dir')">
643+
644+
</button>
645+
</td>
646+
<td style="text-align: center; vertical-align: middle;">
647+
<button
648+
onclick="navigator.clipboard.writeText('OVModelForSeq2SeqLM.from_pretrained(\'google-t5/t5-small\', quantization_config=OVWeightQuantizationConfig(bits=4, dataset=\'wikitext2\')).save_pretrained(\'save_dir\')')">
649+
650+
</button>
651+
</td>
652+
<td style="text-align: center; vertical-align: middle;">–</td>
653+
<td style="text-align: center; vertical-align: middle;">-</td>
654+
<td style="text-align: center; vertical-align: middle;">
655+
<button
656+
onclick="navigator.clipboard.writeText('optimum-cli export openvino -m google-t5/t5-small --quant-mode int8 --dataset wikitext2 --smooth-quant-alpha -1 ./save_dir')">
657+
658+
</button>
659+
</td>
660+
<td style="text-align: center; vertical-align: middle;">
661+
<button
662+
onclick="navigator.clipboard.writeText('OVModelForSeq2SeqLM.from_pretrained(\'google-t5/t5-small\', quantization_config=OVQuantizationConfig(bits=8, dataset=\'wikitext2\', smooth_quant_alpha=-1)).save_pretrained(\'save_dir\')')">
663+
664+
</button>
665+
</td>
666+
<td style="text-align: center; vertical-align: middle;">
667+
<button
668+
onclick="navigator.clipboard.writeText('optimum-cli export openvino -m google-t5/t5-small --quant-mode nf4_f8e4m3 --dataset wikitext2 --smooth-quant-alpha -1 ./save_dir')">
669+
670+
</button>
671+
</td>
672+
<td style="text-align: center; vertical-align: middle;">
673+
<button
674+
onclick="navigator.clipboard.writeText('OVModelForSeq2SeqLM.from_pretrained(\'google-t5/t5-small\', quantization_config=OVMixedQuantizationConfig(OVWeightQuantizationConfig(bits=4, dtype=\'nf4\'), OVQuantizationConfig(dtype=\'f8e4m3\', dataset=\'wikitext2\', smooth_quant_alpha=-1))).save_pretrained(\'save_dir\')')">
675+
676+
</button>
677+
</td>
678+
</tr>
626679
<tr>
627680
<td style="text-align: center; vertical-align: middle;">zero-shot-image-classification<br>(OVModelForZeroShotImageClassification)</td>
628681
<td style="text-align: center; vertical-align: middle;">

optimum/commands/export/openvino.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ def run(self):
482482
and (
483483
task in ["fill-mask", "zero-shot-image-classification"]
484484
or task.startswith("text-generation")
485+
or task.startswith("text2text-generation")
485486
or task.startswith("automatic-speech-recognition")
486487
or task.startswith("feature-extraction")
487488
)
@@ -491,6 +492,10 @@ def run(self):
491492
from optimum.intel import OVModelForCausalLM
492493

493494
model_cls = OVModelForCausalLM
495+
elif task.startswith("text2text-generation"):
496+
from optimum.intel import OVModelForSeq2SeqLM
497+
498+
model_cls = OVModelForSeq2SeqLM
494499
elif task == "image-text-to-text":
495500
from optimum.intel import OVModelForVisualCausalLM
496501

optimum/intel/openvino/configuration.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,24 @@ class OVQuantizationMethod(str, Enum):
361361

362362
# Default configs for int8 full quantization
363363
_DEFAULT_INT8_FQ_CONFIGS = {
364+
"google-t5/t5-small": {
365+
"dtype": "int8",
366+
"dataset": "wikitext2",
367+
"num_samples": 300,
368+
"smooth_quant_alpha": -1,
369+
},
370+
"google-t5/t5-large": {
371+
"dtype": "int8",
372+
"dataset": "wikitext2",
373+
"num_samples": 300,
374+
"smooth_quant_alpha": -1,
375+
},
376+
"google-t5/t5-3b": {
377+
"dtype": "int8",
378+
"dataset": "wikitext2",
379+
"num_samples": 300,
380+
"smooth_quant_alpha": -1,
381+
},
364382
"FacebookAI/roberta-large": {
365383
"dtype": "int8",
366384
"dataset": "wikitext2",

optimum/intel/openvino/modeling_seq2seq.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import copy
1514
import logging
1615
import os
1716
from pathlib import Path
@@ -42,9 +41,9 @@
4241

4342
from ...exporters.openvino import main_export
4443
from ...exporters.openvino.stateful import model_has_state
45-
from .. import OVConfig, OVQuantizer
44+
from .. import OVConfig
4645
from ..utils import is_transformers_version
47-
from .configuration import OVQuantizationConfigBase, OVWeightQuantizationConfig
46+
from .configuration import OVWeightQuantizationConfig
4847
from .modeling_base import OVBaseModel
4948
from .utils import (
5049
ONNX_DECODER_NAME,
@@ -477,7 +476,6 @@ def _from_pretrained(
477476
decoder_with_past_file_name = decoder_with_past_file_name or default_decoder_with_past_file_name
478477
decoder_with_past = None
479478

480-
quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit)
481479
compile_only = kwargs.pop("compile_only", False)
482480
device = kwargs.pop("device", "CPU")
483481
ov_config = kwargs.pop("ov_config", None)
@@ -521,10 +519,10 @@ def _from_pretrained(
521519
"decoder_with_past": model_save_dir / decoder_with_past_file_name,
522520
}
523521
if not compile_only:
524-
encoder = cls.load_model(file_names["encoder"], quantization_config)
525-
decoder = cls.load_model(file_names["decoder"], quantization_config)
522+
encoder = cls.load_model(file_names["encoder"])
523+
decoder = cls.load_model(file_names["decoder"])
526524
if use_cache and not model_has_state(decoder) and os.path.exists(file_names["decoder_with_past"]):
527-
decoder_with_past = cls.load_model(file_names["decoder_with_past"], quantization_config)
525+
decoder_with_past = cls.load_model(file_names["decoder_with_past"])
528526
else:
529527
model_kwargs = {"device": device, "ov_config": ov_config, "model_save_dir": model_save_dir}
530528
encoder = cls._compile_model(file_names["encoder"], **model_kwargs)
@@ -551,7 +549,8 @@ def _from_pretrained(
551549
"Generation config file not found, using a generation config created from the model config."
552550
)
553551

554-
return cls(
552+
quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit)
553+
model = cls(
555554
encoder=encoder,
556555
decoder=decoder,
557556
decoder_with_past=decoder_with_past,
@@ -565,6 +564,17 @@ def _from_pretrained(
565564
**kwargs,
566565
)
567566

567+
if quantization_config is not None:
568+
from optimum.intel import OVQuantizer
569+
570+
quantizer = OVQuantizer(model)
571+
quantization_config_copy = quantization_config.clone()
572+
quantization_config_copy.tokenizer = quantization_config.tokenizer or model_id
573+
quantization_config_copy.processor = quantization_config.processor or model_id
574+
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config_copy))
575+
576+
return model
577+
568578
@classmethod
569579
def _export(
570580
cls,
@@ -657,12 +667,17 @@ def forward(
657667
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
658668
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
659669
cache_position: Optional[torch.LongTensor] = None,
670+
labels: Optional[torch.LongTensor] = None,
660671
**kwargs,
661672
) -> Seq2SeqLMOutput:
662673
# Encode if needed : first prediction pass
663674
if encoder_outputs is None:
664675
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
665676

677+
if labels is not None and decoder_input_ids is None:
678+
# get decoder inputs from shifting lm labels to the right
679+
decoder_input_ids = self._shift_right(labels)
680+
666681
# Decode
667682
if past_key_values is None or self.decoder_with_past is None:
668683
decoder_outputs = self.decoder(
@@ -786,6 +801,28 @@ def compile(self):
786801
for submodel_name in self._ov_submodel_names:
787802
getattr(self, submodel_name)._compile()
788803

804+
def _shift_right(self, input_ids):
805+
# Adopted from https://github.com/huggingface/transformers/blob/v4.53.1/src/transformers/models/t5/modeling_tf_t5.py#L957
806+
decoder_start_token_id = self.config.decoder_start_token_id
807+
pad_token_id = self.config.pad_token_id
808+
809+
if decoder_start_token_id is None:
810+
raise ValueError(
811+
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
812+
"See T5 docs for more information."
813+
)
814+
815+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
816+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
817+
shifted_input_ids[..., 0] = decoder_start_token_id
818+
819+
if pad_token_id is None:
820+
raise ValueError("self.model.config.pad_token_id has to be defined.")
821+
# replace possible -100 values in labels by `pad_token_id`
822+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
823+
824+
return shifted_input_ids
825+
789826

790827
class OVEncoder:
791828
"""
@@ -1345,27 +1382,9 @@ def _from_pretrained(
13451382
cls,
13461383
model_id: Union[str, Path],
13471384
config: "PretrainedConfig",
1348-
load_in_8bit: bool = False,
1349-
quantization_config: Union[dict, OVQuantizationConfigBase] = None,
13501385
**kwargs,
13511386
):
1352-
compile_only = kwargs.get("compile_only", False)
1353-
1354-
quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit)
1355-
is_data_aware_quantization = quantization_config is not None and quantization_config.dataset is not None
1356-
if not compile_only and is_data_aware_quantization:
1357-
model = super(OVModelForSpeechSeq2Seq, cls)._from_pretrained(
1358-
model_id, config, load_in_8bit=False, **kwargs
1359-
)
1360-
quantization_config_copy = copy.deepcopy(quantization_config)
1361-
quantization_config_copy.processor = quantization_config.processor or model_id
1362-
OVQuantizer(model).quantize(ov_config=OVConfig(quantization_config=quantization_config_copy))
1363-
else:
1364-
model = super(OVModelForSpeechSeq2Seq, cls)._from_pretrained(
1365-
model_id, config, load_in_8bit=load_in_8bit, quantization_config=quantization_config, **kwargs
1366-
)
1367-
1368-
return model
1387+
return super(OVModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs)
13691388

13701389
class DummyWhisperModel:
13711390
def __init__(self):

optimum/intel/openvino/quantization.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> OV
270270
OVModelForCausalLM,
271271
OVModelForFeatureExtraction,
272272
OVModelForMaskedLM,
273+
OVModelForSeq2SeqLM,
273274
OVModelForVisualCausalLM,
274275
OVModelForZeroShotImageClassification,
275276
OVSentenceTransformer,
@@ -344,7 +345,9 @@ def build_from_quantization_config(self, config: OVQuantizationConfigBase) -> OV
344345
)
345346

346347
return self.build_from_dataset(config, dataset)
347-
elif isinstance(self.model, (OVModelForFeatureExtraction, OVSentenceTransformer, OVModelForMaskedLM)):
348+
elif isinstance(
349+
self.model, (OVModelForFeatureExtraction, OVSentenceTransformer, OVModelForMaskedLM, OVModelForSeq2SeqLM)
350+
):
348351
if isinstance(config.dataset, str):
349352
dataset_metadata = PREDEFINED_LANGUAGE_DATASETS[config.dataset]
350353
dataset = self.load_dataset(
@@ -467,6 +470,7 @@ def build_from_dataset(
467470
from optimum.intel import (
468471
OVModelForFeatureExtraction,
469472
OVModelForMaskedLM,
473+
OVModelForSeq2SeqLM,
470474
OVModelForVisualCausalLM,
471475
OVModelForZeroShotImageClassification,
472476
OVSentenceTransformer,
@@ -492,6 +496,7 @@ def build_from_dataset(
492496
OVModelForMaskedLM,
493497
OVModelForZeroShotImageClassification,
494498
OVSentenceTransformer,
499+
OVModelForSeq2SeqLM,
495500
),
496501
) or (is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline)):
497502
# Prepare from raw dataset avoiding dataloader creation
@@ -504,6 +509,8 @@ def build_from_dataset(
504509
return self._prepare_visual_causal_lm_calibration_data(quantization_config, dataset)
505510
elif isinstance(self.model, _OVModelForWhisper):
506511
return self._prepare_speech_to_text_calibration_data(quantization_config, dataset)
512+
elif isinstance(self.model, OVModelForSeq2SeqLM):
513+
return self._prepare_text_to_text_calibration_data(quantization_config, dataset)
507514
elif is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline):
508515
return self._prepare_diffusion_calibration_data(quantization_config, dataset)
509516
elif isinstance(self.model, (OVModelForFeatureExtraction, OVSentenceTransformer, OVModelForMaskedLM)):
@@ -770,6 +777,56 @@ def _prepare_speech_to_text_calibration_data(
770777

771778
return OVCalibrationDataset(collected_inputs)
772779

780+
def _prepare_text_to_text_calibration_data(
781+
self,
782+
config: OVQuantizationConfigBase,
783+
dataset: "Dataset",
784+
seq_len: int = 128,
785+
) -> OVCalibrationDataset:
786+
"""
787+
Prepares calibration data for text-to-text pipelines by inferring it on a dataset and collecting incurred inputs.
788+
"""
789+
from optimum.intel.openvino.modeling_seq2seq import OVDecoder, OVEncoder
790+
791+
models: Dict[str, Union[OVEncoder, OVDecoder]] = {}
792+
collected_inputs: Dict[str, List[Dict[str, Any]]] = {}
793+
for submodel_name in self.model._ov_submodel_names:
794+
ov_component: Union[OVEncoder, OVDecoder] = getattr(self.model, submodel_name)
795+
models[submodel_name] = ov_component
796+
collected_inputs[submodel_name] = []
797+
ov_component._compile()
798+
ov_component.request = InferRequestWrapper(
799+
ov_component.request, collected_inputs[submodel_name], apply_caching=True
800+
)
801+
try:
802+
803+
def get_tokenizer():
804+
if config.tokenizer is None:
805+
raise ValueError("Please provide tokenizer for calibration via quantization_config.tokenizer.")
806+
return AutoTokenizer.from_pretrained(config.tokenizer, trust_remote_code=config.trust_remote_code)
807+
808+
num_samples = config.num_samples or 128
809+
dataset = list(tqdm(dataset.take(num_samples), desc="Downloading dataset", total=num_samples))
810+
811+
tokenizer = None
812+
for item in tqdm(dataset, desc="Collecting calibration data"):
813+
if "input_ids" in item:
814+
# Assuming that dataset contains already preprocessed text
815+
inputs = self._wrap_sample_as_array(item, add_batch_dim=True)
816+
else:
817+
tokenizer = tokenizer or get_tokenizer()
818+
inputs = tokenizer(item["text"], truncation=True, max_length=seq_len, return_tensors="pt")
819+
820+
self.model.generate(**inputs, max_new_tokens=seq_len)
821+
finally:
822+
for model in models.values():
823+
model.request = model.request.request
824+
825+
for model_name in collected_inputs:
826+
collected_inputs[model_name] = nncf.Dataset(collected_inputs[model_name])
827+
828+
return OVCalibrationDataset(collected_inputs)
829+
773830
def _prepare_diffusion_calibration_data(
774831
self, config: OVQuantizationConfigBase, dataset: Union[List, "Dataset"]
775832
) -> OVCalibrationDataset:
@@ -1202,18 +1259,16 @@ def _quantize_ovbasemodel(
12021259
#
12031260
# Regular (non-hybrid) weight-only quantization
12041261
#
1205-
if is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline):
1206-
for submodel_name in self.model.ov_submodels:
1207-
quantization_configs[submodel_name] = quantization_config
1208-
elif isinstance(self.model, OVModelForVisualCausalLM):
1262+
if isinstance(self.model, OVModelForVisualCausalLM):
12091263
for submodel_name in self.model.ov_submodels:
12101264
quantization_configs[submodel_name] = (
12111265
quantization_config
12121266
if submodel_name == "lm_model"
12131267
else OVWeightQuantizationConfig(bits=8, sym=True)
12141268
)
12151269
else:
1216-
quantization_configs["model"] = quantization_config
1270+
for submodel_name in self.model.ov_submodels:
1271+
quantization_configs[submodel_name] = quantization_config
12171272
else:
12181273
#
12191274
# Hybrid/Full/Mixed quantization
@@ -1274,15 +1329,17 @@ def _quantize_ovbasemodel(
12741329
else OVWeightQuantizationConfig(bits=8, sym=True)
12751330
)
12761331
else:
1277-
quantization_configs["model"] = quantization_config
1332+
for submodel_name in self.model.ov_submodels:
1333+
quantization_configs[submodel_name] = quantization_config
12781334
elif isinstance(quantization_config, OVMixedQuantizationConfig):
12791335
#
12801336
# Mixed quantization
12811337
#
12821338
if is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline):
12831339
raise NotImplementedError("Mixed precision quantization isn't supported for diffusers.")
12841340

1285-
quantization_configs["model"] = quantization_config
1341+
for submodel_name in self.model.ov_submodels:
1342+
quantization_configs[submodel_name] = quantization_config
12861343
else:
12871344
raise ValueError(f"Unsupported type of quantization config: {type(quantization_config)}")
12881345

0 commit comments

Comments
 (0)