Skip to content

Commit 38e87e2

Browse files
Emit "tune" from all stepmethods
1 parent 63483b7 commit 38e87e2

File tree

3 files changed

+54
-2
lines changed

3 files changed

+54
-2
lines changed

pymc/step_methods/compound.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,16 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]:
262262
return steps
263263

264264

265+
def check_step_emits_tune(step: Union[CompoundStep, BlockedStep]):
266+
if isinstance(step, BlockedStep) and "tune" not in step.stats_dtypes_shapes:
267+
raise TypeError(f"{type(step)} does not emit the required 'tune' stat.")
268+
elif isinstance(step, CompoundStep):
269+
for sstep in step.methods:
270+
if "tune" not in sstep.stats_dtypes_shapes:
271+
raise TypeError(f"{type(sstep)} does not emit the required 'tune' stat.")
272+
return
273+
274+
265275
class StatsBijection:
266276
"""Map between a `list` of stats to `dict` of stats."""
267277

pymc/step_methods/metropolis.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,9 +458,16 @@ class BinaryGibbsMetropolis(ArrayStep):
458458

459459
name = "binary_gibbs_metropolis"
460460

461+
stats_dtypes_shapes = {
462+
"tune": (bool, []),
463+
}
464+
461465
def __init__(self, vars, order="random", transit_p=0.8, model=None):
462466
model = pm.modelcontext(model)
463467

468+
# Doesn't actually tune, but it's required to emit a sampler stat
469+
# that indicates whether a draw was done in a tuning phase.
470+
self.tune = True
464471
# transition probabilities
465472
self.transit_p = transit_p
466473

@@ -483,6 +490,11 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
483490

484491
super().__init__(vars, [model.compile_logp()])
485492

493+
def reset_tuning(self):
494+
# There are no tuning parameters in this step method.
495+
self.tune = False
496+
return
497+
486498
def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
487499
logp: Callable[[RaveledVars], np.ndarray] = args[0]
488500
order = self.order
@@ -503,7 +515,10 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
503515
if accepted:
504516
logp_curr = logp_prop
505517

506-
return q, []
518+
stats = {
519+
"tune": self.tune,
520+
}
521+
return q, [stats]
507522

508523
@staticmethod
509524
def competence(var):
@@ -543,6 +558,10 @@ class CategoricalGibbsMetropolis(ArrayStep):
543558

544559
name = "categorical_gibbs_metropolis"
545560

561+
stats_dtypes_shapes = {
562+
"tune": (bool, []),
563+
}
564+
546565
def __init__(self, vars, proposal="uniform", order="random", model=None):
547566
model = pm.modelcontext(model)
548567

@@ -593,8 +612,17 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
593612
else:
594613
raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")
595614

615+
# Doesn't actually tune, but it's required to emit a sampler stat
616+
# that indicates whether a draw was done in a tuning phase.
617+
self.tune = True
618+
596619
super().__init__(vars, [model.compile_logp()])
597620

621+
def reset_tuning(self):
622+
# There are no tuning parameters in this step method.
623+
self.tune = False
624+
return
625+
598626
def astep_unif(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
599627
logp = args[0]
600628
point_map_info = apoint.point_map_info
@@ -614,7 +642,10 @@ def astep_unif(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType
614642
if accepted:
615643
logp_curr = logp_prop
616644

617-
return q, []
645+
stats = {
646+
"tune": self.tune,
647+
}
648+
return q, [stats]
618649

619650
def astep_prop(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
620651
logp = args[0]

tests/step_methods/test_compound.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Slice,
2727
)
2828
from pymc.step_methods.compound import (
29+
BlockedStep,
2930
StatsBijection,
3031
flatten_steps,
3132
get_stats_dtypes_shapes_from_steps,
@@ -36,6 +37,16 @@
3637
from tests.models import simple_2model_continuous
3738

3839

40+
def test_all_stepmethods_emit_tune_stat():
41+
attrs = [getattr(pm.step_methods, n) for n in dir(pm.step_methods)]
42+
step_types = [
43+
attr for attr in attrs if isinstance(attr, type) and issubclass(attr, BlockedStep)
44+
]
45+
assert len(step_types) > 5
46+
for cls in step_types:
47+
assert "tune" in cls.stats_dtypes_shapes
48+
49+
3950
class TestCompoundStep:
4051
samplers = (Metropolis, Slice, HamiltonianMC, NUTS, DEMetropolis)
4152

0 commit comments

Comments
 (0)