Skip to content

Commit 4618cb3

Browse files
committed
Merge pull request scikit-learn#4653 from ssaeger/issue_4633
[MRG + 1] Added verbose flag to GMM
2 parents f10d2f4 + d2382e1 commit 4618cb3

File tree

4 files changed

+147
-18
lines changed

4 files changed

+147
-18
lines changed

sklearn/mixture/dpgmm.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class DPGMM(GMM):
158158
process. Can contain any combination of 'w' for weights,
159159
'm' for means, and 'c' for covars. Defaults to 'wmc'.
160160
161-
verbose : boolean, default False
161+
verbose : int, default 0
162162
Controls output verbosity.
163163
164164
Attributes
@@ -198,15 +198,14 @@ class DPGMM(GMM):
198198
"""
199199

200200
def __init__(self, n_components=1, covariance_type='diag', alpha=1.0,
201-
random_state=None, thresh=None, tol=1e-3, verbose=False,
201+
random_state=None, thresh=None, tol=1e-3, verbose=0,
202202
min_covar=None, n_iter=10, params='wmc', init_params='wmc'):
203203
self.alpha = alpha
204-
self.verbose = verbose
205204
super(DPGMM, self).__init__(n_components, covariance_type,
206205
random_state=random_state, thresh=thresh,
207206
tol=tol, min_covar=min_covar,
208207
n_iter=n_iter, params=params,
209-
init_params=init_params)
208+
init_params=init_params, verbose=verbose)
210209

211210
def _get_precisions(self):
212211
"""Return precisions as a full matrix."""
@@ -367,7 +366,7 @@ def _monitor(self, X, z, n, end=False):
367366
expected.
368367
369368
Note: this is very expensive and should not be used by default."""
370-
if self.verbose:
369+
if self.verbose > 0:
371370
print("Bound after updating %8s: %f" % (n, self.lower_bound(X, z)))
372371
if end:
373372
print("Cluster proportions:", self.gamma_.T[1])
@@ -653,7 +652,7 @@ class VBGMM(DPGMM):
653652
process. Can contain any combination of 'w' for weights,
654653
'm' for means, and 'c' for covars. Defaults to 'wmc'.
655654
656-
verbose : boolean, default False
655+
verbose : int, default 0
657656
Controls output verbosity.
658657
659658
Attributes
@@ -695,7 +694,7 @@ class VBGMM(DPGMM):
695694
"""
696695

697696
def __init__(self, n_components=1, covariance_type='diag', alpha=1.0,
698-
random_state=None, thresh=None, tol=1e-3, verbose=False,
697+
random_state=None, thresh=None, tol=1e-3, verbose=0,
699698
min_covar=None, n_iter=10, params='wmc', init_params='wmc'):
700699
super(VBGMM, self).__init__(
701700
n_components, covariance_type, random_state=random_state,
@@ -779,7 +778,7 @@ def _monitor(self, X, z, n, end=False):
779778
expected.
780779
781780
Note: this is very expensive and should not be used by default."""
782-
if self.verbose:
781+
if self.verbose > 0:
783782
print("Bound after updating %8s: %f" % (n, self.lower_bound(X, z)))
784783
if end:
785784
print("Cluster proportions:", self.gamma_)

sklearn/mixture/gmm.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import warnings
1313
import numpy as np
1414
from scipy import linalg
15+
from time import time
1516

1617
from ..base import BaseEstimator
1718
from ..utils import check_random_state, check_array
@@ -156,6 +157,11 @@ class GMM(BaseEstimator):
156157
process. Can contain any combination of 'w' for weights,
157158
'm' for means, and 'c' for covars. Defaults to 'wmc'.
158159
160+
verbose : int, default: 0
161+
Enable verbose output. If 1 then it always prints the current
162+
initialization and iteration step. If greater than 1 then
163+
it prints additionally the change and time needed for each step.
164+
159165
Attributes
160166
----------
161167
weights_ : array, shape (`n_components`,)
@@ -203,7 +209,7 @@ class GMM(BaseEstimator):
203209
>>> g.fit(obs) # doctest: +NORMALIZE_WHITESPACE
204210
GMM(covariance_type='diag', init_params='wmc', min_covar=0.001,
205211
n_components=2, n_init=1, n_iter=100, params='wmc',
206-
random_state=None, thresh=None, tol=0.001)
212+
random_state=None, thresh=None, tol=0.001, verbose=0)
207213
>>> np.round(g.weights_, 2)
208214
array([ 0.75, 0.25])
209215
>>> np.round(g.means_, 2)
@@ -221,15 +227,16 @@ class GMM(BaseEstimator):
221227
>>> g.fit(20 * [[0]] + 20 * [[10]]) # doctest: +NORMALIZE_WHITESPACE
222228
GMM(covariance_type='diag', init_params='wmc', min_covar=0.001,
223229
n_components=2, n_init=1, n_iter=100, params='wmc',
224-
random_state=None, thresh=None, tol=0.001)
230+
random_state=None, thresh=None, tol=0.001, verbose=0)
225231
>>> np.round(g.weights_, 2)
226232
array([ 0.5, 0.5])
227233
228234
"""
229235

230236
def __init__(self, n_components=1, covariance_type='diag',
231237
random_state=None, thresh=None, tol=1e-3, min_covar=1e-3,
232-
n_iter=100, n_init=1, params='wmc', init_params='wmc'):
238+
n_iter=100, n_init=1, params='wmc', init_params='wmc',
239+
verbose=0):
233240
if thresh is not None:
234241
warnings.warn("'thresh' has been replaced by 'tol' in 0.16 "
235242
" and will be removed in 0.18.",
@@ -244,6 +251,7 @@ def __init__(self, n_components=1, covariance_type='diag',
244251
self.n_init = n_init
245252
self.params = params
246253
self.init_params = init_params
254+
self.verbose = verbose
247255

248256
if covariance_type not in ['spherical', 'tied', 'diag', 'full']:
249257
raise ValueError('Invalid value for covariance_type: %s' %
@@ -458,15 +466,26 @@ def _fit(self, X, y=None, do_prediction=False):
458466

459467
max_log_prob = -np.infty
460468

461-
for _ in range(self.n_init):
469+
if self.verbose > 0:
470+
print('Expectation-maximization algorithm started.')
471+
472+
for init in range(self.n_init):
473+
if self.verbose > 0:
474+
print('Initialization '+str(init+1))
475+
start_init_time = time()
476+
462477
if 'm' in self.init_params or not hasattr(self, 'means_'):
463478
self.means_ = cluster.KMeans(
464479
n_clusters=self.n_components,
465480
random_state=self.random_state).fit(X).cluster_centers_
481+
if self.verbose > 1:
482+
print('\tMeans have been initialized.')
466483

467484
if 'w' in self.init_params or not hasattr(self, 'weights_'):
468485
self.weights_ = np.tile(1.0 / self.n_components,
469486
self.n_components)
487+
if self.verbose > 1:
488+
print('\tWeights have been initialized.')
470489

471490
if 'c' in self.init_params or not hasattr(self, 'covars_'):
472491
cv = np.cov(X.T) + self.min_covar * np.eye(X.shape[1])
@@ -475,6 +494,8 @@ def _fit(self, X, y=None, do_prediction=False):
475494
self.covars_ = \
476495
distribute_covar_matrix_to_match_covariance_type(
477496
cv, self.covariance_type, self.n_components)
497+
if self.verbose > 1:
498+
print('\tCovariance matrices have been initialized.')
478499

479500
# EM algorithms
480501
current_log_likelihood = None
@@ -486,23 +507,33 @@ def _fit(self, X, y=None, do_prediction=False):
486507
else self.thresh / float(X.shape[0]))
487508

488509
for i in range(self.n_iter):
510+
if self.verbose > 0:
511+
print('\tEM iteration '+str(i+1))
512+
start_iter_time = time()
489513
prev_log_likelihood = current_log_likelihood
490514
# Expectation step
491515
log_likelihoods, responsibilities = self.score_samples(X)
492516
current_log_likelihood = log_likelihoods.mean()
493517

494518
# Check for convergence.
495-
# (should compare to self.tol when dreprecated 'thresh' is
519+
# (should compare to self.tol when deprecated 'thresh' is
496520
# removed in v0.18)
497521
if prev_log_likelihood is not None:
498522
change = abs(current_log_likelihood - prev_log_likelihood)
523+
if self.verbose > 1:
524+
print('\t\tChange: '+str(change))
499525
if change < tol:
500526
self.converged_ = True
527+
if self.verbose > 0:
528+
print('\t\tEM algorithm converged.')
501529
break
502530

503531
# Maximization step
504532
self._do_mstep(X, responsibilities, self.params,
505533
self.min_covar)
534+
if self.verbose > 1:
535+
print('\t\tEM iteration '+str(i+1)+' took {0:.5f}s'.format(
536+
time()-start_iter_time))
506537

507538
# if the results are better, keep it
508539
if self.n_iter:
@@ -511,6 +542,13 @@ def _fit(self, X, y=None, do_prediction=False):
511542
best_params = {'weights': self.weights_,
512543
'means': self.means_,
513544
'covars': self.covars_}
545+
if self.verbose > 1:
546+
print('\tBetter parameters were found.')
547+
548+
if self.verbose > 1:
549+
print('\tInitialization '+str(init+1)+' took {0:.5f}s'.format(
550+
time()-start_init_time))
551+
514552
# check the existence of an init param that was not subject to
515553
# likelihood computation issue.
516554
if np.isneginf(max_log_prob) and self.n_iter:
@@ -661,7 +699,8 @@ def _log_multivariate_normal_density_full(X, means, covars, min_covar=1.e-7):
661699
cv_chol = linalg.cholesky(cv + min_covar * np.eye(n_dim),
662700
lower=True)
663701
except linalg.LinAlgError:
664-
raise ValueError("'covars' must be symmetric, positive-definite")
702+
raise ValueError("'covars' must be symmetric, "
703+
"positive-definite")
665704

666705
cv_log_det = 2 * np.sum(np.log(np.diagonal(cv_chol)))
667706
cv_sol = linalg.solve_triangular(cv_chol, (X - mu).T, lower=True).T

sklearn/mixture/tests/test_dpgmm.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import sys
23

34
import nose
45

@@ -7,8 +8,9 @@
78
from sklearn.mixture import DPGMM, VBGMM
89
from sklearn.mixture.dpgmm import log_normalize
910
from sklearn.datasets import make_blobs
10-
from sklearn.utils.testing import assert_array_less
11+
from sklearn.utils.testing import assert_array_less, assert_equal
1112
from sklearn.mixture.tests.test_gmm import GMMTester
13+
from sklearn.externals.six.moves import cStringIO as StringIO
1214

1315
np.seterr(all='warn')
1416

@@ -30,6 +32,65 @@ def test_class_weights():
3032
assert_array_less(dpgmm.weights_[~active], .05)
3133

3234

35+
def test_verbose_boolean():
36+
# checks that the output for the verbose output is the same
37+
# for the flag values '1' and 'True'
38+
# simple 3 cluster dataset
39+
X, y = make_blobs(random_state=1)
40+
for Model in [DPGMM, VBGMM]:
41+
dpgmm_bool = Model(n_components=10, random_state=1, alpha=20,
42+
n_iter=50, verbose=True)
43+
dpgmm_int = Model(n_components=10, random_state=1, alpha=20,
44+
n_iter=50, verbose=1)
45+
46+
old_stdout = sys.stdout
47+
sys.stdout = StringIO()
48+
try:
49+
# generate output with the boolean flag
50+
dpgmm_bool.fit(X)
51+
verbose_output = sys.stdout
52+
verbose_output.seek(0)
53+
bool_output = verbose_output.readline()
54+
# generate output with the int flag
55+
dpgmm_int.fit(X)
56+
verbose_output = sys.stdout
57+
verbose_output.seek(0)
58+
int_output = verbose_output.readline()
59+
assert_equal(bool_output, int_output)
60+
finally:
61+
sys.stdout = old_stdout
62+
63+
64+
def test_verbose_first_level():
65+
# simple 3 cluster dataset
66+
X, y = make_blobs(random_state=1)
67+
for Model in [DPGMM, VBGMM]:
68+
dpgmm = Model(n_components=10, random_state=1, alpha=20, n_iter=50,
69+
verbose=1)
70+
71+
old_stdout = sys.stdout
72+
sys.stdout = StringIO()
73+
try:
74+
dpgmm.fit(X)
75+
finally:
76+
sys.stdout = old_stdout
77+
78+
79+
def test_verbose_second_level():
80+
# simple 3 cluster dataset
81+
X, y = make_blobs(random_state=1)
82+
for Model in [DPGMM, VBGMM]:
83+
dpgmm = Model(n_components=10, random_state=1, alpha=20, n_iter=50,
84+
verbose=2)
85+
86+
old_stdout = sys.stdout
87+
sys.stdout = StringIO()
88+
try:
89+
dpgmm.fit(X)
90+
finally:
91+
sys.stdout = old_stdout
92+
93+
3394
def test_log_normalize():
3495
v = np.array([0.1, 0.8, 0.01, 0.09])
3596
a = np.log(2 * v)

sklearn/mixture/tests/test_gmm.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import copy
3+
import sys
34

45
from nose.tools import assert_true
56
import numpy as np
@@ -11,6 +12,7 @@
1112
from sklearn.utils.testing import assert_greater
1213
from sklearn.utils.testing import assert_raise_message
1314
from sklearn.metrics.cluster import adjusted_rand_score
15+
from sklearn.externals.six.moves import cStringIO as StringIO
1416

1517
rng = np.random.RandomState(0)
1618

@@ -108,7 +110,7 @@ def test_lmvnpdf_full():
108110

109111

110112
def test_lvmpdf_full_cv_non_positive_definite():
111-
n_features, n_components, n_samples = 2, 1, 10
113+
n_features, n_samples = 2, 10
112114
rng = np.random.RandomState(0)
113115
X = rng.randint(10) * rng.rand(n_samples, n_features)
114116
mu = np.mean(X, 0)
@@ -263,7 +265,7 @@ def test_train_1d(self, params='wmc'):
263265
# Train on 1-D data
264266
# Create a training set by sampling from the predefined distribution.
265267
X = rng.randn(100, 1)
266-
#X.T[1:] = 0
268+
# X.T[1:] = 0
267269
g = self.model(n_components=2, covariance_type=self.covariance_type,
268270
random_state=rng, min_covar=1e-7, n_iter=5,
269271
init_params=params)
@@ -371,7 +373,7 @@ def test_fit_predict():
371373

372374
model = mixture.GMM(n_components=n_comps, n_iter=0)
373375
z = model.fit_predict(X)
374-
assert np.all(z==0), "Quick Initialization Failed!"
376+
assert np.all(z == 0), "Quick Initialization Failed!"
375377

376378

377379
def test_aic():
@@ -443,6 +445,34 @@ def test_positive_definite_covars():
443445
yield check_positive_definite_covars, covariance_type
444446

445447

448+
def test_verbose_first_level():
449+
# Create sample data
450+
X = rng.randn(30, 5)
451+
X[:10] += 2
452+
g = mixture.GMM(n_components=2, n_init=2, verbose=1)
453+
454+
old_stdout = sys.stdout
455+
sys.stdout = StringIO()
456+
try:
457+
g.fit(X)
458+
finally:
459+
sys.stdout = old_stdout
460+
461+
462+
def test_verbose_second_level():
463+
# Create sample data
464+
X = rng.randn(30, 5)
465+
X[:10] += 2
466+
g = mixture.GMM(n_components=2, n_init=2, verbose=2)
467+
468+
old_stdout = sys.stdout
469+
sys.stdout = StringIO()
470+
try:
471+
g.fit(X)
472+
finally:
473+
sys.stdout = old_stdout
474+
475+
446476
if __name__ == '__main__':
447477
import nose
448478
nose.runmodule()

0 commit comments

Comments
 (0)