Skip to content

Commit e6f7ab3

Browse files
MechCoderagramfort
authored andcommitted
ENH Added multinomial logreg to plot_classification_probability.py
1 parent 4cd1777 commit e6f7ab3

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

examples/classification/plot_classification_probability.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
===============================
55
66
Plot the classification probability for different classifiers. We use a 3
7-
class dataset, and we classify it with a Support Vector classifier, as
8-
well as L1 and L2 penalized logistic regression.
7+
class dataset, and we classify it with a Support Vector classifier, L1
8+
and L2 penalized logistic regression, and L2 penalized logistic
9+
regression with a multinomial setting.
910
1011
The logistic regression is not a multiclass classifier out of the box. As
1112
a result it can identify only the first class.
@@ -35,13 +36,21 @@ class dataset, and we classify it with a Support Vector classifier, as
3536
classifiers = {'L1 logistic': LogisticRegression(C=C, penalty='l1'),
3637
'L2 logistic': LogisticRegression(C=C, penalty='l2'),
3738
'Linear SVC': SVC(kernel='linear', C=C, probability=True,
38-
random_state=0)}
39+
random_state=0),
40+
'Multinomial Logistic': LogisticRegression(
41+
C=C, solver='lbfgs', multi_class='multinomial'
42+
)}
3943

4044
n_classifiers = len(classifiers)
4145

4246
plt.figure(figsize=(3 * 2, n_classifiers * 2))
4347
plt.subplots_adjust(bottom=.2, top=.95)
4448

49+
xx = np.linspace(3, 9, 100)
50+
yy = np.linspace(1, 5, 100).T
51+
xx, yy = np.meshgrid(xx, yy)
52+
Xfull = np.c_[xx.ravel(), yy.ravel()]
53+
4554
for index, (name, classifier) in enumerate(classifiers.items()):
4655
classifier.fit(X, y)
4756

@@ -50,10 +59,6 @@ class dataset, and we classify it with a Support Vector classifier, as
5059
print("classif_rate for %s : %f " % (name, classif_rate))
5160

5261
# View probabilities=
53-
xx = np.linspace(3, 9, 100)
54-
yy = np.linspace(1, 5, 100).T
55-
xx, yy = np.meshgrid(xx, yy)
56-
Xfull = np.c_[xx.ravel(), yy.ravel()]
5762
probas = classifier.predict_proba(Xfull)
5863
n_classes = np.unique(y_pred).size
5964
for k in range(n_classes):

0 commit comments

Comments
 (0)