Skip to content

dataloader with IterableDataset cannot work after accelerator.prepare() #3624

@leeruibin

Description

@leeruibin

System Info

Name: accelerate
Version: 1.7.0
Summary: Accelerate
Name: torch
Version: 2.5.1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

from accelerate import Accelerator
from torch.utils.data import DataLoader
import torch
import webdataset as wds

# ---- Step 1: Initialize Accelerator ---- #
accelerator = Accelerator()
device = accelerator.device

# ---- Step 2: Define WebDataset with IterableDataset ---- #
wds_urls = ["/path/to/shard1.tar", "/path/to/shard2.tar"]  # Add your shard paths here

def custom_collate(batch):
    # Create batch as dict of lists
    collated = {}
    for sample in batch:
        for k, v in sample.items():
            if k not in collated:
                collated[k] = []
            collated[k].append(v)
    # Convert tensors to stacked tensors
    for k in collated:
        if torch.is_tensor(collated[k][0]):
            collated[k] = torch.stack(collated[k])
    return collated

wds_dataset = (
    wds.WebDataset(wds_urls, handler=wds.handlers.warn_and_continue)
    .decode("torchrgb")
    .to_tuple("mp4", "txt", "json")
    .map(lambda sample: {
        "video": sample[0],
        "text": sample[1],
        "meta": sample[2]
    })
)

# ---- Step 3: Define DataLoader ---- #
dataloader = DataLoader(
    wds_dataset,
    batch_size=32,
    num_workers=4,
    pin_memory=True,
    collate_fn=custom_collate
)

# ---- Step 4: Define model and optimizer ---- #
model = torch.nn.Linear(1000, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# ---- Step 5: Prepare model, optimizer, and dataloader ---- #
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)

# ---- Step 6: Training Loop ---- #
for epoch in range(10):
    for batch in dataloader:
        batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}

with this code, the dataloader cannot work normally as it return
TypeError: Can only concatenate tensors but got <class 'str'>,
however, if i do not wrap dataloader with accelerator.prepare(), the dataloder can successfully return data.

This issue seems to only occur when using IterableDataset, I notice when the dataset is IterableDataset, accelerator while enter DataLoaderDispatcher and call _fetch_batches, this will cause this error. Is there any solution for this issue?

Expected behavior

Work normally when i use accelerator.prepare to wrap a dataloder with IterableDataset dataset.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions