@@ -2636,6 +2636,7 @@ static __global__ void mul_mat_q(
2636
2636
2637
2637
ids_dst_shared[j] = j;
2638
2638
}
2639
+ __syncthreads ();
2639
2640
2640
2641
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
2641
2642
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
@@ -2664,6 +2665,7 @@ static __global__ void mul_mat_q(
2664
2665
return ;
2665
2666
}
2666
2667
2668
+ // __syncthreads(); // There is no previous tile that could cause a race condition.
2667
2669
#pragma unroll
2668
2670
for (int j0 = 0 ; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2669
2671
const int j = j0 + threadIdx .y *WARP_SIZE + threadIdx .x ;
@@ -2674,6 +2676,7 @@ static __global__ void mul_mat_q(
2674
2676
2675
2677
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
2676
2678
}
2679
+ __syncthreads ();
2677
2680
}
2678
2681
2679
2682
offset_y += (col_low + jt*mmq_x)*(sizeof (block_q8_1_mmq)/sizeof (int ));
@@ -2740,6 +2743,7 @@ static __global__ void mul_mat_q(
2740
2743
continue ;
2741
2744
}
2742
2745
2746
+ __syncthreads ();
2743
2747
#pragma unroll
2744
2748
for (int j0 = 0 ; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2745
2749
const int j = j0 + threadIdx .y *WARP_SIZE + threadIdx .x ;
@@ -2750,6 +2754,7 @@ static __global__ void mul_mat_q(
2750
2754
2751
2755
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
2752
2756
}
2757
+ __syncthreads ();
2753
2758
}
2754
2759
2755
2760
offset_y += (col_low + jt*mmq_x)*(sizeof (block_q8_1_mmq)/sizeof (int ));
@@ -2805,6 +2810,7 @@ static __global__ void mul_mat_q(
2805
2810
}
2806
2811
2807
2812
// The memory layout for the fixup buffer is always contiguous, therefore reset ids:
2813
+ __syncthreads ();
2808
2814
#pragma unroll
2809
2815
for (int j0 = 0 ; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
2810
2816
const int j = j0 + threadIdx .y *WARP_SIZE + threadIdx .x ;
@@ -2815,6 +2821,7 @@ static __global__ void mul_mat_q(
2815
2821
2816
2822
ids_dst_shared[j] = j;
2817
2823
}
2824
+ __syncthreads ();
2818
2825
}
2819
2826
2820
2827
offset_y += (col_low + jt*mmq_x)*(sizeof (block_q8_1_mmq)/sizeof (int ));
0 commit comments