Setting inputs formats for nested Neural ODE
Show older comments
Hi all,
I am constructing a NN that nests a Neural ODE. The NN has two datasets as inputs: i) The initial values of internal states (InitialValue) that are used to feed the Neural ODE and, ii) The sequences (input_2) that are included in the NN after the Neural ODE. The input_2 and the outputs of the Neural ODE must be summed.
I tried creating the entries of the NN in cell and dlarray format. For the last one I also defined the dimensions 'CBT' and/or 'CB' according the structure of the dataserie, nevertheles the problem persists.
The error I got is the following
%% Generate data
data_1 = randn(1,1000);
data_2 = randn(1,1000);
tspan = 1:1:50;
InitialValue = data_1(:,1:end-length(tspan))';
indices = 1:length(InitialValue);
targets = arrayfun(@(i) data_1(:, i + tspan), indices, 'UniformOutput', false)';
input_2 = arrayfun(@(i) data_2(:, i + tspan), indices, 'UniformOutput', false)';
%% Create neuralnetwork
% NeuralODE layers
OdeLayer = [fullyConnectedLayer(5)
tanhLayer
fullyConnectedLayer(1)];
OdeNet = dlnetwork(OdeLayer,Initialize=false);
% Main layer
net = dlnetwork;
Layers =[featureInputLayer(1,'Name','Input 1')
neuralODELayer(OdeNet,tspan,"Name",'OdeLayer','GradientMode','adjoint')];
% add extra input for adition to NeuralODE output
net = addLayers(net, Layers);
net = addLayers(net, sequenceInputLayer(1,'Name','Input 2'));
net = addLayers(net, additionLayer(2,'Name','adition'));
% connect layers
net = connectLayers(net,'Input 2','adition/in2');
net = connectLayers(net,'OdeLayer','adition/in1');
%% Train Network
% gather inputs and targets
input_1_ds = arrayDatastore(InitialValue,"OutputType","same");
input_2_ds = arrayDatastore(input_2,"OutputType","same");
target_ds = arrayDatastore(targets,"OutputType","same");
cds = combine(input_1_ds, input_2_ds, target_ds);
opt = trainingOptions("adam");
% training
net = trainnet(cds,net,"l2loss",opt);
Thanks in advance for your feedback and comments.
Accepted Answer
More Answers (0)
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!