Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit e64fdc7

Browse files
committed
Add sum op for boolean ndarrays using tvm op module
1 parent bc868b7 commit e64fdc7

File tree

7 files changed

+148
-10
lines changed

7 files changed

+148
-10
lines changed

contrib/tvmop/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from . import umath
18+
from . import umath, fromnumeric

contrib/tvmop/core/fromnumeric.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
19+
import tvm
20+
from .. import defop
21+
from ..utils import reduce_axes, assign_by_req
22+
23+
24+
def _compute_sum(itype, otype, ndim, reduce1st_dim, req):
25+
axes = ([reduce1st_dim, 1 - reduce1st_dim] * ndim)[:ndim]
26+
a = tvm.placeholder([tvm.var() for _ in range(ndim)], name='a', dtype=itype)
27+
reduce_output = reduce_axes(a, axes, tvm.sum, otype)
28+
output_placeholder, final_output = assign_by_req(reduce_output, req)
29+
s = tvm.create_schedule(final_output.op)
30+
return s, a, output_placeholder, final_output, [reduce_output, final_output]
31+
32+
33+
@defop(name='sum_cpu', target='cpu', itype=['bool'],
34+
otype=['float32', 'float64', 'int32', 'int64'],
35+
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
36+
attrs=["reduce1st_dim", "req"])
37+
def _sum(itype, otype, ndim, reduce1st_dim, req):
38+
s, a, output_placeholder, final_output, expr_list = _compute_sum(
39+
itype, otype, ndim, reduce1st_dim, req)
40+
for expr in expr_list:
41+
axes = [axis for axis in expr.op.axis]
42+
fused = s[expr].fuse(*axes)
43+
s[expr].parallel(fused)
44+
return s, [a, output_placeholder, final_output]

contrib/tvmop/opdef.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def invoke_all(self):
8080
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) \
8181
+ ''.join(["%s_%d" % (arg.dtype, len(arg.shape))
8282
for arg in args if hasattr(arg, 'shape')])
83+
if 'sum' in name:
84+
print(name)
8385
yield sch, args, name
8486

8587
def get_binds(self, args):

contrib/tvmop/utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@
2121
AllTypes = ["float32", "float64", "float16", "uint8", "int8", "int32", "int64"]
2222
RealTypes = ["float32", "float64", "float16"]
2323

24-
def assign_by_req(a, req):
24+
25+
def assign_by_req(a, req, otype=None):
2526
b = tvm.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype)
26-
if (req == "kAddTo"):
27-
c = tvm.compute(a.shape, lambda *idx: a[idx] + b[idx])
27+
if req == "kAddTo":
28+
c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) + b[idx]
29+
if otype else a[idx] + b[idx])
2830
else:
29-
c = tvm.compute(a.shape, lambda *idx: a[idx])
31+
c = tvm.compute(a.shape, lambda *idx: a[idx].astype(otype) if otype else a[idx])
3032
return b, c
3133

3234

33-
def reduce_axes(X, axes, reducer):
35+
def reduce_axes(X, axes, reducer, atype=None):
3436
def get_index(idx, ridx):
3537
j = 0
3638
k = 0
@@ -45,5 +47,7 @@ def get_index(idx, ridx):
4547
odim = (len(ishape) + 1 - axes[0]) // 2
4648
oshape = [tvm.var() for _ in range(odim)]
4749
ridx = [tvm.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1]
48-
ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)], axis=ridx), name='ret')
50+
ret = tvm.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)].astype(atype)
51+
if atype else X[get_index(idx, ridx)],
52+
axis=ridx), name='ret')
4953
return ret

src/operator/numpy/np_broadcast_reduce_op.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> {
6767
inline TShape NumpyReduceAxesShapeImpl(const TShape& ishape,
6868
const dmlc::optional<mxnet::Tuple<int>>& axis,
6969
bool keepdims) {
70-
// TODO(junwu): improve the logic
7170
// If input is a scalar, output should be a scalar too
7271
if (ishape.ndim() == 0) {
7372
if (axis.has_value()) {
@@ -158,6 +157,10 @@ inline bool NeedSafeAcc(int itype, int otype) {
158157
return safe_acc_hint && rule;
159158
}
160159

160+
void TVMOpReduce(const OpContext& ctx, const TBlob& input,
161+
const dmlc::optional<mxnet::Tuple<int>>& axis,
162+
const TBlob& output, const OpReqType req, const std::string& reducer_name);
163+
161164
template<typename xpu, typename reducer, bool safe_acc_hint = false, bool normalize = false,
162165
typename OP = op::mshadow_op::identity>
163166
void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
@@ -169,6 +172,19 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
169172
if (param.initial.has_value()) {
170173
LOG(FATAL) << "initial is not supported yet";
171174
}
175+
if (req[0] == kNullOp) return;
176+
CHECK_NE(req[0], kWriteInplace) << "Reduce does not support write in-place";
177+
// If boolean ndarray, use the kernel generated by TVM
178+
if (inputs[0].type_flag_ == mshadow::kBool) {
179+
std::string reducer_name;
180+
if (std::is_same<reducer, mshadow_op::sum>::value) {
181+
reducer_name = "sum";
182+
} else {
183+
LOG(FATAL) << "Only reduce op: `sum` is supported for boolean ndarrays";
184+
}
185+
TVMOpReduce(ctx, inputs[0], param.axis, outputs[0], req[0], reducer_name);
186+
return;
187+
}
172188
if (param.axis.has_value() && param.axis.value().ndim() == 0) {
173189
UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
174190
}
@@ -194,6 +210,8 @@ inline void NumpyReduceAxesBackwardUseNone(const nnvm::NodeAttrs& attrs,
194210
const std::vector<TBlob>& outputs) {
195211
using namespace mshadow;
196212
using namespace mshadow::expr;
213+
CHECK_NE(outputs[0].type_flag_, kBool) << "reduce operators do not support gradient calculation "
214+
"for input tensors of boolean type.";
197215
const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
198216
TShape small;
199217
if (param.keepdims) {

src/operator/numpy/np_broadcast_reduce_op_value.cc

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
* \brief CPU Implementation of broadcast and reduce functions based on value.
2424
*/
2525

26+
#if MXNET_USE_TVM_OP
27+
#include "../tvmop/op_module.h"
28+
#endif // MXNET_USE_TVM_OP
29+
2630
#include "np_broadcast_reduce_op.h"
2731

2832
namespace mxnet {
@@ -38,7 +42,15 @@ inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
3842
const NumpyReduceAxesParam &param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
3943

4044
if (param.dtype.has_value()) {
45+
if (in_attrs->at(0) == mshadow::kBool) {
46+
CHECK(param.dtype.value() == mshadow::kInt64 || param.dtype.value() == mshadow::kFloat32
47+
|| param.dtype.value() == mshadow::kFloat64) << "Only support the following output "
48+
"dtypes when input dtype is bool: "
49+
"int32, int64, float32, float64.";
50+
}
4151
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
52+
} else if (in_attrs->at(0) == mshadow::kBool) {
53+
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
4254
} else {
4355
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
4456
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
@@ -47,6 +59,64 @@ inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
4759
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
4860
}
4961

62+
#if MXNET_USE_TVM_OP
63+
static constexpr int max_reduce_ndim = 5;
64+
TBlob PrependAxes(const TBlob& src, const int dst_ndim);
65+
#endif // MXNET_USE_TVM_OP
66+
67+
void TVMOpReduce(const OpContext& ctx,
68+
const TBlob& input,
69+
const dmlc::optional<mxnet::Tuple<int>>& axis,
70+
const TBlob& output,
71+
const OpReqType req,
72+
const std::string& reducer_name) {
73+
#if MXNET_USE_TVM_OP
74+
CHECK_GE(input.ndim(), output.ndim());
75+
CHECK_LE(input.ndim(), max_reduce_ndim) << "TVMOpReduce only supports ndim <= "
76+
<< max_reduce_ndim;
77+
78+
const TBlob expanded_output = (input.ndim() == output.ndim() ?
79+
output : output.reshape(NumpyReduceAxesShapeImpl(input.shape_, axis, true)));
80+
CHECK_EQ(input.ndim(), expanded_output.ndim());
81+
int reduce1st_dim = 0;
82+
if (input.ndim() > 0 && input.size(0) != expanded_output.size(0)) {
83+
reduce1st_dim = 1;
84+
}
85+
// collapse consecutive dimensions where reduction are performed or not performed
86+
std::vector<index_t> ishape_vec;
87+
for (int i = 0; i < input.ndim(); ++i) {
88+
if (i == 0 || ((input.size(i) != expanded_output.size(i))
89+
!= (input.size(i-1) != expanded_output.size(i-1)))) {
90+
ishape_vec.push_back(input.size(i));
91+
} else {
92+
ishape_vec.back() *= input.size(i);
93+
}
94+
}
95+
// append axes after collapsed ishape to reach the max ndim allowed
96+
for (int i = ishape_vec.size(); i < max_reduce_ndim; ++i) {
97+
ishape_vec.push_back(1);
98+
}
99+
std::vector<index_t> oshape_vec;
100+
for (size_t i = reduce1st_dim; i < ishape_vec.size(); i += 2) {
101+
oshape_vec.push_back(ishape_vec[i]);
102+
}
103+
TShape ishape(ishape_vec.begin(), ishape_vec.end()), oshape(oshape_vec.begin(), oshape_vec.end());
104+
TBlob input_tvm = input.reshape(ishape);
105+
TBlob output_tvm = output.reshape(oshape);
106+
const std::string ctx_name =
107+
(ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kCPU) ? "cpu" : "gpu";
108+
std::ostringstream func_name;
109+
func_name << reducer_name << "_"
110+
<< (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kCPU ? "cpu" : "gpu")
111+
<< "reduce1st_dim_" << reduce1st_dim
112+
<< "req_" << (req == kWriteTo ? "kWriteTo" : "kAddTo");
113+
LOG(INFO) << "sum func name: " << func_name.str();
114+
tvm::runtime::TVMOpModule::Get()->Call(func_name.str(), ctx, {input_tvm, output_tvm, output_tvm});
115+
#else
116+
LOG(FATAL) << "Please add USE_TVM_OP=1 to enable kernels generated by TVM."
117+
#endif // MXNET_USE_TVM_OP
118+
}
119+
50120
NNVM_REGISTER_OP(_np_sum)
51121
.describe(R"code()code" ADD_FILELINE)
52122
.set_num_inputs(1)

src/operator/numpy/np_elemwise_broadcast_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
#if MXNET_USE_TVM_OP
2727
#include <tvm/runtime/c_runtime_api.h>
2828
#include <tvm/runtime/packed_func.h>
29+
#include "../tvmop/op_module.h"
2930
#endif // MXNET_USE_TVM_OP
3031

31-
#include "../tvmop/op_module.h"
3232
#include "../tensor/elemwise_binary_broadcast_op.h"
3333
#include "../tensor/elemwise_binary_scalar_op.h"
3434

@@ -140,7 +140,7 @@ bool NumpyBinaryLogicOpType(const nnvm::NodeAttrs& attrs,
140140
return true;
141141
}
142142

143-
static TBlob PrependAxes(const TBlob& src, const int dst_ndim) {
143+
TBlob PrependAxes(const TBlob& src, const int dst_ndim) {
144144
CHECK_LE(src.shape_.ndim(), dst_ndim);
145145
const int src_ndim = src.shape_.ndim();
146146
if (src_ndim == dst_ndim) return src;

0 commit comments

Comments
 (0)