Skip to content

Concatenate str support for IterableDataset #3686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
12 changes: 9 additions & 3 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,18 +599,24 @@ def _slice_tensor(tensor, tensor_slice):

def concatenate(data, dim=0):
"""
Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape.
Recursively concatenates elements in a nested structure of tensors or strings.

Supports nested lists, tuples, or dictionaries that contain either:
- torch.Tensors (with the same shape except along `dim`)
- strings (concatenated as flat lists)

Args:
data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`):
data (nested list/tuple/dictionary of lists of tensors `torch.Tensor` or `str`):
The data to concatenate.
dim (`int`, *optional*, defaults to 0):
The dimension on which to concatenate.

Returns:
The same data structure as `data` with all the tensors concatenated.
"""
if isinstance(data[0], (tuple, list)):
if isinstance(data[0], list) and all(isinstance(x, str) for x in data[0]):
return honor_type(data[0], [item for sublist in data for item in sublist])
elif isinstance(data[0], (tuple, list)):
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
elif isinstance(data[0], Mapping):
return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
Expand Down
73 changes: 72 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
save,
send_to_device,
)
from accelerate.utils.operations import is_namedtuple
from accelerate.utils.operations import concatenate, is_namedtuple


if is_torch_xla_available():
Expand Down Expand Up @@ -391,6 +391,77 @@ def test_slice_and_concatenate(self):
# We should expect there to be 66 items now
assert result.shape == torch.Size([66, 4, 4])

def test_concatenate_batches(self):
# Tensor batches test
batch1 = {
"x": torch.rand(4, 1),
"y": torch.from_numpy(np.array([[1.0, 2.0, 3.0]] * 4, dtype=np.float32)),
}

batch2 = {
"x": torch.rand(4, 1),
"y": torch.from_numpy(np.array([[1.0, 2.0, 3.0]] * 4, dtype=np.float32)),
}

batch = concatenate([batch1, batch2], dim=0)

assert batch["x"].shape == (8, 1)
assert batch["y"].shape == (8, 3)

# String test
batch1 = {"x": torch.rand(4, 1), "animals": ["dog", "cat", "baby", "penguin"]}
batch2 = {
"x": torch.rand(4, 1),
"animals": ["koala", "samurai", "iguana", "rabbit"],
}

batch = concatenate([batch1, batch2], dim=0)

assert batch["x"].shape == (8, 1)
assert batch["animals"] == batch1["animals"] + batch2["animals"]

# dict test
batch1 = {
"dict": {
"a": torch.rand(4, 4),
"b": torch.tensor([1, 2, 3, 4]),
"c": ["bit", "byte", "gigabyte", "terabyte"]
}
}

batch2 = {
"dict": {
"a": torch.rand(4, 4),
"b": torch.tensor([5, 6, 7, 8]),
"c": ["kilobyte", "megabyte", "gigabit", "terabit"]
}
}

batch = concatenate([batch1, batch2], dim=0)

assert batch["dict"]["a"].shape == (8, 4)
assert batch["dict"]["b"].shape == (8,)
assert batch["dict"]["c"] == batch1["dict"]["c"] + batch2["dict"]["c"]

# tuples test
batch1 = {
"tuple_key": (torch.rand(4, 2), ["Blackbeard", "Captain Kidd", "Anne Bonny", "Calico Jack"])
}
batch2 = {
"tuple_key": (torch.rand(4, 2), ["Bartholomew Roberts", "Stede Bonnet", "Calico Jack", "Captain Kidd"])
}

batch = concatenate([batch1, batch2], dim=0)

assert batch["tuple_key"][0].shape == (8, 2)
assert batch["tuple_key"][1] == batch1["tuple_key"][1] + batch2["tuple_key"][1]

batch1 = {"mix": torch.rand(4, 1)}
batch2 = {"mix": ["Basketball", "Baseball", "Surf", "Bilboquet"]}

with pytest.raises(TypeError):
concatenate([batch1, batch2], dim=0)

def test_send_to_device_compiles(self):
compiled_send_to_device = torch.compile(send_to_device, fullgraph=True)
compiled_send_to_device(torch.zeros([1], dtype=torch.bfloat16), "cpu")
Expand Down