This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Backward doesn't work on LSTM with sequence_length #15268
Closed
Description
Description
LSTM with out-of-the-box variable length was introduced in this PR. I tried to use it, and while the forward pass works well, the backward pass fails.
I provide minimum reproducible example. To my best knowledge, the backward pass is not covered with a unit test.
Environment info (Required)
The latest version with --pre
Package used (Python/R/Scala/Julia):
Python
Error Message:
MXNetError: [17:18:04] src/operator/./rnn-inl.h:1006: Check failed: in_data.size() == num_inputs (4 vs. 5) :
Stack trace:
[bt] (0) /home/ubuntu/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x4a157b) [0x7fdedb45957b]
[bt] (1) /home/ubuntu/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x507b9ad) [0x7fdee00339ad]
[bt] (2) /home/ubuntu/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x50b5cac) [0x7fdee006dcac]
[bt] (3) /home/ubuntu/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(mxnet::imperative::PushOperator(mxnet::OpStatePtr const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, mxnet::DispatchMode)::{lambda(mxnet::RunContext, mxnet::engine::CallbackOnComplete)#3}::operator()(mxnet::RunContext, mxnet::engine::CallbackOnComplete) const+0x396) [0x7fdedd6b3d36]
[bt] (4) /home/ubuntu/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(std::_Function_handler<void (mxnet::RunContext), mxnet::imperative::PushOperator(mxnet::OpStatePtr const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, mxnet::DispatchMode)::{lambda(mxnet::RunContext)#4}>::_M_invoke(std::_Any_data const&, mxnet::RunContext)+0x5d) [0x7fdedd6b43cd]
[bt] (5) /home/ubuntu/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x264c4f9) [0x7fdedd6044f9]
[bt] (6) /home/ubuntu/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2658961) [0x7fdedd610961]
[bt] (7) /home/ubuntu/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x265be70) [0x7fdedd613e70]
[bt] (8) /home/ubuntu/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x265c106) [0x7fdedd614106]
Minimum reproducible example
You have to use GPU to run it, as this feature is GPU only.
The devil is in the fact that the backward
fails silently and mx.nd.waitall()
is necessary at the end
import mxnet as mx
import numpy as np
from mxnet.gluon import nn
from mxnet.gluon.rnn import LSTM
ctx = mx.gpu(0)
label = mx.nd.array([1, 2, 3, 4, 5, 6, 7], ctx=ctx)
# random numbers, but with ones at the end as a padding symbol
x = mx.nd.array([[5434, 3232, 776, 323, 1, 1, 1], [4353, 545, 37, 23, 23, 545, 1]], ctx=ctx)
embedding = nn.Embedding(input_dim=6000,
output_dim=100,
weight_initializer=mx.initializer.Uniform(0.001))
lstm = LSTM(hidden_size=100,
num_layers=1, dropout=0.2, bidirectional=True,
use_sequence_length=True)
dense = nn.Dense(1)
l1 = mx.gluon.loss.L1Loss()
embedding.initialize(ctx=ctx)
lstm.initialize(ctx=ctx)
dense.initialize(ctx=ctx)
with mx.autograd.record():
x_mask = x != 1
x_len = mx.nd.sum(x_mask, axis=1).astype(np.int32)
state = lstm.begin_state(batch_size=x.shape[0], ctx=x.context)
x_emb = embedding(x)
x_emb = x_emb.transpose((1, 0, 2))
a, _ = lstm(x_emb, states=state, sequence_length=x_len)
out = dense(a)
loss = l1(out, label)
# this prints the loss, showing that forward pass works fine
print(loss)
# this one will fail
loss.backward()
mx.nd.waitall()