Why Are Hidden State and Cell State Vectors Zero After Training an LSTM Model with trainNetwork Functionality?

9 views (last 30 days)
I am training an LSTM model using the trainNetwork functionality and follwing is the architecture of my model:
layers = [ ...
sequenceInputLayer(size(X_train{1},1))
layerNormalizationLayer
lstmLayer(x.num_hidden_units,'OutputMode','sequence')
fullyConnectedLayer(x.num_layers_ffnn)
dropoutLayer(0.1)
fullyConnectedLayer(1)
regressionLayer];
And I am training this using the following command:
options = trainingOptions('adam', ...
'MaxEpochs', 75, ...
'MiniBatchSize', x.batch_size, ...
'SequenceLength', 'longest', ...
'Shuffle', 'once', ...
'L2Regularization',0.01,...
'ValidationData',{X_val,Y_val}, ...
'ValidationFrequency',10,...
'Verbose',false,...
'ExecutionEnvironment','multi-gpu');
% Train the LSTM network
net = trainNetwork(X_train, Y_train, layers, options);
After training the model, the Hidden state and Cell state values for the LSTM layer is a vector of zeros. Why is this happening? I expect these vectors to have non-zero values to ensure the long term dependency between input and output parameters is captured.

Accepted Answer

Neha
Neha on 20 Oct 2023
Hi Shubham,
The LSTM (Long Short-Term Memory) layer in a neural network is designed to remember values over arbitrary time intervals which indeed helps in maintaining and learning long-term dependencies. However, after training, the hidden and cell states of the LSTM layer are reset to zero. This is standard behavior for LSTMs, and it doesn't mean that the LSTM layer has not learned anything or that it's not working properly.
If you want to maintain the state of LSTM for some reason (like in case of time series prediction where you want the model to remember the state from the previous sequence), you can refer to the explanation for Open Loop Forecasting and Closed Loop Forecasting in the following documentation link:
Here "predictAndUpdateState" function has been used which updates the network state at every timestep.
Hope this helps!
  1 Comment
Shubham Baisthakur
Shubham Baisthakur on 20 Oct 2023
Thanks, Neha! Refering to the 'predictAndUpdateState' function you mentioned, I was wondering if this is applicable to LSTM networks with multivariate input features? The example in the attached link talks about using the previous time steps of a signal to predict the future steps, which is not the kind of problem I am working on.

Sign in to comment.

More Answers (0)

Categories

Find more on Deep Learning Toolbox in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!