Skip to content

Commit c73f97d

Browse files
lucianopaztwiecki
authored andcommitted
Various fixes to random sampling
* Partial fix. Got stumped with LKJCholeskyCov not having a random method * Fixed the basic defects. Still must do new tests. * Added test for sampling prior and posterior predictives from a mixture based on issue #3270 * Fixed bugs and failed tests. Still must write a test for LKJCholeskyCov.random method. * Fixed failed test * Added tests for DrawValuesContext and also for the context blocker that was introduced in this PR. * Changed six metaclass to metaclass keyword
1 parent ee331c0 commit c73f97d

File tree

7 files changed

+414
-73
lines changed

7 files changed

+414
-73
lines changed

pymc3/distributions/distribution.py

Lines changed: 60 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def __new__(cls, *args, **kwargs):
221221
potential_parent = cls.get_contexts()[-1]
222222
# We have to make sure that the context is a _DrawValuesContext
223223
# and not a Model
224-
if isinstance(potential_parent, cls):
224+
if isinstance(potential_parent, _DrawValuesContext):
225225
instance._parent = potential_parent
226226
else:
227227
instance._parent = None
@@ -235,7 +235,8 @@ def __init__(self):
235235
# another _DrawValuesContext will share the reference to the
236236
# drawn_vars dictionary. This means that separate branches
237237
# in the nested _DrawValuesContext context tree will see the
238-
# same drawn values
238+
# same drawn values.
239+
# The drawn_vars keys shall be (RV, size) tuples
239240
self.drawn_vars = self.parent.drawn_vars
240241
else:
241242
self.drawn_vars = dict()
@@ -245,6 +246,22 @@ def parent(self):
245246
return self._parent
246247

247248

249+
class _DrawValuesContextBlocker(_DrawValuesContext, metaclass=InitContextMeta):
250+
"""
251+
Context manager that starts a new drawn variables context disregarding all
252+
parent contexts. This can be used inside a random method to ensure that
253+
the drawn values wont be the ones cached by previous calls
254+
"""
255+
def __new__(cls, *args, **kwargs):
256+
# resolves the parent instance
257+
instance = super(_DrawValuesContextBlocker, cls).__new__(cls)
258+
instance._parent = None
259+
return instance
260+
261+
def __init__(self):
262+
self.drawn_vars = dict()
263+
264+
248265
def is_fast_drawable(var):
249266
return isinstance(var, (numbers.Number,
250267
np.ndarray,
@@ -288,14 +305,14 @@ def draw_values(params, point=None, size=None):
288305
continue
289306

290307
name = getattr(p, 'name', None)
291-
if p in drawn:
308+
if (p, size) in drawn:
292309
# param was drawn in related contexts
293-
v = drawn[p]
310+
v = drawn[(p, size)]
294311
evaluated[i] = v
295312
elif name is not None and name in point:
296313
# param.name is in point
297314
v = point[name]
298-
evaluated[i] = drawn[p] = v
315+
evaluated[i] = drawn[(p, size)] = v
299316
else:
300317
# param still needs to be drawn
301318
symbolic_params.append((i, p))
@@ -330,12 +347,12 @@ def draw_values(params, point=None, size=None):
330347
named_nodes_children[k].update(nnc[k])
331348

332349
# Init givens and the stack of nodes to try to `_draw_value` from
333-
givens = {p.name: (p, v) for p, v in drawn.items()
350+
givens = {p.name: (p, v) for (p, size), v in drawn.items()
334351
if getattr(p, 'name', None) is not None}
335352
stack = list(leaf_nodes.values()) # A queue would be more appropriate
336353
while stack:
337354
next_ = stack.pop(0)
338-
if next_ in drawn:
355+
if (next_, size) in drawn:
339356
# If the node already has a givens value, skip it
340357
continue
341358
elif isinstance(next_, (tt.TensorConstant,
@@ -364,14 +381,14 @@ def draw_values(params, point=None, size=None):
364381
givens=temp_givens,
365382
size=size)
366383
givens[next_.name] = (next_, value)
367-
drawn[next_] = value
384+
drawn[(next_, size)] = value
368385
except theano.gof.fg.MissingInputError:
369386
# The node failed, so we must add the node's parents to
370387
# the stack of nodes to try to draw from. We exclude the
371388
# nodes in the `params` list.
372389
stack.extend([node for node in named_nodes_parents[next_]
373390
if node is not None and
374-
node.name not in drawn and
391+
(node, size) not in drawn and
375392
node not in params])
376393

377394
# the below makes sure the graph is evaluated in order
@@ -386,15 +403,15 @@ def draw_values(params, point=None, size=None):
386403
missing_inputs = set()
387404
for param_idx in to_eval:
388405
param = params[param_idx]
389-
if param in drawn:
390-
evaluated[param_idx] = drawn[param]
406+
if (param, size) in drawn:
407+
evaluated[param_idx] = drawn[(param, size)]
391408
else:
392409
try: # might evaluate in a bad order,
393410
value = _draw_value(param,
394411
point=point,
395412
givens=givens.values(),
396413
size=size)
397-
evaluated[param_idx] = drawn[param] = value
414+
evaluated[param_idx] = drawn[(param, size)] = value
398415
givens[param.name] = (param, value)
399416
except theano.gof.fg.MissingInputError:
400417
missing_inputs.add(param_idx)
@@ -475,8 +492,11 @@ def _draw_value(param, point=None, givens=None, size=None):
475492
# reset shape to account for shape changes
476493
# with theano.shared inputs
477494
dist_tmp.shape = np.array([])
478-
val = np.atleast_1d(dist_tmp.random(point=point,
479-
size=None))
495+
# We want to draw values to infer the dist_shape,
496+
# we don't want to store these drawn values to the context
497+
with _DrawValuesContextBlocker():
498+
val = np.atleast_1d(dist_tmp.random(point=point,
499+
size=None))
480500
# Sometimes point may change the size of val but not the
481501
# distribution's shape
482502
if point and size is not None:
@@ -493,20 +513,34 @@ def _draw_value(param, point=None, givens=None, size=None):
493513
variables, values = list(zip(*givens))
494514
else:
495515
variables = values = []
496-
func = _compile_theano_function(param, variables)
516+
# We only truly care if the ancestors of param that were given
517+
# value have the matching dshape and val.shape
518+
param_ancestors = \
519+
set(theano.gof.graph.ancestors([param],
520+
blockers=list(variables))
521+
)
522+
inputs = [(var, val) for var, val in
523+
zip(variables, values)
524+
if var in param_ancestors]
525+
if inputs:
526+
input_vars, input_vals = list(zip(*inputs))
527+
else:
528+
input_vars = []
529+
input_vals = []
530+
func = _compile_theano_function(param, input_vars)
497531
if size is not None:
498532
size = np.atleast_1d(size)
499533
dshaped_variables = all((hasattr(var, 'dshape')
500-
for var in variables))
534+
for var in input_vars))
501535
if (values and dshaped_variables and
502536
not all(var.dshape == getattr(val, 'shape', tuple())
503-
for var, val in zip(variables, values))):
504-
output = np.array([func(*v) for v in zip(*values)])
537+
for var, val in zip(input_vars, input_vals))):
538+
output = np.array([func(*v) for v in zip(*input_vals)])
505539
elif (size is not None and any((val.ndim > var.ndim)
506-
for var, val in zip(variables, values))):
507-
output = np.array([func(*v) for v in zip(*values)])
540+
for var, val in zip(input_vars, input_vals))):
541+
output = np.array([func(*v) for v in zip(*input_vals)])
508542
else:
509-
output = func(*values)
543+
output = func(*input_vals)
510544
return output
511545
raise ValueError('Unexpected type in draw_value: %s' % type(param))
512546

@@ -515,7 +549,11 @@ def to_tuple(shape):
515549
"""Convert ints, arrays, and Nones to tuples"""
516550
if shape is None:
517551
return tuple()
518-
return tuple(np.atleast_1d(shape))
552+
temp = np.atleast_1d(shape)
553+
if temp.size == 0:
554+
return tuple()
555+
else:
556+
return tuple(temp)
519557

520558
def _is_one_d(dist_shape):
521559
if hasattr(dist_shape, 'dshape') and dist_shape.dshape in ((), (0,), (1,)):

pymc3/distributions/mixture.py

Lines changed: 142 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from ..math import logsumexp
66
from .dist_math import bound, random_choice
77
from .distribution import (Discrete, Distribution, draw_values,
8-
generate_samples, _DrawValuesContext)
8+
generate_samples, _DrawValuesContext,
9+
_DrawValuesContextBlocker, to_tuple)
910
from .continuous import get_tau_sigma, Normal
1011

1112

@@ -102,6 +103,35 @@ def __init__(self, w, comp_dists, *args, **kwargs):
102103

103104
super().__init__(shape, dtype, defaults=defaults, *args, **kwargs)
104105

106+
@property
107+
def comp_dists(self):
108+
return self._comp_dists
109+
110+
@comp_dists.setter
111+
def comp_dists(self, _comp_dists):
112+
self._comp_dists = _comp_dists
113+
# Tests if the comp_dists can call random with non None size
114+
with _DrawValuesContextBlocker():
115+
if isinstance(self.comp_dists, (list, tuple)):
116+
try:
117+
[comp_dist.random(size=23)
118+
for comp_dist in self.comp_dists]
119+
self._comp_dists_vect = True
120+
except Exception:
121+
# The comp_dists cannot call random with non None size or
122+
# without knowledge of the point so we assume that we will
123+
# have to iterate calls to random to get the correct size
124+
self._comp_dists_vect = False
125+
else:
126+
try:
127+
self.comp_dists.random(size=23)
128+
self._comp_dists_vect = True
129+
except Exception:
130+
# The comp_dists cannot call random with non None size or
131+
# without knowledge of the point so we assume that we will
132+
# have to iterate calls to random to get the correct size
133+
self._comp_dists_vect = False
134+
105135
def _comp_logp(self, value):
106136
comp_dists = self.comp_dists
107137

@@ -131,13 +161,33 @@ def _comp_modes(self):
131161
axis=1))
132162

133163
def _comp_samples(self, point=None, size=None):
134-
try:
135-
samples = self.comp_dists.random(point=point, size=size)
136-
except AttributeError:
137-
samples = np.column_stack([comp_dist.random(point=point, size=size)
138-
for comp_dist in self.comp_dists])
139-
140-
return np.squeeze(samples)
164+
if self._comp_dists_vect or size is None:
165+
try:
166+
return self.comp_dists.random(point=point, size=size)
167+
except AttributeError:
168+
samples = np.array([comp_dist.random(point=point, size=size)
169+
for comp_dist in self.comp_dists])
170+
samples = np.moveaxis(samples, 0, samples.ndim - 1)
171+
else:
172+
# We must iterate the calls to random manually
173+
size = to_tuple(size)
174+
_size = int(np.prod(size))
175+
try:
176+
samples = np.array([self.comp_dists.random(point=point,
177+
size=None)
178+
for _ in range(_size)])
179+
samples = np.reshape(samples, size + samples.shape[1:])
180+
except AttributeError:
181+
samples = np.array([[comp_dist.random(point=point, size=None)
182+
for _ in range(_size)]
183+
for comp_dist in self.comp_dists])
184+
samples = np.moveaxis(samples, 0, samples.ndim - 1)
185+
samples = np.reshape(samples, size + samples[1:])
186+
187+
if samples.shape[-1] == 1:
188+
return samples[..., 0]
189+
else:
190+
return samples
141191

142192
def logp(self, value):
143193
w = self.w
@@ -147,42 +197,99 @@ def logp(self, value):
147197
broadcast_conditions=False)
148198

149199
def random(self, point=None, size=None):
200+
# Convert size to tuple
201+
size = to_tuple(size)
202+
# Draw mixture weights and a sample from each mixture to infer shape
150203
with _DrawValuesContext() as draw_context:
151-
w = draw_values([self.w], point=point)[0]
204+
# We first need to check w and comp_tmp shapes and re compute size
205+
w = draw_values([self.w], point=point, size=size)[0]
206+
with _DrawValuesContextBlocker():
207+
# We don't want to store the values drawn here in the context
208+
# because they wont have the correct size
152209
comp_tmp = self._comp_samples(point=point, size=None)
153-
if np.asarray(self.shape).size == 0:
154-
distshape = np.asarray(np.broadcast(w, comp_tmp).shape)[..., :-1]
210+
211+
# When size is not None, it's hard to tell the w parameter shape
212+
if size is not None and w.shape[:len(size)] == size:
213+
w_shape = w.shape[len(size):]
214+
else:
215+
w_shape = w.shape
216+
217+
# Try to determine parameter shape and dist_shape
218+
param_shape = np.broadcast(np.empty(w_shape),
219+
comp_tmp).shape
220+
if np.asarray(self.shape).size != 0:
221+
dist_shape = np.broadcast(np.empty(self.shape),
222+
np.empty(param_shape[:-1])).shape
223+
else:
224+
dist_shape = param_shape[:-1]
225+
226+
# When size is not None, maybe dist_shape partially overlaps with size
227+
if size is not None:
228+
if size == dist_shape:
229+
size = None
230+
elif size[-len(dist_shape):] == dist_shape:
231+
size = size[:len(size) - len(dist_shape)]
232+
233+
# We get an integer _size instead of a tuple size for drawing the
234+
# mixture, then we just reshape the output
235+
if size is None:
236+
_size = None
155237
else:
156-
distshape = np.asarray(self.shape)
238+
_size = int(np.prod(size))
239+
240+
# Now we must broadcast w to the shape that considers size, dist_shape
241+
# and param_shape. However, we must take care with the cases in which
242+
# dist_shape and param_shape overlap
243+
if size is not None and w.shape[:len(size)] == size:
244+
if w.shape[:len(size + dist_shape)] != (size + dist_shape):
245+
# To allow w to broadcast, we insert new axis in between the
246+
# "size" axis and the "mixture" axis
247+
_w = w[(slice(None),) * len(size) + # Index the size axis
248+
(np.newaxis,) * len(dist_shape) + # Add new axis for the dist_shape
249+
(slice(None),)] # Close with the slice of mixture components
250+
w = np.broadcast_to(_w, size + dist_shape + (param_shape[-1],))
251+
elif size is not None:
252+
w = np.broadcast_to(w, size + dist_shape + (param_shape[-1],))
253+
else:
254+
w = np.broadcast_to(w, dist_shape + (param_shape[-1],))
157255

158-
# Normalize inputs
159-
w /= w.sum(axis=-1, keepdims=True)
256+
# Compute the total size of the mixture's random call with size
257+
if _size is not None:
258+
output_size = int(_size * np.prod(dist_shape) * param_shape[-1])
259+
else:
260+
output_size = int(np.prod(dist_shape) * param_shape[-1])
261+
# Get the size we need for the mixture's random call
262+
mixture_size = int(output_size // np.prod(comp_tmp.shape))
263+
if mixture_size == 1 and _size is None:
264+
mixture_size = None
265+
266+
# Semiflatten the mixture weights. The last axis is the number of
267+
# mixture mixture components, and the rest is all about size,
268+
# dist_shape and broadcasting
269+
w = np.reshape(w, (-1, w.shape[-1]))
270+
# Normalize mixture weights
271+
w = w / w.sum(axis=-1, keepdims=True)
160272

161273
w_samples = generate_samples(random_choice,
162274
p=w,
163275
broadcast_shape=w.shape[:-1] or (1,),
164-
dist_shape=distshape,
165-
size=size).squeeze()
166-
if (size is None) or (distshape.size == 0):
167-
with draw_context:
168-
comp_samples = self._comp_samples(point=point, size=size)
169-
if comp_samples.ndim > 1:
170-
samples = np.squeeze(comp_samples[np.arange(w_samples.size), ..., w_samples])
171-
else:
172-
samples = np.squeeze(comp_samples[w_samples])
276+
dist_shape=w.shape[:-1] or (1,),
277+
size=size)
278+
# Sample from the mixture
279+
with draw_context:
280+
mixed_samples = self._comp_samples(point=point,
281+
size=mixture_size)
282+
w_samples = w_samples.flatten()
283+
# Semiflatten the mixture to be able to zip it with w_samples
284+
mixed_samples = np.reshape(mixed_samples, (-1, comp_tmp.shape[-1]))
285+
# Select the samples from the mixture
286+
samples = np.array([mixed[choice] for choice, mixed in
287+
zip(w_samples, mixed_samples)])
288+
# Reshape the samples to the correct output shape
289+
if size is None:
290+
samples = np.reshape(samples, dist_shape)
173291
else:
174-
if w_samples.ndim == 1:
175-
w_samples = np.reshape(np.tile(w_samples, size), (size,) + w_samples.shape)
176-
samples = np.zeros((size,)+tuple(distshape))
177-
with draw_context:
178-
for i in range(size):
179-
w_tmp = w_samples[i, :]
180-
comp_tmp = self._comp_samples(point=point, size=None)
181-
if comp_tmp.ndim > 1:
182-
samples[i, :] = np.squeeze(comp_tmp[np.arange(w_tmp.size), ..., w_tmp])
183-
else:
184-
samples[i, :] = np.squeeze(comp_tmp[w_tmp])
185-
292+
samples = np.reshape(samples, size + dist_shape)
186293
return samples
187294

188295

0 commit comments

Comments
 (0)