Skip to content

Commit 53eaac8

Browse files
author
colin
committed
Refactor tests, fix geweke bug, set random seed
1 parent d709a70 commit 53eaac8

File tree

3 files changed

+117
-66
lines changed

3 files changed

+117
-66
lines changed

pymc3/diagnostics.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""Convergence diagnostics and model validation"""
22

33
import numpy as np
4-
from .stats import autocorr, autocov, statfunc
5-
from copy import copy
4+
from .stats import statfunc
65

76
__all__ = ['geweke', 'gelman_rubin', 'effective_n']
87

@@ -54,6 +53,12 @@ def geweke(x, first=.1, last=.5, intervals=20):
5453
return [geweke(y, first, last, intervals) for y in np.transpose(x)]
5554

5655
# Filter out invalid intervals
56+
for interval in (first, last):
57+
if interval <= 0 or interval >= 1:
58+
raise ValueError(
59+
"Invalid intervals for Geweke convergence analysis",
60+
(first,
61+
last))
5762
if first + last >= 1:
5863
raise ValueError(
5964
"Invalid intervals for Geweke convergence analysis",
@@ -66,18 +71,20 @@ def geweke(x, first=.1, last=.5, intervals=20):
6671
# Last index value
6772
end = len(x) - 1
6873

74+
# Start intervals going up to the <last>% of the chain
75+
last_start_idx = (1 - last) * end
76+
6977
# Calculate starting indices
70-
sindices = np.arange(0, end // 2, step=int((end / 2) / (intervals - 1)))
78+
start_indices = np.arange(0, int(last_start_idx), step=int((last_start_idx) / (intervals - 1)))
7179

7280
# Loop over start indices
73-
for start in sindices:
74-
81+
for start in start_indices:
7582
# Calculate slices
7683
first_slice = x[start: start + int(first * (end - start))]
7784
last_slice = x[int(end - last * (end - start)):]
7885

79-
z = (first_slice.mean() - last_slice.mean())
80-
z /= np.sqrt(first_slice.std() ** 2 + last_slice.std() ** 2)
86+
z = first_slice.mean() - last_slice.mean()
87+
z /= np.sqrt(first_slice.var() + last_slice.var())
8188

8289
zscores.append([start, z])
8390

@@ -177,7 +184,7 @@ def effective_n(mtrace):
177184
mtrace : MultiTrace
178185
A MultiTrace object containing parallel traces (minimum 2)
179186
of one or more stochastic parameters.
180-
187+
181188
Returns
182189
-------
183190
n_eff : float
@@ -191,13 +198,13 @@ def effective_n(mtrace):
191198
.. math:: \hat{n}_{eff} = \frac{mn}}{1 + 2 \sum_{t=1}^T \hat{\rho}_t}
192199
193200
where :math:`\hat{\rho}_t` is the estimated autocorrelation at lag t, and T
194-
is the first odd positive integer for which the sum :math:`\hat{\rho}_{T+1} + \hat{\rho}_{T+1}`
201+
is the first odd positive integer for which the sum :math:`\hat{\rho}_{T+1} + \hat{\rho}_{T+1}`
195202
is negative.
196203
197204
References
198205
----------
199206
Gelman et al. (2014)"""
200-
207+
201208
if mtrace.nchains < 2:
202209
raise ValueError(
203210
'Calculation of effective sample size requires multiple chains of the same length.')
@@ -226,32 +233,32 @@ def calc_vhat(x):
226233
rotated_indices = np.roll(np.arange(x.ndim), 1)
227234
# Now iterate over the dimension of the variable
228235
return np.squeeze([calc_vhat(xi) for xi in x.transpose(rotated_indices)])
229-
236+
230237
def calc_n_eff(x):
231-
238+
232239
m, n = x.shape
233-
240+
234241
negative_autocorr = False
235242
t = 1
236-
243+
237244
Vhat = calc_vhat(x)
238-
239-
variogram = lambda t: (sum(sum((x[j][i] - x[j][i-t])**2
245+
246+
variogram = lambda t: (sum(sum((x[j][i] - x[j][i-t])**2
240247
for i in range(t,n)) for j in range(m)) / (m*(n - t)))
241-
248+
242249
rho = np.ones(n)
243250
# Iterate until the sum of consecutive estimates of autocorrelation is negative
244251
while not negative_autocorr and (t < n):
245-
252+
246253
rho[t] = 1. - variogram(t)/(2.*Vhat)
247-
254+
248255
if not t % 2:
249256
negative_autocorr = sum(rho[t-1:t+1]) < 0
250-
257+
251258
t += 1
252-
259+
253260
return int(m*n / (1. + 2*rho[1:t].sum()))
254-
261+
255262
n_eff = {}
256263
for var in mtrace.varnames:
257264

pymc3/tests/test_diagnostics.py

Lines changed: 86 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,107 @@
1-
from ..theanof import inputvars
2-
from ..model import Model, modelcontext
1+
import unittest
2+
3+
from numpy.testing import assert_allclose, assert_array_less
4+
5+
from ..model import Model
36
from ..step_methods import Slice, Metropolis, NUTS
47
from ..distributions import Normal
58
from ..tuning import find_MAP
69
from ..sampling import sample
710
from ..diagnostics import effective_n, geweke, gelman_rubin
811
from pymc3.examples import disaster_model as dm
9-
from numpy import all, isclose
1012

11-
def test_gelman_rubin(n=1000):
1213

13-
with dm.model:
14-
# Run sampler
15-
step1 = Slice([dm.early_mean, dm.late_mean])
16-
step2 = Metropolis([dm.switchpoint])
17-
start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50}
18-
ptrace = sample(n, [step1, step2], start, njobs=2,
19-
random_seed=[1, 3])
14+
class TestGelmanRubin(unittest.TestCase):
15+
good_ratio = 1.1
16+
17+
def get_ptrace(self, n_samples):
18+
with dm.model:
19+
# Run sampler
20+
step1 = Slice([dm.early_mean, dm.late_mean])
21+
step2 = Metropolis([dm.switchpoint])
22+
start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50}
23+
ptrace = sample(n_samples, [step1, step2], start, njobs=2,
24+
random_seed=[1, 3])
25+
return ptrace
26+
27+
def test_good(self):
28+
"""Confirm Gelman-Rubin statistic is close to 1 for a reasonable number of samples."""
29+
n_samples = 1000
30+
rhat = gelman_rubin(self.get_ptrace(n_samples))
31+
self.assertTrue(all(1 / self.good_ratio < r < self.good_ratio for r in rhat.values()))
32+
33+
def test_bad(self):
34+
"""Confirm Gelman-Rubin statistic is far from 1 for a small number of samples."""
35+
n_samples = 10
36+
rhat = gelman_rubin(self.get_ptrace(n_samples))
37+
self.assertFalse(all(1 / self.good_ratio < r < self.good_ratio for r in rhat.values()))
38+
39+
40+
class TestDiagnostics(unittest.TestCase):
41+
def get_switchpoint(self, n_samples):
42+
with dm.model:
43+
# Run sampler
44+
step1 = Slice([dm.early_mean, dm.late_mean])
45+
step2 = Metropolis([dm.switchpoint])
46+
trace = sample(n_samples, [step1, step2], progressbar=False,
47+
random_seed=1)
48+
return trace['switchpoint']
2049

21-
rhat = gelman_rubin(ptrace)
50+
def test_geweke_negative(self):
51+
"""Confirm Geweke diagnostic is larger than 1 for a small number of samples."""
52+
n_samples = 200
53+
n_intervals = 20
54+
switchpoint = self.get_switchpoint(n_samples)
55+
first = 0.1
56+
last = 0.7
57+
# returns (intervalsx2) matrix, with first row start indexes, second z-scores
58+
z_switch = geweke(switchpoint, first=first, last=last, intervals=n_intervals)
2259

23-
assert all([r < 1.5 for r in rhat.values()])
60+
# These z-scores should be larger, since there are not many samples.
61+
self.assertGreater(max(abs(z_switch[:, 1])), 1)
2462

63+
def test_geweke_positive(self):
64+
"""Confirm Geweke diagnostic is smaller than 1 for a reasonable number of samples."""
65+
n_samples = 2000
66+
n_intervals = 20
67+
switchpoint = self.get_switchpoint(n_samples)
2568

26-
def test_geweke(n=3000):
69+
with self.assertRaises(ValueError):
70+
# first and last must be between 0 and 1
71+
geweke(switchpoint, first=-0.3, last=1.1, intervals=n_intervals)
2772

28-
with dm.model:
29-
# Run sampler
30-
step1 = Slice([dm.early_mean, dm.late_mean])
31-
step2 = Metropolis([dm.switchpoint])
32-
trace = sample(n, [step1, step2], progressbar=False,
33-
random_seed=1)
73+
with self.assertRaises(ValueError):
74+
# first and last must add to < 1
75+
geweke(switchpoint, first=0.3, last=0.7, intervals=n_intervals)
3476

35-
z_switch = geweke(trace['switchpoint'], last=.5, intervals=20)
77+
first = 0.1
78+
last = 0.7
79+
# returns (intervalsx2) matrix, with first row start indexes, second z-scores
80+
z_switch = geweke(switchpoint, first=first, last=last, intervals=n_intervals)
81+
start = z_switch[:, 0]
82+
z_scores = z_switch[:, 1]
3683

37-
# Ensure `intervals` argument is honored
38-
assert len(z_switch) == 20
84+
# Ensure `intervals` argument is honored
85+
self.assertEqual(z_switch.shape[0], n_intervals)
3986

40-
# Ensure `last` argument is honored
41-
assert z_switch[-1, 0] < (n / 2)
87+
# Start index should not be in the last <last>% of samples
88+
assert_array_less(start, (1 - last) * n_samples)
4289

43-
# These should all be z-scores
44-
print(max(abs(z_switch[:, 1])))
45-
assert max(abs(z_switch[:, 1])) < 1
90+
# These z-scores should be small, since there are more samples.
91+
self.assertLess(max(abs(z_scores)), 1)
4692

93+
def test_effective_n(self):
94+
"""Check effective sample size is equal to number of samples when initializing with MAP"""
95+
n_jobs = 3
96+
n_samples = 100
4797

48-
def test_effective_n(k=3, n=1000):
49-
"""Unit test for effective sample size"""
50-
51-
model = Model()
52-
with model:
53-
x = Normal('x', 0, 1., shape=5)
98+
with Model():
99+
Normal('x', 0, 1., shape=5)
54100

55-
# start sampling at the MAP
56-
start = find_MAP()
101+
# start sampling at the MAP
102+
start = find_MAP()
103+
step = NUTS(scaling=start)
104+
ptrace = sample(n_samples, step, start, njobs=n_jobs, random_seed=42)
57105

58-
step = NUTS(scaling=start)
59-
60-
ptrace = sample(n, step, start, njobs=k,
61-
random_seed=42)
62-
63-
n_eff = effective_n(ptrace)['x']
64-
65-
assert isclose(n_eff, k*n, 2).all()
106+
n_effective = effective_n(ptrace)['x']
107+
assert_allclose(n_effective, n_jobs * n_samples, 2)

pymc3/tests/test_distributions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
import numpy as np
1919
from numpy.testing import assert_almost_equal
2020
from numpy.linalg import inv
21+
import numpy.random as nr
2122

2223
from scipy import integrate
2324
import scipy.stats.distributions as sp
2425
import scipy.stats
2526

27+
nr.seed(20160905)
2628

2729
class Domain(object):
2830
def __init__(self, vals, dtype=None, edges=None, shape=None):

0 commit comments

Comments
 (0)