Make Predictions Using Model Function
This example shows how to make predictions using a model function by splitting data into mini-batches.
For large data sets, or when predicting on hardware with limited memory, make predictions by splitting the data into mini-batches. When making predictions with dlnetwork
object, the minibatchpredict
function automatically splits the input data into mini-batches. For model functions, you must split the data into mini-batches manually.
Create Model Function and Load Parameters
Load the model parameters from the MAT file digitsMIMO.mat
. The MAT file contains the model parameters in a struct named parameters
, the model state in a struct named state
, and the class names in classNames
.
s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;
The model function model
, listed at the end of the example, defines the model given the model parameters and state.
Load Data for Prediction
Load the digits data for prediction.
unzip("DigitsData.zip"); dataFolder = "DigitsData"; imds = imageDatastore(dataFolder, ... IncludeSubfolders=true, ... LabelSource="foldernames"); numObservations = numel(imds.Files);
Make Predictions
Loop over the mini-batches of the test data and make predictions using a custom prediction loop.
Use minibatchqueue
to process and manage the mini-batches of images. Specify a mini-batch size of 128. Set the read size property of the image datastore to the mini-batch size.
For each mini-batch:
Use the custom mini-batch preprocessing function
preprocessMiniBatch
(defined at the end of this example) to concatenate the data into a batch and normalize the images.Format the images with the dimensions
'SSCB'
(spatial, spatial, channel, batch). By default, theminibatchqueue
object converts the data todlarray
objects with underlying typesingle
.Make predictions on a GPU if one is available. By default, the
minibatchqueue
object converts the output to agpuArray
if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(imds,... MiniBatchSize=miniBatchSize,... MiniBatchFcn=@preprocessMiniBatch,... MiniBatchFormat="SSCB");
Loop over the mini-batches of data and make predictions using the model function. Use the scores2label
function to determine the class labels. Store the predicted class labels.
doTraining = false; Y1 = []; Y2 = []; % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. X = next(mbq); % Make predictions using the model function. [Y1Batch,Y2Batch] = model(parameters,X,doTraining,state); % Determine corresponding classes. Y1Batch = scores2label(Y1Batch,classNames); Y1 = [Y1 Y1Batch]; Y2Batch = extractdata(Y2Batch); Y2 = [Y2 Y2Batch]; end
View some of the images with their predictions.
idx = randperm(numObservations,9); figure for i = 1:9 subplot(3,3,i) I = imread(imds.Files{idx(i)}); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],"r--") hold off label = string(Y1(idx(i))); title("Label: " + label) end
Model Function
The function model
takes the model parameters parameters
, the input data X
, the flag doTraining
which specifies whether to model should return outputs for training or prediction, and the network state state
. The network outputs the predictions for the labels, the predictions for the angles, and the updated network state.
function [Y1,Y2,state] = model(parameters,X,doTraining,state) % Convolution weights = parameters.conv1.Weights; bias = parameters.conv1.Bias; Y = dlconv(X,weights,bias,Padding="same"); % Batch normalization, ReLU offset = parameters.batchnorm1.Offset; scale = parameters.batchnorm1.Scale; trainedMean = state.batchnorm1.TrainedMean; trainedVariance = state.batchnorm1.TrainedVariance; if doTraining [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance); end Y = relu(Y); % Convolution, batch normalization (Skip connection) weights = parameters.convSkip.Weights; bias = parameters.convSkip.Bias; YSkip = dlconv(Y,weights,bias,Stride=2); offset = parameters.batchnormSkip.Offset; scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance; if doTraining [YSkip,trainedMean,trainedVariance] = batchnorm(YSkip,offset,scale,trainedMean,trainedVariance); % Update state state.batchnormSkip.TrainedMean = trainedMean; state.batchnormSkip.TrainedVariance = trainedVariance; else YSkip = batchnorm(YSkip,offset,scale,trainedMean,trainedVariance); end % Convolution weights = parameters.conv2.Weights; bias = parameters.conv2.Bias; Y = dlconv(Y,weights,bias,Padding="same",Stride=2); % Batch normalization, ReLU offset = parameters.batchnorm2.Offset; scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance; if doTraining [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm2.TrainedMean = trainedMean; state.batchnorm2.TrainedVariance = trainedVariance; else Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance); end Y = relu(Y); % Convolution weights = parameters.conv3.Weights; bias = parameters.conv3.Bias; Y = dlconv(Y,weights,bias,Padding="same"); % Batch normalization offset = parameters.batchnorm3.Offset; scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance; if doTraining [Y,trainedMean,trainedVariance] = batchnorm(Y,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm3.TrainedMean = trainedMean; state.batchnorm3.TrainedVariance = trainedVariance; else Y = batchnorm(Y,offset,scale,trainedMean,trainedVariance); end % Addition, ReLU Y = YSkip + Y; Y = relu(Y); % Fully connect, softmax (labels) weights = parameters.fc1.Weights; bias = parameters.fc1.Bias; Y1 = fullyconnect(Y,weights,bias); Y1 = softmax(Y1); % Fully connect (angles) weights = parameters.fc2.Weights; bias = parameters.fc2.Bias; Y2 = fullyconnect(Y,weights,bias); end
Mini-Batch Preprocessing Function
The preprocessMiniBatch
function preprocesses the data using the following steps:
Extract the data from the incoming cell array and concatenate into a numeric array. Concatenating over the fourth dimension adds a third dimension to each image, to be used as a singleton channel dimension.
Normalize the pixel values between
0
and1
.
function X = preprocessMiniBatch(data) % Extract image data from cell and concatenate X = cat(4,data{:}); % Normalize the images. X = X/255; end
See Also
dlarray
| dlgradient
| dlfeval
| sgdmupdate
| dlconv
| batchnorm
| relu
| fullyconnect
| softmax
| minibatchqueue
| onehotdecode
Related Topics
- Define Custom Training Loops, Loss Functions, and Networks
- Define Model Loss Function for Custom Training Loop
- Train Network Using Model Function
- Update Batch Normalization Statistics Using Model Function
- Initialize Learnable Parameters for Model Function
- Specify Training Options in Custom Training Loop
- List of Functions with dlarray Support