Skip to content

multi-gpu training with torchdata nodes #1472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
karinazad opened this issue Apr 11, 2025 · 2 comments
Open

multi-gpu training with torchdata nodes #1472

karinazad opened this issue Apr 11, 2025 · 2 comments

Comments

@karinazad
Copy link

What's the best way to handle multi-gpu and multi-node training with torchdata nodes?

@keunwoochoi and I have the following nodes:

  • FileListNode which lists all files (shards)
  • StreamNode which reads individual files and streams records

This works well on a single gpu but in a multi-gpu setting, we end up listing all files and stream from all files on each worker.

What would be the best solution? To list files in the main process and stream every n-th record on each worker? Do you have examples of for a similar setup?

@divyanshk
Copy link
Contributor

Hi @karinazad, thanks for the question. The spliting or sharding of data happens at the dataset level, or in this case at the base node level. One way it can be done for a typical file or directory based setup is:

from torch import distributed as dist
class JsonLNode(BaseNode[T]):
    def __init__(self, path_to_dataset_dir: str) -> None:
        
        # initializations here

        # get rank and world size
        if dist.is_available() and dist.is_initialized():
                rank, world_size = dist.get_rank(), dist.get_world_size()
        else:
            rank = os.environ.get("RANK", "0")
            world_size = os.environ.get("WORLD_SIZE", "1")

        # use rank to distribute work
        rank_to_file_map = self.split_directory_by_rank()
        self.paths_to_load = rank_to_file_map[self.rank]

        # use self.paths_to_load in your file iterator

        ....

tl;dr leverage from torch import distributed as dist to know the rank your code is running on, use that to "allocate" work when accessing the data.

Alternatively, if you want the familiar map-stype setup you can use a SamplerWrapper and use DistributedSampler with it. See an example of using sampler here.

@karinazad
Copy link
Author

thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants