Skip to content

Commit ffae3f2

Browse files
authored
Update multivariate.py
1 parent 7bb2ccd commit ffae3f2

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed

pymc/distributions/multivariate.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,48 @@ def logp(value, mu, cov):
275275
ok,
276276
msg="posdef",
277277
)
278+
279+
class MvSkewNormalRV(RandomVariable):
280+
name = "multivariate_studentt"
281+
ndim_supp = 1
282+
ndims_params = [0, 1, 2]
283+
dtype = "floatX"
284+
_print_name = ("MvStudentT", "\\operatorname{MvStudentT}")
285+
286+
def _supp_shape_from_params(self, dist_params, param_shapes=None):
287+
return supp_shape_from_ref_param_shape(
288+
ndim_supp=self.ndim_supp,
289+
dist_params=dist_params,
290+
param_shapes=param_shapes,
291+
ref_param_idx=1,
292+
)
293+
294+
@classmethod
295+
def rng_fn(cls, rng, alpha, mu, cov, size):
296+
if size is None:
297+
# When size is implicit, we need to broadcast parameters correctly,
298+
# so that the MvNormal draws and the chisquare draws have the same number of batch dimensions.
299+
# nu broadcasts mu and cov
300+
if np.ndim(alpha) > max(mu.ndim - 1, cov.ndim - 2):
301+
_, mu, cov = broadcast_params((alpha, mu, cov), ndims_params=cls.ndims_params)
302+
# nu is broadcasted by either mu or cov
303+
elif np.ndim(alpha) < max(mu.ndim - 1, cov.ndim - 2):
304+
alpha, _, _ = broadcast_params((alpha, mu, cov), ndims_params=cls.ndims_params)
305+
306+
mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov_star, size=size)
307+
308+
aCa = alpha @ cov @ alpha
309+
delta = (1/ np.sqrt(1 + aCa)) * cov @ alpha
310+
cov_star = np.block([[np.ones(1), delta],
311+
[delta[:, None], cov]])
312+
313+
x0 = mv_samples[:, 0]
314+
x1 = mv_samples[:, 1:]
315+
inds = x0 <= 0
316+
x1[inds] = -x1[inds]
317+
x1 = x1 + mu
318+
return x1
319+
278320

279321

280322
class MvStudentTRV(RandomVariable):
@@ -310,6 +352,124 @@ def rng_fn(cls, rng, nu, mu, cov, size):
310352
chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)[..., None]
311353

312354
return (mv_samples / chi2_samples) + mu
355+
356+
mv_skewnormal = MvSkewNormalRV()
357+
358+
class MvSkewNormal(Continuous):
359+
r"""
360+
Multivariate Student-T log-likelihood.
361+
362+
.. math::
363+
f(\mathbf{x}| \nu,\mu,\Sigma) =
364+
\frac
365+
{\Gamma\left[(\nu+p)/2\right]}
366+
{\Gamma(\nu/2)\nu^{p/2}\pi^{p/2}
367+
\left|{\Sigma}\right|^{1/2}
368+
\left[
369+
1+\frac{1}{\nu}
370+
({\mathbf x}-{\mu})^T
371+
{\Sigma}^{-1}({\mathbf x}-{\mu})
372+
\right]^{-(\nu+p)/2}}
373+
374+
======== =============================================
375+
Support :math:`x \in \mathbb{R}^p`
376+
Mean :math:`\mu` if :math:`\nu > 1` else undefined
377+
Variance :math:`\frac{\nu}{\mu-2}\Sigma`
378+
if :math:`\nu>2` else undefined
379+
======== =============================================
380+
381+
Parameters
382+
----------
383+
nu : tensor_like of float
384+
Degrees of freedom, should be a positive scalar.
385+
Sigma : tensor_like of float, optional
386+
Scale matrix. Use `scale` in new code.
387+
mu : tensor_like of float, optional
388+
Vector of means.
389+
scale : tensor_like of float, optional
390+
The scale matrix.
391+
tau : tensor_like of float, optional
392+
The precision matrix.
393+
chol : tensor_like of float, optional
394+
The cholesky factor of the scale matrix.
395+
lower : bool, default=True
396+
Whether the cholesky fatcor is given as a lower triangular matrix.
397+
"""
398+
rv_op = mv_skewnormal
399+
400+
@classmethod
401+
def dist(cls, alpha, *, Sigma=None, mu=0, scale=None, tau=None, chol=None, lower=True, **kwargs):
402+
cov = kwargs.pop("cov", None)
403+
if cov is not None:
404+
warnings.warn(
405+
"Use the scale argument to specify the scale matrix. "
406+
"cov will be removed in future versions.",
407+
FutureWarning,
408+
)
409+
scale = cov
410+
if Sigma is not None:
411+
if scale is not None:
412+
raise ValueError("Specify only one of scale and Sigma")
413+
scale = Sigma
414+
alpha = pt.as_tensor_variable(floatX(alpha))
415+
mu = pt.as_tensor_variable(floatX(mu))
416+
aCa = pt.matmul(pt.matmul(alpha, cov), alpha)
417+
delta = pt.matul((1/ pt.sqrt(1 + aCa)) * cov, alpha)
418+
a = pt.concatenate((pt.ones(1), delta)).reshape((1, -1))
419+
b = pt.concatenate((delta[:, None], cov), axis=1)
420+
cov_star = pt.concatenate((a, b))
421+
scale = cov_star
422+
# PyTensor is stricter about the shape of mu, than PyMC used to be
423+
mu, _ = pt.broadcast_arrays(mu, scale[..., -1])
424+
425+
return super().dist([alpha, mu, scale], **kwargs)
426+
427+
def moment(rv, size, alpha, mu, scale):
428+
# mu is broadcasted to the potential length of scale in `dist`
429+
mu, _ = pt.random.utils.broadcast_params([mu, alpha], ndims_params=[1, 0])
430+
moment = mu
431+
if not rv_size_is_none(size):
432+
moment_size = pt.concatenate([size, [mu.shape[-1]]])
433+
moment = pt.full(moment_size, moment)
434+
return moment
435+
436+
def logp(value, alpha, mu, scale):
437+
"""
438+
Calculate log-probability of Multivariate Student's T distribution
439+
at specified value.
440+
441+
Parameters
442+
----------
443+
value: numeric
444+
Value for which log-probability is calculated.
445+
446+
Returns
447+
-------
448+
TensorVariable
449+
"""
450+
mv_normal = pm.MvNormal.dist(mu=mu, cov=scale)
451+
452+
std_devs = pm.math.sqrt(pt.diag(scale)) + 0.1
453+
454+
omega = pt.diag(std_devs)
455+
456+
# # Calculate the log probability of the value under the multivariate normal distribution
457+
log_prob_mv_normal = pm.logp(mv_normal, value - mu)
458+
459+
# # Calculate omega inverse
460+
omega_inv = pt.nlinalg.pinv(omega)
461+
462+
# Calculate the argument for the standard normal CDF
463+
arg_cdf = pm.math.sum(pm.math.dot(alpha.T, omega_inv)* (value - mu), axis=1)
464+
465+
# Instantiate the standard normal distribution for the CDF
466+
std_normal = pm.Normal.dist(mu=0, sigma=1)
467+
468+
# Calculate the log CDF of the argument under the standard normal distribution
469+
log_cdf_std_normal = pm.logcdf(std_normal, arg_cdf)
470+
471+
# Return the log of the skew-normal density
472+
return pm.math.log(2) + log_prob_mv_normal + log_cdf_std_normal
313473

314474

315475
mv_studentt = MvStudentTRV()

0 commit comments

Comments
 (0)