- Use the current network (info.TrainedNetwork) instead of loading checkpoints.
- Compute MSE on the same mini-batch or validation set for consistency:
trainnet difference between TrainingLoss and manually computed MSE loss
17 views (last 30 days)
Show older comments
I train a dlNetwork using trainnet and a custom OutputFcn that loads the network at a given frequency using Checkpoint.
This is the how the TrainingOptions are defined:
options = trainingOptions(algo, ...
'MaxEpochs', epochs, ...
'MiniBatchSize', 1024, ...
'InitialLearnRate', learnrate,...
'Verbose',false,...
'CheckpointPath',checkdir,...
'CheckpointFrequencyUnit', 'iteration', ...
'CheckpointFrequency', check_freq,...
OutputFcn=@(info)updatePlotAndStopTraining(info,lines, checkdir, check_freq, XTest, YTrain, XTrain));
This is the custom OutputFcn where I also manually calculate the mse:
function stop = updatePlotAndStopTraining(info,lines, directory__, checkFreq, XTest, YTrain, XTrain)
global msee
iteration = info.Iteration;
trainingLoss = info.TrainingLoss;
if (~isempty(iteration)) && (mod(iteration,checkFreq)==0) && (iteration ~= 0)
d = dir(fullfile(directory__, '*.mat'));
dates = {d.date};
files = {d.name};
[~, idx] = sort(datenum(dates));
latest_file_name = files{idx(end)};
checknet = load(fullfile(directory__, latest_file_name));
msee = (mse(predict(checknet.net, XTrain.'), YTrain));
end
if iteration<checkFreq
if isvalid(lines.distanceToBase)
addpoints(lines.trainingLossLine,iteration,1.0)
addpoints(lines.mse,iteration,0.0)
end
elseif ~isempty(trainingLoss)
if isvalid(lines.distanceToBase)
addpoints(lines.trainingLossLine,iteration,trainingLoss)
addpoints(lines.mse,iteration,msee)
end
end
stop = false;
end
This is the training call:
[finalnet,info] = trainnet(XTrain.', YTrain.', resetNet,'mse', options);
I would expect mse and training loss to be very close only differing due to the TrainingLoss being normalized, but they are going opposite direction. While my manually computed mse suggested the model is not converging, TrainingLoss shows some convergence...

0 Comments
Answers (1)
Leepakshi
on 20 Nov 2025 at 9:33
Hi,
You computed MSE using a checkpointed network, which may lag behind the current training state. predict(checknet.net, XTrain.') uses entire training set, while TrainingLoss is per mini-batch and normalized. Data orientation and timing mismatch cause misleading trends.
These approaches can be used to sort it:
preds = predict(info.TrainedNetwork, XTrain.');
msee = mse(preds, YTrain);
This aligns your metric with training progress.
Hope it helps!
0 Comments
See Also
Categories
Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!