@@ -55,7 +55,9 @@ extern "C" {
55
55
#endif
56
56
57
57
/*! \brief manually define unsigned int */
58
- typedef unsigned int mx_uint ;
58
+ typedef uint32_t mx_uint ;
59
+ /*! \brief manually define 64-bit int */
60
+ typedef int64_t mx_int64 ;
59
61
/*! \brief manually define float */
60
62
typedef float mx_float ;
61
63
/*! \brief data type to store dim size */
@@ -556,6 +558,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
556
558
int dtype ,
557
559
NDArrayHandle * out );
558
560
561
+ MXNET_DLL int MXNDArrayCreateExInt64 (const mx_int64 * shape ,
562
+ mx_uint ndim ,
563
+ int dev_type ,
564
+ int dev_id ,
565
+ int delay_alloc ,
566
+ int dtype ,
567
+ NDArrayHandle * out );
559
568
560
569
/*!
561
570
* \brief create an empty sparse NDArray with specified shape and data type
@@ -587,6 +596,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
587
596
const mx_uint * aux_shape ,
588
597
NDArrayHandle * out );
589
598
599
+ MXNET_DLL int MXNDArrayCreateSparseExInt64 (int storage_type ,
600
+ const mx_int64 * shape ,
601
+ mx_int64 ndim ,
602
+ int dev_type ,
603
+ int dev_id ,
604
+ int delay_alloc ,
605
+ int dtype ,
606
+ mx_uint num_aux ,
607
+ int * aux_type ,
608
+ mx_uint * aux_ndims ,
609
+ const mx_uint * aux_shape ,
610
+ NDArrayHandle * out );
611
+
590
612
/*!
591
613
* \brief create a NDArray handle that is loaded from raw bytes.
592
614
* \param buf the head of the raw bytes
@@ -634,6 +656,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
634
656
mx_uint * out_name_size ,
635
657
const char * * * out_names );
636
658
659
+ MXNET_DLL int MXNDArrayLoadInt64 (const char * fname ,
660
+ mx_int64 * out_size ,
661
+ NDArrayHandle * * out_arr ,
662
+ mx_int64 * out_name_size ,
663
+ const char * * * out_names );
664
+
637
665
/*!
638
666
* \brief Load list / dictionary of narrays from file content loaded into memory.
639
667
* This will load a list of ndarrays in a similar
@@ -649,11 +677,18 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
649
677
* \return 0 when success, -1 when failure happens
650
678
*/
651
679
MXNET_DLL int MXNDArrayLoadFromBuffer (const void * ndarray_buffer ,
652
- size_t size ,
653
- mx_uint * out_size ,
654
- NDArrayHandle * * out_arr ,
655
- mx_uint * out_name_size ,
656
- const char * * * out_names );
680
+ size_t size ,
681
+ mx_uint * out_size ,
682
+ NDArrayHandle * * out_arr ,
683
+ mx_uint * out_name_size ,
684
+ const char * * * out_names );
685
+
686
+ MXNET_DLL int MXNDArrayLoadFromBufferInt64 (const void * ndarray_buffer ,
687
+ size_t size ,
688
+ mx_int64 * out_size ,
689
+ NDArrayHandle * * out_arr ,
690
+ mx_int64 * out_name_size ,
691
+ const char * * * out_names );
657
692
658
693
/*!
659
694
* \brief Perform a synchronize copy from a continugous CPU memory region.
@@ -793,6 +828,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
793
828
MXNET_DLL int MXNDArrayGetShape (NDArrayHandle handle ,
794
829
mx_uint * out_dim ,
795
830
const mx_uint * * out_pdata );
831
+
832
+ MXNET_DLL int MXNDArrayGetShapeInt64 (NDArrayHandle handle ,
833
+ mx_int64 * out_dim ,
834
+ const mx_int64 * * out_pdata );
835
+
796
836
/*!
797
837
* \brief get the shape of the array
798
838
* \param handle the handle to the narray
@@ -803,6 +843,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
803
843
MXNET_DLL int MXNDArrayGetShapeEx (NDArrayHandle handle ,
804
844
int * out_dim ,
805
845
const int * * out_pdata );
846
+
847
+ MXNET_DLL int MXNDArrayGetShapeExInt64 (NDArrayHandle handle ,
848
+ int * out_dim ,
849
+ const int64_t * * out_pdata );
850
+
806
851
/*!
807
852
* \brief get the content of the data in NDArray
808
853
* \param handle the handle to the ndarray
@@ -886,6 +931,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
886
931
mx_uint i ,
887
932
int * out_type );
888
933
934
+ MXNET_DLL int MXNDArrayGetAuxTypeInt64 (NDArrayHandle handle ,
935
+ mx_int64 i ,
936
+ int * out_type );
937
+
889
938
/*!
890
939
* \brief Get a deep copy of the ith aux data blob
891
940
* in the form of an NDArray of default storage type.
@@ -895,6 +944,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
895
944
mx_uint i ,
896
945
NDArrayHandle * out );
897
946
947
+ MXNET_DLL int MXNDArrayGetAuxNDArrayInt64 (NDArrayHandle handle ,
948
+ mx_int64 i ,
949
+ NDArrayHandle * out );
950
+
898
951
/*!
899
952
* \brief Get a deep copy of the data blob
900
953
* in the form of an NDArray of default storage type.
@@ -950,6 +1003,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
950
1003
*/
951
1004
MXNET_DLL int MXListFunctions (mx_uint * out_size ,
952
1005
FunctionHandle * * out_array );
1006
+
1007
+ MXNET_DLL int MXListFunctionsInt64 (mx_int64 * out_size ,
1008
+ FunctionHandle * * out_array );
1009
+
953
1010
/*!
954
1011
* \brief get the function handle by name
955
1012
* \param name the name of the function
@@ -1217,6 +1274,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
1217
1274
*/
1218
1275
MXNET_DLL int MXListAllOpNames (mx_uint * out_size ,
1219
1276
const char * * * out_array );
1277
+
1278
+ MXNET_DLL int MXListAllOpNamesInt64 (mx_int64 * out_size ,
1279
+ const char * * * out_array );
1280
+
1220
1281
/*!
1221
1282
* \brief list all the available AtomicSymbolEntry
1222
1283
* \param out_size the size of returned array
@@ -1226,6 +1287,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
1226
1287
MXNET_DLL int MXSymbolListAtomicSymbolCreators (mx_uint * out_size ,
1227
1288
AtomicSymbolCreator * * out_array );
1228
1289
1290
+ MXNET_DLL int MXSymbolListAtomicSymbolCreatorsInt64 (mx_int64 * out_size ,
1291
+ AtomicSymbolCreator * * out_array );
1292
+
1229
1293
/*!
1230
1294
* \brief Get the name of an atomic symbol.
1231
1295
* \param creator the AtomicSymbolCreator.
@@ -1438,6 +1502,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
1438
1502
MXNET_DLL int MXSymbolListArguments (SymbolHandle symbol ,
1439
1503
mx_uint * out_size ,
1440
1504
const char * * * out_str_array );
1505
+
1506
+ MXNET_DLL int MXSymbolListArgumentsInt64 (SymbolHandle symbol ,
1507
+ mx_int64 * out_size ,
1508
+ const char * * * out_str_array );
1509
+
1441
1510
/*!
1442
1511
* \brief List returns in the symbol.
1443
1512
* \param symbol the symbol
@@ -1449,14 +1518,18 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
1449
1518
mx_uint * out_size ,
1450
1519
const char * * * out_str_array );
1451
1520
1521
+ MXNET_DLL int MXSymbolListOutputsInt64 (SymbolHandle symbol ,
1522
+ mx_int64 * out_size ,
1523
+ const char * * * out_str_array );
1524
+
1452
1525
/*!
1453
1526
* \brief Get number of outputs of the symbol.
1454
1527
* \param symbol The symbol
1455
1528
* \param out_size number of outputs
1456
1529
* \return 0 when success, -1 when failure happens
1457
1530
*/
1458
1531
MXNET_DLL int MXSymbolGetNumOutputs (SymbolHandle symbol ,
1459
- mx_uint * output_count );
1532
+ mx_uint * output_count );
1460
1533
1461
1534
/*!
1462
1535
* \brief Get a symbol that contains all the internals.
@@ -1495,6 +1568,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
1495
1568
MXNET_DLL int MXSymbolListAuxiliaryStates (SymbolHandle symbol ,
1496
1569
mx_uint * out_size ,
1497
1570
const char * * * out_str_array );
1571
+
1572
+ MXNET_DLL int MXSymbolListAuxiliaryStatesInt64 (SymbolHandle symbol ,
1573
+ mx_int64 * out_size ,
1574
+ const char * * * out_str_array );
1575
+
1498
1576
/*!
1499
1577
* \brief Compose the symbol on other symbols.
1500
1578
*
@@ -1566,6 +1644,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
1566
1644
const mx_uint * * * aux_shape_data ,
1567
1645
int * complete );
1568
1646
1647
+ MXNET_DLL int MXSymbolInferShapeInt64 (SymbolHandle sym ,
1648
+ mx_uint num_args ,
1649
+ const char * * keys ,
1650
+ const mx_int64 * arg_ind_ptr ,
1651
+ const mx_int64 * arg_shape_data ,
1652
+ mx_int64 * in_shape_size ,
1653
+ const mx_int64 * * in_shape_ndim ,
1654
+ const mx_int64 * * * in_shape_data ,
1655
+ mx_int64 * out_shape_size ,
1656
+ const mx_int64 * * out_shape_ndim ,
1657
+ const mx_int64 * * * out_shape_data ,
1658
+ mx_int64 * aux_shape_size ,
1659
+ const mx_int64 * * aux_shape_ndim ,
1660
+ const mx_int64 * * * aux_shape_data ,
1661
+ int * complete );
1662
+
1569
1663
/*!
1570
1664
* \brief infer shape of unknown input shapes given the known one.
1571
1665
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
@@ -1603,6 +1697,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
1603
1697
const int * * aux_shape_ndim ,
1604
1698
const int * * * aux_shape_data ,
1605
1699
int * complete );
1700
+
1701
+ MXNET_DLL int MXSymbolInferShapeExInt64 (SymbolHandle sym ,
1702
+ mx_uint num_args ,
1703
+ const char * * keys ,
1704
+ const mx_uint * arg_ind_ptr ,
1705
+ const int * arg_shape_data ,
1706
+ mx_uint * in_shape_size ,
1707
+ const int * * in_shape_ndim ,
1708
+ const int64_t * * * in_shape_data ,
1709
+ mx_uint * out_shape_size ,
1710
+ const int * * out_shape_ndim ,
1711
+ const int64_t * * * out_shape_data ,
1712
+ mx_uint * aux_shape_size ,
1713
+ const int * * aux_shape_ndim ,
1714
+ const int64_t * * * aux_shape_data ,
1715
+ int * complete );
1716
+
1606
1717
/*!
1607
1718
* \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
1608
1719
* partially infer shape of unknown input shapes given the known one.
@@ -1644,6 +1755,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
1644
1755
const mx_uint * * * aux_shape_data ,
1645
1756
int * complete );
1646
1757
1758
+ MXNET_DLL int MXSymbolInferShapePartialInt64 (SymbolHandle sym ,
1759
+ mx_uint num_args ,
1760
+ const char * * keys ,
1761
+ const mx_int64 * arg_ind_ptr ,
1762
+ const mx_int64 * arg_shape_data ,
1763
+ mx_int64 * in_shape_size ,
1764
+ const mx_int64 * * in_shape_ndim ,
1765
+ const mx_int64 * * * in_shape_data ,
1766
+ mx_int64 * out_shape_size ,
1767
+ const mx_int64 * * out_shape_ndim ,
1768
+ const mx_int64 * * * out_shape_data ,
1769
+ mx_int64 * aux_shape_size ,
1770
+ const mx_int64 * * aux_shape_ndim ,
1771
+ const mx_int64 * * * aux_shape_data ,
1772
+ int * complete );
1647
1773
1648
1774
/*!
1649
1775
* \brief partially infer shape of unknown input shapes given the known one.
@@ -1685,6 +1811,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym,
1685
1811
const int * * * aux_shape_data ,
1686
1812
int * complete );
1687
1813
1814
+ MXNET_DLL int MXSymbolInferShapePartialExInt64 (SymbolHandle sym ,
1815
+ mx_uint num_args ,
1816
+ const char * * keys ,
1817
+ const mx_int64 * arg_ind_ptr ,
1818
+ const int * arg_shape_data ,
1819
+ mx_int64 * in_shape_size ,
1820
+ const int * * in_shape_ndim ,
1821
+ const int * * * in_shape_data ,
1822
+ mx_int64 * out_shape_size ,
1823
+ const int * * out_shape_ndim ,
1824
+ const int * * * out_shape_data ,
1825
+ mx_int64 * aux_shape_size ,
1826
+ const int * * aux_shape_ndim ,
1827
+ const int * * * aux_shape_data ,
1828
+ int * complete );
1829
+
1688
1830
/*!
1689
1831
* \brief infer type of unknown input types given the known one.
1690
1832
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
0 commit comments