Skip to content

Commit 6c5caa3

Browse files
committed
Merge pull request scikit-learn#4434 from xuewei4d/catch_LinAlgError_gmm
[MRG + 2] Friendly error when lvmpdf has non positive definite covariance
2 parents c97ad1d + 9c57ee3 commit 6c5caa3

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

sklearn/mixture/gmm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,8 +608,12 @@ def _log_multivariate_normal_density_full(X, means, covars, min_covar=1.e-7):
608608
except linalg.LinAlgError:
609609
# The model is most probably stuck in a component with too
610610
# few observations, we need to reinitialize this components
611-
cv_chol = linalg.cholesky(cv + min_covar * np.eye(n_dim),
612-
lower=True)
611+
try:
612+
cv_chol = linalg.cholesky(cv + min_covar * np.eye(n_dim),
613+
lower=True)
614+
except linalg.LinAlgError:
615+
raise ValueError("'covars' must be symmetric, positive-definite")
616+
613617
cv_log_det = 2 * np.sum(np.log(np.diagonal(cv_chol)))
614618
cv_sol = linalg.solve_triangular(cv_chol, (X - mu).T, lower=True).T
615619
log_prob[:, c] = - .5 * (np.sum(cv_sol ** 2, axis=1) +

sklearn/mixture/tests/test_gmm.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from sklearn import mixture
99
from sklearn.datasets.samples_generator import make_spd_matrix
1010
from sklearn.utils.testing import assert_greater
11+
from sklearn.utils.testing import assert_raise_message
12+
1113

1214
rng = np.random.RandomState(0)
1315

@@ -104,6 +106,18 @@ def test_lmvnpdf_full():
104106
assert_array_almost_equal(lpr, reference)
105107

106108

109+
def test_lvmpdf_full_cv_non_positive_definite():
110+
n_features, n_components, n_samples = 2, 1, 10
111+
rng = np.random.RandomState(0)
112+
X = rng.randint(10) * rng.rand(n_samples, n_features)
113+
mu = np.mean(X, 0)
114+
cv = np.array([[[-1, 0], [0, 1]]])
115+
expected_message = "'covars' must be symmetric, positive-definite"
116+
assert_raise_message(ValueError, expected_message,
117+
mixture.log_multivariate_normal_density,
118+
X, mu, cv, 'full')
119+
120+
107121
def test_GMM_attributes():
108122
n_components, n_features = 10, 4
109123
covariance_type = 'diag'

0 commit comments

Comments
 (0)