The adamupdate function in MATLAB R2024b incorrectly uses uint32 with sqrt and exhibits state corruption, causing errors even in minimal test cases."

13 views (last 30 days)
% Test adamupdate function
clc;
clear;
% Define test parameters
learnable = dlarray(randn(5, 1)); % Example learnable parameter
gradient = dlarray(randn(5, 1)); % Example gradient
state = []; % Initial state (empty)
optimizer = trainingOptions('adam', 'InitialLearnRate', 0.01); % Example optimizer
timeStep = uint32(1); % Initial time step
try
% Perform a single adamupdate
updatedLearnable = adamupdate(learnable, gradient, state, optimizer, timeStep);
% Display results
disp('adamupdate test successful!');
disp('Updated Learnable:');
disp(updatedLearnable);
catch ME
% Display error message
disp('adamupdate test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace to help Mathworks track the problem.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate test failed!
Error: Undefined function 'sqrt' for input arguments of type 'uint32'.
Stack Trace:
9x1 struct array with fields: file name line
% Perform a second adam update to test state persistence.
timeStep = uint32(2);
try
% Perform a single adamupdate
updatedLearnable = adamupdate(learnable, gradient, state, optimizer, double(timeStep));
% Display results
disp('adamupdate second test successful!');
disp('Updated Learnable:');
disp(updatedLearnable2);
catch ME
% Display error message
disp('adamupdate second test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace to help Mathworks track the problem.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate second test failed!
Error: dlarray is supported only for full arrays of data type double, single, or logical, or for full gpuArrays of these data types.
Stack Trace:
12x1 struct array with fields: file name line
  1 Comment
Chika
Chika on 18 Mar 2025
error message"
:
adamupdate test failed!
Error: Undefined function 'sqrt' for input arguments of type 'uint32'.
Stack Trace:
7×1 struct array with fields:
file
name
line
adamupdate second test failed!
Error: dlarray is supported only for full arrays of data type double, single, or logical, or for full gpuArrays of these data types.
Stack Trace:
10×1 struct array with fields:
file
name
line

Sign in to comment.

Accepted Answer

Joss Knight
Joss Knight on 22 Mar 2025
Well, I admit the error messages aren't very helpful but the basic problem is that passing a trainingOptions object in as an argument to adamupdate is not supported. See the documentation for the correct syntax.
  1 Comment
Chika
Chika on 22 Mar 2025
I am extremely grateful to Joss Knight for pointing out the error and his advis for me to look at the documentation for adamupdate function.

Sign in to comment.

More Answers (1)

Chika
Chika on 22 Mar 2025
% corrected code following the documentation as advised by Joss Knight
% Test adamupdate function (Built-in)
clc;
clear;
% Define test parameters
learnable = dlarray(randn(5, 1)); % Example learnable parameter
gradient = dlarray(randn(5, 1)); % Example gradient
averageGrad = zeros(size(learnable)); % Initialize average gradient
averageSqGrad = zeros(size(learnable)); % Initialize average squared gradient
iteration = 1; % Initial iteration
try
% Perform a single adamupdate
[updatedLearnable, averageGrad, averageSqGrad] = adamupdate(learnable, gradient, averageGrad, averageSqGrad, iteration);
% Display results
disp('adamupdate test successful!');
disp('Updated Learnable:');
disp(updatedLearnable);
disp('Average Gradient:');
disp(averageGrad);
disp('Average Squared Gradient:');
disp(averageSqGrad);
catch ME
% Display error message
disp('adamupdate test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate test successful!
Updated Learnable:
5x1 dlarray -0.7648 -1.0165 -0.0125 -0.5996 0.4997
Average Gradient:
5x1 dlarray 0.1693 -0.0385 0.0958 -0.0383 0.0295
Average Squared Gradient:
5x1 dlarray 0.0029 0.0001 0.0009 0.0001 0.0001
% Perform a second adam update to test state persistence.
iteration = 2;
try
% Perform a second adam update, passing in the updated state
[updatedLearnable2, averageGrad2, averageSqGrad2] = adamupdate(learnable, gradient, averageGrad, averageSqGrad, iteration);
% Display results
disp('adamupdate second test successful!');
disp('Updated Learnable:');
disp(updatedLearnable2);
disp('Average Gradient:');
disp(averageGrad2);
disp('Average Squared Gradient:');
disp(averageSqGrad2);
catch ME
% Display error message
disp('adamupdate second test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate second test successful!
Updated Learnable:
5x1 dlarray -0.7648 -1.0165 -0.0125 -0.5996 0.4997
Average Gradient:
5x1 dlarray 0.3217 -0.0732 0.1819 -0.0727 0.0560
Average Squared Gradient:
5x1 dlarray 0.0057 0.0003 0.0018 0.0003 0.0002

Categories

Find more on Image Data Workflows in Help Center and File Exchange

Products


Release

R2024b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!