Skip to content

Commit 19cc34d

Browse files
committed
xe: softmax: restrict new kernel from PR2525 for only xe_hpc
1 parent 1686471 commit 19cc34d

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/gpu/intel/ocl/reusable_softmax.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ struct reusable_softmax_fwd_t : public gpu_primitive_t {
101101
using arch_t = compute::gpu_arch_t;
102102
auto *compute_engine
103103
= utils::downcast<compute::compute_engine_t *>(engine);
104+
const arch_t arch = compute_engine->device_info()->gpu_arch();
104105

105106
const memory_desc_wrapper src_mdw(src_md());
106107
const memory_desc_wrapper dst_mdw(dst_md());
@@ -118,6 +119,13 @@ struct reusable_softmax_fwd_t : public gpu_primitive_t {
118119
utils::one_of(dst_dt, f64, f32, f16, bf16, u8, s8),
119120
VERBOSE_UNSUPPORTED_DT);
120121

122+
VDISPATCH_SOFTMAX(IMPLICATION(utils::one_of(src_dt, f16, bf16),
123+
arch == arch_t::xe_hpc),
124+
VERBOSE_UNSUPPORTED_DT_CFG);
125+
VDISPATCH_SOFTMAX(IMPLICATION(utils::one_of(dst_dt, f16, bf16),
126+
arch == arch_t::xe_hpc),
127+
VERBOSE_UNSUPPORTED_DT_CFG);
128+
121129
VDISPATCH_SOFTMAX(IMPLICATION(utils::one_of(f16, src_dt, dst_dt),
122130
compute_engine->mayiuse(
123131
compute::device_ext_t::khr_fp16)),
@@ -193,11 +201,10 @@ struct reusable_softmax_fwd_t : public gpu_primitive_t {
193201
}
194202
}
195203

196-
const arch_t arch_ = compute_engine->device_info()->gpu_arch();
197204
const auto nelems = src_mdw.nelems();
198205

199206
conf.algorithm_number = [&]() { // -> int
200-
if (arch_ != arch_t::xe_hpg) {
207+
if (arch != arch_t::xe_hpg) {
201208
if (rt_conf.softmax_axis_stride == 1
202209
&& rt_conf.softmax_axis_size >= 128
203210
&& nelems > (1 << 17)

0 commit comments

Comments
 (0)