Skip to content

Commit eb55106

Browse files
committed
Associate transforms with random variables
1 parent c64e12f commit eb55106

File tree

6 files changed

+105
-76
lines changed

6 files changed

+105
-76
lines changed

aeppl/joint_logprob.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,16 @@ def conditional_logprob(
110110

111111
fgraph, rv_values, _ = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)
112112

113+
# The interface for transformations assumes that the value variables are in
114+
# the transformed space. To get the correct `shape` and `dtype` for the
115+
# value variables we return we need to apply the forward transformation to
116+
# our RV copies, and return the type of the resulting variable as a value
117+
# variable.
118+
vv_remapper = {}
113119
if extra_rewrites is not None:
114-
extra_rewrites.rewrite(fgraph)
120+
extra_rewrites.add_requirements(fgraph, {**original_rv_values, **realized})
121+
extra_rewrites.apply(fgraph)
122+
vv_remapper = fgraph.values_to_untransformed
115123

116124
rv_remapper = fgraph.preserve_rv_mappings
117125

@@ -145,6 +153,7 @@ def conditional_logprob(
145153
q = deque(fgraph.toposort())
146154

147155
logprob_vars = {}
156+
value_variables = {}
148157

149158
while q:
150159
node = q.popleft()
@@ -201,6 +210,9 @@ def conditional_logprob(
201210

202211
logprob_vars[q_rv] = q_logprob_var
203212

213+
q_value_var = vv_remapper.get(q_value_var, q_value_var)
214+
value_variables[q_rv] = q_value_var
215+
204216
# Recompute test values for the changes introduced by the
205217
# replacements above.
206218
if config.compute_test_value != "off":
@@ -213,7 +225,7 @@ def conditional_logprob(
213225
f"The logprob terms of the following random variables could not be derived: {missing_value_terms}"
214226
)
215227

216-
return logprob_vars, list(original_rv_values.values())
228+
return logprob_vars, [value_variables[rv] for rv in original_rv_values.keys()]
217229

218230

219231
def joint_logprob(

aeppl/transforms.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,14 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
122122
"""
123123

124124
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
125+
values_to_untransformed = getattr(fgraph, "values_to_untransformed", None)
125126
values_to_transforms = getattr(fgraph, "values_to_transforms", None)
126127

127-
if rv_map_feature is None or values_to_transforms is None:
128+
if (
129+
rv_map_feature is None
130+
or values_to_transforms is None
131+
or values_to_untransformed is None
132+
):
128133
return None # pragma: no cover
129134

130135
try:
@@ -133,6 +138,7 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
133138
except ValueError:
134139
return None
135140

141+
value_var: TensorVariable
136142
value_var = rv_map_feature.rv_values.get(rv_var, None)
137143
if value_var is None:
138144
return None
@@ -154,10 +160,21 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
154160
trans_node.outputs[rv_var_out_idx].name = node.outputs[rv_var_out_idx].name
155161

156162
# We now assume that the old value variable represents the *transformed space*.
157-
# This means that we need to replace all instance of the old value variable
163+
164+
# Since we initialize value variables as copies of the random variables,
165+
# thus in the untransformed space, we need to apply the forward
166+
# transformation to get value variables in the transformed space.
167+
old_value_var: TensorVariable = transform.forward(
168+
value_var, *trans_node.inputs
169+
).type()
170+
if value_var.name:
171+
old_value_var.name = value_var.name
172+
values_to_untransformed[value_var] = old_value_var
173+
174+
# We need to replace all instance of the old value variable
158175
# with "inversely/un-" transformed versions of itself.
159176
new_value_var = transformed_variable(
160-
transform.backward(value_var, *trans_node.inputs), value_var
177+
transform.backward(old_value_var, *trans_node.inputs), old_value_var
161178
)
162179
if value_var.name and getattr(transform, "name", None):
163180
new_value_var.name = f"{value_var.name}_{transform.name}"
@@ -170,16 +187,24 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
170187

171188

172189
class TransformValuesMapping(Feature):
173-
r"""A `Feature` that maintains a map between value variables and their transforms."""
190+
r"""A `Feature` that maintains a map between value variables and their transforms as
191+
well as between value variables and their transformed counterpart.
192+
193+
"""
174194

175195
def __init__(self, values_to_transforms):
176196
self.values_to_transforms = values_to_transforms
197+
self.values_to_untransformed: Dict[TensorVariable, TensorVariable] = {}
177198

178199
def on_attach(self, fgraph):
179200
if hasattr(fgraph, "values_to_transforms"):
180201
raise AlreadyThere()
181202

182203
fgraph.values_to_transforms = self.values_to_transforms
204+
fgraph.values_to_untransformed = self.values_to_untransformed
205+
206+
def update_untransformed_value(self, value, untransformed_value):
207+
self.values_to_untransformed[value] = untransformed_value
183208

184209

185210
class TransformValuesRewrite(GraphRewriter):
@@ -189,25 +214,31 @@ class TransformValuesRewrite(GraphRewriter):
189214

190215
def __init__(
191216
self,
192-
values_to_transforms: Dict[
217+
rvs_to_transforms: Dict[
193218
TensorVariable, Union[RVTransform, DefaultTransformSentinel, None]
194219
],
195220
):
196221
"""
197222
Parameters
198223
==========
199224
values_to_transforms
200-
Mapping between value variables and their transformations. Each
201-
value variable can be assigned one of `RVTransform`,
202-
``DEFAULT_TRANSFORM``, or ``None``. If a transform is not specified
203-
for a specific value variable it will not be transformed.
225+
Mapping between random variables and their transformations. Each
226+
random variable can be assigned one of `RVTransform`,
227+
``DEFAULT_TRANSFORM``, or ``None``. Random variables with no
228+
transform specified remain unchanged.
204229
205230
"""
206231

207-
self.values_to_transforms = values_to_transforms
232+
self.rvs_to_transforms = rvs_to_transforms
208233

209-
def add_requirements(self, fgraph):
210-
values_transforms_feature = TransformValuesMapping(self.values_to_transforms)
234+
def add_requirements(
235+
self, fgraph, rv_to_values: Dict[TensorVariable, TensorVariable]
236+
):
237+
values_to_transforms = {
238+
rv_to_values[rv]: transform
239+
for rv, transform in self.rvs_to_transforms.items()
240+
}
241+
values_transforms_feature = TransformValuesMapping(values_to_transforms)
211242
fgraph.attach_feature(values_transforms_feature)
212243

213244
def apply(self, fgraph: FunctionGraph):

tests/test_censoring.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,8 @@ def test_clip_transform():
154154
x_rv = at.random.normal(0.5, 1)
155155
cens_x_rv = at.clip(x_rv, 0, x_rv)
156156

157-
cens_x_vv = cens_x_rv.clone()
158-
159-
transform = TransformValuesRewrite({cens_x_vv: LogTransform()})
160-
logp, _ = joint_logprob(realized={cens_x_rv: cens_x_vv}, extra_rewrites=transform)
157+
transform = TransformValuesRewrite({cens_x_rv: LogTransform()})
158+
logp, (cens_x_vv,) = joint_logprob(cens_x_rv, extra_rewrites=transform)
161159

162160
cens_x_vv_testval = -1
163161
obs_logp = logp.eval({cens_x_vv: cens_x_vv_testval})

tests/test_joint_logprob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,6 @@ def test_multiple_rvs_to_same_value_raises():
247247
x = x_rv1.type()
248248
x.name = "x"
249249

250-
msg = "More than one logprob factor was assigned to the value variable x"
250+
msg = "More than one logprob factor was assigned to the random variable x"
251251
with pytest.raises(ValueError, match=msg):
252252
joint_logprob(realized={x_rv1: x, x_rv2: x})

tests/test_mixture.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def create_mix_model(size, axis):
4141
I_rv = env["I_rv"]
4242
M_rv = env["M_rv"]
4343

44-
with pytest.raises(RuntimeError, match="could not be derived: {m}"):
44+
with pytest.raises(
45+
RuntimeError,
46+
match="The logprob terms of the following random variables could not be derived: {M}",
47+
):
4548
conditional_logprob(M_rv, I_rv, X_rv)
4649

4750
with pytest.raises(NotImplementedError):

tests/test_transforms.py

Lines changed: 41 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pytest
55
import scipy as sp
66
import scipy.special
7-
from aesara.graph.basic import equal_computations
87
from aesara.graph.fg import FunctionGraph
98
from numdifftools import Jacobian
109

@@ -22,7 +21,6 @@
2221
TransformValuesMapping,
2322
TransformValuesRewrite,
2423
_default_transformed_rv,
25-
transformed_variable,
2624
)
2725
from tests.utils import assert_no_rvs
2826

@@ -176,15 +174,13 @@ def test_transformed_logprob(at_dist, dist_params, sp_dist, size):
176174

177175
a = at_dist(*dist_params, size=size)
178176
a.name = "a"
179-
a_value_var = at.tensor(a.dtype, shape=(None,) * a.ndim)
180-
a_value_var.name = "a_value"
181177

182178
b = at.random.normal(a, 1.0)
183179
b.name = "b"
184180

185-
transform_rewrite = TransformValuesRewrite({a_value_var: DEFAULT_TRANSFORM})
186-
res, (b_value_var,) = joint_logprob(
187-
b, realized={a: a_value_var}, extra_rewrites=transform_rewrite
181+
transform_rewrite = TransformValuesRewrite({a: DEFAULT_TRANSFORM})
182+
res, (b_value_var, a_value_var) = joint_logprob(
183+
b, a, extra_rewrites=transform_rewrite
188184
)
189185

190186
test_val_rng = np.random.RandomState(3238)
@@ -268,12 +264,10 @@ def a_backward_fn_(x):
268264
@pytest.mark.parametrize("use_jacobian", [True, False])
269265
def test_simple_transformed_logprob_nojac(use_jacobian):
270266
X_rv = at.random.halfnormal(0, 3, name="X")
271-
x_vv = X_rv.clone()
272-
x_vv.name = "x"
273267

274-
transform_rewrite = TransformValuesRewrite({x_vv: DEFAULT_TRANSFORM})
275-
tr_logp, _ = joint_logprob(
276-
realized={X_rv: x_vv},
268+
transform_rewrite = TransformValuesRewrite({X_rv: DEFAULT_TRANSFORM})
269+
tr_logp, (x_vv,) = joint_logprob(
270+
X_rv,
277271
extra_rewrites=transform_rewrite,
278272
use_jacobian=use_jacobian,
279273
)
@@ -321,19 +315,17 @@ def test_hierarchical_uniform_transform():
321315
upper_rv = at.random.uniform(9, 10, name="upper")
322316
x_rv = at.random.uniform(lower_rv, upper_rv, name="x")
323317

324-
lower = lower_rv.clone()
325-
upper = upper_rv.clone()
326-
x = x_rv.clone()
327-
328318
transform_rewrite = TransformValuesRewrite(
329319
{
330-
lower: DEFAULT_TRANSFORM,
331-
upper: DEFAULT_TRANSFORM,
332-
x: DEFAULT_TRANSFORM,
320+
lower_rv: DEFAULT_TRANSFORM,
321+
upper_rv: DEFAULT_TRANSFORM,
322+
x_rv: DEFAULT_TRANSFORM,
333323
}
334324
)
335-
logp, _ = joint_logprob(
336-
realized={lower_rv: lower, upper_rv: upper, x_rv: x},
325+
logp, (lower, upper, x) = joint_logprob(
326+
lower_rv,
327+
upper_rv,
328+
x_rv,
337329
extra_rewrites=transform_rewrite,
338330
)
339331

@@ -346,20 +338,18 @@ def test_nondefault_transforms():
346338
scale_rv = at.random.uniform(-1, 1, name="scale")
347339
x_rv = at.random.normal(loc_rv, scale_rv, name="x")
348340

349-
loc = loc_rv.clone()
350-
scale = scale_rv.clone()
351-
x = x_rv.clone()
352-
353341
transform_rewrite = TransformValuesRewrite(
354342
{
355-
loc: None,
356-
scale: LogOddsTransform(),
357-
x: LogTransform(),
343+
loc_rv: None,
344+
scale_rv: LogOddsTransform(),
345+
x_rv: LogTransform(),
358346
}
359347
)
360348

361-
logp, _ = joint_logprob(
362-
realized={loc_rv: loc, scale_rv: scale, x_rv: x},
349+
logp, (loc, scale, x) = joint_logprob(
350+
loc_rv,
351+
scale_rv,
352+
x_rv,
363353
extra_rewrites=transform_rewrite,
364354
)
365355

@@ -391,12 +381,11 @@ def test_default_transform_multiout():
391381
# multiple outputs and no default output.
392382
sd = at.linalg.svd(at.eye(1))[1][0]
393383
x_rv = at.random.normal(0, sd, name="x")
394-
x = x_rv.clone()
395384

396-
transform_rewrite = TransformValuesRewrite({x: DEFAULT_TRANSFORM})
385+
transform_rewrite = TransformValuesRewrite({x_rv: DEFAULT_TRANSFORM})
397386

398-
logp, _ = joint_logprob(
399-
realized={x_rv: x},
387+
logp, (x,) = joint_logprob(
388+
x_rv,
400389
extra_rewrites=transform_rewrite,
401390
)
402391

@@ -412,12 +401,11 @@ def test_nonexistent_default_transform():
412401
transform does not fail
413402
"""
414403
x_rv = at.random.normal(name="x")
415-
x = x_rv.clone()
416404

417-
transform_rewrite = TransformValuesRewrite({x: DEFAULT_TRANSFORM})
405+
transform_rewrite = TransformValuesRewrite({x_rv: DEFAULT_TRANSFORM})
418406

419-
logp, _ = joint_logprob(
420-
realized={x_rv: x},
407+
logp, (x,) = joint_logprob(
408+
x_rv,
421409
extra_rewrites=transform_rewrite,
422410
)
423411

@@ -446,9 +434,8 @@ def test_original_values_output_dict():
446434
the logprob factor
447435
"""
448436
p_rv = at.random.beta(1, 1, name="p")
449-
p_vv = p_rv.clone()
450437

451-
tr = TransformValuesRewrite({p_vv: DEFAULT_TRANSFORM})
438+
tr = TransformValuesRewrite({p_rv: DEFAULT_TRANSFORM})
452439
logp_dict, _ = conditional_logprob(p_rv, extra_rewrites=tr)
453440

454441
assert p_rv in logp_dict
@@ -469,29 +456,28 @@ def test_mixture_transform():
469456
Y_rv = at.stack([Y_1_rv, Y_2_rv])[I_rv]
470457
Y_rv.name = "Y"
471458

472-
logp_no_trans, (y_vv, i_vv) = joint_logprob(Y_rv, I_rv)
459+
logp, (y_vv, i_vv) = joint_logprob(
460+
Y_rv,
461+
I_rv,
462+
)
473463

474-
transform_rewrite = TransformValuesRewrite({y_vv: LogTransform()})
464+
transform_rewrite = TransformValuesRewrite({Y_rv: LogOddsTransform()})
475465

476466
with pytest.warns(None) as record:
477467
# This shouldn't raise any warnings
478-
logp_trans, _ = joint_logprob(
479-
realized={Y_rv: y_vv, I_rv: i_vv},
468+
logp_trans, (y_vv_trans, i_vv_trans) = joint_logprob(
469+
Y_rv,
470+
I_rv,
480471
extra_rewrites=transform_rewrite,
481472
use_jacobian=False,
482473
)
483474

484475
assert not record.list
485476

486-
# The untransformed graph should be the same as the transformed graph after
487-
# replacing the `Y_rv` value variable with a transformed version of itself
488-
logp_nt_fg = FunctionGraph(outputs=[logp_no_trans], clone=False)
489-
y_trans = transformed_variable(at.exp(y_vv), y_vv)
490-
y_trans.name = "y_log"
491-
logp_nt_fg.replace(y_vv, y_trans)
492-
logp_nt = logp_nt_fg.outputs[0]
493-
494-
assert equal_computations([logp_nt], [logp_trans])
477+
logp_fn = aesara.function((i_vv, y_vv), logp)
478+
logp_trans_fn = aesara.function((i_vv_trans, y_vv_trans), logp_trans)
479+
np.isclose(logp_trans_fn(0, np.log(0.1 / 0.9)), logp_fn(0, 0.1))
480+
np.isclose(logp_trans_fn(1, np.log(0.1 / 0.9)), logp_fn(1, 0.1))
495481

496482

497483
def test_invalid_interval_transform():
@@ -642,11 +628,10 @@ def test_scale_transform_rv(rv_size, scale_type):
642628
def test_transformed_rv_and_value():
643629
y_rv = at.random.halfnormal(-1, 1, name="base_rv") + 1
644630
y_rv.name = "y"
645-
y_vv = y_rv.clone()
646631

647-
transform_rewrite = TransformValuesRewrite({y_vv: LogTransform()})
632+
transform_rewrite = TransformValuesRewrite({y_rv: LogTransform()})
648633

649-
logp, _ = joint_logprob(realized={y_rv: y_vv}, extra_rewrites=transform_rewrite)
634+
logp, (y_vv,) = joint_logprob(y_rv, extra_rewrites=transform_rewrite)
650635
assert_no_rvs(logp)
651636
logp_fn = aesara.function([y_vv], logp)
652637

0 commit comments

Comments
 (0)