Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 1e76151

Browse files
committed
work around operator dispatch for half_t
1 parent e7ec695 commit 1e76151

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

src/operator/leaky_relu-inl.h

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ class LeakyReLUOp : public Operator {
9595
}
9696
switch (param_.act_type) {
9797
case leakyrelu::kLeakyReLU: {
98-
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, param_.slope));
98+
ScalarExp<DType> slope = ScalarExp<DType>(param_.slope);
99+
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, slope));
99100
break;
100101
}
101102
case leakyrelu::kPReLU: {
@@ -106,18 +107,23 @@ class LeakyReLUOp : public Operator {
106107
}
107108
case leakyrelu::kRReLU: {
108109
if (ctx.is_train) {
109-
Random<xpu, DType>* prnd = ctx.requested[leakyrelu::kRandom].get_random<xpu, DType>(s);
110-
mask = prnd->uniform(mask.shape_);
111-
mask = mask * (param_.upper_bound - param_.lower_bound) + param_.lower_bound;
110+
// TODO: Random doesn't work with Float16, this will lead to a reduced
111+
// entropy for Float64.
112+
Random<xpu>* prnd = ctx.requested[leakyrelu::kRandom].get_random<xpu, real_t>(s);
113+
mask = tcast<DType>(prnd->uniform(mask.shape_));
114+
mask = mask * ScalarExp<DType>(param_.upper_bound - param_.lower_bound)
115+
+ ScalarExp<DType>(param_.lower_bound);
112116
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, mask));
113117
} else {
114-
const float slope = (param_.lower_bound + param_.upper_bound) / 2.0f;
118+
ScalarExp<DType> slope =
119+
ScalarExp<DType>((param_.lower_bound + param_.upper_bound) / 2.0f);
115120
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, slope));
116121
}
117122
break;
118123
}
119124
case leakyrelu::kELU: {
120-
Assign(out, req[leakyrelu::kOut], F<mshadow_op::elu>(data, param_.slope));
125+
ScalarExp<DType> slope = ScalarExp<DType>(param_.slope);
126+
Assign(out, req[leakyrelu::kOut], F<mshadow_op::elu>(data, slope));
121127
break;
122128
}
123129
default:
@@ -171,7 +177,8 @@ class LeakyReLUOp : public Operator {
171177
}
172178
switch (param_.act_type) {
173179
case leakyrelu::kLeakyReLU: {
174-
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::xelu_grad>(output, param_.slope) * grad);
180+
ScalarExp<DType> slope = ScalarExp<DType>(param_.slope);
181+
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::xelu_grad>(output, slope) * grad);
175182
break;
176183
}
177184
case leakyrelu::kPReLU: {
@@ -186,7 +193,8 @@ class LeakyReLUOp : public Operator {
186193
break;
187194
}
188195
case leakyrelu::kELU: {
189-
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::elu_grad>(output, param_.slope) * grad);
196+
ScalarExp<DType> slope = ScalarExp<DType>(param_.slope);
197+
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::elu_grad>(output, slope) * grad);
190198
break;
191199
}
192200
default:

0 commit comments

Comments
 (0)