@@ -101,6 +101,7 @@ struct reusable_softmax_fwd_t : public gpu_primitive_t {
101
101
using arch_t = compute::gpu_arch_t ;
102
102
auto *compute_engine
103
103
= utils::downcast<compute::compute_engine_t *>(engine);
104
+ const arch_t arch = compute_engine->device_info ()->gpu_arch ();
104
105
105
106
const memory_desc_wrapper src_mdw (src_md ());
106
107
const memory_desc_wrapper dst_mdw (dst_md ());
@@ -118,6 +119,13 @@ struct reusable_softmax_fwd_t : public gpu_primitive_t {
118
119
utils::one_of (dst_dt, f64, f32, f16, bf16, u8, s8),
119
120
VERBOSE_UNSUPPORTED_DT);
120
121
122
+ VDISPATCH_SOFTMAX (IMPLICATION (utils::one_of (src_dt, f16, bf16),
123
+ arch == arch_t ::xe_hpc),
124
+ VERBOSE_UNSUPPORTED_DT_CFG);
125
+ VDISPATCH_SOFTMAX (IMPLICATION (utils::one_of (dst_dt, f16, bf16),
126
+ arch == arch_t ::xe_hpc),
127
+ VERBOSE_UNSUPPORTED_DT_CFG);
128
+
121
129
VDISPATCH_SOFTMAX (IMPLICATION (utils::one_of (f16, src_dt, dst_dt),
122
130
compute_engine->mayiuse (
123
131
compute::device_ext_t ::khr_fp16)),
@@ -193,11 +201,10 @@ struct reusable_softmax_fwd_t : public gpu_primitive_t {
193
201
}
194
202
}
195
203
196
- const arch_t arch_ = compute_engine->device_info ()->gpu_arch ();
197
204
const auto nelems = src_mdw.nelems ();
198
205
199
206
conf.algorithm_number = [&]() { // -> int
200
- if (arch_ != arch_t ::xe_hpg) {
207
+ if (arch != arch_t ::xe_hpg) {
201
208
if (rt_conf.softmax_axis_stride == 1
202
209
&& rt_conf.softmax_axis_size >= 128
203
210
&& nelems > (1 << 17 )
0 commit comments