Skip to content

Commit 6eecde3

Browse files
HIP: fix flash_attn_stream_k_fixup warning (ggml-org#11604)
1 parent 396856b commit 6eecde3

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

+10
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,12 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516516
nullptr;
517517
}
518518

519+
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
520+
#ifdef __clang__
521+
#pragma clang diagnostic push
522+
#pragma clang diagnostic ignored "-Wpass-failed"
523+
#endif // __clang__
524+
519525
template<int D, int ncols, int KQ_stride> // D == head size
520526
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521527
__launch_bounds__(D, 1)
@@ -614,6 +620,10 @@ static __global__ void flash_attn_stream_k_fixup(
614620
}
615621
}
616622

623+
#ifdef __clang__
624+
#pragma clang diagnostic pop
625+
#endif // __clang__
626+
617627
template<int D, int parallel_blocks> // D == head size
618628
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
619629
__launch_bounds__(D, 1)

ggml/src/ggml-cuda/softmax.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
1818
#ifdef __clang__
1919
#pragma clang diagnostic push
2020
#pragma clang diagnostic ignored "-Wpass-failed"
21-
#endif
21+
#endif // __clang__
2222
template <bool use_shared, int ncols_template, int block_size_template, typename T>
2323
static __global__ void soft_max_f32(
2424
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
@@ -126,7 +126,7 @@ static __global__ void soft_max_f32(
126126
}
127127
#ifdef __clang__
128128
#pragma clang diagnostic pop
129-
#endif
129+
#endif // __clang__
130130

131131
static __global__ void soft_max_back_f32(
132132
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {

0 commit comments

Comments
 (0)