@@ -55,7 +55,8 @@ extern "C" {
55
55
#endif
56
56
57
57
/*! \brief manually define unsigned int */
58
- typedef unsigned int mx_uint ;
58
+ typedef int64_t mx_int64 ;
59
+ typedef uint32_t mx_uint ;
59
60
/*! \brief manually define float */
60
61
typedef float mx_float ;
61
62
/*! \brief data type to store dim size */
@@ -556,6 +557,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
556
557
int dtype ,
557
558
NDArrayHandle * out );
558
559
560
+ MXNET_DLL int MXNDArrayCreateExInt64 (const mx_int64 * shape ,
561
+ mx_uint ndim ,
562
+ int dev_type ,
563
+ int dev_id ,
564
+ int delay_alloc ,
565
+ int dtype ,
566
+ NDArrayHandle * out );
559
567
560
568
/*!
561
569
* \brief create an empty sparse NDArray with specified shape and data type
@@ -587,6 +595,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
587
595
const mx_uint * aux_shape ,
588
596
NDArrayHandle * out );
589
597
598
+ MXNET_DLL int MXNDArrayCreateSparseExInt64 (int storage_type ,
599
+ const mx_int64 * shape ,
600
+ mx_int64 ndim ,
601
+ int dev_type ,
602
+ int dev_id ,
603
+ int delay_alloc ,
604
+ int dtype ,
605
+ mx_uint num_aux ,
606
+ int * aux_type ,
607
+ mx_uint * aux_ndims ,
608
+ const mx_uint * aux_shape ,
609
+ NDArrayHandle * out );
610
+
590
611
/*!
591
612
* \brief create a NDArray handle that is loaded from raw bytes.
592
613
* \param buf the head of the raw bytes
@@ -634,6 +655,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
634
655
mx_uint * out_name_size ,
635
656
const char * * * out_names );
636
657
658
+ MXNET_DLL int MXNDArrayLoadInt64 (const char * fname ,
659
+ mx_int64 * out_size ,
660
+ NDArrayHandle * * out_arr ,
661
+ mx_int64 * out_name_size ,
662
+ const char * * * out_names );
663
+
637
664
/*!
638
665
* \brief Load list / dictionary of narrays from file content loaded into memory.
639
666
* This will load a list of ndarrays in a similar
@@ -655,6 +682,13 @@ MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
655
682
mx_uint * out_name_size ,
656
683
const char * * * out_names );
657
684
685
+ MXNET_DLL int MXNDArrayLoadFromBufferInt64 (const void * ndarray_buffer ,
686
+ size_t size ,
687
+ mx_int64 * out_size ,
688
+ NDArrayHandle * * out_arr ,
689
+ mx_int64 * out_name_size ,
690
+ const char * * * out_names );
691
+
658
692
/*!
659
693
* \brief Perform a synchronize copy from a continugous CPU memory region.
660
694
*
@@ -793,6 +827,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
793
827
MXNET_DLL int MXNDArrayGetShape (NDArrayHandle handle ,
794
828
mx_uint * out_dim ,
795
829
const mx_uint * * out_pdata );
830
+
831
+ MXNET_DLL int MXNDArrayGetShapeInt64 (NDArrayHandle handle ,
832
+ mx_int64 * out_dim ,
833
+ const mx_int64 * * out_pdata );
834
+
796
835
/*!
797
836
* \brief get the shape of the array
798
837
* \param handle the handle to the narray
@@ -803,6 +842,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
803
842
MXNET_DLL int MXNDArrayGetShapeEx (NDArrayHandle handle ,
804
843
int * out_dim ,
805
844
const int * * out_pdata );
845
+
846
+ MXNET_DLL int MXNDArrayGetShapeExInt64 (NDArrayHandle handle ,
847
+ int * out_dim ,
848
+ const int64_t * * out_pdata );
849
+
806
850
/*!
807
851
* \brief get the content of the data in NDArray
808
852
* \param handle the handle to the ndarray
@@ -886,6 +930,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
886
930
mx_uint i ,
887
931
int * out_type );
888
932
933
+ MXNET_DLL int MXNDArrayGetAuxTypeInt64 (NDArrayHandle handle ,
934
+ mx_int64 i ,
935
+ int * out_type );
936
+
889
937
/*!
890
938
* \brief Get a deep copy of the ith aux data blob
891
939
* in the form of an NDArray of default storage type.
@@ -895,6 +943,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
895
943
mx_uint i ,
896
944
NDArrayHandle * out );
897
945
946
+ MXNET_DLL int MXNDArrayGetAuxNDArrayInt64 (NDArrayHandle handle ,
947
+ mx_int64 i ,
948
+ NDArrayHandle * out );
949
+
898
950
/*!
899
951
* \brief Get a deep copy of the data blob
900
952
* in the form of an NDArray of default storage type.
@@ -950,6 +1002,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, int *out);
950
1002
*/
951
1003
MXNET_DLL int MXListFunctions (mx_uint * out_size ,
952
1004
FunctionHandle * * out_array );
1005
+
1006
+ MXNET_DLL int MXListFunctionsInt64 (mx_int64 * out_size ,
1007
+ FunctionHandle * * out_array );
1008
+
953
1009
/*!
954
1010
* \brief get the function handle by name
955
1011
* \param name the name of the function
@@ -1217,6 +1273,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
1217
1273
*/
1218
1274
MXNET_DLL int MXListAllOpNames (mx_uint * out_size ,
1219
1275
const char * * * out_array );
1276
+
1277
+ MXNET_DLL int MXListAllOpNamesInt64 (mx_int64 * out_size ,
1278
+ const char * * * out_array );
1279
+
1220
1280
/*!
1221
1281
* \brief list all the available AtomicSymbolEntry
1222
1282
* \param out_size the size of returned array
@@ -1226,6 +1286,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
1226
1286
MXNET_DLL int MXSymbolListAtomicSymbolCreators (mx_uint * out_size ,
1227
1287
AtomicSymbolCreator * * out_array );
1228
1288
1289
+ MXNET_DLL int MXSymbolListAtomicSymbolCreatorsInt64 (mx_int64 * out_size ,
1290
+ AtomicSymbolCreator * * out_array );
1291
+
1229
1292
/*!
1230
1293
* \brief Get the name of an atomic symbol.
1231
1294
* \param creator the AtomicSymbolCreator.
@@ -1438,6 +1501,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle symbol,
1438
1501
MXNET_DLL int MXSymbolListArguments (SymbolHandle symbol ,
1439
1502
mx_uint * out_size ,
1440
1503
const char * * * out_str_array );
1504
+
1505
+ MXNET_DLL int MXSymbolListArgumentsInt64 (SymbolHandle symbol ,
1506
+ mx_int64 * out_size ,
1507
+ const char * * * out_str_array );
1508
+
1441
1509
/*!
1442
1510
* \brief List returns in the symbol.
1443
1511
* \param symbol the symbol
@@ -1449,14 +1517,18 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
1449
1517
mx_uint * out_size ,
1450
1518
const char * * * out_str_array );
1451
1519
1520
+ MXNET_DLL int MXSymbolListOutputsInt64 (SymbolHandle symbol ,
1521
+ mx_int64 * out_size ,
1522
+ const char * * * out_str_array );
1523
+
1452
1524
/*!
1453
1525
* \brief Get number of outputs of the symbol.
1454
1526
* \param symbol The symbol
1455
1527
* \param out_size number of outputs
1456
1528
* \return 0 when success, -1 when failure happens
1457
1529
*/
1458
1530
MXNET_DLL int MXSymbolGetNumOutputs (SymbolHandle symbol ,
1459
- mx_uint * output_count );
1531
+ mx_uint * output_count );
1460
1532
1461
1533
/*!
1462
1534
* \brief Get a symbol that contains all the internals.
@@ -1495,6 +1567,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
1495
1567
MXNET_DLL int MXSymbolListAuxiliaryStates (SymbolHandle symbol ,
1496
1568
mx_uint * out_size ,
1497
1569
const char * * * out_str_array );
1570
+
1571
+ MXNET_DLL int MXSymbolListAuxiliaryStatesInt64 (SymbolHandle symbol ,
1572
+ mx_int64 * out_size ,
1573
+ const char * * * out_str_array );
1574
+
1498
1575
/*!
1499
1576
* \brief Compose the symbol on other symbols.
1500
1577
*
@@ -1566,6 +1643,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
1566
1643
const mx_uint * * * aux_shape_data ,
1567
1644
int * complete );
1568
1645
1646
+ MXNET_DLL int MXSymbolInferShapeInt64 (SymbolHandle sym ,
1647
+ mx_uint num_args ,
1648
+ const char * * keys ,
1649
+ const mx_int64 * arg_ind_ptr ,
1650
+ const mx_int64 * arg_shape_data ,
1651
+ mx_int64 * in_shape_size ,
1652
+ const mx_int64 * * in_shape_ndim ,
1653
+ const mx_int64 * * * in_shape_data ,
1654
+ mx_int64 * out_shape_size ,
1655
+ const mx_int64 * * out_shape_ndim ,
1656
+ const mx_int64 * * * out_shape_data ,
1657
+ mx_int64 * aux_shape_size ,
1658
+ const mx_int64 * * aux_shape_ndim ,
1659
+ const mx_int64 * * * aux_shape_data ,
1660
+ int * complete );
1661
+
1569
1662
/*!
1570
1663
* \brief infer shape of unknown input shapes given the known one.
1571
1664
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
@@ -1603,6 +1696,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
1603
1696
const int * * aux_shape_ndim ,
1604
1697
const int * * * aux_shape_data ,
1605
1698
int * complete );
1699
+
1700
+ MXNET_DLL int MXSymbolInferShapeExInt64 (SymbolHandle sym ,
1701
+ mx_uint num_args ,
1702
+ const char * * keys ,
1703
+ const mx_uint * arg_ind_ptr ,
1704
+ const int * arg_shape_data ,
1705
+ mx_uint * in_shape_size ,
1706
+ const int * * in_shape_ndim ,
1707
+ const int64_t * * * in_shape_data ,
1708
+ mx_uint * out_shape_size ,
1709
+ const int * * out_shape_ndim ,
1710
+ const int64_t * * * out_shape_data ,
1711
+ mx_uint * aux_shape_size ,
1712
+ const int * * aux_shape_ndim ,
1713
+ const int64_t * * * aux_shape_data ,
1714
+ int * complete );
1715
+
1606
1716
/*!
1607
1717
* \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
1608
1718
* partially infer shape of unknown input shapes given the known one.
@@ -1644,6 +1754,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
1644
1754
const mx_uint * * * aux_shape_data ,
1645
1755
int * complete );
1646
1756
1757
+ MXNET_DLL int MXSymbolInferShapePartialInt64 (SymbolHandle sym ,
1758
+ mx_uint num_args ,
1759
+ const char * * keys ,
1760
+ const mx_int64 * arg_ind_ptr ,
1761
+ const mx_int64 * arg_shape_data ,
1762
+ mx_int64 * in_shape_size ,
1763
+ const mx_int64 * * in_shape_ndim ,
1764
+ const mx_int64 * * * in_shape_data ,
1765
+ mx_int64 * out_shape_size ,
1766
+ const mx_int64 * * out_shape_ndim ,
1767
+ const mx_int64 * * * out_shape_data ,
1768
+ mx_int64 * aux_shape_size ,
1769
+ const mx_int64 * * aux_shape_ndim ,
1770
+ const mx_int64 * * * aux_shape_data ,
1771
+ int * complete );
1647
1772
1648
1773
/*!
1649
1774
* \brief partially infer shape of unknown input shapes given the known one.
@@ -1685,6 +1810,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle sym,
1685
1810
const int * * * aux_shape_data ,
1686
1811
int * complete );
1687
1812
1813
+ MXNET_DLL int MXSymbolInferShapePartialExInt64 (SymbolHandle sym ,
1814
+ mx_uint num_args ,
1815
+ const char * * keys ,
1816
+ const mx_int64 * arg_ind_ptr ,
1817
+ const int * arg_shape_data ,
1818
+ mx_int64 * in_shape_size ,
1819
+ const int * * in_shape_ndim ,
1820
+ const int * * * in_shape_data ,
1821
+ mx_int64 * out_shape_size ,
1822
+ const int * * out_shape_ndim ,
1823
+ const int * * * out_shape_data ,
1824
+ mx_int64 * aux_shape_size ,
1825
+ const int * * aux_shape_ndim ,
1826
+ const int * * * aux_shape_data ,
1827
+ int * complete );
1828
+
1688
1829
/*!
1689
1830
* \brief infer type of unknown input types given the known one.
1690
1831
* The types are packed into a CSR matrix represented by arg_ind_ptr and arg_type_data
0 commit comments