Skip to content

Commit 558ca23

Browse files
ChaiBapchyaRohit Kumar Srivastava
authored andcommitted
Sequence last fix (apache#16156)
* seq last fix * index tensor to have int64 * fix dtypes * revert unnecessary changes * if seq len not passed, pass int64 dtype * dtype comment * use int32 or int64 as index dtype based on build flag * Trigger notification * Trigger notification * lint fix
1 parent 4e508f7 commit 558ca23

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

src/operator/sequence_last-inl.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ class SequenceLastOp : public Operator {
101101
using namespace mshadow::expr;
102102

103103
int axis = param_.axis;
104-
int out_size = out.size(0) * out.size(1);
105-
int max_seq_len = data.size(axis);
104+
index_t out_size = out.size(0) * out.size(1);
105+
index_t max_seq_len = data.size(axis);
106106
index_t offset1 = axis ? out.size(1) : out_size;
107107
index_t offset2 = axis ? (max_seq_len * out.size(1)) : out.size(1);
108108

@@ -121,11 +121,11 @@ class SequenceLastOp : public Operator {
121121
using namespace mshadow::expr;
122122

123123
auto axis = param_.axis;
124-
int batch = out_grad.size(0);
125-
int rest = out_grad.size(1);
126-
int out_size = batch * rest;
124+
index_t batch = out_grad.size(0);
125+
index_t rest = out_grad.size(1);
126+
index_t out_size = batch * rest;
127127

128-
int max_seq_len = in_grad.size(axis);
128+
index_t max_seq_len = in_grad.size(axis);
129129
index_t offset1 = axis ? rest : out_size;
130130
index_t offset2 = axis ? (max_seq_len * rest) : rest;
131131

src/operator/sequence_last.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,14 @@ Operator *SequenceLastProp::CreateOperatorEx(Context ctx,
4646
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[1]);
4747
}
4848

49-
// sequence_length not passed in, so fall back to using input array dtype for second argument
50-
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_type)[0]);
49+
// sequence_length not passed in, so fall back to using int32/int64 dtype for second argument
50+
// second argument is the dtype of the sequence_length NDArray
51+
// use int32 or int64 as index dtype based on build flag
52+
#if MXNET_USE_INT64_TENSOR_SIZE == 1
53+
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], mshadow::kInt64);
54+
#else
55+
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], mshadow::kInt32);
56+
#endif
5157
}
5258

5359
DMLC_REGISTER_PARAMETER(SequenceLastParam);

tests/nightly/test_large_vector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ def test_sequence_last():
356356
# test with sequence length
357357
# parameter sequence_length - NDArray with shape (batch_size)
358358
# (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2
359-
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]),
359+
# need to mention dtype = int64 for sequence_length ndarray to support large indices
360+
# else it defaults to float32 and errors
361+
b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3], dtype="int64"),
360362
use_sequence_length=True)
361363
# check if it takes 2nd sequence from the first batch
362364
assert b[0] == a[1][0]

0 commit comments

Comments
 (0)