Skip to content

Commit 1c5ca44

Browse files
taku-ytwiecki
authored andcommitted
Add check for discrete variables.
1 parent 8513e8c commit 1c5ca44

File tree

2 files changed

+84
-6
lines changed

2 files changed

+84
-6
lines changed

pymc3/tests/test_advi.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import numpy as np
2-
from pymc3 import Model, Normal
2+
from pymc3 import Model, Normal, HalfNormal, DiscreteUniform, Poisson, switch, Exponential
33
from pymc3.theanof import inputvars
4-
from pymc3.variational.advi import variational_gradient_estimate
4+
from pymc3.variational.advi import variational_gradient_estimate, advi, advi_minibatch
55
from theano import function
6+
import theano.tensor as tt
67

7-
def test_advi():
8+
from nose.tools import assert_raises
9+
10+
def test_elbo():
811
mu0 = 1.5
912
sigma = 1.0
1013
y_obs = np.array([1.6, 1.4])
@@ -18,7 +21,7 @@ def test_advi():
1821

1922
# Create variational gradient tensor
2023
grad, elbo, shared, uw = variational_gradient_estimate(
21-
vars, model, n_mcsamples=1000000)
24+
vars, model, n_mcsamples=10000)
2225

2326
# Variational posterior parameters
2427
uw_ = np.array([1.88, np.log(1)])
@@ -34,3 +37,57 @@ def test_advi():
3437
0.5 * (np.log(2 * np.pi) + 1))
3538

3639
np.testing.assert_allclose(elbo_mc, elbo_true, rtol=0, atol=1e-1)
40+
41+
disaster_data = np.ma.masked_values([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
42+
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
43+
2, 2, 3, 4, 2, 1, 3, -999, 2, 1, 1, 1, 1, 3, 0, 0,
44+
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
45+
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
46+
3, 3, 1, -999, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
47+
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1], value=-999)
48+
year = np.arange(1851, 1962)
49+
50+
def test_check_discrete():
51+
with Model() as disaster_model:
52+
switchpoint = DiscreteUniform('switchpoint', lower=year.min(), upper=year.max(), testval=1900)
53+
54+
# Priors for pre- and post-switch rates number of disasters
55+
early_rate = Exponential('early_rate', 1)
56+
late_rate = Exponential('late_rate', 1)
57+
58+
# Allocate appropriate Poisson rates to years before and after current
59+
rate = switch(switchpoint >= year, early_rate, late_rate)
60+
61+
disasters = Poisson('disasters', rate, observed=disaster_data)
62+
63+
# This should raise ValueError
64+
assert_raises(ValueError, advi, model=disaster_model, n=10)
65+
66+
def test_check_discrete_minibatch():
67+
disaster_data_t = tt.vector()
68+
disaster_data_t.tag.test_value = np.zeros(len(disaster_data))
69+
70+
with Model() as disaster_model:
71+
72+
switchpoint = DiscreteUniform(
73+
'switchpoint', lower=year.min(), upper=year.max(), testval=1900)
74+
75+
# Priors for pre- and post-switch rates number of disasters
76+
early_rate = Exponential('early_rate', 1)
77+
late_rate = Exponential('late_rate', 1)
78+
79+
# Allocate appropriate Poisson rates to years before and after current
80+
rate = switch(switchpoint >= year, early_rate, late_rate)
81+
82+
disasters = Poisson('disasters', rate, observed=disaster_data_t)
83+
84+
def create_minibatch():
85+
while True:
86+
return disaster_data
87+
88+
# This should raise ValueError
89+
assert_raises(
90+
ValueError, advi_minibatch, model=disaster_model, n=10,
91+
minibatch_RVs=[disasters], minibatch_tensors=[disaster_data_t],
92+
minibatches=[create_minibatch()], verbose=False)
93+

pymc3/variational/advi.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from scipy import optimize
88
import numpy as np
99
from ..core import *
10+
from ..distributions import Discrete
11+
from ..distributions.transforms import TransformedDistribution
12+
from ..model import ObservedRV
1013

1114
import theano
1215
from ..theanof import make_shared_replacements, join_nonshared_inputs, CallableTensor, gradient
@@ -19,6 +22,22 @@
1922

2023
ADVIFit = namedtuple('ADVIFit', 'means, stds, elbo_vals')
2124

25+
def is_discreteRV(var):
26+
dist = var.distribution
27+
28+
# Transformed distributions are continuous. (is it true?)
29+
if isinstance(dist, TransformedDistribution):
30+
return False
31+
32+
return isinstance(dist, Discrete)
33+
34+
def check_discrete_rvs(vars):
35+
"""Check that vars not include discrete variables, excepting ObservedRVs.
36+
"""
37+
vars_ = [var for var in vars if not isinstance(var, ObservedRV)]
38+
if any([is_discreteRV(var) for var in vars_]):
39+
raise ValueError('Model should not include discrete RVs for ADVI.')
40+
2241
def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False,
2342
learning_rate=.001, epsilon=.1, verbose=1):
2443
"""Run ADVI.
@@ -50,13 +69,13 @@ def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False,
5069
model = modelcontext(model)
5170
if start is None:
5271
start = model.test_point
53-
import pdb
54-
pdb.set_trace()
5572

5673
if vars is None:
5774
vars = model.vars
5875
vars = inputvars(vars)
5976

77+
check_discrete_rvs(vars)
78+
6079
n_mcsamples = 100 if accurate_elbo else 1
6180

6281
# Create variational gradient tensor
@@ -130,6 +149,8 @@ def advi_minibatch(vars=None, start=None, model=None, n=5000, n_mcsamples=1,
130149

131150
vars = set(inputvars(vars)) - set(minibatch_RVs)
132151

152+
check_discrete_rvs(vars)
153+
133154
# Create variational gradient tensor
134155
grad, elbo, shared, uw = variational_gradient_estimate(
135156
vars, model, minibatch_RVs, minibatch_tensors, total_size,

0 commit comments

Comments
 (0)