Same RNN model generates different loss functions values

2 views (last 30 days)
Hello,
I am working on an autoencoder RNN model where I extract features from the encoder and feed them into a GNN model. The problem arises when I run my code using the MATLAB built-in function 'trainnet' with the MSE loss function; it generates very small values that decrease with each iteration. However, when I use the same loss function with 'dlfeval,' it generates very high values that get stuck at a fixed value after a few iterations. Following are my code:
Someone please guide me how this buit-in trainnet works and how can I implement this manually.
layers = [ sequenceInputLayer(1,MinLength = 4096)
% 1(C)xminibatch(B)x2048(T)
modwtLayer('Level',6,'IncludeLowpass',false,'SelectedLevels',1:6,"Wavelet","sym2")
flattenLayer
convolution1dLayer(256,16,Padding="same",Stride=8)
batchNormalizationLayer()
tanhLayer
maxPooling1dLayer(2,Padding="same")
convolution1dLayer(64,16,Padding="same",Stride=4)
batchNormalizationLayer
tanhLayer
maxPooling1dLayer(2,Padding="same")
transposedConv1dLayer(64,16,Cropping="same",Stride=4)
tanhLayer
transposedConv1dLayer(256,16,Cropping="same",Stride=8)
tanhLayer
bilstmLayer(8)
fullyConnectedLayer(8)
dropoutLayer(0.2)
fullyConnectedLayer(4)
dropoutLayer(0.2)
fullyConnectedLayer(1)];
%dataRNN = dlarray (1(C)x1458(B)x4096(T))
[loss,gadients] = dlfeval(@modelLoss,net,dataRNN);
function [loss,gradients] = modelLoss(net,data)
Y = forward(net,data);
% coder = minibatchpredict(net,data,Outputs='maxpool1d_2');
loss = mse(Y,data);
gradients = dlgradient(loss,net.Learnables);
end

Answers (1)

Ayush
Ayush on 3 Sep 2024
I understand that you are experiencing discrepancies between MATLAB's built-in "trainnet" function and a manual implementation using "dlfeval". I’d like to clarify the key reasons behind these differences.
The "trainnet" function in MATLAB is a high-level utility designed to streamline the training process by managing several critical aspects, including:
  1. Data Shuffling: Automatically shuffles data at the start of each epoch.
idx = randperm(size(dataRNN, 2));
2. Mini-Batch Processing: Divides data into mini-batches and processes them sequentially.
miniBatchSize = 32; % Example size
numObservations = size(dataRNN, 2);
numIterationsPerEpoch = floor(numObservations / miniBatchSize);
3. Learning Rate Scheduling: Adjusts the learning rate as training progresses.
initialLearnRate = 0.01;
learnRate = initialLearnRate; % Update this over epochs
4. Gradient Clipping: Prevents gradient explosion by clipping gradients to a specified threshold.
maxGradient = 1;
gradients = dlupdate(@(g) min(max(g, -maxGradient), maxGradient), gradients);
5. Optimization Algorithms: Utilizes optimizers like Adam, RMSProp, etc., with tuned hyperparameters.
[net.Learnables, state] = adamupdate(net.Learnables, gradients, state, learnRate);
For a manual implementation, your training loop needs to handle these aspects explicitly. Here is a refined approach:
for epoch = 1:numEpochs
shuffleIdx = randperm(size(dataRNN, 2));
dataRNN = dataRNN(:, shuffleIdx, :);
for i = 1:numIterationsPerEpoch
idx = (i-1)*miniBatchSize + 1:i*miniBatchSize;
miniBatchData = dataRNN(:, idx, :);
[loss, gradients] = dlfeval(@modelLoss, net, miniBatchData);
% Update model parameters
[net.Learnables, state] = adamupdate(net.Learnables, gradients, state, learnRate);
% Optionally, implement learning rate schedule
% learnRate = updateLearningRate(epoch, learnRate);
end
end
For more information, you can refer to the following documentation on “trainnet” function: https://www.mathworks.com/help/deeplearning/ref/trainnet.html
also, if you want more information on "dlfeval" function, you can refer to the documentation:
Hope it helps!

Products


Release

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!