How can I obtain the Shapley values from a Neural Network Object?

51 views (last 30 days)
I have created a neural network for pattern recognition with the 'patternnet' function and would like the calculate its Shapley values by executing this code:
[x,t] = iris_dataset; 
net = patternnet(7);  
net = train(net,x,t) 
queryPoint=x(:,1)'
explainer = shapley(net,x,'QueryPoint',queryPoint)
I get the following error:
Error using shapley
Blackbox model must be a classification model, regression model, or function handle
Is there a way to obtain the Shapley values from a 'network' object as the one above? 

Accepted Answer

MathWorks Support Team
MathWorks Support Team on 8 Jul 2021
For Shapley, the blackbox input must either be a classification or regression model, which are the models you can attain using "fitcxxx" or "fitrxxx" in Statistics and Machine Learning Toolbox, or it needs to be a function handle which takes in a query point, makes predictions and returns class labels. A network object is not one of the classification or regression models, so we need to use a function handle. As long as the object can do predictions, it can be wrapped in a function and a function handle can be created with it. Here is an example:
Save this function and add it to your path.
function y = predict(net,x)
probabilities = net(x');
[~,y] = max(probabilities,[],1);
y = y(:);
end
We need to change the orientation of the inputs since Shapley expects them to be in a certain orientation.
To train the model:
[x,t] = iris_dataset; 
net = patternnet(10);  
net = train(net,x,t); 
queryPoint=x(:,1)';
We can then create our function handle:
blackbox = @(x)predict(net,x);
Note that when the input to Shapley is a function handle, the training data needs to be passed when creating the Shapley explainer:
explainer = shapley(blackbox,x');
Afterwards, we use the fit function along with a query point:
explainer = fit(explainer,queryPoint);
plot(explainer);

More Answers (0)

Products


Release

R2021a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!