@@ -16,74 +16,51 @@ class Dirichlet(Distribution):
1616 :param alpha: *(real (0, Infinity))*
1717 """
1818
19- def _sanitize_input (self , alpha ):
20- if alpha is not None :
21- # stateless distribution
22- return alpha
23- if self .alpha is not None :
24- # stateful distribution
25- return self .alpha
26- raise ValueError ("Parameter(s) were None" )
27-
28- def _expand_dims (self , x , alpha ):
29- """
30- Expand to 2-dimensional tensors of the same shape.
31- """
32- if not isinstance (x , (torch .Tensor , Variable )):
33- raise TypeError ('Expected x a Tensor or Variable, got a {}' .format (type (x )))
34- if not isinstance (alpha , Variable ):
35- raise TypeError ('Expected alpha a Variable, got a {}' .format (type (alpha )))
36- if x .dim () not in (1 , 2 ):
37- raise ValueError ('Expected x.dim() in (1,2), actual: {}' .format (x .dim ()))
38- if alpha .dim () not in (1 , 2 ):
39- raise ValueError ('Expected alpha.dim() in (1,2), actual: {}' .format (alpha .dim ()))
40- if x .size () != alpha .size ():
41- alpha = alpha .expand_as (x )
42- return x , alpha
43-
44- def __init__ (self , alpha = None , batch_size = None , * args , ** kwargs ):
19+ def __init__ (self , alpha , batch_size = None , * args , ** kwargs ):
4520 """
4621 :param alpha: A vector of concentration parameters.
4722 :type alpha: None or a torch.autograd.Variable of a torch.Tensor of dimension 1 or 2.
4823 :param int batch_size: DEPRECATED.
4924 """
50- if alpha is None :
51- self . alpha = None
52- else :
53- assert alpha .dim () in ( 1 , 2 )
54- self .alpha = alpha
25+ self . alpha = alpha
26+ if alpha . dim () not in ( 1 , 2 ):
27+ raise ValueError ( "Parameter alpha must be either 1 or 2 dimensional." )
28+ if alpha .dim () == 1 and batch_size is not None :
29+ self .alpha = alpha . expand ( batch_size , alpha . size ( 0 ))
5530 super (Dirichlet , self ).__init__ (* args , ** kwargs )
5631
57- def batch_shape (self , alpha = None ):
58- alpha = self ._sanitize_input (alpha )
59- return alpha .size ()[:- 1 ]
32+ def batch_shape (self , x = None ):
33+ event_dim = 1
34+ alpha = self .alpha
35+ if x is not None and x .size () != alpha .size ():
36+ alpha = self .alpha .expand_as (x )
37+ return alpha .size ()[:- event_dim ]
6038
61- def event_shape (self , alpha = None ):
62- alpha = self ._sanitize_input (alpha )
63- return alpha .size ()[- 1 :]
39+ def event_shape (self ):
40+ return self .alpha .size ()[- 1 :]
6441
65- def sample (self , alpha = None ):
42+ def shape (self , x = None ):
43+ return self .batch_shape (x ) + self .event_shape ()
44+
45+ def sample (self ):
6646 """
6747 Draws either a single sample (if alpha.dim() == 1), or one sample per param (if alpha.dim() == 2).
6848
6949 (Un-reparameterized).
7050
7151 :param torch.autograd.Variable alpha:
7252 """
73- alpha = self ._sanitize_input (alpha )
74- if alpha .dim () not in (1 , 2 ):
75- raise ValueError ('Expected alpha.dim() in (1,2), actual: {}' .format (alpha .dim ()))
76- alpha_np = alpha .data .cpu ().numpy ()
77- if alpha .dim () == 1 :
53+ alpha_np = self .alpha .data .cpu ().numpy ()
54+ if self .alpha .dim () == 1 :
7855 x_np = spr .dirichlet .rvs (alpha_np )[0 ]
7956 else :
8057 x_np = np .empty_like (alpha_np )
8158 for i in range (alpha_np .shape [0 ]):
8259 x_np [i , :] = spr .dirichlet .rvs (alpha_np [i , :])[0 ]
83- x = Variable (type (alpha .data )(x_np ))
60+ x = Variable (type (self . alpha .data )(x_np ))
8461 return x
8562
86- def batch_log_pdf (self , x , alpha = None ):
63+ def batch_log_pdf (self , x ):
8764 """
8865 Evaluates log probabity density over one or a batch of samples.
8966
@@ -97,24 +74,20 @@ def batch_log_pdf(self, x, alpha=None):
9774 :return: log probability densities of each element in the batch.
9875 :rtype: torch.autograd.Variable of torch.Tensor of dimension 1.
9976 """
100- alpha = self ._sanitize_input (alpha )
101- x , alpha = self ._expand_dims (x , alpha )
102- assert x .size () == alpha .size ()
77+ alpha = self .alpha .expand (self .shape (x ))
10378 x_sum = torch .sum (torch .mul (alpha - 1 , torch .log (x )), - 1 )
10479 beta = log_beta (alpha )
105- batch_log_pdf_shape = self .batch_shape (alpha ) + (1 ,)
80+ batch_log_pdf_shape = self .batch_shape (x ) + (1 ,)
10681 return (x_sum - beta ).contiguous ().view (batch_log_pdf_shape )
10782
108- def analytic_mean (self , alpha ):
109- alpha = self ._sanitize_input (alpha )
110- sum_alpha = torch .sum (alpha )
111- return alpha / sum_alpha
83+ def analytic_mean (self ):
84+ sum_alpha = torch .sum (self .alpha )
85+ return self .alpha / sum_alpha
11286
113- def analytic_var (self , alpha ):
87+ def analytic_var (self ):
11488 """
11589 :return: Analytic variance of the dirichlet distribution, with parameter alpha.
11690 :rtype: torch.autograd.Variable (Vector of the same size as alpha).
11791 """
118- alpha = self ._sanitize_input (alpha )
119- sum_alpha = torch .sum (alpha )
120- return alpha * (sum_alpha - alpha ) / (torch .pow (sum_alpha , 2 ) * (1 + sum_alpha ))
92+ sum_alpha = torch .sum (self .alpha )
93+ return self .alpha * (sum_alpha - self .alpha ) / (torch .pow (sum_alpha , 2 ) * (1 + sum_alpha ))
0 commit comments