How can I obtain the Shapley values from a Neural Network Object?
51 views (last 30 days)
Show older comments
MathWorks Support Team
on 8 Jul 2021
Answered: MathWorks Support Team
on 16 Sep 2021
I have created a neural network for pattern recognition with the 'patternnet' function and would like the calculate its Shapley values by executing this code:
[x,t] = iris_dataset;
net = patternnet(7);
net = train(net,x,t)
queryPoint=x(:,1)'
explainer = shapley(net,x,'QueryPoint',queryPoint)
I get the following error:
Error using shapley
Blackbox model must be a classification model, regression model, or function handle
Is there a way to obtain the Shapley values from a 'network' object as the one above?
Accepted Answer
MathWorks Support Team
on 8 Jul 2021
For Shapley, the blackbox input must either be a classification or regression model, which are the models you can attain using "fitcxxx" or "fitrxxx" in Statistics and Machine Learning Toolbox, or it needs to be a function handle which takes in a query point, makes predictions and returns class labels. A network object is not one of the classification or regression models, so we need to use a function handle. As long as the object can do predictions, it can be wrapped in a function and a function handle can be created with it. Here is an example:
Save this function and add it to your path.
function y = predict(net,x)
probabilities = net(x');
[~,y] = max(probabilities,[],1);
y = y(:);
end
We need to change the orientation of the inputs since Shapley expects them to be in a certain orientation.
To train the model:
[x,t] = iris_dataset;
net = patternnet(10);
net = train(net,x,t);
queryPoint=x(:,1)';
We can then create our function handle:
blackbox = @(x)predict(net,x);
Note that when the input to Shapley is a function handle, the training data needs to be passed when creating the Shapley explainer:
explainer = shapley(blackbox,x');
Afterwards, we use the fit function along with a query point:
explainer = fit(explainer,queryPoint);
plot(explainer);
0 Comments
More Answers (0)
See Also
Categories
Find more on Get Started with Statistics and Machine Learning Toolbox in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!