@@ -458,9 +458,16 @@ class BinaryGibbsMetropolis(ArrayStep):
458
458
459
459
name = "binary_gibbs_metropolis"
460
460
461
+ stats_dtypes_shapes = {
462
+ "tune" : (bool , []),
463
+ }
464
+
461
465
def __init__ (self , vars , order = "random" , transit_p = 0.8 , model = None ):
462
466
model = pm .modelcontext (model )
463
467
468
+ # Doesn't actually tune, but it's required to emit a sampler stat
469
+ # that indicates whether a draw was done in a tuning phase.
470
+ self .tune = True
464
471
# transition probabilities
465
472
self .transit_p = transit_p
466
473
@@ -483,6 +490,11 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
483
490
484
491
super ().__init__ (vars , [model .compile_logp ()])
485
492
493
+ def reset_tuning (self ):
494
+ # There are no tuning parameters in this step method.
495
+ self .tune = False
496
+ return
497
+
486
498
def astep (self , apoint : RaveledVars , * args ) -> Tuple [RaveledVars , StatsType ]:
487
499
logp : Callable [[RaveledVars ], np .ndarray ] = args [0 ]
488
500
order = self .order
@@ -503,7 +515,10 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
503
515
if accepted :
504
516
logp_curr = logp_prop
505
517
506
- return q , []
518
+ stats = {
519
+ "tune" : self .tune ,
520
+ }
521
+ return q , [stats ]
507
522
508
523
@staticmethod
509
524
def competence (var ):
@@ -543,6 +558,10 @@ class CategoricalGibbsMetropolis(ArrayStep):
543
558
544
559
name = "categorical_gibbs_metropolis"
545
560
561
+ stats_dtypes_shapes = {
562
+ "tune" : (bool , []),
563
+ }
564
+
546
565
def __init__ (self , vars , proposal = "uniform" , order = "random" , model = None ):
547
566
model = pm .modelcontext (model )
548
567
@@ -593,8 +612,17 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
593
612
else :
594
613
raise ValueError ("Argument 'proposal' should either be 'uniform' or 'proportional'" )
595
614
615
+ # Doesn't actually tune, but it's required to emit a sampler stat
616
+ # that indicates whether a draw was done in a tuning phase.
617
+ self .tune = True
618
+
596
619
super ().__init__ (vars , [model .compile_logp ()])
597
620
621
+ def reset_tuning (self ):
622
+ # There are no tuning parameters in this step method.
623
+ self .tune = False
624
+ return
625
+
598
626
def astep_unif (self , apoint : RaveledVars , * args ) -> Tuple [RaveledVars , StatsType ]:
599
627
logp = args [0 ]
600
628
point_map_info = apoint .point_map_info
@@ -614,7 +642,10 @@ def astep_unif(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType
614
642
if accepted :
615
643
logp_curr = logp_prop
616
644
617
- return q , []
645
+ stats = {
646
+ "tune" : self .tune ,
647
+ }
648
+ return q , [stats ]
618
649
619
650
def astep_prop (self , apoint : RaveledVars , * args ) -> Tuple [RaveledVars , StatsType ]:
620
651
logp = args [0 ]
0 commit comments