Skip to content

Commit 9aa5cad

Browse files
access2rohitRohit Kumar Srivastava
authored andcommitted
Large Index Support for Slice (apache#15593)
* Adding Large Index Support for slice operator * adding changes to fix py2 related error in CI/CD * fixing base.py * rearrange system call and slower Feature() call * refactoring c_api, c_symbolic_api, c_api_common * templatizing code * caching results of runtime features and minor refactoring * fixing local caching in ndarray shape
1 parent 2a2a9d7 commit 9aa5cad

File tree

17 files changed

+550
-184
lines changed

17 files changed

+550
-184
lines changed

include/mxnet/c_api.h

Lines changed: 149 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ extern "C" {
5555
#endif
5656

5757
/*! \brief manually define unsigned int */
58-
typedef unsigned int mx_uint;
58+
typedef uint32_t mx_uint;
59+
/*! \brief manually define 64-bit int */
60+
typedef int64_t mx_int64;
5961
/*! \brief manually define float */
6062
typedef float mx_float;
6163
/*! \brief data type to store dim size */
@@ -556,6 +558,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
556558
int dtype,
557559
NDArrayHandle *out);
558560

561+
MXNET_DLL int MXNDArrayCreateEx64(const mx_int64 *shape,
562+
int ndim,
563+
int dev_type,
564+
int dev_id,
565+
int delay_alloc,
566+
int dtype,
567+
NDArrayHandle *out);
559568

560569
/*!
561570
* \brief create an empty sparse NDArray with specified shape and data type
@@ -587,6 +596,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
587596
const mx_uint *aux_shape,
588597
NDArrayHandle *out);
589598

599+
MXNET_DLL int MXNDArrayCreateSparseEx64(int storage_type,
600+
const mx_int64 *shape,
601+
int ndim,
602+
int dev_type,
603+
int dev_id,
604+
int delay_alloc,
605+
int dtype,
606+
mx_uint num_aux,
607+
int *aux_type,
608+
int *aux_ndims,
609+
const mx_int64 *aux_shape,
610+
NDArrayHandle *out);
611+
590612
/*!
591613
* \brief create a NDArray handle that is loaded from raw bytes.
592614
* \param buf the head of the raw bytes
@@ -634,6 +656,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
634656
mx_uint *out_name_size,
635657
const char*** out_names);
636658

659+
MXNET_DLL int MXNDArrayLoad64(const char* fname,
660+
mx_int64 *out_size,
661+
NDArrayHandle** out_arr,
662+
mx_int64 *out_name_size,
663+
const char*** out_names);
664+
637665
/*!
638666
* \brief Load list / dictionary of narrays from file content loaded into memory.
639667
* This will load a list of ndarrays in a similar
@@ -649,11 +677,18 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
649677
* \return 0 when success, -1 when failure happens
650678
*/
651679
MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
652-
size_t size,
653-
mx_uint *out_size,
654-
NDArrayHandle** out_arr,
655-
mx_uint *out_name_size,
656-
const char*** out_names);
680+
size_t size,
681+
mx_uint *out_size,
682+
NDArrayHandle** out_arr,
683+
mx_uint *out_name_size,
684+
const char*** out_names);
685+
686+
MXNET_DLL int MXNDArrayLoadFromBuffer64(const void *ndarray_buffer,
687+
size_t size,
688+
mx_int64 *out_size,
689+
NDArrayHandle** out_arr,
690+
mx_int64 *out_name_size,
691+
const char*** out_names);
657692

658693
/*!
659694
* \brief Perform a synchronize copy from a continugous CPU memory region.
@@ -793,6 +828,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
793828
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
794829
mx_uint *out_dim,
795830
const mx_uint **out_pdata);
831+
832+
MXNET_DLL int MXNDArrayGetShape64(NDArrayHandle handle,
833+
int *out_dim,
834+
const int64_t **out_pdata);
835+
796836
/*!
797837
* \brief get the shape of the array
798838
* \param handle the handle to the narray
@@ -803,6 +843,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
803843
MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle,
804844
int *out_dim,
805845
const int **out_pdata);
846+
847+
MXNET_DLL int MXNDArrayGetShapeEx64(NDArrayHandle handle,
848+
int *out_dim,
849+
const mx_int64 **out_pdata);
850+
806851
/*!
807852
* \brief get the content of the data in NDArray
808853
* \param handle the handle to the ndarray
@@ -886,6 +931,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
886931
mx_uint i,
887932
int *out_type);
888933

934+
MXNET_DLL int MXNDArrayGetAuxType64(NDArrayHandle handle,
935+
mx_int64 i,
936+
int *out_type);
937+
889938
/*!
890939
* \brief Get a deep copy of the ith aux data blob
891940
* in the form of an NDArray of default storage type.
@@ -895,6 +944,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
895944
mx_uint i,
896945
NDArrayHandle *out);
897946

947+
MXNET_DLL int MXNDArrayGetAuxNDArray64(NDArrayHandle handle,
948+
mx_int64 i,
949+
NDArrayHandle *out);
950+
898951
/*!
899952
* \brief Get a deep copy of the data blob
900953
* in the form of an NDArray of default storage type.
@@ -950,6 +1003,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
9501003
*/
9511004
MXNET_DLL int MXListFunctions(mx_uint *out_size,
9521005
FunctionHandle **out_array);
1006+
1007+
MXNET_DLL int MXListFunctions64(mx_int64 *out_size,
1008+
FunctionHandle **out_array);
1009+
9531010
/*!
9541011
* \brief get the function handle by name
9551012
* \param name the name of the function
@@ -1217,6 +1274,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
12171274
*/
12181275
MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
12191276
const char ***out_array);
1277+
1278+
MXNET_DLL int MXListAllOpNames64(mx_int64 *out_size,
1279+
const char ***out_array);
1280+
12201281
/*!
12211282
* \brief list all the available AtomicSymbolEntry
12221283
* \param out_size the size of returned array
@@ -1226,6 +1287,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
12261287
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
12271288
AtomicSymbolCreator **out_array);
12281289

1290+
MXNET_DLL int MXSymbolListAtomicSymbolCreators64(mx_int64 *out_size,
1291+
AtomicSymbolCreator **out_array);
1292+
12291293
/*!
12301294
* \brief Get the name of an atomic symbol.
12311295
* \param creator the AtomicSymbolCreator.
@@ -1438,6 +1502,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
14381502
MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
14391503
mx_uint *out_size,
14401504
const char ***out_str_array);
1505+
1506+
MXNET_DLL int MXSymbolListArguments64(SymbolHandle symbol,
1507+
size_t *out_size,
1508+
const char ***out_str_array);
1509+
14411510
/*!
14421511
* \brief List returns in the symbol.
14431512
* \param symbol the symbol
@@ -1449,14 +1518,18 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
14491518
mx_uint *out_size,
14501519
const char ***out_str_array);
14511520

1521+
MXNET_DLL int MXSymbolListOutputs64(SymbolHandle symbol,
1522+
size_t *out_size,
1523+
const char ***out_str_array);
1524+
14521525
/*!
14531526
* \brief Get number of outputs of the symbol.
14541527
* \param symbol The symbol
14551528
* \param out_size number of outputs
14561529
* \return 0 when success, -1 when failure happens
14571530
*/
14581531
MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
1459-
mx_uint *output_count);
1532+
mx_uint *output_count);
14601533

14611534
/*!
14621535
* \brief Get a symbol that contains all the internals.
@@ -1495,6 +1568,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
14951568
MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
14961569
mx_uint *out_size,
14971570
const char ***out_str_array);
1571+
1572+
MXNET_DLL int MXSymbolListAuxiliaryStates64(SymbolHandle symbol,
1573+
size_t *out_size,
1574+
const char ***out_str_array);
1575+
14981576
/*!
14991577
* \brief Compose the symbol on other symbols.
15001578
*
@@ -1566,6 +1644,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
15661644
const mx_uint ***aux_shape_data,
15671645
int *complete);
15681646

1647+
MXNET_DLL int MXSymbolInferShape64(SymbolHandle sym,
1648+
mx_uint num_args,
1649+
const char** keys,
1650+
const mx_int64 *arg_ind_ptr,
1651+
const mx_int64 *arg_shape_data,
1652+
size_t *in_shape_size,
1653+
const int **in_shape_ndim,
1654+
const mx_int64 ***in_shape_data,
1655+
size_t *out_shape_size,
1656+
const int **out_shape_ndim,
1657+
const mx_int64 ***out_shape_data,
1658+
size_t *aux_shape_size,
1659+
const int **aux_shape_ndim,
1660+
const mx_int64 ***aux_shape_data,
1661+
int *complete);
1662+
15691663
/*!
15701664
* \brief infer shape of unknown input shapes given the known one.
15711665
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
@@ -1603,6 +1697,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
16031697
const int **aux_shape_ndim,
16041698
const int ***aux_shape_data,
16051699
int *complete);
1700+
1701+
MXNET_DLL int MXSymbolInferShapeEx64(SymbolHandle sym,
1702+
mx_uint num_args,
1703+
const char** keys,
1704+
const mx_int64 *arg_ind_ptr,
1705+
const mx_int64 *arg_shape_data,
1706+
size_t *in_shape_size,
1707+
const int **in_shape_ndim,
1708+
const mx_int64 ***in_shape_data,
1709+
size_t *out_shape_size,
1710+
const int **out_shape_ndim,
1711+
const mx_int64 ***out_shape_data,
1712+
size_t *aux_shape_size,
1713+
const int **aux_shape_ndim,
1714+
const mx_int64 ***aux_shape_data,
1715+
int *complete);
1716+
16061717
/*!
16071718
* \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
16081719
* partially infer shape of unknown input shapes given the known one.
@@ -1644,6 +1755,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
16441755
const mx_uint ***aux_shape_data,
16451756
int *complete);
16461757

1758+
MXNET_DLL int MXSymbolInferShapePartial64(SymbolHandle sym,
1759+
mx_uint num_args,
1760+
const char** keys,
1761+
const mx_int64 *arg_ind_ptr,
1762+
const mx_int64 *arg_shape_data,
1763+
size_t *in_shape_size,
1764+
const int **in_shape_ndim,
1765+
const mx_int64 ***in_shape_data,
1766+
size_t *out_shape_size,
1767+
const int **out_shape_ndim,
1768+
const mx_int64 ***out_shape_data,
1769+
size_t *aux_shape_size,
1770+
const int **aux_shape_ndim,
1771+
const mx_int64 ***aux_shape_data,
1772+
int *complete);
16471773

16481774
/*!
16491775
* \brief partially infer shape of unknown input shapes given the known one.
@@ -1685,6 +1811,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym,
16851811
const int ***aux_shape_data,
16861812
int *complete);
16871813

1814+
MXNET_DLL int MXSymbolInferShapePartialEx64(SymbolHandle sym,
1815+
mx_uint num_args,
1816+
const char** keys,
1817+
const mx_int64 *arg_ind_ptr,
1818+
const mx_int64 *arg_shape_data,
1819+
size_t *in_shape_size,
1820+
const int **in_shape_ndim,
1821+
const mx_int64 ***in_shape_data,
1822+
size_t *out_shape_size,
1823+
const int **out_shape_ndim,
1824+
const mx_int64 ***out_shape_data,
1825+
size_t *aux_shape_size,
1826+
const int **aux_shape_ndim,
1827+
const mx_int64 ***aux_shape_data,
1828+
int *complete);
1829+
16881830
/*!
16891831
* \brief infer type of unknown input types given the known one.
16901832
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data

include/mxnet/c_predict_api.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ extern "C" {
4242
#endif
4343

4444
/*! \brief manually define unsigned int */
45-
typedef unsigned int mx_uint;
45+
typedef uint32_t mx_uint;
46+
/*! \brief manually define 64-bit int */
47+
typedef int64_t mx_int64;
4648
/*! \brief manually define float */
4749
typedef float mx_float;
4850
/*! \brief handle to Predictor */

python/mxnet/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _load_lib():
215215
# type definitions
216216
mx_int = ctypes.c_int
217217
mx_uint = ctypes.c_uint
218+
mx_int64 = ctypes.c_int64
218219
mx_float = ctypes.c_float
219220
mx_float_p = ctypes.POINTER(mx_float)
220221
mx_real_t = _np.float32

python/mxnet/ndarray/ndarray.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@
3333
import warnings
3434
import operator
3535
from functools import reduce # pylint: disable=redefined-builtin
36+
import sys
3637
import numpy as np
3738
from ..base import _LIB, numeric_types, integer_types
3839
from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
39-
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int
40+
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64
4041
from ..base import ctypes2buffer
42+
from ..runtime import Features
4143
from ..context import Context, current_context
4244
from . import _internal
4345
from . import op
@@ -105,6 +107,14 @@
105107
_NDARRAY_BASIC_INDEXING = 0
106108
_NDARRAY_ADVANCED_INDEXING = 1
107109

110+
# Caching whether MXNet was built with INT64 support or not
111+
_INT64_TENSOR_SIZE_ENABLED = None
112+
113+
def _int64_enabled():
114+
global _INT64_TENSOR_SIZE_ENABLED
115+
if _INT64_TENSOR_SIZE_ENABLED is None:
116+
_INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE')
117+
return _INT64_TENSOR_SIZE_ENABLED
108118

109119
def _new_empty_handle():
110120
"""Returns a new empty handle.
@@ -132,14 +142,24 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
132142
A new empty `NDArray` handle.
133143
"""
134144
hdl = NDArrayHandle()
135-
check_call(_LIB.MXNDArrayCreateEx(
136-
c_array_buf(mx_uint, native_array('I', shape)),
137-
mx_uint(len(shape)),
138-
ctypes.c_int(ctx.device_typeid),
139-
ctypes.c_int(ctx.device_id),
140-
ctypes.c_int(int(delay_alloc)),
141-
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
142-
ctypes.byref(hdl)))
145+
if sys.version_info[0] > 2 and _int64_enabled():
146+
check_call(_LIB.MXNDArrayCreateEx64(
147+
c_array_buf(mx_int64, native_array('q', shape)),
148+
ctypes.c_int(len(shape)),
149+
ctypes.c_int(ctx.device_typeid),
150+
ctypes.c_int(ctx.device_id),
151+
ctypes.c_int(int(delay_alloc)),
152+
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
153+
ctypes.byref(hdl)))
154+
else:
155+
check_call(_LIB.MXNDArrayCreateEx(
156+
c_array_buf(mx_uint, native_array('I', shape)),
157+
mx_uint(len(shape)),
158+
ctypes.c_int(ctx.device_typeid),
159+
ctypes.c_int(ctx.device_id),
160+
ctypes.c_int(int(delay_alloc)),
161+
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
162+
ctypes.byref(hdl)))
143163
return hdl
144164

145165

@@ -2118,9 +2138,14 @@ def shape(self):
21182138
(2L, 3L, 4L)
21192139
"""
21202140
ndim = mx_int()
2121-
pdata = ctypes.POINTER(mx_int)()
2122-
check_call(_LIB.MXNDArrayGetShapeEx(
2123-
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
2141+
if _int64_enabled():
2142+
pdata = ctypes.POINTER(mx_int64)()
2143+
check_call(_LIB.MXNDArrayGetShapeEx64(
2144+
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
2145+
else:
2146+
pdata = ctypes.POINTER(mx_int)()
2147+
check_call(_LIB.MXNDArrayGetShapeEx(
2148+
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
21242149
if ndim.value == -1:
21252150
return None
21262151
else:

0 commit comments

Comments
 (0)