Skip to content

Commit 0b1ec85

Browse files
hzfanyajiedesign
authored andcommitted
Infra for tvm op runtime dispatch (apache#16100)
* infra for dispatch tvm op * fix ci and sanity error * disable shape with hint and fix coding style * rename to avoid conflict with original dot * update tvm and use soft link * config file moves to lib/ when using Makefile * add tvmop.conf to ci * fix rebase * fix rebase * use inspect to detect dispatchable func
1 parent 9d2d1a7 commit 0b1ec85

File tree

20 files changed

+841
-25
lines changed

20 files changed

+841
-25
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ if(USE_TVM_OP)
763763
endif()
764764
endif()
765765

766-
set(TVM_OP_COMPILE_OPTIONS "-o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so")
766+
set(TVM_OP_COMPILE_OPTIONS "-o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so" "--config" "${CMAKE_CURRENT_BINARY_DIR}/tvmop.conf")
767767
if(CUDA_ARCH_BIN)
768768
set(TVM_OP_COMPILE_OPTIONS "${TVM_OP_COMPILE_OPTIONS}" "--cuda-arch" "${CUDA_ARCH_BIN}")
769769
endif()

Makefile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,7 @@ DMLCCORE:
622622

623623
lib/libtvm_runtime.so:
624624
echo "Compile TVM"
625+
@mkdir -p $(@D)
625626
[ -e $(LLVM_PATH)/bin/llvm-config ] || sh $(ROOTDIR)/contrib/tvmop/prepare_tvm.sh; \
626627
cd $(TVM_PATH)/build; \
627628
cmake -DUSE_LLVM="$(LLVM_PATH)/bin/llvm-config" \
@@ -632,12 +633,13 @@ lib/libtvm_runtime.so:
632633
ls $(ROOTDIR)/lib; \
633634
cd $(ROOTDIR)
634635

635-
TVM_OP_COMPILE_OPTIONS = -o $(ROOTDIR)/lib/libtvmop.so
636+
TVM_OP_COMPILE_OPTIONS = -o $(ROOTDIR)/lib/libtvmop.so --config $(ROOTDIR)/lib/tvmop.conf
636637
ifneq ($(CUDA_ARCH),)
637638
TVM_OP_COMPILE_OPTIONS += --cuda-arch "$(CUDA_ARCH)"
638639
endif
639640
lib/libtvmop.so: lib/libtvm_runtime.so $(wildcard contrib/tvmop/*/*.py contrib/tvmop/*.py)
640641
echo "Compile TVM operators"
642+
@mkdir -p $(@D)
641643
PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib \
642644
LD_LIBRARY_PATH=$(ROOTDIR)/lib \
643645
python3 $(ROOTDIR)/contrib/tvmop/compile.py $(TVM_OP_COMPILE_OPTIONS)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import time
19+
import mxnet as mx
20+
import numpy as _np
21+
from mxnet import np, npx
22+
23+
def measure_cost(repeat, func_name, *args, **kwargs):
24+
"""Measure time cost of running a function
25+
"""
26+
mx.nd.waitall()
27+
start = time.time()
28+
for _ in range(repeat):
29+
func_name(*args, **kwargs)
30+
mx.nd.waitall()
31+
end = time.time()
32+
diff = end - start
33+
return diff / repeat
34+
35+
36+
def test_tvm_dot():
37+
# benchmark
38+
for i in list(range(1000, 1100, 4)):
39+
m = i
40+
k = i
41+
n = i
42+
print("{} * {} X {} * {}".format(m, k, k, n))
43+
a = mx.nd.random.uniform(shape=(m, k), dtype='float32')
44+
b = mx.nd.random.uniform(shape=(k, n), dtype='float32')
45+
cost = measure_cost(2, mx.nd.contrib.tvm_dot, a, b)
46+
print("dispatch cost: {} ms".format(cost * 1000))
47+
a = mx.nd.random.uniform(shape=(m, k), dtype='float32')
48+
b = mx.nd.random.uniform(shape=(k, n), dtype='float32')
49+
cost = measure_cost(2, mx.nd.contrib.tvm_dot_fallback, a, b)
50+
print("fallback cost: {} ms".format(cost * 1000))
51+
a = mx.nd.random.uniform(shape=(m, k), dtype='float32')
52+
b = mx.nd.random.uniform(shape=(k, n), dtype='float32')
53+
cost = measure_cost(2, mx.nd.dot, a, b)
54+
print("dot cost: {} ms".format(cost * 1000))
55+
56+
if __name__ == "__main__":
57+
test_tvm_dot()

ci/jenkins/Jenkins_steps.groovy

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,24 @@
2323
utils = load('ci/Jenkinsfile_utils.groovy')
2424

2525
// mxnet libraries
26-
mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
27-
mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
26+
mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
27+
mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
2828

2929
// Python wheels
3030
mx_pip = 'build/*.whl'
3131

3232
// mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default.
33-
mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
33+
mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
3434
mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
35-
mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
35+
mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
3636
// mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default.
37-
mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
38-
mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0'
39-
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
40-
mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
41-
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
37+
mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
38+
mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0'
39+
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
40+
mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
41+
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
4242
mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
43-
mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/cpp-package/example/*'
43+
mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/cpp-package/example/*'
4444

4545
// Python unittest for CPU
4646
// Python 2

contrib/tvmop/compile.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818
# coding: utf-8
1919
"""TVM Operator compile entry point"""
2020
import tvm
21+
from tvm import autotvm
2122

2223
import os
2324
import argparse
2425
import re
26+
import json
2527
import logging
2628
from tvmop.opdef import __OP_DEF__
29+
from tvmop.space import ConfigSpaces, ConfigSpace
2730
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
2831

2932
logging.basicConfig(level=logging.INFO)
@@ -70,6 +73,8 @@ def get_cuda_arch(arch):
7073
help="Target path which stores compiled library")
7174
parser.add_argument('--cuda-arch', type=str, default=None, dest='cuda_arch',
7275
help='The cuda arch for compiling kernels for')
76+
parser.add_argument("--config", action="store", required=True, dest="config_path",
77+
help="Path which stores the config file")
7378
arguments = parser.parse_args()
7479

7580
func_list_llvm = []
@@ -78,6 +83,7 @@ def get_cuda_arch(arch):
7883
# TODO: attach instruction features to the library, e.g., avx-512, etc.
7984
for operator_def in __OP_DEF__:
8085
for sch, args, name in operator_def.invoke_all():
86+
name = operator_def.get_op_name(name, args)
8187
if tvm.module.enabled(get_target(operator_def.target)):
8288
func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda
8389
func_lower = tvm.lower(sch, args,
@@ -96,3 +102,10 @@ def get_cuda_arch(arch):
96102
set_cuda_target_arch(cuda_arch)
97103
func_binary = tvm.build(lowered_funcs, name="tvmop")
98104
func_binary.export_library(arguments.target_path)
105+
106+
config_spaces = ConfigSpaces()
107+
for operator_def in __OP_DEF__:
108+
for config_space, name in operator_def.get_config_spaces():
109+
config_spaces[name] = ConfigSpace.from_tvm(config_space)
110+
with open(arguments.config_path, "w") as f:
111+
json.dump(config_spaces.to_json_dict(), f)

contrib/tvmop/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from . import umath, fromnumeric
18+
from . import umath, fromnumeric, multiarray

contrib/tvmop/core/multiarray.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# coding: utf-8
19+
import tvm
20+
from tvm import autotvm
21+
from .. import defop, AllTypes
22+
from .. import assign_by_req, reduce_axes
23+
24+
def compute_dot(A, B):
25+
M = A.shape[0]
26+
K = A.shape[1]
27+
N = B.shape[1]
28+
k = tvm.reduce_axis((0, K), 'k')
29+
C = tvm.compute((M, N),
30+
lambda x, y: tvm.sum(A[x, k] * B[k, y], axis=k),
31+
name='C')
32+
return C
33+
34+
35+
@defop(name="dot", target="cpu", dtype=AllTypes)
36+
def dot(dtype, fallback):
37+
cfg = autotvm.get_config()
38+
cfg.define_knob("bn", [64] if fallback else [64, 32])
39+
cfg.define_knob("factor", [4] if fallback else [4])
40+
M = tvm.var("M")
41+
K = tvm.var("K")
42+
N = tvm.var("N")
43+
A = tvm.placeholder((M, K), name='A', dtype=dtype)
44+
B = tvm.placeholder((K, N), name='B', dtype=dtype)
45+
C = compute_dot(A, B)
46+
s = tvm.create_schedule(C.op)
47+
# Blocking by loop tiling
48+
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], cfg["bn"].val, cfg["bn"].val)
49+
k, = s[C].op.reduce_axis
50+
ko, ki = s[C].split(k, factor=cfg["factor"].val)
51+
# Hoist reduction domain outside the blocking loop
52+
s[C].reorder(xo, yo, ko, ki, xi, yi)
53+
return s, [A, B, C]

contrib/tvmop/opdef.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
# coding: utf-8
1919
import tvm
20+
import inspect
21+
from tvm import autotvm
2022
from itertools import product
2123

2224
__OP_DEF__ = []
@@ -68,19 +70,49 @@ def __init__(self, func, name, target, auto_broadcast, **kwargs):
6870
self.name = name
6971
self.target = target
7072
self.auto_broadcast = auto_broadcast
73+
self.dispatchable = 'fallback' in inspect.signature(self.func).parameters
7174

7275
def __call__(self, *args, **kwargs):
7376
return self.func(*args, **kwargs)
7477

7578
def invoke_all(self):
7679
for each_kwargs in self.arg_combination:
7780
if self.attrs_valid(**each_kwargs):
78-
sch, args = self.func(**each_kwargs)
7981
name = self.name \
80-
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) \
81-
+ ''.join(["%s_%d" % (arg.dtype, len(arg.shape))
82-
for arg in args if hasattr(arg, 'shape')])
83-
yield sch, args, name
82+
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs])
83+
if self.dispatchable is False:
84+
sch, args = self.func(**each_kwargs)
85+
yield sch, args, name
86+
else:
87+
# register dispatch schedules
88+
config_space = autotvm.ConfigSpace()
89+
with autotvm.task.ApplyConfig(config_space):
90+
sch, args = self.func(fallback=False, **each_kwargs)
91+
for i in range(len(config_space)):
92+
config_entity = config_space.get(i)
93+
with autotvm.task.ApplyConfig(config_entity):
94+
sch, args = self.func(fallback=False, **each_kwargs)
95+
subname = name + "index_" + str(i)
96+
yield sch, args, subname
97+
# register fallback schedule
98+
config_space = autotvm.ConfigSpace()
99+
with autotvm.task.ApplyConfig(config_space):
100+
sch, args = self.func(fallback=True, **each_kwargs)
101+
subname = name + "fallback"
102+
yield sch, args, subname
103+
104+
def get_op_name(self, name, args):
105+
return name + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args if hasattr(arg, 'shape')])
106+
107+
def get_config_spaces(self):
108+
for each_kwargs in self.arg_combination:
109+
if self.attrs_valid(**each_kwargs) and self.dispatchable is True:
110+
name = self.name \
111+
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs])
112+
config_space = autotvm.ConfigSpace()
113+
with autotvm.task.ApplyConfig(config_space):
114+
self.func(fallback=False, **each_kwargs)
115+
yield config_space, name
84116

85117
def get_binds(self, args):
86118
if self.auto_broadcast:

0 commit comments

Comments
 (0)