Main Content

Visualize Activations of LSTM Network

This example shows how to investigate and visualize the features learned by LSTM networks by extracting the activations.

Load pretrained network. JapaneseVowelsNet is a pretrained LSTM network trained on the Japanese Vowels dataset as described in [1] and [2]. It was trained on the sequences sorted by sequence length with a mini-batch size of 27.

load JapaneseVowelsNet

View the network architecture.

ans = 
  4x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input    Sequence input with 12 dimensions
     2   'lstm'            LSTM              LSTM with 100 hidden units
     3   'fc'              Fully Connected   9 fully connected layer
     4   'softmax'         Softmax           softmax

Load the test data.

load JapaneseVowelsTestData

Visualize the first time series in a plot. Each line corresponds to a feature.

X = XTest{1};

xlabel("Time Step")
title("Test Observation 1")
numFeatures = size(XTest{1},1);
legend("Feature " + string(1:numFeatures),'Location',"northeastoutside")

For each time step of the sequences, get the activations output by the LSTM layer (layer 2) for that time step and update the network state.

sequenceLength = size(X,2);
idxLayer = 2;
outputSize = net.Layers(idxLayer).NumHiddenUnits;

for i = 1:sequenceLength
    [features(i,:),state] = predict(net,X(:,1)',Outputs="lstm");
    net.State = state;
features = features';

Visualize the first 10 hidden units using a heatmap.

xlabel("Time Step")
ylabel("Hidden Unit")
title("LSTM Activations")

The heatmap shows how strongly each hidden unit activates and highlights how the activations change over time.


[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset.

See Also

| | | | | | |

Related Topics