5
5
"""
6
6
import numpy as np
7
7
from ..model import modelcontext
8
+ import warnings
8
9
9
10
10
11
class BackendError (Exception ):
@@ -25,6 +26,8 @@ class BaseTrace(object):
25
26
`model.unobserved_RVs` is used.
26
27
"""
27
28
29
+ supports_sampler_stats = False
30
+
28
31
def __init__ (self , name , model = None , vars = None ):
29
32
self .name = name
30
33
@@ -44,10 +47,33 @@ def __init__(self, name, model=None, vars=None):
44
47
self .var_dtypes = {var : value .dtype
45
48
for var , value in var_values }
46
49
self .chain = None
50
+ self ._is_base_setup = False
51
+ self .sampler_vars = None
47
52
48
53
# Sampling methods
49
54
50
- def setup (self , draws , chain ):
55
+ def _set_sampler_vars (self , sampler_vars ):
56
+ if sampler_vars is not None and not self .supports_sampler_stats :
57
+ raise ValueError ("Backend does not support sampler stats." )
58
+
59
+ if self ._is_base_setup and self .sampler_vars != sampler_vars :
60
+ raise ValueError ("Can't change sampler_vars" )
61
+
62
+ if sampler_vars is None :
63
+ self .sampler_vars = None
64
+ return
65
+
66
+ dtypes = {}
67
+ for stats in sampler_vars :
68
+ for key , dtype in stats .items ():
69
+ if dtypes .setdefault (key , dtype ) != dtype :
70
+ raise ValueError ("Sampler statistic %s appears with "
71
+ "different types." % key )
72
+
73
+ self .sampler_vars = sampler_vars
74
+
75
+
76
+ def setup (self , draws , chain , sampler_vars = None ):
51
77
"""Perform chain-specific setup.
52
78
53
79
Parameters
@@ -56,16 +82,23 @@ def setup(self, draws, chain):
56
82
Expected number of draws
57
83
chain : int
58
84
Chain number
85
+ sampler_vars : list of dictionaries (name -> dtype), optional
86
+ Diagnostics / statistics for each sampler. Before passing this
87
+ to a backend, you should check, that the `supports_sampler_state`
88
+ flag is set.
59
89
"""
60
- pass
90
+ self ._set_sampler_vars (sampler_vars )
91
+ self ._is_base_setup = True
61
92
62
- def record (self , point ):
93
+ def record (self , point , sampler_states = None ):
63
94
"""Record results of a sampling iteration.
64
95
65
96
Parameters
66
97
----------
67
98
point : dict
68
99
Values mapped to variable names
100
+ sampler_states : list of dicts
101
+ The diagnostic values for each sampler
69
102
"""
70
103
raise NotImplementedError
71
104
@@ -105,6 +138,47 @@ def get_values(self, varname, burn=0, thin=1):
105
138
"""
106
139
raise NotImplementedError
107
140
141
+ def get_sampler_stats (self , varname , sampler_idx = None , burn = 0 , thin = 1 ):
142
+ """Get sampler statistics from the trace.
143
+
144
+ Parameters
145
+ ----------
146
+ varname : str
147
+ sampler_idx : int or None
148
+ burn : int
149
+ thin : int
150
+
151
+ Returns
152
+ -------
153
+ If the `sampler_idx` is specified, return the statistic with
154
+ the given name in a numpy array. If it is not specified and there
155
+ is more than one sampler that provides this statistic, return
156
+ a numpy array of shape (m, n), where `m` is the number of
157
+ such samplers, and `n` is the number of samples.
158
+ """
159
+ if not self .supports_sampler_stats :
160
+ raise ValueError ("This backend does not support sampler stats" )
161
+
162
+ if sampler_idx is not None :
163
+ return self ._get_sampler_stats (varname , sampler_idx , burn , thin )
164
+
165
+ sampler_idxs = [i for i , s in enumerate (self .sampler_vars )
166
+ if varname in s ]
167
+ if not sampler_idxs :
168
+ raise KeyError ("Unknown sampler stat %s" % varname )
169
+
170
+ vals = np .stack ([self ._get_sampler_stats (varname , i , burn , thin )
171
+ for i in sampler_idxs ], axis = - 1 )
172
+ if vals .shape [- 1 ] == 1 :
173
+ return vals [..., 0 ]
174
+ else :
175
+ return vals
176
+
177
+
178
+ def _get_sampler_stats (self , varname , sampler_idx , burn , thin ):
179
+ """Get sampler statistics."""
180
+ raise NotImplementedError ()
181
+
108
182
def _slice (self , idx ):
109
183
"""Slice trace object."""
110
184
raise NotImplementedError
@@ -115,13 +189,24 @@ def point(self, idx):
115
189
"""
116
190
raise NotImplementedError
117
191
192
+ @property
193
+ def stat_names (self ):
194
+ if self .supports_sampler_stats :
195
+ names = set ()
196
+ for vars in self .sampler_vars or []:
197
+ names .update (vars .keys ())
198
+ return names
199
+ else :
200
+ return set ()
201
+
118
202
119
203
class MultiTrace (object ):
120
204
"""Main interface for accessing values from MCMC results
121
205
122
- The core method to select values is `get_values`. Values can also be
123
- accessed by indexing the MultiTrace object. Indexing can behave in
124
- three ways:
206
+ The core method to select values is `get_values`. The method
207
+ to select sampler statistics is `get_sampler_stats`. Both kinds of
208
+ values can also be accessed by indexing the MultiTrace object.
209
+ Indexing can behave in four ways:
125
210
126
211
1. Indexing with a variable or variable name (str) returns all
127
212
values for that variable, combining values for all chains.
@@ -134,7 +219,7 @@ class MultiTrace(object):
134
219
>>> trace[varname, 1000:]
135
220
136
221
For convenience during interactive use, values can also be
137
- accessed using the variable an attribute.
222
+ accessed using the variable as an attribute.
138
223
139
224
>>> trace.varname
140
225
@@ -145,6 +230,11 @@ class MultiTrace(object):
145
230
3. Slicing with a range returns a new trace with the number of draws
146
231
corresponding to the range.
147
232
233
+ 4. Indexing with the name of a sampler statistic that is not also
234
+ the name of a variable returns those values from all chains.
235
+ If there is more than one sampler that provides that statistic,
236
+ the values are concatenated along a new axis.
237
+
148
238
For any methods that require a single trace (e.g., taking the length
149
239
of the MultiTrace instance, which returns the number of draws), the
150
240
trace with the highest chain number is always used.
@@ -189,18 +279,36 @@ def __getitem__(self, idx):
189
279
else :
190
280
var = idx
191
281
burn , thin = 0 , 1
192
- return self .get_values (var , burn = burn , thin = thin )
193
282
194
- _attrs = set (['_straces' , 'varnames' , 'chains' ])
283
+ var = str (var )
284
+ if var in self .varnames :
285
+ if var in self .stat_names :
286
+ warnings .warn ("Attribute access on a trace object is ambigous. "
287
+ "Sampler statistic and model variable share a name. Use "
288
+ "trace.get_values or trace.get_sampler_stats." )
289
+ return self .get_values (var , burn = burn , thin = thin )
290
+ if var in self .stat_names :
291
+ return self .get_sampler_stats (var , burn = burn , thin = thin )
292
+ raise KeyError ("Unknown variable %s" % var )
293
+
294
+ _attrs = set (['_straces' , 'varnames' , 'chains' , 'stat_names' ,
295
+ 'supports_sampler_stats' ])
195
296
196
297
def __getattr__ (self , name ):
197
298
# Avoid infinite recursion when called before __init__
198
299
# variables are set up (e.g., when pickling).
199
300
if name in self ._attrs :
200
301
raise AttributeError
201
302
303
+ name = str (name )
202
304
if name in self .varnames :
203
- return self [name ]
305
+ if name in self .stat_names :
306
+ warnings .warn ("Attribute access on a trace object is ambigous. "
307
+ "Sampler statistic and model variable share a name. Use "
308
+ "trace.get_values or trace.get_sampler_stats." )
309
+ return self .get_values (name )
310
+ if name in self .stat_names :
311
+ return self .get_sampler_stats (name )
204
312
raise AttributeError ("'{}' object has no attribute '{}'" .format (
205
313
type (self ).__name__ , name ))
206
314
@@ -213,6 +321,21 @@ def varnames(self):
213
321
chain = self .chains [- 1 ]
214
322
return self ._straces [chain ].varnames
215
323
324
+ @property
325
+ def stat_names (self ):
326
+ if not self ._straces :
327
+ return set ()
328
+ sampler_vars = [s .sampler_vars for s in self ._straces .values ()]
329
+ if not all (svars == sampler_vars [0 ] for svars in sampler_vars ):
330
+ raise ValueError ("Inividual chains contain different sampler stats" )
331
+ names = set ()
332
+ for trace in self ._straces .values ():
333
+ if trace .sampler_vars is None :
334
+ continue
335
+ for vars in trace .sampler_vars :
336
+ names .update (vars .keys ())
337
+ return names
338
+
216
339
def get_values (self , varname , burn = 0 , thin = 1 , combine = True , chains = None ,
217
340
squeeze = True ):
218
341
"""Get values from traces.
@@ -247,6 +370,39 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
247
370
results = [self ._straces [chains ].get_values (varname , burn , thin )]
248
371
return _squeeze_cat (results , combine , squeeze )
249
372
373
+ def get_sampler_stats (self , varname , burn = 0 , thin = 1 , combine = True ,
374
+ chains = None , squeeze = True ):
375
+ """Get sampler statistics from the trace.
376
+
377
+ Parameters
378
+ ----------
379
+ varname : str
380
+ sampler_idx : int or None
381
+ burn : int
382
+ thin : int
383
+
384
+ Returns
385
+ -------
386
+ If the `sampler_idx` is specified, return the statistic with
387
+ the given name in a numpy array. If it is not specified and there
388
+ is more than one sampler that provides this statistic, return
389
+ a numpy array of shape (m, n), where `m` is the number of
390
+ such samplers, and `n` is the number of samples.
391
+ """
392
+ if varname not in self .stat_names :
393
+ raise KeyError ("Unknown sampler statistic %s" % varname )
394
+
395
+ if chains is None :
396
+ chains = self .chains
397
+ try :
398
+ chains = iter (chains )
399
+ except TypeError :
400
+ chains = [chains ]
401
+
402
+ results = [self ._straces [chain ].get_sampler_stats (varname , None , burn , thin )
403
+ for chain in chains ]
404
+ return _squeeze_cat (results , combine , squeeze )
405
+
250
406
def _slice (self , idx ):
251
407
"""Return a new MultiTrace object sliced according to `idx`."""
252
408
new_traces = [trace ._slice (idx ) for trace in self ._straces .values ()]
0 commit comments