How to classify with DAG network from checkpoint
Show older comments
I want to use classify() with DAG network from checkpoint network.
I trained inceptionv3 by transfer learning for a long epochs and it was successed. I set 'CheckpointPath' and have networks at each epoch. I want to evaluate these networks, so I loaded one and used classify(). But error message occuerd and it said "Use trainNetwork". How can I use classify() with network loaded from checkpoint?
3 Comments
Naoya
on 15 Oct 2018
Could you please provide the whole error message and the exact command that you executed?
Yoshinori Abe
on 15 Oct 2018
carlos arizmendi
on 23 Nov 2019
I have now the same problem classifing, how did you fix this bug? Thanks a lot.
Accepted Answer
More Answers (3)
Katja Mogalle
on 30 Apr 2021
@Gediminas Simkus had the right idea for the workaround. I can sketch this out a bit more.
Background information
To make predictions with the network after training, batch normalization requires a fixed mean and variance to normalize the data. By default, this fixed mean and variance is calculated from the training data at the very end of training using the entire training data set. But when using checkpointing, the end of training isn't reached so the mean and variance values are not set.
Two possible solutions
There are two things you can try in order to use checkpoint networks for inference:
- Since R2021a, running statistics can be enabled for batch normalization layers. The batch normalization statistics are then calculated during training and not at the end of training. The checkpoint networks can be used directly without further modification. To do this, set the ‘BatchNormalizationStatistics’ name-value pair in trainingOptions to ‘moving’ when training the network with checkpointing.
- Use trainNetwork with minimal training to convert the checkpoint network into a network with fixed batch normalization mean and variance that can be used for inference. The workaround is based on the process to Resume Training from Checkpoint Network but with some slight tweaks in order to modify the checkpointed network as little as possible.
Example steps for second workaround using trainNetwork (tested in R2020a and R2020b)
Load the checkpoint network into the workspace (replace this with your own file).
load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')
Specify the training options such that training is only run for one iteration, the input data statistics of the input layer are not recomputed, and the learnable parameters are only changed minimally.
options = trainingOptions('sgdm', ...
'InitialLearnRate',eps, ...
'ResetInputNormalization',false,...
'OutputFcn',@(~)true );
Now “resume” training using the layers of the checkpoint network you loaded with the new training options. If the checkpoint network is a DAG network, then use layerGraph(net) as the argument instead of net.Layers.
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);
The returned network can be used for inference.
YPred = classify(net2,XTrain);
I hope this helps.
3 Comments
Andrea Daou
on 7 Oct 2021
Hello,
Concerning the 'ResetInputNormalization' training option for trainNetwork, in which MATLAB release was it introduced because I am using MATLAB R2019a and it does not exist in the trainingOptions. I am getting an error.
Thank you in advance!!
Katja Mogalle
on 19 Oct 2021
The option 'ResetInputNormalization' of training options was added in R2019b.
Nithin M
on 29 Oct 2021
Thank you for the detailed post.
I have a query. What is the impact of using BatchNormalizationStatistics as moving on time of training? whether it will be increased considerably or wont have much effect?
AnaMota
on 27 Apr 2021
0 votes
Any solution on this? I am facing the same issue with MATLAB2020...
Andrea Daou
on 8 Oct 2021
Hello,
I know an answer was accepted for this question but I have a response that might be useful.
If the use of network from checkpoint does not work in your MATLAB version, you can write a function similar to the one in https://fr.mathworks.com/help/deeplearning/ug/customize-output-during-deep-learning-training.html .
For example, instead of being based on Validation Accuracy, it can be based on Validation Loss.
function stop = stopIfValidationLossNotDecreasing(info,N,StartPoint)
stop = false;
% Keep track of the validation loss and the number of successive validations for which
% there has not been a decrease in the loss.
persistent ValLoss
persistent valLag
% Clear the variables when training starts.
if info.State == "start"
ValLoss = StartPoint; % Value chosen depending on the problem case; check first validation loss.
valLag = 0;
elseif ~isempty(info.ValidationLoss)
% Compare the current validation loss to the last validation loss; if
% the new validation loss is less than the validation loss that
% precedes it then reset valLag else increment valLag by 1. Now the new
% ValLoss to compare with is the last one reached.
if info.ValidationLoss < ValLoss
valLag = 0;
ValLoss = info.ValidationLoss;
else
valLag = valLag + 1;
ValLoss = info.ValidationLoss;
end
% If the validation lag is at least N, that is, the validation loss
% has not decreased for at least N validations in a row, then return true and
% stop training.
if valLag >= N
stop = true;
end
end
end
1 Comment
Katja Mogalle
on 19 Oct 2021
Hi Andrea,
The training option "ValidationPatience" actually does exactly what you're showing in your code. To stop training when the loss on the validation set stops decreasing, simply specify validation data and a validation patience using the 'ValidationData' and the 'ValidationPatience' name-value pair arguments of trainingOptions, respectively. The validation patience is the number of times that the loss on the validation set can be larger than or equal to the previously smallest loss before network training stops.
See this doc page as a reference: https://uk.mathworks.com/help/deeplearning/ref/trainingoptions.html?s_tid=doc_ta#d123e136007
But perhaps I am not fully understanding what you are trying to achieve. In that case, perhaps you could provide some clarification?
Thanks
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!