Skip to content

Commit df7942a

Browse files
qinhanmin2014rth
authored andcommitted
EXA Add an example about how to obtain the support vectors in… (scikit-learn#14355)
1 parent acd3510 commit df7942a

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
=====================================
3+
Plot the support vectors in LinearSVC
4+
=====================================
5+
6+
Unlike SVC (based on LIBSVM), LinearSVC (based on LIBLINEAR) does not provide
7+
the support vectors. This example demonstrates how to obtain the support
8+
vectors in LinearSVC.
9+
10+
"""
11+
12+
import numpy as np
13+
import matplotlib.pyplot as plt
14+
from sklearn.datasets import make_blobs
15+
from sklearn.svm import LinearSVC
16+
17+
X, y = make_blobs(n_samples=40, centers=2, random_state=0)
18+
19+
plt.figure(figsize=(10, 5))
20+
for i, C in enumerate([1, 100]):
21+
# "hinge" is the standard SVM loss
22+
clf = LinearSVC(C=C, loss="hinge", random_state=42).fit(X, y)
23+
# obtain the support vectors through the decision function
24+
decision_function = clf.decision_function(X)
25+
# we can also calculate the decision function manually
26+
# decision_function = np.dot(X, clf.coef_[0]) + clf.intercept_[0]
27+
support_vector_indices = np.where((2 * y - 1) * decision_function <= 1)[0]
28+
support_vectors = X[support_vector_indices]
29+
30+
plt.subplot(1, 2, i + 1)
31+
plt.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=plt.cm.Paired)
32+
ax = plt.gca()
33+
xlim = ax.get_xlim()
34+
ylim = ax.get_ylim()
35+
xx, yy = np.meshgrid(np.linspace(xlim[0], xlim[1], 50),
36+
np.linspace(ylim[0], ylim[1], 50))
37+
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
38+
Z = Z.reshape(xx.shape)
39+
plt.contour(xx, yy, Z, colors='k', levels=[-1, 0, 1], alpha=0.5,
40+
linestyles=['--', '-', '--'])
41+
plt.scatter(support_vectors[:, 0], support_vectors[:, 1], s=100,
42+
linewidth=1, facecolors='none', edgecolors='k')
43+
plt.title("C=" + str(C))
44+
plt.tight_layout()
45+
plt.show()

0 commit comments

Comments
 (0)