Imbalance in sequence-to-sequence classification

3 views (last 30 days)
I am using the LSTM network for binary sequence classification. My feature is a timeseries and I need to predict the ocurrence of 0 or 1 at every timestep (YTrain). The problem is that I have far fewer 1s than 0s in my YTrain dataset. The network basically predicts 0 at every timestep and still has very high accuracy. I am looking for a way to penalize misclassifications of the 1s in YTrain. I am grateful for any suggestions!
numFeatures = 1; numHiddenUnits = 200; numClasses = 2;
layers = [ ... sequenceInputLayer(numFeatures) lstmLayer(numHiddenUnits,'OutputMode','sequence') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer];
options = trainingOptions('adam', ... 'MaxEpochs',60, ... 'GradientThreshold',2, ... 'Verbose',0, ... 'Plots','training-progress');

Accepted Answer

Harsh
Harsh on 20 Dec 2024
Hi Cedric,
You can use a weighted cross-entropy loss function in the “trainnet” function to handle the imbalance in dataset. Determine the class weights based on the imbalance in your dataset. For example, if class 1 is underrepresented, you might assign it a higher weight. Please refer to the following page to understand the usage of weighted cross-entropy loss function - https://www.mathworks.com/help/deeplearning/ref/trainnet.html#:~:text=For%20weighted%20cross%2Dentropy%2C%20use%20the%20function%20handle%20%40(Y%2CT)crossentropy(Y%2CT%2Cweights)
You may check the following MATLAB answer also which is relevant to your question - https://www.mathworks.com/matlabcentral/answers/434918-weighted-classification-layer-for-time-series-lstm

More Answers (0)

Categories

Find more on Deep Learning Toolbox in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!