This repository was archived by the owner on Nov 17, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +31
-4
lines changed Expand file tree Collapse file tree 2 files changed +31
-4
lines changed Original file line number Diff line number Diff line change @@ -1950,10 +1950,10 @@ struct ReverseParam : public dmlc::Parameter<ReverseParam> {
1950
1950
#define REVERSE_MAX_DIM 10U
1951
1951
1952
1952
struct reverse {
1953
- MSHADOW_XINLINE static int ReverseIndex (index_t idx,
1954
- index_t nreversedim,
1955
- const index_t * stride_,
1956
- const index_t * trailing_) {
1953
+ MSHADOW_XINLINE static index_t ReverseIndex (index_t idx,
1954
+ index_t nreversedim,
1955
+ const index_t * stride_,
1956
+ const index_t * trailing_) {
1957
1957
index_t outputIndex = idx;
1958
1958
for (index_t i = 0 ; i < nreversedim; ++i) {
1959
1959
const index_t low = outputIndex % trailing_[i];
Original file line number Diff line number Diff line change @@ -292,6 +292,33 @@ def test_unravel_index():
292
292
assert (indices_2d .asnumpy () == np .array (original_2d_indices )).all ()
293
293
294
294
295
+ def create_2d_tensor (rows , columns ):
296
+ a = np .arange (0 , rows ).reshape (rows , 1 )
297
+ b = np .broadcast_to (a , shape = (a .shape [0 ], columns ))
298
+ return nd .array (b , dtype = np .int64 )
299
+
300
+
301
+ def test_transpose ():
302
+ b = create_2d_tensor (rows = LARGE_X , columns = SMALL_Y )
303
+ t = b .T
304
+ assert t .shape == (SMALL_Y , LARGE_X )
305
+ assert np .sum (t [:, - 1 ].asnumpy () == (LARGE_X - 1 )) == b .shape [1 ]
306
+
307
+
308
+ def test_swapaxes ():
309
+ b = create_2d_tensor (rows = LARGE_X , columns = SMALL_Y )
310
+ t = nd .swapaxes (b , dim1 = 0 , dim2 = 1 )
311
+ assert t .shape == (SMALL_Y , LARGE_X )
312
+ assert np .sum (t [:, - 1 ].asnumpy () == (LARGE_X - 1 )) == b .shape [1 ]
313
+
314
+
315
+ def test_flip ():
316
+ b = create_2d_tensor (rows = LARGE_X , columns = SMALL_Y )
317
+ t = nd .flip (b , axis = 0 )
318
+ assert t .shape == (LARGE_X , SMALL_Y )
319
+ assert np .sum (t [- 1 , :].asnumpy () == 0 ) == b .shape [1 ]
320
+
321
+
295
322
if __name__ == '__main__' :
296
323
import nose
297
324
nose .runmodule ()
You can’t perform that action at this time.
0 commit comments