Skip to content

Commit c645079

Browse files
author
Christopher Fonnesbeck
committed
Fix for r-hat calculation on multidimensional nodes
1 parent a0ea4dc commit c645079

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

pymc/diagnostics.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,24 @@ def gelman_rubin(mtrace):
133133

134134
def calc_rhat(x):
135135

136-
m, n = x.shape
136+
try:
137+
m, n = x.shape
138+
139+
# Calculate between-chain variance
140+
B = n * np.var(np.mean(x, axis=1), ddof=1)
137141

138-
# Calculate between-chain variance
139-
B = n * np.var(np.mean(x, axis=1), ddof=1)
142+
# Calculate within-chain variance
143+
W = np.mean(np.var(x, axis=1, ddof=1))
140144

141-
# Calculate within-chain variance
142-
W = np.mean(np.var(x, axis=1, ddof=1))
145+
# Estimate of marginal posterior variance
146+
Vhat = W*(n - 1)/n + B/n
143147

144-
# Estimate of marginal posterior variance
145-
Vhat = W*(n - 1)/n + B/n
148+
return np.sqrt(Vhat/W)
149+
150+
except ValueError:
146151

147-
return np.sqrt(Vhat/W)
152+
rotated_indices = np.roll(np.arange(x.ndim), 1)
153+
return np.squeeze([calc_rhat(xi) for xi in x.transpose(rotated_indices)])
148154

149155
Rhat = {}
150156
for var in mtrace.varnames:

0 commit comments

Comments
 (0)