Skip to content

Commit 0180188

Browse files
committed
ENH multiclass probability estimates for SGDClassifier
Fixes scikit-learn#1814.
1 parent 5eb035c commit 0180188

File tree

5 files changed

+114
-28
lines changed

5 files changed

+114
-28
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ Changelog
5151
converts their ``coef_`` into a sparse matrix, meaning stored models
5252
trained using these estimators can be made much more compact.
5353

54+
- :class:`linear_model.SGDClassifier` now produces multiclass probability
55+
estimates when trained under log loss or modified Huber loss.
56+
5457
- Hyperlinks to documentation in example code on the website by
5558
`Martin Luessi`_.
5659

sklearn/linear_model/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,25 @@ def predict(self, X):
222222
indices = scores.argmax(axis=1)
223223
return self.classes_[indices]
224224

225+
def _predict_proba_lr(self, X):
226+
"""Probability estimation for OvR logistic regression.
227+
228+
Positive class probabilities are computed as
229+
1. / (1. + np.exp(-self.decision_function(X)));
230+
multiclass is handled by normalizing that over all classes.
231+
"""
232+
prob = self.decision_function(X)
233+
prob *= -1
234+
np.exp(prob, prob)
235+
prob += 1
236+
np.reciprocal(prob, prob)
237+
if len(prob.shape) == 1:
238+
return np.vstack([1 - prob, prob]).T
239+
else:
240+
# OvR normalization, like LibLinear's predict_probability
241+
prob /= prob.sum(axis=1).reshape((prob.shape[0], -1))
242+
return prob
243+
225244

226245
class SparseCoefMixin(object):
227246
"""Mixin for converting coef_ to and from CSR format.

sklearn/linear_model/logistic.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,18 +120,7 @@ def predict_proba(self, X):
120120
Returns the probability of the sample for each class in the model,
121121
where classes are ordered as they are in ``self.classes_``.
122122
"""
123-
# 1. / (1. + np.exp(-scores)), computed in-place
124-
prob = self.decision_function(X)
125-
prob *= -1
126-
np.exp(prob, prob)
127-
prob += 1
128-
np.reciprocal(prob, prob)
129-
if len(prob.shape) == 1:
130-
return np.vstack([1 - prob, prob]).T
131-
else:
132-
# OvR, not softmax, like Liblinear's predict_probability
133-
prob /= prob.sum(axis=1).reshape((prob.shape[0], -1))
134-
return prob
123+
return self._predict_proba_lr(X)
135124

136125
def predict_log_proba(self, X):
137126
"""Log of probability estimates.

sklearn/linear_model/stochastic_gradient.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,12 @@ class SGDClassifier(BaseSGDClassifier, SelectorMixin):
654654
def predict_proba(self, X):
655655
"""Probability estimates.
656656
657-
Probability estimates are only supported for binary classification.
657+
Multiclass probability estimates are derived from binary (one-vs.-rest)
658+
estimates by simple normalization, as recommended by Zadrozny and
659+
Elkan.
660+
661+
Binary probability estimates for loss="modified_huber" are given by
662+
(clip(decision_function(X), -1, 1) + 1) / 2.
658663
659664
Parameters
660665
----------
@@ -668,34 +673,59 @@ def predict_proba(self, X):
668673
669674
References
670675
----------
676+
Zadrozny and Elkan, "Transforming classifier scores into multiclass
677+
probability estimates", SIGKDD'02,
678+
http://www.research.ibm.com/people/z/zadrozny/kdd2002-Transf.pdf
671679
672680
The justification for the formula in the loss="modified_huber"
673681
case is in the appendix B in:
674682
http://jmlr.csail.mit.edu/papers/volume2/zhang02c/zhang02c.pdf
675683
"""
676-
if len(self.classes_) != 2:
677-
raise NotImplementedError("predict_(log_)proba only supported"
678-
" for binary classification")
679-
680-
scores = self.decision_function(X)
681-
proba = np.ones((scores.shape[0], 2), dtype=np.float64)
682684
if self.loss == "log":
683-
proba[:, 1] = 1. / (1. + np.exp(-scores))
685+
return self._predict_proba_lr(X)
684686

685687
elif self.loss == "modified_huber":
686-
proba[:, 1] = (np.clip(scores, -1, 1) + 1) / 2.
688+
binary = (len(self.classes_) == 2)
689+
scores = self.decision_function(X)
690+
691+
if binary:
692+
prob2 = np.ones((scores.shape[0], 2))
693+
prob = prob2[:, 1]
694+
else:
695+
prob = scores
696+
697+
np.clip(scores, -1, 1, prob)
698+
prob += 1.
699+
prob /= 2.
700+
701+
if binary:
702+
prob2[:, 0] -= prob
703+
prob = prob2
704+
else:
705+
# the above might assign zero to all classes, which doesn't
706+
# normalize neatly; work around this to produce uniform
707+
# probabilities
708+
prob_sum = prob.sum(axis=1)
709+
all_zero = (prob_sum == 0)
710+
if np.any(all_zero):
711+
prob[all_zero, :] = 1
712+
prob_sum[all_zero] = len(self.classes_)
713+
714+
# normalize
715+
prob /= prob_sum.reshape((prob.shape[0], -1))
716+
717+
return prob
687718

688719
else:
689720
raise NotImplementedError("predict_(log_)proba only supported when"
690721
" loss='log' or loss='modified_huber' "
691-
"(%s given)" % self.loss)
692-
proba[:, 0] -= proba[:, 1]
693-
return proba
722+
"(%r given)" % self.loss)
694723

695724
def predict_log_proba(self, X):
696725
"""Log of probability estimates.
697726
698-
Log probability estimates are only supported for binary classification.
727+
When loss="modified_huber", probability estimates may be hard zeros
728+
and ones, so taking the logarithm is not possible.
699729
700730
Parameters
701731
----------

sklearn/linear_model/tests/test_sgd.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,11 @@ def test_sgd_proba(self):
319319
clf = self.factory(loss="hinge", alpha=0.01, n_iter=10).fit(X, Y)
320320
assert_raises(NotImplementedError, clf.predict_proba, [3, 2])
321321

322-
# the log and modified_huber losses can output "probability" estimates
323-
for loss in ("log", "modified_huber"):
324-
clf = self.factory(loss=loss, alpha=0.01, n_iter=10).fit(X, Y)
322+
# log and modified_huber losses can output probability estimates
323+
# binary case
324+
for loss in ["log", "modified_huber"]:
325+
clf = self.factory(loss="modified_huber", alpha=0.01, n_iter=10)
326+
clf.fit(X, Y)
325327
p = clf.predict_proba([3, 2])
326328
assert_true(p[0, 1] > 0.5)
327329
p = clf.predict_proba([-1, -1])
@@ -332,6 +334,49 @@ def test_sgd_proba(self):
332334
p = clf.predict_log_proba([-1, -1])
333335
assert_true(p[0, 1] < p[0, 0])
334336

337+
# log loss multiclass probability estimates
338+
clf = self.factory(loss="log", alpha=0.01, n_iter=10).fit(X2, Y2)
339+
340+
d = clf.decision_function([[.1, -.1], [.3, .2]])
341+
p = clf.predict_proba([[.1, -.1], [.3, .2]])
342+
assert_array_equal(np.argmax(p, axis=1), np.argmax(d, axis=1))
343+
assert_almost_equal(p[0].sum(), 1)
344+
assert_true(np.all(p[0] >= 0))
345+
346+
p = clf.predict_proba([-1, -1])
347+
d = clf.decision_function([-1, -1])
348+
assert_array_equal(np.argsort(p[0]), np.argsort(d[0]))
349+
350+
l = clf.predict_log_proba([3, 2])
351+
p = clf.predict_proba([3, 2])
352+
assert_array_almost_equal(np.log(p), l)
353+
354+
l = clf.predict_log_proba([-1, -1])
355+
p = clf.predict_proba([-1, -1])
356+
assert_array_almost_equal(np.log(p), l)
357+
358+
# Modified Huber multiclass probability estimates; requires a separate
359+
# test because the hard zero/one probabilities may destroy the
360+
# ordering present in decision_function output.
361+
clf = self.factory(loss="modified_huber", alpha=0.01, n_iter=10)
362+
clf.fit(X2, Y2)
363+
d = clf.decision_function([3, 2])
364+
p = clf.predict_proba([3, 2])
365+
if not isinstance(self, SparseSGDClassifierTestCase):
366+
assert_equal(np.argmax(d, axis=1), np.argmax(p, axis=1))
367+
else: # XXX the sparse test gets a different X2 (?)
368+
assert_equal(np.argmin(d, axis=1), np.argmin(p, axis=1))
369+
370+
# the following sample produces decision_function values < -1,
371+
# which would cause naive normalization to fail (see comment
372+
# in SGDClassifier.predict_proba)
373+
x = X.mean(axis=0)
374+
d = clf.decision_function(x)
375+
if np.all(d < -1): # XXX not true in sparse test case (why?)
376+
p = clf.predict_proba(x)
377+
assert_array_almost_equal(p[0], [1/3.] * 3)
378+
379+
335380
def test_sgd_l1(self):
336381
"""Test L1 regularization"""
337382
n = len(X4)

0 commit comments

Comments
 (0)