Skip to content

Commit 9d49356

Browse files
committed
Updated docstring.
1 parent 032c7d2 commit 9d49356

File tree

1 file changed

+110
-28
lines changed

1 file changed

+110
-28
lines changed

src/accelerate/utils/operations.py

Lines changed: 110 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ def is_namedtuple(data):
6767
Checks if `data` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a
6868
`namedtuple` perfectly.
6969
"""
70-
return isinstance(data, tuple) and hasattr(data, "_asdict") and hasattr(data, "_fields")
70+
return (
71+
isinstance(data, tuple)
72+
and hasattr(data, "_asdict")
73+
and hasattr(data, "_fields")
74+
)
7175

7276

7377
def honor_type(obj, generator):
@@ -81,7 +85,9 @@ def honor_type(obj, generator):
8185
return type(obj)(generator)
8286

8387

84-
def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_other_type=False, **kwargs):
88+
def recursively_apply(
89+
func, data, *args, test_type=is_torch_tensor, error_on_other_type=False, **kwargs
90+
):
8591
"""
8692
Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
8793
@@ -108,7 +114,12 @@ def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_oth
108114
data,
109115
(
110116
recursively_apply(
111-
func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
117+
func,
118+
o,
119+
*args,
120+
test_type=test_type,
121+
error_on_other_type=error_on_other_type,
122+
**kwargs,
112123
)
113124
for o in data
114125
),
@@ -117,7 +128,12 @@ def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_oth
117128
return type(data)(
118129
{
119130
k: recursively_apply(
120-
func, v, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
131+
func,
132+
v,
133+
*args,
134+
test_type=test_type,
135+
error_on_other_type=error_on_other_type,
136+
**kwargs,
121137
)
122138
for k, v in data.items()
123139
}
@@ -167,7 +183,13 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
167183
return tensor.to(device)
168184
elif isinstance(tensor, (tuple, list)):
169185
return honor_type(
170-
tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor)
186+
tensor,
187+
(
188+
send_to_device(
189+
t, device, non_blocking=non_blocking, skip_keys=skip_keys
190+
)
191+
for t in tensor
192+
),
171193
)
172194
elif isinstance(tensor, Mapping):
173195
if isinstance(skip_keys, str):
@@ -176,7 +198,13 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
176198
skip_keys = []
177199
return type(tensor)(
178200
{
179-
k: t if k in skip_keys else send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys)
201+
k: (
202+
t
203+
if k in skip_keys
204+
else send_to_device(
205+
t, device, non_blocking=non_blocking, skip_keys=skip_keys
206+
)
207+
)
180208
for k, t in tensor.items()
181209
}
182210
)
@@ -231,7 +259,9 @@ def initialize_tensors(data_structure):
231259
def _initialize_tensor(tensor_info):
232260
return torch.empty(*tensor_info.shape, dtype=tensor_info.dtype)
233261

234-
return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information)
262+
return recursively_apply(
263+
_initialize_tensor, data_structure, test_type=is_tensor_information
264+
)
235265

236266

237267
def find_batch_size(data):
@@ -253,7 +283,9 @@ def find_batch_size(data):
253283
for k in data.keys():
254284
return find_batch_size(data[k])
255285
elif not isinstance(data, torch.Tensor):
256-
raise TypeError(f"Can only find the batch size of tensors but got {type(data)}.")
286+
raise TypeError(
287+
f"Can only find the batch size of tensors but got {type(data)}."
288+
)
257289
return data.shape[0]
258290

259291

@@ -344,7 +376,9 @@ def _gpu_gather_one(tensor):
344376
# a backend of `None` is always CPU
345377
# also gloo does not support `all_gather_into_tensor`,
346378
# which will result in a larger memory overhead for the op
347-
output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)]
379+
output_tensors = [
380+
torch.empty_like(tensor) for _ in range(state.num_processes)
381+
]
348382
torch.distributed.all_gather(output_tensors, tensor)
349383
return torch.cat(output_tensors, dim=0)
350384

@@ -367,7 +401,10 @@ def verify_operation(function):
367401

368402
@wraps(function)
369403
def wrapper(*args, **kwargs):
370-
if PartialState().distributed_type == DistributedType.NO or not PartialState().debug:
404+
if (
405+
PartialState().distributed_type == DistributedType.NO
406+
or not PartialState().debug
407+
):
371408
return function(*args, **kwargs)
372409
operation = f"{function.__module__}.{function.__name__}"
373410
if "tensor" in kwargs:
@@ -384,7 +421,9 @@ def wrapper(*args, **kwargs):
384421
if output[0] is not None:
385422
are_same = output.count(output[0]) == len(output)
386423
if not are_same:
387-
process_shape_str = "\n - ".join([f"Process {i}: {shape}" for i, shape in enumerate(output)])
424+
process_shape_str = "\n - ".join(
425+
[f"Process {i}: {shape}" for i, shape in enumerate(output)]
426+
)
388427
raise DistributedOperationException(
389428
f"Cannot apply desired operation due to shape mismatches. "
390429
"All shapes across devices must be valid."
@@ -465,14 +504,21 @@ def _gpu_broadcast_one(tensor, src=0):
465504
torch.distributed.broadcast(tensor, src=src)
466505
return tensor
467506

468-
return recursively_apply(_gpu_broadcast_one, data, error_on_other_type=True, src=src)
507+
return recursively_apply(
508+
_gpu_broadcast_one, data, error_on_other_type=True, src=src
509+
)
469510

470511

471512
def _tpu_broadcast(tensor, src=0, name="broadcast tensor"):
472513
if isinstance(tensor, (list, tuple)):
473-
return honor_type(tensor, (_tpu_broadcast(t, name=f"{name}_{i}") for i, t in enumerate(tensor)))
514+
return honor_type(
515+
tensor,
516+
(_tpu_broadcast(t, name=f"{name}_{i}") for i, t in enumerate(tensor)),
517+
)
474518
elif isinstance(tensor, Mapping):
475-
return type(tensor)({k: _tpu_broadcast(v, name=f"{name}_{k}") for k, v in tensor.items()})
519+
return type(tensor)(
520+
{k: _tpu_broadcast(v, name=f"{name}_{k}") for k, v in tensor.items()}
521+
)
476522
return xm.mesh_reduce(name, tensor, lambda x: x[src])
477523

478524

@@ -499,15 +545,19 @@ def gather_tensor_shape(tensor):
499545
# Allocate 80 bytes to store the shape
500546
max_tensor_dimension = 2**20
501547
state = PartialState()
502-
base_tensor = torch.empty(max_tensor_dimension, dtype=torch.int, device=state.device)
548+
base_tensor = torch.empty(
549+
max_tensor_dimension, dtype=torch.int, device=state.device
550+
)
503551

504552
# Since PyTorch can't just send a tensor to another GPU without
505553
# knowing its size, we store the size of the tensor with data
506554
# in an allocation
507555
if tensor is not None:
508556
shape = tensor.shape
509557
tensor_dtype = TENSOR_TYPE_TO_INT[tensor.dtype]
510-
base_tensor[: len(shape) + 1] = torch.tensor(list(shape) + [tensor_dtype], dtype=int)
558+
base_tensor[: len(shape) + 1] = torch.tensor(
559+
list(shape) + [tensor_dtype], dtype=int
560+
)
511561
# Perform a reduction to copy the size data onto all GPUs
512562
base_tensor = reduce(base_tensor, reduction="sum")
513563
base_tensor = base_tensor[base_tensor.nonzero()]
@@ -549,7 +599,9 @@ def broadcast(tensor, from_process: int = 0):
549599
The same data structure as `tensor` with all tensors broadcasted to the proper device.
550600
"""
551601
if PartialState().distributed_type == DistributedType.XLA:
552-
return _tpu_broadcast(tensor, src=from_process, name="accelerate.utils.broadcast")
602+
return _tpu_broadcast(
603+
tensor, src=from_process, name="accelerate.utils.broadcast"
604+
)
553605
elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
554606
return _gpu_broadcast(tensor, src=from_process)
555607
else:
@@ -571,7 +623,9 @@ def broadcast_object_list(object_list, from_process: int = 0):
571623
"""
572624
if PartialState().distributed_type == DistributedType.XLA:
573625
for i, obj in enumerate(object_list):
574-
object_list[i] = xm.mesh_reduce("accelerate.utils.broadcast_object_list", obj, lambda x: x[from_process])
626+
object_list[i] = xm.mesh_reduce(
627+
"accelerate.utils.broadcast_object_list", obj, lambda x: x[from_process]
628+
)
575629
elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
576630
torch.distributed.broadcast_object_list(object_list, src=from_process)
577631
return object_list
@@ -599,10 +653,14 @@ def _slice_tensor(tensor, tensor_slice):
599653

600654
def concatenate(data, dim=0):
601655
"""
602-
Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape.
656+
Recursively concatenates elements in a nested structure of tensors or strings.
657+
658+
Supports nested lists, tuples, or dictionaries that contain either:
659+
- torch.Tensors (with the same shape except along `dim`)
660+
- strings (concatenated as flat lists)
603661
604662
Args:
605-
data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`):
663+
data (nested list/tuple/dictionary of lists of tensors `torch.Tensor` or `str`):
606664
The data to concatenate.
607665
dim (`int`, *optional*, defaults to 0):
608666
The dimension on which to concatenate.
@@ -612,11 +670,17 @@ def concatenate(data, dim=0):
612670
"""
613671
if isinstance(data[0], (tuple, list)):
614672
first_inner = data[0][0] if len(data[0]) > 0 else None
615-
673+
616674
if isinstance(first_inner, str):
617675
return honor_type(data[0], [item for sublist in data for item in sublist])
618676
else:
619-
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
677+
return honor_type(
678+
data[0],
679+
(
680+
concatenate([d[i] for d in data], dim=dim)
681+
for i in range(len(data[0]))
682+
),
683+
)
620684

621685
elif isinstance(data[0], Mapping):
622686
return type(data[0])(
@@ -675,15 +739,24 @@ def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
675739
new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
676740
if pad_first:
677741
indices = tuple(
678-
slice(max_size - old_size[dim], max_size) if i == dim else slice(None) for i in range(len(new_size))
742+
slice(max_size - old_size[dim], max_size) if i == dim else slice(None)
743+
for i in range(len(new_size))
679744
)
680745
else:
681-
indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))
746+
indices = tuple(
747+
slice(0, old_size[dim]) if i == dim else slice(None)
748+
for i in range(len(new_size))
749+
)
682750
new_tensor[indices] = tensor
683751
return new_tensor
684752

685753
return recursively_apply(
686-
_pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first
754+
_pad_across_processes,
755+
tensor,
756+
error_on_other_type=True,
757+
dim=dim,
758+
pad_index=pad_index,
759+
pad_first=pad_first,
687760
)
688761

689762

@@ -713,7 +786,10 @@ def _pad_input_tensors(tensor, batch_size, num_processes, dim=0):
713786
new_size = list(old_size)
714787
new_size[0] = batch_size + to_pad
715788
new_tensor = tensor.new_zeros(tuple(new_size))
716-
indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))
789+
indices = tuple(
790+
slice(0, old_size[dim]) if i == dim else slice(None)
791+
for i in range(len(new_size))
792+
)
717793
new_tensor[indices] = tensor
718794
return new_tensor
719795

@@ -765,7 +841,11 @@ def _reduce_across_processes(tensor, reduction="mean", scale=1.0):
765841
return cloned_tensor
766842

767843
return recursively_apply(
768-
_reduce_across_processes, tensor, error_on_other_type=True, reduction=reduction, scale=scale
844+
_reduce_across_processes,
845+
tensor,
846+
error_on_other_type=True,
847+
reduction=reduction,
848+
scale=scale,
769849
)
770850

771851

@@ -785,7 +865,9 @@ def _convert_to_fp32(tensor):
785865
return tensor.float()
786866

787867
def _is_fp16_bf16_tensor(tensor):
788-
return (is_torch_tensor(tensor) or hasattr(tensor, "dtype")) and tensor.dtype in (
868+
return (
869+
is_torch_tensor(tensor) or hasattr(tensor, "dtype")
870+
) and tensor.dtype in (
789871
torch.float16,
790872
torch.bfloat16,
791873
)

0 commit comments

Comments
 (0)