Skip to content

Commit 532b1aa

Browse files
aseyboldttwiecki
authored andcommitted
Improve performance of transformations
1 parent f4d9a7c commit 532b1aa

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

pymc3/distributions/transforms.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def backward(self, x):
8686
def forward(self, x):
8787
return tt.log(x)
8888

89+
def jacobian_det(self, x):
90+
return x
91+
8992
log = Log()
9093

9194

@@ -109,20 +112,22 @@ class Interval(ElemwiseTransform):
109112

110113
name = "interval"
111114

112-
def __init__(self, a, b, eps=1e-6):
115+
def __init__(self, a, b):
113116
self.a = a
114117
self.b = b
115-
self.eps = eps
116118

117119
def backward(self, x):
118120
a, b = self.a, self.b
119-
r = (b - a) / (1 + tt.exp(-x)) + a
121+
r = (b - a) * tt.nnet.sigmoid(x) + a
120122
return r
121123

122124
def forward(self, x):
123-
a, b, e = self.a, self.b, self.eps
124-
r = tt.log(tt.maximum((x - a) / tt.maximum(b - x, e), e))
125-
return r
125+
a, b = self.a, self.b
126+
return tt.log(x - a) - tt.log(b - x)
127+
128+
def jacobian_det(self, x):
129+
s = tt.nnet.softplus(-x)
130+
return tt.log(self.b - self.a) - 2 * s - x
126131

127132
interval = Interval
128133

@@ -145,6 +150,9 @@ def forward(self, x):
145150
r = tt.log(x - a)
146151
return r
147152

153+
def jacobian_det(self, x):
154+
return x
155+
148156
lowerbound = LowerBound
149157

150158

@@ -166,6 +174,9 @@ def forward(self, x):
166174
r = tt.log(b - x)
167175
return r
168176

177+
def jacobian_det(self, x):
178+
return x
179+
169180
upperbound = UpperBound
170181

171182

0 commit comments

Comments
 (0)