Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 124a282

Browse files
author
Rohit Kumar Srivastava
committed
caching results of runtime features and minor refactoring
1 parent 3c1d76b commit 124a282

File tree

7 files changed

+281
-279
lines changed

7 files changed

+281
-279
lines changed

include/mxnet/c_api.h

Lines changed: 117 additions & 117 deletions
Large diffs are not rendered by default.

include/mxnet/tuple.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ class Tuple {
374374
}
375375
};
376376

377+
377378
/*! brief check if a shape's ndim is known. */
378379
inline bool ndim_is_known(const int ndim) {
379380
CHECK_GE(ndim, -1) << "shape ndim must be >= -1, while received " << ndim;

python/mxnet/ndarray/ndarray.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107
_NDARRAY_BASIC_INDEXING = 0
108108
_NDARRAY_ADVANCED_INDEXING = 1
109109

110+
# Caching whether MXNet was built with INT64 support or not
111+
_INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE')
110112

111113
def _new_empty_handle():
112114
"""Returns a new empty handle.
@@ -134,8 +136,8 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
134136
A new empty `NDArray` handle.
135137
"""
136138
hdl = NDArrayHandle()
137-
if Features().is_enabled('INT64_TENSOR_SIZE') and sys.version_info[0] > 2:
138-
check_call(_LIB.MXNDArrayCreateExInt64(
139+
if sys.version_info[0] > 2 and _INT64_TENSOR_SIZE_ENABLED:
140+
check_call(_LIB.MXNDArrayCreateEx64(
139141
c_array_buf(mx_int64, native_array('q', shape)),
140142
ctypes.c_int(len(shape)),
141143
ctypes.c_int(ctx.device_typeid),
@@ -2230,9 +2232,9 @@ def shape(self):
22302232
(2L, 3L, 4L)
22312233
"""
22322234
ndim = mx_int()
2233-
if sys.version_info[0] > 2 and Features().is_enabled('INT64_TENSOR_SIZE'):
2235+
if _INT64_TENSOR_SIZE_ENABLED:
22342236
pdata = ctypes.POINTER(mx_int64)()
2235-
check_call(_LIB.MXNDArrayGetShapeExInt64(
2237+
check_call(_LIB.MXNDArrayGetShapeEx64(
22362238
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
22372239
else:
22382240
pdata = ctypes.POINTER(mx_int)()

python/mxnet/symbol/symbol.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@
3939
from ..base import check_call, MXNetError, NotImplementedForSymbol
4040
from ..context import Context, current_context
4141
from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP
42-
from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
42+
from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _INT64_TENSOR_SIZE_ENABLED
4343
from ..ndarray import _ndarray_cls
44-
from ..runtime import Features
4544
from ..executor import Executor
4645
from . import _internal
4746
from . import op
@@ -1213,14 +1212,14 @@ def _infer_shape_impl(self, partial, *args, **kwargs):
12131212
aux_shape_size = mx_uint()
12141213
aux_shape_ndim = ctypes.POINTER(mx_int)()
12151214
complete = ctypes.c_int()
1216-
if Features().is_enabled('INT64_TENSOR_SIZE') and sys.version_info[0] > 2:
1215+
if sys.version_info[0] > 2 and _INT64_TENSOR_SIZE_ENABLED:
12171216
arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
12181217
out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
12191218
aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
12201219
if partial:
1221-
infer_func = _LIB.MXSymbolInferShapePartialExInt64
1220+
infer_func = _LIB.MXSymbolInferShapePartialEx64
12221221
else:
1223-
infer_func = _LIB.MXSymbolInferShapeExInt64
1222+
infer_func = _LIB.MXSymbolInferShapeEx64
12241223
check_call(infer_func(
12251224
self.handle,
12261225
mx_uint(len(indptr) - 1),

src/c_api/c_api.cc

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,13 @@ int MXNDArrayCreateNone(NDArrayHandle *out) {
190190
}
191191

192192
template<typename DataType, typename dimtype>
193-
void CreateArray(const DataType* shape, dimtype ndim, int dev_type, int dev_id, int delay_alloc,
194-
int dtype, NDArrayHandle* out) {
193+
void CreateNDArray(const DataType* shape,
194+
dimtype ndim,
195+
int dev_type,
196+
int dev_id,
197+
int delay_alloc,
198+
int dtype,
199+
NDArrayHandle* out) {
195200
*out = new NDArray(mxnet::TShape(shape, shape + ndim),
196201
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
197202
delay_alloc != 0, dtype);
@@ -210,15 +215,15 @@ int MXNDArrayCreate(const mx_uint *shape,
210215
API_END();
211216
}
212217

213-
int MXNDArrayCreateExInt64(const mx_int64 *shape,
214-
int ndim,
215-
int dev_type,
216-
int dev_id,
217-
int delay_alloc,
218-
int dtype,
219-
NDArrayHandle *out) {
218+
int MXNDArrayCreateEx64(const mx_int64 *shape,
219+
int ndim,
220+
int dev_type,
221+
int dev_id,
222+
int delay_alloc,
223+
int dtype,
224+
NDArrayHandle *out) {
220225
API_BEGIN();
221-
CreateArray<mx_int64, int>(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out);
226+
CreateNDArray<mx_int64, int>(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out);
222227
API_END();
223228
}
224229

@@ -230,11 +235,7 @@ int MXNDArrayCreateEx(const mx_uint *shape,
230235
int dtype,
231236
NDArrayHandle *out) {
232237
API_BEGIN();
233-
*out = new NDArray(
234-
mxnet::TShape(shape, shape + ndim),
235-
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
236-
delay_alloc != 0,
237-
dtype);
238+
CreateNDArray<mx_uint, mx_uint>(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out);
238239
API_END();
239240
}
240241

@@ -558,7 +559,7 @@ int MXNDArrayGetShape(NDArrayHandle handle,
558559

559560
template<typename dtype>
560561
inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim,
561-
MXAPIThreadLocalEntry<dtype>* ret) {
562+
MXAPIThreadLocalEntry<dtype>* ret) {
562563
NDArray* arr = static_cast<NDArray*>(handle);
563564
if (!arr->is_none()) {
564565
mxnet::TShape s = arr->shape();
@@ -590,9 +591,9 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle,
590591
API_END();
591592
}
592593

593-
int MXNDArrayGetShapeExInt64(NDArrayHandle handle,
594-
int *out_dim,
595-
const mx_int64 **out_pdata) {
594+
int MXNDArrayGetShapeEx64(NDArrayHandle handle,
595+
int *out_dim,
596+
const mx_int64 **out_pdata) {
596597
MXAPIThreadLocalEntry<int64_t> *ret = MXAPIThreadLocalStore<int64_t>::Get();
597598
API_BEGIN();
598599
GetShape<mx_int64>(handle, out_pdata, out_dim, ret);

0 commit comments

Comments
 (0)