How to add LDA axis to scatter plot?
14 views (last 30 days)
Show older comments
I would like to ask how can I add LDA axis to the scatter plot in the link attached below?
Thanks a lot
2 Comments
Ameer Hamza
on 11 Apr 2020
The link does show an example of how to plot the line. Which part is confusing to you? Do you have a specific dataset?
Accepted Answer
Guillaume Erny
on 27 Jan 2021
Edited: Guillaume Erny
on 27 Jan 2021
I have used the fisheriris dataset to test.
1. Load fisheriris and do analysis
load fisheriris
Mdl = fitcdiscr(meas,species, 'DiscrimType','quadratic')
what you are looking for is the Mu, the center positions of the group obatined by the discriminant analysis
2. Visualisation
So you can see the position of the center of the groups within your data
figure, nexttile
gscatter(meas(:,1), meas(:,2), species);
hold on
scatter(Mdl.Mu(:, 1), Mdl.Mu(:, 2), 100, 'xK', 'LineWidth', 1.5)
hold off
nexttile
gscatter(meas(:,1), meas(:,3), species);
hold on
scatter(Mdl.Mu(:, 1), Mdl.Mu(:, 3), 100, 'xK', 'LineWidth', 1.5)
hold off
legend('off')
nexttile
gscatter(meas(:,1), meas(:,4), species);
hold on
scatter(Mdl.Mu(:, 1), Mdl.Mu(:, 4), 100, 'xK', 'LineWidth', 1.5)
hold off
legend('off')
nexttile
gscatter(meas(:,2), meas(:,3), species);
hold on
scatter(Mdl.Mu(:, 2), Mdl.Mu(:, 3), 100, 'xK', 'LineWidth', 1.5)
hold off
legend('off')
etc for 2/4 and 3/4
your LDA axis is the projection of your data to all possible lines crossing two centers (black crosses), in this case three possible axis. so from 4 axis to three
3. Projection
Here is the code I used to project the original data to LD axis (in this quadratic discriminant). This is not vectorized. I am sure clever peoples will find very easy to improve the code below to improve speed.
V = Mdl.Mu; % positions of the centers
Comb = [1, 2; 1, 3; 2, 3]; % all possible combinations to form the axis with 3 groups
for ii = 1:3 % for each axis
V1 = V(Comb(ii,2), :);
V2 = V(Comb(ii,1), :);
v = (V2-V1)./norm(V2-V1); % normalised vector for projection
for jj =1:size(meas, 1)
Q(jj, :) = dot(meas(jj, :)-V1,v)*v+V1;
LD(jj, ii) = (Q(jj, 1)-V1(1))/(V2(1)-V1(1));
end
end
4. Quick Visualisation of the results
figure
hold on
for ii = 1:3
scatter3(LD(strcmp(species, 'setosa'), 1), LD(strcmp(species, 'setosa'), 2), LD(strcmp(species, 'setosa'), 3), 'r')
scatter3(LD(strcmp(species, 'versicolor'), 1), LD(strcmp(species, 'versicolor'), 2), LD(strcmp(species, 'versicolor'), 3), 'g')
scatter3(LD(strcmp(species, 'virginica'), 1), LD(strcmp(species, 'virginica'), 2), LD(strcmp(species, 'virginica'), 3), 'b')
end
hold off
legend({'setosa', 'versicolor', 'virginica' })
xlabel('LD1')
ylabel('LD2')
zlabel('LD3')
view(40,35)
More Answers (0)
See Also
Categories
Find more on Classification in Help Center and File Exchange
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!