Skip to content

Commit 032c7d2

Browse files
committed
Cleaner implementation.
1 parent 1a73970 commit 032c7d2

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

src/accelerate/utils/operations.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -612,18 +612,11 @@ def concatenate(data, dim=0):
612612
"""
613613
if isinstance(data[0], (tuple, list)):
614614
first_inner = data[0][0] if len(data[0]) > 0 else None
615-
616-
if isinstance(first_inner, (torch.Tensor, tuple, list, Mapping)):
617-
return honor_type(
618-
data[0],
619-
(
620-
concatenate([d[i] for d in data], dim=dim)
621-
for i in range(len(data[0]))
622-
),
623-
)
624-
else:
625-
# If inner element are not nested, flatten
615+
616+
if isinstance(first_inner, str):
626617
return honor_type(data[0], [item for sublist in data for item in sublist])
618+
else:
619+
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
627620

628621
elif isinstance(data[0], Mapping):
629622
return type(data[0])(

0 commit comments

Comments
 (0)