@@ -441,11 +441,18 @@ template <typename T>
441
441
__global__ void __launch_bounds__ (32 , 1 )
442
442
allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
443
443
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
444
- Op op, size_t nelems, uint32_t flag ) {
444
+ Op op, size_t nelems, uint64_t * deviceFlag, mscclpp::DeviceSyncer* deviceSyncer ) {
445
445
// This version of allreduce only works for single nodes
446
446
if (worldSize != nRanksPerNode) return ;
447
447
if (sizeof (T) == 2 ) nelems = (nelems * sizeof (T) + sizeof (T)) / sizeof (int );
448
448
const int nPeers = nRanksPerNode - 1 ;
449
+
450
+ uint64_t commFlag = *deviceFlag;
451
+ uint32_t flag = (uint32_t ) commFlag;
452
+
453
+ size_t scratchBaseOffset = (flag % 2 ) ? SCRATCH_SIZE/2 : 0 ;
454
+ channelScratchOffset = scratchBaseOffset;
455
+
449
456
const int nBlocksPerPeer = gridDim.x / nPeers;
450
457
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
451
458
const int tid = threadIdx.x + localBlockIdx * blockDim.x ;
@@ -478,13 +485,20 @@ __global__ void __launch_bounds__(32, 1)
478
485
}
479
486
dst[idx] = data;
480
487
}
488
+ __syncthreads ();
489
+
490
+ deviceSyncer->sync (gridDim.x );
491
+
492
+ if (blockIdx.x == 0 && threadIdx.x == 0 ) {
493
+ *deviceFlag = *deviceFlag + 1 ;
494
+ }
481
495
}
482
496
483
497
template <typename T>
484
498
__global__ void __launch_bounds__ (1024 , 1 )
485
499
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
486
500
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, Op op,
487
- size_t nelems, uint32_t flag
501
+ size_t nelems, uint64_t * deviceFlag, mscclpp::DeviceSyncer* deviceSyncer
488
502
#if defined(ENABLE_NPKIT)
489
503
,
490
504
NpKitEventCollectContext* npKitEventCollectContexts, uint64_t * cpuTimestamp) {
@@ -527,6 +541,13 @@ __global__ void __launch_bounds__(1024, 1)
527
541
const int nPeers = nRanksPerNode - 1 ;
528
542
const size_t nPkts = nelems / 2 ;
529
543
544
+ uint64_t commFlag = *deviceFlag;
545
+ uint32_t flag = (uint32_t ) commFlag;
546
+
547
+ size_t scratchBaseOffset = (flag % 2 ) ? SCRATCH_SIZE/2 : 0 ;
548
+ channelScratchOffset = scratchBaseOffset;
549
+
550
+
530
551
int nelemsPerRank = nelems / worldSize;
531
552
if ((nelemsPerRank % 2 )) nelemsPerRank = (nelemsPerRank * sizeof (T) + sizeof (T)) / sizeof (T);
532
553
@@ -580,6 +601,7 @@ __global__ void __launch_bounds__(1024, 1)
580
601
channels[index ].write (offset, packet);
581
602
}
582
603
}
604
+ __syncthreads ();
583
605
// step 3: get data result from scratch buffer
584
606
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char *)scratch + scratchResultOffset);
585
607
const int dstOffset = remoteRank * nPktsPerRank;
@@ -589,6 +611,7 @@ __global__ void __launch_bounds__(1024, 1)
589
611
result[idx].x = data.x ;
590
612
result[idx].y = data.y ;
591
613
}
614
+ __syncthreads ();
592
615
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \
593
616
defined (ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT)
594
617
NpKit::CollectGpuEventShm (NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0 , 0 , npkit_timestamp_entry, event_buffer,
@@ -599,6 +622,11 @@ __global__ void __launch_bounds__(1024, 1)
599
622
#if defined(ENABLE_NPKIT)
600
623
NpKit::StoreGpuEventShm (npKitEventCollectContexts, event_buffer, event_buffer_head);
601
624
#endif
625
+ deviceSyncer->sync (gridDim.x );
626
+
627
+ if (blockIdx.x == 0 && threadIdx.x == 0 ) {
628
+ *deviceFlag = *deviceFlag + 1 ;
629
+ }
602
630
}
603
631
604
632
template <typename T>
@@ -741,15 +769,15 @@ template <typename T>
741
769
cudaError_t allreduce (T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
742
770
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryOutChannels, size_t channelInOffset,
743
771
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
744
- Op op, size_t nelems, cudaStream_t stream) {
772
+ Op op, size_t nelems, cudaStream_t stream, uint64_t * deviceFlag, mscclpp::DeviceSyncer* syncer ) {
745
773
static uint32_t flag = 1 ;
746
774
747
775
if (sizeof (T) * nelems < worldSize * sizeof (int )) {
748
776
int nBlocks = 7 ;
749
777
int nThreadsPerBlock = 32 ;
750
778
allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0 , stream>>>(buff, scratch, resultBuff, memoryChannels,
751
779
channelInOffset, channelScratchOffset, rank,
752
- nRanksPerNode, worldSize, op, nelems, flag++ );
780
+ nRanksPerNode, worldSize, op, nelems, deviceFlag, syncer );
753
781
} else if (sizeof (T) * nelems <= (1 << 20 )) {
754
782
int nBlocks = 28 ;
755
783
int nThreadsPerBlock = 1024 ;
@@ -765,7 +793,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
765
793
#else
766
794
allreduce7<<<nBlocks, nThreadsPerBlock, 0 , stream>>>(buff, scratch, resultBuff, memoryChannels, channelInOffset,
767
795
channelScratchOffset, rank, nRanksPerNode, worldSize, op,
768
- nelems, flag++ );
796
+ nelems, deviceFlag, syncer );
769
797
#endif
770
798
} else {
771
799
int nBlocks = 35 ;
0 commit comments