Skip to content

Commit 524e5f9

Browse files
authored
Speedup model loading by 4-5x in Diffusers ⚡ (#3674)
* update * update * make style * update * merge if statements
1 parent d6c986c commit 524e5f9

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

src/accelerate/utils/modeling.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ def set_module_tensor_to_device(
222222
dtype: Optional[Union[str, torch.dtype]] = None,
223223
fp16_statistics: Optional[torch.HalfTensor] = None,
224224
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
225+
non_blocking: bool = False,
226+
clear_cache: bool = True,
225227
):
226228
"""
227229
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
@@ -245,6 +247,10 @@ def set_module_tensor_to_device(
245247
A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given
246248
execution device, this parameter is useful to reuse the first available pointer of a shared weight on the
247249
device for all others, instead of duplicating memory.
250+
non_blocking (`bool`, *optional*, defaults to `False`):
251+
If `True`, the device transfer will be asynchronous with respect to the host, if possible.
252+
clear_cache (`bool`, *optional*, defaults to `True`):
253+
Whether or not to clear the device cache after setting the tensor on the device.
248254
"""
249255
# Recurse if needed
250256
if "." in tensor_name:
@@ -295,9 +301,9 @@ def set_module_tensor_to_device(
295301

296302
if dtype is None:
297303
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
298-
value = value.to(old_value.dtype)
304+
value = value.to(old_value.dtype, non_blocking=non_blocking)
299305
elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
300-
value = value.to(dtype)
306+
value = value.to(dtype, non_blocking=non_blocking)
301307

302308
device_quantization = None
303309
with torch.no_grad():
@@ -326,15 +332,15 @@ def set_module_tensor_to_device(
326332
if "xpu" in str(device) and not is_xpu_available():
327333
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
328334
if value is None:
329-
new_value = old_value.to(device)
335+
new_value = old_value.to(device, non_blocking=non_blocking)
330336
if dtype is not None and device in ["meta", torch.device("meta")]:
331337
if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
332-
new_value = new_value.to(dtype)
338+
new_value = new_value.to(dtype, non_blocking=non_blocking)
333339

334340
if not is_buffer:
335341
module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
336342
elif isinstance(value, torch.Tensor):
337-
new_value = value.to(device)
343+
new_value = value.to(device, non_blocking=non_blocking)
338344
else:
339345
new_value = torch.tensor(value, device=device)
340346
if device_quantization is not None:
@@ -347,24 +353,30 @@ def set_module_tensor_to_device(
347353
if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]:
348354
if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32:
349355
# downcast to fp16 if any - needed for 8bit serialization
350-
new_value = new_value.to(torch.float16)
356+
new_value = new_value.to(torch.float16, non_blocking=non_blocking)
351357
# quantize module that are going to stay on the cpu so that we offload quantized weights
352358
if device == "cpu" and param_cls.__name__ == "Int8Params":
353359
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu")
354360
new_value.CB = new_value.CB.to("cpu")
355361
new_value.SCB = new_value.SCB.to("cpu")
356362
else:
357-
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
363+
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(
364+
device, non_blocking=non_blocking
365+
)
358366
elif param_cls.__name__ in ["QTensor", "QBitsTensor"]:
359-
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device)
367+
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(
368+
device, non_blocking=non_blocking
369+
)
360370
elif param_cls.__name__ in ["AffineQuantizedTensor"]:
361-
new_value = new_value.to(device)
371+
new_value = new_value.to(device, non_blocking=non_blocking)
362372
else:
363-
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
373+
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(
374+
device, non_blocking=non_blocking
375+
)
364376

365377
module._parameters[tensor_name] = new_value
366378
if fp16_statistics is not None:
367-
module._parameters[tensor_name].SCB = fp16_statistics.to(device)
379+
module._parameters[tensor_name].SCB = fp16_statistics.to(device, non_blocking=non_blocking)
368380
del fp16_statistics
369381
# as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight
370382
if (
@@ -390,8 +402,9 @@ def set_module_tensor_to_device(
390402
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
391403
if not getattr(module.weight, "quant_state", None) and device_index is not None:
392404
module.weight = module.weight.cuda(device_index)
405+
393406
# clean pre and post forward hook
394-
if device != "cpu":
407+
if clear_cache and device != "cpu":
395408
clear_device_cache()
396409

397410
# When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in

0 commit comments

Comments
 (0)