Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 0eb213d

Browse files
authored
Fix backward_clip num inputs and type of clip params (#15688)
* Fix backward_clip num inputs and type of clip params * Clip test * Trigger CI * Changes to clip docs * Fix docstring * Trigger CI
1 parent bfd3bb8 commit 0eb213d

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

src/operator/tensor/matrix_op-inl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,7 +1481,7 @@ struct ClipParam : public dmlc::Parameter<ClipParam> {
14811481
struct clip {
14821482
template<typename DType>
14831483
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* datas,
1484-
DType a_min, DType a_max) {
1484+
const float a_min, const float a_max) {
14851485
DType data = datas[i];
14861486
if (data > a_max) {
14871487
out[i] = a_max;
@@ -1497,7 +1497,7 @@ struct clip {
14971497
struct clip_grad {
14981498
template<typename DType>
14991499
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* grad, const DType* datas,
1500-
DType a_min, DType a_max) {
1500+
const float a_min, const float a_max) {
15011501
DType data = datas[i];
15021502
if (data > a_max) {
15031503
out[i] = 0;
@@ -1524,7 +1524,7 @@ void Clip(const nnvm::NodeAttrs& attrs,
15241524
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
15251525
mxnet_op::Kernel<mxnet::op::clip, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
15261526
inputs[0].dptr<DType>(),
1527-
DType(param.a_min), DType(param.a_max));
1527+
param.a_min, param.a_max);
15281528
});
15291529
}
15301530

@@ -1553,7 +1553,7 @@ void ClipGrad_(const nnvm::NodeAttrs& attrs,
15531553
Stream<xpu> *s = ctx.get_stream<xpu>();
15541554
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
15551555
Kernel<clip_grad, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
1556-
inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), DType(param.a_min), DType(param.a_max));
1556+
inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), param.a_min, param.a_max);
15571557
});
15581558
}
15591559

src/operator/tensor/matrix_op.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,11 @@ MXNET_ADD_SPARSE_OP_ALIAS(clip)
702702
.describe(R"code(Clips (limits) the values in an array.
703703
704704
Given an interval, values outside the interval are clipped to the interval edges.
705-
Clipping ``x`` between `a_min` and `a_x` would be::
705+
Clipping ``x`` between `a_min` and `a_max` would be::
706706
707-
clip(x, a_min, a_max) = max(min(x, a_max), a_min))
707+
.. math::
708+
709+
clip(x, a_min, a_max) = \max(\min(x, a_max), a_min))
708710
709711
Example::
710712
@@ -766,7 +768,7 @@ parameter values:
766768
.add_arguments(ClipParam::__FIELDS__());
767769

768770
NNVM_REGISTER_OP(_backward_clip)
769-
.set_num_inputs(1)
771+
.set_num_inputs(2)
770772
.set_num_outputs(1)
771773
.set_attr_parser(ParamParser<ClipParam>)
772774
.set_attr<nnvm::TIsBackward>("TIsBackward", true)

tests/python/unittest/test_operator.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4174,15 +4174,25 @@ def test_special_functions_using_scipy():
41744174

41754175

41764176
@with_seed()
4177-
@unittest.skip("Flaky test, tracked at https://github.com/apache/incubator-mxnet/issues/12901")
41784177
def test_clip():
41794178
data = mx.symbol.Variable('data')
41804179
shape = (30, 30)
4181-
data_tmp = np.random.uniform(-1, 1, shape)
4180+
data_tmp = np.random.uniform(-1, 1, shape).astype('float32')
41824181
test = mx.sym.clip(data, a_max=0.6, a_min=-0.6)
41834182
check_symbolic_forward(test, [data_tmp], [np.clip(data_tmp, -0.6, 0.6)])
41844183
check_symbolic_backward(test, [data_tmp], [np.ones(shape)],
4185-
[np.where(data_tmp < 0.6, [1], [0]) * np.where(data_tmp > -0.6, [1], [0])])
4184+
[np.where(data_tmp <= 0.6, [1], [0]) * np.where(data_tmp >= -0.6, [1], [0])])
4185+
4186+
# Test monitor on symbol using clip
4187+
4188+
def simple_callback(name, arr):
4189+
pass
4190+
4191+
exe = test.simple_bind(ctx=mx.current_context(), data=shape)
4192+
exe.set_monitor_callback(simple_callback, monitor_all=True)
4193+
exe.forward(is_train=True)
4194+
exe.backward(out_grads=mx.nd.ones(shape))
4195+
mx.nd.waitall()
41864196

41874197

41884198
@with_seed()

0 commit comments

Comments
 (0)