Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit b4f903d

Browse files
committed
work around operator dispatch for half_t
1 parent 1808ee5 commit b4f903d

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/operator/pooling-inl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ class PoolingOp : public Operator {
9090
param_.global_pool ? 1 : param_.stride[0],
9191
param_.global_pool ? 1 : param_.stride[1]));
9292
} else if (param_.pool_type == pool_enum::kAvgPooling) {
93+
ScalarExp<DType> x = ScalarExp<DType>(1.0f / (param_.global_pool ?
94+
data.shape_[2] * data.shape_[3] :
95+
param_.kernel[0] * param_.kernel[1]));
9396
Assign(out,
9497
req[pool_enum::kOut],
95-
(1.0f / (param_.global_pool ?
96-
data.shape_[2] * data.shape_[3] :
97-
param_.kernel[0] * param_.kernel[1])) * \
98-
pool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
98+
x * pool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
9999
out_shape,
100100
param_.global_pool ? data.shape_[2] : param_.kernel[0],
101101
param_.global_pool ? data.shape_[3] : param_.kernel[1],
@@ -140,9 +140,9 @@ class PoolingOp : public Operator {
140140
param_.pad[0],
141141
param_.pad[1]));
142142
} else if (param_.pool_type == pool_enum::kAvgPooling) {
143+
ScalarExp<DType> x = ScalarExp<DType>(1.0f / param_.kernel[0] / param_.kernel[1]);
143144
Assign(input_grad, req[pool_enum::kData],
144-
(1.0f / param_.kernel[0] / param_.kernel[1]) *\
145-
crop(unpool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
145+
x * crop(unpool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
146146
pad(output_data, 0, 0),
147147
pad(grad, 0, 0),
148148
param_.global_pool ? in_shape[0] : param_.kernel[0],

0 commit comments

Comments
 (0)