You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
StreamNode which reads individual files and streams records
This works well on a single gpu but in a multi-gpu setting, we end up listing all files and stream from all files on each worker.
What would be the best solution? To list files in the main process and stream every n-th record on each worker? Do you have examples of for a similar setup?
The text was updated successfully, but these errors were encountered:
Hi @karinazad, thanks for the question. The spliting or sharding of data happens at the dataset level, or in this case at the base node level. One way it can be done for a typical file or directory based setup is:
from torch import distributed as dist
class JsonLNode(BaseNode[T]):
def __init__(self, path_to_dataset_dir: str) -> None:
# initializations here
# get rank and world size
if dist.is_available() and dist.is_initialized():
rank, world_size = dist.get_rank(), dist.get_world_size()
else:
rank = os.environ.get("RANK", "0")
world_size = os.environ.get("WORLD_SIZE", "1")
# use rank to distribute work
rank_to_file_map = self.split_directory_by_rank()
self.paths_to_load = rank_to_file_map[self.rank]
# use self.paths_to_load in your file iterator
....
tl;dr leverage from torch import distributed as dist to know the rank your code is running on, use that to "allocate" work when accessing the data.
Alternatively, if you want the familiar map-stype setup you can use a SamplerWrapper and use DistributedSampler with it. See an example of using sampler here.
What's the best way to handle multi-gpu and multi-node training with torchdata nodes?
@keunwoochoi and I have the following nodes:
FileListNode
which lists all files (shards)StreamNode
which reads individual files and streams recordsThis works well on a single gpu but in a multi-gpu setting, we end up listing all files and stream from all files on each worker.
What would be the best solution? To list files in the main process and stream every n-th record on each worker? Do you have examples of for a similar setup?
The text was updated successfully, but these errors were encountered: