Skip to content

Commit 01a7000

Browse files
Chris Fonnesbecktwiecki
authored andcommitted
Re-implemented zero-inflated Poisson to use standard parameterization (pymc-devs#1300)
* Re-implemented zero-inflated Poisson to use standard parameterization, rather than using indicators * Fixed typo in latent occupancy * Fix to ZIP test
1 parent 2c9fb26 commit 01a7000

File tree

4 files changed

+48
-35
lines changed

4 files changed

+48
-35
lines changed

pymc3/distributions/discrete.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -432,20 +432,49 @@ def logp(self, value):
432432

433433

434434
class ZeroInflatedPoisson(Discrete):
435-
def __init__(self, theta, z, *args, **kwargs):
435+
R"""
436+
Zero-inflated Poisson log-likelihood.
437+
438+
Often used to model the number of events occurring in a fixed period
439+
of time when the times at which events occur are independent.
440+
441+
.. math::
442+
443+
f(x \mid \theta, \psi) = \left\{ \begin{array}{l}
444+
(1-\psi) + \psi e^{-\theta}, \text{if } x = 0 \\
445+
\psi \frac{e^{-\theta}\theta^x}{x!}, \text{if } x=1,2,3,\ldots
446+
\end{array} \right.
447+
448+
======== ==========================
449+
Support :math:`x \in \mathbb{N}_0`
450+
Mean :math:`\psi\theta`
451+
Variance :math:`\theta + \frac{1-\psi}{\psi}\theta^2`
452+
======== ==========================
453+
454+
Parameters
455+
----------
456+
theta : float
457+
Expected number of occurrences during the given interval
458+
(theta >= 0).
459+
psi : float
460+
Expected proportion of Poisson variates (0 < psi < 1)
461+
462+
"""
463+
def __init__(self, theta, psi, *args, **kwargs):
436464
super(ZeroInflatedPoisson, self).__init__(*args, **kwargs)
437465
self.theta = theta
438-
self.z = z
466+
self.psi = psi
439467
self.pois = Poisson.dist(theta)
440-
self.const = ConstantDist.dist(0)
441468
self.mode = self.pois.mode
442469

443470
def random(self, point=None, size=None, repeat=None):
444-
theta = draw_values([self.theta], point=point)
445-
# To do: Finish me
446-
return None
471+
theta, psi = draw_values([self.theta, self.psi], point=point)
472+
g = generate_samples(stats.poisson.rvs, theta,
473+
dist_shape=self.shape,
474+
size=size)
475+
return g * (np.random.random(np.squeeze(g.shape)) < psi)
447476

448477
def logp(self, value):
449-
return tt.switch(self.z,
450-
self.pois.logp(value),
451-
self.const.logp(value))
478+
return tt.switch(value > 0,
479+
tt.log(self.psi) + self.pois.logp(value),
480+
tt.log((1. - self.psi) + self.psi * tt.exp(-self.theta)))

pymc3/examples/latent_occupancy.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,13 @@
4747
with model:
4848
# Estimated occupancy
4949

50-
p = Beta('p', 1, 1)
51-
52-
# Latent variable for occupancy
53-
z = Bernoulli('z', p, y.shape)
50+
psi = Beta('psi', 1, 1)
5451

5552
# Estimated mean count
5653
theta = Uniform('theta', 0, 100)
5754

5855
# Poisson likelihood
59-
yd = ZeroInflatedPoisson('y', theta, z, observed=y)
56+
yd = ZeroInflatedPoisson('y', theta, psi, observed=y)
6057

6158

6259
point = model.test_point
@@ -77,11 +74,9 @@ def run(n=5000):
7774
if n == "short":
7875
n = 50
7976
with model:
80-
start = {'p': 0.5, 'z': (y > 0).astype(int), 'theta': 5}
81-
82-
step1 = Metropolis([theta, p])
77+
start = {'psi': 0.5, 'z': (y > 0).astype(int), 'theta': 5}
8378

84-
step2 = BinaryMetropolis([z])
79+
step1 = Metropolis([theta, psi])
8580

8681
trace = sample(n, [step1, step2], start)
8782

pymc3/tests/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def test_constantdist():
380380
)
381381

382382
def test_zeroinflatedpoisson():
383-
checkd(ZeroInflatedPoisson, I, {'theta': Rplus, 'z': Bool})
383+
checkd(ZeroInflatedPoisson, Nat, {'theta': Rplus, 'psi': Unit})
384384

385385
def test_mvnormal():
386386
for n in [1, 2]:

pymc3/tests/test_distributions_random.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,8 @@ def test_constant_dist(self):
201201
self.check(ConstantDist, c=3)
202202

203203
def test_zero_inflated_poisson(self):
204-
# To do: implement ZIP random
205-
#self.check(ZeroInflatedPoisson)
206-
raise SkipTest(
207-
'ZeroInflatedPoisson random sampling not yet implemented.')
208-
204+
self.check(ZeroInflatedPoisson, theta=1, psi=0.3)
205+
209206
def test_discrete_uniform(self):
210207
self.check(DiscreteUniform, lower=0., upper=10)
211208

@@ -297,9 +294,7 @@ def test_constant_dist(self):
297294
self.check(ConstantDist, c=3)
298295

299296
def test_zero_inflated_poisson(self):
300-
# To do: implement ZIP random
301-
raise SkipTest(
302-
'ZeroInflatedPoisson random sampling not yet implemented.')
297+
self.check(ZeroInflatedPoisson, theta=1, psi=0.3)
303298

304299
def test_discrete_uniform(self):
305300
self.check(DiscreteUniform, lower=0., upper=10)
@@ -402,10 +397,7 @@ def test_constantDist(self):
402397
self.check(ConstantDist, c=(self.ones * 3).astype(int))
403398

404399
def test_zero_inflated_poisson(self):
405-
# To do: implement ZIP random
406-
raise SkipTest(
407-
'ZeroInflatedPoisson random sampling not yet implemented.')
408-
self.check(ZeroInflatedPoisson, {}, SkipTest)
400+
self.check(ZeroInflatedPoisson, theta=self.ones, psi=self.ones/2)
409401

410402
def test_discrete_uniform(self):
411403
self.check(DiscreteUniform,
@@ -515,10 +507,7 @@ def test_constantDist(self):
515507
self.check(ConstantDist, c=(self.ones * 3).astype(int))
516508

517509
def test_zero_inflated_poisson(self):
518-
# To do: implement ZIP random
519-
raise SkipTest(
520-
'ZeroInflatedPoisson random sampling not yet implemented.')
521-
self.check(ZeroInflatedPoisson, {})
510+
self.check(ZeroInflatedPoisson, theta=self.ones*2, psi=self.ones/3)
522511

523512
def test_discrete_uniform(self):
524513
self.check(DiscreteUniform, lower=self.zeros.astype(int),

0 commit comments

Comments
 (0)