Skip to content

Commit 1b1caa6

Browse files
author
Junpeng Lao
authored
Merge pull request pymc-devs#2672 from aplavin/patch-1
support multidimensional arrays in forestplot
2 parents 66d21e1 + c4ce563 commit 1b1caa6

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

pymc3/plots/forestplot.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,18 @@
55
pass
66
import numpy as np
77
from pymc3.diagnostics import gelman_rubin
8-
from pymc3.stats import quantiles, hpd
8+
from pymc3.stats import quantiles, hpd, dict2pd
99
from .utils import identity_transform, get_default_varnames
1010

11-
1211
def _var_str(name, shape):
1312
"""Return a sequence of strings naming the element of the tallyable object.
1413
1514
:Example:
1615
>>> _var_str('theta', (4,))
17-
['theta[1]', 'theta[2]', 'theta[3]', 'theta[4]']
16+
['theta[0]', 'theta[1]', 'theta[2]', 'theta[3]']
1817
"""
1918
size = np.prod(shape)
20-
ind = (np.indices(shape) + 1).reshape(-1, size)
19+
ind = (np.indices(shape)).reshape(-1, size)
2120
names = ['[' + ','.join(map(str, i)) + ']' for i in zip(*ind)]
2221
names[0] = '%s %s' % (name, names[0])
2322
return names
@@ -45,9 +44,6 @@ def _make_rhat_plot(trace, ax, title, labels, varnames, include_transformed):
4544
if varnames is None:
4645
varnames = get_default_varnames(trace.varnames, include_transformed)
4746

48-
R = gelman_rubin(trace)
49-
R = {v: R[v] for v in varnames}
50-
5147
ax.set_title(title)
5248

5349
# Set x range
@@ -62,9 +58,11 @@ def _make_rhat_plot(trace, ax, title, labels, varnames, include_transformed):
6258
chain = trace.chains[0]
6359
value = trace.get_values(varname, chains=[chain])[0]
6460
k = np.size(value)
61+
R = gelman_rubin(trace, varnames=[varname])
6562

6663
if k > 1:
67-
ax.plot([min(r, 2) for r in R[varname]],
64+
Rval = dict2pd(R, 'rhat').values
65+
ax.plot([min(r, 2) for r in Rval],
6866
[-(j + i) for j in range(k)], 'bo', markersize=4)
6967
else:
7068
ax.plot(min(R[varname], 2), -i, 'bo', markersize=4)
@@ -261,7 +259,7 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform,
261259

262260
# Deal with multivariate nodes
263261
if k > 1:
264-
for q in np.transpose(quants).squeeze():
262+
for q in np.moveaxis(np.array(quants), 0, -1).squeeze().reshape(-1, len(quants)):
265263
# Multiple y values
266264
interval_plot = _plot_tree(interval_plot, y, q, quartiles,
267265
plot_kwargs)

pymc3/tests/test_plots.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_plots():
2323
start = model.test_point
2424
h = find_hessian(start)
2525
step = Metropolis(model.vars, h)
26-
trace = sample(3000, tune=0, step=step, start=start, njobs=1)
26+
trace = sample(3000, tune=0, step=step, start=start, chains=1)
2727

2828
traceplot(trace)
2929
forestplot(trace)
@@ -48,21 +48,22 @@ def test_plots_categorical():
4848
start = model.test_point
4949
h = find_hessian(start)
5050
step = Metropolis(model.vars, h)
51-
trace = sample(3000, tune=0, step=step, start=start, njobs=1)
51+
trace = sample(3000, tune=0, step=step, start=start, chains=1)
5252

53-
traceplot(trace)
53+
traceplot(trace)
5454

5555

5656
def test_plots_multidimensional():
57-
# Test single trace
57+
# Test multiple trace
5858
start, model, _ = multidimensional_model()
5959
with model:
6060
h = np.diag(find_hessian(start))
6161
step = Metropolis(model.vars, h)
6262
trace = sample(3000, tune=0, step=step, start=start)
63-
64-
traceplot(trace)
65-
plot_posterior(trace)
63+
64+
traceplot(trace)
65+
plot_posterior(trace)
66+
forestplot(trace)
6667

6768
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on GPU due to njobs=2")
6869
def test_multichain_plots():

0 commit comments

Comments
 (0)