Skip to content

Commit ca50184

Browse files
access2rohitRohit Kumar Srivastava
authored andcommitted
Large Index Support for Slice (apache#15593)
* Adding Large Index Support for slice operator * adding changes to fix py2 related error in CI/CD * fixing base.py * rearrange system call and slower Feature() call * refactoring c_api, c_symbolic_api, c_api_common * templatizing code * caching results of runtime features and minor refactoring * fixing local caching in ndarray shape
1 parent c0d48f0 commit ca50184

File tree

17 files changed

+550
-184
lines changed

17 files changed

+550
-184
lines changed

include/mxnet/c_api.h

Lines changed: 149 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ extern "C" {
5555
#endif
5656

5757
/*! \brief manually define unsigned int */
58-
typedef unsigned int mx_uint;
58+
typedef uint32_t mx_uint;
59+
/*! \brief manually define 64-bit int */
60+
typedef int64_t mx_int64;
5961
/*! \brief manually define float */
6062
typedef float mx_float;
6163
/*! \brief data type to store dim size */
@@ -572,6 +574,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
572574
int dtype,
573575
NDArrayHandle *out);
574576

577+
MXNET_DLL int MXNDArrayCreateEx64(const mx_int64 *shape,
578+
int ndim,
579+
int dev_type,
580+
int dev_id,
581+
int delay_alloc,
582+
int dtype,
583+
NDArrayHandle *out);
575584

576585
/*!
577586
* \brief create an empty sparse NDArray with specified shape and data type
@@ -603,6 +612,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
603612
const mx_uint *aux_shape,
604613
NDArrayHandle *out);
605614

615+
MXNET_DLL int MXNDArrayCreateSparseEx64(int storage_type,
616+
const mx_int64 *shape,
617+
int ndim,
618+
int dev_type,
619+
int dev_id,
620+
int delay_alloc,
621+
int dtype,
622+
mx_uint num_aux,
623+
int *aux_type,
624+
int *aux_ndims,
625+
const mx_int64 *aux_shape,
626+
NDArrayHandle *out);
627+
606628
/*!
607629
* \brief create a NDArray handle that is loaded from raw bytes.
608630
* \param buf the head of the raw bytes
@@ -650,6 +672,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
650672
mx_uint *out_name_size,
651673
const char*** out_names);
652674

675+
MXNET_DLL int MXNDArrayLoad64(const char* fname,
676+
mx_int64 *out_size,
677+
NDArrayHandle** out_arr,
678+
mx_int64 *out_name_size,
679+
const char*** out_names);
680+
653681
/*!
654682
* \brief Load list / dictionary of narrays from file content loaded into memory.
655683
* This will load a list of ndarrays in a similar
@@ -665,11 +693,18 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
665693
* \return 0 when success, -1 when failure happens
666694
*/
667695
MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
668-
size_t size,
669-
mx_uint *out_size,
670-
NDArrayHandle** out_arr,
671-
mx_uint *out_name_size,
672-
const char*** out_names);
696+
size_t size,
697+
mx_uint *out_size,
698+
NDArrayHandle** out_arr,
699+
mx_uint *out_name_size,
700+
const char*** out_names);
701+
702+
MXNET_DLL int MXNDArrayLoadFromBuffer64(const void *ndarray_buffer,
703+
size_t size,
704+
mx_int64 *out_size,
705+
NDArrayHandle** out_arr,
706+
mx_int64 *out_name_size,
707+
const char*** out_names);
673708

674709
/*!
675710
* \brief Perform a synchronize copy from a continugous CPU memory region.
@@ -809,6 +844,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
809844
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
810845
mx_uint *out_dim,
811846
const mx_uint **out_pdata);
847+
848+
MXNET_DLL int MXNDArrayGetShape64(NDArrayHandle handle,
849+
int *out_dim,
850+
const int64_t **out_pdata);
851+
812852
/*!
813853
* \brief get the shape of the array
814854
* \param handle the handle to the narray
@@ -819,6 +859,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
819859
MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle,
820860
int *out_dim,
821861
const int **out_pdata);
862+
863+
MXNET_DLL int MXNDArrayGetShapeEx64(NDArrayHandle handle,
864+
int *out_dim,
865+
const mx_int64 **out_pdata);
866+
822867
/*!
823868
* \brief get the content of the data in NDArray
824869
* \param handle the handle to the ndarray
@@ -902,6 +947,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
902947
mx_uint i,
903948
int *out_type);
904949

950+
MXNET_DLL int MXNDArrayGetAuxType64(NDArrayHandle handle,
951+
mx_int64 i,
952+
int *out_type);
953+
905954
/*!
906955
* \brief Get a deep copy of the ith aux data blob
907956
* in the form of an NDArray of default storage type.
@@ -911,6 +960,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
911960
mx_uint i,
912961
NDArrayHandle *out);
913962

963+
MXNET_DLL int MXNDArrayGetAuxNDArray64(NDArrayHandle handle,
964+
mx_int64 i,
965+
NDArrayHandle *out);
966+
914967
/*!
915968
* \brief Get a deep copy of the data blob
916969
* in the form of an NDArray of default storage type.
@@ -966,6 +1019,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
9661019
*/
9671020
MXNET_DLL int MXListFunctions(mx_uint *out_size,
9681021
FunctionHandle **out_array);
1022+
1023+
MXNET_DLL int MXListFunctions64(mx_int64 *out_size,
1024+
FunctionHandle **out_array);
1025+
9691026
/*!
9701027
* \brief get the function handle by name
9711028
* \param name the name of the function
@@ -1233,6 +1290,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
12331290
*/
12341291
MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
12351292
const char ***out_array);
1293+
1294+
MXNET_DLL int MXListAllOpNames64(mx_int64 *out_size,
1295+
const char ***out_array);
1296+
12361297
/*!
12371298
* \brief list all the available AtomicSymbolEntry
12381299
* \param out_size the size of returned array
@@ -1242,6 +1303,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
12421303
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
12431304
AtomicSymbolCreator **out_array);
12441305

1306+
MXNET_DLL int MXSymbolListAtomicSymbolCreators64(mx_int64 *out_size,
1307+
AtomicSymbolCreator **out_array);
1308+
12451309
/*!
12461310
* \brief Get the name of an atomic symbol.
12471311
* \param creator the AtomicSymbolCreator.
@@ -1454,6 +1518,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
14541518
MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
14551519
mx_uint *out_size,
14561520
const char ***out_str_array);
1521+
1522+
MXNET_DLL int MXSymbolListArguments64(SymbolHandle symbol,
1523+
size_t *out_size,
1524+
const char ***out_str_array);
1525+
14571526
/*!
14581527
* \brief List returns in the symbol.
14591528
* \param symbol the symbol
@@ -1465,14 +1534,18 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
14651534
mx_uint *out_size,
14661535
const char ***out_str_array);
14671536

1537+
MXNET_DLL int MXSymbolListOutputs64(SymbolHandle symbol,
1538+
size_t *out_size,
1539+
const char ***out_str_array);
1540+
14681541
/*!
14691542
* \brief Get number of outputs of the symbol.
14701543
* \param symbol The symbol
14711544
* \param out_size number of outputs
14721545
* \return 0 when success, -1 when failure happens
14731546
*/
14741547
MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
1475-
mx_uint *output_count);
1548+
mx_uint *output_count);
14761549

14771550
/*!
14781551
* \brief Get a symbol that contains all the internals.
@@ -1511,6 +1584,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
15111584
MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
15121585
mx_uint *out_size,
15131586
const char ***out_str_array);
1587+
1588+
MXNET_DLL int MXSymbolListAuxiliaryStates64(SymbolHandle symbol,
1589+
size_t *out_size,
1590+
const char ***out_str_array);
1591+
15141592
/*!
15151593
* \brief Compose the symbol on other symbols.
15161594
*
@@ -1582,6 +1660,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
15821660
const mx_uint ***aux_shape_data,
15831661
int *complete);
15841662

1663+
MXNET_DLL int MXSymbolInferShape64(SymbolHandle sym,
1664+
mx_uint num_args,
1665+
const char** keys,
1666+
const mx_int64 *arg_ind_ptr,
1667+
const mx_int64 *arg_shape_data,
1668+
size_t *in_shape_size,
1669+
const int **in_shape_ndim,
1670+
const mx_int64 ***in_shape_data,
1671+
size_t *out_shape_size,
1672+
const int **out_shape_ndim,
1673+
const mx_int64 ***out_shape_data,
1674+
size_t *aux_shape_size,
1675+
const int **aux_shape_ndim,
1676+
const mx_int64 ***aux_shape_data,
1677+
int *complete);
1678+
15851679
/*!
15861680
* \brief infer shape of unknown input shapes given the known one.
15871681
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
@@ -1619,6 +1713,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
16191713
const int **aux_shape_ndim,
16201714
const int ***aux_shape_data,
16211715
int *complete);
1716+
1717+
MXNET_DLL int MXSymbolInferShapeEx64(SymbolHandle sym,
1718+
mx_uint num_args,
1719+
const char** keys,
1720+
const mx_int64 *arg_ind_ptr,
1721+
const mx_int64 *arg_shape_data,
1722+
size_t *in_shape_size,
1723+
const int **in_shape_ndim,
1724+
const mx_int64 ***in_shape_data,
1725+
size_t *out_shape_size,
1726+
const int **out_shape_ndim,
1727+
const mx_int64 ***out_shape_data,
1728+
size_t *aux_shape_size,
1729+
const int **aux_shape_ndim,
1730+
const mx_int64 ***aux_shape_data,
1731+
int *complete);
1732+
16221733
/*!
16231734
* \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
16241735
* partially infer shape of unknown input shapes given the known one.
@@ -1660,6 +1771,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
16601771
const mx_uint ***aux_shape_data,
16611772
int *complete);
16621773

1774+
MXNET_DLL int MXSymbolInferShapePartial64(SymbolHandle sym,
1775+
mx_uint num_args,
1776+
const char** keys,
1777+
const mx_int64 *arg_ind_ptr,
1778+
const mx_int64 *arg_shape_data,
1779+
size_t *in_shape_size,
1780+
const int **in_shape_ndim,
1781+
const mx_int64 ***in_shape_data,
1782+
size_t *out_shape_size,
1783+
const int **out_shape_ndim,
1784+
const mx_int64 ***out_shape_data,
1785+
size_t *aux_shape_size,
1786+
const int **aux_shape_ndim,
1787+
const mx_int64 ***aux_shape_data,
1788+
int *complete);
16631789

16641790
/*!
16651791
* \brief partially infer shape of unknown input shapes given the known one.
@@ -1701,6 +1827,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym,
17011827
const int ***aux_shape_data,
17021828
int *complete);
17031829

1830+
MXNET_DLL int MXSymbolInferShapePartialEx64(SymbolHandle sym,
1831+
mx_uint num_args,
1832+
const char** keys,
1833+
const mx_int64 *arg_ind_ptr,
1834+
const mx_int64 *arg_shape_data,
1835+
size_t *in_shape_size,
1836+
const int **in_shape_ndim,
1837+
const mx_int64 ***in_shape_data,
1838+
size_t *out_shape_size,
1839+
const int **out_shape_ndim,
1840+
const mx_int64 ***out_shape_data,
1841+
size_t *aux_shape_size,
1842+
const int **aux_shape_ndim,
1843+
const mx_int64 ***aux_shape_data,
1844+
int *complete);
1845+
17041846
/*!
17051847
* \brief infer type of unknown input types given the known one.
17061848
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data

include/mxnet/c_predict_api.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ extern "C" {
4242
#endif
4343

4444
/*! \brief manually define unsigned int */
45-
typedef unsigned int mx_uint;
45+
typedef uint32_t mx_uint;
46+
/*! \brief manually define 64-bit int */
47+
typedef int64_t mx_int64;
4648
/*! \brief manually define float */
4749
typedef float mx_float;
4850
/*! \brief handle to Predictor */

python/mxnet/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _load_lib():
215215
# type definitions
216216
mx_int = ctypes.c_int
217217
mx_uint = ctypes.c_uint
218+
mx_int64 = ctypes.c_int64
218219
mx_float = ctypes.c_float
219220
mx_float_p = ctypes.POINTER(mx_float)
220221
mx_real_t = _np.float32

python/mxnet/ndarray/ndarray.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@
3333
import warnings
3434
import operator
3535
from functools import reduce # pylint: disable=redefined-builtin
36+
import sys
3637
import numpy as np
3738
from ..base import _LIB, numeric_types, integer_types
3839
from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
39-
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int
40+
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64
4041
from ..base import ctypes2buffer
42+
from ..runtime import Features
4143
from ..context import Context, current_context
4244
from . import _internal
4345
from . import op
@@ -105,6 +107,14 @@
105107
_NDARRAY_BASIC_INDEXING = 0
106108
_NDARRAY_ADVANCED_INDEXING = 1
107109

110+
# Caching whether MXNet was built with INT64 support or not
111+
_INT64_TENSOR_SIZE_ENABLED = None
112+
113+
def _int64_enabled():
114+
global _INT64_TENSOR_SIZE_ENABLED
115+
if _INT64_TENSOR_SIZE_ENABLED is None:
116+
_INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE')
117+
return _INT64_TENSOR_SIZE_ENABLED
108118

109119
def _new_empty_handle():
110120
"""Returns a new empty handle.
@@ -132,14 +142,24 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
132142
A new empty `NDArray` handle.
133143
"""
134144
hdl = NDArrayHandle()
135-
check_call(_LIB.MXNDArrayCreateEx(
136-
c_array_buf(mx_uint, native_array('I', shape)),
137-
mx_uint(len(shape)),
138-
ctypes.c_int(ctx.device_typeid),
139-
ctypes.c_int(ctx.device_id),
140-
ctypes.c_int(int(delay_alloc)),
141-
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
142-
ctypes.byref(hdl)))
145+
if sys.version_info[0] > 2 and _int64_enabled():
146+
check_call(_LIB.MXNDArrayCreateEx64(
147+
c_array_buf(mx_int64, native_array('q', shape)),
148+
ctypes.c_int(len(shape)),
149+
ctypes.c_int(ctx.device_typeid),
150+
ctypes.c_int(ctx.device_id),
151+
ctypes.c_int(int(delay_alloc)),
152+
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
153+
ctypes.byref(hdl)))
154+
else:
155+
check_call(_LIB.MXNDArrayCreateEx(
156+
c_array_buf(mx_uint, native_array('I', shape)),
157+
mx_uint(len(shape)),
158+
ctypes.c_int(ctx.device_typeid),
159+
ctypes.c_int(ctx.device_id),
160+
ctypes.c_int(int(delay_alloc)),
161+
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
162+
ctypes.byref(hdl)))
143163
return hdl
144164

145165

@@ -2118,9 +2138,14 @@ def shape(self):
21182138
(2L, 3L, 4L)
21192139
"""
21202140
ndim = mx_int()
2121-
pdata = ctypes.POINTER(mx_int)()
2122-
check_call(_LIB.MXNDArrayGetShapeEx(
2123-
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
2141+
if _int64_enabled():
2142+
pdata = ctypes.POINTER(mx_int64)()
2143+
check_call(_LIB.MXNDArrayGetShapeEx64(
2144+
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
2145+
else:
2146+
pdata = ctypes.POINTER(mx_int)()
2147+
check_call(_LIB.MXNDArrayGetShapeEx(
2148+
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
21242149
if ndim.value == -1:
21252150
return None
21262151
else:

0 commit comments

Comments
 (0)