- “X = reshape(X', 28, 28, 1, size(X,1));” reshapes each input back to the original [28, 28, 1] image format.
- “score = squeeze(YPred(targetClass, :))';” selects the logits for the target class (the class label of the ‘QueryPoint’) and returns a column vector.
how can I obtain shapley values from convolutional neural network
18 views (last 30 days)
Show older comments
clear all;
close all;
[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;
[XTest,YTest,anglesTest] = digitTest4DArrayData;
% Define the layers of the CNN
layers = [
imageInputLayer([28 28 1]) % Assuming grayscale 64x64 images
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(10) % Assuming 10 classes
softmaxLayer
classificationLayer];
% Define the training options
options = trainingOptions('adam', ...
'MaxEpochs',10, ...
'InitialLearnRate',1e-4, ...
'Verbose',false, ...
'Plots','training-progress');
% Train the CNN on the training data
net = trainNetwork(XTrain, categorical(YTrain), layers, options);
% Predict on the test data
YPred = classify(net, XTest);
accuracy = mean(YPred == categorical(YTest));
% Display the confusion matrix
confusionchart(categorical(YTest), YPred);
explainer = shapley( ...
@(XTest)PredictCNN(net,XTest,YTest(1)), ...
reshape(XTest,[5000,28*28]), "QueryPoint", reshape(XTest(:,:,1,1),[1,28*28]) );
function score = PredictCNN(net,XTest,YTest)
YPred = predict(net,XTest);
score = YPred(:,double(YTest));
end
0 Comments
Accepted Answer
Rahul
on 25 Oct 2024 at 6:57
I understand that you are trying to get shapely values from a convolutional neural network. A few adjustments you can make in your code to ensure correct computation of Shapley values:
Flattening: The 4D images are reshaped to a 2D matrix for Shapley value calculation. The ‘shapley’ function expects a 2D input matrix ([NumImages, NumFeatures]), but CNNs typically take 4D input data ([Height, Width, Channels, NumSamples]).
XTestFlatSample = reshape(XTestSample, [28*28, numSamples])';
This line reshapes ‘XTestSample’ (e.g., 28x28x1x200 for grayscale images) into ‘XTestFlatSample’, a 2D matrix of size [200, 28*28], suitable for the ‘shapley’ function.
QueryPoint: Ensuring the ‘QueryPoint’ matches the format expected by your model, which is specified as ‘XTestFlatSample(1,:)’, indicating the first flattened testing image is being used as the point to explain.
Parallelization: Set 'UseParallel' to true in the ‘shapley’ function call to distribute the computation across multiple cores if you have the Parallel Computing Toolbox, to improve performance.
explainer = shapley( ...
@(X)PredictCNN(net, X, YTestSample(1)), ...
XTestFlatSample, ...
"QueryPoint", XTestFlatSample(1,:), ...
'UseParallel', true);
Custom Prediction Function: The CNN expects a 4D input format, but the ‘shapley’ function provides a 2D matrix. Additionally, ‘shapley’ requires a column vector of scores for each image, not the logits for all classes.
function score = PredictCNN(net, X, targetClass)
% Reshape the flattened input back to 4D
X = reshape(X', 28, 28, 1, size(X,1));
% Obtain logits from CNN
YPred = activations(net, X, 'fc');
% Extract the score for the target class
score = squeeze(YPred(targetClass, :))';
end
Here’s how the final code would look like:
clear all;
close all;
[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;
[XTest,YTest,anglesTest] = digitTest4DArrayData;
% Define the layers of the CNN
layers = [
imageInputLayer([28 28 1]) % Assuming grayscale 28x28 images
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(10) % Assuming 10 classes
softmaxLayer
classificationLayer];
% Define the training options
options = trainingOptions('adam', ...
'MaxEpochs',10, ...
'InitialLearnRate',1e-4, ...
'Verbose',false, ...
'Plots','training-progress');
% Train the CNN on the training data
net = trainNetwork(XTrain, categorical(YTrain), layers, options);
% Predict on the test data
YPred = classify(net, XTest);
accuracy = mean(YPred == categorical(YTest));
% Display the confusion matrix
confusionchart(categorical(YTest), YPred);
% Sample a smaller portion of XTest for Shapley computation
numSamples = 200; % Adjust to a lower sample size if necessary
XTestSample = XTest(:,:,:,1:numSamples);
YTestSample = YTest(1:numSamples);
% Flatten the sample images to a 2D matrix
XTestFlatSample = reshape(XTestSample, [28*28, numSamples])';
% Shapley value computation with parallel processing and smaller sample
explainer = shapley( ...
@(X)PredictCNN(net, X, YTestSample(1)), ... % Custom prediction function
XTestFlatSample, ... % Pass the flattened input sample images
"QueryPoint", XTestFlatSample(1,:), ...
'UseParallel', true); % Use parallel processing
% Custom function to reshape input and get raw predictions from CNN
function score = PredictCNN(net, X, targetClass)
% Reshape the flattened input back to 4D [Height, Width, Channels, NumSamples]
X = reshape(X', 28, 28, 1, size(X,1)); % Transpose X to match original dimensions
% Get raw scores (logits) before softmax
YPred = activations(net, X, 'fc'); % 'fc' refers to the fully connected layer
YPred = squeeze(YPred); % Remove singleton dimensions
% Extract the score for the target class (column corresponding to class of interest)
score = YPred(targetClass, :)'; % Return the score for the target class as a column vector
end
You can refer to the following documentation link for more information regarding the usage of ‘shapley’ function:
Hope this helps!
More Answers (1)
Taylor
on 24 Oct 2024
The shapley function expects the input data to be in a format suitable for the model. You are reshaping XTest to a 2D matrix with dimensions [5000, 28*28], assuming 5000 samples of 28x28 images. Ensure that XTest indeed has 5000 samples. If not, adjust the reshape dimensions accordingly.
The PredictCNN function is used as a handle in the shapley function. It takes XTest and YTest as inputs. However, the shapley function only passes the reshaped XTest. You will need to modify PredictCNN to handle this correctly, possibly by removing YTest from its input arguments.
Modify PredictCNN to accommodate the input format expected by shapley:
numSamples = size(XTest, 4); % Adjust based on your dataset
explainer = shapley( ...
@(XTest)PredictCNN(net, XTest), ...
reshape(XTest, [numSamples, 28*28]), ...
"QueryPoint", reshape(XTest(:,:,1,1), [1, 28*28]) );
Ensure that the XTest reshaping aligns with the actual number of samples:
function score = PredictCNN(net, XTest)
YPred = predict(net, XTest);
score = YPred;
end
See Also
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!