How to use Vanilla SGD solver in training options ?
Show older comments
when i used Vanilla SGD instead of adam solver the code has error : invalid solver name .
how can i use Vanilla SGD instead of adam solver ?
this is my code for traning options part :
options = trainingOptions('sgdm', ...
'MaxEpochs',20,...
'InitialLearnRate',1e-4, ...
'Verbose',false, ...
'Plots','training-progress');
1 Comment
Mira mosad
on 11 May 2023
Answers (1)
Meet
on 12 Sep 2024
Hi Mira,
The option for vanilla SGD is not available as a pre-built solver in the “trainingOptions” function. However, you can define a custom SGD solver and training loop according to your preferences.
Below is the code for defining a custom SGD solver and training loop:
Custom SGD Function:
function parameters = sgdStep(parameters,gradients,learnRate)
parameters = parameters - learnRate .* gradients;
end
Custom Training Loop:
epoch = 0;
iteration = 0;
% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;
% Shuffle data.
shuffle(mbq);
% Loop over mini-batches.
while hasdata(mbq) && ~monitor.Stop
iteration = iteration + 1;
% Read mini-batch of data.
[X,T] = next(mbq);
% Evaluate the model gradients, state, and loss using dlfeval and the
% modelLoss function and update the network state.
[loss,gradients,state] = dlfeval(@modelLoss,net,X,T);
net.State = state;
% Update the network parameters using SGD.
updateFcn = @(parameters,gradients) sgdStep(parameters,gradients,learnRate);
net = dlupdate(updateFcn,net,gradients);
% Update the training progress monitor.
recordMetrics(monitor,iteration,Loss=loss);
updateInfo(monitor,Epoch=epoch);
monitor.Progress = 100 * iteration/numIterations;
end
end
You can refer to the resource below for more information:
Categories
Find more on Converters (High Power) in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!