diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index f63b2412077b..6178447e4b4e 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -252,6 +252,7 @@ List of Contributors * [Jonathan Tan](https://github.com/jonatan1626) * [Oliver Kowalke](https://github.com/olk) * [Connor Goggins](https://github.com/connorgoggins) +* [Deng, Wenqi](https://github.com/tobecontinued) * [Wei Chu](https://github.com/waytrue17) * [Yang Shi](https://github.com/ys2843) diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index ccfebf597f67..2d3f6974842b 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -479,6 +479,16 @@ inline std::vector MakeNonlossGradNode( return CreateNodeEntries(p, &ograds, &inputs); } +struct NonlossGradFGradient { + nnvm::FGradient grad_func; + std::vector operator()(const nnvm::ObjectPtr& n, + const std::vector& ograds) const { + if (CheckGradAllZero(ograds)) + return MakeZeroGradNodes(n, ograds); + return grad_func(n, ograds); + } +}; + /*! \brief Parse keyword arguments as PType arguments and save to parsed */ template inline void ParamParser(nnvm::NodeAttrs* attrs) { diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc index 71be8f814f3b..57aa8f129c41 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cc +++ b/src/operator/tensor/broadcast_reduce_op_value.cc @@ -130,6 +130,58 @@ NNVM_REGISTER_OP(_broadcast_backward) return std::vector{ResourceRequest::kTempSpace}; }); +// if the rhs can broadcast to lhs, use sum to reduce lhs to shape of rhs +NNVM_REGISTER_OP(_reduce_sum_brodcasted) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FInferShape", [](const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + mxnet::TShape& lhs_shape = (*in_attrs)[0]; + mxnet::TShape& rhs_shape = (*in_attrs)[1]; + + if (!mxnet::ndim_is_known(lhs_shape) || !mxnet::ndim_is_known(rhs_shape)) { + return false; + } + + // the lhs and rhs are comp + CHECK_EQ(lhs_shape.ndim(), rhs_shape.ndim()) + << "Operand of shape " << lhs_shape << " cannot be reduced to " << rhs_shape; + + for (int i = 0; i < lhs_shape.ndim(); ++i) { + if (rhs_shape[i] != -1) { + CHECK(lhs_shape[i] == rhs_shape[i] || rhs_shape[i] == 1) + << "Array cannot be reduced from " << lhs_shape << " to " << rhs_shape; + } + } + auto oshape = mxnet::TShape(rhs_shape); + + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + return true; + }) +.set_attr("FCompute", [](const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, const std::vector& req, + const std::vector& outputs) { + ReduceAxesComputeImpl(ctx, inputs, req, outputs, inputs[1].shape_); + }) +.set_attr("FGradient", NonlossGradFGradient{ + [](const nnvm::ObjectPtr& n, const std::vector& ograds) { + auto head_grad = ograds[0]; + std::vector ret; + ret.emplace_back(MakeNode("broadcast_like", n->attrs.name + "_lhs_backward", + {head_grad, n->inputs[0]}, nullptr, &n)); + ret.emplace_back(MakeNode("zeros_like", n->attrs.name + "_rhs_backward", + {n->inputs[1]}, nullptr, &n)); + return ret; + }}); + NNVM_REGISTER_OP(broadcast_like) .add_alias("_npx_broadcast_like") .set_num_inputs(2) @@ -148,25 +200,23 @@ NNVM_REGISTER_OP(broadcast_like) (*in_attrs)[0] = checked_in_attrs[0]; return ret; }) -.set_attr("FGradient", - [](const nnvm::ObjectPtr& n, - const std::vector& ograds) { - if (CheckGradAllZero(ograds)) - return MakeZeroGradNodes(n, ograds); - std::vector lhs = MakeNonlossGradNode("_broadcast_backward", n, ograds, {}, - {{"keepdims", "true"}}); - lhs.emplace_back(MakeNode("zeros_like", n->attrs.name + "_rhs_backward", +.set_attr("FGradient", NonlossGradFGradient{ + [](const nnvm::ObjectPtr& n, const std::vector& ograds) { + std::vector ret; + ret.emplace_back(MakeNode("_reduce_sum_brodcasted", n->attrs.name + "_lhs_backward", + {ograds[0], n->inputs[0]}, nullptr, &n)); + ret.emplace_back(MakeNode("zeros_like", n->attrs.name + "_rhs_backward", {n->inputs[1]}, nullptr, &n)); - return lhs; - }) + return ret; +}}) .add_argument("lhs", "NDArray-or-Symbol", "First input.") .add_argument("rhs", "NDArray-or-Symbol", "Second input.") .describe(R"code(Broadcasts lhs to have the same shape as rhs. Broadcasting is a mechanism that allows NDArrays to perform arithmetic operations with arrays of different shapes efficiently without creating multiple copies of arrays. -Also see, `Broadcasting `_ for more explanation. +Also see, `Broadcasting `_ for more explanation. Broadcasting is allowed on axes with size 1, such as from `(2,1,3,1)` to `(2,8,3,9)`. Elements will be duplicated on the broadcasted axes. diff --git a/src/operator/tensor/broadcast_reduce_op_value.cu b/src/operator/tensor/broadcast_reduce_op_value.cu index 35b3c0272db8..2b545a5452c4 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cu +++ b/src/operator/tensor/broadcast_reduce_op_value.cu @@ -36,6 +36,14 @@ NNVM_REGISTER_OP(broadcast_to) NNVM_REGISTER_OP(broadcast_like) .set_attr("FCompute", BroadcastCompute); +NNVM_REGISTER_OP(_reduce_sum_brodcasted) +.set_attr("FCompute", [](const nnvm::NodeAttrs& attrs, const OpContext& ctx, + const std::vector& inputs, const std::vector& req, + const std::vector& outputs) { + ReduceAxesComputeImpl(ctx, inputs, req, outputs, inputs[1].shape_); + }); + NNVM_REGISTER_OP(_broadcast_backward) .set_attr("FCompute", ReduceAxesCompute); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc b/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc index 9e52b3197dcb..fcc0365561b9 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_extended.cc @@ -44,22 +44,41 @@ Example:: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"}); - -NNVM_REGISTER_OP(_backward_broadcast_power) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 1}}; - }) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FGradient", NonlossGradFGradient{ + [](const nnvm::ObjectPtr& n, const std::vector& ograds) { + auto head_grad_z = ograds[0]; + auto x = nnvm::NodeEntry{mxnet::op::MakeNode("broadcast_like", + n->attrs.name + "_broadcast_like", {n->inputs[0], head_grad_z}, nullptr, &n)}; + auto y = nnvm::NodeEntry{mxnet::op::MakeNode("broadcast_like", + n->attrs.name + "_broadcast_like", {n->inputs[1], head_grad_z}, nullptr, &n)}; + + auto one_like = nnvm::NodeEntry{mxnet::op::MakeNode("ones_like", + n->attrs.name + "_ones_like", {y}, nullptr, &n)}; + auto y_sub_1 = nnvm::NodeEntry{MakeNode("elemwise_sub", + n->attrs.name + "_rhs_sub_1", {y, one_like}, nullptr, &n)}; + auto x_power_y_sub_1 = nnvm::NodeEntry{MakeNode("broadcast_power", + n->attrs.name + "_lhs_power_rhs_sub_1", {x, y_sub_1}, nullptr, &n)}; + auto dzdx = nnvm::NodeEntry{MakeNode("elemwise_mul", + n->attrs.name + "dpower/dlhs", {y, x_power_y_sub_1}, nullptr, &n)}; + + auto lnx = nnvm::NodeEntry{MakeNode("log", + n->attrs.name + "_ln_lhs", {x}, nullptr, &n)}; + auto x_power_y = nnvm::NodeEntry{n}; + auto dzdy = nnvm::NodeEntry{MakeNode("elemwise_mul", + n->attrs.name + "dpower/drhs", {x_power_y, lnx}, nullptr, &n)}; + + auto broadcasted_lhs_backward = nnvm::NodeEntry{MakeNode("elemwise_mul", + n->attrs.name + "_broadcasted_lhs_backward", {head_grad_z, dzdx}, nullptr, &n)}; + auto broadcasted_rhs_backward = nnvm::NodeEntry{MakeNode("elemwise_mul", + n->attrs.name + "_broadcasted_rhs_backward", {head_grad_z, dzdy}, nullptr, &n)}; + + std::vector ret; + ret.emplace_back(MakeNode("_reduce_sum_brodcasted", n->attrs.name + "_lhs_backward", + {broadcasted_lhs_backward, n->inputs[0]}, nullptr, &n)); + ret.emplace_back(MakeNode("_reduce_sum_brodcasted", n->attrs.name + "rhs_backward", + {broadcasted_rhs_backward, n->inputs[1]}, nullptr, &n)); + return ret; +}}); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_maximum) .add_alias("_npi_maximum") diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index 469081682b2e..1b65cc6ba7c0 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -192,23 +192,19 @@ The storage type of ``elemwise_sub`` output depends on storage types of inputs - otherwise, ``elemwise_sub`` generates output with default storage )code") -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_sub"}); +.set_attr("FGradient", NonlossGradFGradient{ + [](const nnvm::ObjectPtr& n, const std::vector& ograds) { + auto head_grad = ograds[0]; + auto x = n->inputs[0]; + auto y = n->inputs[1]; -NNVM_REGISTER_OP(_backward_sub) -.set_num_inputs(1) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FInplaceOption", - [](const NodeAttrs &attrs) { - return std::vector >{{0, 0}, - {0, 1}}; - }) -.set_attr("FCompute", ElemwiseBinaryOp::BackwardUseNone) -.set_attr("FComputeEx", ElemwiseBinaryOp::BackwardUseNoneEx) -.set_attr("FInferStorageType", - ElemwiseStorageType<1, 2, true, true, true>); + std::vector ret; + ret.emplace_back(MakeNode("identity", n->attrs.name + "_lhs_backward", + {head_grad}, nullptr, &n)); + ret.emplace_back(MakeNode("negative", n->attrs.name + "_rhs_backward", + {head_grad}, nullptr, &n)); + return ret; +}}); MXNET_OPERATOR_REGISTER_BINARY(elemwise_mul) MXNET_ADD_SPARSE_OP_ALIAS(elemwise_mul) @@ -235,25 +231,20 @@ The storage type of ``elemwise_mul`` output depends on storage types of inputs }) .set_attr("THasDeterministicOutput", true) .add_alias("_mul").add_alias("_Mul") -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_mul"}); +.set_attr("FGradient", NonlossGradFGradient{ + [](const nnvm::ObjectPtr& n, const std::vector& ograds) { + auto head_grad = ograds[0]; + auto x = n->inputs[0]; + auto y = n->inputs[1]; + + std::vector ret; + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_lhs_backward", + {head_grad, y}, nullptr, &n)); + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_rhs_backward", + {head_grad, x}, nullptr, &n)); + return ret; +}}); -NNVM_REGISTER_OP(_backward_mul) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FInplaceOption", - [](const NodeAttrs &attrs) { - return std::vector >{{0, 1}}; - }) -.set_attr("FInferStorageType", ElemwiseBinaryOp::BackwardUseInStorageType) -.set_attr("FResourceRequest", /* For Sparse CSR */ - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", ElemwiseBinaryOp::BackwardUseIn< - cpu, mshadow_op::right, mshadow_op::left>) -.set_attr("FComputeEx", ElemwiseBinaryOp::BackwardUseInEx< - cpu, mshadow_op::right, mshadow_op::left>); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(elemwise_div, op::mshadow_op::div) MXNET_ADD_SPARSE_OP_ALIAS(elemwise_div) diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h index ecbfcd5d7d3a..04aeb3ae974f 100644 --- a/tests/cpp/include/test_core_op.h +++ b/tests/cpp/include/test_core_op.h @@ -25,9 +25,12 @@ #include #include #include +#include +#include #include "./test_op.h" #include "profiler/vtune.h" #include "../../../src/imperative/imperative_utils.h" +#include "../../../src/operator/operator_common.h" namespace mxnet { namespace test { @@ -129,6 +132,11 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer nnvm::ObjectPtr MakeNode() const { nnvm::ObjectPtr node = nnvm::Node::Create(); node->attrs = attrs_; + for (uint32_t i = 0; i < node->num_inputs(); ++i) { + auto input = nnvm::Node::Create(); + input->attrs.name = "input_" + std::to_string(i); + node->inputs.emplace_back(nnvm::NodeEntry(input, 0, 0)); + } return node; } @@ -142,20 +150,184 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer nnvm::FGradient grad_fun = gradient.get(op_, nullptr); if (grad_fun) { auto n = MakeNode(); + uint32_t index = -1; + std::map name_to_indexs; + std::vector index_to_nodes; + std::vector out_grads(n->num_outputs()); - std::vector entries = grad_fun(n, out_grads); + for (auto i = 0; i < out_grads.size(); ++i) { + out_grads[i].node = nnvm::Node::Create(); + out_grads[i].node->attrs.name = "out_grad_" + std::to_string(i); + name_to_indexs[out_grads[i].node->attrs.name] = ++index; + out_grads[i].index = index; + index_to_nodes.push_back(out_grads[i]); + index_to_nodes[index].index = index; + bwd_inputs_.emplace_back(CreateRandArray(output_shapes_[i], ctx_.run_ctx, + output_types_[i])); + } + for (auto i = 0; i < n->num_inputs(); ++i) { + name_to_indexs[n->inputs[i].node->attrs.name] = ++index; + n->inputs[i].index = index; + index_to_nodes.push_back(n->inputs[i]); + index_to_nodes[index].index = index; + bwd_outputs_.emplace_back(CreateZeroArray(input_shapes_[i], ctx_.run_ctx, + input_types_[i])); + } + + std::vector< nnvm::NodeEntry> entries = grad_fun(n, out_grads); CHECK_GE(entries.size(), 1U); - res.reserve(entries.size()); - for (const nnvm::NodeEntry& node_entry : entries) { - CHECK_NOTNULL(node_entry.node.get()); - CHECK_NOTNULL(node_entry.node->op()); - CHECK_GT(node_entry.node->op()->name.size(), 0); + + std::map bwd_output_nodes; + std::queue queue; + std::unordered_set visited; + for (auto i = 0; i < entries.size(); ++i) { + nnvm::NodeEntry& node_entry = entries[i]; + queue.push(&node_entry); + visited.insert(node_entry.node->attrs.name); + name_to_indexs[node_entry.node->attrs.name] = ++index; + node_entry.index = index; + index_to_nodes.push_back(node_entry); + index_to_nodes[index].index = index; + bwd_output_nodes[node_entry.index] = i; + } + while (!queue.empty()) { + auto* node_entry = queue.front(); + queue.pop(); + + CHECK_NOTNULL(node_entry); + CHECK_NOTNULL(node_entry->node.get()); + CHECK_NOTNULL(node_entry->node->op()); + CHECK_GT(node_entry->node->op()->name.size(), 0); + + auto it = name_to_indexs.find(node_entry->node->attrs.name); + if (it == name_to_indexs.end()) { + name_to_indexs[node_entry->node->attrs.name] = ++index; + index_to_nodes.push_back(*node_entry); + } else { + node_entry->index = it->second; + } + + for (nnvm::NodeEntry& input : node_entry->node->inputs) { + if (visited.find(input.node->attrs.name) != visited.end()) { + continue; + } else { + input.index = name_to_indexs[input.node->attrs.name]; + } + visited.insert(input.node->attrs.name); + queue.push(&input); + } + } + + std::map> node_outputs; + std::queue index_queue; + std::unordered_set visited_indexes; + for (auto i = 0; i < entries.size(); ++i) { + nnvm::NodeEntry& node_entry = entries[i]; + visited_indexes.insert(node_entry.index); + index_queue.push(node_entry.index); + } + while (!index_queue.empty()) { + auto node_index = index_queue.front(); + index_queue.pop(); + const auto& node_entry = index_to_nodes.at(node_index); + + for (nnvm::NodeEntry& input : node_entry.node->inputs) { + node_outputs[input.index].insert(node_entry.index); + if (visited_indexes.find(input.index) != visited_indexes.end()) { + continue; + } + + visited_indexes.insert(input.index); + index_queue.push(input.index); + } + } + + visited_indexes.clear(); + std::map node_shapes; + std::map node_types; + for (auto i = 0; i < n->num_inputs(); ++i) { + index_queue.push(n->inputs[i].index); + visited_indexes.insert(n->inputs[i].index); + node_shapes[n->inputs[i].index] = input_shapes_[i]; + node_types[n->inputs[i].index] = input_types_[i]; + } + for (auto i = 0; i < out_grads.size(); ++i) { + index_queue.push(out_grads[i].index); + visited_indexes.insert(out_grads[i].index); + node_shapes[out_grads[i].index] = output_shapes_[i]; + node_types[out_grads[i].index] = output_types_[i]; + } + bwd_output_executors_.resize(entries.size()); + std::map node_executors; + // the order of nodes getting from button-to-up DFS can't be used + // for executing CoreOpExecutors. + while (!index_queue.empty()) { + auto node_index = index_queue.front(); + index_queue.pop(); + const auto& node_entry = index_to_nodes.at(node_index); + + for (auto output_index : node_outputs[node_entry.index]) { + if (visited_indexes.find(output_index) != visited_indexes.end()) { + continue; + } + visited_indexes.insert(output_index); + index_queue.push(output_index); + } + + std::vector input_shapes; + input_shapes.reserve(node_entry.node->num_inputs()); + for (nnvm::NodeEntry& input : node_entry.node->inputs) { + CHECK(node_shapes.find(input.index) != node_shapes.end()); + input_shapes.push_back(node_shapes[input.index]); + } + + if (node_entry.node->is_variable() || node_entry.node->op()->name.empty()) { + continue; + } + if (verbose_) { std::cout << node_entry.node->op()->name << std::endl; } + + std::vector output_shapes; + output_shapes.resize(node_entry.node->num_outputs()); + static auto& finfer_shape = Op::GetAttr("FInferShape"); + if (finfer_shape.count(node_entry.node->op())) { + mxnet::FInferShape call_infer_shapes = finfer_shape[node_entry.node->op()]; + call_infer_shapes(node_entry.node->attrs, &input_shapes, &output_shapes); + // how to handle #output > 1? + CHECK_EQ(output_shapes.size(), 1U); + node_shapes[node_entry.index] = output_shapes[0]; + } else { + CHECK(false) << "can't find finfer_shape for op " << node_entry.node->op()->name; + } + std::shared_ptr pOp = std::make_shared( - ctx().run_ctx.ctx.dev_type == Context::kGPU, ShapesOf(outputs())); - res.push_back({ pOp, node_entry.node->op()->name }); + ctx().run_ctx.ctx.dev_type == Context::kGPU, input_shapes); + + node_executors[node_entry.index] = pOp.get(); + bwd_input_map_[pOp.get()]; + res.push_back({pOp, node_entry.node->op()->name}); + auto index_iter = bwd_output_nodes.find(node_entry.index); + if (index_iter != bwd_output_nodes.end()) { + bwd_output_executors_[index_iter->second] = pOp.get(); + } + + for (const auto& input : node_entry.node->inputs) { + ExecutorInputSources executor_input_source; + if (input.index < bwd_inputs_.size()) { + executor_input_source.type = ExecutorInputSources::GRAD; + executor_input_source.index = input.index; + } else if (input.index < bwd_inputs_.size() + inputs_.size()) { + executor_input_source.type = ExecutorInputSources::INPUT; + executor_input_source.index = input.index - bwd_inputs_.size(); + } else { + executor_input_source.type = ExecutorInputSources::EXECUTOR; + executor_input_source.index = 0; + executor_input_source.executor = node_executors[input.index]; + } + bwd_input_map_[pOp.get()].push_back(executor_input_source); + } } } return res; @@ -353,7 +525,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer std::string op_name, bwd_op_name; kwargs_t args = ArgsSansOpName(in_args, &op_name, &bwd_op_name); - CHECK(op_name.empty() == false); + CHECK(!op_name.empty()); CHECK(!backward_for_op || bwd_op_name.empty()) << "Backward op should not be supplied another backward operator"; @@ -366,7 +538,6 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer CHECK_NOTNULL(op_); std::map index2array; - nnvm::ObjectPtr bwd_node_ptr; if (backward_for_op) { bwd_node_ptr = backward_for_op->CalcBackwardPass(&index2array); } @@ -463,6 +634,7 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer } } } + input_types_ = input_types; // Output arrays if (outputs_.empty()) { @@ -511,7 +683,9 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer : NDArray())); outputs_p.emplace_back(&*outputs_.rbegin()); } + output_shapes_ = std::move(output_shapes); } + output_types_ = std::move(output_types); for (size_t i = 0; i < static_cast(num_inputs); ++i) { CHECK_LT(i, static_cast(input_shapes.size())); @@ -574,7 +748,6 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer if (!no_backward) { CHECK_GE(bwd.size(), 1U) << "Can't automatically determine backward op name. Please specify"; - for (std::pair, std::string> &bw_item : bwd) { bw_item.first->set_verbose(verbose_); backward_.emplace_back(bw_item.first); @@ -671,8 +844,28 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer if (!backward_.empty()) { // Avoid locked ref count here for (std::shared_ptr &p : backward_) { + if (bwd_input_map_.find(p.get()) != bwd_input_map_.end()) { + p->inputs().clear(); + for (const auto &bwd_input_source : bwd_input_map_[p.get()]) { + if (bwd_input_source.type == ExecutorInputSources::Type::EXECUTOR) { + p->inputs().push_back(bwd_input_source.executor->outputs()[bwd_input_source.index]); + } else if (bwd_input_source.type == ExecutorInputSources::Type::GRAD) { + p->inputs().push_back(bwd_inputs_[bwd_input_source.index]); + } else { + p->inputs().push_back(inputs_[bwd_input_source.index]); + } + } + } p->Execute(); } + if (!bwd_output_executors_.empty()) { + bwd_outputs().clear(); + for (auto& p : bwd_output_executors_) { + for (auto& output : p->outputs()) { + bwd_outputs_.push_back(output); + } + } + } return true; } return false; @@ -733,13 +926,11 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer * \return reference to NDArray vector of backward inputs */ std::vector& bwd_inputs() { - CHECK_EQ(backward_.size(), 1U); - return backward_[0]->inputs(); + return bwd_inputs_.empty() ? backward_[0]->inputs() : bwd_inputs_; } const std::vector& bwd_inputs() const { - CHECK_EQ(backward_.size(), 1U); - return backward_[0]->inputs(); + return bwd_inputs_.empty() ? backward_[0]->inputs() : bwd_inputs_; } /*! @@ -747,13 +938,11 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer * \return reference to NDArray vector of backward outputs */ std::vector& bwd_outputs() { - CHECK_EQ(backward_.size(), 1U); - return backward_[0]->outputs(); + return bwd_outputs_.empty() ? backward_[0]->outputs() : bwd_outputs_; } const std::vector& bwd_outputs() const { - CHECK_EQ(backward_.size(), 1U); - return backward_[0]->outputs(); + return bwd_outputs_.empty() ? backward_[0]->outputs() : bwd_outputs_; } void set_verbose(bool verbose) { @@ -788,7 +977,23 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer /*! * \brief Input data shape */ + mxnet::ShapeVector input_shapes_; + /*! + * \brief Input data type + */ + std::vector input_types_; + + /*! + * \brief Output data shape + */ + mxnet::ShapeVector output_shapes_; + + /*! + * \brief Output data type + */ + std::vector output_types_; + /* * \brief Pointer to the operator object */ @@ -798,9 +1003,10 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer */ nnvm::NodeAttrs attrs_; /*! - * \brief Input and output NDArray vectors + * \brief Input, output, bwd input and bwd output NDArray vectors */ - std::vector inputs_, outputs_; + std::vector inputs_, outputs_, bwd_inputs_, bwd_outputs_; + /*! * \brief Vectors of the TBlob objects associated with the NDArrays in inputs_ and outputs_ */ @@ -834,6 +1040,27 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer * \brief Backward executors (if any) */ std::vector> backward_; + + /*! + * \brief The backwards which generated bwd output + */ + std::vector bwd_output_executors_; + + struct ExecutorInputSources { + enum Type { + INPUT, + GRAD, + EXECUTOR + }; + Type type; + uint32_t index; + // used for EXECUTOR + CoreOpExecutor* executor = nullptr; + }; + /*! + * \brief The sources of input of backend CoreOpExecutor + */ + std::map> bwd_input_map_; }; class CoreOpProp { diff --git a/tests/cpp/operator/runner/core_op_runner_test.cc b/tests/cpp/operator/runner/core_op_runner_test.cc index 96458cd1c713..2afb4b9e7818 100644 --- a/tests/cpp/operator/runner/core_op_runner_test.cc +++ b/tests/cpp/operator/runner/core_op_runner_test.cc @@ -46,7 +46,7 @@ static const std::vector> test_unary_operato static const std::vector> test_binary_operators = { { "elemwise_add", "_backward_add" }, - { "elemwise_mul", "_backward_mul" } + { "elemwise_mul", "" } }; template diff --git a/tests/cpp/operator/tune/operator_tune_test.cc b/tests/cpp/operator/tune/operator_tune_test.cc index 00a062698b17..0f8e699bda5d 100644 --- a/tests/cpp/operator/tune/operator_tune_test.cc +++ b/tests/cpp/operator/tune/operator_tune_test.cc @@ -106,7 +106,7 @@ static float EvaluateTune(const bool verbose = true) { {"sigmoid", ""}, {"sqrt", ""}, {"elemwise_add", "_backward_add"}, - {"elemwise_mul", "_backward_mul"}, + {"elemwise_mul", ""}, {"elemwise_div", "_backward_div"} }; } else { diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index ae3c33a4d9b7..f768bdac1605 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -554,6 +554,7 @@ def check_nth_order_unary(x, op, grad_ops, orders, rtol=None, atol=None): computed_grads = [] head_grads = [] + # Perform compute. with autograd.record(): y = op(x) @@ -576,6 +577,284 @@ def check_nth_order_unary(x, op, grad_ops, orders, rtol=None, atol=None): assert_almost_equal( expected_grad, computed_grad.asnumpy(), rtol=rtol, atol=atol) +@with_seed() +def test_elemwise_sub(): + def sub(inputs): + return nd.elemwise_sub(inputs[0], inputs[1]) + def grad_op(inputs): + return [nd.ones_like(inputs[0]), nd.negative(nd.ones_like(inputs[1]))] + def grad_grad_op(inputs): + return [nd.zeros_like(inputs[0]), nd.zeros_like(inputs[1])] + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + x, y = random_arrays(shape, shape) + check_nth_order_binary([x, y], sub, [grad_op, grad_grad_op], [1, 2]) + +@with_seed() +def test_elemwise_mul(): + def mul(inputs): + return nd.elemwise_mul(inputs[0], inputs[1]) + def grad_op(inputs): + return [inputs[1], inputs[0]] + def grad_grad_op(inputs): + return [nd.zeros_like(inputs[0]) ,nd.zeros_like(inputs[1])] + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + x, y = random_arrays(shape, shape) + check_nth_order_binary([x, y], mul, [grad_op, grad_grad_op], [1, 2]) + +@with_seed() +def test_power(): + def power(inputs): + return nd.power(inputs[0], inputs[1]) + + def grad_op(inputs): + x, y = inputs + return [y * nd.power(x, y - 1), nd.power(x, y) * nd.log(x)] + + def grad_grad_op(inputs): + x, y = inputs + return [y * (y - 1) * nd.power(x, y - 2), nd.power(x, y) * (nd.log(x) ** 2)] + + def grad_grad_grad_op(inputs): + x, y = inputs + return [y * (y - 1) * (y - 2) * nd.power(x, y - 3), nd.power(x, y) * (nd.log(x) ** 3)] + + low = 1.0 + high = 3.0 + for dim in range(1, 5): + shape = rand_shape_nd(dim) + x = nd.random.uniform(low, high, shape) + y = nd.random.uniform(low, high, shape) + check_nth_order_binary([x, y], power, [grad_op, grad_grad_op, grad_grad_grad_op], [1, 2, 3]) + +# based on gen_broadcast_data in test_operation.py +def gen_broadcast_shape(idx): + # Manually set test cases + binary_op_data_shape = nd.array( + [[[2, 5, 1, 30, 7], [1, 5, 448, 30, 1]], + [[10, 49, 1, 77, 17], [10, 1, 2, 1, 17]], + [[13, 2, 65, 2, 1], [13, 1, 65, 1, 225]], + [[9, 434, 4, 2, 37], [9, 1, 4, 1, 37]], + [[2, 52, 1, 4, 1], [1, 52, 60, 1, 37]], + [[1, 23, 7, 122, 50], [2, 1, 7, 1, 50]], + [[1, 17, 1, 5, 1], [22, 1, 2, 1, 28]], + [[29, 1, 2, 1, 8], [29, 22, 1, 130, 1]], + [[2, 36, 1, 427, 3], [1, 36, 11, 427, 1]], + [[1, 2, 1, 100, 7], [1, 2, 448, 100, 1]], + [[1, 2, 495, 77, 7], [1, 2, 1, 1, 7]], + [[1, 43, 65, 2, 1], [1, 43, 65, 1, 225]], + [[1, 92, 434, 2, 2], [1, 92, 1, 2, 2]], + [[1, 92, 1, 4, 1], [1, 92, 134, 1, 17]], + [[1, 53, 2, 122, 143], [1, 1, 2, 1, 143]], + [[1, 179, 1, 87, 17], [1, 179, 1, 1, 17]], + [[1, 1, 17, 5, 1], [1, 22, 1, 1, 28]], + [[1, 2, 1, 1, 8], [1, 2, 52, 430, 1]], + [[1, 163, 1, 22, 3], [1, 163, 116, 22, 1]], + [[1, 1, 44, 30, 7], [1, 1, 44, 30, 1]], + [[1, 1, 1, 1, 28], [1, 127, 1, 5, 28]], + [[1, 2, 394, 38, 1], [1, 2, 394, 38, 16]], + [[1, 10, 49, 77, 17], [1, 1, 1, 1, 17]], + [[1, 431, 6, 2, 225], [1, 1, 6, 2, 225]], + [[1, 15, 1, 28, 1], [1, 15, 1, 28, 463]], [[1, 129, 2, 48, 96], [1, 129, 2, 1, 1]], + [[1, 1, 403, 17, 2], [1, 44, 403, 17, 2]], + [[1, 1, 65, 2, 22], [1, 1, 65, 1, 1]], + [[1, 24, 103, 17, 18], [1, 24, 1, 1, 1]], + [[1, 1, 1, 1, 2], [1, 24, 194, 50, 1]], + [[1, 1, 107, 84, 9], [1, 1, 1, 1, 1]]]) + if idx < binary_op_data_shape.shape[0]: + l_shape = binary_op_data_shape[idx][0] + r_shape = binary_op_data_shape[idx][1] + else: + # Generate random data that has ndim between 1-7 and all the shape dims between 1-5 + ndim = nd.random.randint(1, 6) + shape = nd.random.randint(1, 6, size=(ndim,)) + l_same_dim = nd.random.randint(0, 5) + r_same_dim = nd.random.randint(0, 5) + l_axis_flags = nd.random.randint(0, 2, size=ndim) + r_axis_flags = nd.random.randint(0, 2, size=ndim) + if l_same_dim == 4: + l_axis_flags = nd.ones(ndim) + if r_same_dim == 4: + r_axis_flags = nd.ones(ndim) + l_shape = shape.copy() + r_shape = shape.copy() + l_shape[nd.where(l_axis_flags == 0)] = 1 + r_shape[nd.where(r_axis_flags == 0)] = 1 + return tuple(l_shape.asnumpy().astype(int)), tuple(r_shape.asnumpy().astype(int)) + +# from test_operation.py +def reduce_op(shape, x): + if shape == x.shape: + return x + keepdims_shape = list(x.shape) + for i in range(len(shape)): + if x.shape[i] != shape[i]: + keepdims_shape[i] = 1 + x = nd.sum(x, axis=i).reshape(keepdims_shape) + return x + +@with_seed() +def test_broadcast_power(): + def broadcast_power(inputs): + return nd.broadcast_power(inputs[0], inputs[1]) + + def unreduced_grad_op(inputs): + x, y = inputs + return [y * nd.broadcast_power(x, y - 1), nd.broadcast_power(x, y) * nd.log(x)] + + def unreduced_grad_grad_op(inputs): + x, y = inputs + return [y * (y - 1) * nd.broadcast_power(x, y - 2), nd.broadcast_power(x, y) * (nd.log(x) ** 2)] + + def unreduced_grad_grad_grad_op(inputs): + x, y = inputs + return [y * (y - 1) * (y - 2) * nd.broadcast_power(x, y - 3), nd.broadcast_power(x, y) * (nd.log(x) ** 3)] + + low = 1.0 + high = 3.0 + for dim in range(1, 5): + x_shape, y_shape = gen_broadcast_shape(dim) + x = nd.random.uniform(low, high, x_shape) + y = nd.random.uniform(low, high, y_shape) + + check_nth_order_binary([x, y], broadcast_power, [unreduced_grad_op, unreduced_grad_grad_op, + unreduced_grad_grad_grad_op], [1, 2, 3], True, rtol=1e-3, atol=1e-5) + +def autograd_grad_ex(heads, variables, head_grads=None, retain_graph=None, create_graph=False, + train_mode=True): + """ If some variables don't in the path of computing heads, we set the heads grad of them to zero + instead of throwing exceptions. + + The autograd.grad requires user knows which variables involved to compute the heads grad of them. + That's fine for first order grad, but for higher order grad, the variables used to compute the heads, + may not used to compute their higher order grad. It's impossible to ask user to know + the formulas of every order grad. + + E.g. we use such code to compute 2-nd order gradient: + with autograd.record(): + z = op(x, y) + head_grad = nd.ones_like(z) + dz_dx, _ = autograd.grad(heads=z, variables=[x, y], head_grads=nd.ones_like(z), + create_graph=True, retain_graph=True) + d2z_d2x, _ = autograd.grad(heads=dz_dx, variables=[x, y], head_grads=nd.ones_like(dz_dx), + create_graph=True, retain_graph=True) + If z = x * y, because d2z_d2x = 0, MXNET will report the input is unreachable from the output. + But it seems in that case MXNET returns zeros is more reasonable. + """ + # xxx: only consider one head currently + argument_names = autograd.get_symbol(heads).list_arguments() + + # XXX: in some cases, a variable may has more than one outputs, we need a other way ot get the name of various. + # But in the unittest, it is fine + variable_names = [autograd.get_symbol(variable).list_outputs()[0] for variable in variables] + involved_variable_indexes = [] + involved_variables = [] + for i in range(0, len(variables)): + if variable_names[i] in argument_names: + involved_variables.append(variables[i]) + involved_variable_indexes.append(i) + + if involved_variables: + partial_grads = autograd.grad(heads, involved_variables, head_grads, retain_graph, create_graph, train_mode) + else: + partial_grads = [] + + grads = [] + partial_grads_index = 0 + for i in range(0, len(variables)): + if i in involved_variable_indexes: + grads.append(partial_grads[partial_grads_index]) + partial_grads_index += 1 + else: + grads.append(nd.zeros_like(variables[i])) + return grads + + +def check_nth_order_binary(inputs, op, grad_ops, orders, broadcast_op = False, rtol=None, atol=None): + """Assert n-th order autograd gradient against expected gradient. + + Multiple order of gradients can be checked by passing list of + function computing the particular order gradient and passing the corresponding list of order. + Note + ---- + 1. Orders should always be monotonically increasing. + 2. Elements of grads_ops should correspond to elements of orders + i.e. grads_op = [grad_op, grad_grad_grad_op] should be passed with + orders = [1, 3] + + Parameters + ---------- + inputs : tuple of mxnet.NDArray (x, y) + Input Array. + op : Callable (x,y) -> z + Operation to perform on Input Array. + grad_ops : Callable or List of Callable + Function (x,y) -> (n_grad_x, n_grad_y) to compute and assert gradient of given order. + orders : int or List of int + Order/s to assert expected and computed gradients. + + Returns + ------- + None + + """ + if isinstance(orders, int): + orders = [orders] + grad_ops = [grad_ops] + + assert all(i < j for i, j in zip(orders[0:-1], orders[1:])), \ + "orders should be monotonically increasing" + assert len(set(orders)) == len(orders), \ + "orders should have unique elements" + highest_order = max(orders) + + inputs = [nd.array(input) for input in inputs] + for input in inputs: + input.attach_grad() + + expected_grads = [grad_op(inputs) for grad_op in grad_ops] + computed_grads = [] + head_grads = [[]] + + # Perform compute. + with autograd.record(): + z = op(inputs) + heads = [z for _ in inputs] + for current_order in range(1, highest_order+1): + grads = [] + new_head_grads = [] + new_heads = [] + for i in range(0, len(heads)): + head = heads[i] + head_grad = nd.random.normal(shape=head.shape) + new_head_grads.append(head_grad) + grads.append(autograd_grad_ex(heads=head, variables=inputs, head_grads=head_grad, + create_graph=True, retain_graph=True)[i]) + # If we only use once auto grad with head_grads = head_grad in every iteration, + # in the i-th iteration, we use head = derivative_(i-1) * head_grad_(i-1) + # but in the expected computed, we use head = derivative_(i-1) + new_heads.append(autograd_grad_ex(heads=head, variables=inputs, head_grads=nd.ones_like(head), + create_graph=True, retain_graph=True)[i]) + heads = new_heads + if current_order in orders: + computed_grads.append(grads) + head_grads.append(new_head_grads) + + # Validate all the gradients. + for order, grad_list, computed_grad_list in \ + zip(orders, expected_grads, computed_grads): + # Compute expected values. + # keep as numpy value and use dot mul + expected_grad_list = [grad for grad in grad_list] + for expected_grad, head_grad, computed_grad, input in zip(expected_grad_list, head_grads[order], computed_grad_list, inputs): + if broadcast_op: + expected_grad = reduce_op(input.shape, expected_grad * head_grad) + else: + expected_grad *= head_grad + assert_almost_equal(expected_grad.asnumpy(), computed_grad.asnumpy(), rtol=rtol, atol=atol) def arange_shape_like(y): shape = y.shape