Skip to content

Commit 4a51075

Browse files
younesbelkadasguggermichaelbenayoun
authored
bitsandbytes - Linear8bitLt integration into transformers models (#17901)
* first commit * correct replace function * add final changes - works like charm! - cannot implement tests yet - tested * clean up a bit * add bitsandbytes dependencies * working version - added import function - added bitsandbytes utils file * small fix * small fix - fix import issue * fix import issues * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * refactor a bit - move bitsandbytes utils to utils - change comments on functions * reformat docstring - reformat docstring on init_empty_weights_8bit * Update src/transformers/__init__.py Co-authored-by: Sylvain Gugger <[email protected]> * revert bad formatting * change to bitsandbytes * refactor a bit - remove init8bit since it is useless * more refactoring - fixed init empty weights issue - added threshold param * small hack to make it work * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * revmoe the small hack * modify utils file * make style + refactor a bit * create correctly device map * add correct dtype for device map creation * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * apply suggestions - remove with torch.grad - do not rely on Python bool magic! * add docstring - add docstring for new kwargs * add docstring - comment `replace_8bit_linear` function - fix weird formatting * - added more documentation - added new utility function for memory footprint tracking - colab demo to add * few modifs - typo doc - force cast into float16 when load_in_8bit is enabled * added colab link * add test architecture + docstring a bit * refactor a bit testing class * make style + refactor a bit * enhance checks - add more checks - start writing saving test * clean up a bit * male style * add more details on doc * add more tests - still needs to fix 2 tests * replace by "or" - could not fix it from GitHub GUI Co-authored-by: Sylvain Gugger <[email protected]> * refactor a bit testing code + add readme * make style * fix import issue * Update src/transformers/modeling_utils.py Co-authored-by: Michael Benayoun <[email protected]> * add few comments * add more doctring + make style * more docstring * raise error when loaded in 8bit * make style * add warning if loaded on CPU * add small sanity check * fix small comment * add bitsandbytes on dockerfile * Improve documentation - improve documentation from comments * add few comments * slow tests pass on the VM but not on the CI VM * Fix merge conflict * make style * another test should pass on a multi gpu setup * fix bad import in testing file * Fix slow tests - remove dummy batches - no more CUDA illegal memory errors * odify dockerfile * Update docs/source/en/main_classes/model.mdx * Update Dockerfile * Update model.mdx * Update Dockerfile * Apply suggestions from code review * few modifications - lm head can stay on disk/cpu - change model name so that test pass * change test value - change test value to the correct output - torch bmm changed to baddmm in bloom modeling when merging * modify installation guidelines * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * replace `n`by `name` * merge `load_in_8bit` and `low_cpu_mem_usage` * first try - keep the lm head in full precision * better check - check the attribute `base_model_prefix` instead of computing the number of parameters * added more tests * Update src/transformers/utils/bitsandbytes.py Co-authored-by: Sylvain Gugger <[email protected]> * Merge branch 'integration-8bit' of https://github.com/younesbelkada/transformers into integration-8bit * improve documentation - fix typos for installation - change title in the documentation Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Michael Benayoun <[email protected]>
1 parent 8cf4a6f commit 4a51075

File tree

9 files changed

+534
-8
lines changed

9 files changed

+534
-8
lines changed

docker/transformers-all-latest-gpu/Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ RUN python3 -m pip install -U "itsdangerous<2.1.0"
4545

4646
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate
4747

48+
# Add bitsandbytes for mixed int8 testing
49+
RUN python3 -m pip install -i https://test.pypi.org/simple/ bitsandbytes==0.31.5
50+
4851
RUN python3 -m pip install --no-cache-dir decord
4952

5053
# When installing in editable mode, `transformers` is not recognized as a package.

docs/source/en/main_classes/model.mdx

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ You can also write your own device map following the same format (a dictionary l
105105
device_map = {"shared": 0, "encoder": 0, "decoder": 1, "lm_head": 1}
106106
```
107107

108-
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`).
108+
Another way to minimize the memory impact of your model is to instantiate it at a lower precision dtype (like `torch.float16`) or use direct quantization techniques as described below.
109109

110110
### Model Instantiation dtype
111111

@@ -133,6 +133,45 @@ model = AutoModel.from_config(config)
133133

134134
Due to Pytorch design, this functionality is only available for floating dtypes.
135135

136+
### `bitsandbytes` integration for Int8 mixed-precision matrix decomposition
137+
138+
From the paper `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale`, we suport HuggingFace 🤗 integration for all models in the Hub with few lines of code.
139+
For models trained in half-precision (aka, either `float16` or `bfloat16`) or full precision. This method aims to reduce `nn.Linear` size by 2 (if trained in half precision) or by 4 if trained in full precision, without affecting too much quality by operating on the outliers in half-precision.
140+
This technique is useful and works well for billion scale models (>1B parameters) therefore we advice you to use it only for models of that scale. This method has been tested for 2-billion to 176-billion scale models and supports only PyTorch models.
141+
142+
![HFxbitsandbytes.png](https://s3.amazonaws.com/moonup/production/uploads/1659861207959-62441d1d9fdefb55a0b7d12c.png)
143+
144+
Int8 mixed-precision matrix decomposition works by separating a matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no predictive degradation is possible for very large models (>=176B parameters).
145+
Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models (small models, fine-tuning).
146+
147+
Note also that you would require a GPU to run mixed-8bit models as the kernels has been compiled for GPUs only. Make sure that you have enough GPU RAM to store the quarter (or half if your model is natively in half precision) of the model before using this feature.
148+
149+
Below are some notes to help you use this module, or follow this demo on Google colab: [![Open In Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qOjXfQIAULfKvZqwCen8-MoWKGdSatZ4?usp=sharing)
150+
151+
#### Requirements
152+
153+
- Make sure you run that on a NVIDIA GPU that supports 8-bit tensor cores (Turing or Ampere GPUs - e.g. T4, RTX20s RTX30s, A40-A100). Note that previous generations of NVIDIA GPUs do not support 8-bit tensor cores.
154+
- Install the correct version of `bitsandbytes` by running:
155+
`pip install -i https://test.pypi.org/simple/ bitsandbytes`
156+
- Install `accelerate`:
157+
`pip install accelerate`
158+
159+
#### Running mixed-int8 models
160+
161+
After carefully installing the required libraries, the way to load your mixed 8-bit model is as follows:
162+
```py
163+
model_name = "bigscience/bloom-2b5"
164+
model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
165+
```
166+
The implementation supports multi-GPU setup thanks to `accelerate` as backend. If you want to control the GPU memory you want to allocate for each GPU, you can use the `max_memory` argument as follows:
167+
(If allocating `1GB` into GPU-0 and `2GB` into GPU-1, you can use `max_memory={0:"1GB", 1:"2GB"}`)
168+
```py
169+
max_memory_mapping = {0: "1GB", 1: "2GB"}
170+
model_name = "bigscience/bloom-3b"
171+
model_8bit = AutoModelForCausalLM.from_pretrained(
172+
model_name, device_map="auto", load_in_8bit=True, max_memory=max_memory_mapping
173+
)
174+
```
136175

137176

138177
## ModuleUtilsMixin

src/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@
462462
"is_vision_available",
463463
"logging",
464464
],
465+
"utils.bitsandbytes": [],
465466
}
466467

467468
# sentencepiece-backed objects

src/transformers/modeling_utils.py

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
copy_func,
6262
has_file,
6363
is_accelerate_available,
64+
is_bitsandbytes_available,
6465
is_offline_mode,
6566
logging,
6667
replace_return_docstrings,
@@ -83,6 +84,9 @@
8384
else:
8485
get_balanced_memory = None
8586

87+
if is_bitsandbytes_available():
88+
from .utils.bitsandbytes import get_key_to_not_convert, replace_8bit_linear, set_module_8bit_tensor_to_device
89+
8690
logger = logging.get_logger(__name__)
8791

8892

@@ -501,6 +505,7 @@ def _load_state_dict_into_meta_model(
501505
state_dict_folder=None,
502506
state_dict_index=None,
503507
dtype=None,
508+
load_in_8bit=False,
504509
):
505510
"""
506511
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
@@ -561,13 +566,14 @@ def _load_state_dict_into_meta_model(
561566
# TODO: group all errors and raise at the end.
562567
raise ValueError(f"{param_name} doesn't have any device set.")
563568
param_device = device_map[module_name]
564-
565569
if param_device == "disk":
566570
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
567571
elif param_device == "cpu" and state_dict_index is not None:
568572
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
569-
else:
573+
elif not load_in_8bit:
570574
set_module_tensor_to_device(model, param_name, param_device, value=param)
575+
else:
576+
set_module_8bit_tensor_to_device(model, param_name, param_device, value=param)
571577

572578
return error_msgs, offload_index, state_dict_index
573579

@@ -1578,6 +1584,24 @@ def save_pretrained(
15781584
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
15791585
)
15801586

1587+
def get_memory_footprint(self, return_buffers=True):
1588+
r"""
1589+
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
1590+
Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
1591+
PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
1592+
1593+
Arguments:
1594+
return_buffers (`bool`, *optional*, defaults to `True`):
1595+
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
1596+
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
1597+
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
1598+
"""
1599+
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
1600+
if return_buffers:
1601+
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
1602+
mem = mem + mem_bufs
1603+
return mem
1604+
15811605
@classmethod
15821606
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
15831607
r"""
@@ -1707,6 +1731,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
17071731
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
17081732
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
17091733
`True` when there is some disk offload.
1734+
load_in_8bit (`bool`, *optional*, defaults to `False`):
1735+
If `True`, will convert the loaded model into mixed-8bit quantized model. To use this feature please
1736+
install `bitsandbytes` compiled with your CUDA version by running `pip install -i
1737+
https://test.pypi.org/simple/ bitsandbytes-cudaXXX` where XXX is your CUDA version (e.g. 11.6 = 116).
1738+
Make also sure that you have enough GPU RAM to store half of the model size since the 8bit modules are
1739+
not compiled and adapted for CPUs.
1740+
int8_threshold (`float`, *optional*, defaults to 6):
1741+
Works together with `load_in_8bit`. This corresponds to the outlier threshold for outlier detection as
1742+
described in `GPT3.int8() : 8-bit Matrix Multiplication for Transformers at Scale` paper. Any hidden
1743+
states value that is above this threshold will be considered an outlier and the operation on those
1744+
values will be done in fp16. Values are usually normally distributed, that is, most values are in the
1745+
range [-3.5, 3.5], but there are some exceptional systematic outliers that are very differently
1746+
distributed for large models. These outliers are often in the interval [-60, -6] or [6, 60]. Int8
1747+
quantization works well for values of magnitude ~5, but beyond that, there is a significant performance
1748+
penalty. A good default threshold is 6, but a lower threshold might be needed for more unstable models
1749+
(small models, fine-tuning).
17101750
subfolder (`str`, *optional*, defaults to `""`):
17111751
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
17121752
specify the folder name here.
@@ -1796,15 +1836,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
17961836
device_map = kwargs.pop("device_map", None)
17971837
max_memory = kwargs.pop("max_memory", None)
17981838
offload_folder = kwargs.pop("offload_folder", None)
1799-
offload_state_dict = kwargs.pop("offload_state_dict", None)
1839+
offload_state_dict = kwargs.pop("offload_state_dict", False)
1840+
load_in_8bit = kwargs.pop("load_in_8bit", False)
1841+
int8_threshold = kwargs.pop("int8_threshold", 6.0)
18001842
subfolder = kwargs.pop("subfolder", "")
18011843

18021844
if trust_remote_code is True:
18031845
logger.warning(
18041846
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
18051847
" ignored."
18061848
)
1807-
18081849
if device_map is not None:
18091850
if low_cpu_mem_usage is None:
18101851
low_cpu_mem_usage = True
@@ -1824,6 +1865,28 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
18241865
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
18251866
)
18261867

1868+
if load_in_8bit:
1869+
if not (is_accelerate_available() and is_bitsandbytes_available()):
1870+
raise ImportError(
1871+
"Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of"
1872+
" bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or"
1873+
" pip install bitsandbytes` "
1874+
)
1875+
if torch_dtype == "auto" or torch_dtype != torch.float16:
1876+
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
1877+
torch_dtype = torch.float16
1878+
logger.info("Loading the model in mixed int8 - forcing the weights to be casted in float16")
1879+
if device_map is None:
1880+
raise ValueError(
1881+
"A device map needs to be passed to run convert models into mixed-int8 format. Please run"
1882+
"`.from_pretrained` with `device_map='auto'`"
1883+
)
1884+
if from_tf or from_flax:
1885+
raise ValueError(
1886+
"Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make"
1887+
" sure the weights are in PyTorch format."
1888+
)
1889+
18271890
from_pt = not (from_tf | from_flax)
18281891

18291892
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
@@ -2063,12 +2126,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
20632126

20642127
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
20652128
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
2066-
elif low_cpu_mem_usage:
2129+
elif load_in_8bit or low_cpu_mem_usage:
20672130
init_contexts.append(init_empty_weights())
20682131

20692132
with ContextManagers(init_contexts):
20702133
model = cls(config, *model_args, **model_kwargs)
20712134

2135+
if load_in_8bit:
2136+
logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
2137+
2138+
# We never convert lm_head or any last modules for numerical stability reasons
2139+
modules_to_not_convert = get_key_to_not_convert(model)
2140+
model = replace_8bit_linear(model, threshold=int8_threshold, modules_to_not_convert=modules_to_not_convert)
2141+
20722142
if isinstance(device_map, str):
20732143
if model._no_split_modules is None:
20742144
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
@@ -2091,9 +2161,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
20912161
# Make sure tied weights are tied before creating the device map.
20922162
model.tie_weights()
20932163
device_map = infer_auto_device_map(
2094-
model, no_split_module_classes=no_split_modules, dtype=torch_dtype, max_memory=max_memory
2164+
model,
2165+
no_split_module_classes=no_split_modules,
2166+
dtype=torch_dtype if not load_in_8bit else torch.int8,
2167+
max_memory=max_memory,
20952168
)
20962169

2170+
if load_in_8bit:
2171+
# The LM head can stay on disk / CPU
2172+
device_map_without_lm_head = {
2173+
key: device_map[key] for key in device_map.keys() if key != modules_to_not_convert
2174+
}
2175+
if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
2176+
raise ValueError("8-bit operations on `bitsandbytes` are not supported under CPU!")
2177+
del device_map_without_lm_head
2178+
20972179
if from_tf:
20982180
if resolved_archive_file.endswith(".index"):
20992181
# Load from a TensorFlow 1.X checkpoint - provided by original authors
@@ -2145,6 +2227,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
21452227
offload_folder=offload_folder,
21462228
offload_state_dict=offload_state_dict,
21472229
dtype=torch_dtype,
2230+
load_in_8bit=load_in_8bit,
21482231
)
21492232

21502233
# make sure token embedding weights are still tied if needed
@@ -2185,6 +2268,7 @@ def _load_pretrained_model(
21852268
offload_folder=None,
21862269
offload_state_dict=None,
21872270
dtype=None,
2271+
load_in_8bit=False,
21882272
):
21892273
if device_map is not None and "disk" in device_map.values():
21902274
if offload_folder is None:
@@ -2250,7 +2334,10 @@ def _fix_key(key):
22502334
key = ".".join(key.split(".")[1:])
22512335
param = model_state_dict[key]
22522336
if param.device == torch.device("meta"):
2253-
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
2337+
if not load_in_8bit:
2338+
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
2339+
else:
2340+
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
22542341

22552342
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
22562343
if _fast_init:
@@ -2359,6 +2446,7 @@ def _find_mismatched_keys(
23592446
state_dict_folder=state_dict_folder,
23602447
state_dict_index=state_dict_index,
23612448
dtype=dtype,
2449+
load_in_8bit=load_in_8bit,
23622450
)
23632451
error_msgs += new_error_msgs
23642452
else:

0 commit comments

Comments
 (0)