機械学習の入力エラーについて

LSTMの学習方法について質問です.
最下部に示したコードを実行したとき,「予測子はシーケンスの N 行 1 列の cell 配列でなければなりません。」が表示されうまく学習できません.
入力データは1タイムステップに t-2, t-1, t のデータが含まれており,それに対応する出力データは t+1 のデータとなっています.
ここで学習に用いるデータを
net = trainNetwork(XTrain_C, YTrain_C, layers, options);
のように,Cのみを用いるようにすると上手く実行できるのですが,元コードのようにA,C,Dの3つの時系列データを学習させたモデルを作成しようとするとエラー文が表示されてしまいます.
下記のコードをどう修正すれば実行可能になりますでしょうか?
clear all
close all
%% Make dataset
A = zeros(1,100);
B = zeros(1,100);
C = zeros(1,100);
D = zeros(1,100);
% A
for i = 1:100
if i <= 40
A(:,i) = i / 40;
elseif (41 <= i) && (i <= 45)
A(:,i) = 1 - ((i - 40) / 5);
elseif 46 <= i
A(:,i) = 0;
end
end
% B
for i = 1:100
if i <= 60
B(:,i) = i / 60;
elseif (61 <= i) && (i <= 65)
B(:,i) = 1 - ((i - 60) / 5);
elseif 66 <= i
B(:,i) = 0;
end
end
% C
for i = 1:100
if i <= 80
C(:,i) = i / 80;
elseif (81 <= i) && (i <= 85)
C(:,i) = 1 - ((i - 80) / 5);
elseif 86 <= i
C(:,i) = 0;
end
end
% D
for i = 1:100
if i <= 40
D(:,i) = i / 20;
elseif (21 <= i) && (i <= 25)
D(:,i) = 1 - ((i - 20) / 5);
elseif 26 <= i
D(:,i) = 0;
end
end
%% Plot
plot(1:100, A(1,:),'LineWidth',2);hold on
plot(1:100, B(1,:),'LineWidth',2);hold on
plot(1:100, C(1,:),'LineWidth',2);hold off
xlim([1 100])
ylim([-0.1 1.1])
legend('A','B','C','Location','northwest')
grid on
%% Preparing for ML
% A
for i = 1:97
XTrain_A{1,i} = A(:,i:i+2).';
YTrain_A{1,i} = A(:,i+3);
end
% C
for i = 1:97
XTrain_C{1,i} = C(:,i:i+2).';
YTrain_C{1,i} = C(:,i+3);
end
% D
for i = 1:97
XTrain_D{1,i} = D(:,i:i+2).';
YTrain_D{1,i} = D(:,i+3);
end
% Input
XTrain{1,1} = XTrain_D;
XTrain{2,1} = XTrain_A;
XTrain{3,1} = XTrain_C;
YTrain{1,1} = YTrain_D;
YTrain{2,1} = YTrain_A;
YTrain{3,1} = YTrain_C;
%% TrainNetwork
numFeatures = 3;
numResponses = 1;
numHiddenUnits = 300;
layers = [ ...
sequenceInputLayer(numFeatures)
flattenLayer('Name','flatten')
lstmLayer(numHiddenUnits,'OutputMode','sequence')
fullyConnectedLayer(20)
fullyConnectedLayer(numResponses)
regressionLayer];
options = trainingOptions('adam', ...
'MaxEpochs',200, ...
'GradientThreshold',1, ...
'InitialLearnRate',0.0001, ...
'LearnRateSchedule','piecewise', ...
'LearnRateDropPeriod',50, ...
'LearnRateDropFactor',0.2, ...
'Verbose',0, ...
'Plots','training-progress');
net = trainNetwork(XTrain, YTrain, layers, options);
%% Test
Result = zeros(1,100);
Result(:,1:3) = B(1,1:3);
for i = 1:97
[net,Result(1,i+1)] = predictAndUpdateState(net, Result(:,i:i+2).');
end
%% Plot result
plot(1:100, B(1,:),'k','LineWidth',2);hold on
plot(1:100, Result(1,:),'r','LineWidth',2);hold off
xlim([1 100])
ylim([-0.1 1.1])
legend('B','Predection','Location','northwest')
grid on

7 Comments

Yuuki
Yuuki on 23 Nov 2020
XTrainとYTrainの中身を張っておきます.
Naoya
Naoya on 27 Nov 2020
trainNetwork に与える XTrain は 数値シーケンスの N 行 1 列の cell 配列として与える必要があります。
ここで、N はシーケンスのパターン数となります。
また、各セルには、 FxS の行列として定義します。
F は入力の特徴量数、S はタイムステップ数となります。
同様に、YTrain は数値シーケンスの N 行 1 列の cell 配列として与えます。。
各セルには RxS の行列として定義します。
R は出力の応答数、Sはタイムステップ数となります。
キャプチャの XTrain, YTrain をみてみますと、少なくとも セル配列の中にセル配列となる定義になっているようですので、上記のように修正する必要があると思います。
Yuuki
Yuuki on 27 Nov 2020
Naoya様
コメントありがとうございます.
>また、各セルには、 FxS の行列として定義します。F は入力の特徴量数、S はタイムステップ数となります。
つまりキャプチャの例だと
A,C,Dの3つのシーケンスパターン(キャプチャ一枚目.N=3)は正しくできているが,各セルの中身はF*Sの行列でなければならず,1つのタイムステップにt-2. t-1, tを特徴量として与えたい場合は3行S列の行列としなければならない.
という理解で合っていますでしょうか?
最終的にはこちらのように1つのタイムステップは5行1列(5つのセンサーで取得したデータを各行に配置)であり,それを用いて各行に対応する次のタイムステップの値(5行1列)を予測することを実行したいと考えています.
上記のようにt-2, t-1, tのような複数の時間のデータを用いてt+1を予測したい場合,与えるXTrainの各セルの中のシーケンスデータは15*Sの行列データになるのでしょうか?
Naoya
Naoya on 30 Nov 2020
帰還路も含むリカレントネットワークの一部となりますので、入力は t, t-1, t-2 と別々に設ける必要はないかもしれません。
5つのセンサーを入力とする場合は、 5xS として与えてみてはいかがでしょうか?
Yuuki
Yuuki on 30 Nov 2020
Naoya様
勉強不足で申し訳ありません.
>入力は t, t-1, t-2 と別々に設ける必要はない
とは,
LSTMによる学習は時刻tのデータを入力すると,RNNの性質上それ以前の結果も考慮した出力結果(時刻t+1)を出してくれるからわざわざ時刻tの入力データにt-2,t-1のデータを含める必要はない,
ということでしょうか?
Naoya
Naoya on 30 Nov 2020
はい、その理解となります。 t-2, t-1 を含めずにまずはお試し頂ければと思います。
Yuuki
Yuuki on 30 Nov 2020
Naoya様
ご返信ありがとうございます.
何度か5×Sで試したもののあまり精度が出ず,他論文でt-5~tのデータを入力としt+1を出力する例を見たため,同様の方法で精度が上がらないかと思い上記の質問を設けた次第です.
もう少し他の方法で改善を試みようと思います.

Sign in to comment.

Answers (0)

Categories

Find more on Deep Learning Toolbox in Help Center and File Exchange

Products

Release

R2019b

Asked:

on 23 Nov 2020

Commented:

on 30 Nov 2020

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!