@@ -55,7 +55,9 @@ extern "C" {
55
55
#endif
56
56
57
57
/*! \brief manually define unsigned int */
58
- typedef unsigned int mx_uint ;
58
+ typedef uint32_t mx_uint ;
59
+ /*! \brief manually define 64-bit int */
60
+ typedef int64_t mx_int64 ;
59
61
/*! \brief manually define float */
60
62
typedef float mx_float ;
61
63
/*! \brief data type to store dim size */
@@ -565,6 +567,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
565
567
int dtype ,
566
568
NDArrayHandle * out );
567
569
570
+ MXNET_DLL int MXNDArrayCreateExInt64 (const mx_int64 * shape ,
571
+ mx_uint ndim ,
572
+ int dev_type ,
573
+ int dev_id ,
574
+ int delay_alloc ,
575
+ int dtype ,
576
+ NDArrayHandle * out );
568
577
569
578
/*!
570
579
* \brief create an empty sparse NDArray with specified shape and data type
@@ -596,6 +605,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
596
605
const mx_uint * aux_shape ,
597
606
NDArrayHandle * out );
598
607
608
+ MXNET_DLL int MXNDArrayCreateSparseExInt64 (int storage_type ,
609
+ const mx_int64 * shape ,
610
+ mx_int64 ndim ,
611
+ int dev_type ,
612
+ int dev_id ,
613
+ int delay_alloc ,
614
+ int dtype ,
615
+ mx_uint num_aux ,
616
+ int * aux_type ,
617
+ mx_uint * aux_ndims ,
618
+ const mx_uint * aux_shape ,
619
+ NDArrayHandle * out );
620
+
599
621
/*!
600
622
* \brief create a NDArray handle that is loaded from raw bytes.
601
623
* \param buf the head of the raw bytes
@@ -643,6 +665,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
643
665
mx_uint * out_name_size ,
644
666
const char * * * out_names );
645
667
668
+ MXNET_DLL int MXNDArrayLoadInt64 (const char * fname ,
669
+ mx_int64 * out_size ,
670
+ NDArrayHandle * * out_arr ,
671
+ mx_int64 * out_name_size ,
672
+ const char * * * out_names );
673
+
646
674
/*!
647
675
* \brief Load list / dictionary of narrays from file content loaded into memory.
648
676
* This will load a list of ndarrays in a similar
@@ -658,11 +686,18 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
658
686
* \return 0 when success, -1 when failure happens
659
687
*/
660
688
MXNET_DLL int MXNDArrayLoadFromBuffer (const void * ndarray_buffer ,
661
- size_t size ,
662
- mx_uint * out_size ,
663
- NDArrayHandle * * out_arr ,
664
- mx_uint * out_name_size ,
665
- const char * * * out_names );
689
+ size_t size ,
690
+ mx_uint * out_size ,
691
+ NDArrayHandle * * out_arr ,
692
+ mx_uint * out_name_size ,
693
+ const char * * * out_names );
694
+
695
+ MXNET_DLL int MXNDArrayLoadFromBufferInt64 (const void * ndarray_buffer ,
696
+ size_t size ,
697
+ mx_int64 * out_size ,
698
+ NDArrayHandle * * out_arr ,
699
+ mx_int64 * out_name_size ,
700
+ const char * * * out_names );
666
701
667
702
/*!
668
703
* \brief Perform a synchronize copy from a continugous CPU memory region.
@@ -802,6 +837,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
802
837
MXNET_DLL int MXNDArrayGetShape (NDArrayHandle handle ,
803
838
mx_uint * out_dim ,
804
839
const mx_uint * * out_pdata );
840
+
841
+ MXNET_DLL int MXNDArrayGetShapeInt64 (NDArrayHandle handle ,
842
+ mx_int64 * out_dim ,
843
+ const mx_int64 * * out_pdata );
844
+
805
845
/*!
806
846
* \brief get the shape of the array
807
847
* \param handle the handle to the narray
@@ -812,6 +852,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
812
852
MXNET_DLL int MXNDArrayGetShapeEx (NDArrayHandle handle ,
813
853
int * out_dim ,
814
854
const int * * out_pdata );
855
+
856
+ MXNET_DLL int MXNDArrayGetShapeExInt64 (NDArrayHandle handle ,
857
+ int * out_dim ,
858
+ const mx_int64 * * out_pdata );
859
+
815
860
/*!
816
861
* \brief get the content of the data in NDArray
817
862
* \param handle the handle to the ndarray
@@ -895,6 +940,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
895
940
mx_uint i ,
896
941
int * out_type );
897
942
943
+ MXNET_DLL int MXNDArrayGetAuxTypeInt64 (NDArrayHandle handle ,
944
+ mx_int64 i ,
945
+ int * out_type );
946
+
898
947
/*!
899
948
* \brief Get a deep copy of the ith aux data blob
900
949
* in the form of an NDArray of default storage type.
@@ -904,6 +953,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
904
953
mx_uint i ,
905
954
NDArrayHandle * out );
906
955
956
+ MXNET_DLL int MXNDArrayGetAuxNDArrayInt64 (NDArrayHandle handle ,
957
+ mx_int64 i ,
958
+ NDArrayHandle * out );
959
+
907
960
/*!
908
961
* \brief Get a deep copy of the data blob
909
962
* in the form of an NDArray of default storage type.
@@ -959,6 +1012,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
959
1012
*/
960
1013
MXNET_DLL int MXListFunctions (mx_uint * out_size ,
961
1014
FunctionHandle * * out_array );
1015
+
1016
+ MXNET_DLL int MXListFunctionsInt64 (mx_int64 * out_size ,
1017
+ FunctionHandle * * out_array );
1018
+
962
1019
/*!
963
1020
* \brief get the function handle by name
964
1021
* \param name the name of the function
@@ -1226,6 +1283,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
1226
1283
*/
1227
1284
MXNET_DLL int MXListAllOpNames (mx_uint * out_size ,
1228
1285
const char * * * out_array );
1286
+
1287
+ MXNET_DLL int MXListAllOpNamesInt64 (mx_int64 * out_size ,
1288
+ const char * * * out_array );
1289
+
1229
1290
/*!
1230
1291
* \brief list all the available AtomicSymbolEntry
1231
1292
* \param out_size the size of returned array
@@ -1235,6 +1296,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
1235
1296
MXNET_DLL int MXSymbolListAtomicSymbolCreators (mx_uint * out_size ,
1236
1297
AtomicSymbolCreator * * out_array );
1237
1298
1299
+ MXNET_DLL int MXSymbolListAtomicSymbolCreatorsInt64 (mx_int64 * out_size ,
1300
+ AtomicSymbolCreator * * out_array );
1301
+
1238
1302
/*!
1239
1303
* \brief Get the name of an atomic symbol.
1240
1304
* \param creator the AtomicSymbolCreator.
@@ -1447,6 +1511,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
1447
1511
MXNET_DLL int MXSymbolListArguments (SymbolHandle symbol ,
1448
1512
mx_uint * out_size ,
1449
1513
const char * * * out_str_array );
1514
+
1515
+ MXNET_DLL int MXSymbolListArgumentsInt64 (SymbolHandle symbol ,
1516
+ mx_int64 * out_size ,
1517
+ const char * * * out_str_array );
1518
+
1450
1519
/*!
1451
1520
* \brief List returns in the symbol.
1452
1521
* \param symbol the symbol
@@ -1458,14 +1527,18 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
1458
1527
mx_uint * out_size ,
1459
1528
const char * * * out_str_array );
1460
1529
1530
+ MXNET_DLL int MXSymbolListOutputsInt64 (SymbolHandle symbol ,
1531
+ mx_int64 * out_size ,
1532
+ const char * * * out_str_array );
1533
+
1461
1534
/*!
1462
1535
* \brief Get number of outputs of the symbol.
1463
1536
* \param symbol The symbol
1464
1537
* \param out_size number of outputs
1465
1538
* \return 0 when success, -1 when failure happens
1466
1539
*/
1467
1540
MXNET_DLL int MXSymbolGetNumOutputs (SymbolHandle symbol ,
1468
- mx_uint * output_count );
1541
+ mx_uint * output_count );
1469
1542
1470
1543
/*!
1471
1544
* \brief Get a symbol that contains all the internals.
@@ -1504,6 +1577,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
1504
1577
MXNET_DLL int MXSymbolListAuxiliaryStates (SymbolHandle symbol ,
1505
1578
mx_uint * out_size ,
1506
1579
const char * * * out_str_array );
1580
+
1581
+ MXNET_DLL int MXSymbolListAuxiliaryStatesInt64 (SymbolHandle symbol ,
1582
+ mx_int64 * out_size ,
1583
+ const char * * * out_str_array );
1584
+
1507
1585
/*!
1508
1586
* \brief Compose the symbol on other symbols.
1509
1587
*
@@ -1575,6 +1653,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
1575
1653
const mx_uint * * * aux_shape_data ,
1576
1654
int * complete );
1577
1655
1656
+ MXNET_DLL int MXSymbolInferShapeInt64 (SymbolHandle sym ,
1657
+ mx_uint num_args ,
1658
+ const char * * keys ,
1659
+ const mx_int64 * arg_ind_ptr ,
1660
+ const mx_int64 * arg_shape_data ,
1661
+ mx_int64 * in_shape_size ,
1662
+ const mx_int64 * * in_shape_ndim ,
1663
+ const mx_int64 * * * in_shape_data ,
1664
+ mx_int64 * out_shape_size ,
1665
+ const mx_int64 * * out_shape_ndim ,
1666
+ const mx_int64 * * * out_shape_data ,
1667
+ mx_int64 * aux_shape_size ,
1668
+ const mx_int64 * * aux_shape_ndim ,
1669
+ const mx_int64 * * * aux_shape_data ,
1670
+ int * complete );
1671
+
1578
1672
/*!
1579
1673
* \brief infer shape of unknown input shapes given the known one.
1580
1674
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
@@ -1612,6 +1706,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
1612
1706
const int * * aux_shape_ndim ,
1613
1707
const int * * * aux_shape_data ,
1614
1708
int * complete );
1709
+
1710
+ MXNET_DLL int MXSymbolInferShapeExInt64 (SymbolHandle sym ,
1711
+ mx_uint num_args ,
1712
+ const char * * keys ,
1713
+ const mx_uint * arg_ind_ptr ,
1714
+ const int * arg_shape_data ,
1715
+ mx_uint * in_shape_size ,
1716
+ const int * * in_shape_ndim ,
1717
+ const int64_t * * * in_shape_data ,
1718
+ mx_uint * out_shape_size ,
1719
+ const int * * out_shape_ndim ,
1720
+ const int64_t * * * out_shape_data ,
1721
+ mx_uint * aux_shape_size ,
1722
+ const int * * aux_shape_ndim ,
1723
+ const int64_t * * * aux_shape_data ,
1724
+ int * complete );
1725
+
1615
1726
/*!
1616
1727
* \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
1617
1728
* partially infer shape of unknown input shapes given the known one.
@@ -1653,6 +1764,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
1653
1764
const mx_uint * * * aux_shape_data ,
1654
1765
int * complete );
1655
1766
1767
+ MXNET_DLL int MXSymbolInferShapePartialInt64 (SymbolHandle sym ,
1768
+ mx_uint num_args ,
1769
+ const char * * keys ,
1770
+ const mx_int64 * arg_ind_ptr ,
1771
+ const mx_int64 * arg_shape_data ,
1772
+ mx_int64 * in_shape_size ,
1773
+ const mx_int64 * * in_shape_ndim ,
1774
+ const mx_int64 * * * in_shape_data ,
1775
+ mx_int64 * out_shape_size ,
1776
+ const mx_int64 * * out_shape_ndim ,
1777
+ const mx_int64 * * * out_shape_data ,
1778
+ mx_int64 * aux_shape_size ,
1779
+ const mx_int64 * * aux_shape_ndim ,
1780
+ const mx_int64 * * * aux_shape_data ,
1781
+ int * complete );
1656
1782
1657
1783
/*!
1658
1784
* \brief partially infer shape of unknown input shapes given the known one.
@@ -1694,6 +1820,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym,
1694
1820
const int * * * aux_shape_data ,
1695
1821
int * complete );
1696
1822
1823
+ MXNET_DLL int MXSymbolInferShapePartialExInt64 (SymbolHandle sym ,
1824
+ mx_uint num_args ,
1825
+ const char * * keys ,
1826
+ const mx_int64 * arg_ind_ptr ,
1827
+ const int * arg_shape_data ,
1828
+ mx_int64 * in_shape_size ,
1829
+ const int * * in_shape_ndim ,
1830
+ const int * * * in_shape_data ,
1831
+ mx_int64 * out_shape_size ,
1832
+ const int * * out_shape_ndim ,
1833
+ const int * * * out_shape_data ,
1834
+ mx_int64 * aux_shape_size ,
1835
+ const int * * aux_shape_ndim ,
1836
+ const int * * * aux_shape_data ,
1837
+ int * complete );
1838
+
1697
1839
/*!
1698
1840
* \brief infer type of unknown input types given the known one.
1699
1841
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
0 commit comments