Main Content

Convert Classification Network into Regression Network

This example shows how to convert a trained classification network into a regression network.

Pretrained image classification networks have been trained on over a million images and can classify images into 1000 object categories, such as keyboard, coffee mug, pencil, and many animals. The networks have learned rich feature representations for a wide range of images. The network takes an image as input, and then outputs a label for the object in the image together with the probabilities for each of the object categories.

Transfer learning is commonly used in deep learning applications. You can take a pretrained network and use it as a starting point to learn a new task. This example shows how to take a pretrained classification network and retrain it for regression tasks.

The example loads a pretrained convolutional neural network architecture for classification, replaces the layers for classification and retrains the network to predict angles of rotated handwritten digits.

Load Pretrained Network

Load the pretrained network from the supporting file digitsClassificationConvolutionNet.mat. This file contains a classification network that classifies handwritten digits.

load digitsClassificationConvolutionNet
layers = net.Layers
layers = 
  13x1 Layer array with layers:

     1   'imageinput'    Image Input                  28x28x1 images
     2   'conv_1'        2-D Convolution              10 3x3x1 convolutions with stride [2  2] and padding [0  0  0  0]
     3   'batchnorm_1'   Batch Normalization          Batch normalization with 10 channels
     4   'relu_1'        ReLU                         ReLU
     5   'conv_2'        2-D Convolution              20 3x3x10 convolutions with stride [2  2] and padding [0  0  0  0]
     6   'batchnorm_2'   Batch Normalization          Batch normalization with 20 channels
     7   'relu_2'        ReLU                         ReLU
     8   'conv_3'        2-D Convolution              40 3x3x20 convolutions with stride [2  2] and padding [0  0  0  0]
     9   'batchnorm_3'   Batch Normalization          Batch normalization with 40 channels
    10   'relu_3'        ReLU                         ReLU
    11   'gap'           2-D Global Average Pooling   2-D global average pooling
    12   'fc'            Fully Connected              10 fully connected layer
    13   'softmax'       Softmax                      softmax

Load Data

The data set contains synthetic images of handwritten digits together with the corresponding angles (in degrees) by which each image is rotated.

Load the training and test images as 4-D arrays from the supporting files DigitsDataTrain.mat and DigitsDataTest.mat. The variables anglesTrain and anglesTest are the rotation angles in degrees. The training and test data sets each contain 5000 images.

load DigitsDataTrain
load DigitsDataTest

Display 20 random training images using imshow.

numTrainImages = numel(anglesTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Figure contains 20 axes objects. Hidden axes object 1 contains an object of type image. Hidden axes object 2 contains an object of type image. Hidden axes object 3 contains an object of type image. Hidden axes object 4 contains an object of type image. Hidden axes object 5 contains an object of type image. Hidden axes object 6 contains an object of type image. Hidden axes object 7 contains an object of type image. Hidden axes object 8 contains an object of type image. Hidden axes object 9 contains an object of type image. Hidden axes object 10 contains an object of type image. Hidden axes object 11 contains an object of type image. Hidden axes object 12 contains an object of type image. Hidden axes object 13 contains an object of type image. Hidden axes object 14 contains an object of type image. Hidden axes object 15 contains an object of type image. Hidden axes object 16 contains an object of type image. Hidden axes object 17 contains an object of type image. Hidden axes object 18 contains an object of type image. Hidden axes object 19 contains an object of type image. Hidden axes object 20 contains an object of type image.

Replace Final Layers

The convolutional layers of the network extract image features that the last learnable layer used to classify the input image. The layer 'fc' contains the information on how to combine the features that the network extracts into class probabilities. To retrain a pretrained network for regression, replace this layer and the following softmax layer with a new layer adapted to the task.

Replace the final fully connected layer with a fully connected layer of size 1 (the number of responses).

numResponses = 1;
layer = fullyConnectedLayer(numResponses,Name="fc");

net = replaceLayer(net,"fc",layer)
net = 
  dlnetwork with properties:

         Layers: [13x1 nnet.cnn.layer.Layer]
    Connections: [12x2 table]
     Learnables: [14x3 table]
          State: [6x3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 0

  View summary with summary.

Remove the softmax layer.

net = removeLayers(net,"softmax");

Adjust Layer Learning Rate Factors

The network is now ready to be retrained on the new data. Optionally, you can slow down the training of the weights of earlier layers in the network by increasing the learning rate of the new fully connected layer and reducing the global learning rate when you specify the training options.

Increase the learning rates of the fully connected layer parameters by a factor of using the setLearnRateFactor function.

net = setLearnRateFactor(net,"fc","Weights",10);
net = setLearnRateFactor(net,"fc","Bias",10);

Specify Training Options

Specify the training options. Choosing among the options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager app.

  • Specify a reduced learning rate of 0.0001.

  • Display the training progress in a plot.

  • Disable the verbose output.

options = trainingOptions("sgdm",...
    InitialLearnRate=0.001, ...
    Plots="training-progress",...
    Verbose=false);

Train Neural Network

Train the neural network using the trainnet function. For regression, use mean squared error loss. By default, the trainnet function uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

net = trainnet(XTrain,anglesTrain,net,"mse",options);

Test Network

Test the performance of the network by evaluating the accuracy on the test data.

Use predict to predict the angles of rotation of the validation images.

YTest = predict(net,XTest);

Visualize the predictions in a scatter plot. Plot the predicted values against the true values.

figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],"r--")

Figure contains an axes object. The axes object with xlabel Predicted Value, ylabel True Value contains 2 objects of type scatter, line.

See Also

| |

Related Topics