- 'Train Network with Complex-Valued Data' example - https://www.mathworks.com/help/deeplearning/ug/train-network-with-complex-valued-data.html
- "SplitComplexInputs" argument - https://www.mathworks.com/help/deeplearning/ref/nnet.cnn.layer.sequenceinputlayer.html#mw_1f1ef68c-244a-4374-a846-2e24b71d384f_sep_mw_19bc7780-8e05-482a-b309-d24e230ab466
- function handles - https://www.mathworks.com/help/matlab/function-handles.html
- "trainnet" function - https://www.mathworks.com/help/deeplearning/ref/trainnet.html
netTrained = trainnet(sequences,targets,net,lossFcn,options),sequences包含复数无法使用此函数
6 views (last 30 days)
Show older comments
问题:
应用函数netTrained = trainnet(sequences,targets,net,lossFcn,options),
sequences包含复数时如何使用此函数?
函数说明里有提示可使用复数输入:This argument supports complex-valued predictors and targets.
代码:
XTrain = permute(dataTrain(:,1:end-1,:),[1,3,2]);
TTrain = permute(dataTrain(:,2:end,:),[1,3,2]);
numChannels = betalen;
layers = [
sequenceInputLayer(numChannels)
lstmLayer(128)
fullyConnectedLayer(numChannels)];
options = trainingOptions("adam", ...
MaxEpochs=200, ...
SequencePaddingDirection="left", ...
Shuffle="every-epoch", ...
Plots="training-progress", ...
Verbose=false);
net = trainnet(XTrain,TTrain,layers,"mse",options);
报错结果:
错误使用 trainnet (第 46 行)
在层 'lstm' 期间执行失败。
出错 HDL (第 66 行)
net = trainnet(XTrain,TTrain,layers,"mse",options);
原因:
错误使用 dlarray/lstm (第 105 行)
位置 1 处的参数无效。 值必须为实数。
0 Comments
Answers (1)
Paras Gupta
on 18 Jul 2024
Edited: Paras Gupta
on 18 Jul 2024
Hi Alexander,
I understand that you are trying to use the "trainnet" function on complex-valued sequences and complex-valued targets.
You are correct in noting that the documentation indicates that the "trainnet" function can support complex-valued predictors and targets. However, the built-in loss functions provided by "trainnet" do not inherently support complex-valued targets. To address this, you will need to define a custom loss function that can handle complex values for targets.
Moreover, the "sequenceInputLayer" in your model should be configured to handle complex-valued inputs. This can be done by setting the "SplitComplexInputs" argument to true.
Below is an example of a custom loss function for complex inputs, which you can use in your training loop:
% dummy data
numSamples = 100;
numTimesteps = 10;
numChannels = 2;
realPart = randn(numSamples, numTimesteps, numChannels);
imagPart = randn(numSamples, numTimesteps, numChannels);
dataTrain = realPart + 1i * imagPart;
XTrain = permute(dataTrain(:,1:end-1,:),[1,3,2]);
% complex target
TTrain = permute(dataTrain(:,2:end,:),[1,3,2]);
% real target
% TTrain = rand(numSamples, numChannels, numTimesteps-1);
numChannels = 2;
layers = [
sequenceInputLayer(numChannels, SplitComplexInputs=true) % split Complex Inputs
lstmLayer(128)
fullyConnectedLayer(numChannels)];
options = trainingOptions("adam", ...
MaxEpochs=200, ...
SequencePaddingDirection="left", ...
Shuffle="every-epoch", ...
Plots="training-progress", ...
Verbose=false);
% net = trainnet(XTrain, TTrain, layers, "mse", options);
% custom loss function passed as function handle
net = trainnet(XTrain, TTrain, layers, @complexLoss, options);
function loss = complexLoss(Y, T)
difference = Y - T;
squaredMagnitude = real(difference).^2;
loss = mean(squaredMagnitude, 'all');
end
You can refer to the following documentation links for more information on the code above:
Hope this helps with your work.
0 Comments
See Also
Categories
Find more on R Language in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!