Skip to content

Commit 1cbe057

Browse files
author
Yifu Wang
committed
SymmetricMemory-based, low contention intra-node all-gather and reduce-scatter
ghstack-source-id: a28a292 Pull Request resolved: pytorch#130583
1 parent 72d553d commit 1cbe057

File tree

4 files changed

+295
-7
lines changed

4 files changed

+295
-7
lines changed

test/distributed/test_symmetric_memory.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ def test_empty_strided_p2p(self) -> None:
117117
alloc_args = (shape, stride, dtype, device, group_name)
118118

119119
t = torch.empty(shape, dtype=dtype, device=device)
120-
with self.assertRaises(RuntimeError):
121-
_SymmetricMemory.rendezvous(t)
120+
self.assertIsNone(_SymmetricMemory.rendezvous(t))
122121

123122
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
124123
symm_mem = _SymmetricMemory.rendezvous(t)
@@ -295,6 +294,69 @@ def test_optimal_layout(self, dim: int) -> None:
295294
self.assertTrue(x.movedim(dim, 0).is_contiguous())
296295
self.assertTrue(torch.allclose(x, t))
297296

297+
@skip_if_lt_x_gpu(2)
298+
@parametrize("symm_mem_input", [True, False])
299+
def test_low_contention_all_gather(self, symm_mem_input: bool) -> None:
300+
self._init_process()
301+
302+
if symm_mem_input:
303+
t = _SymmetricMemory.empty_strided_p2p(
304+
size=(64, 64),
305+
stride=(64, 1),
306+
dtype=torch.float32,
307+
device=self.device,
308+
group_name="0",
309+
).fill_(self.rank)
310+
else:
311+
t = torch.full((64, 64), self.rank, dtype=torch.float32, device=self.device)
312+
313+
res = torch.ops.symm_mem._low_contention_all_gather(t, "0")
314+
res = torch.ops._c10d_functional.wait_tensor(res)
315+
self.assertEqual(res.shape, (64 * self.world_size, 64))
316+
317+
chunks = res.chunk(self.world_size)
318+
for r in range(self.world_size):
319+
self.assertTrue(chunks[r].eq(r).all())
320+
321+
dist.destroy_process_group()
322+
323+
@skip_if_lt_x_gpu(2)
324+
@parametrize("reduce_op", ["sum", "avg"])
325+
@parametrize("symm_mem_input", [True, False])
326+
def test_low_contention_reduce_scatter(
327+
self, reduce_op: str, symm_mem_input: bool
328+
) -> None:
329+
self._init_process()
330+
331+
if symm_mem_input:
332+
t = _SymmetricMemory.empty_strided_p2p(
333+
size=(64, 64),
334+
stride=(64, 1),
335+
dtype=torch.float32,
336+
device=self.device,
337+
group_name="0",
338+
)
339+
else:
340+
t = torch.empty((64, 64), dtype=torch.float32, device=self.device)
341+
342+
chunks = t.chunk(self.world_size)
343+
for r in range(self.world_size):
344+
chunks[r].fill_(r)
345+
346+
res = torch.ops.symm_mem._low_contention_reduce_scatter(t, reduce_op, "0")
347+
res = torch.ops._c10d_functional.wait_tensor(res)
348+
self.assertEqual(res.shape, (64 // self.world_size, 64))
349+
350+
if reduce_op == "sum":
351+
expect = self.rank * self.world_size
352+
elif reduce_op == "avg":
353+
expect = self.rank
354+
else:
355+
raise AssertionError(f"Unexpected reduce_op: {reduce_op}")
356+
self.assertTrue(res.eq(expect).all())
357+
358+
dist.destroy_process_group()
359+
298360

299361
if __name__ == "__main__":
300362
run_tests()

torch/_C/_distributed_c10d.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ def _register_process_group(
634634
process_group: ProcessGroup,
635635
) -> None: ...
636636
def _resolve_process_group(group_name: str) -> ProcessGroup: ...
637+
def _register_work(tensor: torch.Tensor, work: Work) -> ProcessGroup: ...
637638
def _unregister_all_process_groups() -> None: ...
638639
def _unregister_process_group(group_name: str) -> None: ...
639640

torch/csrc/distributed/c10d/CUDASymmetricMemory.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -545,10 +545,9 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
545545
void* ptr) {
546546
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
547547
auto block = find_block(ptr);
548-
TORCH_CHECK(
549-
block != nullptr,
550-
"CUDASymmetricMemoryAllocator::rendezvous: input must be allocated ",
551-
"via CUDASymmetricMemoryAllocator::alloc");
548+
if (block == nullptr) {
549+
return nullptr;
550+
}
552551

553552
if (block->symm_mem != nullptr) {
554553
return block->symm_mem;

torch/distributed/_symmetric_memory/__init__.py

Lines changed: 227 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import uuid
33

44
from contextlib import contextmanager
5+
from datetime import timedelta
56
from functools import partial
67
from typing import Callable, Dict, Generator, List, Optional, Tuple
78

89
import torch
910
import torch.distributed._functional_collectives as funcol
1011
import torch.distributed.distributed_c10d as c10d
11-
from torch._C._distributed_c10d import _SymmetricMemory
12+
from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work
1213

1314
_group_name_to_store: Dict[str, c10d.Store] = {}
1415

@@ -263,6 +264,10 @@ def get_p2p_buf(rank: int, idx: int) -> torch.Tensor:
263264
lib.define(
264265
"fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor"
265266
)
267+
lib.define("_low_contention_all_gather(Tensor tensor, str group_name) -> Tensor")
268+
lib.define(
269+
"_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor"
270+
)
266271

267272

268273
@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
@@ -579,3 +584,224 @@ def shard_consumer(shard: torch.Tensor, rank: int) -> None:
579584
group_name,
580585
)
581586
return unflatten(ag_out), [unflatten(output) for output in outputs]
587+
588+
589+
class Work(_Work):
590+
def __init__(self) -> None:
591+
super().__init__()
592+
self.event = torch.cuda.Event()
593+
self.event.record()
594+
595+
def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool:
596+
self.event.wait()
597+
return True
598+
599+
600+
"""
601+
NOTE [low-contention collectives]
602+
When a collective is overlapped with abundant compute, it makes sense to
603+
prioritize reducing the contention between the collective and the overlapped
604+
compute, even at the cost of a slightly slower collective.
605+
606+
Common collective implementations (e.g., NCCL without user buffer
607+
registration) optimize for throughput with no ambient compute. However, such
608+
implementations may not be optimal when they are overlapped with compute:
609+
- These implementations typically fuse the entire collective into a single
610+
kernel and reserve SM resources based on the most demanding portion of the
611+
collective, even when a large portion of the collective does not require this
612+
much resource.
613+
- These implementations often use SM-based P2P copy as opposed to copy
614+
engine-based P2P copy. Copy engine-based P2P copy may not have a significant
615+
advantage when there's no ambient compute. However, it may significantly
616+
improve overall resource utilization in the presence of ambient compute.
617+
618+
When overlapped with intensive compute (e.g., persistent matmul kernels), the
619+
SM-usage of a collective can lead to inefficient overlapping.
620+
621+
Low-contention collectives achieve their goals with the following strategies:
622+
- Use copy engine-based copy whenever possible.
623+
- Break down portions of a collective with different resource requirements
624+
into multiple kernels. This improves the overlapping efficiency at the cost
625+
of additional launching overhead.
626+
"""
627+
628+
629+
@torch.library.impl(lib, "_low_contention_all_gather", "Meta")
630+
def _low_contention_all_gather_meta(
631+
tensor: torch.Tensor,
632+
group_name: str,
633+
) -> torch.Tensor:
634+
group_size = c10d._get_group_size_by_name(group_name)
635+
return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:])
636+
637+
638+
@torch.library.impl(lib, "_low_contention_all_gather", "CUDA")
639+
def _low_contention_all_gather(
640+
tensor: torch.Tensor,
641+
group_name: str,
642+
) -> torch.Tensor:
643+
"""
644+
Performs all-gather with symmetric memory in a low-contention fashion.
645+
646+
When `tensor` is already in symmetric memory:
647+
- The collective is carried out without using SMs.
648+
- No symmetric memory workspace is required.
649+
650+
When `tensor` is not in symmetric memory:
651+
- An extra SM-based copy is performed to copy the input data into the
652+
symmetric memory workspace.
653+
- Symmetric memory workspace size requirement: the size of `tensor`.
654+
"""
655+
symm_mem = _SymmetricMemory.rendezvous(tensor)
656+
if symm_mem is not None:
657+
input_is_symm_mem = True
658+
else:
659+
symm_mem = get_symm_mem_workspace(
660+
group_name, tensor.numel() * tensor.element_size()
661+
)
662+
input_is_symm_mem = False
663+
664+
rank = symm_mem.rank
665+
world_size = symm_mem.world_size
666+
667+
output = tensor.new_empty(tensor.shape[0] * world_size, *tensor.shape[1:])
668+
chunks = output.chunk(world_size)
669+
670+
_get_backend_stream().wait_stream(torch.cuda.current_stream())
671+
with torch.cuda.stream(_get_backend_stream()):
672+
if not input_is_symm_mem:
673+
local_buf = symm_mem.get_buffer(rank, tensor.shape, tensor.dtype)
674+
local_buf.copy_(tensor)
675+
# pull
676+
symm_mem.barrier()
677+
for step in range(0, world_size):
678+
remote_rank = (rank - step) % world_size
679+
src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype)
680+
chunks[remote_rank].copy_(src_buf)
681+
symm_mem.barrier()
682+
torch._C._distributed_c10d._register_work(output, Work())
683+
return output
684+
685+
686+
@torch.library.impl(lib, "_low_contention_reduce_scatter", "Meta")
687+
def _low_contention_reduce_scatter_meta(
688+
tensor: torch.Tensor,
689+
reduce_op: str,
690+
group_name: str,
691+
) -> torch.Tensor:
692+
group_size = c10d._get_group_size_by_name(group_name)
693+
return tensor.unflatten(0, (group_size, -1)).mean(dim=0)
694+
695+
696+
def _low_contention_reduce_scatter_with_symm_mem_input(
697+
tensor: torch.Tensor,
698+
reduce_op: str,
699+
symm_mem: _SymmetricMemory,
700+
) -> torch.Tensor:
701+
rank = symm_mem.rank
702+
world_size = symm_mem.world_size
703+
704+
assert tensor.shape[0] % world_size == 0
705+
a2a_res = torch.empty_like(tensor)
706+
chunks = a2a_res.chunk(world_size)
707+
708+
_get_backend_stream().wait_stream(torch.cuda.current_stream())
709+
with torch.cuda.stream(_get_backend_stream()):
710+
# pull + offline reduction
711+
symm_mem.barrier()
712+
for step in range(0, world_size):
713+
remote_rank = (rank - step) % world_size
714+
src_buf = symm_mem.get_buffer(
715+
remote_rank,
716+
chunks[0].shape,
717+
chunks[0].dtype,
718+
chunks[0].numel() * rank,
719+
)
720+
chunks[remote_rank].copy_(src_buf)
721+
symm_mem.barrier()
722+
723+
ret = a2a_res.unflatten(0, (world_size, -1))
724+
if reduce_op == "sum":
725+
ret = ret.sum(dim=0)
726+
elif reduce_op == "avg":
727+
ret = ret.mean(dim=0)
728+
else:
729+
raise ValueError(f"reduce_op ({reduce_op}) is not supported")
730+
torch._C._distributed_c10d._register_work(ret, Work())
731+
return ret
732+
733+
734+
def _low_contention_reduce_scatter_with_workspace(
735+
tensor: torch.Tensor,
736+
reduce_op: str,
737+
workspace: _SymmetricMemory,
738+
) -> torch.Tensor:
739+
rank = workspace.rank
740+
world_size = workspace.world_size
741+
742+
assert tensor.shape[0] % world_size == 0
743+
chunks = tensor.chunk(world_size)
744+
745+
_get_backend_stream().wait_stream(torch.cuda.current_stream())
746+
with torch.cuda.stream(_get_backend_stream()):
747+
# push + offline reduction
748+
workspace.barrier()
749+
for step in range(0, world_size):
750+
remote_rank = (rank - step) % world_size
751+
dst_buf = workspace.get_buffer(
752+
remote_rank, chunks[0].shape, chunks[0].dtype, chunks[0].numel() * rank
753+
)
754+
dst_buf.copy_(chunks[remote_rank])
755+
workspace.barrier()
756+
757+
buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype)
758+
ret = buf.unflatten(0, (world_size, -1))
759+
if reduce_op == "sum":
760+
ret = ret.sum(dim=0)
761+
elif reduce_op == "avg":
762+
ret = ret.mean(dim=0)
763+
else:
764+
raise ValueError(f"reduce_op ({reduce_op}) is not supported")
765+
torch._C._distributed_c10d._register_work(ret, Work())
766+
return ret
767+
768+
769+
@torch.library.impl(lib, "_low_contention_reduce_scatter", "CUDA")
770+
def _low_contention_reduce_scatter(
771+
tensor: torch.Tensor,
772+
reduce_op: str,
773+
group_name: str,
774+
) -> torch.Tensor:
775+
"""
776+
Performs reduce-scatter with symmetric memory in a low-contention fashion.
777+
778+
This implementation performs a P2P-based all-to-all followed by an offline
779+
reduction.
780+
781+
When `tensor` is already in symmetric memory:
782+
- Pull-based all-to-all is used.
783+
- No symmetric memory workspace is required.
784+
785+
When `tensor` is not in symmetric memory:
786+
- Push-based all-to-all is used.
787+
- Symmetric memory workspace size requirement: the size of `tensor`.
788+
789+
SM-usage:
790+
- SM-based copy of the rank's own chunk for the all-to-all.
791+
- Reduction on the all-to-all result.
792+
793+
TODO(yifu): the SM-based copy can be avoided with a list-based reduction
794+
kernel.
795+
"""
796+
symm_mem = _SymmetricMemory.rendezvous(tensor)
797+
if symm_mem is not None:
798+
return _low_contention_reduce_scatter_with_symm_mem_input(
799+
tensor, reduce_op, symm_mem
800+
)
801+
else:
802+
workspace = get_symm_mem_workspace(
803+
group_name, tensor.numel() * tensor.element_size()
804+
)
805+
return _low_contention_reduce_scatter_with_workspace(
806+
tensor, reduce_op, workspace
807+
)

0 commit comments

Comments
 (0)