Main Content

Use Signal Feature Extraction to Train PyTorch Fault Detection Model

Since R2026a

This example shows how to use signal feature extraction objects to identify faulty bearing signals in mechanical systems. The example uses signal feature extraction objects and the MATLAB® Python® interface to train a PyTorch® fault detection model.

Feature extraction objects compute multiple features efficiently by minimizing repeated domain transformations.

This example extends the Machine Learning and Deep Learning Classification Using Signal Feature Extraction Objects example by showing how to use signal feature extraction objects to train a PyTorch model. To learn how to accelerate feature extraction using a GPU or a parallel pool of CPU workers, see Accelerate Signal Feature Extraction and Classification Using a GPU and Accelerate Signal Feature Extraction and Classification Using a Parallel Pool of Workers.

Download Data

The data set in this example contains acceleration signals collected from rotating machines in a bearing test rig and from real-world machines such as oil pump bearings, intermediate-speed bearings, and planet bearings. There are 34 files in the data set. The signals in the files are sampled at a frequency of 48828 Hz. The filenames describe the signals they contain:

  • HealthySignal_*.mat Healthy signals

  • InnerRaceFault_*.mat Signals with inner race faults

  • OuterRaceFault_*.mat Signals with outer race faults

Download the data files into your temporary directory, whose location is specified by the tempdir command in MATLAB. If you want to place the data files in a folder different from tempdir, change the directory name in the subsequent instructions.

dataURL = "https://www.mathworks.com/supportfiles/SPT/data/rollingBearingDataset.zip";
datasetFolder = fullfile(tempdir,"rollingBearingDataset");
zipFile = fullfile(tempdir,"rollingBearingDataset.zip");
if ~exist(datasetFolder,"dir")
    websave(zipFile,dataURL);
    unzip(zipFile,datasetFolder);
end

Create a signalDatastore object to access the data in the files and obtain the labels. Set the OutputDataType property to single for memory and computational efficiency, especially when using a GPU.

sds = signalDatastore(datasetFolder,OutputDataType="single");

Filenames in the data set includes the labels. Get a list of labels from the filenames in the datastore using the filenames2labels function.

labels = filenames2labels(sds,ExtractBefore=pattern("Signal"|"Fault"));

Set Up Feature Extraction Objects

In this section, you set up feature extraction objects that extract multidomain features from the signals. You will then use these features to implement machine learning and deep learning solutions to classify signals as healthy, as having inner race faults, or as having outer race faults.

Use the signalTimeFeatureExtractor, signalFrequencyFeatureExtractor, and signalTimeFrequencyFeatureExtractor objects to extract features from all the signals.

  • For time domain, use root-mean-square value, impulse factor, standard deviation, and clearance factor as features.

  • For frequency domain, use median frequency, band power, power bandwidth, and peak amplitude of the power spectral density (PSD) as features.

  • For time-frequency domain, use time-averaged wavelet spectrum as a feature.

fs = 48828;

Create a time-domain feature extractor to extract time-domain features.

timeFE = signalTimeFeatureExtractor(SampleRate=fs, ...
    RMS=true, ...
    ImpulseFactor=true, ...
    StandardDeviation=true, ...
    ClearanceFactor=true);

Create a frequency-domain feature extractor to extract frequency-domain features.

freqFE = signalFrequencyFeatureExtractor(SampleRate=fs, ...
    MedianFrequency=true, ...
    BandPower=true, ...
    PowerBandwidth=true, ...
    PeakAmplitude=true);

Create a time-frequency feature extractor to extract time-frequency features from scalogram.

timeFreqFE = signalTimeFrequencyFeatureExtractor(SampleRate=fs, ...
    TimeSpectrum=true);

setExtractorParameters(timeFreqFE,"scalogram", ...
    VoicesPerOctave=16,FrequencyLimits=[50 20000]);

Extract Multidomain Features

Each signal in the signalDatastore object sds has around 150,000 samples. Window each signal into 2000-sample signal frames and extract multidomain features from it using all three feature extractors. You can achieve this windowing by setting the FrameSize property for all three feature extractors to 2000.

frameSize = 2000;
timeFE.FrameSize = frameSize;
freqFE.FrameSize = frameSize;
timeFreqFE.FrameSize = frameSize;

Features extracted from frames correspond to a sequence of features over time that have a lower dimension than the original signal. Reducing the dimension helps the LSTM network to train faster. The workflow follows these steps:

  1. Split the signal datastore and labels into training and test sets.

  2. For each signal in the training and test sets, use all three feature extractor objects to extract features for multiple signal frames. Concatenate the multidomain features to obtain the feature matrix.

  3. Normalize the training and testing feature matrices.

  4. Train the PyTorch LSTM model using the labels and feature matrices.

  5. Classify the signals using the trained PyTorch network.

Split the labels into training and testing sets. Use 70% of the labels to train the network and the remaining 30% to test it. Use splitlabels to partition the labels. Partitioning the labels guarantees that each split data set contains similar label proportions as the entire data set. Obtain the corresponding datastore subsets from signalDatastore. Reset the random number generator for reproducible results.

rng("default")

splitIndices = splitlabels(labels,0.7,"randomized");
trainIdx = splitIndices{1};
testIdx = splitIndices{2};

trainDs = subset(sds,trainIdx);
trainLabels = labels(trainIdx);
allCategories = categories(labels);  % use a order based on full label set
trainLabelsEncoded = onehotencode( ...
    categorical(trainLabels,allCategories), ...
    2,"single");

testDs = subset(sds,testIdx);
testLabels = labels(testIdx);

Obtain features from the training datastore using all three feature extractors.

trainFeatures = cellfun(@(a,b,c) [a b c], ...
extract(timeFE,trainDs), ...
extract(freqFE,trainDs), ...
extract(timeFreqFE,trainDs), ...
UniformOutput=false);

Follow the same workflow to obtain test features.

testFeatures = cellfun(@(a,b,c) [a b c], ...
    extract(timeFE,testDs), ...
    extract(freqFE,testDs), ...
    extract(timeFreqFE,testDs), ...
    UniformOutput=false);
[trainFeaturesNorm, testFeaturesNorm] = helperGetNormalizedLSTMFeatureMatrices(trainFeatures, testFeatures);

Set Up Python Environment

To install a supported python implementation, see Configure Your System to Use Python. To avoid library conflicts, use the External Languages side panel in MATLAB to create a python virtual environment using the requirements_faultdetect.txt file. For details on the External Languages side panel, see Manage Python Environments Using External Languages Panel. For details on Python environment execution modes and debugging Python from MATLAB, see Python Coexecution.

Use the helperCheckPyenv function to verify that the current PythonEnvironment object contains the libraries listed in the requirements_faultdetect.txt file. This example was tested using Python 3.11.2.

reqFileName = "requirements_faultdetect.txt";
currentPyenv = helperCheckPyenv(reqFileName,Verbose=true);
Checking Python environment
Parsing requirements_faultdetect.txt 
Checking required package 'torch'
Checking required package 'numpy'
Required Python libraries are installed.

You can use the following process ID and name to attach a debugger to the Python interface and debug the example code.

fprintf("Process ID for '%s' is %s.\n", ...
currentPyenv.ProcessName,currentPyenv.ProcessID)
Process ID for 'MATLABPyHost' is 3596083.

The fault_detector.py helper module contains the neural network definition, training methods, and other functionalities. If the Python helper modules changed after loading them in memory, rerun the helperCheckPyenv function to reload them.

Train LSTM Network

Initialize an LSTM network with 147 features, 50 hidden units, and three classes. The faultDetector variable is the PyTorch LSTM model. The model uses a GPU for training if a GPU is available and if you have installed a GPU-enabled PyTorch build. Otherwise, the model trains the network on the CPU.

numFeatures = size(trainFeatures{1},2);
numHiddenUnits = 50;
numClasses = length(categories(trainLabels));

faultDetector = py.fault_detector.construct_model( ...
    numFeatures, ...
    numHiddenUnits, ...
    numClasses);
py.fault_detector.info(faultDetector);
Model architecture:
lstmNN(
  (lstm): LSTM(147, 50, batch_first=True)
  (linear): Linear(in_features=50, out_features=3, bias=True)
)

Total number of parameters: 39953
device = py.fault_detector.select_device(verbose=true);
Selected device: GPU (NVIDIA RTX A5000, 23.56 GB)

The total size of the training data set is 24 samples. Use the full data set for each training iteration.

trainTensors = py.fault_detector.to_tensor(trainFeaturesNorm', device);
trainTensors.shape
ans = 
  Python Size with no properties.

    torch.Size([24, 29, 147])

trainLabelsTensors = py.fault_detector.to_tensor(trainLabelsEncoded, device);
trainLabelsTensors.shape
ans = 
  Python Size with no properties.

    torch.Size([24, 3])

Set the training options using the fault_detector Python helper module. Use the trainingProgressMonitor (Deep Learning Toolbox) object to plot the training loss. Alternatively, select the verbose check box to display the training progress information.

maxEpochs           =80;
initialLearningRate = 1e-3;
verbose             = false;
verboseFrequency    = 10; % Iterations
progressPlot        = true;

if progressPlot
    monitor = trainingProgressMonitor(Metrics="Loss", ...
        Info=["Epoch","LearningRate","ExecutionEnvironment"], ...
        XLabel="Iteration");
    monitor.Status = "Running";
    if strcmp(string(device.type),"cuda")
        environment = "GPU";
    else
        environment = "CPU";
    end
end

trainerObject = py.fault_detector.trainer(faultDetector, ...
    initialLearningRate, ...
    device);

tStart = tic;

% Training loop
for iteration = 1:maxEpochs

    trainerObject.train_step(trainTensors,trainLabelsTensors);

    if verbose && mod(iteration,verboseFrequency)==0
        t = seconds(toc(tStart)); %#ok
        t.Format = "hh:mm:ss";
        loss=trainerObject.loss_vector{iteration};
        disp(string(t) + ...
            " - Iteration "+iteration+"/" + maxEpochs + ...
            " - " + "Cross entropy loss: "+string(loss))
    end

    % Update the training progress monitor
    if progressPlot
        recordMetrics(monitor,iteration, ...
            Loss=trainerObject.loss_vector{iteration});
        updateInfo(monitor, ...
            Epoch=iteration + " of " + maxEpochs, ...
            LearningRate=initialLearningRate, ...
            ExecutionEnvironment=environment);
        monitor.Progress = 100*iteration/maxEpochs;
        if monitor.Stop
            monitor.Status = "Stopped by user";
            break
        end
    end
end

if progressPlot
    monitor.Status = "Training complete";
end

Select the saveModel to save the trained model weights. Set the save_dir variable to specify the directory to save the model to.

saveModel = false;

if saveModel
    modelFileName = sprintf("fault_detector_iter%d",maxEpochs); %#ok
    py.fault_detector.save_model_weights(faultDetector, ...
        modelFileName, ...
        save_dir=pwd);
end

Use the trained network to classify the signals in the test data set and analyze the accuracy of the network.

testTensors = py.fault_detector.to_tensor(testFeaturesNorm',device);
predictionsTensors = trainerObject.predict(testTensors);
predictions = py.fault_detector.from_tensor(predictionsTensors);

predictedLabels = onehotdecode(single(predictions),allCategories,2);
cm = confusionchart(testLabels,predictedLabels, ...
ColumnSummary="column-normalized",RowSummary="row-normalized");

Calculate the classifier accuracy.

accuracy = trace(cm.NormalizedValues)/sum(cm.NormalizedValues,"all");
fprintf("The classification accuracy on the test partition is %2.1f%%",accuracy*100)
The classification accuracy on the test partition is 100.0%

Summary

This example uses multidomain signal feature extraction together with a PyTorch LSTM deep learning network for motor bearing fault detection. To learn how to extract features and train models using a GPU, see Accelerate Signal Feature Extraction and Classification Using a GPU.

Helper Function

helperGetNormalizedLSTMFeatureMatrices – This function normalizes the training and test feature matrices using the mean and standard deviation of the training feature matrix.

function [trainFeaturesNorm,testFeaturesNorm] = helperGetNormalizedLSTMFeatureMatrices(trainFeatures,testFeatures)
%   This function is only intended support examples in the Signal
%   Processing Toolbox. It may be changed or removed in a future release

% Compute normalization parameters from training data
trainMatrix = cell2mat(trainFeatures);
featureMean = mean(trainMatrix,1,"omitnan");
featureStd = std(trainMatrix,0,1,"omitnan");

% Handle zero-variance features
zeroVarIdx = featureStd == 0;
featureStd(zeroVarIdx) = 1;  % Avoid division by zero

% Normalize training sequences
trainFeaturesNorm = cell(size(trainFeatures));
for i = 1:numel(trainFeatures)
    trainFeaturesNorm{i} = (trainFeatures{i}-featureMean)./ featureStd;
    trainFeaturesNorm{i}(~isfinite(trainFeaturesNorm{i})) = 0;
end

% Normalize test sequences using training parameters
testFeaturesNorm = cell(size(testFeatures));
for i = 1:numel(testFeatures)
    testFeaturesNorm{i} = (testFeatures{i}-featureMean)./ featureStd;
    testFeaturesNorm{i}(~isfinite(testFeaturesNorm{i})) = 0;
end

end

See Also

Functions

Objects

Topics