Skip to content

Commit 6acfa18

Browse files
authored
Create RBF SVM parameters
1 parent c427ffa commit 6acfa18

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed

RBF SVM parameters

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
RBF SVM parameters
2+
This example illustrates the effect of the parameters gamma and C of the Radial Basis Function (RBF) kernel SVM.
3+
Intuitively, the gamma parameter defines how far the influence of a single training example reaches, with low values meaning ‘far’ and high values meaning ‘close’. The gamma parameters can be seen as the inverse of the radius of influence of samples selected by the model as support vectors.
4+
The C parameter trades off misclassification of training examples against simplicity of the decision surface. A low C makes the decision surface smooth, while a high C aims at classifying all training examples correctly by giving the model freedom to select more samples as support vectors.
5+
The first plot is a visualization of the decision function for a variety of parameter values on a simplified classification problem involving only 2 input features and 2 possible target classes (binary classification). Note that this kind of plot is not possible to do for problems with more features or target classes.
6+
The second plot is a heatmap of the classifier’s cross-validation accuracy as a function of C and gamma. For this example we explore a relatively large grid for illustration purposes. In practice, a logarithmic grid from 10^{-3} to 10^3 is usually sufficient. If the best parameters lie on the boundaries of the grid, it can be extended in that direction in a subsequent search.
7+
Note that the heat map plot has a special colorbar with a midpoint value close to the score values of the best performing models so as to make it easy to tell them appart in the blink of an eye.
8+
The behavior of the model is very sensitive to the gamma parameter. If gamma is too large, the radius of the area of influence of the support vectors only includes the support vector itself and no amount of regularization with C will be able to prevent overfitting.
9+
When gamma is very small, the model is too constrained and cannot capture the complexity or “shape” of the data. The region of influence of any selected support vector would include the whole training set. The resulting model will behave similarly to a linear model with a set of hyperplanes that separate the centers of high density of any pair of two classes.
10+
For intermediate values, we can see on the second plot that good models can be found on a diagonal of C and gamma. Smooth models (lower gamma values) can be made more complex by selecting a larger number of support vectors (larger C values) hence the diagonal of good performing models.
11+
Finally one can also observe that for some intermediate values of gamma we get equally performing models when C becomes very large: it is not necessary to regularize by limiting the number of support vectors. The radius of the RBF kernel alone acts as a good structural regularizer. In practice though it might still be interesting to limit the number of support vectors with a lower value of C so as to favor models that use less memory and that are faster to predict.
12+
We should also note that small differences in scores results from the random splits of the cross-validation procedure. Those spurious variations can be smoothed out by increasing the number of CV iterations n_splits at the expense of compute time. Increasing the value number of C_range and gamma_range steps will increase the resolution of the hyper-parameter heat map.
13+
../../_images/sphx_glr_plot_rbf_parameters_001.png ../../_images/sphx_glr_plot_rbf_parameters_002.png
14+
Out:
15+
The best parameters are {'C': 1.0, 'gamma': 0.10000000000000001} with a score of 0.97
16+
17+
print(__doc__)
18+
19+
import numpy as np
20+
import matplotlib.pyplot as plt
21+
from matplotlib.colors import Normalize
22+
23+
from sklearn.svm import SVC
24+
from sklearn.preprocessing import StandardScaler
25+
from sklearn.datasets import load_iris
26+
from sklearn.model_selection import StratifiedShuffleSplit
27+
from sklearn.model_selection import GridSearchCV
28+
29+
30+
# Utility function to move the midpoint of a colormap to be around
31+
# the values of interest.
32+
33+
class MidpointNormalize(Normalize):
34+
35+
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
36+
self.midpoint = midpoint
37+
Normalize.__init__(self, vmin, vmax, clip)
38+
39+
def __call__(self, value, clip=None):
40+
x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
41+
return np.ma.masked_array(np.interp(value, x, y))
42+
43+
# #############################################################################
44+
# Load and prepare data set
45+
#
46+
# dataset for grid search
47+
48+
iris = load_iris()
49+
X = iris.data
50+
y = iris.target
51+
52+
# Dataset for decision function visualization: we only keep the first two
53+
# features in X and sub-sample the dataset to keep only 2 classes and
54+
# make it a binary classification problem.
55+
56+
X_2d = X[:, :2]
57+
X_2d = X_2d[y > 0]
58+
y_2d = y[y > 0]
59+
y_2d -= 1
60+
61+
# It is usually a good idea to scale the data for SVM training.
62+
# We are cheating a bit in this example in scaling all of the data,
63+
# instead of fitting the transformation on the training set and
64+
# just applying it on the test set.
65+
66+
scaler = StandardScaler()
67+
X = scaler.fit_transform(X)
68+
X_2d = scaler.fit_transform(X_2d)
69+
70+
# #############################################################################
71+
# Train classifiers
72+
#
73+
# For an initial search, a logarithmic grid with basis
74+
# 10 is often helpful. Using a basis of 2, a finer
75+
# tuning can be achieved but at a much higher cost.
76+
77+
C_range = np.logspace(-2, 10, 13)
78+
gamma_range = np.logspace(-9, 3, 13)
79+
param_grid = dict(gamma=gamma_range, C=C_range)
80+
cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
81+
grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv)
82+
grid.fit(X, y)
83+
84+
print("The best parameters are %s with a score of %0.2f"
85+
% (grid.best_params_, grid.best_score_))
86+
87+
# Now we need to fit a classifier for all parameters in the 2d version
88+
# (we use a smaller set of parameters here because it takes a while to train)
89+
90+
C_2d_range = [1e-2, 1, 1e2]
91+
gamma_2d_range = [1e-1, 1, 1e1]
92+
classifiers = []
93+
for C in C_2d_range:
94+
for gamma in gamma_2d_range:
95+
clf = SVC(C=C, gamma=gamma)
96+
clf.fit(X_2d, y_2d)
97+
classifiers.append((C, gamma, clf))
98+
99+
# #############################################################################
100+
# Visualization
101+
#
102+
# draw visualization of parameter effects
103+
104+
plt.figure(figsize=(8, 6))
105+
xx, yy = np.meshgrid(np.linspace(-3, 3, 200), np.linspace(-3, 3, 200))
106+
for (k, (C, gamma, clf)) in enumerate(classifiers):
107+
# evaluate decision function in a grid
108+
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
109+
Z = Z.reshape(xx.shape)
110+
111+
# visualize decision function for these parameters
112+
plt.subplot(len(C_2d_range), len(gamma_2d_range), k + 1)
113+
plt.title("gamma=10^%d, C=10^%d" % (np.log10(gamma), np.log10(C)),
114+
size='medium')
115+
116+
# visualize parameter's effect on decision function
117+
plt.pcolormesh(xx, yy, -Z, cmap=plt.cm.RdBu)
118+
plt.scatter(X_2d[:, 0], X_2d[:, 1], c=y_2d, cmap=plt.cm.RdBu_r,
119+
edgecolors='k')
120+
plt.xticks(())
121+
plt.yticks(())
122+
plt.axis('tight')
123+
124+
scores = grid.cv_results_['mean_test_score'].reshape(len(C_range),
125+
len(gamma_range))
126+
127+
# Draw heatmap of the validation accuracy as a function of gamma and C
128+
#
129+
# The score are encoded as colors with the hot colormap which varies from dark
130+
# red to bright yellow. As the most interesting scores are all located in the
131+
# 0.92 to 0.97 range we use a custom normalizer to set the mid-point to 0.92 so
132+
# as to make it easier to visualize the small variations of score values in the
133+
# interesting range while not brutally collapsing all the low score values to
134+
# the same color.
135+
136+
plt.figure(figsize=(8, 6))
137+
plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
138+
plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot,
139+
norm=MidpointNormalize(vmin=0.2, midpoint=0.92))
140+
plt.xlabel('gamma')
141+
plt.ylabel('C')
142+
plt.colorbar()
143+
plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
144+
plt.yticks(np.arange(len(C_range)), C_range)
145+
plt.title('Validation accuracy')
146+
plt.show()

0 commit comments

Comments
 (0)