@@ -206,11 +206,12 @@ def warnings(self):
206
206
207
207
208
208
# A proposal for the next position
209
- Proposal = namedtuple ("Proposal" , "q, q_grad, energy, log_p_accept , logp" )
209
+ Proposal = namedtuple ("Proposal" , "q, q_grad, energy, log_p_accept_weighted , logp" )
210
210
211
211
# A subtree of the binary tree built by nuts.
212
212
Subtree = namedtuple (
213
- "Subtree" , "left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals"
213
+ "Subtree" ,
214
+ "left, right, p_sum, proposal, log_size, log_weighted_accept_sum, n_proposals" ,
214
215
)
215
216
216
217
@@ -243,7 +244,7 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
243
244
)
244
245
self .depth = 0
245
246
self .log_size = 0
246
- self .log_accept_sum = - np .inf
247
+ self .log_weighted_accept_sum = - np .inf
247
248
self .mean_tree_accept = 0.0
248
249
self .n_proposals = 0
249
250
self .p_sum = start .p .copy ()
@@ -291,7 +292,9 @@ def extend(self, direction):
291
292
self .proposal = tree .proposal
292
293
293
294
self .log_size = np .logaddexp (self .log_size , tree .log_size )
294
- self .log_accept_sum = np .logaddexp (self .log_accept_sum , tree .log_accept_sum )
295
+ self .log_weighted_accept_sum = np .logaddexp (
296
+ self .log_weighted_accept_sum , tree .log_weighted_accept_sum
297
+ )
295
298
self .p_sum [:] += tree .p_sum
296
299
297
300
# Additional turning check only when tree depth > 0 to avoid redundant work
@@ -331,13 +334,17 @@ def _single_step(self, left, epsilon):
331
334
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
332
335
# Saturated Metropolis accept probability with Boltzmann weight
333
336
# if h - H0 < 0
334
- log_p_accept = - energy_change + min (0.0 , - energy_change )
337
+ log_p_accept_weighted = - energy_change + min (0.0 , - energy_change )
335
338
log_size = - energy_change
336
339
proposal = Proposal (
337
- right .q , right .q_grad , right .energy , log_p_accept , right .model_logp
340
+ right .q ,
341
+ right .q_grad ,
342
+ right .energy ,
343
+ log_p_accept_weighted ,
344
+ right .model_logp ,
338
345
)
339
346
tree = Subtree (
340
- right , right , right .p , proposal , log_size , log_p_accept , 1
347
+ right , right , right .p , proposal , log_size , log_p_accept_weighted , 1
341
348
)
342
349
return tree , None , False
343
350
else :
@@ -377,21 +384,23 @@ def _build_subtree(self, left, depth, epsilon):
377
384
turning = turning | turning1 | turning2
378
385
379
386
log_size = np .logaddexp (tree1 .log_size , tree2 .log_size )
380
- log_accept_sum = np .logaddexp (tree1 .log_accept_sum , tree2 .log_accept_sum )
387
+ log_weighted_accept_sum = np .logaddexp (
388
+ tree1 .log_weighted_accept_sum , tree2 .log_weighted_accept_sum
389
+ )
381
390
if logbern (tree2 .log_size - log_size ):
382
391
proposal = tree2 .proposal
383
392
else :
384
393
proposal = tree1 .proposal
385
394
else :
386
395
p_sum = tree1 .p_sum
387
396
log_size = tree1 .log_size
388
- log_accept_sum = tree1 .log_accept_sum
397
+ log_weighted_accept_sum = tree1 .log_weighted_accept_sum
389
398
proposal = tree1 .proposal
390
399
391
400
n_proposals = tree1 .n_proposals + tree2 .n_proposals
392
401
393
402
tree = Subtree (
394
- left , right , p_sum , proposal , log_size , log_accept_sum , n_proposals
403
+ left , right , p_sum , proposal , log_size , log_weighted_accept_sum , n_proposals
395
404
)
396
405
return tree , diverging , turning
397
406
@@ -401,7 +410,9 @@ def stats(self):
401
410
# Remove contribution from initial state which is always a perfect
402
411
# accept
403
412
log_sum_weight = logdiffexp_numpy (self .log_size , 0.0 )
404
- self .mean_tree_accept = np .exp (self .log_accept_sum - log_sum_weight )
413
+ self .mean_tree_accept = np .exp (
414
+ self .log_weighted_accept_sum - log_sum_weight
415
+ )
405
416
return {
406
417
"depth" : self .depth ,
407
418
"mean_tree_accept" : self .mean_tree_accept ,
0 commit comments