Skip to content

Commit 30eaf46

Browse files
authored
MXNet NDArray bridge. (#930)
* MXNet NDArray bridge. Support convert a tvm Function as MXNet's async NDArray function. * fix lint * update comment
1 parent d9e4ccc commit 30eaf46

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

python/tvm/contrib/mxnet.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""MXNet bridge wrap Function MXNet's async function."""
2+
from __future__ import absolute_import as _abs
3+
4+
from .. import api, _api_internal, ndarray
5+
from ..module import Module
6+
7+
# pylint: disable=invalid-name
8+
_wrap_async = None
9+
10+
11+
def to_mxnet_func(func, const_loc=None):
12+
"""Wrap a TVM function as MXNet function
13+
14+
MXNet function runs asynchrously via its engine.
15+
16+
Parameters
17+
----------
18+
func : Function
19+
A TVM function that can take positional arguments
20+
21+
const_loc : list of int
22+
List of integers indicating the argument position
23+
of read only NDArray argument.
24+
The NDArray argument location that are not annotated
25+
will be viewed as mutable arrays in MXNet's engine.
26+
27+
Returns
28+
-------
29+
async_func : Function
30+
A function that can take MXNet NDArray as argument
31+
in places that used to expect TVM NDArray.
32+
Run asynchrously in MXNet's async engine.
33+
"""
34+
# only import mxnet when wrap get called.
35+
# pylint: disable=import-self
36+
import mxnet
37+
if isinstance(func, Module):
38+
func = func.entry_func
39+
40+
def _get_bridge_func():
41+
"""Get MXNet bridge function"""
42+
if not mxnet.base._LIB.MXTVMBridge:
43+
raise RuntimeError(
44+
"MXTVMBridge not exist in mxnet package,"
45+
" please update to latest version")
46+
47+
fdict = api.extract_ext_funcs(mxnet.base._LIB.MXTVMBridge)
48+
ret = fdict["WrapAsyncCall"]
49+
ret.is_global = True
50+
return ret
51+
global _wrap_async
52+
53+
if _wrap_async is None:
54+
# Register extension type in first time
55+
_wrap_async = _get_bridge_func()
56+
ndarray.register_extension(mxnet.nd.NDArray)
57+
58+
const_loc = const_loc if const_loc else []
59+
return _wrap_async(func, _api_internal._TVMSetStream, len(const_loc), *const_loc)

src/api/api_base.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,9 @@ TVM_REGISTER_API("_load_json")
3636
TVM_REGISTER_API("_nop")
3737
.set_body([](TVMArgs args, TVMRetValue *ret) {
3838
});
39+
40+
TVM_REGISTER_API("_TVMSetStream")
41+
.set_body([](TVMArgs args, TVMRetValue *ret) {
42+
TVMSetStream(args[0], args[1], args[2]);
43+
});
3944
} // namespace tvm
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
def mxnet_check():
2+
"""This is a simple test function for MXNet bridge
3+
4+
It is not included as nosetests, because of its dependency on mxnet
5+
6+
User can directly run this script to verify correctness.
7+
"""
8+
import mxnet as mx
9+
import topi
10+
import tvm
11+
import numpy as np
12+
from tvm.contrib.mxnet import to_mxnet_func
13+
14+
# build a TVM function through topi
15+
n = 20
16+
shape = (20,)
17+
scale = tvm.var("scale", dtype="float32")
18+
x = tvm.placeholder(shape)
19+
y = tvm.placeholder(shape)
20+
z = topi.broadcast_add(x, y)
21+
zz = tvm.compute(shape, lambda *i: z(*i) * scale)
22+
23+
target = tvm.target.cuda()
24+
25+
# build the function
26+
with target:
27+
s = topi.generic.schedule_injective(zz)
28+
f = tvm.build(s, [x, y, zz, scale])
29+
30+
# get a mxnet version
31+
mxf = to_mxnet_func(f, const_loc=[0, 1])
32+
33+
ctx = mx.gpu(0)
34+
xx = mx.nd.uniform(shape=shape, ctx=ctx)
35+
yy = mx.nd.uniform(shape=shape, ctx=ctx)
36+
zz = mx.nd.empty(shape=shape, ctx=ctx)
37+
38+
# invoke myf: this runs in mxnet engine
39+
mxf(xx, yy, zz, 10.0)
40+
mxf(xx, yy, zz, 10.0)
41+
42+
43+
np.testing.assert_allclose(
44+
zz.asnumpy(), (xx.asnumpy() + yy.asnumpy()) * 10)
45+
46+
47+
if __name__ == "__main__":
48+
mxnet_check()

0 commit comments

Comments
 (0)