@@ -67,7 +67,11 @@ def is_namedtuple(data):
67
67
Checks if `data` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a
68
68
`namedtuple` perfectly.
69
69
"""
70
- return isinstance (data , tuple ) and hasattr (data , "_asdict" ) and hasattr (data , "_fields" )
70
+ return (
71
+ isinstance (data , tuple )
72
+ and hasattr (data , "_asdict" )
73
+ and hasattr (data , "_fields" )
74
+ )
71
75
72
76
73
77
def honor_type (obj , generator ):
@@ -81,7 +85,9 @@ def honor_type(obj, generator):
81
85
return type (obj )(generator )
82
86
83
87
84
- def recursively_apply (func , data , * args , test_type = is_torch_tensor , error_on_other_type = False , ** kwargs ):
88
+ def recursively_apply (
89
+ func , data , * args , test_type = is_torch_tensor , error_on_other_type = False , ** kwargs
90
+ ):
85
91
"""
86
92
Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
87
93
@@ -108,7 +114,12 @@ def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_oth
108
114
data ,
109
115
(
110
116
recursively_apply (
111
- func , o , * args , test_type = test_type , error_on_other_type = error_on_other_type , ** kwargs
117
+ func ,
118
+ o ,
119
+ * args ,
120
+ test_type = test_type ,
121
+ error_on_other_type = error_on_other_type ,
122
+ ** kwargs ,
112
123
)
113
124
for o in data
114
125
),
@@ -117,7 +128,12 @@ def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_oth
117
128
return type (data )(
118
129
{
119
130
k : recursively_apply (
120
- func , v , * args , test_type = test_type , error_on_other_type = error_on_other_type , ** kwargs
131
+ func ,
132
+ v ,
133
+ * args ,
134
+ test_type = test_type ,
135
+ error_on_other_type = error_on_other_type ,
136
+ ** kwargs ,
121
137
)
122
138
for k , v in data .items ()
123
139
}
@@ -167,7 +183,13 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
167
183
return tensor .to (device )
168
184
elif isinstance (tensor , (tuple , list )):
169
185
return honor_type (
170
- tensor , (send_to_device (t , device , non_blocking = non_blocking , skip_keys = skip_keys ) for t in tensor )
186
+ tensor ,
187
+ (
188
+ send_to_device (
189
+ t , device , non_blocking = non_blocking , skip_keys = skip_keys
190
+ )
191
+ for t in tensor
192
+ ),
171
193
)
172
194
elif isinstance (tensor , Mapping ):
173
195
if isinstance (skip_keys , str ):
@@ -176,7 +198,13 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
176
198
skip_keys = []
177
199
return type (tensor )(
178
200
{
179
- k : t if k in skip_keys else send_to_device (t , device , non_blocking = non_blocking , skip_keys = skip_keys )
201
+ k : (
202
+ t
203
+ if k in skip_keys
204
+ else send_to_device (
205
+ t , device , non_blocking = non_blocking , skip_keys = skip_keys
206
+ )
207
+ )
180
208
for k , t in tensor .items ()
181
209
}
182
210
)
@@ -231,7 +259,9 @@ def initialize_tensors(data_structure):
231
259
def _initialize_tensor (tensor_info ):
232
260
return torch .empty (* tensor_info .shape , dtype = tensor_info .dtype )
233
261
234
- return recursively_apply (_initialize_tensor , data_structure , test_type = is_tensor_information )
262
+ return recursively_apply (
263
+ _initialize_tensor , data_structure , test_type = is_tensor_information
264
+ )
235
265
236
266
237
267
def find_batch_size (data ):
@@ -253,7 +283,9 @@ def find_batch_size(data):
253
283
for k in data .keys ():
254
284
return find_batch_size (data [k ])
255
285
elif not isinstance (data , torch .Tensor ):
256
- raise TypeError (f"Can only find the batch size of tensors but got { type (data )} ." )
286
+ raise TypeError (
287
+ f"Can only find the batch size of tensors but got { type (data )} ."
288
+ )
257
289
return data .shape [0 ]
258
290
259
291
@@ -344,7 +376,9 @@ def _gpu_gather_one(tensor):
344
376
# a backend of `None` is always CPU
345
377
# also gloo does not support `all_gather_into_tensor`,
346
378
# which will result in a larger memory overhead for the op
347
- output_tensors = [torch .empty_like (tensor ) for _ in range (state .num_processes )]
379
+ output_tensors = [
380
+ torch .empty_like (tensor ) for _ in range (state .num_processes )
381
+ ]
348
382
torch .distributed .all_gather (output_tensors , tensor )
349
383
return torch .cat (output_tensors , dim = 0 )
350
384
@@ -367,7 +401,10 @@ def verify_operation(function):
367
401
368
402
@wraps (function )
369
403
def wrapper (* args , ** kwargs ):
370
- if PartialState ().distributed_type == DistributedType .NO or not PartialState ().debug :
404
+ if (
405
+ PartialState ().distributed_type == DistributedType .NO
406
+ or not PartialState ().debug
407
+ ):
371
408
return function (* args , ** kwargs )
372
409
operation = f"{ function .__module__ } .{ function .__name__ } "
373
410
if "tensor" in kwargs :
@@ -384,7 +421,9 @@ def wrapper(*args, **kwargs):
384
421
if output [0 ] is not None :
385
422
are_same = output .count (output [0 ]) == len (output )
386
423
if not are_same :
387
- process_shape_str = "\n - " .join ([f"Process { i } : { shape } " for i , shape in enumerate (output )])
424
+ process_shape_str = "\n - " .join (
425
+ [f"Process { i } : { shape } " for i , shape in enumerate (output )]
426
+ )
388
427
raise DistributedOperationException (
389
428
f"Cannot apply desired operation due to shape mismatches. "
390
429
"All shapes across devices must be valid."
@@ -465,14 +504,21 @@ def _gpu_broadcast_one(tensor, src=0):
465
504
torch .distributed .broadcast (tensor , src = src )
466
505
return tensor
467
506
468
- return recursively_apply (_gpu_broadcast_one , data , error_on_other_type = True , src = src )
507
+ return recursively_apply (
508
+ _gpu_broadcast_one , data , error_on_other_type = True , src = src
509
+ )
469
510
470
511
471
512
def _tpu_broadcast (tensor , src = 0 , name = "broadcast tensor" ):
472
513
if isinstance (tensor , (list , tuple )):
473
- return honor_type (tensor , (_tpu_broadcast (t , name = f"{ name } _{ i } " ) for i , t in enumerate (tensor )))
514
+ return honor_type (
515
+ tensor ,
516
+ (_tpu_broadcast (t , name = f"{ name } _{ i } " ) for i , t in enumerate (tensor )),
517
+ )
474
518
elif isinstance (tensor , Mapping ):
475
- return type (tensor )({k : _tpu_broadcast (v , name = f"{ name } _{ k } " ) for k , v in tensor .items ()})
519
+ return type (tensor )(
520
+ {k : _tpu_broadcast (v , name = f"{ name } _{ k } " ) for k , v in tensor .items ()}
521
+ )
476
522
return xm .mesh_reduce (name , tensor , lambda x : x [src ])
477
523
478
524
@@ -499,15 +545,19 @@ def gather_tensor_shape(tensor):
499
545
# Allocate 80 bytes to store the shape
500
546
max_tensor_dimension = 2 ** 20
501
547
state = PartialState ()
502
- base_tensor = torch .empty (max_tensor_dimension , dtype = torch .int , device = state .device )
548
+ base_tensor = torch .empty (
549
+ max_tensor_dimension , dtype = torch .int , device = state .device
550
+ )
503
551
504
552
# Since PyTorch can't just send a tensor to another GPU without
505
553
# knowing its size, we store the size of the tensor with data
506
554
# in an allocation
507
555
if tensor is not None :
508
556
shape = tensor .shape
509
557
tensor_dtype = TENSOR_TYPE_TO_INT [tensor .dtype ]
510
- base_tensor [: len (shape ) + 1 ] = torch .tensor (list (shape ) + [tensor_dtype ], dtype = int )
558
+ base_tensor [: len (shape ) + 1 ] = torch .tensor (
559
+ list (shape ) + [tensor_dtype ], dtype = int
560
+ )
511
561
# Perform a reduction to copy the size data onto all GPUs
512
562
base_tensor = reduce (base_tensor , reduction = "sum" )
513
563
base_tensor = base_tensor [base_tensor .nonzero ()]
@@ -549,7 +599,9 @@ def broadcast(tensor, from_process: int = 0):
549
599
The same data structure as `tensor` with all tensors broadcasted to the proper device.
550
600
"""
551
601
if PartialState ().distributed_type == DistributedType .XLA :
552
- return _tpu_broadcast (tensor , src = from_process , name = "accelerate.utils.broadcast" )
602
+ return _tpu_broadcast (
603
+ tensor , src = from_process , name = "accelerate.utils.broadcast"
604
+ )
553
605
elif PartialState ().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES :
554
606
return _gpu_broadcast (tensor , src = from_process )
555
607
else :
@@ -571,7 +623,9 @@ def broadcast_object_list(object_list, from_process: int = 0):
571
623
"""
572
624
if PartialState ().distributed_type == DistributedType .XLA :
573
625
for i , obj in enumerate (object_list ):
574
- object_list [i ] = xm .mesh_reduce ("accelerate.utils.broadcast_object_list" , obj , lambda x : x [from_process ])
626
+ object_list [i ] = xm .mesh_reduce (
627
+ "accelerate.utils.broadcast_object_list" , obj , lambda x : x [from_process ]
628
+ )
575
629
elif PartialState ().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES :
576
630
torch .distributed .broadcast_object_list (object_list , src = from_process )
577
631
return object_list
@@ -599,10 +653,14 @@ def _slice_tensor(tensor, tensor_slice):
599
653
600
654
def concatenate (data , dim = 0 ):
601
655
"""
602
- Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape.
656
+ Recursively concatenates elements in a nested structure of tensors or strings.
657
+
658
+ Supports nested lists, tuples, or dictionaries that contain either:
659
+ - torch.Tensors (with the same shape except along `dim`)
660
+ - strings (concatenated as flat lists)
603
661
604
662
Args:
605
- data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`):
663
+ data (nested list/tuple/dictionary of lists of tensors `torch.Tensor` or `str` ):
606
664
The data to concatenate.
607
665
dim (`int`, *optional*, defaults to 0):
608
666
The dimension on which to concatenate.
@@ -612,11 +670,17 @@ def concatenate(data, dim=0):
612
670
"""
613
671
if isinstance (data [0 ], (tuple , list )):
614
672
first_inner = data [0 ][0 ] if len (data [0 ]) > 0 else None
615
-
673
+
616
674
if isinstance (first_inner , str ):
617
675
return honor_type (data [0 ], [item for sublist in data for item in sublist ])
618
676
else :
619
- return honor_type (data [0 ], (concatenate ([d [i ] for d in data ], dim = dim ) for i in range (len (data [0 ]))))
677
+ return honor_type (
678
+ data [0 ],
679
+ (
680
+ concatenate ([d [i ] for d in data ], dim = dim )
681
+ for i in range (len (data [0 ]))
682
+ ),
683
+ )
620
684
621
685
elif isinstance (data [0 ], Mapping ):
622
686
return type (data [0 ])(
@@ -675,15 +739,24 @@ def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
675
739
new_tensor = tensor .new_zeros (tuple (new_size )) + pad_index
676
740
if pad_first :
677
741
indices = tuple (
678
- slice (max_size - old_size [dim ], max_size ) if i == dim else slice (None ) for i in range (len (new_size ))
742
+ slice (max_size - old_size [dim ], max_size ) if i == dim else slice (None )
743
+ for i in range (len (new_size ))
679
744
)
680
745
else :
681
- indices = tuple (slice (0 , old_size [dim ]) if i == dim else slice (None ) for i in range (len (new_size )))
746
+ indices = tuple (
747
+ slice (0 , old_size [dim ]) if i == dim else slice (None )
748
+ for i in range (len (new_size ))
749
+ )
682
750
new_tensor [indices ] = tensor
683
751
return new_tensor
684
752
685
753
return recursively_apply (
686
- _pad_across_processes , tensor , error_on_other_type = True , dim = dim , pad_index = pad_index , pad_first = pad_first
754
+ _pad_across_processes ,
755
+ tensor ,
756
+ error_on_other_type = True ,
757
+ dim = dim ,
758
+ pad_index = pad_index ,
759
+ pad_first = pad_first ,
687
760
)
688
761
689
762
@@ -713,7 +786,10 @@ def _pad_input_tensors(tensor, batch_size, num_processes, dim=0):
713
786
new_size = list (old_size )
714
787
new_size [0 ] = batch_size + to_pad
715
788
new_tensor = tensor .new_zeros (tuple (new_size ))
716
- indices = tuple (slice (0 , old_size [dim ]) if i == dim else slice (None ) for i in range (len (new_size )))
789
+ indices = tuple (
790
+ slice (0 , old_size [dim ]) if i == dim else slice (None )
791
+ for i in range (len (new_size ))
792
+ )
717
793
new_tensor [indices ] = tensor
718
794
return new_tensor
719
795
@@ -765,7 +841,11 @@ def _reduce_across_processes(tensor, reduction="mean", scale=1.0):
765
841
return cloned_tensor
766
842
767
843
return recursively_apply (
768
- _reduce_across_processes , tensor , error_on_other_type = True , reduction = reduction , scale = scale
844
+ _reduce_across_processes ,
845
+ tensor ,
846
+ error_on_other_type = True ,
847
+ reduction = reduction ,
848
+ scale = scale ,
769
849
)
770
850
771
851
@@ -785,7 +865,9 @@ def _convert_to_fp32(tensor):
785
865
return tensor .float ()
786
866
787
867
def _is_fp16_bf16_tensor (tensor ):
788
- return (is_torch_tensor (tensor ) or hasattr (tensor , "dtype" )) and tensor .dtype in (
868
+ return (
869
+ is_torch_tensor (tensor ) or hasattr (tensor , "dtype" )
870
+ ) and tensor .dtype in (
789
871
torch .float16 ,
790
872
torch .bfloat16 ,
791
873
)
0 commit comments