@@ -222,6 +222,8 @@ def set_module_tensor_to_device(
222
222
dtype : Optional [Union [str , torch .dtype ]] = None ,
223
223
fp16_statistics : Optional [torch .HalfTensor ] = None ,
224
224
tied_params_map : Optional [dict [int , dict [torch .device , torch .Tensor ]]] = None ,
225
+ non_blocking : bool = False ,
226
+ clear_cache : bool = True ,
225
227
):
226
228
"""
227
229
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
@@ -245,6 +247,10 @@ def set_module_tensor_to_device(
245
247
A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given
246
248
execution device, this parameter is useful to reuse the first available pointer of a shared weight on the
247
249
device for all others, instead of duplicating memory.
250
+ non_blocking (`bool`, *optional*, defaults to `False`):
251
+ If `True`, the device transfer will be asynchronous with respect to the host, if possible.
252
+ clear_cache (`bool`, *optional*, defaults to `True`):
253
+ Whether or not to clear the device cache after setting the tensor on the device.
248
254
"""
249
255
# Recurse if needed
250
256
if "." in tensor_name :
@@ -295,9 +301,9 @@ def set_module_tensor_to_device(
295
301
296
302
if dtype is None :
297
303
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
298
- value = value .to (old_value .dtype )
304
+ value = value .to (old_value .dtype , non_blocking = non_blocking )
299
305
elif not str (value .dtype ).startswith (("torch.uint" , "torch.int" , "torch.bool" )):
300
- value = value .to (dtype )
306
+ value = value .to (dtype , non_blocking = non_blocking )
301
307
302
308
device_quantization = None
303
309
with torch .no_grad ():
@@ -326,15 +332,15 @@ def set_module_tensor_to_device(
326
332
if "xpu" in str (device ) and not is_xpu_available ():
327
333
raise ValueError (f'{ device } is not available, you should use device="cpu" instead' )
328
334
if value is None :
329
- new_value = old_value .to (device )
335
+ new_value = old_value .to (device , non_blocking = non_blocking )
330
336
if dtype is not None and device in ["meta" , torch .device ("meta" )]:
331
337
if not str (old_value .dtype ).startswith (("torch.uint" , "torch.int" , "torch.bool" )):
332
- new_value = new_value .to (dtype )
338
+ new_value = new_value .to (dtype , non_blocking = non_blocking )
333
339
334
340
if not is_buffer :
335
341
module ._parameters [tensor_name ] = param_cls (new_value , requires_grad = old_value .requires_grad )
336
342
elif isinstance (value , torch .Tensor ):
337
- new_value = value .to (device )
343
+ new_value = value .to (device , non_blocking = non_blocking )
338
344
else :
339
345
new_value = torch .tensor (value , device = device )
340
346
if device_quantization is not None :
@@ -347,24 +353,30 @@ def set_module_tensor_to_device(
347
353
if param_cls .__name__ in ["Int8Params" , "FP4Params" , "Params4bit" ]:
348
354
if param_cls .__name__ == "Int8Params" and new_value .dtype == torch .float32 :
349
355
# downcast to fp16 if any - needed for 8bit serialization
350
- new_value = new_value .to (torch .float16 )
356
+ new_value = new_value .to (torch .float16 , non_blocking = non_blocking )
351
357
# quantize module that are going to stay on the cpu so that we offload quantized weights
352
358
if device == "cpu" and param_cls .__name__ == "Int8Params" :
353
359
new_value = param_cls (new_value , requires_grad = old_value .requires_grad , ** kwargs ).to (0 ).to ("cpu" )
354
360
new_value .CB = new_value .CB .to ("cpu" )
355
361
new_value .SCB = new_value .SCB .to ("cpu" )
356
362
else :
357
- new_value = param_cls (new_value , requires_grad = old_value .requires_grad , ** kwargs ).to (device )
363
+ new_value = param_cls (new_value , requires_grad = old_value .requires_grad , ** kwargs ).to (
364
+ device , non_blocking = non_blocking
365
+ )
358
366
elif param_cls .__name__ in ["QTensor" , "QBitsTensor" ]:
359
- new_value = torch .nn .Parameter (new_value , requires_grad = old_value .requires_grad ).to (device )
367
+ new_value = torch .nn .Parameter (new_value , requires_grad = old_value .requires_grad ).to (
368
+ device , non_blocking = non_blocking
369
+ )
360
370
elif param_cls .__name__ in ["AffineQuantizedTensor" ]:
361
- new_value = new_value .to (device )
371
+ new_value = new_value .to (device , non_blocking = non_blocking )
362
372
else :
363
- new_value = param_cls (new_value , requires_grad = old_value .requires_grad ).to (device )
373
+ new_value = param_cls (new_value , requires_grad = old_value .requires_grad ).to (
374
+ device , non_blocking = non_blocking
375
+ )
364
376
365
377
module ._parameters [tensor_name ] = new_value
366
378
if fp16_statistics is not None :
367
- module ._parameters [tensor_name ].SCB = fp16_statistics .to (device )
379
+ module ._parameters [tensor_name ].SCB = fp16_statistics .to (device , non_blocking = non_blocking )
368
380
del fp16_statistics
369
381
# as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight
370
382
if (
@@ -390,8 +402,9 @@ def set_module_tensor_to_device(
390
402
device_index = torch .device (device ).index if torch .device (device ).type == "cuda" else None
391
403
if not getattr (module .weight , "quant_state" , None ) and device_index is not None :
392
404
module .weight = module .weight .cuda (device_index )
405
+
393
406
# clean pre and post forward hook
394
- if device != "cpu" :
407
+ if clear_cache and device != "cpu" :
395
408
clear_device_cache ()
396
409
397
410
# When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in
0 commit comments