Skip to content

Commit 9d890fc

Browse files
neerajpradfritzo
authored andcommitted
Batch 1 - RandomPrimitive wrappers for a few distributions (pyro-ppl#351)
1 parent a618b94 commit 9d890fc

File tree

7 files changed

+120
-174
lines changed

7 files changed

+120
-174
lines changed

pyro/distributions/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
lognormal = LogNormal()
2525
categorical = Categorical()
2626
bernoulli = RandomPrimitive(Bernoulli)
27-
beta = Beta()
27+
beta = RandomPrimitive(Beta)
2828
delta = Delta()
2929
exponential = Exponential()
3030
gamma = Gamma()
@@ -33,6 +33,6 @@
3333
normalchol = NormalChol()
3434
poisson = Poisson()
3535
uniform = Uniform()
36-
dirichlet = Dirichlet()
37-
cauchy = Cauchy()
38-
halfcauchy = HalfCauchy()
36+
dirichlet = RandomPrimitive(Dirichlet)
37+
cauchy = RandomPrimitive(Cauchy)
38+
halfcauchy = RandomPrimitive(HalfCauchy)

pyro/distributions/bernoulli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class Bernoulli(Distribution):
1313
"""
1414
enumerable = True
1515

16-
def __init__(self, ps=None, batch_size=None, log_pdf_mask=None):
16+
def __init__(self, ps, batch_size=None, log_pdf_mask=None, *args, **kwargs):
1717
"""
1818
:param ps: tensor of probabilities
1919
"""
@@ -22,8 +22,8 @@ def __init__(self, ps=None, batch_size=None, log_pdf_mask=None):
2222
if ps.dim() == 1 and batch_size is not None:
2323
self.ps = ps.expand(batch_size, ps.size(0))
2424
if log_pdf_mask is not None and log_pdf_mask.dim() == 1:
25-
self.log_pdf_mask = log_pdf_mask.expand(batch_size, ps.size(0))
26-
super(Bernoulli, self).__init__()
25+
self.log_pdf_mask = log_pdf_mask.expand(batch_size, log_pdf_mask.size(0))
26+
super(Bernoulli, self).__init__(*args, **kwargs)
2727

2828
def batch_shape(self, x=None):
2929
event_dim = 1

pyro/distributions/beta.py

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,75 +16,64 @@ class Beta(Distribution):
1616
Univariate beta distribution parameterized by alpha and beta
1717
"""
1818

19-
def _sanitize_input(self, alpha, beta):
20-
if alpha is not None:
21-
# stateless distribution
22-
return alpha, beta
23-
elif self.alpha is not None:
24-
# stateful distribution
25-
return self.alpha, self.beta
26-
else:
27-
raise ValueError("Parameter(s) were None")
28-
29-
def __init__(self, alpha=None, beta=None, batch_size=None, *args, **kwargs):
19+
def __init__(self, alpha, beta, batch_size=None, *args, **kwargs):
3020
"""
3121
Params:
3222
`alpha` - alpha
3323
`beta` - beta
3424
"""
3525
self.alpha = alpha
3626
self.beta = beta
37-
if alpha is not None:
38-
if alpha.dim() != beta.dim():
39-
raise ValueError("Alpha and beta need to have the same dimensions.")
40-
if alpha.dim() == 1 and beta.dim() == 1 and batch_size is not None:
41-
self.alpha = alpha.expand(batch_size, alpha.size(0))
42-
self.beta = beta.expand(batch_size, beta.size(0))
27+
if alpha.size() != beta.size():
28+
raise ValueError("Expected alpha.size() == beta.size(), but got {} vs {}"
29+
.format(alpha.size(), beta.size()))
30+
if alpha.dim() == 1 and beta.dim() == 1 and batch_size is not None:
31+
self.alpha = alpha.expand(batch_size, alpha.size(0))
32+
self.beta = beta.expand(batch_size, beta.size(0))
4333
super(Beta, self).__init__(*args, **kwargs)
4434

45-
def batch_shape(self, alpha=None, beta=None):
46-
alpha, beta = self._sanitize_input(alpha, beta)
35+
def batch_shape(self, x=None):
4736
event_dim = 1
37+
alpha = self.alpha
38+
if x is not None and x.size() != alpha.size():
39+
alpha = self.alpha.expand_as(x)
4840
return alpha.size()[:-event_dim]
4941

50-
def event_shape(self, alpha=None, beta=None):
51-
alpha, beta = self._sanitize_input(alpha, beta)
42+
def event_shape(self):
5243
event_dim = 1
53-
return alpha.size()[-event_dim:]
44+
return self.alpha.size()[-event_dim:]
45+
46+
def shape(self, x=None):
47+
return self.batch_shape(x) + self.event_shape()
5448

55-
def sample(self, alpha=None, beta=None):
49+
def sample(self):
5650
"""
5751
Un-reparameterizeable sampler.
5852
"""
59-
alpha, beta = self._sanitize_input(alpha, beta)
60-
np_sample = spr.beta.rvs(alpha.data.cpu().numpy(), beta.data.cpu().numpy())
53+
np_sample = spr.beta.rvs(self.alpha.data.cpu().numpy(),
54+
self.beta.data.cpu().numpy())
6155
if isinstance(np_sample, numbers.Number):
6256
np_sample = [np_sample]
63-
x = Variable(torch.Tensor(np_sample).type_as(alpha.data))
64-
x = x.expand(self.shape(alpha, beta))
57+
x = Variable(torch.Tensor(np_sample).type_as(self.alpha.data))
58+
x = x.expand(self.shape())
6559
return x
6660

67-
def batch_log_pdf(self, x, alpha=None, beta=None):
68-
alpha, beta = self._sanitize_input(alpha, beta)
69-
assert alpha.dim() == beta.dim()
70-
if alpha.size() != x.size():
71-
alpha = alpha.expand_as(x)
72-
beta = beta.expand_as(x)
61+
def batch_log_pdf(self, x):
62+
alpha = self.alpha.expand(self.shape(x))
63+
beta = self.beta.expand(self.shape(x))
7364
one = Variable(torch.ones(x.size()).type_as(alpha.data))
7465
ll_1 = (alpha - one) * torch.log(x)
7566
ll_2 = (beta - one) * torch.log(one - x)
7667
ll_3 = log_gamma(alpha + beta)
7768
ll_4 = -log_gamma(alpha)
7869
ll_5 = -log_gamma(beta)
7970
batch_log_pdf = torch.sum(ll_1 + ll_2 + ll_3 + ll_4 + ll_5, -1)
80-
batch_log_pdf_shape = self.batch_shape(alpha, beta) + (1,)
71+
batch_log_pdf_shape = self.batch_shape(x) + (1,)
8172
return batch_log_pdf.contiguous().view(batch_log_pdf_shape)
8273

83-
def analytic_mean(self, alpha=None, beta=None):
84-
alpha, beta = self._sanitize_input(alpha, beta)
85-
return alpha / (alpha + beta)
74+
def analytic_mean(self):
75+
return self.alpha / (self.alpha + self.beta)
8676

87-
def analytic_var(self, alpha=None, beta=None):
88-
alpha, beta = self._sanitize_input(alpha, beta)
89-
return torch.pow(self.analytic_mean(alpha, beta), 2.0) * beta / \
90-
(alpha * (alpha + beta + Variable(torch.ones([1]))))
77+
def analytic_var(self):
78+
return torch.pow(self.analytic_mean(), 2.0) * self.beta / \
79+
(self.alpha * (self.alpha + self.beta + Variable(torch.ones([1]))))

pyro/distributions/cauchy.py

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,70 +18,62 @@ class Cauchy(Distribution):
1818
the same shape as each other.
1919
"""
2020

21-
def _sanitize_input(self, mu, gamma):
22-
if mu is not None:
23-
# stateless distribution
24-
return mu, gamma
25-
elif self.mu is not None:
26-
# stateful distribution
27-
return self.mu, self.gamma
28-
else:
29-
raise ValueError("Parameter(s) were None")
30-
31-
def __init__(self, mu=None, gamma=None, batch_size=None, *args, **kwargs):
21+
def __init__(self, mu, gamma, batch_size=None, *args, **kwargs):
3222
"""
3323
Params:
3424
`mu` - mean
3525
`gamma` - scale
3626
"""
3727
self.mu = mu
3828
self.gamma = gamma
39-
if mu is not None:
40-
# this will be deprecated in a future PR
41-
if mu.dim() == 1 and batch_size is not None:
42-
self.mu = mu.expand(batch_size, mu.size(0))
43-
self.gamma = gamma.expand(batch_size, gamma.size(0))
29+
if mu.size() != gamma.size():
30+
raise ValueError("Expected mu.size() == gamma.size(), but got {} vs {}"
31+
.format(mu.size(), gamma.size()))
32+
if mu.dim() == 1 and batch_size is not None:
33+
self.mu = mu.expand(batch_size, mu.size(0))
34+
self.gamma = gamma.expand(batch_size, gamma.size(0))
4435
super(Cauchy, self).__init__(*args, **kwargs)
4536

46-
def batch_shape(self, mu=None, gamma=None):
47-
mu, gamma = self._sanitize_input(mu, gamma)
37+
def batch_shape(self, x=None):
4838
event_dim = 1
39+
mu = self.mu
40+
if x is not None and x.size() != mu.size():
41+
mu = self.mu.expand_as(x)
4942
return mu.size()[:-event_dim]
5043

51-
def event_shape(self, mu=None, gamma=None):
52-
mu, gamma = self._sanitize_input(mu, gamma)
44+
def event_shape(self):
5345
event_dim = 1
54-
return mu.size()[-event_dim:]
46+
return self.mu.size()[-event_dim:]
47+
48+
def shape(self, x=None):
49+
return self.batch_shape(x) + self.event_shape()
5550

56-
def sample(self, mu=None, gamma=None):
51+
def sample(self):
5752
"""
5853
Cauchy sampler.
5954
"""
60-
mu, gamma = self._sanitize_input(mu, gamma)
61-
assert mu.dim() == gamma.dim()
62-
np_sample = spr.cauchy.rvs(mu.data.cpu().numpy(), gamma.data.cpu().numpy())
55+
np_sample = spr.cauchy.rvs(self.mu.data.cpu().numpy(),
56+
self.gamma.data.cpu().numpy())
6357
if isinstance(np_sample, numbers.Number):
6458
np_sample = [np_sample]
65-
sample = Variable(torch.Tensor(np_sample).type_as(mu.data))
59+
sample = Variable(torch.Tensor(np_sample).type_as(self.mu.data))
6660
return sample
6761

68-
def batch_log_pdf(self, x, mu=None, gamma=None):
62+
def batch_log_pdf(self, x):
6963
"""
7064
Cauchy log-likelihood
7165
"""
7266
# expand to patch size of input
73-
mu, gamma = self._sanitize_input(mu, gamma)
74-
if x.size() != mu.size():
75-
mu = mu.expand_as(x)
76-
gamma = gamma.expand_as(x)
67+
mu = self.mu.expand(self.shape(x))
68+
gamma = self.gamma.expand(self.shape(x))
7769
x_0 = torch.pow((x - mu)/gamma, 2)
7870
px = np.pi * gamma * (1 + x_0)
7971
log_pdf = -1 * torch.sum(torch.log(px), -1)
80-
batch_log_pdf_shape = self.batch_shape(mu, gamma) + (1,)
72+
batch_log_pdf_shape = self.batch_shape(x) + (1,)
8173
return log_pdf.contiguous().view(batch_log_pdf_shape)
8274

83-
def analytic_mean(self, mu=None, gamma=None):
75+
def analytic_mean(self):
8476
raise ValueError("Cauchy has no defined mean")
8577

86-
def analytic_var(self, mu=None, gamma=None):
78+
def analytic_var(self):
8779
raise ValueError("Cauchy has no defined variance")

pyro/distributions/dirichlet.py

Lines changed: 30 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -16,74 +16,51 @@ class Dirichlet(Distribution):
1616
:param alpha: *(real (0, Infinity))*
1717
"""
1818

19-
def _sanitize_input(self, alpha):
20-
if alpha is not None:
21-
# stateless distribution
22-
return alpha
23-
if self.alpha is not None:
24-
# stateful distribution
25-
return self.alpha
26-
raise ValueError("Parameter(s) were None")
27-
28-
def _expand_dims(self, x, alpha):
29-
"""
30-
Expand to 2-dimensional tensors of the same shape.
31-
"""
32-
if not isinstance(x, (torch.Tensor, Variable)):
33-
raise TypeError('Expected x a Tensor or Variable, got a {}'.format(type(x)))
34-
if not isinstance(alpha, Variable):
35-
raise TypeError('Expected alpha a Variable, got a {}'.format(type(alpha)))
36-
if x.dim() not in (1, 2):
37-
raise ValueError('Expected x.dim() in (1,2), actual: {}'.format(x.dim()))
38-
if alpha.dim() not in (1, 2):
39-
raise ValueError('Expected alpha.dim() in (1,2), actual: {}'.format(alpha.dim()))
40-
if x.size() != alpha.size():
41-
alpha = alpha.expand_as(x)
42-
return x, alpha
43-
44-
def __init__(self, alpha=None, batch_size=None, *args, **kwargs):
19+
def __init__(self, alpha, batch_size=None, *args, **kwargs):
4520
"""
4621
:param alpha: A vector of concentration parameters.
4722
:type alpha: None or a torch.autograd.Variable of a torch.Tensor of dimension 1 or 2.
4823
:param int batch_size: DEPRECATED.
4924
"""
50-
if alpha is None:
51-
self.alpha = None
52-
else:
53-
assert alpha.dim() in (1, 2)
54-
self.alpha = alpha
25+
self.alpha = alpha
26+
if alpha.dim() not in (1, 2):
27+
raise ValueError("Parameter alpha must be either 1 or 2 dimensional.")
28+
if alpha.dim() == 1 and batch_size is not None:
29+
self.alpha = alpha.expand(batch_size, alpha.size(0))
5530
super(Dirichlet, self).__init__(*args, **kwargs)
5631

57-
def batch_shape(self, alpha=None):
58-
alpha = self._sanitize_input(alpha)
59-
return alpha.size()[:-1]
32+
def batch_shape(self, x=None):
33+
event_dim = 1
34+
alpha = self.alpha
35+
if x is not None and x.size() != alpha.size():
36+
alpha = self.alpha.expand_as(x)
37+
return alpha.size()[:-event_dim]
6038

61-
def event_shape(self, alpha=None):
62-
alpha = self._sanitize_input(alpha)
63-
return alpha.size()[-1:]
39+
def event_shape(self):
40+
return self.alpha.size()[-1:]
6441

65-
def sample(self, alpha=None):
42+
def shape(self, x=None):
43+
return self.batch_shape(x) + self.event_shape()
44+
45+
def sample(self):
6646
"""
6747
Draws either a single sample (if alpha.dim() == 1), or one sample per param (if alpha.dim() == 2).
6848
6949
(Un-reparameterized).
7050
7151
:param torch.autograd.Variable alpha:
7252
"""
73-
alpha = self._sanitize_input(alpha)
74-
if alpha.dim() not in (1, 2):
75-
raise ValueError('Expected alpha.dim() in (1,2), actual: {}'.format(alpha.dim()))
76-
alpha_np = alpha.data.cpu().numpy()
77-
if alpha.dim() == 1:
53+
alpha_np = self.alpha.data.cpu().numpy()
54+
if self.alpha.dim() == 1:
7855
x_np = spr.dirichlet.rvs(alpha_np)[0]
7956
else:
8057
x_np = np.empty_like(alpha_np)
8158
for i in range(alpha_np.shape[0]):
8259
x_np[i, :] = spr.dirichlet.rvs(alpha_np[i, :])[0]
83-
x = Variable(type(alpha.data)(x_np))
60+
x = Variable(type(self.alpha.data)(x_np))
8461
return x
8562

86-
def batch_log_pdf(self, x, alpha=None):
63+
def batch_log_pdf(self, x):
8764
"""
8865
Evaluates log probabity density over one or a batch of samples.
8966
@@ -97,24 +74,20 @@ def batch_log_pdf(self, x, alpha=None):
9774
:return: log probability densities of each element in the batch.
9875
:rtype: torch.autograd.Variable of torch.Tensor of dimension 1.
9976
"""
100-
alpha = self._sanitize_input(alpha)
101-
x, alpha = self._expand_dims(x, alpha)
102-
assert x.size() == alpha.size()
77+
alpha = self.alpha.expand(self.shape(x))
10378
x_sum = torch.sum(torch.mul(alpha - 1, torch.log(x)), -1)
10479
beta = log_beta(alpha)
105-
batch_log_pdf_shape = self.batch_shape(alpha) + (1,)
80+
batch_log_pdf_shape = self.batch_shape(x) + (1,)
10681
return (x_sum - beta).contiguous().view(batch_log_pdf_shape)
10782

108-
def analytic_mean(self, alpha):
109-
alpha = self._sanitize_input(alpha)
110-
sum_alpha = torch.sum(alpha)
111-
return alpha / sum_alpha
83+
def analytic_mean(self):
84+
sum_alpha = torch.sum(self.alpha)
85+
return self.alpha / sum_alpha
11286

113-
def analytic_var(self, alpha):
87+
def analytic_var(self):
11488
"""
11589
:return: Analytic variance of the dirichlet distribution, with parameter alpha.
11690
:rtype: torch.autograd.Variable (Vector of the same size as alpha).
11791
"""
118-
alpha = self._sanitize_input(alpha)
119-
sum_alpha = torch.sum(alpha)
120-
return alpha * (sum_alpha - alpha) / (torch.pow(sum_alpha, 2) * (1 + sum_alpha))
92+
sum_alpha = torch.sum(self.alpha)
93+
return self.alpha * (sum_alpha - self.alpha) / (torch.pow(sum_alpha, 2) * (1 + sum_alpha))

0 commit comments

Comments
 (0)