Skip to content

Commit 686b81d

Browse files
lucianopazjunpenglao
authored andcommitted
Fix for #3210 without computing the Bayes network (#3273)
* Fix for #3225. Made Triangular `c` attribute be handled consistently with scipy.stats. Added test and updated example code. * Fix for #3210 which uses a completely different approach than PR #3214. It uses a context manager inside `draw_values` that makes all the values drawn from `TensorVariables` or `MultiObservedRV`s available to nested calls of the original call to `draw_values`. It is partly inspired by how Edward2 approaches the problem of forward sampling. Ed2 tensors fix a `_values` attribute after they first call `sample` and then only return that. They can do it because of their functional scheme, where the entire graph is recreated each time the generative function is called. Our object oriented paradigm cannot set a fixed _values, it has to know it is in the context of a single `draw_values` call. That is why I opted for context managers to store the drawn values. * Removed leftover print statement * Added release notes and draw values context managers to mixture and multivariate distributions that make many calls to draw_values or other distributions random methods within their own random.
1 parent 589aee1 commit 686b81d

File tree

7 files changed

+518
-320
lines changed

7 files changed

+518
-320
lines changed

RELEASE-NOTES.md

+11
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,24 @@
1010
- Add log CDF functions to continuous distributions: `Beta`, `Cauchy`, `ExGaussian`, `Exponential`, `Flat`, `Gumbel`, `HalfCauchy`, `HalfFlat`, `HalfNormal`, `Laplace`, `Logistic`, `Lognormal`, `Normal`, `Pareto`, `StudentT`, `Triangular`, `Uniform`, `Wald`, `Weibull`.
1111
- Behavior of `sample_posterior_predictive` is now to produce posterior predictive samples, in order, from all values of the `trace`. Previously, by default it would produce 1 chain worth of samples, using a random selection from the `trace` (#3212)
1212
- Show diagnostics for initial energy errors in HMC and NUTS.
13+
- PR #3273 has added the `distributions.distribution._DrawValuesContext` context
14+
manager. This is used to store the values already drawn in nested `random`
15+
and `draw_values` calls, enabling `draw_values` to draw samples from the
16+
joint probability distribution of RVs and not the marginals. Custom
17+
distributions that must call `draw_values` several times in their `random`
18+
method, or that invoke many calls to other distribution's `random` methods
19+
(e.g. mixtures) must do all of these calls under the same `_DrawValuesContext`
20+
context manager instance. If they do not, the conditional relations between
21+
the distribution's parameters could be broken, and `random` could return
22+
values drawn from an incorrect distribution.
1323

1424
### Maintenance
1525

1626
- Big rewrite of documentation (#3275)
1727
- Fixed Triangular distribution `c` attribute handling in `random` and updated sample codes for consistency (#3225)
1828
- Refactor SMC and properly compute marginal likelihood (#3124)
1929
- Removed use of deprecated `ymin` keyword in matplotlib's `Axes.set_ylim` (#3279)
30+
- Fix for #3210. Now `distribution.draw_values(params)`, will draw the `params` values from their joint probability distribution and not from combinations of their marginals (Refer to PR #3273).
2031

2132
### Deprecations
2233

pymc3/distributions/distribution.py

+210-97
Large diffs are not rendered by default.

pymc3/distributions/mixture.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from pymc3.util import get_variable_name
55
from ..math import logsumexp
66
from .dist_math import bound, random_choice
7-
from .distribution import Discrete, Distribution, draw_values, generate_samples
7+
from .distribution import (Discrete, Distribution, draw_values,
8+
generate_samples, _DrawValuesContext)
89
from .continuous import get_tau_sd, Normal
910

1011

@@ -147,8 +148,9 @@ def logp(self, value):
147148
broadcast_conditions=False)
148149

149150
def random(self, point=None, size=None):
150-
w = draw_values([self.w], point=point)[0]
151-
comp_tmp = self._comp_samples(point=point, size=None)
151+
with _DrawValuesContext() as draw_context:
152+
w = draw_values([self.w], point=point)[0]
153+
comp_tmp = self._comp_samples(point=point, size=None)
152154
if np.asarray(self.shape).size == 0:
153155
distshape = np.asarray(np.broadcast(w, comp_tmp).shape)[..., :-1]
154156
else:
@@ -163,7 +165,8 @@ def random(self, point=None, size=None):
163165
dist_shape=distshape,
164166
size=size).squeeze()
165167
if (size is None) or (distshape.size == 0):
166-
comp_samples = self._comp_samples(point=point, size=size)
168+
with draw_context:
169+
comp_samples = self._comp_samples(point=point, size=size)
167170
if comp_samples.ndim > 1:
168171
samples = np.squeeze(comp_samples[np.arange(w_samples.size), ..., w_samples])
169172
else:
@@ -172,13 +175,14 @@ def random(self, point=None, size=None):
172175
if w_samples.ndim == 1:
173176
w_samples = np.reshape(np.tile(w_samples, size), (size,) + w_samples.shape)
174177
samples = np.zeros((size,)+tuple(distshape))
175-
for i in range(size):
176-
w_tmp = w_samples[i, :]
177-
comp_tmp = self._comp_samples(point=point, size=None)
178-
if comp_tmp.ndim > 1:
179-
samples[i, :] = np.squeeze(comp_tmp[np.arange(w_tmp.size), ..., w_tmp])
180-
else:
181-
samples[i, :] = np.squeeze(comp_tmp[w_tmp])
178+
with draw_context:
179+
for i in range(size):
180+
w_tmp = w_samples[i, :]
181+
comp_tmp = self._comp_samples(point=point, size=None)
182+
if comp_tmp.ndim > 1:
183+
samples[i, :] = np.squeeze(comp_tmp[np.arange(w_tmp.size), ..., w_tmp])
184+
else:
185+
samples[i, :] = np.squeeze(comp_tmp[w_tmp])
182186

183187
return samples
184188

pymc3/distributions/multivariate.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from pymc3.theanof import floatX
1717
from . import transforms
1818
from pymc3.util import get_variable_name
19-
from .distribution import Continuous, Discrete, draw_values, generate_samples
19+
from .distribution import (Continuous, Discrete, draw_values, generate_samples,
20+
_DrawValuesContext)
2021
from ..model import Deterministic
2122
from .continuous import ChiSquared, Normal
2223
from .special import gammaln, multigammaln
@@ -338,18 +339,19 @@ def __init__(self, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None,
338339
self.mean = self.median = self.mode = self.mu = self.mu
339340

340341
def random(self, point=None, size=None):
341-
nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
342-
if self._cov_type == 'cov':
343-
cov, = draw_values([self.cov], point=point, size=size)
344-
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov)
345-
elif self._cov_type == 'tau':
346-
tau, = draw_values([self.tau], point=point, size=size)
347-
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
348-
else:
349-
chol, = draw_values([self.chol_cov], point=point, size=size)
350-
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)
342+
with _DrawValuesContext():
343+
nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
344+
if self._cov_type == 'cov':
345+
cov, = draw_values([self.cov], point=point, size=size)
346+
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov)
347+
elif self._cov_type == 'tau':
348+
tau, = draw_values([self.tau], point=point, size=size)
349+
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
350+
else:
351+
chol, = draw_values([self.chol_cov], point=point, size=size)
352+
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)
351353

352-
samples = dist.random(point, size)
354+
samples = dist.random(point, size)
353355

354356
chi2 = np.random.chisquare
355357
return (np.sqrt(nu) * samples.T / chi2(nu, size)).T + mu

pymc3/model.py

+10
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,16 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
13861386
self.distribution = distribution
13871387
self.scaling = _get_scaling(total_size, self.logp_elemwiset.shape, self.logp_elemwiset.ndim)
13881388

1389+
# Make hashable by id for draw_values
1390+
def __hash__(self):
1391+
return id(self)
1392+
1393+
def __eq__(self, other):
1394+
return self.id == other.id
1395+
1396+
def __ne__(self, other):
1397+
return not self == other
1398+
13891399

13901400
def _walk_up_rv(rv):
13911401
"""Walk up theano graph to get inputs for deterministic RV."""

pymc3/tests/test_random.py

+58
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import pymc3 as pm
22
import numpy as np
3+
from numpy import random as nr
34
import numpy.testing as npt
45
import pytest
56
import theano.tensor as tt
67
import theano
78

89
from pymc3.distributions.distribution import _draw_value, draw_values
10+
from .helpers import SeededTest
911

1012

1113
def test_draw_value():
@@ -88,3 +90,59 @@ def test_dep_vars(self):
8890
assert all([np.all(val1 != val2), np.all(val1 != val3),
8991
np.all(val1 != val4), np.all(val2 != val3),
9092
np.all(val2 != val4), np.all(val3 != val4)])
93+
94+
95+
class TestJointDistributionDrawValues(SeededTest):
96+
def test_joint_distribution(self):
97+
with pm.Model() as model:
98+
a = pm.Normal('a', mu=0, sd=100)
99+
b = pm.Normal('b', mu=a, sd=1e-8)
100+
c = pm.Normal('c', mu=a, sd=1e-8)
101+
d = pm.Deterministic('d', b + c)
102+
103+
# Expected RVs
104+
N = 1000
105+
norm = np.random.randn(3, N)
106+
eA = norm[0] * 100
107+
eB = eA + norm[1] * 1e-8
108+
eC = eA + norm[2] * 1e-8
109+
eD = eB + eC
110+
111+
# Drawn RVs
112+
nr.seed(self.random_seed)
113+
# A, B, C, D = list(zip(*[draw_values([a, b, c, d]) for i in range(N)]))
114+
A, B, C, D = draw_values([a, b, c, d], size=N)
115+
A = np.array(A).flatten()
116+
B = np.array(B).flatten()
117+
C = np.array(C).flatten()
118+
D = np.array(D).flatten()
119+
120+
# Assert that the drawn samples match the expected values
121+
assert np.allclose(eA, A)
122+
assert np.allclose(eB, B)
123+
assert np.allclose(eC, C)
124+
assert np.allclose(eD, D)
125+
126+
# Assert that A, B and C have the expected difference
127+
assert np.all(np.abs(A - B) < 1e-6)
128+
assert np.all(np.abs(A - C) < 1e-6)
129+
assert np.all(np.abs(B - C) < 1e-6)
130+
131+
# Marginal draws
132+
mA = np.array([draw_values([a]) for i in range(N)]).flatten()
133+
mB = np.array([draw_values([b]) for i in range(N)]).flatten()
134+
mC = np.array([draw_values([c]) for i in range(N)]).flatten()
135+
# Also test the with model context of draw_values
136+
with model:
137+
mD = np.array([draw_values([d]) for i in range(N)]).flatten()
138+
139+
# Assert that the marginal distributions have different sample values
140+
assert not np.all(np.abs(B - mB) < 1e-2)
141+
assert not np.all(np.abs(C - mC) < 1e-2)
142+
assert not np.all(np.abs(D - mD) < 1e-2)
143+
144+
# Assert that the marginal distributions do not have high cross
145+
# correlation
146+
assert np.abs(np.corrcoef(mA, mB)[0, 1]) < 0.1
147+
assert np.abs(np.corrcoef(mA, mC)[0, 1]) < 0.1
148+
assert np.abs(np.corrcoef(mB, mC)[0, 1]) < 0.1

0 commit comments

Comments
 (0)