Skip to content

Commit a5ca56f

Browse files
Supporting seq2seq models for bitsandbytes integration (#18579)
* Supporting seq2seq models for `bitsandbytes` integration - `bitsandbytes` integration supports now seq2seq models - check if a model has tied weights as an additional check * small modification - tie the weights before looking at tied weights!
1 parent ed1924e commit a5ca56f

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

src/transformers/utils/bitsandbytes.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from copy import deepcopy
2+
13
from transformers.utils import is_accelerate_available, is_bitsandbytes_available
24

35

@@ -9,6 +11,7 @@
911

1012
if is_accelerate_available():
1113
from accelerate import init_empty_weights
14+
from accelerate.utils import find_tied_parameters
1215

1316

1417
def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None):
@@ -132,8 +135,17 @@ def get_key_to_not_convert(model):
132135
model (`torch.nn.Module`):
133136
Input model
134137
"""
138+
# Create a copy of the model and tie the weights, then
139+
# check if it contains tied weights
140+
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
141+
tied_model.tie_weights()
142+
has_tied_params = len(find_tied_parameters(tied_model)) > 0
143+
144+
# Check if it is a base model
145+
is_base_model = not hasattr(model, model.base_model_prefix)
146+
135147
# Ignore this for base models (BertModel, GPT2Model, etc.)
136-
if not hasattr(model, model.base_model_prefix):
148+
if (not has_tied_params) and is_base_model:
137149
return ""
138150

139151
# otherwise they have an attached head

tests/mixed_int8/test_mixed_int8.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515
import gc
1616
import unittest
1717

18-
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, pipeline
18+
from transformers import (
19+
AutoModel,
20+
AutoModelForCausalLM,
21+
AutoModelForSeq2SeqLM,
22+
AutoModelForSequenceClassification,
23+
AutoTokenizer,
24+
pipeline,
25+
)
1926
from transformers.testing_utils import (
2027
is_torch_available,
2128
require_accelerate,
@@ -106,12 +113,21 @@ def setUp(self):
106113
super().setUp()
107114
# model_name
108115
self.model_name = "bigscience/bloom-560m"
109-
# Models and tokenizer
116+
self.seq_to_seq_name = "t5-small"
117+
118+
# Different types of model
119+
110120
self.base_model = AutoModel.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
121+
# Sequence classification model
111122
self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
112123
self.model_name, load_in_8bit=True, device_map="auto"
113124
)
125+
# CausalLM model
114126
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
127+
# Seq2seq model
128+
self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained(
129+
self.seq_to_seq_name, load_in_8bit=True, device_map="auto"
130+
)
115131

116132
def tearDown(self):
117133
r"""
@@ -121,6 +137,7 @@ def tearDown(self):
121137
del self.base_model
122138
del self.sequence_model
123139
del self.model_8bit
140+
del self.seq_to_seq_model
124141

125142
gc.collect()
126143
torch.cuda.empty_cache()
@@ -138,6 +155,7 @@ def test_correct_head_class(self):
138155
# Other heads should be nn.Parameter
139156
self.assertTrue(self.model_8bit.lm_head.weight.__class__ == torch.nn.Parameter)
140157
self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
158+
self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter)
141159

142160

143161
class MixedInt8TestPipeline(BaseMixedInt8Test):

0 commit comments

Comments
 (0)