How to change output classes for image classification in alexnet?

3 views (last 30 days)
Hello everyone
I'm trying to implement the following example:
but I got this error:
Error using trainNetwork (line 170)
Invalid training data. The output size (28224) of the last layer does not match the number of classes (24).
and this is my code:
tic
clear;
clc;
outputFolder = fullfile('Spilt_Dataset');
rootFolder = fullfile(outputFolder , 'Categories');
categories = {'front_or_left' , 'front_or_right','hump' , 'left_turn','narrow_from_left' , 'narrows_from_right',...
'no_horron' , 'no_parking','no_u_turn' , 'overtaking_is_forbidden','parking' , 'pedestrian_crossing',...
'right_or_left' , 'right_turn','rotor' , 'slow','speed_30' , 'speed_40',...
'speed_50','speed_60','speed_80','speed_100','stop','u_turn'};
imds = imageDatastore(fullfile(rootFolder,categories),'LabelSource','Foldernames');
tbl = countEachLabel(imds);
minSetCount = min(tbl{:,2});
imds = splitEachLabel(imds, minSetCount, 'randomize');
countEachLabel(imds);
front_or_left = find(imds.Labels == 'front_or_left',1);
front_or_right = find(imds.Labels == 'front_or_right',1);
hump = find(imds.Labels == 'hump',1);
left_turn = find(imds.Labels == 'left_turn',1);
narrow_from_left = find(imds.Labels == 'narrow_from_left',1);
narrows_from_right = find(imds.Labels == 'narrows_from_right',1);
no_horron = find(imds.Labels == 'no_horron',1);
no_parking = find(imds.Labels == 'no_parking',1);
no_u_turn = find(imds.Labels == 'no_u_turn',1);
overtaking_is_forbidden = find(imds.Labels == 'overtaking_is_forbidden',1);
parking = find(imds.Labels == 'parking',1);
pedestrian_crossing = find(imds.Labels == 'pedestrian_crossing',1);
right_or_left = find(imds.Labels == 'right_or_left',1);
right_turn = find(imds.Labels == 'right_turn',1);
rotor = find(imds.Labels == 'rotor',1);
slow = find(imds.Labels == 'slow',1);
speed_30 = find(imds.Labels == 'speed_30',1);
speed_40 = find(imds.Labels == 'speed_40',1);
speed_50 = find(imds.Labels == 'speed_50',1);
speed_60 = find(imds.Labels == 'speed_60',1);
speed_80 = find(imds.Labels == 'speed_80',1);
speed_100 = find(imds.Labels == 'speed_100',1);
stop = find(imds.Labels == 'stop',1);
u_turn = find(imds.Labels == 'u_turn',1);
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
net = googlenet;
inputSize = net.Layers(1).InputSize;
lgraph = layerGraph(net);
numClasses = numel(categories(imdsTrain.Labels))
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,'loss3-classifier',newLearnableLayer);
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',false, ...
'Plots','training-progress');
netTransfer = trainNetwork(augimdsTrain,lgraph,options);

Accepted Answer

Srivardhan Gadila
Srivardhan Gadila on 16 Apr 2021
If the classes listed from line "front_or_left = find(imd...." in the above code are the only classes which are 24 in total then the issue might be that the value returned by
numClasses = numel(categories(imdsTrain.Labels))
is probably 28224 and not 24. If that's the case then specify the value 24 directly in the line
newLearnableLayer = fullyConnectedLayer(24, ...
and try running the code.
You can try debugging the issue by checking the layer activation size by using analyzeNetwork function.
If the network is fine then the issue might be w.r.t your data i.e., labels itself.
If you still face the issue then Contact MathWorks Support.

More Answers (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!