5
5
pass
6
6
import numpy as np
7
7
from pymc3 .diagnostics import gelman_rubin
8
- from pymc3 .stats import quantiles , hpd
8
+ from pymc3 .stats import quantiles , hpd , dict2pd
9
9
from .utils import identity_transform , get_default_varnames
10
10
11
-
12
11
def _var_str (name , shape ):
13
12
"""Return a sequence of strings naming the element of the tallyable object.
14
13
15
14
:Example:
16
15
>>> _var_str('theta', (4,))
17
- ['theta[1 ]', 'theta[2 ]', 'theta[3 ]', 'theta[4 ]']
16
+ ['theta[0 ]', 'theta[1 ]', 'theta[2 ]', 'theta[3 ]']
18
17
"""
19
18
size = np .prod (shape )
20
- ind = (np .indices (shape ) + 1 ).reshape (- 1 , size )
19
+ ind = (np .indices (shape )).reshape (- 1 , size )
21
20
names = ['[' + ',' .join (map (str , i )) + ']' for i in zip (* ind )]
22
21
names [0 ] = '%s %s' % (name , names [0 ])
23
22
return names
@@ -45,9 +44,6 @@ def _make_rhat_plot(trace, ax, title, labels, varnames, include_transformed):
45
44
if varnames is None :
46
45
varnames = get_default_varnames (trace .varnames , include_transformed )
47
46
48
- R = gelman_rubin (trace )
49
- R = {v : R [v ] for v in varnames }
50
-
51
47
ax .set_title (title )
52
48
53
49
# Set x range
@@ -62,9 +58,11 @@ def _make_rhat_plot(trace, ax, title, labels, varnames, include_transformed):
62
58
chain = trace .chains [0 ]
63
59
value = trace .get_values (varname , chains = [chain ])[0 ]
64
60
k = np .size (value )
61
+ R = gelman_rubin (trace , varnames = [varname ])
65
62
66
63
if k > 1 :
67
- ax .plot ([min (r , 2 ) for r in R [varname ]],
64
+ Rval = dict2pd (R , 'rhat' ).values
65
+ ax .plot ([min (r , 2 ) for r in Rval ],
68
66
[- (j + i ) for j in range (k )], 'bo' , markersize = 4 )
69
67
else :
70
68
ax .plot (min (R [varname ], 2 ), - i , 'bo' , markersize = 4 )
@@ -261,7 +259,7 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform,
261
259
262
260
# Deal with multivariate nodes
263
261
if k > 1 :
264
- for q in np .transpose ( quants ).squeeze ():
262
+ for q in np .moveaxis ( np . array ( quants ), 0 , - 1 ) .squeeze (). reshape ( - 1 , len ( quants ) ):
265
263
# Multiple y values
266
264
interval_plot = _plot_tree (interval_plot , y , q , quartiles ,
267
265
plot_kwargs )
0 commit comments