Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit c33b8b0

Browse files
author
Ying
committed
numpy operator nonzero
* add cpu test and handle 0-dim * add FGradient with MakeZeroGradNodes * handle 0-dim and 0-shape and add test on gpu * add doc * fix bug in review * do not use thrust::inclusive_scan on cpu * fix format error * edit test and remove gpu test The output is same as numpy.transpose(numpy.nonzero(x)) * fix error of review
1 parent 90091b1 commit c33b8b0

File tree

5 files changed

+405
-0
lines changed

5 files changed

+405
-0
lines changed

python/mxnet/_numpy_op_doc.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,55 @@ def _np_cumsum(a, axis=None, dtype=None, out=None):
109109
>>> np.cumsum(a,axis=1) # sum over columns for each of the 2 rows
110110
array([[ 1, 3, 6],
111111
[ 4, 9, 15]])
112+
"""
113+
pass
114+
115+
116+
def _npx_nonzero(a):
117+
"""
118+
nonzero(a)
119+
120+
Return the indices of the elements that are non-zero.
121+
122+
Returns a ndarray with ndim is 2. Each row contains the indices
123+
of the non-zero elements. The values in `a` are always tested and returned in
124+
row-major, C-style order.
125+
126+
The result of this is always a 2-D array, with a row for
127+
each non-zero element.
128+
129+
Parameters
130+
----------
131+
a : array_like
132+
Input array.
133+
134+
Returns
135+
-------
136+
array : ndarray
137+
Indices of elements that are non-zero.
112138
139+
Notes
140+
-----
141+
This function differs from the original numpy.prod in the following aspects:
142+
- Do not support python numeric.
143+
- The return value is same as numpy.transpose(numpy.nonzero(a)).
144+
145+
Examples
146+
--------
147+
>>> x = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]])
148+
>>> x
149+
array([[3, 0, 0],
150+
[0, 4, 0],
151+
[5, 6, 0]])
152+
>>> npx.nonzero(x)
153+
array([[0, 0],
154+
[1, 1],
155+
[2, 0],
156+
[2, 1]], dtype=int64)
157+
158+
>>> np.transpose(npx.nonzero(x))
159+
array([[0, 1, 2, 2],
160+
[0, 1, 0, 1]], dtype=int64)
113161
"""
114162
pass
115163

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* Copyright (c) 2018 by Contributors
21+
* \file np_nonzero_op-inl.h
22+
*/
23+
24+
#ifndef MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_
25+
#define MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_
26+
27+
#include <dmlc/logging.h>
28+
#include <dmlc/parameter.h>
29+
#include <mxnet/operator.h>
30+
#include <mxnet/ndarray.h>
31+
#include <map>
32+
#include <vector>
33+
#include <string>
34+
#include <utility>
35+
#include <algorithm>
36+
#include "../operator_common.h"
37+
#include "../mxnet_op.h"
38+
#include "../tensor/init_op.h"
39+
#include "../mshadow_op.h"
40+
#include "../elemwise_op_common.h"
41+
42+
namespace mxnet {
43+
namespace op {
44+
45+
struct NonzeroForwardKernel {
46+
template<int ndim>
47+
MSHADOW_XINLINE static void Map(int i,
48+
int64_t* out,
49+
const int32_t* idx,
50+
const mshadow::Shape<ndim> shape) {
51+
int32_t prev = (i == 0) ? 0 : idx[i - 1];
52+
int32_t curr = idx[i];
53+
if (prev != curr) {
54+
mshadow::Shape<ndim> coord = mxnet_op::unravel<ndim>(i, shape);
55+
for (int j = 0; j < ndim; j++) {
56+
out[prev * ndim + j] = coord[j];
57+
}
58+
}
59+
}
60+
};
61+
62+
} // namespace op
63+
} // namespace mxnet
64+
65+
#endif // MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_

src/operator/numpy/np_nonzero_op.cc

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* Copyright (c) 2018 by Contributors
21+
* \file np_nonzero_op.cc
22+
*/
23+
#include "np_nonzero_op-inl.h"
24+
25+
namespace mxnet {
26+
namespace op {
27+
28+
bool NonzeroType(const nnvm::NodeAttrs& attrs,
29+
std::vector<int> *in_attrs,
30+
std::vector<int> *out_attrs) {
31+
CHECK_EQ(in_attrs->size(), 1);
32+
CHECK_EQ(out_attrs->size(), 1);
33+
// Output must be int64.
34+
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
35+
return out_attrs->at(0) != -1;
36+
}
37+
38+
#define MAXDIM 5
39+
40+
bool NonzeroStorageType(const nnvm::NodeAttrs& attrs,
41+
const int dev_mask,
42+
DispatchMode* dispatch_mode,
43+
std::vector<int> *in_attrs,
44+
std::vector<int> *out_attrs) {
45+
CHECK_EQ(in_attrs->size(), 1);
46+
CHECK_EQ(out_attrs->size(), 1);
47+
for (int &attr : *in_attrs) {
48+
CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
49+
}
50+
for (int &attr : *out_attrs) {
51+
attr = kDefaultStorage;
52+
}
53+
*dispatch_mode = DispatchMode::kFComputeEx;
54+
return true;
55+
}
56+
57+
void NonzeroForwardCPU(const nnvm::NodeAttrs& attrs,
58+
const OpContext &ctx,
59+
const std::vector<NDArray> &inputs,
60+
const std::vector<OpReqType> &req,
61+
const std::vector<NDArray> &outputs) {
62+
CHECK_EQ(inputs.size(), 1U);
63+
CHECK_EQ(outputs.size(), 1U);
64+
const NDArray &in = inputs[0];
65+
const NDArray &out = outputs[0];
66+
CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM;
67+
// 0-dim
68+
if (0 == in.shape().ndim()) {
69+
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
70+
DType* in_dptr = in.data().dptr<DType>();
71+
if (*in_dptr) {
72+
mxnet::TShape s(2, 1);
73+
const_cast<NDArray &>(out).Init(s);
74+
*(out.data().dptr<int64_t>()) = 0;
75+
} else {
76+
mxnet::TShape s(2, 1);
77+
s[0] = 0;
78+
const_cast<NDArray &>(out).Init(s);
79+
}
80+
});
81+
return;
82+
}
83+
size_t in_size = in.shape().Size();
84+
// 0-shape
85+
if (0 == in_size) {
86+
mxnet::TShape s(2, in.shape().ndim());
87+
s[0] = 0;
88+
const_cast<NDArray &>(out).Init(s);
89+
return;
90+
}
91+
std::vector<int32_t> prefix_sum(in_size, 0);
92+
size_t valid_num = 0;
93+
// Calculate prefix sum
94+
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
95+
DType* in_dptr = in.data().dptr<DType>();
96+
for (size_t i = 0; i < in_size; i++) {
97+
prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
98+
prefix_sum[i] += (in_dptr[i]) ? 1 : 0;
99+
}
100+
});
101+
valid_num = prefix_sum[in_size - 1];
102+
// set the output shape forcefully
103+
mxnet::TShape s(2, in.shape().ndim());
104+
s[0] = valid_num;
105+
const_cast<NDArray &>(out).Init(s);
106+
// get the shape from the input
107+
MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, {
108+
mshadow::Shape<ndim> shape = in.shape().get<ndim>();
109+
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
110+
mxnet_op::Kernel<NonzeroForwardKernel, cpu>::Launch(
111+
stream, in_size, out.data().dptr<int64_t>(), prefix_sum.data(), shape);
112+
})
113+
}
114+
115+
NNVM_REGISTER_OP(_npx_nonzero)
116+
.set_num_inputs(1)
117+
.set_num_outputs(1)
118+
.set_attr<nnvm::FListInputNames>("FListInputNames",
119+
[](const NodeAttrs& attrs) {
120+
return std::vector<std::string>{"x"};
121+
})
122+
.set_attr<nnvm::FInferType>("FInferType", NonzeroType)
123+
.set_attr<FComputeEx>("FComputeEx<cpu>", NonzeroForwardCPU)
124+
.set_attr<FInferStorageType>("FInferStorageType", NonzeroStorageType)
125+
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
126+
.add_argument("x", "NDArray-or-Symbol", "The input array.");
127+
128+
} // namespace op
129+
} // namespace mxnet

src/operator/numpy/np_nonzero_op.cu

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* Copyright (c) 2018 by Contributors
21+
* \file np_nonzero_op.cu
22+
*/
23+
24+
#include "np_nonzero_op-inl.h"
25+
#include <cub/cub.cuh>
26+
27+
namespace mxnet {
28+
namespace op {
29+
30+
struct PrefixSumInit {
31+
template<typename DType>
32+
MSHADOW_XINLINE static void Map(int i,
33+
int32_t* out,
34+
DType* in) {
35+
if (in[i]) {
36+
out[i] = 1;
37+
} else {
38+
out[i] = 0;
39+
}
40+
}
41+
};
42+
43+
#define MAXDIM 5
44+
45+
void NonzeroForwardGPU(const nnvm::NodeAttrs& attrs,
46+
const OpContext &ctx,
47+
const std::vector<NDArray> &inputs,
48+
const std::vector<OpReqType> &req,
49+
const std::vector<NDArray> &outputs) {
50+
using namespace mshadow;
51+
CHECK_EQ(inputs.size(), 1U);
52+
CHECK_EQ(outputs.size(), 1U);
53+
const NDArray &in = inputs[0];
54+
const NDArray &out = outputs[0];
55+
CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM;
56+
size_t in_size = in.shape().Size();
57+
// 0-shape
58+
if (0 == in_size) {
59+
mxnet::TShape s(2, in.shape().ndim());
60+
s[0] = 0;
61+
const_cast<NDArray &>(out).Init(s);
62+
return;
63+
}
64+
int32_t valid_num = 0;
65+
Stream<gpu>* stream = ctx.get_stream<gpu>();
66+
int32_t* prefix_sum = nullptr;
67+
void* d_temp_storage = nullptr;
68+
size_t temp_storage_bytes = 0;
69+
// Calculate total temporary memory size
70+
cub::DeviceScan::InclusiveSum(d_temp_storage,
71+
temp_storage_bytes,
72+
prefix_sum,
73+
prefix_sum,
74+
in_size,
75+
Stream<gpu>::GetStream(stream));
76+
size_t buffer_size = in_size * sizeof(int32_t);
77+
temp_storage_bytes += buffer_size;
78+
// Allocate memory on GPU and allocate pointer
79+
Tensor<gpu, 1, char> workspace =
80+
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), stream);
81+
prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
82+
d_temp_storage = workspace.dptr_ + buffer_size;
83+
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
84+
mxnet_op::Kernel<PrefixSumInit, gpu>::Launch(
85+
stream, in_size, prefix_sum, in.data().dptr<DType>());
86+
});
87+
// Calculate prefix sum
88+
cub::DeviceScan::InclusiveSum(d_temp_storage,
89+
temp_storage_bytes,
90+
prefix_sum,
91+
prefix_sum,
92+
in_size,
93+
Stream<gpu>::GetStream(stream));
94+
CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[in_size - 1], sizeof(int32_t),
95+
cudaMemcpyDeviceToHost));
96+
// 0-dim
97+
if (0 == in.shape().ndim()) {
98+
mxnet::TShape s(2, 1);
99+
if (valid_num) {
100+
const_cast<NDArray &>(out).Init(s);
101+
int64_t temp = 0;
102+
CUDA_CALL(cudaMemcpy(out.data().dptr<int64_t>(), &temp, sizeof(int64_t),
103+
cudaMemcpyHostToDevice));
104+
} else {
105+
s[0] = 0;
106+
const_cast<NDArray &>(out).Init(s);
107+
}
108+
return;
109+
}
110+
// Set the output shape forcefully
111+
mxnet::TShape s(2, in.shape().ndim());
112+
s[0] = valid_num;
113+
const_cast<NDArray &>(out).Init(s);
114+
// get the shape from the input
115+
MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, {
116+
mshadow::Shape<ndim> shape = in.shape().get<ndim>();
117+
mxnet_op::Kernel<NonzeroForwardKernel, gpu>::Launch(
118+
stream, in_size, out.data().dptr<int64_t>(), prefix_sum, shape);
119+
})
120+
}
121+
122+
NNVM_REGISTER_OP(_npx_nonzero)
123+
.set_attr<FResourceRequest>("FResourceRequest",
124+
[](const NodeAttrs& attrs) {
125+
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
126+
})
127+
.set_attr<FComputeEx>("FComputeEx<gpu>", NonzeroForwardGPU);
128+
129+
} // namespace op
130+
} // namespace mxnet

0 commit comments

Comments
 (0)