@@ -14986,6 +14986,9 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
14986
14986
SYCL_CHECK(ggml_sycl_set_device(g_main_device));
14987
14987
dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
14988
14988
14989
+ bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda ||
14990
+ main_stream->get_backend() == sycl::backend::ext_oneapi_hip;
14991
+
14989
14992
SYCL_CHECK(
14990
14993
CHECK_TRY_ERROR(g_sycl_handles[g_main_device] = main_stream));
14991
14994
@@ -15016,24 +15019,38 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15016
15019
15017
15020
dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
15018
15021
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
15022
+ if (no_mixed_dtypes) {
15023
+ cu_compute_type = dpct::library_data_t::real_half;
15024
+ cu_data_type = dpct::library_data_t::real_half;
15025
+ }
15019
15026
15020
15027
// dst strides
15021
15028
size_t nbd2 = dst->nb[2];
15022
15029
size_t nbd3 = dst->nb[3];
15023
15030
15031
+ const float alpha_f32 = 1.0f;
15032
+ const float beta_f32 = 0.0f;
15033
+
15024
15034
const sycl::half alpha_f16 = 1.0f;
15025
15035
const sycl::half beta_f16 = 0.0f;
15026
15036
15027
- const float alpha_f32 = 1.0f;
15028
- const float beta_f32 = 0.0f;
15029
-
15030
15037
const void * alpha = &alpha_f32;
15031
15038
const void * beta = &beta_f32;
15039
+ if (no_mixed_dtypes) {
15040
+ alpha = &alpha_f16;
15041
+ beta = &beta_f16;
15042
+ }
15032
15043
15033
15044
// TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
15034
- // oneMKL open source supports half, half, float, float: datatypes
15045
+ // when oneMKL open source supports half, half, float, float: datatypes
15035
15046
15036
15047
dst_t = (char *) dst_ddf;
15048
+ if (no_mixed_dtypes) {
15049
+ dst_t = (char *) dst_f16.alloc(ne_dst);
15050
+
15051
+ nbd2 /= sizeof(float) / sizeof(sycl::half);
15052
+ nbd3 /= sizeof(float) / sizeof(sycl::half);
15053
+ }
15037
15054
15038
15055
GGML_ASSERT(ne12 % ne02 == 0);
15039
15056
GGML_ASSERT(ne13 % ne03 == 0);
@@ -15119,6 +15136,10 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15119
15136
}
15120
15137
#endif
15121
15138
15139
+ if (no_mixed_dtypes) {
15140
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
15141
+ to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
15142
+ }
15122
15143
}
15123
15144
catch (sycl::exception const &exc) {
15124
15145
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
0 commit comments