Skip to content

Commit 9b50ea9

Browse files
committed
Rename property name in NUTS to better reflect the computation
1 parent 8f74fcb commit 9b50ea9

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

pymc3/step_methods/hmc/nuts.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,12 @@ def warnings(self):
206206

207207

208208
# A proposal for the next position
209-
Proposal = namedtuple("Proposal", "q, q_grad, energy, log_p_accept, logp")
209+
Proposal = namedtuple("Proposal", "q, q_grad, energy, log_p_accept_weighted, logp")
210210

211211
# A subtree of the binary tree built by nuts.
212212
Subtree = namedtuple(
213-
"Subtree", "left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals"
213+
"Subtree",
214+
"left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals",
214215
)
215216

216217

@@ -243,7 +244,7 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
243244
)
244245
self.depth = 0
245246
self.log_size = 0
246-
self.log_accept_sum = -np.inf
247+
self.log_weighted_accept_sum = -np.inf
247248
self.mean_tree_accept = 0.0
248249
self.n_proposals = 0
249250
self.p_sum = start.p.copy()
@@ -291,7 +292,9 @@ def extend(self, direction):
291292
self.proposal = tree.proposal
292293

293294
self.log_size = np.logaddexp(self.log_size, tree.log_size)
294-
self.log_accept_sum = np.logaddexp(self.log_accept_sum, tree.log_accept_sum)
295+
self.log_weighted_accept_sum = np.logaddexp(
296+
self.log_weighted_accept_sum, tree.log_weighted_accept_sum
297+
)
295298
self.p_sum[:] += tree.p_sum
296299

297300
# Additional turning check only when tree depth > 0 to avoid redundant work
@@ -331,13 +334,17 @@ def _single_step(self, left, epsilon):
331334
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
332335
# Saturated Metropolis accept probability with Boltzmann weight
333336
# if h - H0 < 0
334-
log_p_accept = -energy_change + min(0.0, -energy_change)
337+
log_p_accept_weighted = -energy_change + min(0.0, -energy_change)
335338
log_size = -energy_change
336339
proposal = Proposal(
337-
right.q, right.q_grad, right.energy, log_p_accept, right.model_logp
340+
right.q,
341+
right.q_grad,
342+
right.energy,
343+
log_p_accept_weighted,
344+
right.model_logp,
338345
)
339346
tree = Subtree(
340-
right, right, right.p, proposal, log_size, log_p_accept, 1
347+
right, right, right.p, proposal, log_size, log_p_accept_weighted, 1
341348
)
342349
return tree, None, False
343350
else:
@@ -377,21 +384,23 @@ def _build_subtree(self, left, depth, epsilon):
377384
turning = turning | turning1 | turning2
378385

379386
log_size = np.logaddexp(tree1.log_size, tree2.log_size)
380-
log_accept_sum = np.logaddexp(tree1.log_accept_sum, tree2.log_accept_sum)
387+
log_weighted_accept_sum = np.logaddexp(
388+
tree1.log_weighted_accept_sum, tree2.log_weighted_accept_sum
389+
)
381390
if logbern(tree2.log_size - log_size):
382391
proposal = tree2.proposal
383392
else:
384393
proposal = tree1.proposal
385394
else:
386395
p_sum = tree1.p_sum
387396
log_size = tree1.log_size
388-
log_accept_sum = tree1.log_accept_sum
397+
log_weighted_accept_sum = tree1.log_weighted_accept_sum
389398
proposal = tree1.proposal
390399

391400
n_proposals = tree1.n_proposals + tree2.n_proposals
392401

393402
tree = Subtree(
394-
left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals
403+
left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals
395404
)
396405
return tree, diverging, turning
397406

@@ -401,7 +410,9 @@ def stats(self):
401410
# Remove contribution from initial state which is always a perfect
402411
# accept
403412
log_sum_weight = logdiffexp_numpy(self.log_size, 0.0)
404-
self.mean_tree_accept = np.exp(self.log_accept_sum - log_sum_weight)
413+
self.mean_tree_accept = np.exp(
414+
self.log_weighted_accept_sum - log_sum_weight
415+
)
405416
return {
406417
"depth": self.depth,
407418
"mean_tree_accept": self.mean_tree_accept,

0 commit comments

Comments
 (0)