Skip to content

Commit c761c5d

Browse files
author
Junpeng Lao
authored
Merge pull request pymc-devs#2299 from ferrine/fix_2298
trying to fix pymc-devs#2298
2 parents dfd0b8c + 3aa2605 commit c761c5d

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

pymc3/model.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -783,13 +783,17 @@ def __call__(self, *args, **kwargs):
783783
compilef = fastfn
784784

785785

786-
def _get_scaling(total_size, data):
786+
def _get_scaling(total_size, shape, ndim):
787787
"""
788+
Gets scaling constant for logp
788789
789790
Parameters
790791
----------
791792
total_size : int or list[int]
792-
data : n-dimentional tensor
793+
shape : shape
794+
shape to scale
795+
ndim : int
796+
ndim hint
793797
794798
Returns
795799
-------
@@ -798,16 +802,15 @@ def _get_scaling(total_size, data):
798802
if total_size is None:
799803
coef = pm.floatX(1)
800804
elif isinstance(total_size, int):
801-
if data.ndim >= 1:
802-
denom = data.shape[0]
805+
if ndim >= 1:
806+
denom = shape[0]
803807
else:
804808
denom = 1
805809
coef = pm.floatX(total_size) / pm.floatX(denom)
806810
elif isinstance(total_size, (list, tuple)):
807811
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
808812
raise TypeError('Unrecognized `total_size` type, expected '
809813
'int or list of ints, got %r' % total_size)
810-
shape = data.shape
811814
if Ellipsis in total_size:
812815
sep = total_size.index(Ellipsis)
813816
begin = total_size[:sep]
@@ -817,7 +820,7 @@ def _get_scaling(total_size, data):
817820
else:
818821
begin = total_size
819822
end = []
820-
if (len(begin) + len(end)) > data.ndim:
823+
if (len(begin) + len(end)) > ndim:
821824
raise ValueError('Length of `total_size` is too big, '
822825
'number of scalings is bigger that ndim, got %r' % total_size)
823826
elif (len(begin) + len(end)) == 0:
@@ -866,7 +869,7 @@ def __init__(self, type=None, owner=None, index=None, name=None,
866869
self.logp_elemwiset = distribution.logp(self)
867870
self.total_size = total_size
868871
self.model = model
869-
self.scaling = _get_scaling(total_size, self)
872+
self.scaling = _get_scaling(total_size, self.shape, self.ndim)
870873

871874
incorporate_methods(source=distribution, destination=self,
872875
methods=['random'],
@@ -972,7 +975,7 @@ def __init__(self, type=None, owner=None, index=None, name=None, data=None,
972975
theano.gof.Apply(theano.compile.view_op,
973976
inputs=[data], outputs=[self])
974977
self.tag.test_value = theano.compile.view_op(data).tag.test_value
975-
self.scaling = _get_scaling(total_size, data)
978+
self.scaling = _get_scaling(total_size, data.shape, data.ndim)
976979

977980
def _repr_latex_(self, name=None, dist=None):
978981
if self.distribution is None:
@@ -1016,6 +1019,7 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
10161019
self.total_size = total_size
10171020
self.model = model
10181021
self.distribution = distribution
1022+
self.scaling = _get_scaling(total_size, self.logp_elemwiset.shape, self.logp_elemwiset.ndim)
10191023

10201024

10211025
def Deterministic(name, var, model=None):
@@ -1093,7 +1097,7 @@ def __init__(self, type=None, owner=None, index=None, name=None,
10931097
theano.Apply(theano.compile.view_op, inputs=[
10941098
normalRV], outputs=[self])
10951099
self.tag.test_value = normalRV.tag.test_value
1096-
1100+
self.scaling = _get_scaling(total_size, self.shape, self.ndim)
10971101
incorporate_methods(source=distribution, destination=self,
10981102
methods=['random'],
10991103
wrapper=InstanceMethod)

0 commit comments

Comments
 (0)