Main Content


Compute deep learning network output for training


Some deep learning layers behave differently during training and inference (prediction). For example, during training, dropout layers randomly set input elements to zero to help prevent overfitting, but during inference, dropout layers do not change the input.

To compute network outputs for training, use the forward function. To compute network outputs for inference, use the predict function.


dlY = forward(dlnet,dlX) returns the network output dlY during training given the input data dlX.

dlY = forward(dlnet,dlX1,...,dlXM) returns the network output dlY during training given the M inputs dlX1, ...,dlXM and the network dlnet that has M inputs and a single output.

[dlY1,...,dlYN] = forward(___) returns the N outputs dlY1, …, dlYN during training for networks that have N outputs using any of the previous syntaxes.

[dlY1,...,dlYK] = forward(___,'Outputs',layerNames) returns the outputs dlY1, …, dlYK during training for the specified layers using any of the previous syntaxes.

[___] = forward(___,'Acceleration',acceleration) also specifies performance optimization to use during training, in addition to the input arguments in previous syntaxes.

[___,state] = forward(___) also returns the updated network state.


collapse all

This example shows how to train a network that classifies handwritten digits with a custom learning rate schedule.

If trainingOptions does not provide the options you need (for example, a custom learning rate schedule), then you can define your own custom training loop using automatic differentiation.

This example trains a network to classify handwritten digits with the time-based decay learning rate schedule: for each iteration, the solver uses the learning rate given by ρt=ρ01+kt, where t is the iteration number, ρ0 is the initial learning rate, and k is the decay.

Load Training Data

Load the digits data as an image datastore using the imageDatastore function and specify the folder containing the image data.

dataFolder = fullfile(toolboxdir('nnet'),'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(dataFolder, ...
    'IncludeSubfolders',true, ....

Partition the data into training and validation sets. Set aside 10% of the data for validation using the splitEachLabel function.

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9,'randomize');

The network used in this example requires input images of size 28-by-28-by-1. To automatically resize the training images, use an augmented image datastore. Specify additional augmentation operations to perform on the training images: randomly translate the images up to 5 pixels in the horizontal and vertical axes. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.

inputSize = [28 28 1];
pixelRange = [-5 5];
imageAugmenter = imageDataAugmenter( ...
    'RandXTranslation',pixelRange, ...
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);

To automatically resize the validation images without performing further data augmentation, use an augmented image datastore without specifying any additional preprocessing operations.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Determine the number of classes in the training data.

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

Define Network

Define the network for image classification.

layers = [
lgraph = layerGraph(layers);

Create a dlnetwork object from the layer graph.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}

Define Model Gradients Function

Create the function modelGradients, listed at the end of the example, that takes a dlnetwork object, a mini-batch of input data with corresponding labels and returns the gradients of the loss with respect to the learnable parameters in the network and the corresponding loss.

Specify Training Options

Train for ten epochs with a mini-batch size of 128.

numEpochs = 10;
miniBatchSize = 128;

Specify the options for SGDM optimization. Specify an initial learn rate of 0.01 with a decay of 0.01, and momentum 0.9.

initialLearnRate = 0.01;
decay = 0.01;
momentum = 0.9;

Train Model

Create a minibatchqueue object that processes and manages mini-batches of images during training. For each mini-batch:

  • Use the custom mini-batch preprocessing function preprocessMiniBatch (defined at the end of this example) to convert the labels to one-hot encoded variables.

  • Format the image data with the dimension labels 'SSCB' (spatial, spatial, channel, batch). By default, the minibatchqueue object converts the data to dlarray objects with underlying type single. Do not add a format to the class labels.

  • Train on a GPU if one is available. By default, the minibatchqueue object converts each output to a gpuArray if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Support by Release (Parallel Computing Toolbox).

mbq = minibatchqueue(augimdsTrain,...

Initialize the training progress plot.

lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
grid on

Initialize the velocity parameter for the SGDM solver.

velocity = [];

Train the network using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. For each mini-batch:

  • Evaluate the model gradients, state, and loss using the dlfeval and modelGradients functions and update the network state.

  • Determine the learning rate for the time-based decay learning rate schedule.

  • Update the network parameters using the sgdmupdate function.

  • Display the training progress.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        % Read mini-batch of data.
        [dlX, dlY] = next(mbq);
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function and update the network state.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY);
        dlnet.State = state;
        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);
        % Update the network parameters using the SGDM optimizer.
        [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum);
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        title("Epoch: " + epoch + ", Elapsed: " + string(D))

Test Model

Test the classification accuracy of the model by comparing the predictions on the validation set with the true labels.

After training, making predictions on new data does not require the labels. Create minibatchqueue object containing only the predictors of the test data:

  • To ignore the labels for testing, set the number of outputs of the mini-batch queue to 1.

  • Specify the same mini-batch size used for training.

  • Preprocess the predictors using the preprocessMiniBatchPredictors function, listed at the end of the example.

  • For the single output of the datastore, specify the mini-batch format 'SSCB' (spatial, spatial, channel, batch).

numOutputs = 1;
mbqTest = minibatchqueue(augimdsValidation,numOutputs, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn',@preprocessMiniBatchPredictors, ...

Loop over the mini-batches and classify the images using modelPredictions function, listed at the end of the example.

predictions = modelPredictions(dlnet,mbqTest,classes);

Evaluate the classification accuracy.

YTest = imdsValidation.Labels;
accuracy = mean(predictions == YTest)
accuracy = 0.9530

Model Gradients Function

The modelGradients function takes a dlnetwork object dlnet, a mini-batch of input data dlX with corresponding labels Y and returns the gradients of the loss with respect to the learnable parameters in dlnet, the network state, and the loss. To compute the gradients automatically, use the dlgradient function.

function [gradients,state,loss] = modelGradients(dlnet,dlX,Y)

[dlYPred,state] = forward(dlnet,dlX);

loss = crossentropy(dlYPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);

loss = double(gather(extractdata(loss)));


Model Predictions Function

The modelPredictions function takes a dlnetwork object dlnet, a minibatchqueue of input data mbq, and the network classes, and computes the model predictions by iterating over all data in the minibatchqueue object. The function uses the onehotdecode function to find the predicted class with the highest score.

function predictions = modelPredictions(dlnet,mbq,classes)

predictions = [];

while hasdata(mbq)
    dlXTest = next(mbq);
    dlYPred = predict(dlnet,dlXTest);
    YPred = onehotdecode(dlYPred,classes,1)';
    predictions = [predictions; YPred];


Mini Batch Preprocessing Function

The preprocessMiniBatch function preprocesses a mini-batch of predictors and labels using the following steps:

  1. Preprocess the images using the preprocessMiniBatchPredictors function.

  2. Extract the label data from the incoming cell array and concatenate into a categorical array along the second dimension.

  3. One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.

function [X,Y] = preprocessMiniBatch(XCell,YCell)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(XCell);

% Extract label data from cell and concatenate.
Y = cat(2,YCell{1:end});

% One-hot encode labels.
Y = onehotencode(Y,1);


Mini-Batch Predictors Preprocessing Function

The preprocessMiniBatchPredictors function preprocesses a mini-batch of predictors by extracting the image data from the input cell array and concatenate into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension.

function X = preprocessMiniBatchPredictors(XCell)

% Concatenate.
X = cat(4,XCell{1:end});


Input Arguments

collapse all

Network for custom training loops, specified as a dlnetwork object.

Input data, specified as a formatted dlarray. For more information about dlarray formats, see the fmt input argument of dlarray.

Layers to extract outputs from, specified as a string array or a cell array of character vectors containing the layer names.

  • If layerNames(i) corresponds to a layer with a single output, then layerNames(i) is the name of the layer.

  • If layerNames(i) corresponds to a layer with multiple outputs, then layerNames(i) is the layer name followed by the character "/" and the name of the layer output: 'layerName/outputName'.

Performance optimization, specified as one of the following:

  • 'auto' — Automatically apply a number of optimizations suitable for the input network and hardware resources.

  • 'none' — Disable all acceleration.

The default option is 'auto'.

Using the 'auto' acceleration option can offer performance benefits, but at the expense of an increased initial run time. Subsequent calls with compatible parameters are faster. Use performance optimization when you plan to call the function multiple times using different input data with the same size and shape.

Output Arguments

collapse all

Output data, returned as a formatted dlarray. For more information about dlarray formats, see the fmt input argument of dlarray.

Updated network state, returned as a table.

The network state is a table with three columns:

  • Layer – Layer name, specified as a string scalar.

  • Parameter – Parameter name, specified as a string scalar.

  • Value – Value of parameter, specified as a dlarray object.

The network state contains information remembered by the network between iterations. For example, the state of LSTM and batch normalization layers.

Update the state of a dlnetwork using the State property.

Compatibility Considerations

expand all

Behavior changed in R2021a

Extended Capabilities

Introduced in R2019b