Prune Image Classification Network Using Taylor Scores
This example shows how to reduce the size of a deep neural network using Taylor pruning. By using the taylorPrunableNetwork
function to remove convolution layer filters, you can reduce the overall network size and increase the inference speed.
Network pruning is a powerful model compression tool that helps identify redundancies that can be removed with little impact on the final network output. Pruning is particularly useful in transfer learning, where the network is often overparameterized.
This example uses a residual network trained on the CIFAR-10 data set. For more information, see Train Residual Network for Image Classification.
Load Pretrained Network and Data
Download the CIFAR-10 data set [1] using the downloadCIFARData
function, attached to this example as a supporting file. To access this file, open the example as a live script. The data set contains 60,000 images. Each image is 32-by-32 in size and has three color channels (RGB). The size of the data set is 175 MB. Depending on your internet connection, the download process can take time.
datadir = tempdir; downloadCIFARData(datadir);
Downloading CIFAR-10 dataset (175 MB). This can take a while...done.
Load the trained network for pruning.
load("CIFARNetDlNetwork","trainedNet");
Load the CIFAR-10 training and test images as 4-D arrays. The training set contains 50,000 images and the test set contains 10,000 images. Convert the images to an augmentedImageDatastore
for training and validation.
[XTrain,TTrain,XTest,TTest] = loadCIFARData(datadir); inputSize = trainedNet.Layers(1).InputSize; augimdsTrain = augmentedImageDatastore(inputSize,XTrain,TTrain); augimdsTest = augmentedImageDatastore(inputSize,XTest,TTest); classes = categories(TTest);
Create a minibatchqueue
object that processes and manages mini-batches of images during training. For each mini-batch:
Use the custom mini-batch preprocessing function
preprocessMiniBatch
to convert the labels to one-hot encoded variables.Format the image data with the dimension labels
"SSCB"
(spatial, spatial, channel, batch). By default, theminibatchqueue
object converts the data todlarray
objects with underlying typesingle
. Do not add a format to the class labels.Train on a GPU if one is available. By default, the
minibatchqueue
object converts each output to agpuArray
if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
mbqTrain = minibatchqueue(augimdsTrain, ... MiniBatchSize = 256, ... MiniBatchFcn = @preprocessMiniBatchTraining, ... OutputAsDlarray = [1 1], ... OutputEnvironment = ["auto","auto"], ... PartialMiniBatch = "return", ... MiniBatchFormat = ["SSCB",""]); mbqTest = minibatchqueue(augimdsTest,... MiniBatchSize = 256,... MiniBatchFcn = @preprocessMiniBatchTraining, ... OutputAsDlarray = [1 1], ... OutputEnvironment = ["auto","auto"], ... PartialMiniBatch = "return", ... MiniBatchFormat = ["SSCB",""]);
Calculate the accuracy of the trained network on the test data.
YTest = modelPredictions(trainedNet, mbqTest, classes); accuracyOfTrainedNet = mean(YTest == TTest)*100
accuracyOfTrainedNet = 90.2400
Analyze the network. The software displays a network diagram that indicates that conversion removes the classification layer of the network. The deep network analyzer shows the total number of learnable parameters in the network.
analyzeNetwork(trainedNet)
Prune Network
Prune the network using the taylorPrunableNetwork
function. The network computes an importance score for each convolution filter in the network based on Taylor expansion [2][3]. Pruning is iterative: each time the loop runs, until a stopping criterion is met, the function removes a small number of the least important convolution filters and updates the network architecture.
Specify Pruning and Fine-Tuning Options
Set the pruning options.
maxPruningIterations
specifies the maximum number of iterations to use for pruning process.maxToPrune
specifies the maximum number of filters to prune in each iteration of the pruning cycle.
maxPruningIterations = 30; maxToPrune = 8;
Set the fine-tuning options.
learnRate = 1e-2; momentum = 0.9; miniBatchSize = 256; numMinibatchUpdates = 50; validationFrequency = 1;
Prune Network Using Custom Pruning Loop
Create a Taylor prunable network from the original network.
prunableNet = taylorPrunableNetwork(trainedNet); maxPrunableFilters = prunableNet.NumPrunables;
Initialize the training progress plots.
figure("Position",[10,10,700,700]) tl = tiledlayout(3,1); lossAx = nexttile; lineLossFinetune = animatedline(Color=[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Fine-Tuning Iteration") ylabel("Loss") grid on title("Mini-Batch Loss During Pruning") xTickPos = []; accuracyAx = nexttile; lineAccuracyPruning = animatedline(Color=[0.098 0.325 0.85],LineWidth=2,Marker="o"); ylim([50 100]) xlabel("Pruning Iteration") ylabel("Accuracy") grid on addpoints(lineAccuracyPruning,0,accuracyOfTrainedNet) title("Validation Accuracy After Pruning") numPrunablesAx = nexttile; lineNumPrunables = animatedline(Color=[0.4660 0.6740 0.1880],LineWidth=2,Marker="^"); ylim([200 700]) xlabel("Pruning Iteration") ylabel("Prunable Filters") grid on addpoints(lineNumPrunables,0,maxPrunableFilters) title("Number of Prunable Convolution Filters After Pruning")
Prune the network by repeatedly fine-tuning the network and removing the low scoring filters.
In each pruning iteration, the following steps are performed:
Fine-tune network and accumulate Taylor scores for convolution filters for
numMinibatchUpdates
.Prune the network using the
updatePrunables
function to removemaxToPrune
number of convolution filters.Compute the validation accuracy.
To fine-tune the network, loop over the mini-batches of the training data. For each mini-batch in the fine-tuning iteration, the following steps are performed:
Evaluate the pruning loss, gradients of the pruning activations, pruning activations, model gradients, and state using the
dlfeval
andmodelLossPruning
functions.Update the network state.
Update the network parameters using the
sgdmupdate
function.Update the Taylor scores of the prunable network using the
updateScore
function.Display the training progress.
start = tic; iteration = 0; for pruningIteration = 1:maxPruningIterations % Shuffle data. shuffle(mbqTrain); % Reset the velocity parameter for the SGDM solver in every pruning % iteration. velocity = []; % Loop over mini-batches. fineTuningIteration = 0; while hasdata(mbqTrain) iteration = iteration + 1; fineTuningIteration = fineTuningIteration + 1; % Read mini-batch of data. [X, T] = next(mbqTrain); % Evaluate the pruning activations, gradients of the pruning % activations, model gradients, state, and loss using the dlfeval and % modelLossPruning functions. [loss,pruningActivations, pruningGradients, netGradients, state] = ... dlfeval(@modelLossPruning, prunableNet, X, T); % Update the network state. prunableNet.State = state; % Update the network parameters using the SGDM optimizer. [prunableNet, velocity] = sgdmupdate(prunableNet, netGradients, velocity, learnRate, momentum); % Compute first-order Taylor scores and accumulate the score across % previous mini-batches of data. prunableNet = updateScore(prunableNet, pruningActivations, pruningGradients); % Display the training progress. D = duration(0,0,toc(start),Format="hh:mm:ss"); addpoints(lineLossFinetune, iteration, loss) title(tl,"Processing Pruning Iteration: " + pruningIteration + " of " + maxPruningIterations + ... ", Elapsed Time: " + string(D)) % Synchronize the x-axis of the accuracy and numPrunables plots with the loss plot. xlim(accuracyAx,lossAx.XLim) xlim(numPrunablesAx,lossAx.XLim) drawnow % Stop the fine-tuning loop when numMinibatchUpdates is reached. if (fineTuningIteration > numMinibatchUpdates) break end end % Prune filters based on previously computed Taylor scores. prunableNet = updatePrunables(prunableNet, MaxToPrune = maxToPrune); % Show results on the validation data set in a subset of pruning iterations. isLastPruningIteration = pruningIteration == maxPruningIterations; if (mod(pruningIteration, validationFrequency) == 0 || isLastPruningIteration) accuracy = modelAccuracy(prunableNet, mbqTest, classes, augimdsTest.NumObservations); addpoints(lineAccuracyPruning, iteration, accuracy) addpoints(lineNumPrunables,iteration,prunableNet.NumPrunables) end % Set x-axis tick values at the end of each pruning iteration. xTickPos = [xTickPos, iteration]; %#ok<AGROW> xticks(lossAx,xTickPos) xticks(accuracyAx,[0,xTickPos]) xticks(numPrunablesAx,[0,xTickPos]) xticklabels(accuracyAx,["Unpruned",string(1:pruningIteration)]) xticklabels(numPrunablesAx,["Unpruned",string(1:pruningIteration)]) drawnow end
In contrast to typical training where the loss decreases with each iteration, pruning can increase the loss and reduce the validation accuracy because the network structure changes when convolution filters are pruned. To further improve the accuracy of the network, you can retrain the network.
Once pruning is complete, convert the taylorPrunableNetwork
back to a dlnetwork
for retraining.
prunedNet = dlnetwork(prunableNet);
Retrain Network After Pruning
Retrain the network after pruning to regain any loss in accuracy using the trainnet
function. You can also use a custom training loop to train the network. For more information, see Train Network Using Custom Training Loop.
Set the options to the default settings for stochastic gradient descent with momentum. Set the maximum number of retraining epochs at 10 and start the training with an initial learning rate of 0.01.
options = trainingOptions("sgdm", ... MaxEpochs = 4, ... MiniBatchSize = 256, ... InitialLearnRate = 1e-2, ... LearnRateSchedule = "piecewise", ... LearnRateDropFactor = 0.1, ... LearnRateDropPeriod = 2, ... L2Regularization = 0.02, ... ValidationData = augimdsTest, ... ValidationFrequency = 200, ... Verbose = false, ... Shuffle = "every-epoch", ... Plots = "training-progress");
Train the network.
prunedDLNet = trainnet(augimdsTrain,prunedNet, "crossentropy",options);
Compare Original Network and Pruned Network
Determine the impact of pruning on each layer.
[originalNetFilters,layerNames] = numConvLayerFilters(trainedNet); prunedNetFilters = numConvLayerFilters(prunedDLNet);
Visualize the number of filters in the original network and in the pruned network.
figure("Position",[10,10,900,900]) bar([originalNetFilters,prunedNetFilters]) xlabel("Layer") ylabel("Number of Filters") title("Number of Filters per Layer") xticks(1:(numel(layerNames))) xticklabels(layerNames) xtickangle(90) ax = gca; ax.TickLabelInterpreter = "none"; legend("Original Network Filters","Pruned Network Filters","Location","southoutside")
The chart shows the layers where pruning removed less important filters.
Next, compare the accuracy of the original network and the pruned network.
YPredOriginal = modelPredictions(trainedNet,mbqTest, classes); accuOriginal = mean(YPredOriginal == TTest)
accuOriginal = 0.9024
YPredPruned = modelPredictions(prunedDLNet,mbqTest, classes); accuPruned = mean(YPredPruned == TTest)
accuPruned = 0.8783
Pruning can unequally affect the classification of different classes and introduce bias into the model, which might not be apparent from the accuracy value. To assess the impact of pruning at a class level, use a confusion matrix chart.
figure confusionchart(TTest,YPredOriginal,Normalization="row-normalized"); title("Original Network")
figure confusionchart(TTest,YPredPruned,Normalization="row-normalized"); title("Pruned Network")
Next, estimate the model parameters for the original network and the pruned network to understand the impact of pruning on the overall network learnables and size.
analyzeNetworkMetrics(trainedNet,prunedDLNet,accuOriginal,accuPruned)
ans=3×3 table
Network Learnables Approx. Network Memory (MB) Accuracy
__________________ ___________________________ ________
Original Network 2.7169e+05 1.0364 0.9024
Pruned Network 1.1982e+05 0.45709 0.8783
Percentage Change -55.897 -55.897 -2.6707
This table compares the size and classification accuracy of the original network and the pruned network. A decrease in network memory and similar accuracy values indicate a good pruning operation. For an example showing how to further reduce the size of the network for deployment using quantization, see Quantize Residual Network Trained for Image Classification and Generate CUDA Code.
Helper Functions
Evaluate Model Accuracy
The modelAccuracy
function takes as input the network (dlnetwork
), minibatchqueue
object, classes, and number of observations and returns the accuracy.
function accuracy = modelAccuracy(net, mbq, classes, numObservations) % This function computes the model accuracy of a dlnetwork on the minibatchque 'mbq'. totalCorrect = 0; classes = int32(categorical(classes)); reset(mbq); while hasdata(mbq) [dlX, Y] = next(mbq); dlYPred = extractdata(predict(net, dlX)); YPred = onehotdecode(dlYPred,classes,1)'; YReal = onehotdecode(Y,classes,1)'; miniBatchCorrect = nnz(YPred == YReal); totalCorrect = totalCorrect + miniBatchCorrect; end accuracy = totalCorrect / numObservations * 100; end
Model Gradients Function for Fine-Tuning and Pruning
The modelLossPruning
function takes as input a deep.prune.TaylorPrunableNetwork
object prunableNet
, a mini-batch of input data X
with corresponding labels T and returns the loss, gradients of the loss with respect to the pruning activations, pruning activations, gradients of the loss with respect to the learnable parameters in prunableNet
, and the network state. To compute the gradients automatically, use the dlgradient
function.
function [loss,pruningGradient,pruningActivations,netGradients,state] = modelLossPruning(prunableNet, X, T) [dlYPred,state,pruningActivations] = forward(prunableNet,X); loss = crossentropy(dlYPred,T); [pruningGradient,netGradients] = dlgradient(loss,pruningActivations,prunableNet.Learnables); end
Model Predictions Function
The modelPredictions
function takes as input a dlnetwork
object net
, a minibatchqueue
of input data mbq
, and the network classes, and computes the model predictions by iterating over all data in the minibatchqueue
object. The function uses the onehotdecode
function to find the predicted class with the highest score.
function Y = modelPredictions(net,mbq,classes) Y = []; % Loop over mini-batches. while hasdata(mbq) X = next(mbq); % Make prediction. scores = predict(net,X); % Decode labels and append to output. labels = onehotdecode(scores,classes,1)'; Y = [Y; labels]; end mbq.reset(); end
Mini-Batch Preprocessing Function
The preprocessMiniBatchTraining
function preprocesses a mini-batch of predictors and labels for loss computation during training.
function [X,T] = preprocessMiniBatchTraining(XCell,TCell) % Concatenate. X = cat(4,XCell{1:end}); % Extract label data from cell and concatenate. T = cat(2,TCell{1:end}); % One-hot encode labels. T = onehotencode(T,1); end
Evaluate Number of Filters in Convolution Layers
The numConvLayerFilters
function returns the number of filters in each convolution layer.
function [nFilters, convNames] = numConvLayerFilters(net) numLayers = numel(net.Layers); convNames = []; nFilters = []; % Check for convolution layers and extract the number of filters. for cnt = 1:numLayers if isa(net.Layers(cnt),"nnet.cnn.layer.Convolution2DLayer") sizeW = size(net.Layers(cnt).Weights); nFilters = [nFilters; sizeW(end)]; convNames = [convNames; string(net.Layers(cnt).Name)]; end end end
Evaluate Network Statistics of Original Network and Pruned Network
The analyzeNetworkMetrics
function takes as input the original network, pruned network, accuracy of the original network, and accuracy of the pruned network and returns a table with three metrics: the network learnables, network memory, and the accuracy on the test data.
function [statistics] = analyzeNetworkMetrics(originalNet,prunedNet,accuracyOriginal,accuracyPruned) originalNetMetrics = estimateNetworkMetrics(originalNet); prunedNetMetrics = estimateNetworkMetrics(prunedNet); % Accuracy of original network and pruned network perChangeAccu = 100*(accuracyPruned - accuracyOriginal)/accuracyOriginal; AccuracyForNetworks = [accuracyOriginal;accuracyPruned;perChangeAccu]; % Total learnables in both networks originalNetLearnables = sum(originalNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables); prunedNetLearnables = sum(prunedNetMetrics(1:end,"NumberOfLearnables").NumberOfLearnables); percentageChangeLearnables = 100*(prunedNetLearnables - originalNetLearnables)/originalNetLearnables; LearnablesForNetwork = [originalNetLearnables;prunedNetLearnables;percentageChangeLearnables]; % Approximate parameter memory approxOriginalMemory = sum(originalNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)")); approxPrunedMemory = sum(prunedNetMetrics(1:end,"ParameterMemory (MB)").("ParameterMemory (MB)")); percentageChangeMemory = 100*(approxPrunedMemory - approxOriginalMemory)/approxOriginalMemory; NetworkMemory = [ approxOriginalMemory; approxPrunedMemory; percentageChangeMemory]; % Create the summary table statistics = table(LearnablesForNetwork,NetworkMemory,AccuracyForNetworks, ... 'VariableNames',["Network Learnables","Approx. Network Memory (MB)","Accuracy"], ... 'RowNames',{'Original Network','Pruned Network','Percentage Change'}); end
References
[1] Krizhevsky, Alex. 2009. "Learning Multiple Layers of Features from Tiny Images." https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf.
[2] Molchanov, Pavlo, Stephen Tyree, Tero Karras, Timo Aila, and Jan Kautz. “Pruning Convolutional Neural Networks for Resource Efficient Inference.” Preprint, submitted June 8, 2017. https://arxiv.org/abs/1611.06440.
[3] Molchanov, Pavlo, Arun Mallya, Stephen Tyree, Iuri Frosio, and Jan Kautz. “Importance Estimation for Neural Network Pruning.” In 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 11256–64. Long Beach, CA, USA: IEEE, 2019. https://doi.org/10.1109/CVPR.2019.01152.