Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Higher Order Gradient Support broadcast_to, broadcast_power, power, elemwise_mul and elemwise_sub and add unit test function check check_nth_order_binary #17754

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ List of Contributors
* [Piljae Chae](https://github.com/IHateMint)
* [Oliver Kowalke](https://github.com/olk)
* [Connor Goggins](https://github.com/connorgoggins)
* [Deng, Wenqi](https://github.com/tobecontinued)

Label Bot
---------
Expand Down
10 changes: 10 additions & 0 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,16 @@ inline std::vector<nnvm::NodeEntry> MakeNonlossGradNode(
return CreateNodeEntries(p, &ograds, &inputs);
}

struct NonlossGradFGradient {
nnvm::FGradient grad_func;
std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
if (CheckGradAllZero(ograds))
return MakeZeroGradNodes(n, ograds);
return grad_func(n, ograds);
}
};

/*! \brief Parse keyword arguments as PType arguments and save to parsed */
template<typename PType>
inline void ParamParser(nnvm::NodeAttrs* attrs) {
Expand Down
70 changes: 61 additions & 9 deletions src/operator/tensor/broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,58 @@ NNVM_REGISTER_OP(_broadcast_backward)
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
});

// if the rhs can broadcast to lhs, use sum to reduce lhs to shape of rhs
NNVM_REGISTER_OP(_reduce_sum_brodcasted)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<mxnet::FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& lhs_shape = (*in_attrs)[0];
mxnet::TShape& rhs_shape = (*in_attrs)[1];

if (!mxnet::ndim_is_known(lhs_shape) || !mxnet::ndim_is_known(rhs_shape)) {
return false;
}

// the lhs and rhs are comp
CHECK_EQ(lhs_shape.ndim(), rhs_shape.ndim())
<< "Operand of shape " << lhs_shape << " cannot be reduced to " << rhs_shape;

for (int i = 0; i < lhs_shape.ndim(); ++i) {
if (rhs_shape[i] != -1) {
CHECK(lhs_shape[i] == rhs_shape[i] || rhs_shape[i] == 1)
<< "Array cannot be reduced from " << lhs_shape << " to " << rhs_shape;
}
}
auto oshape = mxnet::TShape(rhs_shape);

SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
})
.set_attr<FCompute>("FCompute<cpu>", [](const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
ReduceAxesComputeImpl<cpu, mshadow::red::sum, false, false,
op::mshadow_op::identity>(ctx, inputs, req, outputs, inputs[1].shape_);
})
.set_attr<nnvm::FGradient>("FGradient", NonlossGradFGradient{
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto head_grad = ograds[0];
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("broadcast_like", n->attrs.name + "_lhs_backward",
{head_grad, n->inputs[0]}, nullptr, &n));
ret.emplace_back(MakeNode("zeros_like", n->attrs.name + "_rhs_backward",
{n->inputs[1]}, nullptr, &n));
return ret;
}});

NNVM_REGISTER_OP(broadcast_like)
.set_num_inputs(2)
.set_num_outputs(1)
Expand All @@ -138,25 +190,25 @@ NNVM_REGISTER_OP(broadcast_like)
return std::vector<std::string>{"lhs", "rhs"};
})
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
.set_attr<nnvm::FGradient>("FGradient", NonlossGradFGradient{
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
if (CheckGradAllZero(ograds))
return MakeZeroGradNodes(n, ograds);
std::vector<nnvm::NodeEntry> lhs = MakeNonlossGradNode("_broadcast_backward", n, ograds, {},
{{"keepdims", "true"}});
lhs.emplace_back(MakeNode("zeros_like", n->attrs.name + "_rhs_backward",
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("_reduce_sum_brodcasted", n->attrs.name + "_lhs_backward",
{ograds[0], n->inputs[0]}, nullptr, &n));
ret.emplace_back(MakeNode("zeros_like", n->attrs.name + "_rhs_backward",
{n->inputs[1]}, nullptr, &n));
return lhs;
})
return ret;
}})
.add_argument("lhs", "NDArray-or-Symbol", "First input.")
.add_argument("rhs", "NDArray-or-Symbol", "Second input.")
.describe(R"code(Broadcasts lhs to have the same shape as rhs.

Broadcasting is a mechanism that allows NDArrays to perform arithmetic operations
with arrays of different shapes efficiently without creating multiple copies of arrays.
Also see, `Broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ for more explanation.

Also see, `Broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ for more explanation.
Broadcasting is allowed on axes with size 1, such as from `(2,1,3,1)` to
`(2,8,3,9)`. Elements will be duplicated on the broadcasted axes.

Expand Down
8 changes: 8 additions & 0 deletions src/operator/tensor/broadcast_reduce_op_value.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ NNVM_REGISTER_OP(broadcast_to)
NNVM_REGISTER_OP(broadcast_like)
.set_attr<FCompute>("FCompute<gpu>", BroadcastCompute<gpu>);

NNVM_REGISTER_OP(_reduce_sum_brodcasted)
.set_attr<FCompute>("FCompute<gpu>", [](const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
ReduceAxesComputeImpl<gpu, mshadow::red::sum, false, false,
op::mshadow_op::identity>(ctx, inputs, req, outputs, inputs[1].shape_);
});

NNVM_REGISTER_OP(_broadcast_backward)
.set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum>);

Expand Down
52 changes: 36 additions & 16 deletions src/operator/tensor/elemwise_binary_broadcast_op_extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,42 @@ Example::

)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::power>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"});

NNVM_REGISTER_OP(_backward_broadcast_power)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::power_grad,
mshadow_op::power_rgrad>);
.set_attr<nnvm::FGradient>("FGradient", NonlossGradFGradient{
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto head_grad_z = ograds[0];
auto x = nnvm::NodeEntry{mxnet::op::MakeNode("broadcast_like",
n->attrs.name + "_broadcast_like", {n->inputs[0], head_grad_z}, nullptr, &n)};
auto y = nnvm::NodeEntry{mxnet::op::MakeNode("broadcast_like",
n->attrs.name + "_broadcast_like", {n->inputs[1], head_grad_z}, nullptr, &n)};

auto one_like = nnvm::NodeEntry{mxnet::op::MakeNode("ones_like",
n->attrs.name + "_ones_like", {y}, nullptr, &n)};
auto y_sub_1 = nnvm::NodeEntry{MakeNode("elemwise_sub",
n->attrs.name + "_rhs_sub_1", {y, one_like}, nullptr, &n)};
auto x_power_y_sub_1 = nnvm::NodeEntry{MakeNode("broadcast_power",
n->attrs.name + "_lhs_power_rhs_sub_1", {x, y_sub_1}, nullptr, &n)};
auto dzdx = nnvm::NodeEntry{MakeNode("elemwise_mul",
n->attrs.name + "dpower/dlhs", {y, x_power_y_sub_1}, nullptr, &n)};

auto lnx = nnvm::NodeEntry{MakeNode("log",
n->attrs.name + "_ln_lhs", {x}, nullptr, &n)};
auto x_power_y = nnvm::NodeEntry{MakeNode("elemwise_mul",
n->attrs.name + "_lhs_power_rhs", {x_power_y_sub_1, x}, nullptr, &n)};
auto dzdy = nnvm::NodeEntry{MakeNode("elemwise_mul",
n->attrs.name + "dpower/drhs", {x_power_y, lnx}, nullptr, &n)};

auto broadcasted_lhs_backward = nnvm::NodeEntry{MakeNode("elemwise_mul",
n->attrs.name + "_broadcasted_lhs_backward", {head_grad_z, dzdx}, nullptr, &n)};
auto broadcasted_rhs_backward = nnvm::NodeEntry{MakeNode("elemwise_mul",
n->attrs.name + "_broadcasted_rhs_backward", {head_grad_z, dzdy}, nullptr, &n)};

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("_reduce_sum_brodcasted", n->attrs.name + "_lhs_backward",
{broadcasted_lhs_backward, n->inputs[0]}, nullptr, &n));
ret.emplace_back(MakeNode("_reduce_sum_brodcasted", n->attrs.name + "rhs_backward",
{broadcasted_rhs_backward, n->inputs[1]}, nullptr, &n));
return ret;
}});

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_maximum)
.add_alias("_npi_maximum")
Expand Down
59 changes: 25 additions & 34 deletions src/operator/tensor/elemwise_binary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,23 +192,19 @@ The storage type of ``elemwise_sub`` output depends on storage types of inputs
- otherwise, ``elemwise_sub`` generates output with default storage

)code")
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_sub"});
.set_attr<nnvm::FGradient>("FGradient", NonlossGradFGradient{
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto head_grad = ograds[0];
auto x = n->inputs[0];
auto y = n->inputs[1];

NNVM_REGISTER_OP(_backward_sub)
.set_num_inputs(1)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs &attrs) {
return std::vector<std::pair<int, int> >{{0, 0},
{0, 1}};
})
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::BackwardUseNone<cpu,
mshadow_op::identity, mshadow_op::negation>)
.set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::BackwardUseNoneEx<cpu,
mshadow_op::identity, mshadow_op::negation>)
.set_attr<FInferStorageType>("FInferStorageType",
ElemwiseStorageType<1, 2, true, true, true>);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("identity", n->attrs.name + "_lhs_backward",
{head_grad}, nullptr, &n));
ret.emplace_back(MakeNode("negative", n->attrs.name + "_rhs_backward",
{head_grad}, nullptr, &n));
return ret;
}});

MXNET_OPERATOR_REGISTER_BINARY(elemwise_mul)
MXNET_ADD_SPARSE_OP_ALIAS(elemwise_mul)
Expand All @@ -235,25 +231,20 @@ The storage type of ``elemwise_mul`` output depends on storage types of inputs
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.add_alias("_mul").add_alias("_Mul")
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_mul"});
.set_attr<nnvm::FGradient>("FGradient", NonlossGradFGradient{
[](const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto head_grad = ograds[0];
auto x = n->inputs[0];
auto y = n->inputs[1];

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_lhs_backward",
{head_grad, y}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_rhs_backward",
{head_grad, x}, nullptr, &n));
return ret;
}});

NNVM_REGISTER_OP(_backward_mul)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs &attrs) {
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseBinaryOp::BackwardUseInStorageType)
.set_attr<FResourceRequest>("FResourceRequest", /* For Sparse CSR */
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::BackwardUseIn<
cpu, mshadow_op::right, mshadow_op::left>)
.set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::BackwardUseInEx<
cpu, mshadow_op::right, mshadow_op::left>);

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(elemwise_div, op::mshadow_op::div)
MXNET_ADD_SPARSE_OP_ALIAS(elemwise_div)
Expand Down
10 changes: 9 additions & 1 deletion tests/cpp/operator/runner/core_op_runner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,15 @@ static const std::vector<std::pair<std::string, std::string>> test_unary_operato

static const std::vector<std::pair<std::string, std::string>> test_binary_operators = {
{ "elemwise_add", "_backward_add" },
{ "elemwise_mul", "_backward_mul" }
// TODO(Deng, Wenqi): In https://github.com/apache/incubator-mxnet/pull/17754,
// we have changed backward op to graph of ops for computing backward of elemwise_mul,
// but the CoreOpExecutor in tests/cpp/include/test_core_op.h actually has issues
// to support this way even it provides CoreOpExecutor::GetBackward for the case.
// e.g: It actually assumes there is one backward for all kinds of op, but elemwise_mul has two.
// It will get wrong dependency for the second backward in CoreOpExecutor::GetBackwardDependency
// due to "return igrad_entries[0].node;" // and failed to call CoreOpExecutor::bwd_inputs()
// and CoreOpExecutor::bwd_outpuss() due to "CHECK_EQ(backward_.size(), 1U)";.
// { "elemwise_mul", "_backward_mul" }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know much about this, but I don't think removing an existing test would be a good idea.

@apeforest @larroy @sxjscience Would be better people to take the call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more explanation, becaue the failure of the test cases are caused by defects of CoreOpExecutor, I just want a pr to achieve only one task. Maybe we need a other pr or jira to fix the issuee of CoreOpExecutor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That makes sense.
However, I feel the mentioned people are in much better position to make the call.

};

template<typename TT>
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/operator/tune/operator_tune_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ static float EvaluateTune(const bool verbose = true) {
{"sigmoid", ""},
{"sqrt", ""},
{"elemwise_add", "_backward_add"},
{"elemwise_mul", "_backward_mul"},
// TODO(Deng, Wenqi): See comment in tests/cpp/operator/runner/core_op_runner_test.cc:49
// {"elemwise_mul", "_backward_mul"},
{"elemwise_div", "_backward_div"}
};
} else {
Expand Down
Loading