Main Content

Time-Frequency Feature Embedding with Deep Metric Learning

Since R2023a

This example shows how to use deep metric learning with a supervised contrastive loss to construct feature embeddings based on a time-frequency analysis of electroencephalographic (EEG) signals. The learned time-frequency embeddings reduce the dimensionality of the time-series data by a factor of 16. You can use these embeddings to classify EEG time-series from persons with and without epilepsy using a support vector machine classifier.

Deep Metric Learning

Deep metric learning attempts to learn a nonlinear feature embedding, or encoder, that reduces the distance (a metric) between examples from the same class and increases the distance between examples from different classes. Loss functions that work in this way are often referred to as contrastive. This example uses supervised deep metric learning with a particular contrastive loss function called the normalized temperature-scaled cross-entropy loss [3],[4],[8]. The figure shows the general workflow for this supervised deep metric learning.

metric_learning_text.png

Positive pairs refer to training samples with the same label, while negative pairs refer to training samples with different labels. A distance, or similarity, matrix is formed from the positive and negative pairs. In this example, the cosine similarity matrix is used. From these distances, losses are computed and aggregated (reduced) to form a single scalar-valued loss for use in gradient-descent learning.

Deep metric learning is also applicable in weakly supervised, self-supervised, and unsupervised contexts. There is a wide variety of distance (metrics) measures, losses, reducers, and regularizers that are employed in deep metric learning.

Data — Description, Attribution, and Download Instructions

The data used in this example is the Bonn EEG Data Set. The data is currently available at EEG Data Download and Ralph Andrzejak's EEG data download page. See Ralph Andrzejak's EEG data for legal conditions on the use of the data. The authors have kindly permitted the use of the data in this example.

The data in this example were first analyzed and reported in:

Andrzejak, Ralph G., Klaus Lehnertz, Florian Mormann, Christoph Rieke, Peter David, and Christian E. Elger. "Indications of Nonlinear Deterministic and Finite-Dimensional Structures in Time Series of Brain Electrical Activity: Dependence on Recording Region and Brain State." Physical Review E 64, no. 6 (2001). <https://doi.org/10.1103/physreve.64.061907>

The data consists of five sets of 100 single-channel EEG recordings. The resulting single-channel EEG recordings were selected from 128-channel EEG recordings after visually inspecting each channel for obvious artifacts and satisfying a weak stationarity criterion. See the linked paper for details.

The original paper designates these five sets as A-E. Each recording is 23.6 seconds in duration sampled at 173.61 Hz. Each time series contains 4097 samples. The conditions are as follows:

A -- Normal subjects with eyes open

B -- Normal subjects with eyes closed

C -- Seizure-free recordings from patients with epilepsy. Recording from hippocampus in the hemisphere opposite the epileptogenic zone

D -- Seizure-free recordings obtained from patients with epilepsy. Recordings from epileptogenic zone.

E - Recordings from patients with epilepsy showing seizure activity.

The zip files corresponding to this data are labeled as z.zip (A), o.zip (B), n.zip (C), f.zip (D), and s.zip (E).

The example assumes you have downloaded and unzipped the zip files into folders named Z, O, N, F, and S respectively. In MATLAB® you can do this by creating a parent folder and using that as the OUTPUTDIR variable in the unzip command. This example uses the folder designated by MATLAB as tempdir as the parent folder. If you choose to use a different folder, adjust the value of parentDir accordingly. The following code assumes that all the .zip files have been downloaded into parentDir. Unzip the files by folder into a subfolder called BonnEEG.

parentDir = tempdir;
cd(parentDir)
mkdir('BonnEEG')
dataDir = fullfile(parentDir,'BonnEEG');
unzip('z.zip',dataDir)
unzip('o.zip',dataDir)
unzip('n.zip',dataDir)
unzip('f.zip',dataDir)
unzip('s.zip',dataDir)

Creating In-Memory Data and Labels

The individual EEG time series are stored as .txt files in each of the Z, N, O, F, and S folders under dataDir. Use a tabularTextDatastore to read the data. Create a tabular text datastore and create a categorical array of signal labels based on the folder names.

tds = tabularTextDatastore(dataDir,'IncludeSubfolders',true,'FileExtensions','.txt');

The zip files were created on a macOS and accordingly there may be a MACOSX folder created with unzip that results in extra files. If those exist, remove them.

extraTXT = contains(tds.Files,'__MACOSX');
tds.Files(extraTXT) = [];

Create labels for the data based on the first letter of the text file name.

labels = filenames2labels(tds.Files,'ExtractBetween',[1 1]);

Each read of the tabular text datastore creates a table containing the data. Create a cell array of all signals reshaped as row vectors so they conform with the deep learning networks used in the example.

ii = 1;
eegData = cell(numel(labels),1);
while hasdata(tds)
    tsTable = read(tds);
    ts = tsTable.Var1;
    eegData{ii} = reshape(ts,1,[]);
    ii = ii+1;
end

Time-Frequency Feature Embedding Deep Network

Here we construct a deep learning network that creates an embedding of the input signal based on a time-frequency analysis.

TFnet = [sequenceInputLayer(1,'MinLength',4097,'Name',"input")
    cwtLayer('SignalLength',4097,'IncludeLowpass',true,'Wavelet','amor',...
    'FrequencyLimits',[0 0.23])
    convolution2dLayer([5,10],1,'stride',2)
    maxPooling2dLayer([5,10])
    convolution2dLayer([5,10],5,'Padding','same')
    maxPooling2dLayer([5,10])
    batchNormalizationLayer
    reluLayer
    convolution2dLayer([5,10],10,'Padding','same')
    maxPooling2dLayer([2,4])
    batchNormalizationLayer
    reluLayer
    flattenLayer
    globalAveragePooling1dLayer
    fullyConnectedLayer(256)];
TFnet = dlnetwork(TFnet);

After the input layer, the network obtains the continuous wavelet transform (CWT) of the data using the analytic Morlet wavelet. The output of cwtLayer is the magnitude of the CWT, or scalogram. Unlike the analyses in [1],[2], and [7], no pre-processing bandpass filter is used in this network. Instead, the CWT is obtained only over the frequency range of [0.0, 0.23] cycles/sample which is equivalent to [0,39.93] Hz for the sample rate of 173.61 Hz. This is the approximate range of the bandpass filter applied to the data before analysis in [1]. After the network obtains the scalogram, the network cascades a series of 2-D convolutional, batch normalization, and RELU layers. The final layer is a fully connected layer with 256 output units. This results in a 16-fold reduction in the size of the input. See [7] for another scalogram-based analysis of this data and [2] for another wavelet-based analysis using the tunable Q-factor wavelet transform.

Differentiating Normal, Pre-Seizure, and Seizure EEG

Given the five conditions present in the data, there are multiple meaningful and clinically informative ways to partition the data. One relevant way is to group the Z and O labels (non-epileptic subjects with eyes open and closed) as "Normal". Similarly, the two conditions recorded in the persons with epilepsy without overt seizure activity (N and F) may be grouped as "Pre-seizure". Finally, we designate the recordings obtained in epileptic subjects with seizure activity as "Seizure". To create labels, which may be cast to numeric values during training, designate these three classes as:

  • 0 -- "Normal"

  • 1 -- "Pre-seizure"

  • 2 -- "Seizure"

Partition the data into training and test sets. First, create the new labels in order to partition the data. Examine the number of examples in each class.

labelsPS = labels;
labelsPS = removecats(labelsPS,{'F','N','O','S','Z'});
labelsPS(labels == categorical("Z") | labels == categorical("O")) = categorical("0");
labelsPS(labels == categorical("N") | labels == categorical("F")) = categorical("1");
labelsPS(labels == categorical("S")) = categorical("2");
labelsPS(isundefined(labelsPS)) = [];
summary(labelsPS)
labelsPS: 500×1 categorical

     0                200 
     1                200 
     2                100 
     <undefined>        0 

The resulting classes are unbalanced with twice as many signals in the "Normal" and "Pre-seizure" categories as in the "Seizure" category. Partition the data for training the encoder and the hold-out test set. Allocate 80% of the data to the training set and 20% to the test set.

idxPS = splitlabels(labelsPS,[0.8 0.2]);
TrainDataPS = eegData(idxPS{1});
TrainLabelsPS = labelsPS(idxPS{1});
testDataPS = eegData(idxPS{2});
testLabelsPS = labelsPS(idxPS{2});

Training the Encoder

To train the encoder, set trainEmbedder to true. To skip the training and load a pretrained encoder and corresponding embeddings, set trainEmbedder to false and go to the Test Data Embeddings section.

trainEmbedder = true;

Because this example uses a custom loss function, you must use a custom training loop. To manage data through the custom training loop, use a signalDatastore (Signal Processing Toolbox) with a custom read function that normalizes the input signals to have zero mean and unit standard deviation.

if trainEmbedder
    sdsTrain = signalDatastore(TrainDataPS,MemberNames = string(TrainLabelsPS));
    transTrainDS = transform(sdsTrain,@(x,info)helperReadData(x,info),'IncludeInfo',true);
end

Train the network by measuring the normalized temperature-controlled cross-entropy loss between embeddings obtained from identical classes (corresponding to positive pairs) and disparate classes (corresponding to negative pairs) in each mini-batch. The custom loss function computes the cosine similarity between each training example, obtaining a M-by-M similarity matrix, where M is the mini-batch size. The function computes the normalized temperature-controlled cross entropy for the similarity matrix with the temperature parameter equal to 0.07. The function calculates the scalar loss as the mean of the mini-batch losses.

Specify Training Options

The model parameters are updated based on the loss using an Adam optimizer.

Train the encoder for 150 epochs with a mini-batch size of 50, a learning rate of 0.001, and an L2-regularization rate of 0.01.

if trainEmbedder
    NumEpochs = 150;
    minibatchSize = 50;
    learnRate = 0.001;
    l2Regularization = 1e-2;    
end

Calculate the number of iterations per epoch and the total number of iterations to display training progress.

if trainEmbedder
    numObservations = numel(TrainDataPS);
    numIterationsPerEpoch = floor(numObservations./minibatchSize);
    numIterations = NumEpochs*numIterationsPerEpoch;
end

Create a minibatchqueue (Deep Learning Toolbox) object to manage data flow through the custom training loop and train the encoder.

if trainEmbedder
    numOutputs = 2;
    mbqTrain = minibatchqueue(transTrainDS,numOutputs,...
        'minibatchSize',minibatchSize,...
        'OutputAsDlarray',[1,1],...
        'minibatchFcn',@processMB,...
        'OutputCast',{'single','single'},...
        'minibatchFormat', {'CBT','B'});  
    epoch = 0;
    iteration = 0;
    monitor = trainingProgressMonitor(Metrics= "TrainingLoss");
    monitor.Info = ["LearningRate","Epoch","Iteration","ExecutionEnvironment"];
    monitor.XLabel = "Iteration";
    monitor.Status = "Configuring";
    monitor.Progress = 0;
    trainENV = string(mbqTrain.OutputEnvironment);
    if canUseGPU && trainENV(1) == "auto" || trainENV(1) == "gpu"
        updateInfo(monitor,ExecutionEnvironment="GPU");
    else
        updateInfo(monitor,ExecutionEnvironment="CPU");
    end
    % Initialize some training loop variables
    trailingAvg = [];
    trailingAvgSq = [];  
    % shuffle once
    shuffle(mbqTrain)
    while epoch < NumEpochs 
        epoch = epoch+1;
        reset(mbqTrain)        
        % Loop over mini-batches
        while hasdata(mbqTrain)
            iteration = iteration+1;
            % Get the next mini-batch and one-hot coded targets
            [dlX,Y] = next(mbqTrain);
            % Evaluate the model gradients and contrastive loss
            [gradients, loss, state] = dlfeval(@modelGradcontrastiveLoss,TFnet,dlX,Y);            
            % Update the gradients with the L2-regularization rate
            idx = TFnet.Learnables.Parameter == "Weights";
            gradients(idx,:) = ...
                dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), TFnet.Learnables(idx,:));
            % Update the network state
            TFnet.State = state;
            % Update the network parameters using an Adam optimizer
            [TFnet,trailingAvg,trailingAvgSq] = adamupdate(...
                TFnet,gradients,trailingAvg,trailingAvgSq,iteration,learnRate);

            % Display the training progress
            recordMetrics(monitor,iteration,TrainingLoss=loss);

            % Update learning rate, epoch, and iteration information values.
            updateInfo(monitor, ...
                LearningRate=learnRate, ...
                Epoch=string(epoch) + " of " + string(NumEpochs), ...
                Iteration=string(iteration) + " of " + string(numIterations));
            % Update progress percentage.
            monitor.Progress = 100*iteration/numIterations;            
        end        
    end    
end

Test Data Embeddings

Obtain the embeddings for the test data. If you set trainEmbedder to false, you can load the trained encoder and embeddings obtained using the helperEmbedTestFeatures function.

if trainEmbedder
    finalEmbeddingsTableTrain = helperEmbedFeatures(TFnet,TrainDataPS,TrainLabelsPS);    
else
    load('TFnet.mat'); %#ok<*UNRCH>
    load('finalEmbeddingsTableTrain.mat');
    load('embeddingsTableTest.mat');
end

Use a support vector machine (SVM) classifier with a Gaussian kernel to classify the embeddings.

template = templateSVM(...
    'KernelFunction', 'gaussian', ...
    'PolynomialOrder', [], ...
    'KernelScale', 4, ...
    'BoxConstraint', 1, ...
    'Standardize', true);
classificationSVM = fitcecoc(...
    finalEmbeddingsTableTrain, ...
    "EEGClass", ...
    'Learners', template, ...
    'Coding', 'onevsone');

Determine the cross-validation accuracy of the feature embeddings obtained from the training-set embeddings. Use five-fold cross validation.

partitionedModel = crossval(classificationSVM, 'KFold', 5);
[validationPredictions, validationScores] = kfoldPredict(partitionedModel);
validationAccuracy =  ...
    (1 - kfoldLoss(partitionedModel, 'LossFun', 'ClassifError'))*100
validationAccuracy = single

96

The cross-validation accuracy is excellent at approximately 96%. If you trained the encoder, obtain the embeddings for the held-out test set.

if trainEmbedder
    embeddingsTableTest = helperEmbedFeatures(TFnet,testDataPS,testLabelsPS);    
end

Show the final test performance of the trained encoder. The recall and precision performance for all three classes is excellent. The learned feature embeddings

provide nearly 100% recall and precision for the normal (0), pre-seizure (1), and seizure classes (2). Each embedding represents a reduction in the input

size from 4097 samples to 256 samples.

predLabelsFinal = predict(classificationSVM,embeddingsTableTest);
testAccuracyFinal = sum(predLabelsFinal == testLabelsPS)/numel(testLabelsPS)*100
testAccuracyFinal = 
99
hf = figure;
confusionchart(hf,testLabelsPS,predLabelsFinal,'RowSummary','row-normalized',...
    'ColumnSummary','column-normalized');
set(gca,'Title','Confusion Chart -- Trained Embeddings')

Figure contains an object of type ConfusionMatrixChart. The chart of type ConfusionMatrixChart has title Confusion Chart -- Trained Embeddings.

Note that we have used all the 256 embeddings in the SVM model, but the embeddings returned by the encoder are always amenable to further reduction by using feature selection techniques such as neighborhood component analysis, minimum redundancy maximum relevance (MRMR), or principal component analysis. See Introduction to Feature Selection (Statistics and Machine Learning Toolbox) for more details.

Summary

In this example, a time-frequency convolutional network was used as the basis for learning feature embeddings using a deep metric model. Specifically, the normalized temperature-controlled cross-entropy loss with cosine similarities was used to obtain the embeddings. The embeddings were then used with a SVM with a Gaussian kernel to achieve near perfect test performance. There are a number of ways this deep metric network can be optimized which are not explored in this example. For example, the size of the embeddings can likely be reduced further without affecting performance while achieving further dimensionality reduction. Additionally, there are a large number of similarity (metrics) measures, loss functions, regularizers, and reducers which are not explored in this example. Finally, the resulting embeddings are compatible with any machine learning algorithm. An SVM was used in this example, but you can explore the feature embeddings in the Classification Learner app and may find that another classification algorithm is more robust for your application.

References

[1] Andrzejak, Ralph G., Klaus Lehnertz, Florian Mormann, Christoph Rieke, Peter David, and Christian E. Elger. "Indications of Nonlinear Deterministic and Finite-Dimensional Structures in Time Series of Brain Electrical Activity: Dependence on Recording Region and Brain State." Physical Review E 64, no. 6 (2001). https://doi.org/10.1103/physreve.64.061907.

[2] Bhattacharyya, Abhijit, Ram Pachori, Abhay Upadhyay, and U. Acharya. "Tunable-Q Wavelet Transform Based Multiscale Entropy Measure for Automated Classification of Epileptic EEG Signals." Applied Sciences 7, no. 4 (2017): 385. https://doi.org/10.3390/app7040385.

[3] Chen, Ting, Simon Kornblith, Mohammed Norouzi, and Geoffrey Hinton. "A Simple Framework for Contrastive Learning of Visual Representations." (2020). https://arxiv.org/abs/2002.05709

[4] He, Kaiming, Fan, Haoqi, Wu, Yuxin, Xie, Saining, Girschick, Ross. "Momentum Contrast for Unsupervised Visual Representation Learning." (2020). https://arxiv.org/abs/1911.05722

[6] Musgrave, Kevin. "PyTorch Metric Learning" https://kevinmusgrave.github.io/pytorch-metric-learning/

[7] Türk, Ömer, and Mehmet Siraç Özerdem. “Epilepsy Detection by Using Scalogram Based Convolutional Neural Network from EEG Signals.” Brain Sciences 9, no. 5 (2019): 115. https://doi.org/10.3390/brainsci9050115.

[8] Van den Oord, Aaron, Li, Yazhe, and Vinyals, Oriol. "Representation Learning with Contrastive Predictive Coding." (2019). https://arxiv.org/abs/1807.03748

function [grads,loss,state] = modelGradcontrastiveLoss(net,X,T)
% This function is only for use in the "Time-Frequency Feature Embedding
% with Deep Metric Learning" example. It may change or be removed in a
% future release.

% Copyright 2022, The Mathworks, Inc.
[y,state] = net.forward(X);
loss = contrastiveLoss(y,T);
grads = dlgradient(loss,net.Learnables);
loss = double(gather(extractdata(loss)));
end

function [out,info] = helperReadData(x,info)
% This function is only for use in the "Time-Frequency Feature Embedding
% with Deep Metric Learning" example. It may change or be removed in a
% future release.

% Copyright 2022, The Mathworks, Inc.
mu = mean(x,2);
stdev = std(x,1,2);
z = (x-mu)./stdev;
out = {z,info.MemberName};
end

function [dlX,dlY] = processMB(Xcell,Ycell)
% This function is only for use in the "Time-Frequency Feature Embedding
% with Deep Metric Learning" example. It may change or be removed in a
% future release.

% Copyright 2022, The Mathworks, Inc.
Xcell = cellfun(@(x)reshape(x,1,1,[]),Xcell,'uni',false);
Ycell = cellfun(@(x)str2double(x),Ycell,'uni',false);
dlX = cat(2,Xcell{:});
dlY = cat(1,Ycell{:});
end

function testFeatureTable = helperEmbedFeatures(net,testdata,testlabels)
% This function is only for use in the "Time-Frequency Feature Embedding
% with Deep Metric Learning" example. It may change or be removed in a
% future release.

% Copyright 2022, The Mathworks, Inc.
testFeatures = zeros(length(testlabels),256,'single');
for ii = 1:length(testdata)
    yhat = predict(net,dlarray(reshape(testdata{ii},1,1,[]),'CBT'));
    yhat= extractdata(gather(yhat));
    testFeatures(ii,:) = yhat;
end
testFeatureTable = array2table(testFeatures);
testFeatureTable = addvars(testFeatureTable,testlabels,...
    'NewVariableNames',"EEGClass");
end

function loss = contrastiveLoss(features,targets)
% This function is for is only for use in the "Time-Frequency Feature
% Embedding with Deep Metric Learning" example. It may change or be removed
% in a future release.
%
% Replicates code in PyTorch Metric Learning 
% https://github.com/KevinMusgrave/pytorch-metric-learning.
% Python algorithms due to Kevin Musgrave

% Copyright 2022, The Mathworks, Inc. 
    loss = infoNCE(features,targets);
end

function loss = infoNCE(embed,labels)
    ref_embed = embed;
    [posR,posC,negR,negC] = convertToPairs(labels);
    dist = cosineSimilarity(embed,ref_embed);
    loss = pairBasedLoss(dist,posR,posC,negR,negC);
end

function [posR,posC,negR,negC] = convertToPairs(labels)
    Nr = length(labels);
    % The following provides a logical matrix which indicates where
    % the corresponding element (i,j) of the covariance matrix of
    % features comes from the same class or not. At each (i,j)
    % coming from the same class we have a 1, at each (i,j) from a
    % different class we have 0. Of course the diagonal is 1s.
    labels = stripdims(labels);
    matches = (labels == labels');
    % Logically negate the matches matrix to obtain differences.
    differences = ~matches;
    % We negate the diagonal of the matches matrix to avoid biasing
    % the learning. Later when we identify the positive and
    % negative indices, these diagonal elements will not be picked
    % up.
    matches(1:Nr+1:end) = false;
    [posR,posC,negR,negC] = getAllPairIndices(matches,differences);

end

function dist = cosineSimilarity(emb,ref_embed)
    emb = stripdims(emb);
    ref_embed = stripdims(ref_embed);
    normEMB = emb./sqrt(sum(emb.*emb,1));
    normREF = ref_embed./sqrt(sum(ref_embed.*ref_embed,1));
    dist = normEMB'*normREF;
end

function loss = pairBasedLoss(dist,posR,posC,negR,negC)
    if any([isempty(posR),isempty(posC),isempty(negR),isempty(negC)])
        loss = dlarray(zeros(1,1,'like',dist));
        return;
    end
    Temperature = 0.07;
    dtype = underlyingType(dist);
    idxPos = sub2ind(size(dist),posR,posC);
    pos_pair = dist(idxPos);
    pos_pair = reshape(pos_pair,[],1);
    idxNeg = sub2ind(size(dist),negR,negC);
    neg_pair = dist(idxNeg);
    neg_pair = reshape(neg_pair,[],1);
    pos_pair = pos_pair./Temperature;
    neg_pair = neg_pair./Temperature;
    n_per_p = negR' == posR;
    neg_pairs = neg_pair'.*n_per_p;
    neg_pairs(n_per_p==0) = -realmax(dtype);
    maxNeg = max(neg_pairs,[],2);
    maxPos = max(pos_pair,[],2);
    maxVal = max(maxPos,maxNeg);
    numerator = exp(pos_pair-maxVal);
    denominator = sum(exp(neg_pairs-maxVal),2)+numerator;
    logexp = log((numerator./denominator)+realmin(dtype));
    loss = mean(-logexp,'all');
end

function [posR,posC,negR,negC] = getAllPairIndices(matches,differences)
    % Here we just get the row and column indices of the anchor
    % positive and anchor negative elements.
    [posR, posC] = find(extractdata(matches));
    [negR,negC] = find(extractdata(differences));
end

See Also

Apps

Functions

Objects

Topics