Skip to content

Commit 7d564b3

Browse files
committed
FIX expit bug with out != None
1 parent bd41939 commit 7d564b3

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

sklearn/utils/fixes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def expit(x, out=None):
3838
See sklearn.utils.extmath.log_logistic for the log of this function.
3939
"""
4040
if out is None:
41-
out = np.copy(x)
41+
out = np.empty(np.atleast_1d(x).shape, dtype=np.float64)
42+
out[:] = x
4243

4344
# 1 / (1 + exp(-x)) = (1 + tanh(x / 2)) / 2
4445
# This way of computing the logistic is both fast and stable.
@@ -47,7 +48,7 @@ def expit(x, out=None):
4748
out += 1
4849
out *= .5
4950

50-
return out
51+
return out.reshape(np.shape(x))
5152

5253

5354
# little danse to see if np.copy has an 'order' keyword argument

sklearn/utils/tests/test_fixes.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import numpy as np
77

88
from nose.tools import assert_equal
9-
from numpy.testing import assert_almost_equal, assert_array_equal
9+
from numpy.testing import (assert_almost_equal,
10+
assert_array_almost_equal,
11+
assert_array_equal)
1012

1113
from ..fixes import divide, expit
1214

@@ -20,6 +22,10 @@ def test_expit():
2022
assert_almost_equal(expit(-1000.), np.exp(-1000.) / (1. + np.exp(-1000.)),
2123
decimal=16)
2224

25+
x = np.arange(10)
26+
out = np.zeros_like(x, dtype=np.float32)
27+
assert_array_almost_equal(expit(x), expit(x, out=out))
28+
2329

2430
def test_divide():
2531
assert_equal(divide(.6, 1), .600000000000)

0 commit comments

Comments
 (0)