Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit c5705c8

Browse files
author
Rohit Kumar Srivastava
committed
init_commit
1 parent 4d07d78 commit c5705c8

File tree

13 files changed

+404
-75
lines changed

13 files changed

+404
-75
lines changed

debug.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import numpy as np
2+
import mxnet as mx
3+
from mxnet import ndarray as nd
4+
5+
ctx=mx.cpu()
6+
7+
def create_2d_tensor(rows, columns):
8+
a = np.arange(0, columns).reshape(1, columns)
9+
# a = np.arange(0, columns)
10+
#b = np.broadcast_to(a, shape=(rows, columns))
11+
return nd.array(a, dtype=np.int64)
12+
13+
b = create_2d_tensor(rows=1, columns=5000000000)
14+
print(b.shape)

include/mxnet/c_api.h

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ extern "C" {
5555
#endif
5656

5757
/*! \brief manually define unsigned int */
58-
typedef unsigned int mx_uint;
58+
typedef int64_t mx_int64;
59+
typedef uint32_t mx_uint;
5960
/*! \brief manually define float */
6061
typedef float mx_float;
6162
/*! \brief data type to store dim size */
@@ -556,6 +557,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
556557
int dtype,
557558
NDArrayHandle *out);
558559

560+
MXNET_DLL int MXNDArrayCreateExInt64(const mx_int64 *shape,
561+
mx_uint ndim,
562+
int dev_type,
563+
int dev_id,
564+
int delay_alloc,
565+
int dtype,
566+
NDArrayHandle *out);
559567

560568
/*!
561569
* \brief create an empty sparse NDArray with specified shape and data type
@@ -587,6 +595,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
587595
const mx_uint *aux_shape,
588596
NDArrayHandle *out);
589597

598+
MXNET_DLL int MXNDArrayCreateSparseExInt64(int storage_type,
599+
const mx_int64 *shape,
600+
mx_int64 ndim,
601+
int dev_type,
602+
int dev_id,
603+
int delay_alloc,
604+
int dtype,
605+
mx_uint num_aux,
606+
int *aux_type,
607+
mx_uint *aux_ndims,
608+
const mx_uint *aux_shape,
609+
NDArrayHandle *out);
610+
590611
/*!
591612
* \brief create a NDArray handle that is loaded from raw bytes.
592613
* \param buf the head of the raw bytes
@@ -634,6 +655,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
634655
mx_uint *out_name_size,
635656
const char*** out_names);
636657

658+
MXNET_DLL int MXNDArrayLoadInt64(const char* fname,
659+
mx_int64 *out_size,
660+
NDArrayHandle** out_arr,
661+
mx_int64 *out_name_size,
662+
const char*** out_names);
663+
637664
/*!
638665
* \brief Load list / dictionary of narrays from file content loaded into memory.
639666
* This will load a list of ndarrays in a similar
@@ -655,6 +682,13 @@ MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
655682
mx_uint *out_name_size,
656683
const char*** out_names);
657684

685+
MXNET_DLL int MXNDArrayLoadFromBufferInt64(const void *ndarray_buffer,
686+
size_t size,
687+
mx_int64 *out_size,
688+
NDArrayHandle** out_arr,
689+
mx_int64 *out_name_size,
690+
const char*** out_names);
691+
658692
/*!
659693
* \brief Perform a synchronize copy from a continugous CPU memory region.
660694
*
@@ -793,6 +827,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
793827
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
794828
mx_uint *out_dim,
795829
const mx_uint **out_pdata);
830+
831+
MXNET_DLL int MXNDArrayGetShapeInt64(NDArrayHandle handle,
832+
mx_int64 *out_dim,
833+
const mx_int64 **out_pdata);
834+
796835
/*!
797836
* \brief get the shape of the array
798837
* \param handle the handle to the narray
@@ -803,6 +842,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
803842
MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle,
804843
int *out_dim,
805844
const int **out_pdata);
845+
846+
MXNET_DLL int MXNDArrayGetShapeExInt64(NDArrayHandle handle,
847+
int *out_dim,
848+
const int64_t **out_pdata);
849+
806850
/*!
807851
* \brief get the content of the data in NDArray
808852
* \param handle the handle to the ndarray
@@ -886,6 +930,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
886930
mx_uint i,
887931
int *out_type);
888932

933+
MXNET_DLL int MXNDArrayGetAuxTypeInt64(NDArrayHandle handle,
934+
mx_int64 i,
935+
int *out_type);
936+
889937
/*!
890938
* \brief Get a deep copy of the ith aux data blob
891939
* in the form of an NDArray of default storage type.
@@ -895,6 +943,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
895943
mx_uint i,
896944
NDArrayHandle *out);
897945

946+
MXNET_DLL int MXNDArrayGetAuxNDArrayInt64(NDArrayHandle handle,
947+
mx_int64 i,
948+
NDArrayHandle *out);
949+
898950
/*!
899951
* \brief Get a deep copy of the data blob
900952
* in the form of an NDArray of default storage type.
@@ -950,6 +1002,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
9501002
*/
9511003
MXNET_DLL int MXListFunctions(mx_uint *out_size,
9521004
FunctionHandle **out_array);
1005+
1006+
MXNET_DLL int MXListFunctionsInt64(mx_int64 *out_size,
1007+
FunctionHandle **out_array);
1008+
9531009
/*!
9541010
* \brief get the function handle by name
9551011
* \param name the name of the function
@@ -1217,6 +1273,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
12171273
*/
12181274
MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
12191275
const char ***out_array);
1276+
1277+
MXNET_DLL int MXListAllOpNamesInt64(mx_int64 *out_size,
1278+
const char ***out_array);
1279+
12201280
/*!
12211281
* \brief list all the available AtomicSymbolEntry
12221282
* \param out_size the size of returned array
@@ -1226,6 +1286,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
12261286
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
12271287
AtomicSymbolCreator **out_array);
12281288

1289+
MXNET_DLL int MXSymbolListAtomicSymbolCreatorsInt64(mx_int64 *out_size,
1290+
AtomicSymbolCreator **out_array);
1291+
12291292
/*!
12301293
* \brief Get the name of an atomic symbol.
12311294
* \param creator the AtomicSymbolCreator.
@@ -1438,6 +1501,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
14381501
MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
14391502
mx_uint *out_size,
14401503
const char ***out_str_array);
1504+
1505+
MXNET_DLL int MXSymbolListArgumentsInt64(SymbolHandle symbol,
1506+
mx_int64 *out_size,
1507+
const char ***out_str_array);
1508+
14411509
/*!
14421510
* \brief List returns in the symbol.
14431511
* \param symbol the symbol
@@ -1449,14 +1517,18 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
14491517
mx_uint *out_size,
14501518
const char ***out_str_array);
14511519

1520+
MXNET_DLL int MXSymbolListOutputsInt64(SymbolHandle symbol,
1521+
mx_int64 *out_size,
1522+
const char ***out_str_array);
1523+
14521524
/*!
14531525
* \brief Get number of outputs of the symbol.
14541526
* \param symbol The symbol
14551527
* \param out_size number of outputs
14561528
* \return 0 when success, -1 when failure happens
14571529
*/
14581530
MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
1459-
mx_uint *output_count);
1531+
mx_uint *output_count);
14601532

14611533
/*!
14621534
* \brief Get a symbol that contains all the internals.
@@ -1495,6 +1567,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
14951567
MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
14961568
mx_uint *out_size,
14971569
const char ***out_str_array);
1570+
1571+
MXNET_DLL int MXSymbolListAuxiliaryStatesInt64(SymbolHandle symbol,
1572+
mx_int64 *out_size,
1573+
const char ***out_str_array);
1574+
14981575
/*!
14991576
* \brief Compose the symbol on other symbols.
15001577
*
@@ -1566,6 +1643,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
15661643
const mx_uint ***aux_shape_data,
15671644
int *complete);
15681645

1646+
MXNET_DLL int MXSymbolInferShapeInt64(SymbolHandle sym,
1647+
mx_uint num_args,
1648+
const char** keys,
1649+
const mx_int64 *arg_ind_ptr,
1650+
const mx_int64 *arg_shape_data,
1651+
mx_int64 *in_shape_size,
1652+
const mx_int64 **in_shape_ndim,
1653+
const mx_int64 ***in_shape_data,
1654+
mx_int64 *out_shape_size,
1655+
const mx_int64 **out_shape_ndim,
1656+
const mx_int64 ***out_shape_data,
1657+
mx_int64 *aux_shape_size,
1658+
const mx_int64 **aux_shape_ndim,
1659+
const mx_int64 ***aux_shape_data,
1660+
int *complete);
1661+
15691662
/*!
15701663
* \brief infer shape of unknown input shapes given the known one.
15711664
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
@@ -1603,6 +1696,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
16031696
const int **aux_shape_ndim,
16041697
const int ***aux_shape_data,
16051698
int *complete);
1699+
1700+
MXNET_DLL int MXSymbolInferShapeExInt64(SymbolHandle sym,
1701+
mx_uint num_args,
1702+
const char** keys,
1703+
const mx_uint *arg_ind_ptr,
1704+
const int *arg_shape_data,
1705+
mx_uint *in_shape_size,
1706+
const int **in_shape_ndim,
1707+
const int64_t ***in_shape_data,
1708+
mx_uint *out_shape_size,
1709+
const int **out_shape_ndim,
1710+
const int64_t ***out_shape_data,
1711+
mx_uint *aux_shape_size,
1712+
const int **aux_shape_ndim,
1713+
const int64_t ***aux_shape_data,
1714+
int *complete);
1715+
16061716
/*!
16071717
* \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
16081718
* partially infer shape of unknown input shapes given the known one.
@@ -1644,6 +1754,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
16441754
const mx_uint ***aux_shape_data,
16451755
int *complete);
16461756

1757+
MXNET_DLL int MXSymbolInferShapePartialInt64(SymbolHandle sym,
1758+
mx_uint num_args,
1759+
const char** keys,
1760+
const mx_int64 *arg_ind_ptr,
1761+
const mx_int64 *arg_shape_data,
1762+
mx_int64 *in_shape_size,
1763+
const mx_int64 **in_shape_ndim,
1764+
const mx_int64 ***in_shape_data,
1765+
mx_int64 *out_shape_size,
1766+
const mx_int64 **out_shape_ndim,
1767+
const mx_int64 ***out_shape_data,
1768+
mx_int64 *aux_shape_size,
1769+
const mx_int64 **aux_shape_ndim,
1770+
const mx_int64 ***aux_shape_data,
1771+
int *complete);
16471772

16481773
/*!
16491774
* \brief partially infer shape of unknown input shapes given the known one.
@@ -1685,6 +1810,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym,
16851810
const int ***aux_shape_data,
16861811
int *complete);
16871812

1813+
MXNET_DLL int MXSymbolInferShapePartialExInt64(SymbolHandle sym,
1814+
mx_uint num_args,
1815+
const char** keys,
1816+
const mx_int64 *arg_ind_ptr,
1817+
const int *arg_shape_data,
1818+
mx_int64 *in_shape_size,
1819+
const int **in_shape_ndim,
1820+
const int ***in_shape_data,
1821+
mx_int64 *out_shape_size,
1822+
const int **out_shape_ndim,
1823+
const int ***out_shape_data,
1824+
mx_int64 *aux_shape_size,
1825+
const int **aux_shape_ndim,
1826+
const int ***aux_shape_data,
1827+
int *complete);
1828+
16881829
/*!
16891830
* \brief infer type of unknown input types given the known one.
16901831
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data

include/mxnet/c_predict_api.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ extern "C" {
4242
#endif
4343

4444
/*! \brief manually define unsigned int */
45-
typedef unsigned int mx_uint;
45+
typedef int64_t mx_int64;
46+
typedef uint32_t mx_uint;
4647
/*! \brief manually define float */
4748
typedef float mx_float;
4849
/*! \brief handle to Predictor */

include/mxnet/tuple.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,6 @@ class Tuple {
366366
}
367367
};
368368

369-
370369
/*! brief check if a shape's ndim is known. */
371370
inline bool ndim_is_known(const int ndim) {
372371
CHECK_GE(ndim, -1) << "shape ndim must be >= -1, while received " << ndim;

python/mxnet/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _load_lib():
215215
# type definitions
216216
mx_int = ctypes.c_int
217217
mx_uint = ctypes.c_uint
218+
mx_int64 = ctypes.c_int64
218219
mx_float = ctypes.c_float
219220
mx_float_p = ctypes.POINTER(mx_float)
220221
mx_real_t = _np.float32

python/mxnet/ndarray/ndarray.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import numpy as np
3636
from ..base import _LIB, numeric_types, integer_types
3737
from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
38-
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int
38+
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64
3939
from ..base import ctypes2buffer
4040
from ..context import Context, current_context
4141
from . import _internal
@@ -130,10 +130,12 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
130130
handle
131131
A new empty `NDArray` handle.
132132
"""
133+
print("shape={}".format(shape))
133134
hdl = NDArrayHandle()
134-
check_call(_LIB.MXNDArrayCreateEx(
135-
c_array_buf(mx_uint, native_array('I', shape)),
136-
mx_uint(len(shape)),
135+
# check_call(_LIB.MXNDArrayCreateEx(
136+
check_call(_LIB.MXNDArrayCreateExInt64(
137+
c_array_buf(mx_int64, native_array('q', shape)),
138+
mx_int64(len(shape)),
137139
ctypes.c_int(ctx.device_typeid),
138140
ctypes.c_int(ctx.device_id),
139141
ctypes.c_int(int(delay_alloc)),
@@ -695,7 +697,6 @@ def _set_nd_basic_indexing(self, key, value):
695697
raise IndexError('index %d is out of bounds for axis 0 with size %d'
696698
% (key, shape[0]))
697699
key = py_slice(key, key+1) # key must be >= 0 here
698-
699700
if isinstance(key, py_slice):
700701
assign_to_self = key.step is None or key.step == 1
701702
assign_to_self &= key.start is None or key.start == 0
@@ -711,6 +712,7 @@ def _set_nd_basic_indexing(self, key, value):
711712
dtype=self.dtype, value=float(value), out=self)
712713
elif isinstance(value, (np.ndarray, np.generic)):
713714
if isinstance(value, np.generic) or value.shape != shape:
715+
print("shape={}, value.shape={}".format(shape, value.shape))
714716
value = np.broadcast_to(value, shape)
715717
self._sync_copyfrom(value)
716718
else: # value might be a list or a tuple
@@ -1847,8 +1849,8 @@ def shape(self):
18471849
(2L, 3L, 4L)
18481850
"""
18491851
ndim = mx_int()
1850-
pdata = ctypes.POINTER(mx_int)()
1851-
check_call(_LIB.MXNDArrayGetShapeEx(
1852+
pdata = ctypes.POINTER(mx_int64)()
1853+
check_call(_LIB.MXNDArrayGetShapeExInt64(
18521854
self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
18531855
if ndim.value == -1:
18541856
return None

0 commit comments

Comments
 (0)