Skip to content

Extend Rice distribution #3289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 6, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix Rice distribution and add new parametrization (#3286)
  • Loading branch information
nbud committed Dec 5, 2018
commit de17badf5a736c105c9ae83f2101378f259b033b
43 changes: 36 additions & 7 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3505,29 +3505,57 @@ class Rice(PositiveContinuous):
======== ==============================================================
Support :math:`x \in (0, \infty)`
Mean :math:`\sigma {\sqrt {\pi /2}}\,\,L_{{1/2}}(-\nu ^{2}/2\sigma ^{2})`
Variance :math:`2\sigma ^{2}+\nu ^{2}-{\frac {\pi \sigma ^{2}}{2}}L_{{1/2}}^{2}
\left({\frac {-\nu ^{2}}{2\sigma ^{2}}}\right)`
Variance :math:`2\sigma ^{2}+\nu ^{2}-{\frac {\pi \sigma ^{2}}{2}}L_{{1/2}}^{2}\left({\frac {-\nu ^{2}}{2\sigma ^{2}}}\right)`
======== ==============================================================


Parameters
----------
nu : float
shape parameter.
noncentrality parameter.
sd : float
standard deviation.
scale parameter.
b : float
shape parameter (alternative to nu).

Notes
-----
The distribution :math:`\mathrm{Rice}\left(|\nu|,\sigma\right)` is the
distribution of :math:`R=\sqrt{X^2+Y^2}` where :math:`X\sim N(\nu, \sigma^2)`,
:math:`Y\sim N(\nu, \sigma^2)`, and :math:`X` and :math:`Y` are independent.

The distribution is defined with either nu or b.
The link between the two parametrizations is given by

.. math::

b = \dfrac{\nu}{\sigma}

"""

def __init__(self, nu=None, sd=None, *args, **kwargs):
def __init__(self, nu=None, sd=None, b=None, *args, **kwargs):
super(Rice, self).__init__(*args, **kwargs)
nu, b, sd = self.get_nu_b(nu, b, sd)
self.nu = nu = tt.as_tensor_variable(nu)
self.sd = sd = tt.as_tensor_variable(sd)
self.b = b = tt.as_tensor_variable(b)
self.mean = sd * np.sqrt(np.pi / 2) * tt.exp((-nu**2 / (2 * sd**2)) / 2) * ((1 - (-nu**2 / (2 * sd**2)))
* i0(-(-nu**2 / (2 * sd**2)) / 2) - (-nu**2 / (2 * sd**2)) * i1(-(-nu**2 / (2 * sd**2)) / 2))
self.variance = 2 * sd**2 + nu**2 - (np.pi * sd**2 / 2) * (tt.exp((-nu**2 / (2 * sd**2)) / 2) * ((1 - (-nu**2 / (
2 * sd**2))) * i0(-(-nu**2 / (2 * sd**2)) / 2) - (-nu**2 / (2 * sd**2)) * i1(-(-nu**2 / (2 * sd**2)) / 2)))**2

def get_nu_b(self, nu, b, sd):
if sd is None:
sd = 1.
if nu is None and b is not None:
nu = b * sd
return nu, b, sd
elif nu is not None and b is None:
b = nu / sd
return nu, b, sd
raise ValueError('Rice distribution must specify either nu'
' or b.')

def random(self, point=None, size=None):
"""
Draw random values from Rice distribution.
Expand All @@ -3547,7 +3575,7 @@ def random(self, point=None, size=None):
"""
nu, sd = draw_values([self.nu, self.sd],
point=point, size=size)
return generate_samples(stats.rice.rvs, b=nu, scale=sd, loc=0,
return generate_samples(stats.rice.rvs, b=nu/sd, scale=sd, loc=0,
dist_shape=self.shape, size=size)

def logp(self, value):
Expand All @@ -3566,8 +3594,9 @@ def logp(self, value):
"""
nu = self.nu
sd = self.sd
b = self.b
x = value / sd
return bound(tt.log(x * tt.exp((-(x - nu) * (x - nu)) / 2) * i0e(x * nu) / sd),
return bound(tt.log(x * tt.exp((-(x - b) * (x - b)) / 2) * i0e(x * b) / sd),
sd >= 0,
nu >= 0,
value > 0,
Expand Down
4 changes: 3 additions & 1 deletion pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,9 @@ def test_multidimensional_beta_construction(self):

def test_rice(self):
self.pymc3_matches_scipy(Rice, Rplus, {'nu': Rplus, 'sd': Rplusbig},
lambda value, nu, sd: sp.rice.logpdf(value, b=nu, loc=0, scale=sd))
lambda value, nu, sd: sp.rice.logpdf(value, b=nu/sd, loc=0, scale=sd))
self.pymc3_matches_scipy(Rice, Rplus, {'b': Rplus, 'sd': Rplusbig},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test work before your last commit with the elementwise fixes? If not, we might want to add one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is actually part of #3287 and successfully passes. The current PR (#3289) initially broke it when I tried to use elemwise but now passes. If PR #3287 is merged, I'll rebase the current PR so it will be clearer to see the differences.

lambda value, b, sd: sp.rice.logpdf(value, b=b, loc=0, scale=sd))

@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
def test_interpolated(self):
Expand Down