Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 11280e8

Browse files
author
Rohit Kumar Srivastava
committed
Adding Large Index Support for slice operator
1 parent 77254f2 commit 11280e8

File tree

12 files changed

+418
-72
lines changed

12 files changed

+418
-72
lines changed

include/mxnet/c_api.h

Lines changed: 149 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ extern "C" {
5555
#endif
5656

5757
/*! \brief manually define unsigned int */
58-
typedef unsigned int mx_uint;
58+
typedef uint32_t mx_uint;
59+
/*! \brief manually define 64-bit int */
60+
typedef int64_t mx_int64;
5961
/*! \brief manually define float */
6062
typedef float mx_float;
6163
/*! \brief data type to store dim size */
@@ -565,6 +567,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
565567
int dtype,
566568
NDArrayHandle *out);
567569

570+
MXNET_DLL int MXNDArrayCreateExInt64(const mx_int64 *shape,
571+
mx_uint ndim,
572+
int dev_type,
573+
int dev_id,
574+
int delay_alloc,
575+
int dtype,
576+
NDArrayHandle *out);
568577

569578
/*!
570579
* \brief create an empty sparse NDArray with specified shape and data type
@@ -596,6 +605,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
596605
const mx_uint *aux_shape,
597606
NDArrayHandle *out);
598607

608+
MXNET_DLL int MXNDArrayCreateSparseExInt64(int storage_type,
609+
const mx_int64 *shape,
610+
mx_int64 ndim,
611+
int dev_type,
612+
int dev_id,
613+
int delay_alloc,
614+
int dtype,
615+
mx_uint num_aux,
616+
int *aux_type,
617+
mx_uint *aux_ndims,
618+
const mx_uint *aux_shape,
619+
NDArrayHandle *out);
620+
599621
/*!
600622
* \brief create a NDArray handle that is loaded from raw bytes.
601623
* \param buf the head of the raw bytes
@@ -643,6 +665,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
643665
mx_uint *out_name_size,
644666
const char*** out_names);
645667

668+
MXNET_DLL int MXNDArrayLoadInt64(const char* fname,
669+
mx_int64 *out_size,
670+
NDArrayHandle** out_arr,
671+
mx_int64 *out_name_size,
672+
const char*** out_names);
673+
646674
/*!
647675
* \brief Load list / dictionary of narrays from file content loaded into memory.
648676
* This will load a list of ndarrays in a similar
@@ -658,11 +686,18 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
658686
* \return 0 when success, -1 when failure happens
659687
*/
660688
MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
661-
size_t size,
662-
mx_uint *out_size,
663-
NDArrayHandle** out_arr,
664-
mx_uint *out_name_size,
665-
const char*** out_names);
689+
size_t size,
690+
mx_uint *out_size,
691+
NDArrayHandle** out_arr,
692+
mx_uint *out_name_size,
693+
const char*** out_names);
694+
695+
MXNET_DLL int MXNDArrayLoadFromBufferInt64(const void *ndarray_buffer,
696+
size_t size,
697+
mx_int64 *out_size,
698+
NDArrayHandle** out_arr,
699+
mx_int64 *out_name_size,
700+
const char*** out_names);
666701

667702
/*!
668703
* \brief Perform a synchronize copy from a continugous CPU memory region.
@@ -802,6 +837,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
802837
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
803838
mx_uint *out_dim,
804839
const mx_uint **out_pdata);
840+
841+
MXNET_DLL int MXNDArrayGetShapeInt64(NDArrayHandle handle,
842+
mx_int64 *out_dim,
843+
const mx_int64 **out_pdata);
844+
805845
/*!
806846
* \brief get the shape of the array
807847
* \param handle the handle to the narray
@@ -812,6 +852,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
812852
MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle,
813853
int *out_dim,
814854
const int **out_pdata);
855+
856+
MXNET_DLL int MXNDArrayGetShapeExInt64(NDArrayHandle handle,
857+
int *out_dim,
858+
const mx_int64 **out_pdata);
859+
815860
/*!
816861
* \brief get the content of the data in NDArray
817862
* \param handle the handle to the ndarray
@@ -895,6 +940,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
895940
mx_uint i,
896941
int *out_type);
897942

943+
MXNET_DLL int MXNDArrayGetAuxTypeInt64(NDArrayHandle handle,
944+
mx_int64 i,
945+
int *out_type);
946+
898947
/*!
899948
* \brief Get a deep copy of the ith aux data blob
900949
* in the form of an NDArray of default storage type.
@@ -904,6 +953,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
904953
mx_uint i,
905954
NDArrayHandle *out);
906955

956+
MXNET_DLL int MXNDArrayGetAuxNDArrayInt64(NDArrayHandle handle,
957+
mx_int64 i,
958+
NDArrayHandle *out);
959+
907960
/*!
908961
* \brief Get a deep copy of the data blob
909962
* in the form of an NDArray of default storage type.
@@ -959,6 +1012,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
9591012
*/
9601013
MXNET_DLL int MXListFunctions(mx_uint *out_size,
9611014
FunctionHandle **out_array);
1015+
1016+
MXNET_DLL int MXListFunctionsInt64(mx_int64 *out_size,
1017+
FunctionHandle **out_array);
1018+
9621019
/*!
9631020
* \brief get the function handle by name
9641021
* \param name the name of the function
@@ -1226,6 +1283,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
12261283
*/
12271284
MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
12281285
const char ***out_array);
1286+
1287+
MXNET_DLL int MXListAllOpNamesInt64(mx_int64 *out_size,
1288+
const char ***out_array);
1289+
12291290
/*!
12301291
* \brief list all the available AtomicSymbolEntry
12311292
* \param out_size the size of returned array
@@ -1235,6 +1296,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
12351296
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
12361297
AtomicSymbolCreator **out_array);
12371298

1299+
MXNET_DLL int MXSymbolListAtomicSymbolCreatorsInt64(mx_int64 *out_size,
1300+
AtomicSymbolCreator **out_array);
1301+
12381302
/*!
12391303
* \brief Get the name of an atomic symbol.
12401304
* \param creator the AtomicSymbolCreator.
@@ -1447,6 +1511,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
14471511
MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
14481512
mx_uint *out_size,
14491513
const char ***out_str_array);
1514+
1515+
MXNET_DLL int MXSymbolListArgumentsInt64(SymbolHandle symbol,
1516+
mx_int64 *out_size,
1517+
const char ***out_str_array);
1518+
14501519
/*!
14511520
* \brief List returns in the symbol.
14521521
* \param symbol the symbol
@@ -1458,14 +1527,18 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
14581527
mx_uint *out_size,
14591528
const char ***out_str_array);
14601529

1530+
MXNET_DLL int MXSymbolListOutputsInt64(SymbolHandle symbol,
1531+
mx_int64 *out_size,
1532+
const char ***out_str_array);
1533+
14611534
/*!
14621535
* \brief Get number of outputs of the symbol.
14631536
* \param symbol The symbol
14641537
* \param out_size number of outputs
14651538
* \return 0 when success, -1 when failure happens
14661539
*/
14671540
MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
1468-
mx_uint *output_count);
1541+
mx_uint *output_count);
14691542

14701543
/*!
14711544
* \brief Get a symbol that contains all the internals.
@@ -1504,6 +1577,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
15041577
MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
15051578
mx_uint *out_size,
15061579
const char ***out_str_array);
1580+
1581+
MXNET_DLL int MXSymbolListAuxiliaryStatesInt64(SymbolHandle symbol,
1582+
mx_int64 *out_size,
1583+
const char ***out_str_array);
1584+
15071585
/*!
15081586
* \brief Compose the symbol on other symbols.
15091587
*
@@ -1575,6 +1653,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
15751653
const mx_uint ***aux_shape_data,
15761654
int *complete);
15771655

1656+
MXNET_DLL int MXSymbolInferShapeInt64(SymbolHandle sym,
1657+
mx_uint num_args,
1658+
const char** keys,
1659+
const mx_int64 *arg_ind_ptr,
1660+
const mx_int64 *arg_shape_data,
1661+
mx_int64 *in_shape_size,
1662+
const mx_int64 **in_shape_ndim,
1663+
const mx_int64 ***in_shape_data,
1664+
mx_int64 *out_shape_size,
1665+
const mx_int64 **out_shape_ndim,
1666+
const mx_int64 ***out_shape_data,
1667+
mx_int64 *aux_shape_size,
1668+
const mx_int64 **aux_shape_ndim,
1669+
const mx_int64 ***aux_shape_data,
1670+
int *complete);
1671+
15781672
/*!
15791673
* \brief infer shape of unknown input shapes given the known one.
15801674
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
@@ -1612,6 +1706,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
16121706
const int **aux_shape_ndim,
16131707
const int ***aux_shape_data,
16141708
int *complete);
1709+
1710+
MXNET_DLL int MXSymbolInferShapeExInt64(SymbolHandle sym,
1711+
mx_uint num_args,
1712+
const char** keys,
1713+
const mx_uint *arg_ind_ptr,
1714+
const int *arg_shape_data,
1715+
mx_uint *in_shape_size,
1716+
const int **in_shape_ndim,
1717+
const int64_t ***in_shape_data,
1718+
mx_uint *out_shape_size,
1719+
const int **out_shape_ndim,
1720+
const int64_t ***out_shape_data,
1721+
mx_uint *aux_shape_size,
1722+
const int **aux_shape_ndim,
1723+
const int64_t ***aux_shape_data,
1724+
int *complete);
1725+
16151726
/*!
16161727
* \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
16171728
* partially infer shape of unknown input shapes given the known one.
@@ -1653,6 +1764,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
16531764
const mx_uint ***aux_shape_data,
16541765
int *complete);
16551766

1767+
MXNET_DLL int MXSymbolInferShapePartialInt64(SymbolHandle sym,
1768+
mx_uint num_args,
1769+
const char** keys,
1770+
const mx_int64 *arg_ind_ptr,
1771+
const mx_int64 *arg_shape_data,
1772+
mx_int64 *in_shape_size,
1773+
const mx_int64 **in_shape_ndim,
1774+
const mx_int64 ***in_shape_data,
1775+
mx_int64 *out_shape_size,
1776+
const mx_int64 **out_shape_ndim,
1777+
const mx_int64 ***out_shape_data,
1778+
mx_int64 *aux_shape_size,
1779+
const mx_int64 **aux_shape_ndim,
1780+
const mx_int64 ***aux_shape_data,
1781+
int *complete);
16561782

16571783
/*!
16581784
* \brief partially infer shape of unknown input shapes given the known one.
@@ -1694,6 +1820,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym,
16941820
const int ***aux_shape_data,
16951821
int *complete);
16961822

1823+
MXNET_DLL int MXSymbolInferShapePartialExInt64(SymbolHandle sym,
1824+
mx_uint num_args,
1825+
const char** keys,
1826+
const mx_int64 *arg_ind_ptr,
1827+
const int *arg_shape_data,
1828+
mx_int64 *in_shape_size,
1829+
const int **in_shape_ndim,
1830+
const int ***in_shape_data,
1831+
mx_int64 *out_shape_size,
1832+
const int **out_shape_ndim,
1833+
const int ***out_shape_data,
1834+
mx_int64 *aux_shape_size,
1835+
const int **aux_shape_ndim,
1836+
const int ***aux_shape_data,
1837+
int *complete);
1838+
16971839
/*!
16981840
* \brief infer type of unknown input types given the known one.
16991841
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data

include/mxnet/c_predict_api.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ extern "C" {
4242
#endif
4343

4444
/*! \brief manually define unsigned int */
45-
typedef unsigned int mx_uint;
45+
typedef uint32_t mx_uint;
46+
/*! \brief manually define 64-bit int */
47+
typedef int64_t mx_int64;
4648
/*! \brief manually define float */
4749
typedef float mx_float;
4850
/*! \brief handle to Predictor */

include/mxnet/tuple.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,6 @@ class Tuple {
366366
}
367367
};
368368

369-
370369
/*! brief check if a shape's ndim is known. */
371370
inline bool ndim_is_known(const int ndim) {
372371
CHECK_GE(ndim, -1) << "shape ndim must be >= -1, while received " << ndim;

python/mxnet/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _load_lib():
215215
# type definitions
216216
mx_int = ctypes.c_int
217217
mx_uint = ctypes.c_uint
218+
mx_int64 = ctypes.c_int64
218219
mx_float = ctypes.c_float
219220
mx_float_p = ctypes.POINTER(mx_float)
220221
mx_real_t = _np.float32

python/mxnet/ndarray/ndarray.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
import numpy as np
3636
from ..base import _LIB, numeric_types, integer_types
3737
from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
38-
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int
38+
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64
3939
from ..base import ctypes2buffer
40+
from ..runtime import Features
4041
from ..context import Context, current_context
4142
from . import _internal
4243
from . import op
@@ -131,14 +132,24 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
131132
A new empty `NDArray` handle.
132133
"""
133134
hdl = NDArrayHandle()
134-
check_call(_LIB.MXNDArrayCreateEx(
135-
c_array_buf(mx_uint, native_array('I', shape)),
136-
mx_uint(len(shape)),
137-
ctypes.c_int(ctx.device_typeid),
138-
ctypes.c_int(ctx.device_id),
139-
ctypes.c_int(int(delay_alloc)),
140-
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
141-
ctypes.byref(hdl)))
135+
if Features().is_enabled('INT64_TENSOR_SIZE'):
136+
check_call(_LIB.MXNDArrayCreateExInt64(
137+
c_array_buf(mx_int64, native_array('q', shape)),
138+
mx_int64(len(shape)),
139+
ctypes.c_int(ctx.device_typeid),
140+
ctypes.c_int(ctx.device_id),
141+
ctypes.c_int(int(delay_alloc)),
142+
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
143+
ctypes.byref(hdl)))
144+
else:
145+
check_call(_LIB.MXNDArrayCreateEx(
146+
c_array_buf(mx_uint, native_array('I', shape)),
147+
mx_uint(len(shape)),
148+
ctypes.c_int(ctx.device_typeid),
149+
ctypes.c_int(ctx.device_id),
150+
ctypes.c_int(int(delay_alloc)),
151+
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
152+
ctypes.byref(hdl)))
142153
return hdl
143154

144155

@@ -1847,9 +1858,14 @@ def shape(self):
18471858
(2L, 3L, 4L)
18481859
"""
18491860
ndim = mx_int()
1850-
pdata = ctypes.POINTER(mx_int)()
1851-
check_call(_LIB.MXNDArrayGetShapeEx(
1852-
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
1861+
if Features().is_enabled('INT64_TENSOR_SIZE'):
1862+
pdata = ctypes.POINTER(mx_int64)()
1863+
check_call(_LIB.MXNDArrayGetShapeExInt64(
1864+
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
1865+
else:
1866+
pdata = ctypes.POINTER(mx_int)()
1867+
check_call(_LIB.MXNDArrayGetShapeEx(
1868+
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
18531869
if ndim.value == -1:
18541870
return None
18551871
else:

0 commit comments

Comments
 (0)