@@ -188,7 +188,31 @@ The storage type of ``arcsin`` output depends upon the input storage type:
188
188
.set_attr<nnvm::FGradient>(" FGradient" , ElemwiseGradUseIn{ " _backward_arcsin" });
189
189
190
190
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR (_backward_arcsin,
191
- unary_bwd<mshadow_op::arcsin_grad>);
191
+ unary_bwd<mshadow_op::arcsin_grad>)
192
+ .set_attr<nnvm::FGradient>(" FGradient" ,
193
+ [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
194
+ // ograds[0]: head_grad_grads (dL/dxgrad)
195
+ // inputs[0]: dL/dy
196
+ // inputs[1]: x (ElemwiseGradUseIn)
197
+ // f(x) = arcsin(x)
198
+ // n: f'(x) = 1/(1-x^2)^1/2
199
+ // f''(x) = f'(x) * x/(1-x^2)
200
+ // Note: x/(1-x^2) = x * f'(x)^2
201
+ auto dydx = n->inputs [0 ];
202
+ auto x = n->inputs [1 ];
203
+ auto dydx_mul_grad_x = nnvm::NodeEntry{n};
204
+ auto op = mxnet::util::NodeOpGen{n};
205
+
206
+ auto x_grad = op.div (dydx_mul_grad_x, dydx);
207
+ auto x_grad_square = op.square (x_grad);
208
+ auto x_grad_square_mul_x = op.mul (x_grad_square, x);
209
+ auto x_grad_grad = op.mul (dydx_mul_grad_x, x_grad_square_mul_x);
210
+
211
+ std::vector<nnvm::NodeEntry> ret;
212
+ ret.emplace_back (op.mul (ograds[0 ], x_grad));
213
+ ret.emplace_back (op.mul (ograds[0 ], x_grad_grad));
214
+ return ret;
215
+ });
192
216
193
217
// arccos
194
218
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR (arccos, cpu, mshadow_op::arccos)
@@ -207,7 +231,32 @@ The storage type of ``arccos`` output is always dense
207
231
.set_attr<nnvm::FGradient>(" FGradient" , ElemwiseGradUseIn{ " _backward_arccos" });
208
232
209
233
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR (_backward_arccos,
210
- unary_bwd<mshadow_op::arccos_grad>);
234
+ unary_bwd<mshadow_op::arccos_grad>)
235
+ .set_attr<nnvm::FGradient>(" FGradient" ,
236
+ [](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
237
+ // ograds[0]: head_grad_grads (dL/dxgrad)
238
+ // inputs[0]: dL/dy
239
+ // inputs[1]: x (ElemwiseGradUseIn)
240
+ // f(x) = arccos(x)
241
+ // n: f'(x) = -1/(1-x^2)^1/2
242
+ // f''(x) = f'(x) * x/(1-x^2)
243
+ // Note: x/(1-x^2) = x * f'(x)^2
244
+ auto dydx = n->inputs [0 ];
245
+ auto x = n->inputs [1 ];
246
+ auto dydx_mul_grad_x = nnvm::NodeEntry{n};
247
+ auto op = mxnet::util::NodeOpGen{n};
248
+
249
+ auto x_grad = op.div (dydx_mul_grad_x, dydx);
250
+ auto x_grad_square = op.square (x_grad);
251
+ auto x_grad_square_mul_x = op.mul (x_grad_square, x);
252
+ auto x_grad_grad = op.mul (dydx_mul_grad_x, x_grad_square_mul_x);
253
+
254
+ std::vector<nnvm::NodeEntry> ret;
255
+ ret.emplace_back (op.mul (ograds[0 ], x_grad));
256
+ ret.emplace_back (op.mul (ograds[0 ], x_grad_grad));
257
+ return ret;
258
+ });
259
+
211
260
212
261
// arctan
213
262
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR (arctan, cpu, mshadow_op::arctan)
0 commit comments