Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 4c0464d

Browse files
committed
support pure boolean elemwise/broadcast binary op
1 parent 3c404a5 commit 4c0464d

File tree

9 files changed

+241
-52
lines changed

9 files changed

+241
-52
lines changed

src/operator/mshadow_op.h

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,18 @@ using std::is_integral;
9797
} \
9898
}
9999

100+
#define MXNET_BINARY_MATH_OP_NC_WITH_BOOL(name, expr) \
101+
struct name : public mxnet_op::tunable { \
102+
template<typename DType, \
103+
typename std::enable_if<!std::is_same<DType, bool>::value, int>::type = 0> \
104+
MSHADOW_XINLINE static DType Map(DType a, DType b) { \
105+
return (expr); \
106+
} \
107+
MSHADOW_XINLINE static bool Map(bool a, bool b) { \
108+
return (expr); \
109+
} \
110+
}
111+
100112
#define MXNET_BINARY_LOGIC_OP_NC(name, expr) \
101113
struct name : public mxnet_op::tunable { \
102114
template<typename DType> \
@@ -192,8 +204,6 @@ MXNET_BINARY_MATH_OP_NC(left, a);
192204

193205
MXNET_BINARY_MATH_OP_NC(right, b);
194206

195-
MXNET_BINARY_MATH_OP_NC(mul, a * b);
196-
197207
#ifndef _WIN32
198208
struct mixed_plus {
199209
template<typename DType,
@@ -288,11 +298,13 @@ struct mixed_mul {
288298
};
289299
#endif
290300

291-
MXNET_BINARY_MATH_OP_NC(div, a / b);
301+
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(mul, a * b);
302+
303+
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(div, a / b);
292304

293-
MXNET_BINARY_MATH_OP_NC(plus, a + b);
305+
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(plus, a + b);
294306

295-
MXNET_BINARY_MATH_OP_NC(minus, a - b);
307+
MXNET_BINARY_MATH_OP_NC_WITH_BOOL(minus, a - b);
296308

297309
MXNET_UNARY_MATH_OP(negation, -a);
298310

src/operator/numpy/np_elemwise_broadcast_op.cc

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
7070
return true;
7171
}
7272

73-
#ifdef _WIN32
73+
#ifndef _WIN32
7474
#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \
7575
NNVM_REGISTER_OP(name) \
7676
.set_num_inputs(2) \
@@ -85,10 +85,6 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
8585
[](const NodeAttrs& attrs){ \
8686
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
8787
}) \
88-
.set_attr<FResourceRequest>("FResourceRequest", \
89-
[](const NodeAttrs& attrs) { \
90-
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
91-
}) \
9288
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
9389
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
9490
#else
@@ -106,6 +102,10 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
106102
[](const NodeAttrs& attrs){ \
107103
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
108104
}) \
105+
.set_attr<FResourceRequest>("FResourceRequest", \
106+
[](const NodeAttrs& attrs) { \
107+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \
108+
}) \
109109
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \
110110
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
111111
#endif
@@ -114,41 +114,38 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add)
114114
#ifndef _WIN32
115115
.set_attr<FCompute>(
116116
"FCompute<cpu>",
117-
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus,
118-
op::mshadow_op::mixed_plus>)
117+
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus,
118+
op::mshadow_op::mixed_plus>)
119119
#else
120120
.set_attr<FCompute>(
121121
"FCompute<cpu>",
122-
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::plus, op::mshadow_op::plus,
123-
op::mshadow_op::plus>)
122+
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus>)
124123
#endif
125124
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"});
126125

127126
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
128127
#ifndef _WIN32
129128
.set_attr<FCompute>(
130129
"FCompute<cpu>",
131-
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus,
130+
NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus,
132131
op::mshadow_op::mixed_rminus>)
133132
#else
134133
.set_attr<FCompute>(
135134
"FCompute<cpu>",
136-
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::minus, op::mshadow_op::minus,
137-
op::mshadow_op::minus>)
135+
NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
138136
#endif
139137
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});
140138

141139
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
142140
#ifndef _WIN32
143141
.set_attr<FCompute>(
144142
"FCompute<cpu>",
145-
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
146-
op::mshadow_op::mixed_mul>)
143+
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
144+
op::mshadow_op::mixed_mul>)
147145
#else
148146
.set_attr<FCompute>(
149147
"FCompute<cpu>",
150-
MixedBinaryBroadcastCompute<cpu, op::mshadow_op::mul, op::mshadow_op::mul,
151-
op::mshadow_op::mul>)
148+
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul>)
152149
#endif
153150
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
154151

src/operator/numpy/np_elemwise_broadcast_op.cu

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,39 +32,36 @@ NNVM_REGISTER_OP(_npi_add)
3232
#ifndef _WIN32
3333
.set_attr<FCompute>(
3434
"FCompute<gpu>",
35-
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus,
36-
op::mshadow_op::mixed_plus>);
35+
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus,
36+
op::mshadow_op::mixed_plus>);
3737
#else
3838
.set_attr<FCompute>(
3939
"FCompute<gpu>",
40-
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::plus, op::mshadow_op::plus,
41-
op::mshadow_op::plus>);
40+
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::plus>);
4241
#endif
4342

4443
NNVM_REGISTER_OP(_npi_subtract)
4544
#ifndef _WIN32
4645
.set_attr<FCompute>(
4746
"FCompute<gpu>",
48-
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus,
47+
NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus,
4948
op::mshadow_op::mixed_rminus>);
5049
#else
5150
.set_attr<FCompute>(
5251
"FCompute<gpu>",
53-
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::minus, op::mshadow_op::minus,
54-
op::mshadow_op::minus>);
52+
NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
5553
#endif
5654

5755
NNVM_REGISTER_OP(_npi_multiply)
5856
#ifndef _WIN32
5957
.set_attr<FCompute>(
6058
"FCompute<gpu>",
61-
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
62-
op::mshadow_op::mixed_mul>);
59+
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul,
60+
op::mshadow_op::mixed_mul>);
6361
#else
6462
.set_attr<FCompute>(
6563
"FCompute<gpu>",
66-
MixedBinaryBroadcastCompute<gpu, op::mshadow_op::mul, op::mshadow_op::mul,
67-
op::mshadow_op::mul>);
64+
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul>);
6865
#endif
6966

7067
NNVM_REGISTER_OP(_npi_mod)

src/operator/numpy/np_elemwise_broadcast_op.h

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
namespace mxnet {
3535
namespace op {
3636

37-
inline void PrintErrorMessage(const std::string& name, const int dtype1, const int dtype2) {
38-
LOG(FATAL) << "Operator " << name << " does not support combination of "
37+
inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) {
38+
LOG(FATAL) << "Operator " << op_name << " does not support combination of "
3939
<< common::dtype_string(dtype1) << " with " << common::dtype_string(dtype2)
4040
<< " yet...";
4141
}
@@ -218,7 +218,11 @@ void MixedAllRealBinaryBroadcastCompute(const std::string& op_name,
218218
}
219219
#endif
220220

221+
#ifndef _WIN32
221222
template<typename xpu, typename OP, typename LOP, typename ROP>
223+
#else
224+
template<typename xpu, typename OP>
225+
#endif
222226
void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
223227
const OpContext& ctx,
224228
const std::vector<TBlob>& inputs,
@@ -233,13 +237,6 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
233237
const TBlob& rhs = inputs[1];
234238
const TBlob& out = outputs[0];
235239

236-
if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;
237-
238-
if (lhs.type_flag_ == rhs.type_flag_) {
239-
BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
240-
return;
241-
}
242-
243240
#ifndef _WIN32
244241
mxnet::TShape new_lshape, new_rshape, new_oshape;
245242
int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
@@ -299,7 +296,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
299296
temp_tblob = TBlob(temp_tensor);
300297
});
301298
CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
302-
BinaryBroadcastCompute<xpu, OP>(
299+
BinaryBroadcastCompute<xpu, OP, allow_bool>(
303300
attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
304301
} else {
305302
MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, {
@@ -308,7 +305,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
308305
temp_tblob = TBlob(temp_tensor);
309306
});
310307
CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
311-
BinaryBroadcastCompute<xpu, OP>(
308+
BinaryBroadcastCompute<xpu, OP, allow_bool>(
312309
attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
313310
}
314311
} else {
@@ -317,6 +314,72 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
317314
#endif
318315
}
319316

317+
#ifndef _WIN32
318+
template<typename xpu, typename OP, typename LOP, typename ROP>
319+
#else
320+
template<typename xpu, typename OP>
321+
#endif
322+
void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
323+
const OpContext& ctx,
324+
const std::vector<TBlob>& inputs,
325+
const std::vector<OpReqType>& req,
326+
const std::vector<TBlob>& outputs) {
327+
using namespace mshadow;
328+
using namespace mxnet_op;
329+
CHECK_EQ(inputs.size(), 2U);
330+
CHECK_EQ(outputs.size(), 1U);
331+
332+
const TBlob& lhs = inputs[0];
333+
const TBlob& rhs = inputs[1];
334+
const TBlob& out = outputs[0];
335+
336+
if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;
337+
338+
if (lhs.type_flag_ == rhs.type_flag_) {
339+
BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
340+
return;
341+
}
342+
343+
#ifndef _WIN32
344+
MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs);
345+
#else
346+
MixedBinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
347+
#endif
348+
}
349+
350+
#ifndef _WIN32
351+
template<typename xpu, typename OP, typename LOP, typename ROP>
352+
#else
353+
template<typename xpu, typename OP>
354+
#endif
355+
void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
356+
const OpContext& ctx,
357+
const std::vector<TBlob>& inputs,
358+
const std::vector<OpReqType>& req,
359+
const std::vector<TBlob>& outputs) {
360+
using namespace mshadow;
361+
using namespace mxnet_op;
362+
CHECK_EQ(inputs.size(), 2U);
363+
CHECK_EQ(outputs.size(), 1U);
364+
365+
const TBlob& lhs = inputs[0];
366+
const TBlob& rhs = inputs[1];
367+
const TBlob& out = outputs[0];
368+
369+
if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;
370+
371+
if (lhs.type_flag_ == rhs.type_flag_) {
372+
BinaryBroadcastComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, outputs);
373+
return;
374+
}
375+
376+
#ifndef _WIN32
377+
MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs);
378+
#else
379+
MixedBinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
380+
#endif
381+
}
382+
320383
template<typename xpu, typename LOP, typename ROP>
321384
void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
322385
const OpContext& ctx,

src/operator/operator_tune-inl.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ class OperatorTune : public OperatorTuneByType<DType> {
116116
TuneAll();
117117
}
118118

119+
~OperatorTune() {
120+
delete[] data_set_;
121+
}
122+
119123
/*!
120124
* \brief Initialize the OperatorTune object
121125
* \return Whether the OperatorTune object was successfully initialized
@@ -124,7 +128,7 @@ class OperatorTune : public OperatorTuneByType<DType> {
124128
if (!initialized_) {
125129
initialized_ = true;
126130
// Generate some random data for calling the operator kernels
127-
data_set_.reserve(0x100);
131+
data_set_ = reinterpret_cast<DType*>(new char[0x100 * sizeof(DType)]);
128132
std::random_device rd;
129133
std::mt19937 gen(rd());
130134
if (!std::is_integral<DType>::value) {
@@ -136,7 +140,7 @@ class OperatorTune : public OperatorTuneByType<DType> {
136140
--n;
137141
continue;
138142
}
139-
data_set_.emplace_back(val);
143+
data_set_[n] = val;
140144
}
141145
} else {
142146
std::uniform_int_distribution<> dis(-128, 127);
@@ -147,7 +151,7 @@ class OperatorTune : public OperatorTuneByType<DType> {
147151
--n;
148152
continue;
149153
}
150-
data_set_.emplace_back(val);
154+
data_set_[n] = val;
151155
}
152156
}
153157
// Use this environment variable to generate new tuning statistics
@@ -517,7 +521,7 @@ class OperatorTune : public OperatorTuneByType<DType> {
517521
/*! \brief Number of passes to obtain an average */
518522
static constexpr duration_t OUTSIDE_COUNT = (1 << OUTSIDE_COUNT_SHIFT);
519523
/*! \brief Random data for timing operator calls */
520-
static std::vector<DType> data_set_;
524+
static DType* data_set_;
521525
/*! \brief Operators tuned */
522526
static std::unordered_set<std::string> operator_names_;
523527
/*! \brief Arbitary object to modify in OMP loop */

src/operator/operator_tune.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ double OperatorTuneBase::tuning_weight_scale_ = 0.0;
3939
*/
4040
#define IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(__typ$) \
4141
template<> bool OperatorTune<__typ$>::initialized_ = false; \
42-
template<> std::vector<__typ$> OperatorTune<__typ$>::data_set_ = {}; \
42+
template<> __typ$* OperatorTune<__typ$>::data_set_ = nullptr; \
4343
template<> volatile tune::TuningMode OperatorTuneByType<__typ$>::tuning_mode_ = tune::kAuto; \
4444
template<> volatile int OperatorTune<__typ$>::volatile_int_ = 9; /* arbitrary number */ \
4545
template<> std::unordered_set<std::string> OperatorTune<__typ$>::operator_names_({}); \
@@ -314,10 +314,10 @@ IMPLEMENT_UNARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_logical_not);
314314
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::nt); // NOLINT()
315315
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::clip); // NOLINT()
316316
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::clip); // NOLINT()
317-
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::plus); // NOLINT()
318-
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus); // NOLINT()
319-
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mul); // NOLINT()
320-
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div); // NOLINT()
317+
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::plus); // NOLINT()
318+
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::minus); // NOLINT()
319+
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::mul); // NOLINT()
320+
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::div); // NOLINT()
321321
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::true_divide); // NOLINT()
322322
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus_sign); // NOLINT()
323323
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus); // NOLINT()

0 commit comments

Comments
 (0)