Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 53b325c

Browse files
committed
work around operator dispatch for half_t
1 parent 0d7e446 commit 53b325c

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/operator/leaky_relu-inl.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ class LeakyReLUOp : public Operator {
9595
}
9696
switch (param_.act_type) {
9797
case leakyrelu::kLeakyReLU: {
98-
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, param_.slope));
98+
ScalarExp<DType> slope = ScalarExp<DType>(param_.slope);
99+
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, slope));
99100
break;
100101
}
101102
case leakyrelu::kPReLU: {
@@ -111,13 +112,14 @@ class LeakyReLUOp : public Operator {
111112
mask = mask * (param_.upper_bound - param_.lower_bound) + param_.lower_bound;
112113
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, mask));
113114
} else {
114-
const float slope = (param_.lower_bound + param_.upper_bound) / 2.0f;
115+
ScalarExp<DType> slope = ScalarExp<DType>((param_.lower_bound + param_.upper_bound) / 2.0f);
115116
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, slope));
116117
}
117118
break;
118119
}
119120
case leakyrelu::kELU: {
120-
Assign(out, req[leakyrelu::kOut], F<mshadow_op::elu>(data, param_.slope));
121+
const DType slope = DType(param_.slope);
122+
Assign(out, req[leakyrelu::kOut], F<mshadow_op::elu>(data, slope));
121123
break;
122124
}
123125
default:
@@ -171,7 +173,8 @@ class LeakyReLUOp : public Operator {
171173
}
172174
switch (param_.act_type) {
173175
case leakyrelu::kLeakyReLU: {
174-
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::xelu_grad>(output, param_.slope) * grad);
176+
ScalarExp<DType> slope = ScalarExp<DType>(param_.slope);
177+
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::xelu_grad>(output, slope) * grad);
175178
break;
176179
}
177180
case leakyrelu::kPReLU: {
@@ -186,7 +189,8 @@ class LeakyReLUOp : public Operator {
186189
break;
187190
}
188191
case leakyrelu::kELU: {
189-
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::elu_grad>(output, param_.slope) * grad);
192+
ScalarExp<DType> slope = ScalarExp<DType>(param_.slope);
193+
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::elu_grad>(output, slope) * grad);
190194
break;
191195
}
192196
default:

0 commit comments

Comments
 (0)