機械学習の入力エラーについて
Show older comments
LSTMの学習方法について質問です.
最下部に示したコードを実行したとき,「予測子はシーケンスの N 行 1 列の cell 配列でなければなりません。」が表示されうまく学習できません.
入力データは1タイムステップに t-2, t-1, t のデータが含まれており,それに対応する出力データは t+1 のデータとなっています.
ここで学習に用いるデータを
net = trainNetwork(XTrain_C, YTrain_C, layers, options);
のように,Cのみを用いるようにすると上手く実行できるのですが,元コードのようにA,C,Dの3つの時系列データを学習させたモデルを作成しようとするとエラー文が表示されてしまいます.
下記のコードをどう修正すれば実行可能になりますでしょうか?
clear all
close all
%% Make dataset
A = zeros(1,100);
B = zeros(1,100);
C = zeros(1,100);
D = zeros(1,100);
% A
for i = 1:100
if i <= 40
A(:,i) = i / 40;
elseif (41 <= i) && (i <= 45)
A(:,i) = 1 - ((i - 40) / 5);
elseif 46 <= i
A(:,i) = 0;
end
end
% B
for i = 1:100
if i <= 60
B(:,i) = i / 60;
elseif (61 <= i) && (i <= 65)
B(:,i) = 1 - ((i - 60) / 5);
elseif 66 <= i
B(:,i) = 0;
end
end
% C
for i = 1:100
if i <= 80
C(:,i) = i / 80;
elseif (81 <= i) && (i <= 85)
C(:,i) = 1 - ((i - 80) / 5);
elseif 86 <= i
C(:,i) = 0;
end
end
% D
for i = 1:100
if i <= 40
D(:,i) = i / 20;
elseif (21 <= i) && (i <= 25)
D(:,i) = 1 - ((i - 20) / 5);
elseif 26 <= i
D(:,i) = 0;
end
end
%% Plot
plot(1:100, A(1,:),'LineWidth',2);hold on
plot(1:100, B(1,:),'LineWidth',2);hold on
plot(1:100, C(1,:),'LineWidth',2);hold off
xlim([1 100])
ylim([-0.1 1.1])
legend('A','B','C','Location','northwest')
grid on
%% Preparing for ML
% A
for i = 1:97
XTrain_A{1,i} = A(:,i:i+2).';
YTrain_A{1,i} = A(:,i+3);
end
% C
for i = 1:97
XTrain_C{1,i} = C(:,i:i+2).';
YTrain_C{1,i} = C(:,i+3);
end
% D
for i = 1:97
XTrain_D{1,i} = D(:,i:i+2).';
YTrain_D{1,i} = D(:,i+3);
end
% Input
XTrain{1,1} = XTrain_D;
XTrain{2,1} = XTrain_A;
XTrain{3,1} = XTrain_C;
YTrain{1,1} = YTrain_D;
YTrain{2,1} = YTrain_A;
YTrain{3,1} = YTrain_C;
%% TrainNetwork
numFeatures = 3;
numResponses = 1;
numHiddenUnits = 300;
layers = [ ...
sequenceInputLayer(numFeatures)
flattenLayer('Name','flatten')
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(20)
fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions('adam', ...
'MaxEpochs',200, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.0001, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',50, ...
'LearnRateDropFactor',0.2, ...
'Verbose',0, ...
'Plots','training-progress');
net = trainNetwork(XTrain, YTrain, layers, options);
%% Test
Result = zeros(1,100);
Result(:,1:3) = B(1,1:3);
for i = 1:97
[net,Result(1,i+1)] = predictAndUpdateState(net, Result(:,i:i+2).');
end
%% Plot result
plot(1:100, B(1,:),'k','LineWidth',2);hold on
plot(1:100, Result(1,:),'r','LineWidth',2);hold off
xlim([1 100])
ylim([-0.1 1.1])
legend('B','Predection','Location','northwest')
grid on
7 Comments
Yuuki
on 23 Nov 2020
Naoya
on 27 Nov 2020
trainNetwork に与える XTrain は 数値シーケンスの N 行 1 列の cell 配列として与える必要があります。
ここで、N はシーケンスのパターン数となります。
また、各セルには、 FxS の行列として定義します。
F は入力の特徴量数、S はタイムステップ数となります。
同様に、YTrain は数値シーケンスの N 行 1 列の cell 配列として与えます。。
各セルには RxS の行列として定義します。
R は出力の応答数、Sはタイムステップ数となります。
キャプチャの XTrain, YTrain をみてみますと、少なくとも セル配列の中にセル配列となる定義になっているようですので、上記のように修正する必要があると思います。
Yuuki
on 27 Nov 2020
Naoya
on 30 Nov 2020
帰還路も含むリカレントネットワークの一部となりますので、入力は t, t-1, t-2 と別々に設ける必要はないかもしれません。
5つのセンサーを入力とする場合は、 5xS として与えてみてはいかがでしょうか?
Yuuki
on 30 Nov 2020
Naoya
on 30 Nov 2020
はい、その理解となります。 t-2, t-1 を含めずにまずはお試し頂ければと思います。
Yuuki
on 30 Nov 2020
Answers (0)
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!

