Conditional GAN Training Error for TrainGAN function
Show older comments
I try to make the Conditional GAN training working with input as a 2D matrix: 14*8.
I try to mimic the "GenerateSyntheticPumpSignalsUsingCGANExample", by changing the vector input as a 2D matrix input.
The error message pops out as:

It seems that there is a size mismatch in the function modelGradients. But since this is an official example, thus I have no idea how to revise it. Can someone give a hint?
The input data is attached as: test.mat
The training script is attached as: untitled3.m. I have also pasted it below.
clear;
%% Load the data
% LSTM_Reform_Data_SeriesData1_20210315_data001_for_GAN;
% load('LoadedData_20210315_data001_for_GAN.mat')
load('test.mat');
% load('test2.mat');
%% Generator Network
numFilters = 4;
numLatentInputs = 120;
projectionSize = [2 1 63];
numClasses = 2;
embeddingDimension = 120;
layersGenerator = [
imageInputLayer([1 1 numLatentInputs],'Normalization','none','Name','Input')
projectAndReshapeLayer(projectionSize,numLatentInputs,'ProjReshape');
concatenationLayer(3,2,'Name','Concate1');
transposedConv2dLayer([3 2],8*numFilters,'Stride',1,'Name','TransConv1') % 4*2*32
batchNormalizationLayer('Name','BN1','Epsilon',1e-5)
reluLayer('Name','Relu1')
transposedConv2dLayer([2 2],4*numFilters,'Stride',2,'Name','TransConv2') % 8*4*16
batchNormalizationLayer('Name','BN2','Epsilon',1e-5)
reluLayer('Name','Relu2')
transposedConv2dLayer([2 2],2*numFilters,'Stride',2,'Cropping',[2 1],'Name','TransConv3') % 12*6*8
batchNormalizationLayer('Name','BN3','Epsilon',1e-5)
reluLayer('Name','Relu3')
transposedConv2dLayer([3 3],2*numFilters,'Stride',1,'Name','TransConv4') % 14*8*1
];
lgraphGenerator = layerGraph(layersGenerator);
layers = [
imageInputLayer([1 1],'Name','Labels','Normalization','none')
embedAndReshapeLayer(projectionSize(1:2),embeddingDimension,numClasses,'EmbedReshape1')];
lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,'EmbedReshape1','Concate1/in2');
subplot(1,2,1);
plot(lgraphGenerator);
dlnetGenerator = dlnetwork(lgraphGenerator);
%% Discriminator Network
scale = 0.2;
Input_Num_Feature = [14 8 1]; % The input data is [14 8 1]
layersDiscriminator = [
imageInputLayer(Input_Num_Feature,'Normalization','none','Name','Input')
concatenationLayer(3,2,'Name','Concate2')
convolution2dLayer([2 2],4*numFilters,'Stride',1,'DilationFactor',2,'Padding',[0 0],'Name','Conv1')
leakyReluLayer(scale,'Name','LeakyRelu1')
convolution2dLayer([2 4],2*numFilters,'Stride',2,'DilationFactor',1,'Padding',[2 2],'Name','Conv2')
leakyReluLayer(scale,'Name','LeakyRelu2')
convolution2dLayer([2 2],numFilters,'Stride',2,'DilationFactor',1,'Padding',[0 0],'Name','Conv3')
leakyReluLayer(scale,'Name','LeakyRelu3')
convolution2dLayer([2 1],numFilters/2,'Stride',1,'DilationFactor',2,'Padding',[0 0],'Name','Conv4')
leakyReluLayer(scale,'Name','LeakyRelu4')
convolution2dLayer([2 2],numFilters/4,'Stride',1,'DilationFactor',1,'Padding',[0 0],'Name','Conv5')
];
lgraphDiscriminator = layerGraph(layersDiscriminator);
layers = [
imageInputLayer([1 1],'Name','Labels','Normalization','none')
embedAndReshapeLayer(Input_Num_Feature,embeddingDimension,numClasses,'EmbedReshape2')];
lgraphDiscriminator = addLayers(lgraphDiscriminator,layers);
lgraphDiscriminator = connectLayers(lgraphDiscriminator,'EmbedReshape2','Concate2/in2');
subplot(1,2,2);
plot(lgraphDiscriminator);
dlnetDiscriminator = dlnetwork(lgraphDiscriminator);
%% Train model
params.numLatentInputs = numLatentInputs;
params.numClasses = numClasses;
params.sizeData = [Input_Num_Feature length(Series_Fused_Label)];
params.numEpochs = 50;
params.miniBatchSize = 512;
% Specify the options for Adam optimizer
params.learnRate = 0.0002;
params.gradientDecayFactor = 0.5;
params.squaredGradientDecayFactor = 0.999;
executionEnvironment = "cpu";
params.executionEnvironment = executionEnvironment;
% for test, 14*8*30779
[dlnetGenerator,dlnetDiscriminator] =...
trainGAN(dlnetGenerator,dlnetDiscriminator,Series_Fused_Expand_Norm_Input,Series_Fused_Label,params);
Accepted Answer
More Answers (0)
Categories
Find more on Deep Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!