Skip to content

Commit ed52f36

Browse files
authored
sycl: Remove not needed copy f16->f32 for dnnl mul mat (ggml-org#14125)
1 parent a681b4b commit ed52f36

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

ggml/src/ggml-sycl/gemm.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class DnnlGemmWrapper {
6565

6666
dnnl::primitive_attr primitive_attr;
6767
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
68+
#ifdef GGML_SYCL_F16
69+
primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
70+
#endif
6871

6972
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
7073
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2127,21 +2127,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
21272127
const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
21282128
? (const sycl::half *)src1->data + src1_padded_row_size
21292129
: src1_as_f16.get();
2130-
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
21312130

21322131
#if GGML_SYCL_DNNL
21332132
if (!g_ggml_sycl_disable_dnn) {
21342133
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
21352134
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2136-
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
2137-
scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
2138-
" : converting dst to fp32");
2139-
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2140-
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2135+
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
21412136
}
21422137
else
21432138
#endif
21442139
{
2140+
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2141+
21452142
const sycl::half alpha_f16 = 1.0f;
21462143
const sycl::half beta_f16 = 0.0f;
21472144
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(

0 commit comments

Comments
 (0)