4
4
import pytest
5
5
import scipy as sp
6
6
import scipy .special
7
- from aesara .graph .basic import equal_computations
8
7
from aesara .graph .fg import FunctionGraph
9
8
from numdifftools import Jacobian
10
9
22
21
TransformValuesMapping ,
23
22
TransformValuesRewrite ,
24
23
_default_transformed_rv ,
25
- transformed_variable ,
26
24
)
27
25
from tests .utils import assert_no_rvs
28
26
@@ -176,15 +174,13 @@ def test_transformed_logprob(at_dist, dist_params, sp_dist, size):
176
174
177
175
a = at_dist (* dist_params , size = size )
178
176
a .name = "a"
179
- a_value_var = at .tensor (a .dtype , shape = (None ,) * a .ndim )
180
- a_value_var .name = "a_value"
181
177
182
178
b = at .random .normal (a , 1.0 )
183
179
b .name = "b"
184
180
185
- transform_rewrite = TransformValuesRewrite ({a_value_var : DEFAULT_TRANSFORM })
186
- res , (b_value_var ,) = joint_logprob (
187
- b , realized = { a : a_value_var } , extra_rewrites = transform_rewrite
181
+ transform_rewrite = TransformValuesRewrite ({a : DEFAULT_TRANSFORM })
182
+ res , (b_value_var , a_value_var ) = joint_logprob (
183
+ b , a , extra_rewrites = transform_rewrite
188
184
)
189
185
190
186
test_val_rng = np .random .RandomState (3238 )
@@ -268,12 +264,10 @@ def a_backward_fn_(x):
268
264
@pytest .mark .parametrize ("use_jacobian" , [True , False ])
269
265
def test_simple_transformed_logprob_nojac (use_jacobian ):
270
266
X_rv = at .random .halfnormal (0 , 3 , name = "X" )
271
- x_vv = X_rv .clone ()
272
- x_vv .name = "x"
273
267
274
- transform_rewrite = TransformValuesRewrite ({x_vv : DEFAULT_TRANSFORM })
275
- tr_logp , _ = joint_logprob (
276
- realized = { X_rv : x_vv } ,
268
+ transform_rewrite = TransformValuesRewrite ({X_rv : DEFAULT_TRANSFORM })
269
+ tr_logp , ( x_vv ,) = joint_logprob (
270
+ X_rv ,
277
271
extra_rewrites = transform_rewrite ,
278
272
use_jacobian = use_jacobian ,
279
273
)
@@ -321,19 +315,17 @@ def test_hierarchical_uniform_transform():
321
315
upper_rv = at .random .uniform (9 , 10 , name = "upper" )
322
316
x_rv = at .random .uniform (lower_rv , upper_rv , name = "x" )
323
317
324
- lower = lower_rv .clone ()
325
- upper = upper_rv .clone ()
326
- x = x_rv .clone ()
327
-
328
318
transform_rewrite = TransformValuesRewrite (
329
319
{
330
- lower : DEFAULT_TRANSFORM ,
331
- upper : DEFAULT_TRANSFORM ,
332
- x : DEFAULT_TRANSFORM ,
320
+ lower_rv : DEFAULT_TRANSFORM ,
321
+ upper_rv : DEFAULT_TRANSFORM ,
322
+ x_rv : DEFAULT_TRANSFORM ,
333
323
}
334
324
)
335
- logp , _ = joint_logprob (
336
- realized = {lower_rv : lower , upper_rv : upper , x_rv : x },
325
+ logp , (lower , upper , x ) = joint_logprob (
326
+ lower_rv ,
327
+ upper_rv ,
328
+ x_rv ,
337
329
extra_rewrites = transform_rewrite ,
338
330
)
339
331
@@ -346,20 +338,18 @@ def test_nondefault_transforms():
346
338
scale_rv = at .random .uniform (- 1 , 1 , name = "scale" )
347
339
x_rv = at .random .normal (loc_rv , scale_rv , name = "x" )
348
340
349
- loc = loc_rv .clone ()
350
- scale = scale_rv .clone ()
351
- x = x_rv .clone ()
352
-
353
341
transform_rewrite = TransformValuesRewrite (
354
342
{
355
- loc : None ,
356
- scale : LogOddsTransform (),
357
- x : LogTransform (),
343
+ loc_rv : None ,
344
+ scale_rv : LogOddsTransform (),
345
+ x_rv : LogTransform (),
358
346
}
359
347
)
360
348
361
- logp , _ = joint_logprob (
362
- realized = {loc_rv : loc , scale_rv : scale , x_rv : x },
349
+ logp , (loc , scale , x ) = joint_logprob (
350
+ loc_rv ,
351
+ scale_rv ,
352
+ x_rv ,
363
353
extra_rewrites = transform_rewrite ,
364
354
)
365
355
@@ -391,12 +381,11 @@ def test_default_transform_multiout():
391
381
# multiple outputs and no default output.
392
382
sd = at .linalg .svd (at .eye (1 ))[1 ][0 ]
393
383
x_rv = at .random .normal (0 , sd , name = "x" )
394
- x = x_rv .clone ()
395
384
396
- transform_rewrite = TransformValuesRewrite ({x : DEFAULT_TRANSFORM })
385
+ transform_rewrite = TransformValuesRewrite ({x_rv : DEFAULT_TRANSFORM })
397
386
398
- logp , _ = joint_logprob (
399
- realized = { x_rv : x } ,
387
+ logp , ( x ,) = joint_logprob (
388
+ x_rv ,
400
389
extra_rewrites = transform_rewrite ,
401
390
)
402
391
@@ -412,12 +401,11 @@ def test_nonexistent_default_transform():
412
401
transform does not fail
413
402
"""
414
403
x_rv = at .random .normal (name = "x" )
415
- x = x_rv .clone ()
416
404
417
- transform_rewrite = TransformValuesRewrite ({x : DEFAULT_TRANSFORM })
405
+ transform_rewrite = TransformValuesRewrite ({x_rv : DEFAULT_TRANSFORM })
418
406
419
- logp , _ = joint_logprob (
420
- realized = { x_rv : x } ,
407
+ logp , ( x ,) = joint_logprob (
408
+ x_rv ,
421
409
extra_rewrites = transform_rewrite ,
422
410
)
423
411
@@ -446,9 +434,8 @@ def test_original_values_output_dict():
446
434
the logprob factor
447
435
"""
448
436
p_rv = at .random .beta (1 , 1 , name = "p" )
449
- p_vv = p_rv .clone ()
450
437
451
- tr = TransformValuesRewrite ({p_vv : DEFAULT_TRANSFORM })
438
+ tr = TransformValuesRewrite ({p_rv : DEFAULT_TRANSFORM })
452
439
logp_dict , _ = conditional_logprob (p_rv , extra_rewrites = tr )
453
440
454
441
assert p_rv in logp_dict
@@ -469,29 +456,28 @@ def test_mixture_transform():
469
456
Y_rv = at .stack ([Y_1_rv , Y_2_rv ])[I_rv ]
470
457
Y_rv .name = "Y"
471
458
472
- logp_no_trans , (y_vv , i_vv ) = joint_logprob (Y_rv , I_rv )
459
+ logp , (y_vv , i_vv ) = joint_logprob (
460
+ Y_rv ,
461
+ I_rv ,
462
+ )
473
463
474
- transform_rewrite = TransformValuesRewrite ({y_vv : LogTransform ()})
464
+ transform_rewrite = TransformValuesRewrite ({Y_rv : LogOddsTransform ()})
475
465
476
466
with pytest .warns (None ) as record :
477
467
# This shouldn't raise any warnings
478
- logp_trans , _ = joint_logprob (
479
- realized = {Y_rv : y_vv , I_rv : i_vv },
468
+ logp_trans , (y_vv_trans , i_vv_trans ) = joint_logprob (
469
+ Y_rv ,
470
+ I_rv ,
480
471
extra_rewrites = transform_rewrite ,
481
472
use_jacobian = False ,
482
473
)
483
474
484
475
assert not record .list
485
476
486
- # The untransformed graph should be the same as the transformed graph after
487
- # replacing the `Y_rv` value variable with a transformed version of itself
488
- logp_nt_fg = FunctionGraph (outputs = [logp_no_trans ], clone = False )
489
- y_trans = transformed_variable (at .exp (y_vv ), y_vv )
490
- y_trans .name = "y_log"
491
- logp_nt_fg .replace (y_vv , y_trans )
492
- logp_nt = logp_nt_fg .outputs [0 ]
493
-
494
- assert equal_computations ([logp_nt ], [logp_trans ])
477
+ logp_fn = aesara .function ((i_vv , y_vv ), logp )
478
+ logp_trans_fn = aesara .function ((i_vv_trans , y_vv_trans ), logp_trans )
479
+ np .isclose (logp_trans_fn (0 , np .log (0.1 / 0.9 )), logp_fn (0 , 0.1 ))
480
+ np .isclose (logp_trans_fn (1 , np .log (0.1 / 0.9 )), logp_fn (1 , 0.1 ))
495
481
496
482
497
483
def test_invalid_interval_transform ():
@@ -642,11 +628,10 @@ def test_scale_transform_rv(rv_size, scale_type):
642
628
def test_transformed_rv_and_value ():
643
629
y_rv = at .random .halfnormal (- 1 , 1 , name = "base_rv" ) + 1
644
630
y_rv .name = "y"
645
- y_vv = y_rv .clone ()
646
631
647
- transform_rewrite = TransformValuesRewrite ({y_vv : LogTransform ()})
632
+ transform_rewrite = TransformValuesRewrite ({y_rv : LogTransform ()})
648
633
649
- logp , _ = joint_logprob (realized = { y_rv : y_vv } , extra_rewrites = transform_rewrite )
634
+ logp , ( y_vv ,) = joint_logprob (y_rv , extra_rewrites = transform_rewrite )
650
635
assert_no_rvs (logp )
651
636
logp_fn = aesara .function ([y_vv ], logp )
652
637
0 commit comments