|
2 | 2 | import uuid
|
3 | 3 |
|
4 | 4 | from contextlib import contextmanager
|
| 5 | +from datetime import timedelta |
5 | 6 | from functools import partial
|
6 | 7 | from typing import Callable, Dict, Generator, List, Optional, Tuple
|
7 | 8 |
|
8 | 9 | import torch
|
9 | 10 | import torch.distributed._functional_collectives as funcol
|
10 | 11 | import torch.distributed.distributed_c10d as c10d
|
11 |
| -from torch._C._distributed_c10d import _SymmetricMemory |
| 12 | +from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work |
12 | 13 |
|
13 | 14 | _group_name_to_store: Dict[str, c10d.Store] = {}
|
14 | 15 |
|
@@ -263,6 +264,10 @@ def get_p2p_buf(rank: int, idx: int) -> torch.Tensor:
|
263 | 264 | lib.define(
|
264 | 265 | "fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor"
|
265 | 266 | )
|
| 267 | +lib.define("_low_contention_all_gather(Tensor tensor, str group_name) -> Tensor") |
| 268 | +lib.define( |
| 269 | + "_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor" |
| 270 | +) |
266 | 271 |
|
267 | 272 |
|
268 | 273 | @torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
|
@@ -579,3 +584,224 @@ def shard_consumer(shard: torch.Tensor, rank: int) -> None:
|
579 | 584 | group_name,
|
580 | 585 | )
|
581 | 586 | return unflatten(ag_out), [unflatten(output) for output in outputs]
|
| 587 | + |
| 588 | + |
| 589 | +class Work(_Work): |
| 590 | + def __init__(self) -> None: |
| 591 | + super().__init__() |
| 592 | + self.event = torch.cuda.Event() |
| 593 | + self.event.record() |
| 594 | + |
| 595 | + def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool: |
| 596 | + self.event.wait() |
| 597 | + return True |
| 598 | + |
| 599 | + |
| 600 | +""" |
| 601 | +NOTE [low-contention collectives] |
| 602 | +When a collective is overlapped with abundant compute, it makes sense to |
| 603 | +prioritize reducing the contention between the collective and the overlapped |
| 604 | +compute, even at the cost of a slightly slower collective. |
| 605 | +
|
| 606 | +Common collective implementations (e.g., NCCL without user buffer |
| 607 | +registration) optimize for throughput with no ambient compute. However, such |
| 608 | +implementations may not be optimal when they are overlapped with compute: |
| 609 | +- These implementations typically fuse the entire collective into a single |
| 610 | +kernel and reserve SM resources based on the most demanding portion of the |
| 611 | +collective, even when a large portion of the collective does not require this |
| 612 | +much resource. |
| 613 | +- These implementations often use SM-based P2P copy as opposed to copy |
| 614 | +engine-based P2P copy. Copy engine-based P2P copy may not have a significant |
| 615 | +advantage when there's no ambient compute. However, it may significantly |
| 616 | +improve overall resource utilization in the presence of ambient compute. |
| 617 | +
|
| 618 | +When overlapped with intensive compute (e.g., persistent matmul kernels), the |
| 619 | +SM-usage of a collective can lead to inefficient overlapping. |
| 620 | +
|
| 621 | +Low-contention collectives achieve their goals with the following strategies: |
| 622 | +- Use copy engine-based copy whenever possible. |
| 623 | +- Break down portions of a collective with different resource requirements |
| 624 | +into multiple kernels. This improves the overlapping efficiency at the cost |
| 625 | +of additional launching overhead. |
| 626 | +""" |
| 627 | + |
| 628 | + |
| 629 | +@torch.library.impl(lib, "_low_contention_all_gather", "Meta") |
| 630 | +def _low_contention_all_gather_meta( |
| 631 | + tensor: torch.Tensor, |
| 632 | + group_name: str, |
| 633 | +) -> torch.Tensor: |
| 634 | + group_size = c10d._get_group_size_by_name(group_name) |
| 635 | + return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:]) |
| 636 | + |
| 637 | + |
| 638 | +@torch.library.impl(lib, "_low_contention_all_gather", "CUDA") |
| 639 | +def _low_contention_all_gather( |
| 640 | + tensor: torch.Tensor, |
| 641 | + group_name: str, |
| 642 | +) -> torch.Tensor: |
| 643 | + """ |
| 644 | + Performs all-gather with symmetric memory in a low-contention fashion. |
| 645 | +
|
| 646 | + When `tensor` is already in symmetric memory: |
| 647 | + - The collective is carried out without using SMs. |
| 648 | + - No symmetric memory workspace is required. |
| 649 | +
|
| 650 | + When `tensor` is not in symmetric memory: |
| 651 | + - An extra SM-based copy is performed to copy the input data into the |
| 652 | + symmetric memory workspace. |
| 653 | + - Symmetric memory workspace size requirement: the size of `tensor`. |
| 654 | + """ |
| 655 | + symm_mem = _SymmetricMemory.rendezvous(tensor) |
| 656 | + if symm_mem is not None: |
| 657 | + input_is_symm_mem = True |
| 658 | + else: |
| 659 | + symm_mem = get_symm_mem_workspace( |
| 660 | + group_name, tensor.numel() * tensor.element_size() |
| 661 | + ) |
| 662 | + input_is_symm_mem = False |
| 663 | + |
| 664 | + rank = symm_mem.rank |
| 665 | + world_size = symm_mem.world_size |
| 666 | + |
| 667 | + output = tensor.new_empty(tensor.shape[0] * world_size, *tensor.shape[1:]) |
| 668 | + chunks = output.chunk(world_size) |
| 669 | + |
| 670 | + _get_backend_stream().wait_stream(torch.cuda.current_stream()) |
| 671 | + with torch.cuda.stream(_get_backend_stream()): |
| 672 | + if not input_is_symm_mem: |
| 673 | + local_buf = symm_mem.get_buffer(rank, tensor.shape, tensor.dtype) |
| 674 | + local_buf.copy_(tensor) |
| 675 | + # pull |
| 676 | + symm_mem.barrier() |
| 677 | + for step in range(0, world_size): |
| 678 | + remote_rank = (rank - step) % world_size |
| 679 | + src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype) |
| 680 | + chunks[remote_rank].copy_(src_buf) |
| 681 | + symm_mem.barrier() |
| 682 | + torch._C._distributed_c10d._register_work(output, Work()) |
| 683 | + return output |
| 684 | + |
| 685 | + |
| 686 | +@torch.library.impl(lib, "_low_contention_reduce_scatter", "Meta") |
| 687 | +def _low_contention_reduce_scatter_meta( |
| 688 | + tensor: torch.Tensor, |
| 689 | + reduce_op: str, |
| 690 | + group_name: str, |
| 691 | +) -> torch.Tensor: |
| 692 | + group_size = c10d._get_group_size_by_name(group_name) |
| 693 | + return tensor.unflatten(0, (group_size, -1)).mean(dim=0) |
| 694 | + |
| 695 | + |
| 696 | +def _low_contention_reduce_scatter_with_symm_mem_input( |
| 697 | + tensor: torch.Tensor, |
| 698 | + reduce_op: str, |
| 699 | + symm_mem: _SymmetricMemory, |
| 700 | +) -> torch.Tensor: |
| 701 | + rank = symm_mem.rank |
| 702 | + world_size = symm_mem.world_size |
| 703 | + |
| 704 | + assert tensor.shape[0] % world_size == 0 |
| 705 | + a2a_res = torch.empty_like(tensor) |
| 706 | + chunks = a2a_res.chunk(world_size) |
| 707 | + |
| 708 | + _get_backend_stream().wait_stream(torch.cuda.current_stream()) |
| 709 | + with torch.cuda.stream(_get_backend_stream()): |
| 710 | + # pull + offline reduction |
| 711 | + symm_mem.barrier() |
| 712 | + for step in range(0, world_size): |
| 713 | + remote_rank = (rank - step) % world_size |
| 714 | + src_buf = symm_mem.get_buffer( |
| 715 | + remote_rank, |
| 716 | + chunks[0].shape, |
| 717 | + chunks[0].dtype, |
| 718 | + chunks[0].numel() * rank, |
| 719 | + ) |
| 720 | + chunks[remote_rank].copy_(src_buf) |
| 721 | + symm_mem.barrier() |
| 722 | + |
| 723 | + ret = a2a_res.unflatten(0, (world_size, -1)) |
| 724 | + if reduce_op == "sum": |
| 725 | + ret = ret.sum(dim=0) |
| 726 | + elif reduce_op == "avg": |
| 727 | + ret = ret.mean(dim=0) |
| 728 | + else: |
| 729 | + raise ValueError(f"reduce_op ({reduce_op}) is not supported") |
| 730 | + torch._C._distributed_c10d._register_work(ret, Work()) |
| 731 | + return ret |
| 732 | + |
| 733 | + |
| 734 | +def _low_contention_reduce_scatter_with_workspace( |
| 735 | + tensor: torch.Tensor, |
| 736 | + reduce_op: str, |
| 737 | + workspace: _SymmetricMemory, |
| 738 | +) -> torch.Tensor: |
| 739 | + rank = workspace.rank |
| 740 | + world_size = workspace.world_size |
| 741 | + |
| 742 | + assert tensor.shape[0] % world_size == 0 |
| 743 | + chunks = tensor.chunk(world_size) |
| 744 | + |
| 745 | + _get_backend_stream().wait_stream(torch.cuda.current_stream()) |
| 746 | + with torch.cuda.stream(_get_backend_stream()): |
| 747 | + # push + offline reduction |
| 748 | + workspace.barrier() |
| 749 | + for step in range(0, world_size): |
| 750 | + remote_rank = (rank - step) % world_size |
| 751 | + dst_buf = workspace.get_buffer( |
| 752 | + remote_rank, chunks[0].shape, chunks[0].dtype, chunks[0].numel() * rank |
| 753 | + ) |
| 754 | + dst_buf.copy_(chunks[remote_rank]) |
| 755 | + workspace.barrier() |
| 756 | + |
| 757 | + buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype) |
| 758 | + ret = buf.unflatten(0, (world_size, -1)) |
| 759 | + if reduce_op == "sum": |
| 760 | + ret = ret.sum(dim=0) |
| 761 | + elif reduce_op == "avg": |
| 762 | + ret = ret.mean(dim=0) |
| 763 | + else: |
| 764 | + raise ValueError(f"reduce_op ({reduce_op}) is not supported") |
| 765 | + torch._C._distributed_c10d._register_work(ret, Work()) |
| 766 | + return ret |
| 767 | + |
| 768 | + |
| 769 | +@torch.library.impl(lib, "_low_contention_reduce_scatter", "CUDA") |
| 770 | +def _low_contention_reduce_scatter( |
| 771 | + tensor: torch.Tensor, |
| 772 | + reduce_op: str, |
| 773 | + group_name: str, |
| 774 | +) -> torch.Tensor: |
| 775 | + """ |
| 776 | + Performs reduce-scatter with symmetric memory in a low-contention fashion. |
| 777 | +
|
| 778 | + This implementation performs a P2P-based all-to-all followed by an offline |
| 779 | + reduction. |
| 780 | +
|
| 781 | + When `tensor` is already in symmetric memory: |
| 782 | + - Pull-based all-to-all is used. |
| 783 | + - No symmetric memory workspace is required. |
| 784 | +
|
| 785 | + When `tensor` is not in symmetric memory: |
| 786 | + - Push-based all-to-all is used. |
| 787 | + - Symmetric memory workspace size requirement: the size of `tensor`. |
| 788 | +
|
| 789 | + SM-usage: |
| 790 | + - SM-based copy of the rank's own chunk for the all-to-all. |
| 791 | + - Reduction on the all-to-all result. |
| 792 | +
|
| 793 | + TODO(yifu): the SM-based copy can be avoided with a list-based reduction |
| 794 | + kernel. |
| 795 | + """ |
| 796 | + symm_mem = _SymmetricMemory.rendezvous(tensor) |
| 797 | + if symm_mem is not None: |
| 798 | + return _low_contention_reduce_scatter_with_symm_mem_input( |
| 799 | + tensor, reduce_op, symm_mem |
| 800 | + ) |
| 801 | + else: |
| 802 | + workspace = get_symm_mem_workspace( |
| 803 | + group_name, tensor.numel() * tensor.element_size() |
| 804 | + ) |
| 805 | + return _low_contention_reduce_scatter_with_workspace( |
| 806 | + tensor, reduce_op, workspace |
| 807 | + ) |
0 commit comments