Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 3545697

Browse files
authored
TVM bridge support to JIT NDArray Function by TVM (#9880)
* TVM bridge support. Support wrap TVM compiled function as a NDArray function. * Testcases and CI to include TVM as dependency * address review comments * Add more comments, change to constexpr * change to log warn * update comment on the type code
1 parent dc39343 commit 3545697

File tree

12 files changed

+349
-10
lines changed

12 files changed

+349
-10
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ include_directories("include")
234234
include_directories("mshadow")
235235
include_directories("3rdparty/cub")
236236
include_directories("nnvm/include")
237+
include_directories("nnvm/tvm/include")
237238
include_directories("dmlc-core/include")
238239
include_directories("dlpack/include")
239240

@@ -696,4 +697,3 @@ endif()
696697
set(LINT_DIRS "include src plugin cpp-package tests")
697698
set(EXCLUDE_PATH "src/operator/contrib/ctc_include")
698699
add_custom_target(mxnet_lint COMMAND ${CMAKE_COMMAND} -DMSVC=${MSVC} -DPYTHON_EXECUTABLE=${PYTHON_EXECUTABLE} -DLINT_DIRS=${LINT_DIRS} -DPROJECT_SOURCE_DIR=${CMAKE_CURRENT_SOURCE_DIR} -DPROJECT_NAME=mxnet -DEXCLUDE_PATH=${EXCLUDE_PATH} -P ${CMAKE_CURRENT_SOURCE_DIR}/dmlc-core/cmake/lint.cmake)
699-

Jenkinsfile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ def init_git() {
3838
deleteDir()
3939
retry(5) {
4040
try {
41-
// Make sure wait long enough for api.github.com request quota. Important: Don't increase the amount of
41+
// Make sure wait long enough for api.github.com request quota. Important: Don't increase the amount of
4242
// retries as this will increase the amount of requests and worsen the throttling
4343
timeout(time: 15, unit: 'MINUTES') {
4444
checkout scm
45-
sh 'git submodule update --init'
46-
sh 'git clean -d -f'
45+
sh 'git submodule update --init --recursive'
46+
sh 'git clean -d -f'
4747
}
4848
} catch (exc) {
4949
deleteDir()
@@ -61,8 +61,8 @@ def init_git_win() {
6161
// retries as this will increase the amount of requests and worsen the throttling
6262
timeout(time: 15, unit: 'MINUTES') {
6363
checkout scm
64-
bat 'git submodule update --init'
65-
bat 'git clean -d -f'
64+
bat 'git submodule update --init --recursive'
65+
bat 'git clean -d -f'
6666
}
6767
} catch (exc) {
6868
deleteDir()

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ ifeq ($(DEBUG), 1)
9191
else
9292
CFLAGS += -O3 -DNDEBUG=1
9393
endif
94-
CFLAGS += -I$(ROOTDIR)/mshadow/ -I$(ROOTDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -Iinclude $(MSHADOW_CFLAGS)
94+
CFLAGS += -I$(ROOTDIR)/mshadow/ -I$(ROOTDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(NNVM_PATH)/tvm/include -Iinclude $(MSHADOW_CFLAGS)
9595
LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)
9696
ifeq ($(DEBUG), 1)
9797
NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
@@ -356,7 +356,7 @@ ifeq ($(USE_CUDA), 1)
356356
LDFLAGS += -lcuda -lnvrtc
357357
CFLAGS += -DMXNET_ENABLE_CUDA_RTC=1
358358
endif
359-
# Make sure to add stubs as fallback in order to be able to build
359+
# Make sure to add stubs as fallback in order to be able to build
360360
# without full CUDA install (especially if run without nvidia-docker)
361361
LDFLAGS += -L/usr/local/cuda/lib64/stubs
362362
SCALA_PKG_PROFILE := $(SCALA_PKG_PROFILE)-gpu

include/mxnet/tensor_blob.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,18 @@
3636
#include <utility>
3737
#include <algorithm>
3838
#include "./base.h"
39+
3940
namespace mxnet {
4041

42+
// redefine DLPack enumeration to be backward compatible.
43+
constexpr const int kCPU = kDLCPU;
44+
constexpr const int kGPU = kDLGPU;
45+
// extension type code under TVM function.
46+
// Currently NNVM reserved 16 to 19 type code from TVM
47+
// 16, 17, 18 is used by NNVM compiler already.
48+
// Pick code 19 for MXNet NDArray
49+
constexpr const int kTVMNDArrayTypeCode = 19;
50+
4151
/* Forward declaration for friend declaration in TBlob */
4252
class NDArray;
4353

nnvm

Submodule nnvm updated 73 files

python/mxnet/ndarray/ndarray.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,15 @@ class NDArray(NDArrayBase):
174174
__slots__ = []
175175
# make numpy functions return NDArray instead of numpy object array
176176
__array_priority__ = 1000.0
177+
# Extension type code for TVM function.
178+
# See C++ side of definition(kTVMNDArrayTypeCode) at include/mxmet/tensor_blob.h
179+
_tvm_tcode = 19
177180
# pylint: disable= no-member, undefined-variable
178181

182+
@property
183+
def _tvm_handle(self):
184+
return self.handle.value
185+
179186
def __repr__(self):
180187
"""Returns a string representation of the array."""
181188
shape_info = 'x'.join(['%d' % x for x in self.shape])

src/nnvm/tvm_bridge.cc

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm_bridge.cc
22+
* \brief Bridge to run TVM's PackedFunc in MXNet's async engine.
23+
*
24+
* This bridge is mainly used to expose MXNet's async engine push to
25+
* TVM. It only uses TVM runtime in aheader only mode, which means
26+
* there is no link dependencies.
27+
*
28+
* Support for TVM is optional even when this code
29+
* is always compiled and built with the project.
30+
* We choose this strategy because we do not yet want
31+
* llvm as dependency(which TVM uses). So instead we expose hook
32+
* to TVM and let user use this feature when they have TVM installed.
33+
*
34+
* We do require TVM and MXNet to be built with same C++ ABI of std::function
35+
*/
36+
#define TVM_RUNTIME_HEADER_ONLY 1
37+
#include <tvm/runtime/packed_func.h>
38+
#include <mxnet/c_api.h>
39+
#include <mxnet/ndarray.h>
40+
#include <mxnet/engine.h>
41+
42+
#include <memory>
43+
44+
namespace mxnet {
45+
46+
using tvm::runtime::PackedFunc;
47+
using tvm::runtime::TVMArgs;
48+
using tvm::runtime::TVMRetValue;
49+
50+
/*!
51+
* \brief Async functor object
52+
* calling argument of the function.
53+
*/
54+
class TVMFunctor {
55+
public:
56+
// constructor
57+
explicit TVMFunctor(PackedFunc func, PackedFunc fset_stream)
58+
: func_(func), fset_stream_(fset_stream) {}
59+
60+
void Init(const TVMArgs& args,
61+
const std::vector<int>& const_loc,
62+
std::vector<Engine::VarHandle>* const_vars,
63+
std::vector<Engine::VarHandle>* mutate_vars) {
64+
values_.clear();
65+
type_codes_.clear();
66+
values_.insert(values_.end(), args.values, args.values + args.size());
67+
type_codes_.insert(
68+
type_codes_.end(), args.type_codes, args.type_codes + args.size());
69+
70+
size_t const_loc_ptr = 0;
71+
for (int i = 0; i < args.size(); ++i) {
72+
if (args.type_codes[i] == kTVMNDArrayTypeCode) {
73+
const NDArray& nd =
74+
static_cast<NDArray*>(args.values[i].v_handle)[0];
75+
// We cannot set the value until
76+
type_codes_[i] = kArrayHandle;
77+
array_data_.push_back(nd);
78+
array_loc_.push_back(i);
79+
// check if there is read or mutate
80+
// by default assume we mutate the array.
81+
if (const_loc_ptr < const_loc.size() &&
82+
i == const_loc[const_loc_ptr]) {
83+
const_vars->push_back(nd.var());
84+
++const_loc_ptr;
85+
} else {
86+
mutate_vars->push_back(nd.var());
87+
}
88+
} else {
89+
CHECK_LT(args.type_codes[i], kTVMType)
90+
<< "Only allow POD type in mxnet async call";
91+
}
92+
}
93+
}
94+
95+
Context ctx() {
96+
return array_data_[0].ctx();
97+
}
98+
99+
void Run(const RunContext& rctx) {
100+
// setup DLTensor
101+
for (size_t i = 0; i < array_loc_.size(); ++i) {
102+
values_[array_loc_[i]].v_handle =
103+
const_cast<DLTensor*>(&(array_data_[i].data().dltensor()));
104+
}
105+
// run the packed function
106+
TVMRetValue rv;
107+
TVMArgs args(&values_[0], &type_codes_[0], values_.size());
108+
if (ctx().dev_type == Context::kGPU) {
109+
#if MXNET_USE_CUDA
110+
// pass stream via last argument.
111+
void* strm = static_cast<void*>(rctx.get_stream<gpu>()->stream_);
112+
int dev_type = kDLGPU;
113+
fset_stream_(dev_type, rctx.ctx.dev_id, strm);
114+
func_.CallPacked(args, &rv);
115+
fset_stream_(dev_type, rctx.ctx.dev_id, nullptr);
116+
#else
117+
LOG(FATAL) << "Please compile with CUDA enabled for cuda features";
118+
#endif
119+
} else {
120+
func_.CallPacked(args, &rv);
121+
}
122+
}
123+
124+
private:
125+
/*! \brief The function */
126+
PackedFunc func_;
127+
/*! \brief Set stream */
128+
PackedFunc fset_stream_;
129+
/*! \brief Values field */
130+
std::vector<TVMValue> values_;
131+
/*! \brief type code field */
132+
std::vector<int> type_codes_;
133+
/*! \brief arrays field */
134+
std::vector<NDArray> array_data_;
135+
/*! \brief position of array in arguments */
136+
std::vector<int> array_loc_;
137+
};
138+
139+
140+
// Wrap a TVM function to a function that invokes MXNet's Engine
141+
// It does two things: call the engine properly
142+
// set up the NDArray to DLTensor during invocation.
143+
void WrapAsyncCall(TVMArgs wrap_args, TVMRetValue* wrap_rv) {
144+
PackedFunc f = wrap_args[0];
145+
PackedFunc fset_stream = wrap_args[1];
146+
int num_const = wrap_args[2];
147+
148+
// sorted position of constant arguments
149+
std::vector<int> const_loc;
150+
for (int i = 0; i < num_const; ++i) {
151+
const_loc.push_back(wrap_args[i + 3].operator int());
152+
}
153+
std::sort(const_loc.begin(), const_loc.end());
154+
// wrapped function
155+
// This is the function that called by the user.
156+
auto wrapped = [f, fset_stream, const_loc](TVMArgs args, TVMRetValue* rv) {
157+
std::shared_ptr<TVMFunctor> func =
158+
std::make_shared<TVMFunctor>(f, fset_stream);
159+
std::vector<Engine::VarHandle> const_vars, mutate_vars;
160+
func->Init(args, const_loc, &const_vars, &mutate_vars);
161+
Engine *engine = Engine::Get();
162+
engine->DeduplicateVarHandle(&const_vars, &mutate_vars);
163+
engine->PushSync([func](RunContext ctx) {
164+
func->Run(ctx);
165+
}, func->ctx(), const_vars, mutate_vars);
166+
};
167+
*wrap_rv = PackedFunc(wrapped);
168+
}
169+
170+
} // namespace mxnet
171+
172+
// C callback that can be used by TVM to extract
173+
// the WrapAsyncCall function.
174+
extern "C" MXNET_DLL int MXTVMBridge(TVMFunctionHandle pregister) {
175+
using tvm::runtime::PackedFunc;
176+
const PackedFunc& fregister =
177+
*static_cast<PackedFunc*>(pregister);
178+
fregister("WrapAsyncCall", PackedFunc(mxnet::WrapAsyncCall));
179+
return 0;
180+
}

tests/ci_build/Dockerfile.gpu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ COPY install/ubuntu_install_r.sh /install/
1212
RUN /install/ubuntu_install_r.sh
1313
COPY install/ubuntu_install_perl.sh /install/
1414
RUN /install/ubuntu_install_perl.sh
15+
16+
COPY install/ubuntu_install_llvm.sh /install/
17+
RUN /install/ubuntu_install_llvm.sh
18+
19+
COPY install/ubuntu_install_tvm.sh /install/
20+
RUN /install/ubuntu_install_tvm.sh
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/usr/bin/env bash
2+
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
20+
21+
22+
echo deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-5.0 main\
23+
>> /etc/apt/sources.list.d/llvm.list
24+
echo deb-src http://apt.llvm.org/xenial/ llvm-toolchain-xenial-5.0 main\
25+
>> /etc/apt/sources.list.d/llvm.list
26+
27+
wget -O - http://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add -
28+
apt-get update && apt-get install -y --force-yes llvm-5.0
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/usr/bin/env bash
2+
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
20+
# Build and install TVM
21+
cd /tmp
22+
git clone https://github.com/dmlc/tvm/ --recursive
23+
cd tvm
24+
25+
# This is a stable tag that support MXNet TVM bridge.
26+
# We use this since support for mxnet bridge just checked
27+
# into master and there is yet a version tag
28+
git checkout 30eaf463e34d7c301357c31a010945d11df16537
29+
30+
cp make/config.mk
31+
echo USE_CUDA=1 >> config.mk
32+
echo LLVM_CONFIG=llvm-config-5.0 >> config.mk
33+
echo USE_RPC=1 >> config.mk
34+
echo USE_GRAPH_RUNTIME=1 >> config.mk
35+
echo CUDA_PATH=/usr/local/cuda >> config.mk
36+
make -j`nproc`
37+
38+
cd python
39+
python setup.py install
40+
cd -
41+
42+
cd topi/python
43+
python setup.py install
44+
cd -

0 commit comments

Comments
 (0)