Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit a40e541

Browse files
author
Ubuntu
committed
* impl - FFi for linalg op
* fix - cpplint * impl - benchmark ffi for ops * rm - FFI for ops with param * fix - makefile * fix - not include unordered_map and use num_inputs * ci - compiler error
1 parent 34010ea commit a40e541

File tree

22 files changed

+445
-21
lines changed

22 files changed

+445
-21
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,9 @@ endif
464464

465465
all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages extension_libs
466466

467-
SRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc)
467+
SRC = $(wildcard src/*/*/*/*/*.cc src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc)
468468
OBJ = $(patsubst %.cc, build/%.o, $(SRC))
469-
CUSRC = $(wildcard src/*/*/*/*.cu src/*/*/*.cu src/*/*.cu src/*.cu)
469+
CUSRC = $(wildcard src/*/*/*/*.cu src/*/*/*/*.cu src/*/*/*.cu src/*/*.cu src/*.cu)
470470
CUOBJ = $(patsubst %.cu, build/%_gpu.o, $(CUSRC))
471471

472472
ifeq ($(USE_TVM_OP), 1)

benchmark/python/ffi/benchmark_ffi.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def prepare_workloads():
5555
OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2'])
5656
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
5757
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)
58+
OpArgMngr.add_workload("linalg.cholesky", pool['1x1'])
59+
OpArgMngr.add_workload("linalg.eigvals", pool['1x1'])
60+
OpArgMngr.add_workload("linalg.eigvalsh", pool['1x1'], UPLO='L')
61+
OpArgMngr.add_workload("linalg.inv", pool['1x1'])
62+
OpArgMngr.add_workload("linalg.pinv", pool['2x3x3'], pool['1'], hermitian=False)
63+
OpArgMngr.add_workload("linalg.solve", pool['1x1'], pool['1'])
64+
OpArgMngr.add_workload("linalg.tensorinv", pool['1x1'], ind=2)
65+
OpArgMngr.add_workload("linalg.tensorsolve", pool['1x1x1'], pool['1x1x1'], (2, 0, 1))
5866

5967

6068
def benchmark_helper(f, *args, **kwargs):

python/mxnet/ndarray/numpy/linalg.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from . import _op as _mx_nd_np
2121
from . import _internal as _npi
22+
from . import _api_internal
2223

2324
__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv',
2425
'eigvals', 'eig', 'eigvalsh', 'eigh']
@@ -91,9 +92,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
9192
"""
9293
if hermitian is True:
9394
raise NotImplementedError("hermitian is not supported yet...")
94-
if _mx_nd_np._np.isscalar(rcond):
95-
return _npi.pinv_scalar_rcond(a, rcond, hermitian)
96-
return _npi.pinv(a, rcond, hermitian)
95+
return _api_internal.pinv(a, rcond, hermitian)
9796

9897

9998
# pylint: disable=too-many-return-statements
@@ -332,7 +331,7 @@ def svd(a):
332331
return tuple(_npi.svd(a))
333332

334333

335-
def cholesky(a):
334+
def cholesky(a, lower=True):
336335
r"""
337336
Cholesky decomposition.
338337
@@ -388,7 +387,7 @@ def cholesky(a):
388387
array([[16., 4.],
389388
[ 4., 10.]])
390389
"""
391-
return _npi.cholesky(a)
390+
return _api_internal.cholesky(a, lower)
392391

393392

394393
def inv(a):
@@ -430,7 +429,7 @@ def inv(a):
430429
[[-1.2500001 , 0.75000006],
431430
[ 0.75000006, -0.25000003]]])
432431
"""
433-
return _npi.inv(a)
432+
return _api_internal.inv(a)
434433

435434

436435
def det(a):
@@ -594,7 +593,7 @@ def solve(a, b):
594593
>>> np.allclose(np.dot(a, x), b)
595594
True
596595
"""
597-
return _npi.solve(a, b)
596+
return _api_internal.solve(a, b)
598597

599598

600599
def tensorinv(a, ind=2):
@@ -649,7 +648,7 @@ def tensorinv(a, ind=2):
649648
>>> np.allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))
650649
True
651650
"""
652-
return _npi.tensorinv(a, ind)
651+
return _api_internal.tensorinv(a, ind)
653652

654653

655654
def tensorsolve(a, b, axes=None):
@@ -697,7 +696,7 @@ def tensorsolve(a, b, axes=None):
697696
>>> np.allclose(np.tensordot(a, x, axes=3), b)
698697
True
699698
"""
700-
return _npi.tensorsolve(a, b, axes)
699+
return _api_internal.tensorsolve(a, b, axes)
701700

702701

703702
def eigvals(a):
@@ -765,7 +764,7 @@ def eigvals(a):
765764
>>> LA.eigvals(A)
766765
array([ 1., -1.]) # random
767766
"""
768-
return _npi.eigvals(a)
767+
return _api_internal.eigvals(a)
769768

770769

771770
def eigvalsh(a, UPLO='L'):
@@ -824,7 +823,7 @@ def eigvalsh(a, UPLO='L'):
824823
>>> LA.eigvalsh(a, UPLO='L')
825824
array([-2.87381886, 5.10144682, 6.38623114]) # in ascending order
826825
"""
827-
return _npi.eigvalsh(a, UPLO)
826+
return _api_internal.eigvalsh(a, UPLO)
828827

829828

830829
def eig(a):

python/mxnet/numpy/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def svd(a):
232232
return _mx_nd_np.linalg.svd(a)
233233

234234

235-
def cholesky(a):
235+
def cholesky(a, lower=True):
236236
r"""
237237
Cholesky decomposition.
238238
@@ -288,7 +288,7 @@ def cholesky(a):
288288
array([[16., 4.],
289289
[ 4., 10.]])
290290
"""
291-
return _mx_nd_np.linalg.cholesky(a)
291+
return _mx_nd_np.linalg.cholesky(a, lower)
292292

293293

294294
def inv(a):

python/mxnet/symbol/numpy/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def svd(a):
322322
return _npi.svd(a)
323323

324324

325-
def cholesky(a):
325+
def cholesky(a, lower=True):
326326
r"""
327327
Cholesky decomposition.
328328
@@ -378,7 +378,7 @@ def cholesky(a):
378378
array([[16., 4.],
379379
[ 4., 10.]])
380380
"""
381-
return _npi.cholesky(a)
381+
return _npi.cholesky(a, lower)
382382

383383

384384
def inv(a):
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file np_eigvals.cc
22+
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_eigvals.cc
23+
*/
24+
#include <mxnet/api_registry.h>
25+
#include <mxnet/runtime/packed_func.h>
26+
#include "../../utils.h"
27+
#include "../../../../operator/numpy/linalg/np_eigvals-inl.h"
28+
29+
namespace mxnet {
30+
31+
MXNET_REGISTER_API("_npi.eigvals")
32+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
33+
using namespace runtime;
34+
const nnvm::Op* op = Op::Get("_npi_eigvals");
35+
nnvm::NodeAttrs attrs;
36+
attrs.op = op;
37+
int num_inputs = 1;
38+
int num_outputs = 0;
39+
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
40+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
41+
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
42+
});
43+
44+
MXNET_REGISTER_API("_npi.eigvalsh")
45+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
46+
using namespace runtime;
47+
const nnvm::Op* op = Op::Get("_npi_eigvalsh");
48+
nnvm::NodeAttrs attrs;
49+
op::EigvalshParam param;
50+
param.UPLO = *((args[1].operator std::string()).c_str());
51+
attrs.parsed = param;
52+
attrs.op = op;
53+
SetAttrDict<op::EigvalshParam>(&attrs);
54+
int num_inputs = 1;
55+
int num_outputs = 0;
56+
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
57+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
58+
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
59+
});
60+
61+
} // namespace mxnet
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file np_inv.cc
22+
* \brief Implementation of the API of functions in src/operator/tensor/la_op.cc
23+
*/
24+
#include <mxnet/api_registry.h>
25+
#include <mxnet/runtime/packed_func.h>
26+
#include "../../utils.h"
27+
28+
namespace mxnet {
29+
30+
MXNET_REGISTER_API("_npi.inv")
31+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
32+
using namespace runtime;
33+
const nnvm::Op* op = Op::Get("_npi_inv");
34+
nnvm::NodeAttrs attrs;
35+
attrs.op = op;
36+
int num_inputs = 1;
37+
int num_outputs = 0;
38+
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
39+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
40+
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
41+
});
42+
43+
} // namespace mxnet
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file np_pinv.cc
22+
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_pinv.cc
23+
*/
24+
#include <mxnet/api_registry.h>
25+
#include <mxnet/runtime/packed_func.h>
26+
#include "../../utils.h"
27+
#include "../../../../operator/numpy/linalg/np_pinv-inl.h"
28+
29+
namespace mxnet {
30+
31+
inline static void _npi_pinv(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
32+
using namespace runtime;
33+
const nnvm::Op* op = Op::Get("_npi_pinv");
34+
op::PinvParam param;
35+
nnvm::NodeAttrs attrs;
36+
param.hermitian = args[2].operator bool();
37+
attrs.parsed = param;
38+
attrs.op = op;
39+
SetAttrDict<op::PinvParam>(&attrs);
40+
int num_inputs = 2;
41+
int num_outputs = 0;
42+
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
43+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
44+
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
45+
}
46+
47+
inline static void _npi_pinv_scalar_rcond(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
48+
using namespace runtime;
49+
const nnvm::Op* op = Op::Get("_npi_pinv_scalar_rcond");
50+
op::PinvScalarRcondParam param;
51+
nnvm::NodeAttrs attrs;
52+
param.rcond = args[1].operator double();
53+
param.hermitian = args[2].operator bool();
54+
attrs.parsed = param;
55+
attrs.op = op;
56+
SetAttrDict<op::PinvScalarRcondParam>(&attrs);
57+
int num_inputs = 1;
58+
int num_outputs = 0;
59+
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
60+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
61+
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
62+
}
63+
64+
MXNET_REGISTER_API("_npi.pinv")
65+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
66+
if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) {
67+
_npi_pinv_scalar_rcond(args, ret);
68+
} else {
69+
_npi_pinv(args, ret);
70+
}
71+
});
72+
73+
} // namespace mxnet
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file np_potrf.cc
22+
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_potrf.cc
23+
*/
24+
#include <mxnet/api_registry.h>
25+
#include <mxnet/runtime/packed_func.h>
26+
#include "../../utils.h"
27+
#include "../../../../operator/numpy/linalg/np_potrf-inl.h"
28+
29+
namespace mxnet {
30+
31+
MXNET_REGISTER_API("_npi.cholesky")
32+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
33+
using namespace runtime;
34+
const nnvm::Op* op = Op::Get("_npi_cholesky");
35+
nnvm::NodeAttrs attrs;
36+
op::LaCholeskyParam param;
37+
param.lower = args[1].operator bool();
38+
attrs.parsed = param;
39+
attrs.op = op;
40+
SetAttrDict<op::LaCholeskyParam>(&attrs);
41+
int num_inputs = 1;
42+
int num_outputs = 0;
43+
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
44+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
45+
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
46+
});
47+
48+
} // namespace mxnet

0 commit comments

Comments
 (0)