This is machine translation

Translated by Microsoft
Mouseover text to see original. Click the button below to return to the English version of the page.

Note: This page has been translated by MathWorks. Click here to see
To view all translated materials including this page, select Country from the country navigator on the bottom of this page.

Classify Text Data Using Deep Learning

This example shows how to classify text descriptions of weather reports using a deep learning long short-term memory (LSTM) network.

Text data is naturally sequential. A piece of text is a sequence of words, which might have dependencies between them. To learn and use long-term dependencies to classify sequence data, use an LSTM neural network. An LSTM network is a type of recurrent neural network (RNN) that can learn long-term dependencies between time steps of sequence data.

To input text to an LSTM network, first convert the text data into numeric sequences. You can achieve this using a word encoding which maps documents to sequences of numeric indices. For better results, also include a word embedding layer in the network. Word embeddings map words in a vocabulary to numeric vectors rather than scalar indices. These embeddings capture semantic details of the words, so that words with similar meanings have similar vectors. They also model relationships between words through vector arithmetic. For example, the relationship "king is to queen as man is to woman" is described by the equation king man + woman = queen.

There are four steps in training and using the LSTM network in this example:

  • Import and preprocess the data.

  • Convert the words to numeric sequences using a word encoding.

  • Create and train an LSTM network with a word embedding layer.

  • Classify new text data using the trained LSTM network.

Import Data

Import the weather reports data. This data contains labeled textual descriptions of weather events. To import the text data as strings, specify the text type to be 'string'.

filename = "weatherReports.csv";
data = readtable(filename,'TextType','string');
head(data)
ans=8×16 table
            Time             event_id          state              event_type         damage_property    damage_crops    begin_lat    begin_lon    end_lat    end_lon                                                                                             event_narrative                                                                                             storm_duration    begin_day    end_day    year       end_timestamp    
    ____________________    __________    ________________    ___________________    _______________    ____________    _________    _________    _______    _______    _________________________________________________________________________________________________________________________________________________________________________________________________    ______________    _________    _______    ____    ____________________

    22-Jul-2016 16:10:00    6.4433e+05    "MISSISSIPPI"       "Thunderstorm Wind"       ""                "0.00K"         34.14        -88.63     34.122     -88.626    "Large tree down between Plantersville and Nettleton."                                                                                                                                                  00:05:00          22          22       2016    22-Jul-0016 16:15:00
    15-Jul-2016 17:15:00    6.5182e+05    "SOUTH CAROLINA"    "Heavy Rain"              "2.00K"           "0.00K"         34.94        -81.03      34.94      -81.03    "One to two feet of deep standing water developed on a street on the Winthrop University campus after more than an inch of rain fell in less than an hour. One vehicle was stalled in the water."       00:00:00          15          15       2016    15-Jul-0016 17:15:00
    15-Jul-2016 17:25:00    6.5183e+05    "SOUTH CAROLINA"    "Thunderstorm Wind"       "0.00K"           "0.00K"         35.01        -80.93      35.01      -80.93    "NWS Columbia relayed a report of trees blown down along Tom Hall St."                                                                                                                                  00:00:00          15          15       2016    15-Jul-0016 17:25:00
    16-Jul-2016 12:46:00    6.5183e+05    "NORTH CAROLINA"    "Thunderstorm Wind"       "0.00K"           "0.00K"         35.64        -82.14      35.64      -82.14    "Media reported two trees blown down along I-40 in the Old Fort area."                                                                                                                                  00:00:00          16          16       2016    16-Jul-0016 12:46:00
    15-Jul-2016 14:28:00    6.4332e+05    "MISSOURI"          "Hail"                    ""                ""              36.45        -89.97      36.45      -89.97    ""                                                                                                                                                                                                      00:07:00          15          15       2016    15-Jul-0016 14:35:00
    15-Jul-2016 16:31:00    6.4332e+05    "ARKANSAS"          "Thunderstorm Wind"       ""                "0.00K"         35.85         -90.1     35.838     -90.087    "A few tree limbs greater than 6 inches down on HWY 18 in Roseland."                                                                                                                                    00:09:00          15          15       2016    15-Jul-0016 16:40:00
    15-Jul-2016 16:03:00    6.4343e+05    "TENNESSEE"         "Thunderstorm Wind"       "20.00K"          "0.00K"        35.056       -89.937      35.05     -89.904    "Awning blown off a building on Lamar Avenue. Multiple trees down near the intersection of Winchester and Perkins."                                                                                     00:07:00          15          15       2016    15-Jul-0016 16:10:00
    15-Jul-2016 17:27:00    6.4344e+05    "TENNESSEE"         "Hail"                    ""                ""             35.385        -89.78     35.385      -89.78    "Quarter size hail near Rosemark."                                                                                                                                                                      00:05:00          15          15       2016    15-Jul-0016 17:32:00

Remove the rows of the table with empty reports.

idxEmpty = strlength(data.event_narrative) == 0;
data(idxEmpty,:) = [];

The goal of this example is to classify events by the label in the event_type column. To divide the data into classes, convert these labels to categorical.

data.event_type = categorical(data.event_type);

View the distribution of the classes in the data using a histogram. To make the labels easier to read, increase the width of the figure.

f = figure;
f.Position(3) = 1.5*f.Position(3);

h = histogram(data.event_type);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

The classes of the data are imbalanced, with many classes containing few observations. When the classes are imbalanced in this way, the network might converge to a less accurate model. To prevent this problem, remove any classes which appear fewer than ten times.

Get the frequency counts of the classes and the class names from the histogram.

classCounts = h.BinCounts;
classNames = h.Categories;

Find the classes containing fewer than ten observations.

idxLowCounts = classCounts < 10;
infrequentClasses = classNames(idxLowCounts)
infrequentClasses = 1×8 cell array
    {'Freezing Fog'}    {'Hurricane'}    {'Lakeshore Flood'}    {'Marine Dense Fog'}    {'Marine Strong Wind'}    {'Marine Tropical Depression'}    {'Seiche'}    {'Sneakerwave'}

Remove these infrequent classes from the data. Use removecats to remove the unused categories from the categorical data.

idxInfrequent = ismember(data.event_type,infrequentClasses);
data(idxInfrequent,:) = [];
data.event_type = removecats(data.event_type);

Now the data is sorted into classes of reasonable size. The next step is to partition it into sets for training, validation, and testing. Partition the data into a training partition and a held-out partition for validation and testing. Specify the holdout percentage to be 30%.

cvp = cvpartition(data.event_type,'Holdout',0.3);
dataTrain = data(training(cvp),:);
dataHeldOut = data(test(cvp),:);

Partition the held-out set again to get a validation set. Specify the holdout percentage to be 50%. This results in a partitioning of 70% training observations, 15% validation observations, and 15% test observations.

cvp = cvpartition(dataHeldOut.event_type,'HoldOut',0.5);
dataValidation = dataHeldOut(training(cvp),:);
dataTest = dataHeldOut(test(cvp),:);

Extract the text data and labels from the partitioned tables.

textDataTrain = dataTrain.event_narrative;
textDataValidation = dataValidation.event_narrative;
textDataTest = dataTest.event_narrative;
YTrain = dataTrain.event_type;
YValidation = dataValidation.event_type;
YTest = dataTest.event_type;

To check that you have imported the data correctly, visualize the training text data using a word cloud.

figure
wordcloud(textDataTrain);
title("Training Data")

Preprocess Text Data

Preprocess the training data. Convert the text to lowercase, tokenize it, and then erase the punctuation. Do not stem or remove words, as these steps can lead to a worse word embedding fit.

textDataTrain = lower(textDataTrain);
documentsTrain = tokenizedDocument(textDataTrain);
documentsTrain = erasePunctuation(documentsTrain);

textDataValidation = lower(textDataValidation);
documentsValidation = tokenizedDocument(textDataValidation);
documentsValidation = erasePunctuation(documentsValidation);

View the first few preprocessed training documents.

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

   (1,1)   7 tokens: large tree down between plantersville and nettleton
   (2,1)  37 tokens: one to two feet of deep standing water developed on a stre…
   (3,1)  13 tokens: nws columbia relayed a report of trees blown down along to…
   (4,1)  13 tokens: media reported two trees blown down along i40 in the old f…
   (5,1)  14 tokens: a few tree limbs greater than 6 inches down on hwy 18 in r…

Convert Document to Sequences

To input the documents into an LSTM network, use a word encoding to convert the documents into sequences of numeric indices.

To create a word encoding, use the wordEncoding function.

enc = wordEncoding(documentsTrain);

The next conversion step is to pad and truncate documents so they are all the same length. The trainingOptions function provides options to pad and truncate input sequences automatically. However, these options are not well suited for sequences of word vectors. Instead, pad and truncate the sequences manually. If you left-pad and truncate the sequences of word vectors, then the training might improve.

To pad and truncate the documents, first choose a target length, and then truncate documents that are longer than it and left-pad documents that are shorter than it. For best results, the target length should be short without discarding large amounts of data. To find a suitable target length, view a histogram of the training document lengths.

documentLengths = doclength(documentsTrain);
figure
histogram(documentLengths)
title("Document Lengths")
xlabel("Length")
ylabel("Number of Documents")

Most of the training documents have fewer than 75 tokens. Use this as your target length for truncation and padding.

Convert the documents to sequences of numeric indices using doc2sequence. To truncate or left-pad the sequences to have length 75, set the 'Length' option to 75.

XTrain = doc2sequence(enc,documentsTrain,'Length',75);
XTrain(1:5)
ans = 5×1 cell array
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}

Convert the validation documents to sequences using the same options.

XValidation = doc2sequence(enc,documentsValidation,'Length',75);

Create and Train LSTM Network

Define the LSTM network architecture. To input sequence data into the network, include a sequence input layer and set the input size to 1. Next, include a word embedding layer of dimension 100 and the same number of words as the word encoding. Next, include an LSTM layer and set the number of hidden units to 180. To use the LSTM layer for a sequence-to-label classification problem, set the output mode to 'last'. Finally, add a fully connected layer with the same size as the number of classes, a softmax layer, and a classification layer.

inputSize = 1;
embeddingDimension = 100;
numHiddenUnits = enc.NumWords;
hiddenSize = 180;
numClasses = numel(categories(YTrain));

layers = [ ...
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(embeddingDimension,numHiddenUnits)
    lstmLayer(hiddenSize,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  6x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 1 dimensions
     2   ''   Word Embedding Layer    Word embedding layer with 100 dimensions and 16954 unique words
     3   ''   LSTM                    LSTM with 180 hidden units
     4   ''   Fully Connected         39 fully connected layer
     5   ''   Softmax                 softmax
     6   ''   Classification Output   crossentropyex

Specify the training options. Set the solver to 'adam', train for 10 epochs, and set the gradient threshold to 1. Set the initial learn rate to 0.01. To monitor the training progress, set the 'Plots' option to 'training-progress'. Specify the validation data using the 'ValidationData' option. To suppress verbose output, set 'Verbose' to false.

By default, trainNetwork uses a GPU if one is available (requires Parallel Computing Toolbox™ and a CUDA® enabled GPU with compute capability 3.0 or higher). Otherwise, it uses the CPU. To specify the execution environment manually, use the 'ExecutionEnvironment' name-value pair argument of trainingOptions. Training on a CPU can take significantly longer than training on a GPU.

options = trainingOptions('adam', ...
    'MaxEpochs',10, ...    
    'GradientThreshold',1, ...
    'InitialLearnRate',0.01, ...
    'ValidationData',{XValidation,YValidation}, ...
    'Plots','training-progress', ...
    'Verbose',false);

Train the LSTM network using the trainNetwork function.

net = trainNetwork(XTrain,YTrain,layers,options);

Test LSTM Network

To test the LSTM network, first prepare the test data in the same way as the training data. Then make predictions on the preprocessed test data using the trained LSTM network net.

Preprocess the test data using the same steps as the training documents.

textDataTest = lower(textDataTest);
documentsTest = tokenizedDocument(textDataTest);
documentsTest = erasePunctuation(documentsTest);

Convert the test documents to sequences using doc2sequence with the same options as when creating the training sequences.

XTest = doc2sequence(enc,documentsTest,'Length',75);
XTest(1:5)
ans = 5×1 cell array
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}

Classify the test documents using the trained LSTM network.

YPred = classify(net,XTest);

Calculate the classification accuracy. The accuracy is the proportion of labels that the network predicts correctly.

accuracy = sum(YPred == YTest)/numel(YPred)
accuracy = 0.8691

Predict Using New Data

Classify the event type of three new weather reports. Create a string array containing the new weather reports.

reportsNew = [ ...
    "Lots of water damage to computer equipment inside the office."
    "A large tree is downed and blocking traffic outside Apple Hill."
    "Damage to many car windshields in parking lot."];

Preprocess the text data using the same steps as the training documents.

reportsNew = lower(reportsNew);
documentsNew = tokenizedDocument(reportsNew);
documentsNew = erasePunctuation(documentsNew);

Convert the text data to sequences using doc2sequence with the same options as when creating the training sequences.

XNew = doc2sequence(enc,documentsNew,'Length',75);

Classify the new sequences using the trained LSTM network.

[labelsNew,score] = classify(net,XNew);

Show the weather reports with their predicted labels.

[reportsNew string(labelsNew)]
ans = 3×2 string array
    "lots of water damage to computer equipment inside the office."      "Flash Flood"      
    "a large tree is downed and blocking traffic outside apple hill."    "Thunderstorm Wind"
    "damage to many car windshields in parking lot."                     "Hail"             

See Also

| | | | | | | |

Related Topics