-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Closed
Description
System Info
Name: accelerate
Version: 1.7.0
Summary: Accelerate
Name: torch
Version: 2.5.1
Information
- The official example scripts
- My own modified scripts
Tasks
- One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
) - My own task or dataset (give details below)
Reproduction
from accelerate import Accelerator
from torch.utils.data import DataLoader
import torch
import webdataset as wds
# ---- Step 1: Initialize Accelerator ---- #
accelerator = Accelerator()
device = accelerator.device
# ---- Step 2: Define WebDataset with IterableDataset ---- #
wds_urls = ["/path/to/shard1.tar", "/path/to/shard2.tar"] # Add your shard paths here
def custom_collate(batch):
# Create batch as dict of lists
collated = {}
for sample in batch:
for k, v in sample.items():
if k not in collated:
collated[k] = []
collated[k].append(v)
# Convert tensors to stacked tensors
for k in collated:
if torch.is_tensor(collated[k][0]):
collated[k] = torch.stack(collated[k])
return collated
wds_dataset = (
wds.WebDataset(wds_urls, handler=wds.handlers.warn_and_continue)
.decode("torchrgb")
.to_tuple("mp4", "txt", "json")
.map(lambda sample: {
"video": sample[0],
"text": sample[1],
"meta": sample[2]
})
)
# ---- Step 3: Define DataLoader ---- #
dataloader = DataLoader(
wds_dataset,
batch_size=32,
num_workers=4,
pin_memory=True,
collate_fn=custom_collate
)
# ---- Step 4: Define model and optimizer ---- #
model = torch.nn.Linear(1000, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# ---- Step 5: Prepare model, optimizer, and dataloader ---- #
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
# ---- Step 6: Training Loop ---- #
for epoch in range(10):
for batch in dataloader:
batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
with this code, the dataloader cannot work normally as it return
TypeError: Can only concatenate tensors but got <class 'str'>,
however, if i do not wrap dataloader with accelerator.prepare(), the dataloder can successfully return data.
This issue seems to only occur when using IterableDataset, I notice when the dataset is IterableDataset, accelerator while enter DataLoaderDispatcher and call _fetch_batches, this will cause this error. Is there any solution for this issue?
Expected behavior
Work normally when i use accelerator.prepare to wrap a dataloder with IterableDataset dataset.
Metadata
Metadata
Assignees
Labels
No labels