Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 6d528f7

Browse files
committed
Expose get_all_registered_operators and get_operator_arguments in the Python API.
1 parent 7fe478a commit 6d528f7

File tree

3 files changed

+79
-77
lines changed

3 files changed

+79
-77
lines changed

benchmark/opperf/utils/op_registry_utils.py

Lines changed: 6 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import ctypes
2020
import sys
2121
from mxnet.base import _LIB, check_call, py_str, OpHandle, c_str, mx_uint
22+
import mxnet as mx
2223

2324
from benchmark.opperf.rules.default_params import DEFAULTS_INPUTS
2425

@@ -77,89 +78,19 @@ def _select_ops(operator_names, filters=("_contrib", "_"), merge_op_forward_back
7778
return mx_operators
7879

7980

80-
def _get_all_registered_ops():
81-
"""Get all registered MXNet operator names.
82-
83-
84-
Returns
85-
-------
86-
["operator_name"]
87-
"""
88-
plist = ctypes.POINTER(ctypes.c_char_p)()
89-
size = ctypes.c_uint()
90-
91-
check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
92-
ctypes.byref(plist)))
93-
94-
mx_registered_operator_names = [py_str(plist[i]) for i in range(size.value)]
95-
return mx_registered_operator_names
96-
97-
98-
def _get_op_handles(op_name):
99-
"""Get handle for an operator with given name - op_name.
100-
101-
Parameters
102-
----------
103-
op_name: str
104-
Name of operator to get handle for.
105-
"""
106-
op_handle = OpHandle()
107-
check_call(_LIB.NNGetOpHandle(c_str(op_name), ctypes.byref(op_handle)))
108-
return op_handle
109-
110-
111-
def _get_op_arguments(op_handle):
112-
"""Given operator name and handle, fetch operator arguments - number of arguments,
113-
argument names, argument types.
114-
115-
Parameters
116-
----------
117-
op_handle: OpHandle
118-
Handle for the operator
119-
120-
Returns
121-
-------
122-
(narg, arg_names, arg_types)
123-
"""
124-
real_name = ctypes.c_char_p()
125-
desc = ctypes.c_char_p()
126-
num_args = mx_uint()
127-
arg_names = ctypes.POINTER(ctypes.c_char_p)()
128-
arg_types = ctypes.POINTER(ctypes.c_char_p)()
129-
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
130-
key_var_num_args = ctypes.c_char_p()
131-
ret_type = ctypes.c_char_p()
132-
133-
check_call(_LIB.MXSymbolGetAtomicSymbolInfo(
134-
op_handle, ctypes.byref(real_name), ctypes.byref(desc),
135-
ctypes.byref(num_args),
136-
ctypes.byref(arg_names),
137-
ctypes.byref(arg_types),
138-
ctypes.byref(arg_descs),
139-
ctypes.byref(key_var_num_args),
140-
ctypes.byref(ret_type)))
141-
142-
narg = int(num_args.value)
143-
arg_names = [py_str(arg_names[i]) for i in range(narg)]
144-
arg_types = [py_str(arg_types[i]) for i in range(narg)]
145-
146-
return narg, arg_names, arg_types
147-
148-
14981
def _set_op_arguments(mx_operators):
15082
"""Fetch and set operator arguments - nargs, arg_names, arg_types
15183
"""
15284
for op_name in mx_operators:
153-
op_handle = _get_op_handles(op_name)
154-
narg, arg_names, arg_types = _get_op_arguments(op_handle)
155-
mx_operators[op_name]["params"] = {"narg": narg,
156-
"arg_names": arg_names,
157-
"arg_types": arg_types}
85+
operator_arguments = mx.operator.get_operator_arguments(op_name)
86+
mx_operators[op_name]["params"] = {"narg": operator_arguments.narg,
87+
"arg_names": operator_arguments.names,
88+
"arg_types": operator_arguments.types}
15889

15990

16091
def _get_all_mxnet_operators():
16192
# Step 1 - Get all registered op names and filter it
162-
operator_names = _get_all_registered_ops()
93+
operator_names = mx.operator.get_all_registered_operators()
16394
mx_operators = _select_ops(operator_names)
16495

16596
# Step 2 - Get all parameters for the operators

python/mxnet/operator.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222

2323
import traceback
2424
import warnings
25+
import collections
2526

2627
from array import array
2728
from threading import Lock
29+
import ctypes
2830
from ctypes import CFUNCTYPE, POINTER, Structure, pointer
2931
from ctypes import c_void_p, c_int, c_char, c_char_p, cast, c_bool
3032

31-
from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf, mx_int
33+
from .base import _LIB, check_call, MXCallbackList, c_array, c_array_buf, mx_int, OpHandle
3234
from .base import c_str, mx_uint, mx_float, ctypes2numpy_shared, NDArrayHandle, py_str
3335
from . import symbol, context
3436
from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
@@ -1099,3 +1101,60 @@ def delete_entry(_):
10991101
return do_register
11001102

11011103
register("custom_op")(CustomOpProp)
1104+
1105+
1106+
def get_all_registered_operators():
1107+
"""Get all registered MXNet operator names.
1108+
1109+
Returns
1110+
-------
1111+
operator_names : list of string
1112+
"""
1113+
plist = ctypes.POINTER(ctypes.c_char_p)()
1114+
size = ctypes.c_uint()
1115+
1116+
check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
1117+
ctypes.byref(plist)))
1118+
1119+
mx_registered_operator_names = [py_str(plist[i]) for i in range(size.value)]
1120+
return mx_registered_operator_names
1121+
1122+
OperatorArguments = collections.namedtuple('OperatorArguments', ['narg', 'names', 'types'])
1123+
1124+
def get_operator_arguments(op_name):
1125+
"""Given operator name, fetch operator arguments - number of arguments,
1126+
argument names, argument types.
1127+
1128+
Parameters
1129+
----------
1130+
op_name: str
1131+
Handle for the operator
1132+
1133+
Returns
1134+
-------
1135+
operator_arguments : OperatorArguments, namedtuple with number of arguments, names and types
1136+
"""
1137+
op_handle = OpHandle()
1138+
check_call(_LIB.NNGetOpHandle(c_str(op_name), ctypes.byref(op_handle)))
1139+
real_name = ctypes.c_char_p()
1140+
desc = ctypes.c_char_p()
1141+
num_args = mx_uint()
1142+
arg_names = ctypes.POINTER(ctypes.c_char_p)()
1143+
arg_types = ctypes.POINTER(ctypes.c_char_p)()
1144+
arg_descs = ctypes.POINTER(ctypes.c_char_p)()
1145+
key_var_num_args = ctypes.c_char_p()
1146+
ret_type = ctypes.c_char_p()
1147+
1148+
check_call(_LIB.MXSymbolGetAtomicSymbolInfo(
1149+
op_handle, ctypes.byref(real_name), ctypes.byref(desc),
1150+
ctypes.byref(num_args),
1151+
ctypes.byref(arg_names),
1152+
ctypes.byref(arg_types),
1153+
ctypes.byref(arg_descs),
1154+
ctypes.byref(key_var_num_args),
1155+
ctypes.byref(ret_type)))
1156+
1157+
narg = int(num_args.value)
1158+
arg_names = [py_str(arg_names[i]) for i in range(narg)]
1159+
arg_types = [py_str(arg_types[i]) for i in range(narg)]
1160+
return OperatorArguments(narg, arg_names, arg_types)

tests/python/unittest/test_operator.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727
from distutils.version import LooseVersion
2828
from numpy.testing import assert_allclose, assert_array_equal
2929
from mxnet.test_utils import *
30+
from mxnet.operator import *
3031
from mxnet.base import py_str, MXNetError, _as_list
3132
from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises
3233
from common import run_in_spawned_process
33-
from nose.tools import assert_raises
34+
from nose.tools import assert_raises, ok_
3435
import unittest
3536
import os
3637

@@ -8655,6 +8656,17 @@ def test_add_n():
86558656
assert_almost_equal(rslt.asnumpy(), add_n_rslt.asnumpy(), atol=1e-5)
86568657

86578658

8659+
def test_get_all_registered_operators():
8660+
ops = get_all_registered_operators()
8661+
ok_(isinstance(ops, list))
8662+
ok_(len(ops) > 0)
8663+
8664+
8665+
def test_get_operator_arguments():
8666+
operator_arguments = get_operator_arguments(mx.operator.get_all_registered_operators()[0])
8667+
ok_(isinstance(operator_arguments, OperatorArguments))
8668+
8669+
86588670
if __name__ == '__main__':
86598671
import nose
86608672
nose.runmodule()

0 commit comments

Comments
 (0)