Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit ed09547

Browse files
kshitij12345sxjscience
authored andcommitted
[MXNET-978] Higher Order Gradient Support arcsin, arccos. (#15515)
* support arcsin, arccos for higher order grad * add relevant tests * add small note for computation * update comments * use NodeOpGen * retrigger CI * address comment * rename grad_x -> x_grad * retrigger CI * retrigger CI
1 parent faa2832 commit ed09547

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

src/operator/tensor/elemwise_unary_op_trig.cc

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,31 @@ The storage type of ``arcsin`` output depends upon the input storage type:
188188
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arcsin" });
189189

190190
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_arcsin,
191-
unary_bwd<mshadow_op::arcsin_grad>);
191+
unary_bwd<mshadow_op::arcsin_grad>)
192+
.set_attr<nnvm::FGradient>("FGradient",
193+
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
194+
// ograds[0]: head_grad_grads (dL/dxgrad)
195+
// inputs[0]: dL/dy
196+
// inputs[1]: x (ElemwiseGradUseIn)
197+
// f(x) = arcsin(x)
198+
// n: f'(x) = 1/(1-x^2)^1/2
199+
// f''(x) = f'(x) * x/(1-x^2)
200+
// Note: x/(1-x^2) = x * f'(x)^2
201+
auto dydx = n->inputs[0];
202+
auto x = n->inputs[1];
203+
auto dydx_mul_grad_x = nnvm::NodeEntry{n};
204+
auto op = mxnet::util::NodeOpGen{n};
205+
206+
auto x_grad = op.div(dydx_mul_grad_x, dydx);
207+
auto x_grad_square = op.square(x_grad);
208+
auto x_grad_square_mul_x = op.mul(x_grad_square, x);
209+
auto x_grad_grad = op.mul(dydx_mul_grad_x, x_grad_square_mul_x);
210+
211+
std::vector<nnvm::NodeEntry> ret;
212+
ret.emplace_back(op.mul(ograds[0], x_grad));
213+
ret.emplace_back(op.mul(ograds[0], x_grad_grad));
214+
return ret;
215+
});
192216

193217
// arccos
194218
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(arccos, cpu, mshadow_op::arccos)
@@ -207,7 +231,32 @@ The storage type of ``arccos`` output is always dense
207231
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arccos" });
208232

209233
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_arccos,
210-
unary_bwd<mshadow_op::arccos_grad>);
234+
unary_bwd<mshadow_op::arccos_grad>)
235+
.set_attr<nnvm::FGradient>("FGradient",
236+
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
237+
// ograds[0]: head_grad_grads (dL/dxgrad)
238+
// inputs[0]: dL/dy
239+
// inputs[1]: x (ElemwiseGradUseIn)
240+
// f(x) = arccos(x)
241+
// n: f'(x) = -1/(1-x^2)^1/2
242+
// f''(x) = f'(x) * x/(1-x^2)
243+
// Note: x/(1-x^2) = x * f'(x)^2
244+
auto dydx = n->inputs[0];
245+
auto x = n->inputs[1];
246+
auto dydx_mul_grad_x = nnvm::NodeEntry{n};
247+
auto op = mxnet::util::NodeOpGen{n};
248+
249+
auto x_grad = op.div(dydx_mul_grad_x, dydx);
250+
auto x_grad_square = op.square(x_grad);
251+
auto x_grad_square_mul_x = op.mul(x_grad_square, x);
252+
auto x_grad_grad = op.mul(dydx_mul_grad_x, x_grad_square_mul_x);
253+
254+
std::vector<nnvm::NodeEntry> ret;
255+
ret.emplace_back(op.mul(ograds[0], x_grad));
256+
ret.emplace_back(op.mul(ograds[0], x_grad_grad));
257+
return ret;
258+
});
259+
211260

212261
// arctan
213262
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(arctan, cpu, mshadow_op::arctan)

tests/python/unittest/test_higher_order_grad.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,44 @@ def grad_grad_op(x):
133133
array, tanh, grad_grad_op, rtol=1e-6, atol=1e-6)
134134

135135

136+
@with_seed()
137+
def test_arcsin():
138+
def arcsin(x):
139+
return nd.arcsin(x)
140+
141+
def grad_grad_op(x):
142+
return x / nd.sqrt((1-x**2)**3)
143+
144+
for dim in range(1, 5):
145+
shape = rand_shape_nd(dim)
146+
array = random_arrays(shape)
147+
# Hack: Decrease std_dev to make
148+
# sure all elements
149+
# are in range -1 to 1
150+
# i.e. Domain of arcsin
151+
array *= 0.2
152+
check_second_order_unary(array, arcsin, grad_grad_op)
153+
154+
155+
@with_seed()
156+
def test_arccos():
157+
def arccos(x):
158+
return nd.arccos(x)
159+
160+
def grad_grad_op(x):
161+
return -x / nd.sqrt((1-x**2)**3)
162+
163+
for dim in range(1, 5):
164+
shape = rand_shape_nd(dim)
165+
array = random_arrays(shape)
166+
# Hack: Decrease std_dev to make
167+
# sure all elements
168+
# are in range -1 to 1
169+
# i.e. Domain of arccos
170+
array *= 0.2
171+
check_second_order_unary(array, arccos, grad_grad_op)
172+
173+
136174
@with_seed()
137175
def test_arctan():
138176
def arctan(x):

0 commit comments

Comments
 (0)