Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit d5443f0

Browse files
author
Rohit Kumar Srivastava
committed
[MXNET-1410]Adding Large Tensor Support for tensor transpose
1 parent 8a9dd72 commit d5443f0

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

src/operator/tensor/matrix_op-inl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,10 +1950,10 @@ struct ReverseParam : public dmlc::Parameter<ReverseParam> {
19501950
#define REVERSE_MAX_DIM 10U
19511951

19521952
struct reverse {
1953-
MSHADOW_XINLINE static int ReverseIndex(index_t idx,
1954-
index_t nreversedim,
1955-
const index_t * stride_,
1956-
const index_t * trailing_) {
1953+
MSHADOW_XINLINE static index_t ReverseIndex(index_t idx,
1954+
index_t nreversedim,
1955+
const index_t * stride_,
1956+
const index_t * trailing_) {
19571957
index_t outputIndex = idx;
19581958
for (index_t i = 0; i < nreversedim; ++i) {
19591959
const index_t low = outputIndex % trailing_[i];

tests/nightly/test_large_array.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,29 @@ def test_diag():
279279
assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
280280

281281

282+
def test_transpose():
283+
a = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
284+
b = nd.broadcast_to(a, shape=(a.shape[0], SMALL_Y))
285+
t = b.T
286+
assert t.shape == (SMALL_Y, LARGE_X)
287+
assert np.sum(t[:, -1].asnumpy() == LARGE_X) == b.shape[1]
288+
289+
290+
def test_swapaxes():
291+
a = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
292+
b = nd.broadcast_to(a, shape=(a.shape[0], SMALL_Y))
293+
t = nd.swapaxes(b, dim1=0, dim2=1)
294+
assert t.shape == (SMALL_Y, LARGE_X)
295+
assert np.sum(t[:, -1].asnumpy() == LARGE_X) == b.shape[1]
296+
297+
def test_flip():
298+
a = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
299+
b = nd.broadcast_to(a, shape=(a.shape[0], SMALL_Y))
300+
t = nd.flip(b, axis=0)
301+
assert t.shape == (LARGE_X, SMALL_Y)
302+
assert np.sum(t[-1, :].asnumpy() == 0) == b.shape[1]
303+
304+
282305
if __name__ == '__main__':
283306
import nose
284307
nose.runmodule()

0 commit comments

Comments
 (0)