Skip to content

Commit 6a7a6c3

Browse files
committed
apps/nccl: fix a bug in allreduce kernels for graph mode
allreduce7 and allreduce8 were updating the LL protocol flag on the host side. So, it was not properly captured in graph mode. This PR fixes the issue by updating the flag in the kernels.
1 parent adc9ee5 commit 6a7a6c3

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

apps/nccl/src/allreduce.hpp

+33-5
Original file line numberDiff line numberDiff line change
@@ -441,11 +441,18 @@ template <typename T>
441441
__global__ void __launch_bounds__(32, 1)
442442
allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
443443
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
444-
Op op, size_t nelems, uint32_t flag) {
444+
Op op, size_t nelems, uint64_t* deviceFlag, mscclpp::DeviceSyncer* deviceSyncer) {
445445
// This version of allreduce only works for single nodes
446446
if (worldSize != nRanksPerNode) return;
447447
if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
448448
const int nPeers = nRanksPerNode - 1;
449+
450+
uint64_t commFlag = *deviceFlag;
451+
uint32_t flag = (uint32_t) commFlag;
452+
453+
size_t scratchBaseOffset = (flag % 2) ? SCRATCH_SIZE/2 : 0;
454+
channelScratchOffset = scratchBaseOffset;
455+
449456
const int nBlocksPerPeer = gridDim.x / nPeers;
450457
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
451458
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
@@ -478,13 +485,20 @@ __global__ void __launch_bounds__(32, 1)
478485
}
479486
dst[idx] = data;
480487
}
488+
__syncthreads();
489+
490+
deviceSyncer->sync(gridDim.x);
491+
492+
if (blockIdx.x == 0 && threadIdx.x == 0) {
493+
*deviceFlag = *deviceFlag + 1;
494+
}
481495
}
482496

483497
template <typename T>
484498
__global__ void __launch_bounds__(1024, 1)
485499
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
486500
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, Op op,
487-
size_t nelems, uint32_t flag
501+
size_t nelems, uint64_t* deviceFlag, mscclpp::DeviceSyncer* deviceSyncer
488502
#if defined(ENABLE_NPKIT)
489503
,
490504
NpKitEventCollectContext* npKitEventCollectContexts, uint64_t* cpuTimestamp) {
@@ -527,6 +541,13 @@ __global__ void __launch_bounds__(1024, 1)
527541
const int nPeers = nRanksPerNode - 1;
528542
const size_t nPkts = nelems / 2;
529543

544+
uint64_t commFlag = *deviceFlag;
545+
uint32_t flag = (uint32_t) commFlag;
546+
547+
size_t scratchBaseOffset = (flag % 2) ? SCRATCH_SIZE/2 : 0;
548+
channelScratchOffset = scratchBaseOffset;
549+
550+
530551
int nelemsPerRank = nelems / worldSize;
531552
if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T);
532553

@@ -580,6 +601,7 @@ __global__ void __launch_bounds__(1024, 1)
580601
channels[index].write(offset, packet);
581602
}
582603
}
604+
__syncthreads();
583605
// step 3: get data result from scratch buffer
584606
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
585607
const int dstOffset = remoteRank * nPktsPerRank;
@@ -589,6 +611,7 @@ __global__ void __launch_bounds__(1024, 1)
589611
result[idx].x = data.x;
590612
result[idx].y = data.y;
591613
}
614+
__syncthreads();
592615
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \
593616
defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT)
594617
NpKit::CollectGpuEventShm(NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer,
@@ -599,6 +622,11 @@ __global__ void __launch_bounds__(1024, 1)
599622
#if defined(ENABLE_NPKIT)
600623
NpKit::StoreGpuEventShm(npKitEventCollectContexts, event_buffer, event_buffer_head);
601624
#endif
625+
deviceSyncer->sync(gridDim.x);
626+
627+
if (blockIdx.x == 0 && threadIdx.x == 0) {
628+
*deviceFlag = *deviceFlag + 1;
629+
}
602630
}
603631

604632
template <typename T>
@@ -741,15 +769,15 @@ template <typename T>
741769
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
742770
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryOutChannels, size_t channelInOffset,
743771
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
744-
Op op, size_t nelems, cudaStream_t stream) {
772+
Op op, size_t nelems, cudaStream_t stream, uint64_t* deviceFlag, mscclpp::DeviceSyncer* syncer) {
745773
static uint32_t flag = 1;
746774

747775
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
748776
int nBlocks = 7;
749777
int nThreadsPerBlock = 32;
750778
allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, memoryChannels,
751779
channelInOffset, channelScratchOffset, rank,
752-
nRanksPerNode, worldSize, op, nelems, flag++);
780+
nRanksPerNode, worldSize, op, nelems, deviceFlag, syncer);
753781
} else if (sizeof(T) * nelems <= (1 << 20)) {
754782
int nBlocks = 28;
755783
int nThreadsPerBlock = 1024;
@@ -765,7 +793,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
765793
#else
766794
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, memoryChannels, channelInOffset,
767795
channelScratchOffset, rank, nRanksPerNode, worldSize, op,
768-
nelems, flag++);
796+
nelems, deviceFlag, syncer);
769797
#endif
770798
} else {
771799
int nBlocks = 35;

apps/nccl/src/nccl.cu

+16-4
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ struct ncclComm {
194194
uint32_t numScratchBuff;
195195
uint32_t buffFlag;
196196

197+
uint64_t* deviceFlag;
198+
mscclpp::DeviceSyncer *syncer;
199+
197200
void* mscclppNcclComm;
198201
};
199202

@@ -384,23 +387,25 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
384387
case ncclFloat16:
385388
CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, memoryChannels,
386389
memoryOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
387-
comm->comm->bootstrap()->getNranks(), reduceOp, count, stream));
390+
comm->comm->bootstrap()->getNranks(), reduceOp, count, stream, comm->deviceFlag, comm->syncer));
388391
break;
389392
case ncclFloat32:
390393
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, memoryChannels,
391394
memoryOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
392-
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), reduceOp, count, stream));
395+
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), reduceOp, count, stream,
396+
comm->deviceFlag, comm->syncer));
393397
break;
394398
case ncclBfloat16:
395399
CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff,
396400
memoryChannels, memoryOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
397-
comm->comm->bootstrap()->getNranks(), reduceOp, count, stream));
401+
comm->comm->bootstrap()->getNranks(), reduceOp, count, stream, comm->deviceFlag, comm->syncer));
398402
break;
399403
case ncclInt32:
400404
case ncclUint32:
401405
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, memoryChannels,
402406
memoryOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
403-
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), reduceOp, count, stream));
407+
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), reduceOp, count, stream,
408+
comm->deviceFlag, comm->syncer));
404409
break;
405410
default:
406411
WARN("datatype is invalid, datatype: %d", datatype);
@@ -524,6 +529,13 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
524529
commPtr->scratchBuff = mscclpp::GpuBuffer<char>(SCRATCH_SIZE).memory();
525530
commPtr->remoteScratchRegMemories =
526531
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
532+
533+
hipMalloc((void**)&(commPtr->syncer), sizeof(mscclpp::DeviceSyncer));
534+
hipMemset((void*)(commPtr->syncer), 0, sizeof(mscclpp::DeviceSyncer));
535+
536+
uint64_t initFlag = 1;
537+
hipMalloc((void**)&(commPtr->deviceFlag), sizeof(uint64_t));
538+
hipMemcpy((void*)(commPtr->deviceFlag), &initFlag, sizeof(uint64_t), hipMemcpyHostToDevice);
527539
}
528540

529541
NCCL_API ncclResult_t ncclGetVersion(int* version) {

0 commit comments

Comments
 (0)