Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 4b372cc

Browse files
committed
Add sum boolean gpu compute
1 parent e64fdc7 commit 4b372cc

File tree

5 files changed

+52
-17
lines changed

5 files changed

+52
-17
lines changed

contrib/tvmop/core/fromnumeric.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,30 @@ def _compute_sum(itype, otype, ndim, reduce1st_dim, req):
3434
otype=['float32', 'float64', 'int32', 'int64'],
3535
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
3636
attrs=["reduce1st_dim", "req"])
37-
def _sum(itype, otype, ndim, reduce1st_dim, req):
38-
s, a, output_placeholder, final_output, expr_list = _compute_sum(
37+
def _sum_cpu(itype, otype, ndim, reduce1st_dim, req):
38+
s, a, output_placeholder, final_output, tensor_list = _compute_sum(
3939
itype, otype, ndim, reduce1st_dim, req)
40-
for expr in expr_list:
41-
axes = [axis for axis in expr.op.axis]
42-
fused = s[expr].fuse(*axes)
43-
s[expr].parallel(fused)
40+
for t in tensor_list:
41+
axes = [axis for axis in t.op.axis]
42+
fused = s[t].fuse(*axes)
43+
s[t].parallel(fused)
44+
return s, [a, output_placeholder, final_output]
45+
46+
47+
@defop(name='sum_gpu', target='gpu', itype=['bool'],
48+
otype=['float32', 'float64', 'int32', 'int64'],
49+
ndim=[5], req=['kWriteTo', 'kAddTo'], reduce1st_dim=[0, 1],
50+
attrs=["reduce1st_dim", "req"])
51+
def _sum_gpu(itype, otype, ndim, reduce1st_dim, req):
52+
s, a, output_placeholder, final_output, tensor_list = _compute_sum(
53+
itype, otype, ndim, reduce1st_dim, req)
54+
num_threads = 64
55+
for t in tensor_list:
56+
block_x = tvm.thread_axis("blockIdx.x")
57+
thread_x = tvm.thread_axis("threadIdx.x")
58+
axes = [axis for axis in t.op.axis]
59+
fused = s[t].fuse(*axes)
60+
bx, tx = s[t].split(fused, factor=num_threads)
61+
s[t].bind(bx, block_x)
62+
s[t].bind(tx, thread_x)
4463
return s, [a, output_placeholder, final_output]

contrib/tvmop/opdef.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ def invoke_all(self):
8080
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) \
8181
+ ''.join(["%s_%d" % (arg.dtype, len(arg.shape))
8282
for arg in args if hasattr(arg, 'shape')])
83-
if 'sum' in name:
84-
print(name)
8583
yield sch, args, name
8684

8785
def get_binds(self, args):

src/operator/numpy/np_broadcast_reduce_op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include <algorithm>
2929
#include <vector>
30+
#include <string>
3031
#include "../tensor/broadcast_reduce_op.h"
3132

3233
namespace mxnet {

src/operator/numpy/np_broadcast_reduce_op_value.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
4343

4444
if (param.dtype.has_value()) {
4545
if (in_attrs->at(0) == mshadow::kBool) {
46-
CHECK(param.dtype.value() == mshadow::kInt64 || param.dtype.value() == mshadow::kFloat32
46+
CHECK(param.dtype.value() == mshadow::kInt32
47+
|| param.dtype.value() == mshadow::kInt64
48+
|| param.dtype.value() == mshadow::kFloat32
4749
|| param.dtype.value() == mshadow::kFloat64) << "Only support the following output "
4850
"dtypes when input dtype is bool: "
4951
"int32, int64, float32, float64.";
@@ -110,7 +112,6 @@ void TVMOpReduce(const OpContext& ctx,
110112
<< (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kCPU ? "cpu" : "gpu")
111113
<< "reduce1st_dim_" << reduce1st_dim
112114
<< "req_" << (req == kWriteTo ? "kWriteTo" : "kAddTo");
113-
LOG(INFO) << "sum func name: " << func_name.str();
114115
tvm::runtime::TVMOpModule::Get()->Call(func_name.str(), ctx, {input_tvm, output_tvm, output_tvm});
115116
#else
116117
LOG(FATAL) << "Please add USE_TVM_OP=1 to enable kernels generated by TVM."

tests/python/unittest/test_numpy_op.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -232,27 +232,43 @@ def is_int(dtype):
232232
in_data_dim = random.choice([2, 3, 4])
233233
shape = rand_shape_nd(in_data_dim, dim=3)
234234
acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64',
235-
'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
235+
'int8': 'int32', 'int32': 'int64', 'int64': 'int64', 'bool': 'int64'}
236236
for hybridize in [False, True]:
237237
for keepdims in [True, False]:
238238
for axis in ([i for i in range(in_data_dim)] + [(), None]):
239-
for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']:
239+
for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool']:
240240
for dtype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']:
241-
if is_int(dtype) and not is_int(itype):
241+
print("==========================")
242+
print(shape)
243+
print(itype)
244+
print(axis)
245+
print(dtype)
246+
print(keepdims)
247+
print(hybridize)
248+
if (is_int(dtype) and not is_int(itype))\
249+
or (itype == 'bool' and dtype not in ('float32', 'float64', 'int32', 'int64')):
242250
continue
243251
# test gluon
244252
test_sum = TestSum(axis=axis, dtype=dtype, keepdims=keepdims)
245253
if hybridize:
246254
test_sum.hybridize()
247255
if is_int(itype):
248256
x = _np.random.randint(-128, 128, shape, dtype=itype)
249-
x = mx.nd.array(x)
257+
x = np.array(x)
258+
elif itype == 'bool':
259+
x = _np.random.randint(0, 2, shape) < 1
260+
x = np.array(x, dtype='bool')
250261
else:
251-
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
252-
x = x.as_np_ndarray()
253-
x.attach_grad()
262+
x = np.random.uniform(-1.0, 1.0, size=shape, dtype=itype)
254263
expected_ret = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
255264
expected_ret = expected_ret.astype(dtype)
265+
if itype == 'bool': # special handling of boolean ndarray
266+
y = test_sum(x)
267+
assert y.dtype == expected_ret.dtype
268+
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-4, atol=1e-5, use_broadcast=False)
269+
continue
270+
271+
x.attach_grad()
256272
with mx.autograd.record():
257273
y = test_sum(x)
258274
assert y.shape == expected_ret.shape

0 commit comments

Comments
 (0)