Skip to content

Commit 8afbd96

Browse files
CUDA: fix race condition in MMQ ids_dst (ggml-org#13294)
1 parent 8ae5ebc commit 8afbd96

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

ggml/src/ggml-cuda/mmq.cuh

+7
Original file line numberDiff line numberDiff line change
@@ -2636,6 +2636,7 @@ static __global__ void mul_mat_q(
26362636

26372637
ids_dst_shared[j] = j;
26382638
}
2639+
__syncthreads();
26392640

26402641
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
26412642
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
@@ -2664,6 +2665,7 @@ static __global__ void mul_mat_q(
26642665
return;
26652666
}
26662667

2668+
// __syncthreads(); // There is no previous tile that could cause a race condition.
26672669
#pragma unroll
26682670
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
26692671
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2674,6 +2676,7 @@ static __global__ void mul_mat_q(
26742676

26752677
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
26762678
}
2679+
__syncthreads();
26772680
}
26782681

26792682
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
@@ -2740,6 +2743,7 @@ static __global__ void mul_mat_q(
27402743
continue;
27412744
}
27422745

2746+
__syncthreads();
27432747
#pragma unroll
27442748
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
27452749
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2750,6 +2754,7 @@ static __global__ void mul_mat_q(
27502754

27512755
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
27522756
}
2757+
__syncthreads();
27532758
}
27542759

27552760
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
@@ -2805,6 +2810,7 @@ static __global__ void mul_mat_q(
28052810
}
28062811

28072812
// The memory layout for the fixup buffer is always contiguous, therefore reset ids:
2813+
__syncthreads();
28082814
#pragma unroll
28092815
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
28102816
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2815,6 +2821,7 @@ static __global__ void mul_mat_q(
28152821

28162822
ids_dst_shared[j] = j;
28172823
}
2824+
__syncthreads();
28182825
}
28192826

28202827
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));

0 commit comments

Comments
 (0)