Skip to content

Commit a06f957

Browse files
committed
Merge pull request pymc-devs#902 from jonsedar/issue_899_fix
Issue pymc-devs#899 fix
2 parents c4b6ba0 + 9f22f1b commit a06f957

File tree

6 files changed

+57
-45
lines changed

6 files changed

+57
-45
lines changed

pymc3/examples/GHME 2013.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@
723723
"cell_type": "code",
724724
"collapsed": false,
725725
"input": [
726-
"autocorrplot(trace, vars = [coeff_sd,sd ])"
726+
"autocorrplot(trace, varnames = [coeff_sd,sd ])"
727727
],
728728
"language": "python",
729729
"metadata": {},
@@ -739,4 +739,4 @@
739739
"metadata": {}
740740
}
741741
]
742-
}
742+
}

pymc3/examples/GHME_2013.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def run(n=3000):
112112

113113
# <codecell>
114114

115-
autocorrplot(trace, vars = [coeff_sd,sd ])
115+
autocorrplot(trace, varnames = [coeff_sd,sd ])
116116

117117
if __name__ == '__main__':
118118
run()

pymc3/examples/gaussian_mixture_model.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@
242242
],
243243
"source": [
244244
"# I prefer autocorrelation plots for serious confirmation of MCMC convergence\n",
245-
"pm.autocorrplot(tr[5000::5], ['sd'])"
245+
"pm.autocorrplot(tr[5000::5], varnames=['sd'])"
246246
]
247247
},
248248
{

pymc3/examples/survival_analysis.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@
580580
}
581581
],
582582
"source": [
583-
"pm.autocorrplot(trace, vars=['beta']);"
583+
"pm.autocorrplot(trace, varnames=['beta']);"
584584
]
585585
},
586586
{

pymc3/plots.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def traceplot(trace, vars=None, figsize=None,
10-
lines=None, combined=False, grid=True,
10+
lines=None, combined=False, grid=True,
1111
alpha=0.35, ax=None):
1212
"""Plot samples histograms and values
1313
@@ -149,82 +149,94 @@ def kde2plot(x, y, grid=200, ax=None):
149149
return ax
150150

151151

152-
def autocorrplot(trace, vars=None, max_lag=100, burn=0, ax=None,
153-
symmetric_plot=False):
152+
def autocorrplot(trace, varnames=None, max_lag=100, burn=0,
153+
symmetric_plot=False, ax=None, figsize=None):
154154
"""Bar plot of the autocorrelation function for a trace
155+
155156
Parameters
156157
----------
157158
trace : result of MCMC run
158-
vars : list of variable names
159+
varnames : list of variable names
159160
Variables to be plotted, if None all variable are plotted.
160161
Vector-value stochastics are handled automatically.
161162
max_lag : int
162163
Maximum lag to calculate autocorrelation. Defaults to 100.
163164
burn : int
164-
Number of samples to discard from the beginning of the trace.
165+
Number of samples to discard from the beginning of the trace.
165166
Defaults to 0.
166-
ax : axes
167-
Matplotlib axes. Defaults to None.
168167
symmetric_plot : boolean
169168
Plot from either [0, +lag] or [-lag, lag]. Defaults to False, [-, +lag].
170-
169+
ax : axes
170+
Matplotlib axes. Defaults to None.
171+
figsize : figure size tuple
172+
If None, size is (12, num of variables * 2) inches.
173+
Note this is not used if ax is supplied.
174+
171175
Returns
172176
-------
173177
ax : matplotlib axes
174178
"""
175-
179+
176180
import matplotlib.pyplot as plt
177-
178-
def _handle_array_varnames(val):
179-
if trace[0][val].__class__ is np.ndarray:
180-
k = trace[val].shape[1]
181-
for i in xrange(k):
182-
yield val + '_{0}'.format(i)
181+
182+
def _handle_array_varnames(varname):
183+
if trace[0][varname].__class__ is np.ndarray:
184+
k = trace[varname].shape[1]
185+
for i in range(k):
186+
yield varname + '_{0}'.format(i)
183187
else:
184-
yield val
185-
186-
if vars is None:
187-
vars = [item for sub in [[i for i in _handle_array_varnames(var)]
188-
for var in trace.varnames] for item in sub]
188+
yield varname
189+
190+
if varnames is None:
191+
varnames = trace.varnames
189192
else:
190-
vars = [str(var) for var in vars]
191-
vars = [item for sub in [[i for i in _handle_array_varnames(var)]
192-
for var in vars] for item in sub]
193+
varnames = [str(v) for v in varnames]
193194

194-
chains = trace.nchains
195+
varnames = [item for sub in [[i for i in _handle_array_varnames(v)]
196+
for v in varnames] for item in sub]
195197

196-
fig, ax = plt.subplots(len(vars), chains, squeeze=False,
197-
sharex=True, sharey=True)
198+
nchains = trace.nchains
199+
200+
if figsize is None:
201+
figsize = (12, len(varnames)*2)
202+
203+
if ax is None:
204+
fig, ax = plt.subplots(len(varnames), nchains, squeeze=False,
205+
sharex=True, sharey=True, figsize=figsize)
206+
elif ax.shape != (len(varnames), nchains):
207+
raise ValueError('autocorrplot requires {}*{} subplots'.format(
208+
len(varnames), nchains))
209+
return None
198210

199211
max_lag = min(len(trace) - 1, max_lag)
200212

201-
for i, v in enumerate(vars):
202-
for j in range(chains):
213+
for i, v in enumerate(varnames):
214+
for j in range(nchains):
203215
try:
204216
d = np.squeeze(trace.get_values(v, chains=[j], burn=burn,
205217
combine=False))
206218
except KeyError:
207219
k = int(v.split('_')[-1])
208220
v_use = '_'.join(v.split('_')[:-1])
209-
d = np.squeeze(trace.get_values(v_use, chains=[j], burn=burn,
210-
combine=False)[:, k])
221+
d = np.squeeze(trace.get_values(v_use, chains=[j],
222+
burn=burn, combine=False)[:, k])
211223

212224
ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag)
213225

214226
if not j:
215227
ax[i, j].set_ylabel("correlation")
216-
if i == len(vars) - 1:
228+
if i == len(varnames) - 1:
217229
ax[i, j].set_xlabel("lag")
218-
230+
219231
ax[i, j].set_title(v)
220-
232+
221233
if not symmetric_plot:
222234
ax[i, j].set_xlim(0, max_lag)
223-
224-
if chains > 1:
235+
236+
if nchains > 1:
225237
ax[i, j].set_title("chain {0}".format(j+1))
226-
227-
return (fig, ax)
238+
239+
return ax
228240

229241

230242
def var_str(name, shape):
@@ -294,10 +306,10 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
294306
295307
vline (optional): numeric
296308
Location of vertical reference line (defaults to 0).
297-
309+
298310
gs : GridSpec
299311
Matplotlib GridSpec object. Defaults to None.
300-
312+
301313
Returns
302314
-------
303315

pymc3/tests/test_plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_multichain_plots():
5656

5757
forestplot(ptrace, vars=['early_mean', 'late_mean'])
5858

59-
autocorrplot(ptrace, vars=['switchpoint'])
59+
autocorrplot(ptrace, varnames=['switchpoint'])
6060

6161
def test_make_2d():
6262

0 commit comments

Comments
 (0)