Skip to content

Commit 5d2767d

Browse files
aseyboldttwiecki
authored andcommitted
ENH Save step size et al for hmc and nuts (pymc-devs#1687)
* Save step size et al for hmc and nuts * Fix tests for sampler stats * Improve notebook about sampler stats * Fix tests for sampler stats * Rename backend.get_stats * Fix nuts failure after rebase * Update sampler-stats notebook
1 parent 45c11f0 commit 5d2767d

File tree

13 files changed

+925
-57
lines changed

13 files changed

+925
-57
lines changed

docs/source/notebooks/sampler-stats.ipynb

Lines changed: 373 additions & 0 deletions
Large diffs are not rendered by default.

pymc3/backends/base.py

Lines changed: 166 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
import numpy as np
77
from ..model import modelcontext
8+
import warnings
89

910

1011
class BackendError(Exception):
@@ -25,6 +26,8 @@ class BaseTrace(object):
2526
`model.unobserved_RVs` is used.
2627
"""
2728

29+
supports_sampler_stats = False
30+
2831
def __init__(self, name, model=None, vars=None):
2932
self.name = name
3033

@@ -44,10 +47,33 @@ def __init__(self, name, model=None, vars=None):
4447
self.var_dtypes = {var: value.dtype
4548
for var, value in var_values}
4649
self.chain = None
50+
self._is_base_setup = False
51+
self.sampler_vars = None
4752

4853
# Sampling methods
4954

50-
def setup(self, draws, chain):
55+
def _set_sampler_vars(self, sampler_vars):
56+
if sampler_vars is not None and not self.supports_sampler_stats:
57+
raise ValueError("Backend does not support sampler stats.")
58+
59+
if self._is_base_setup and self.sampler_vars != sampler_vars:
60+
raise ValueError("Can't change sampler_vars")
61+
62+
if sampler_vars is None:
63+
self.sampler_vars = None
64+
return
65+
66+
dtypes = {}
67+
for stats in sampler_vars:
68+
for key, dtype in stats.items():
69+
if dtypes.setdefault(key, dtype) != dtype:
70+
raise ValueError("Sampler statistic %s appears with "
71+
"different types." % key)
72+
73+
self.sampler_vars = sampler_vars
74+
75+
76+
def setup(self, draws, chain, sampler_vars=None):
5177
"""Perform chain-specific setup.
5278
5379
Parameters
@@ -56,16 +82,23 @@ def setup(self, draws, chain):
5682
Expected number of draws
5783
chain : int
5884
Chain number
85+
sampler_vars : list of dictionaries (name -> dtype), optional
86+
Diagnostics / statistics for each sampler. Before passing this
87+
to a backend, you should check, that the `supports_sampler_state`
88+
flag is set.
5989
"""
60-
pass
90+
self._set_sampler_vars(sampler_vars)
91+
self._is_base_setup = True
6192

62-
def record(self, point):
93+
def record(self, point, sampler_states=None):
6394
"""Record results of a sampling iteration.
6495
6596
Parameters
6697
----------
6798
point : dict
6899
Values mapped to variable names
100+
sampler_states : list of dicts
101+
The diagnostic values for each sampler
69102
"""
70103
raise NotImplementedError
71104

@@ -105,6 +138,47 @@ def get_values(self, varname, burn=0, thin=1):
105138
"""
106139
raise NotImplementedError
107140

141+
def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
142+
"""Get sampler statistics from the trace.
143+
144+
Parameters
145+
----------
146+
varname : str
147+
sampler_idx : int or None
148+
burn : int
149+
thin : int
150+
151+
Returns
152+
-------
153+
If the `sampler_idx` is specified, return the statistic with
154+
the given name in a numpy array. If it is not specified and there
155+
is more than one sampler that provides this statistic, return
156+
a numpy array of shape (m, n), where `m` is the number of
157+
such samplers, and `n` is the number of samples.
158+
"""
159+
if not self.supports_sampler_stats:
160+
raise ValueError("This backend does not support sampler stats")
161+
162+
if sampler_idx is not None:
163+
return self._get_sampler_stats(varname, sampler_idx, burn, thin)
164+
165+
sampler_idxs = [i for i, s in enumerate(self.sampler_vars)
166+
if varname in s]
167+
if not sampler_idxs:
168+
raise KeyError("Unknown sampler stat %s" % varname)
169+
170+
vals = np.stack([self._get_sampler_stats(varname, i, burn, thin)
171+
for i in sampler_idxs], axis=-1)
172+
if vals.shape[-1] == 1:
173+
return vals[..., 0]
174+
else:
175+
return vals
176+
177+
178+
def _get_sampler_stats(self, varname, sampler_idx, burn, thin):
179+
"""Get sampler statistics."""
180+
raise NotImplementedError()
181+
108182
def _slice(self, idx):
109183
"""Slice trace object."""
110184
raise NotImplementedError
@@ -115,13 +189,24 @@ def point(self, idx):
115189
"""
116190
raise NotImplementedError
117191

192+
@property
193+
def stat_names(self):
194+
if self.supports_sampler_stats:
195+
names = set()
196+
for vars in self.sampler_vars or []:
197+
names.update(vars.keys())
198+
return names
199+
else:
200+
return set()
201+
118202

119203
class MultiTrace(object):
120204
"""Main interface for accessing values from MCMC results
121205
122-
The core method to select values is `get_values`. Values can also be
123-
accessed by indexing the MultiTrace object. Indexing can behave in
124-
three ways:
206+
The core method to select values is `get_values`. The method
207+
to select sampler statistics is `get_sampler_stats`. Both kinds of
208+
values can also be accessed by indexing the MultiTrace object.
209+
Indexing can behave in four ways:
125210
126211
1. Indexing with a variable or variable name (str) returns all
127212
values for that variable, combining values for all chains.
@@ -134,7 +219,7 @@ class MultiTrace(object):
134219
>>> trace[varname, 1000:]
135220
136221
For convenience during interactive use, values can also be
137-
accessed using the variable an attribute.
222+
accessed using the variable as an attribute.
138223
139224
>>> trace.varname
140225
@@ -145,6 +230,11 @@ class MultiTrace(object):
145230
3. Slicing with a range returns a new trace with the number of draws
146231
corresponding to the range.
147232
233+
4. Indexing with the name of a sampler statistic that is not also
234+
the name of a variable returns those values from all chains.
235+
If there is more than one sampler that provides that statistic,
236+
the values are concatenated along a new axis.
237+
148238
For any methods that require a single trace (e.g., taking the length
149239
of the MultiTrace instance, which returns the number of draws), the
150240
trace with the highest chain number is always used.
@@ -189,18 +279,36 @@ def __getitem__(self, idx):
189279
else:
190280
var = idx
191281
burn, thin = 0, 1
192-
return self.get_values(var, burn=burn, thin=thin)
193282

194-
_attrs = set(['_straces', 'varnames', 'chains'])
283+
var = str(var)
284+
if var in self.varnames:
285+
if var in self.stat_names:
286+
warnings.warn("Attribute access on a trace object is ambigous. "
287+
"Sampler statistic and model variable share a name. Use "
288+
"trace.get_values or trace.get_sampler_stats.")
289+
return self.get_values(var, burn=burn, thin=thin)
290+
if var in self.stat_names:
291+
return self.get_sampler_stats(var, burn=burn, thin=thin)
292+
raise KeyError("Unknown variable %s" % var)
293+
294+
_attrs = set(['_straces', 'varnames', 'chains', 'stat_names',
295+
'supports_sampler_stats'])
195296

196297
def __getattr__(self, name):
197298
# Avoid infinite recursion when called before __init__
198299
# variables are set up (e.g., when pickling).
199300
if name in self._attrs:
200301
raise AttributeError
201302

303+
name = str(name)
202304
if name in self.varnames:
203-
return self[name]
305+
if name in self.stat_names:
306+
warnings.warn("Attribute access on a trace object is ambigous. "
307+
"Sampler statistic and model variable share a name. Use "
308+
"trace.get_values or trace.get_sampler_stats.")
309+
return self.get_values(name)
310+
if name in self.stat_names:
311+
return self.get_sampler_stats(name)
204312
raise AttributeError("'{}' object has no attribute '{}'".format(
205313
type(self).__name__, name))
206314

@@ -213,6 +321,21 @@ def varnames(self):
213321
chain = self.chains[-1]
214322
return self._straces[chain].varnames
215323

324+
@property
325+
def stat_names(self):
326+
if not self._straces:
327+
return set()
328+
sampler_vars = [s.sampler_vars for s in self._straces.values()]
329+
if not all(svars == sampler_vars[0] for svars in sampler_vars):
330+
raise ValueError("Inividual chains contain different sampler stats")
331+
names = set()
332+
for trace in self._straces.values():
333+
if trace.sampler_vars is None:
334+
continue
335+
for vars in trace.sampler_vars:
336+
names.update(vars.keys())
337+
return names
338+
216339
def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
217340
squeeze=True):
218341
"""Get values from traces.
@@ -247,6 +370,39 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
247370
results = [self._straces[chains].get_values(varname, burn, thin)]
248371
return _squeeze_cat(results, combine, squeeze)
249372

373+
def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
374+
chains=None, squeeze=True):
375+
"""Get sampler statistics from the trace.
376+
377+
Parameters
378+
----------
379+
varname : str
380+
sampler_idx : int or None
381+
burn : int
382+
thin : int
383+
384+
Returns
385+
-------
386+
If the `sampler_idx` is specified, return the statistic with
387+
the given name in a numpy array. If it is not specified and there
388+
is more than one sampler that provides this statistic, return
389+
a numpy array of shape (m, n), where `m` is the number of
390+
such samplers, and `n` is the number of samples.
391+
"""
392+
if varname not in self.stat_names:
393+
raise KeyError("Unknown sampler statistic %s" % varname)
394+
395+
if chains is None:
396+
chains = self.chains
397+
try:
398+
chains = iter(chains)
399+
except TypeError:
400+
chains = [chains]
401+
402+
results = [self._straces[chain].get_sampler_stats(varname, None, burn, thin)
403+
for chain in chains]
404+
return _squeeze_cat(results, combine, squeeze)
405+
250406
def _slice(self, idx):
251407
"""Return a new MultiTrace object sliced according to `idx`."""
252408
new_traces = [trace._slice(idx) for trace in self._straces.values()]

0 commit comments

Comments
 (0)