trainnet difference between TrainingLoss and manually computed MSE loss

17 views (last 30 days)
I train a dlNetwork using trainnet and a custom OutputFcn that loads the network at a given frequency using Checkpoint.
This is the how the TrainingOptions are defined:
options = trainingOptions(algo, ...
'MaxEpochs', epochs, ...
'MiniBatchSize', 1024, ...
'InitialLearnRate', learnrate,...
'Verbose',false,...
'CheckpointPath',checkdir,...
'CheckpointFrequencyUnit', 'iteration', ...
'CheckpointFrequency', check_freq,...
OutputFcn=@(info)updatePlotAndStopTraining(info,lines, checkdir, check_freq, XTest, YTrain, XTrain));
This is the custom OutputFcn where I also manually calculate the mse:
function stop = updatePlotAndStopTraining(info,lines, directory__, checkFreq, XTest, YTrain, XTrain)
global msee
iteration = info.Iteration;
trainingLoss = info.TrainingLoss;
if (~isempty(iteration)) && (mod(iteration,checkFreq)==0) && (iteration ~= 0)
d = dir(fullfile(directory__, '*.mat'));
dates = {d.date};
files = {d.name};
[~, idx] = sort(datenum(dates));
latest_file_name = files{idx(end)};
checknet = load(fullfile(directory__, latest_file_name));
msee = (mse(predict(checknet.net, XTrain.'), YTrain));
end
if iteration<checkFreq
if isvalid(lines.distanceToBase)
addpoints(lines.trainingLossLine,iteration,1.0)
addpoints(lines.mse,iteration,0.0)
end
elseif ~isempty(trainingLoss)
if isvalid(lines.distanceToBase)
addpoints(lines.trainingLossLine,iteration,trainingLoss)
addpoints(lines.mse,iteration,msee)
end
end
stop = false;
end
This is the training call:
[finalnet,info] = trainnet(XTrain.', YTrain.', resetNet,'mse', options);
I would expect mse and training loss to be very close only differing due to the TrainingLoss being normalized, but they are going opposite direction. While my manually computed mse suggested the model is not converging, TrainingLoss shows some convergence...

Answers (1)

Leepakshi
Leepakshi on 20 Nov 2025 at 9:33
Hi,
You computed MSE using a checkpointed network, which may lag behind the current training state. predict(checknet.net, XTrain.') uses entire training set, while TrainingLoss is per mini-batch and normalized. Data orientation and timing mismatch cause misleading trends.
These approaches can be used to sort it:
  1. Use the current network (info.TrainedNetwork) instead of loading checkpoints.
  2. Compute MSE on the same mini-batch or validation set for consistency:
preds = predict(info.TrainedNetwork, XTrain.');
msee = mse(preds, YTrain);
This aligns your metric with training progress.
Hope it helps!

Categories

Find more on Sequence and Numeric Feature Data Workflows in Help Center and File Exchange

Products


Release

R2025b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!