How can we use Nadam optimizer in place of sgdm in training deep learning networks

5 views (last 30 days)
Training_Options = trainingOptions('sgdm', ...
'MiniBatchSize', 32, ...
'MaxEpochs', 50, ...
"InitialLearnRate", 1e-5, ...
'Shuffle', 'every-epoch', ...
'ValidationData', Resized_Validation_Data, ...
'ValidationFrequency', 40, ...
"ExecutionEnvironment","gpu",...
'Plots','training-progress', ...
'Verbose',false);

Answers (2)

Joss Knight
Joss Knight on 4 Apr 2023
You cannot do this using trainNetwork. You need to use a dlnetwork with a custom training loop so you can author your own update rule. Perhaps adam will work for you instead.

Amanjit Dulai
Amanjit Dulai on 25 Oct 2024
You can train with Nadam by defining a custom training loop. The function dlupdate can be used to define custom update rules for training. The rules for Nadam are shown below:
where the momentum is given by:
Below is an example of how to train a digit classification network using Nadam in a custom training loop:
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end
  1 Comment
Amanjit Dulai
Amanjit Dulai on 25 Oct 2024
Also, if you want to use weight decay only on the weights, you can modify the example as shown below:
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
l2RegularizationFactor = 0.0001;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
l2Indices = ~(net.Learnables.Parameter == "Bias");
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Apply weight regulatization
gradients(l2Indices,:) = dlupdate( @(g,w)g + l2RegularizationFactor*w, ...
gradients(l2Indices,:), net.Learnables(l2Indices,:) );
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end
One thing to note is that with adaptive learning rules like Adam and Nadam, it has been found that it is often more effective to apply weight decay directly to the weights instead of the gradients. When applying this to Nadam, it results in the algorithm NadamW. Below is an example on how to use NadamW.
% Load the data
[XTrain, TTrain] = digitTrain4DArrayData;
dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);
% Define the architecture
numClasses = numel(categories(TTrain));
net = dlnetwork([
imageInputLayer([28 28 1], Normalization="none")
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
]);
% Set training options
numEpochs = 4;
miniBatchSize = 100;
learnRate = 0.001;
gradientDecay = 0.9;
squaredGradientDecay = 0.99;
momentumDecay = 0.004;
epsilon = 1e-08;
l2RegularizationFactor = 0.0001;
momentums = gradientDecay*(1 - 0.5*0.96^momentumDecay);
velocity = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
squaredGradients = dlupdate(@(x)zeros(size(x),"like",x), net.Learnables);
l2Indices = ~(net.Learnables.Parameter == "Bias");
% Create mini-batch queue
mbq = minibatchqueue(dsTrain, ...
MiniBatchSize = miniBatchSize,...
MiniBatchFcn = @preprocessMiniBatch,...
MiniBatchFormat = {'SSCB',''});
% Use acceleration to speed up training
acceleratedFcn = dlaccelerate(@modelLoss);
% Initialize the training progress monitor
monitor = trainingProgressMonitor( ...
Metrics = "Loss", ...
Info = "Epoch", ...
XLabel = "Iteration");
% Train the network
numObservationsTrain = numel(TTrain);
numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;
iteration = 1;
for epoch = 1:numEpochs
% Shuffle data
shuffle(mbq)
while hasdata(mbq) && ~monitor.Stop
% Read mini-batch of data.
[XBatch, TBatch] = next(mbq);
% Evaluate the model gradients, state, and loss.
[loss, gradients, state] = dlfeval(acceleratedFcn, net, XBatch, TBatch);
net.State = state;
% Apply decoupled weight regulatization (NadamW)
net.Learnables(l2Indices,:) = dlupdate( @(w)w - learnRate*l2RegularizationFactor*w, ...
net.Learnables(l2Indices,:) );
% Update the dlnetwork according to Nadam
nextMomentum = gradientDecay*(1 - 0.5*0.96^((iteration + 1)*momentumDecay));
momentums = [momentums nextMomentum]; %#ok<AGROW>
velocity = dlupdate(@(v,g)gradientDecay.*v + (1 - gradientDecay).*g, velocity, gradients);
squaredGradients = dlupdate(@(n,g)squaredGradientDecay.*n + (1 - squaredGradientDecay).*(g.^2), squaredGradients, gradients);
velocityHat = dlupdate(@(v,g)(momentums(iteration+1) .* v) ./ (1-prod(momentums(1:(iteration+1)))) + ...
((1-momentums(iteration)) .* g) ./ (1-prod(momentums(1:iteration))), ...
velocity, gradients);
squaredGradientsHat = dlupdate(@(n)n ./ (1 - squaredGradientDecay.^iteration), squaredGradients);
net.Learnables = dlupdate(@(w,v,n)w - (learnRate .* v) ./ (sqrt(n) + epsilon), ...
net.Learnables, ...
velocityHat, ...
squaredGradientsHat );
% Update the training progress monitor.
recordMetrics(monitor, iteration, Loss = loss);
updateInfo(monitor, Epoch = epoch);
monitor.Progress = 100 * iteration/numIterations;
iteration = iteration +1;
end
end
% Calculate the test accuracy
[XTest, TTest] = digitTest4DArrayData;
accuracy = testnet(net, XTest, TTest,"accuracy");
%% Helpers
function [loss, gradients, state] = modelLoss(net, X, T)
[Y, state] = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss, net.Learnables);
end
function [X,T] = preprocessMiniBatch(XCell,TCell)
X = cat(4,XCell{1:end});
T = cat(2,TCell{1:end});
T = onehotencode(T,1);
end

Sign in to comment.

Categories

Find more on Deep Learning Toolbox in Help Center and File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!