Skip to content

Commit 538d0ba

Browse files
taku-ytwiecki
authored andcommitted
ENH Add support for autoencoding variational Bayes (pymc-devs#1236)
1 parent cb9dda8 commit 538d0ba

File tree

8 files changed

+1327
-171
lines changed

8 files changed

+1327
-171
lines changed

docs/source/notebooks/gaussian-mixture-model-advi.ipynb

Lines changed: 62 additions & 47 deletions
Large diffs are not rendered by default.

docs/source/notebooks/lda-advi-aevb.ipynb

Lines changed: 882 additions & 0 deletions
Large diffs are not rendered by default.

pymc3/distributions/transforms.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from ..model import FreeRV
44
from ..theanof import gradient
55
from .distribution import Distribution
6-
from ..math import logit
6+
from ..math import logit, invlogit
77
import numpy as np
88

9-
__all__ = ['transform', 'stick_breaking', 'logodds', 'log', 'sum_to_1']
9+
__all__ = ['transform', 'stick_breaking', 'logodds', 'log', 'sum_to_1', 't_stick_breaking']
1010

1111

1212
class Transform(object):
@@ -85,17 +85,14 @@ def forward(self, x):
8585

8686
log = Log()
8787

88-
inverse_logit = tt.nnet.sigmoid
89-
90-
9188
class LogOdds(ElemwiseTransform):
9289
name = "logodds"
9390

9491
def __init__(self):
9592
pass
9693

9794
def backward(self, x):
98-
return inverse_logit(x)
95+
return invlogit(x, 0.0)
9996

10097
def forward(self, x):
10198
return logit(x)
@@ -185,10 +182,18 @@ def jacobian_det(self, x):
185182
class StickBreaking(Transform):
186183
"""Transforms K dimensional simplex space (values in [0,1] and sum to 1) to K - 1 vector of real values.
187184
Primarily borrowed from the STAN implementation.
185+
186+
Parameters
187+
----------
188+
eps : float, positive value
189+
A small value for numerical stability in invlogit.
188190
"""
189191

190192
name = "stickbreaking"
191193

194+
def __init__(self, eps=0.0):
195+
self.eps = eps
196+
192197
def forward(self, x_):
193198
x = x_.T
194199
# reverse cumsum
@@ -206,7 +211,7 @@ def backward(self, y_):
206211
Km1 = y.shape[0]
207212
k = tt.arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)]
208213
eq_share = logit(1./(Km1 + 1 - k)) #- tt.log(Km1 - k)
209-
z = inverse_logit(y + eq_share)
214+
z = invlogit(y + eq_share, self.eps)
210215
yl = tt.concatenate([z, tt.ones(y[:1].shape)])
211216
yu = tt.concatenate([tt.ones(y[:1].shape), 1-z])
212217
S = tt.extra_ops.cumprod(yu, 0)
@@ -219,12 +224,13 @@ def jacobian_det(self, y_):
219224
k = tt.arange(Km1)[(slice(None), ) + (None, ) * (y.ndim - 1)]
220225
eq_share = logit(1./(Km1 + 1 - k)) # -tt.log(Km1 - k)
221226
yl = y + eq_share
222-
yu = tt.concatenate([tt.ones(y[:1].shape), 1-inverse_logit(yl)])
227+
yu = tt.concatenate([tt.ones(y[:1].shape), 1-invlogit(yl, self.eps)])
223228
S = tt.extra_ops.cumprod(yu, 0)
224229
return tt.sum(tt.log(S[:-1]) - tt.log1p(tt.exp(yl)) - tt.log1p(tt.exp(-yl)), 0).T
225230

226231
stick_breaking = StickBreaking()
227232

233+
t_stick_breaking = lambda eps: StickBreaking(eps)
228234

229235
class Circular(ElemwiseTransform):
230236
"""Transforms a linear space into a circular one.

pymc3/math.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@ def logsumexp(x, axis=None):
1515
x_max = tt.max(x, axis=axis, keepdims=True)
1616
return tt.log(tt.sum(tt.exp(x - x_max), axis=axis, keepdims=True)) + x_max
1717

18-
def invlogit(x):
19-
p_min = sys.float_info.epsilon
20-
21-
return (1 - 2 * p_min) / (1 + tt.exp(-x)) + p_min
18+
def invlogit(x, eps=sys.float_info.epsilon):
19+
return (1 - 2 * eps) / (1 + tt.exp(-x)) + eps
2220

2321
def logit(p):
2422
return tt.log(p/(1-p))

pymc3/tests/test_advi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import pymc3 as pm
33
from pymc3 import Model, Normal, DiscreteUniform, Poisson, switch, Exponential
44
from pymc3.theanof import inputvars
5-
from pymc3.variational.advi import variational_gradient_estimate, advi, advi_minibatch, sample_vp
5+
from pymc3.variational import advi, advi_minibatch, sample_vp
6+
from pymc3.variational.advi import variational_gradient_estimate
67
from theano import function, shared
78
import theano.tensor as tt
89

pymc3/variational/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .advi import advi, advi_minibatch, sample_vp
1+
from .advi import advi, sample_vp
2+
from .advi_minibatch import advi_minibatch

pymc3/variational/advi.py

Lines changed: 25 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -98,108 +98,6 @@ def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False,
9898
w[var] = np.exp(w[var])
9999
return ADVIFit(u, w, elbos)
100100

101-
def advi_minibatch(vars=None, start=None, model=None, n=5000, n_mcsamples=1,
102-
minibatch_RVs=None, minibatch_tensors=None, minibatches=None, total_size=None,
103-
learning_rate=.001, epsilon=.1, random_seed=20090425, verbose=1):
104-
"""Run mini-batch ADVI.
105-
106-
minibatch_RVs, minibatch_tensors and minibatches should be in the
107-
same order.
108-
109-
Parameters
110-
----------
111-
vars : object
112-
Random variables.
113-
start : Dict or None
114-
Initial values of parameters (variational means).
115-
model : Model
116-
Probabilistic model.
117-
n : int
118-
Number of interations updating parameters.
119-
n_mcsamples : int
120-
Number of Monte Carlo samples to approximate ELBO.
121-
minibatch_RVs : list of ObservedRVs
122-
Random variables for mini-batch.
123-
minibatch_tensors : list of tensors
124-
Tensors used to create ObservedRVs in minibatch_RVs.
125-
minibatches : list of generators
126-
Generates minibatches when calling next().
127-
totalsize : int
128-
Total size of training samples.
129-
learning_rate: float
130-
Adagrad base learning rate.
131-
epsilon : float
132-
Offset in denominator of the scale of learning rate in Adagrad.
133-
random_seed : int
134-
Seed to initialize random state.
135-
136-
Returns
137-
-------
138-
ADVIFit
139-
Named tuple, which includes 'means', 'stds', and 'elbo_vals'.
140-
"""
141-
142-
model = modelcontext(model)
143-
if start is None:
144-
start = model.test_point
145-
146-
if vars is None:
147-
vars = model.vars
148-
149-
vars = set(inputvars(vars)) - set(minibatch_RVs)
150-
151-
check_discrete_rvs(vars)
152-
153-
# Create variational gradient tensor
154-
grad, elbo, shared, uw = variational_gradient_estimate(
155-
vars, model, minibatch_RVs, minibatch_tensors, total_size,
156-
n_mcsamples=n_mcsamples, random_seed=random_seed)
157-
158-
# Set starting values
159-
for var, share in shared.items():
160-
share.set_value(start[str(var)])
161-
162-
order = ArrayOrdering(vars)
163-
bij = DictToArrayBijection(order, start)
164-
u_start = bij.map(start)
165-
w_start = np.zeros_like(u_start)
166-
uw_start = np.concatenate([u_start, w_start])
167-
168-
shared_inarray = theano.shared(uw_start, 'uw_shared')
169-
grad = theano.clone(grad, { uw : shared_inarray }, strict=False)
170-
elbo = theano.clone(elbo, { uw : shared_inarray }, strict=False)
171-
updates = adagrad(grad, shared_inarray, learning_rate=learning_rate, epsilon=epsilon, n=10)
172-
173-
# Create in-place update function
174-
tensors, givens = replace_shared_minibatch_tensors(minibatch_tensors)
175-
f = theano.function(tensors, [shared_inarray, grad, elbo],
176-
updates=updates, givens=givens)
177-
178-
# Run adagrad steps
179-
elbos = np.empty(n)
180-
for i in range(n):
181-
uw_i, g, e = f(*next(minibatches))
182-
elbos[i] = e
183-
if verbose and not i % (n//10):
184-
if not i:
185-
print('Iteration {0} [{1}%]: ELBO = {2}'.format(i, 100*i//n, e.round(2)))
186-
else:
187-
avg_elbo = elbos[i-n//10:i].mean()
188-
print('Iteration {0} [{1}%]: Average ELBO = {2}'.format(i, 100*i//n, avg_elbo.round(2)))
189-
190-
if verbose:
191-
avg_elbo = elbos[i-n//10:i].mean()
192-
print('Finished [100%]: Average ELBO = {}'.format(avg_elbo.round(2)))
193-
194-
l = int(uw_i.size / 2)
195-
196-
u = bij.rmap(uw_i[:l])
197-
w = bij.rmap(uw_i[l:])
198-
# w is in log space
199-
for var in w.keys():
200-
w[var] = np.exp(w[var])
201-
return ADVIFit(u, w, elbos)
202-
203101
def replace_shared_minibatch_tensors(minibatch_tensors):
204102
"""Replace shared variables in minibatch tensors with normal tensors.
205103
"""
@@ -326,9 +224,10 @@ def adagrad(grad, param, learning_rate, epsilon, n):
326224
tt.sqrt(accu_sum + epsilon))
327225
return updates
328226

329-
def sample_vp(vparams, draws=1000, model=None, random_seed=20090425,
330-
hide_transformed=True):
331-
"""Draw samples from variational posterior.
227+
def sample_vp(
228+
vparams, draws=1000, model=None, local_RVs=None, random_seed=20090425,
229+
hide_transformed=True):
230+
"""Draw samples from variational posterior.
332231
333232
Parameters
334233
----------
@@ -356,17 +255,33 @@ def sample_vp(vparams, draws=1000, model=None, random_seed=20090425,
356255
'stds': vparams.stds
357256
}
358257

258+
ds = model.deterministics
259+
get_transformed = lambda v: v if v not in ds else v.transformed
260+
rvs = lambda x: [get_transformed(v) for v in x] if x is not None else []
261+
262+
global_RVs = list(set(model.free_RVs) - set(rvs(local_RVs)))
263+
359264
# Make dict for replacements of random variables
360265
r = MRG_RandomStreams(seed=random_seed)
361266
updates = {}
362-
for var in model.free_RVs:
363-
u = theano.shared(vparams['means'][str(var)]).ravel()
364-
w = theano.shared(vparams['stds'][str(var)]).ravel()
267+
for v in global_RVs:
268+
u = theano.shared(vparams['means'][str(v)]).ravel()
269+
w = theano.shared(vparams['stds'][str(v)]).ravel()
365270
n = r.normal(size=u.tag.test_value.shape)
366-
updates.update({var: (n * w + u).reshape(var.tag.test_value.shape)})
367-
vars = model.free_RVs
271+
updates.update({v: (n * w + u).reshape(v.tag.test_value.shape)})
272+
273+
if local_RVs is not None:
274+
ds = model.deterministics
275+
get_transformed = lambda v: v if v not in ds else v.transformed
276+
for v_, (uw, _) in local_RVs.items():
277+
v = get_transformed(v_)
278+
u = uw[0].ravel()
279+
w = uw[1].ravel()
280+
n = r.normal(size=u.tag.test_value.shape)
281+
updates.update({v: (n * tt.exp(w) + u).reshape(v.tag.test_value.shape)})
368282

369283
# Replace some nodes of the graph with variational distributions
284+
vars = model.free_RVs
370285
samples = theano.clone(vars, updates)
371286
f = theano.function([], samples)
372287

0 commit comments

Comments
 (0)