Implement Ridge Regression Equation for a Neural Network MATLAB
Show older comments
I am trying to replicate the following equation in MATLAB to find the optimal output weight matrix of a neural network from training using ridge regression.
Output Weight Matrix of a Neural Network after Training using Ridge Regression:

This equation comes from the echo state network guide provided by Mantas Lukosevicius and can be found at: https://www.researchgate.net/publication/319770153_A_practical_guide_to_applying_echo_state_networks
My attempt is below. I believe that the outer parenthesis (in red) makes this a non-traditional double summation, meaning the method presented by @Voss (see https://www.mathworks.com/matlabcentral/answers/1694960-nested-loops-for-double-summation) cannot be followed.Note that y_i is a T by 1 vector and y_i_target is also a T by 1 vector. Wout_i is a N by 1 vector where N is the number of nodes in the neural network. I generate a Wout_i,y_i,y_i_target for each i^th target training signal. The final output for Wout is a N by 1 vector, where each element in the vector is the optimal weight for each node in the network.
close all;
clear all;
clc;
N = 100; % number of nodes in nerual network
Ny = 200; % number of training signals
T = 50; % time length of each training signal
X = rand(N,T); % neural network state matrix
reg = 10^-4; % ridge regression coefficient
outer_sum = zeros(Ny,1);
for i = 1:Ny
y_i_target = rand(T,1); % training signal
Wout_i = ((X*X' + reg*eye(N)) \ (X*y_i_target));
Wouts{i} = Wout_i; % collected cell matrix of each Wout_i for each i^th target training signal
y_i = Wout_i'*X; % predicted signal
inner_sum = sum(((y_i'-y_i_target).^2)+reg*norm(Wout_i)^2);
outer_sum(i) = inner_sum;
end
outer_sum = outer_sum.*(1/Ny);
[minval, minidx] = min(outer_sum);
Wout = cell2mat(Wouts(minidx));
outer_sum = outer_sum.*(1/Ny);
[minval, minidx] = min(outer_sum);
Wout = cell2mat(Wouts(minidx));
My final answer for Wout is a N by 1 as it should be, but I am uncertain in my answer. I am particularly unsure whether or not I have done the double summation and arg min with respect to Wout operations correctly. Is there any way to validate my answer?
1 Comment
Jonathan Frutschy
on 11 Jun 2024
Accepted Answer
More Answers (0)
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!