Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit e8883e4

Browse files
committed
fix dtypes
1 parent 5b6d697 commit e8883e4

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

src/operator/sequence_last-inl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,11 @@ class SequenceLastOp : public Operator {
121121
using namespace mshadow::expr;
122122

123123
auto axis = param_.axis;
124-
int batch = out_grad.size(0);
125-
int rest = out_grad.size(1);
126-
int out_size = batch * rest;
124+
index_t batch = out_grad.size(0);
125+
index_t rest = out_grad.size(1);
126+
index_t out_size = batch * rest;
127127

128-
int max_seq_len = in_grad.size(axis);
128+
index_t max_seq_len = in_grad.size(axis);
129129
index_t offset1 = axis ? rest : out_size;
130130
index_t offset2 = axis ? (max_seq_len * rest) : rest;
131131

src/operator/sequence_last.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ template <>
3131
Operator *CreateOp<cpu>(SequenceLastParam param, int dtype, int itype) {
3232
Operator *op = nullptr;
3333
MSHADOW_TYPE_SWITCH(dtype, DType, {
34-
// MSHADOW_TYPE_SWITCH(itype, IType, {
35-
op = new SequenceLastOp<cpu, DType, int64_t>(param);
36-
// });
34+
MSHADOW_TYPE_SWITCH(itype, IType, {
35+
op = new SequenceLastOp<cpu, DType, IType>(param);
36+
});
3737
});
3838
return op;
3939
}

tests/nightly/test_large_vector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def test_sequence_reverse():
346346

347347

348348
def test_sequence_last():
349-
a = nd.arange(0, LARGE_X * 2).reshape(LARGE_X, 2)
349+
a = nd.arange(0, LARGE_X * 2, dtype="int64").reshape(LARGE_X, 2)
350350

351351
# test if returns last sequence
352352
b = nd.SequenceLast(a)
@@ -356,7 +356,7 @@ def test_sequence_last():
356356
# test with sequence length
357357
# parameter sequence_length - NDArray with shape (batch_size)
358358
# (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2
359-
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]),
359+
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3], dtype="int64"),
360360
use_sequence_length=True)
361361
# check if it takes 2nd sequence from the first batch
362362
assert b[0] == a[1][0]

0 commit comments

Comments
 (0)