Skip to content

Commit 86d7afb

Browse files
committed
FIX OrthogonalMatchingPursuit normalized twice
1 parent a4e53a6 commit 86d7afb

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

sklearn/linear_model/omp.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -679,13 +679,6 @@ def fit(self, X, y, Gram=None, Xy=None):
679679
self.tol, norms_sq,
680680
copy_Gram, True).T
681681

682-
if self.normalize:
683-
nonzeros = np.flatnonzero(X_std)
684-
scaling = X_std[nonzeros]
685-
if self.coef_.ndim == 2:
686-
scaling = scaling[np.newaxis, :]
687-
self.coef_[:, nonzeros] /= scaling
688-
689682
self._set_intercept(X_mean, y_mean, X_std)
690683
return self
691684

sklearn/linear_model/tests/test_omp.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
from sklearn.linear_model import (orthogonal_mp, orthogonal_mp_gram,
1616
OrthogonalMatchingPursuit,
17-
OrthogonalMatchingPursuitCV)
17+
OrthogonalMatchingPursuitCV,
18+
LinearRegression)
1819
from sklearn.utils.fixes import count_nonzero
20+
from sklearn.utils import check_random_state
1921
from sklearn.datasets import make_sparse_coded_signal
2022

2123
n_samples, n_features, n_nonzero_coefs, n_targets = 20, 30, 5, 3
@@ -93,7 +95,6 @@ def test_bad_input():
9395

9496

9597
def test_perfect_signal_recovery():
96-
# XXX: use signal generator
9798
idx, = gamma[:, 0].nonzero()
9899
gamma_rec = orthogonal_mp(X, y[:, 0], 5)
99100
gamma_gram = orthogonal_mp_gram(G, Xy[:, 0], 5)
@@ -218,3 +219,17 @@ def test_omp_cv():
218219
n_nonzero_coefs=ompcv.n_nonzero_coefs_)
219220
omp.fit(X, y_)
220221
assert_array_almost_equal(ompcv.coef_, omp.coef_)
222+
223+
224+
def test_omp_reaches_least_squares():
225+
# Use small simple data; it's a sanity check but OMP can stop early
226+
rng = check_random_state(0)
227+
n_samples, n_features = (10, 8)
228+
n_targets = 3
229+
X = rng.randn(n_samples, n_features)
230+
Y = rng.randn(n_samples, n_targets)
231+
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=n_features)
232+
lstsq = LinearRegression()
233+
omp.fit(X, Y)
234+
lstsq.fit(X, Y)
235+
assert_array_almost_equal(omp.coef_, lstsq.coef_)

0 commit comments

Comments
 (0)