Using fft to replace feature learning in CNN
    3 views (last 30 days)
  
       Show older comments
    
    Juuso Korhonen
 on 19 Jan 2021
  
    
    
    
    
    Commented: fawad ahmad
 on 3 Aug 2021
            Hello,
I read this interesting article: https://www.groundai.com/project/reducing-deep-network-complexity-with-fourier-transform-methods/1  , where they managed to get really good results with replacing feature learning in CNN with basic fft. I'm very interested to try this out in Matlab, because of the implications that it could relax the requirements for the amount of data (I'm currently working with medical data where sample sizes are often small). But I can't seem to get it to work, since my accuracy stays at 10% in MNIST data, which means that it is basically not learning anything. There must be some major bug, but I can't figure it out. I suspect it has to do with my implementation of the preprocessForTraining function, which is applied as transformation function for the imageDataStore to do fft on the images and the flatten these fft images to 1-D vector to be inputted to featureInputLayer in my simple neural network. (However I think the transformation goes right since I can read an image from the dsTrain and transform it back to original image)
% data read
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');
% batch the data, so it can do batch normalization in training
miniBatchSize = 128;
imds.ReadSize = miniBatchSize;
% split to training and validation data
numTrainFiles = 750;
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomize');
% define a transform which is to be applied everytime data is read
% our transform of choice is in preprocessForTraining function (separate
% file) which includes grayscaling, resizing and fft and flattening into
% 1-d vector
dsTrain = transform(imdsTrain, @preprocessForTraining,'IncludeInfo',true);
dsValidation = transform(imdsValidation, @preprocessForTraining,'IncludeInfo',true);
% Network structure (basic MLP)
% input size is twice the pixel amount due to both real and imaginary part of
% fft
% one hidden layer with half the input size as the number of nodes
% relus as activation functions
layers = [
    featureInputLayer(28*28*2)
    fullyConnectedLayer(28*28)
    reluLayer
    fullyConnectedLayer(10)
    reluLayer;
    softmaxLayer
    classificationLayer];
% training options
options = trainingOptions('adam', ...
    'Plots','training-progress', ...
    'MiniBatchSize',miniBatchSize);
% training
net = trainNetwork(dsTrain,layers,options);
function [dataOut,info] = preprocessForTraining(data,info)
    numRows = size(data,1);
    dataOut = cell(numRows,2);
    targetSize = [28,28]; 
    % since ReadSize is expected to be >1, data comes in cell form containing
    % multiple images
    for idx = 1:numRows
        % get the image out of the datacell
        img = data{idx,1};
        % if rgb image, turn to grayscale
        if size(img, 3) == 3
            img = rgb2gray(img);
        end    
        % resize and fft
        fft_img = fftshift(fft2(imresize(img, targetSize)));
        real_part = real(fft_img);
        imag_part = imag(fft_img);
        % flatten to vector
        imgOut = [real_part(:); imag_part(:)];   
        % Return the label from info struct as the 
        % second column in dataOut.
        dataOut(idx,:) = {imgOut,info.Label(idx)};
    end
end
1 Comment
Accepted Answer
  Hrishikesh Borate
    
 on 2 Feb 2021
        Hi,
I understand that you are using FFT for feature learning instead of CNN and the accuracy is staying at 10%. This is due to the use of reluLayer before the softmaxLayer in the layer before classificationLayer. You can use the following layer definition, to improve the training results.
layers = [
    featureInputLayer(28*28*2)
    fullyConnectedLayer(28*28)
    reluLayer
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];
0 Comments
More Answers (0)
See Also
Categories
				Find more on Get Started with Deep Learning Toolbox in Help Center and File Exchange
			
	Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!

