Skip to content

Commit 2ac88af

Browse files
committed
Rename basic "joint_logprob" functions to "conditional_logp"
1 parent a32c5e7 commit 2ac88af

19 files changed

+159
-137
lines changed

pymc/logprob/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@
3434
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3535
# SOFTWARE.
3636

37-
from pymc.logprob.basic import factorized_joint_logprob, icdf, joint_logp, logcdf, logp
37+
from pymc.logprob.basic import (
38+
conditional_logp,
39+
icdf,
40+
logcdf,
41+
logp,
42+
transformed_conditional_logp,
43+
)
3844

3945
# isort: off
4046
# Add rewrites to the DBs

pymc/logprob/basic.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,15 @@ def icdf(
163163
)
164164

165165

166-
def factorized_joint_logprob(
166+
def conditional_logp(
167167
rv_values: Dict[TensorVariable, TensorVariable],
168168
warn_missing_rvs: bool = True,
169169
ir_rewriter: Optional[GraphRewriter] = None,
170170
extra_rewrites: Optional[Union[GraphRewriter, NodeRewriter]] = None,
171171
**kwargs,
172172
) -> Dict[TensorVariable, TensorVariable]:
173-
r"""Create a map between variables and their log-probabilities such that the
174-
sum is their joint log-probability.
173+
r"""Create a map between variables and conditional log-probabilities
174+
such that the sum is their joint log-probability.
175175
176176
The `rv_values` dictionary specifies a joint probability graph defined by
177177
pairs of random variables and respective measure-space input parameters
@@ -189,20 +189,21 @@ def factorized_joint_logprob(
189189
190190
.. math::
191191
192-
\sigma^2 \sim& \operatorname{InvGamma}(0.5, 0.5) \\
193-
Y \sim& \operatorname{N}(0, \sigma^2)
192+
\Sigma^2 \sim& \operatorname{InvGamma}(0.5, 0.5) \\
193+
Y \sim& \operatorname{N}(0, \Sigma)
194194
195195
If we create a value variable for ``Y_rv``, i.e. ``y_vv = pt.scalar("y")``,
196-
the graph of ``factorized_joint_logprob({Y_rv: y_vv})`` is equivalent to the
197-
conditional probability :math:`\log p(Y = y \mid \sigma^2)`, with a stochastic
196+
the graph of ``conditional_logp({Y_rv: y_vv})`` is equivalent to the
197+
conditional log-probability :math:`\log p(Y = y \mid \Sigma^2)`, with a stochastic
198198
``sigma2_rv``. If we specify a value variable for ``sigma2_rv``, i.e.
199-
``s_vv = pt.scalar("s2")``, then ``factorized_joint_logprob({Y_rv: y_vv, sigma2_rv: s_vv})``
200-
yields the joint log-probability of the two variables.
199+
``s_vv = pt.scalar("s2")``, then ``conditional_logp({Y_rv: y_vv, sigma2_rv: s_vv})``
200+
yields the conditional log-probabilities of the two variables.
201+
The sum of the two terms gives their joint log-probability.
201202
202203
.. math::
203204
204-
\log p(Y = y, \sigma^2 = s) =
205-
\log p(Y = y \mid \sigma^2 = s) + \log p(\sigma^2 = s)
205+
\log p(Y = y, \Sigma^2 = \sigma^2) =
206+
\log p(Y = y \mid \Sigma^2 = \sigma^2) + \log p(\Sigma^2 = \sigma^2)
206207
207208
208209
Parameters
@@ -223,7 +224,7 @@ def factorized_joint_logprob(
223224
224225
Returns
225226
-------
226-
A ``dict`` that maps each value variable to the log-probability factor derived
227+
A ``dict`` that maps each value variable to the conditional log-probability term derived
227228
from the respective `RandomVariable`.
228229
229230
"""
@@ -309,7 +310,7 @@ def factorized_joint_logprob(
309310

310311
if q_value_var in logprob_vars:
311312
raise ValueError(
312-
f"More than one logprob factor was assigned to the value var {q_value_var}"
313+
f"More than one logprob term was assigned to the value var {q_value_var}"
313314
)
314315

315316
logprob_vars[q_value_var] = q_logprob_var
@@ -337,16 +338,19 @@ def factorized_joint_logprob(
337338
return logprob_vars
338339

339340

340-
def joint_logp(
341+
def transformed_conditional_logp(
341342
rvs: Sequence[TensorVariable],
342343
*,
343344
rvs_to_values: Dict[TensorVariable, TensorVariable],
344345
rvs_to_transforms: Dict[TensorVariable, RVTransform],
345346
jacobian: bool = True,
346347
**kwargs,
347348
) -> List[TensorVariable]:
348-
"""Thin wrapper around pymc.logprob.factorized_joint_logprob, extended with Model
349-
specific concerns such as transforms, jacobian, and scaling"""
349+
"""Thin wrapper around conditional_logprob, which creates a value transform rewrite.
350+
351+
This helper will only return the subset of logprob terms corresponding to `rvs`.
352+
All rvs_to_values and rvs_to_transforms mappings are required.
353+
"""
350354

351355
transform_rewrite = None
352356
values_to_transforms = {
@@ -359,7 +363,7 @@ def joint_logp(
359363
transform_rewrite = TransformValuesRewrite(values_to_transforms) # type: ignore
360364

361365
kwargs.setdefault("warn_missing_rvs", False)
362-
temp_logp_terms = factorized_joint_logprob(
366+
temp_logp_terms = conditional_logp(
363367
rvs_to_values,
364368
extra_rewrites=transform_rewrite,
365369
use_jacobian=jacobian,
@@ -381,3 +385,21 @@ def joint_logp(
381385
raise ValueError(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions)
382386

383387
return logp_terms_list
388+
389+
390+
def factorized_joint_logprob(*args, **kwargs):
391+
warnings.warn(
392+
"`factorized_joint_logprob` was renamed to `conditional_logp`. "
393+
"The function will be removed in a future release",
394+
FutureWarning,
395+
)
396+
return conditional_logp(*args, **kwargs)
397+
398+
399+
def joint_logp(*args, **kwargs):
400+
warnings.warn(
401+
"`joint_logp` was renamed to `transformed_conditional_logp`. "
402+
"The function will be removed in a future release",
403+
FutureWarning,
404+
)
405+
return transformed_conditional_logp(*args, **kwargs)

pymc/logprob/scan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from pytensor.updates import OrderedUpdates
5555

5656
from pymc.logprob.abstract import MeasurableVariable, _logprob
57-
from pymc.logprob.basic import factorized_joint_logprob
57+
from pymc.logprob.basic import conditional_logp
5858
from pymc.logprob.rewriting import (
5959
PreserveRVMappings,
6060
construct_ir_fgraph,
@@ -310,7 +310,7 @@ def logprob_ScanRV(op, values, *inputs, name=None, **kwargs):
310310

311311
def create_inner_out_logp(value_map: Dict[TensorVariable, TensorVariable]) -> TensorVariable:
312312
"""Create a log-likelihood inner-output for a `Scan`."""
313-
logp_parts = factorized_joint_logprob(value_map, warn_rvs=False)
313+
logp_parts = conditional_logp(value_map, warn_rvs=False)
314314
return logp_parts.values()
315315

316316
logp_scan_args = convert_outer_out_to_in(

pymc/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
ShapeWarning,
6565
)
6666
from pymc.initial_point import make_initial_point_fn
67-
from pymc.logprob.basic import joint_logp
67+
from pymc.logprob.basic import transformed_conditional_logp
6868
from pymc.logprob.utils import ParameterValueError
6969
from pymc.pytensorf import (
7070
PointFunc,
@@ -761,7 +761,7 @@ def logp(
761761

762762
rv_logps: List[TensorVariable] = []
763763
if rvs:
764-
rv_logps = joint_logp(
764+
rv_logps = transformed_conditional_logp(
765765
rvs=rvs,
766766
rvs_to_values=self.rvs_to_values,
767767
rvs_to_transforms=self.rvs_to_transforms,

pymc/testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from pymc.distributions.distribution import Distribution
3636
from pymc.distributions.shape_utils import change_dist_size
3737
from pymc.initial_point import make_initial_point_fn
38-
from pymc.logprob.basic import icdf, joint_logp, logcdf, logp
38+
from pymc.logprob.basic import icdf, logcdf, logp, transformed_conditional_logp
3939
from pymc.logprob.utils import ParameterValueError, find_rvs_in_graph
4040
from pymc.pytensorf import (
4141
compile_pymc,
@@ -673,7 +673,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):
673673

674674
if check_finite_logp:
675675
logp_moment = (
676-
joint_logp(
676+
transformed_conditional_logp(
677677
(model["x"],),
678678
rvs_to_values={model["x"]: pt.constant(moment)},
679679
rvs_to_transforms={},

tests/distributions/test_transform.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pymc as pm
2626
import pymc.distributions.transforms as tr
2727

28-
from pymc.logprob.basic import joint_logp
28+
from pymc.logprob.basic import transformed_conditional_logp
2929
from pymc.pytensorf import floatX, jacobian
3030
from pymc.testing import (
3131
Circ,
@@ -308,7 +308,7 @@ def check_transform_elementwise_logp(self, model):
308308
assert model.logp(x, sum=False)[0].ndim == x.ndim == jacob_det.ndim
309309

310310
v1 = (
311-
joint_logp(
311+
transformed_conditional_logp(
312312
(x,),
313313
rvs_to_values={x: x_val_transf},
314314
rvs_to_transforms={x: transform},
@@ -318,7 +318,7 @@ def check_transform_elementwise_logp(self, model):
318318
.eval({x_val_transf: test_array_transf})
319319
)
320320
v2 = (
321-
joint_logp(
321+
transformed_conditional_logp(
322322
(x,),
323323
rvs_to_values={x: x_val_untransf},
324324
rvs_to_transforms={},
@@ -356,7 +356,7 @@ def check_vectortransform_elementwise_logp(self, model):
356356
assert model.logp(x, sum=False)[0].ndim == (x.ndim - 1) == jacob_det.ndim
357357

358358
a = (
359-
joint_logp(
359+
transformed_conditional_logp(
360360
(x,),
361361
rvs_to_values={x: x_val_transf},
362362
rvs_to_transforms={x: transform},
@@ -366,7 +366,7 @@ def check_vectortransform_elementwise_logp(self, model):
366366
.eval({x_val_transf: test_array_transf})
367367
)
368368
b = (
369-
joint_logp(
369+
transformed_conditional_logp(
370370
(x,),
371371
rvs_to_values={x: x_val_untransf},
372372
rvs_to_transforms={},

tests/logprob/test_basic.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@
5555

5656
import pymc as pm
5757

58-
from pymc.logprob.basic import factorized_joint_logprob, icdf, joint_logp, logcdf, logp
58+
from pymc.logprob.basic import (
59+
conditional_logp,
60+
icdf,
61+
logcdf,
62+
logp,
63+
transformed_conditional_logp,
64+
)
5965
from pymc.logprob.transforms import LogTransform
6066
from pymc.logprob.utils import rvs_to_value_vars, walk_model
6167
from pymc.pytensorf import replace_rvs_by_values
@@ -68,7 +74,7 @@ def test_factorized_joint_logprob_basic():
6874
a.name = "a"
6975
a_value_var = a.clone()
7076

71-
a_logp = factorized_joint_logprob({a: a_value_var})
77+
a_logp = conditional_logp({a: a_value_var})
7278
a_logp_comb = tuple(a_logp.values())[0]
7379
a_logp_exp = logp(a, a_value_var)
7480

@@ -81,7 +87,7 @@ def test_factorized_joint_logprob_basic():
8187
sigma_value_var = sigma.clone()
8288
y_value_var = Y.clone()
8389

84-
total_ll = factorized_joint_logprob({Y: y_value_var, sigma: sigma_value_var})
90+
total_ll = conditional_logp({Y: y_value_var, sigma: sigma_value_var})
8591
total_ll_combined = pt.add(*total_ll.values())
8692

8793
# We need to replace the reference to `sigma` in `Y` with its value
@@ -106,7 +112,7 @@ def test_factorized_joint_logprob_basic():
106112
b_value_var = b.clone()
107113
c_value_var = c.clone()
108114

109-
b_logp = factorized_joint_logprob({a: a_value_var, b: b_value_var, c: c_value_var})
115+
b_logp = conditional_logp({a: a_value_var, b: b_value_var, c: c_value_var})
110116
b_logp_combined = pt.sum([pt.sum(factor) for factor in b_logp.values()])
111117

112118
# There shouldn't be any `RandomVariable`s in the resulting graph
@@ -125,7 +131,7 @@ def test_factorized_joint_logprob_multi_obs():
125131
a_val = a.clone()
126132
b_val = b.clone()
127133

128-
logp_res = factorized_joint_logprob({a: a_val, b: b_val})
134+
logp_res = conditional_logp({a: a_val, b: b_val})
129135
logp_res_combined = pt.add(*logp_res.values())
130136
logp_exp = logp(a, a_val) + logp(b, b_val)
131137

@@ -137,8 +143,8 @@ def test_factorized_joint_logprob_multi_obs():
137143
x_val = x.clone()
138144
y_val = y.clone()
139145

140-
logp_res = factorized_joint_logprob({x: x_val, y: y_val})
141-
exp_logp = factorized_joint_logprob({x: x_val, y: y_val})
146+
logp_res = conditional_logp({x: x_val, y: y_val})
147+
exp_logp = conditional_logp({x: x_val, y: y_val})
142148
logp_res_comb = pt.sum([pt.sum(factor) for factor in logp_res.values()])
143149
exp_logp_comb = pt.sum([pt.sum(factor) for factor in exp_logp.values()])
144150

@@ -155,7 +161,7 @@ def test_factorized_joint_logprob_diff_dims():
155161
y_vv = y.clone()
156162
y_vv.name = "y"
157163

158-
logp = factorized_joint_logprob({x: x_vv, y: y_vv})
164+
logp = conditional_logp({x: x_vv, y: y_vv})
159165
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
160166

161167
M_val = np.random.normal(size=(10, 3))
@@ -181,7 +187,7 @@ def test_incsubtensor_original_values_output_dict():
181187
rv = pt.set_subtensor(base_rv[0], 5)
182188
vv = rv.clone()
183189

184-
logp_dict = factorized_joint_logprob({rv: vv})
190+
logp_dict = conditional_logp({rv: vv})
185191
assert vv in logp_dict
186192

187193

@@ -194,14 +200,14 @@ def test_persist_inputs():
194200
beta_vv = beta_rv.type()
195201
y_vv = Y_rv.clone()
196202

197-
logp = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv})
203+
logp = conditional_logp({beta_rv: beta_vv, Y_rv: y_vv})
198204
logp_combined = pt.sum([pt.sum(factor) for factor in logp.values()])
199205

200206
assert x in ancestors([logp_combined])
201207

202208
# Make sure we don't clone value variables when they're graphs.
203209
y_vv_2 = y_vv * 2
204-
logp_2 = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
210+
logp_2 = conditional_logp({beta_rv: beta_vv, Y_rv: y_vv_2})
205211
logp_2_combined = pt.sum([pt.sum(factor) for factor in logp_2.values()])
206212

207213
assert y_vv in ancestors([logp_2_combined])
@@ -210,7 +216,7 @@ def test_persist_inputs():
210216
# Even when they are random
211217
y_vv = pt.random.normal(name="y_vv2")
212218
y_vv_2 = y_vv * 2
213-
logp_2 = factorized_joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})
219+
logp_2 = conditional_logp({beta_rv: beta_vv, Y_rv: y_vv_2})
214220
logp_2_combined = pt.sum([pt.sum(factor) for factor in logp_2.values()])
215221

216222
assert y_vv in ancestors([logp_2_combined])
@@ -224,11 +230,11 @@ def test_warn_random_found_factorized_joint_logprob():
224230
y_vv = y_rv.clone()
225231

226232
with pytest.warns(UserWarning, match="Random variables detected in the logp graph: {x}"):
227-
factorized_joint_logprob({y_rv: y_vv})
233+
conditional_logp({y_rv: y_vv})
228234

229235
with warnings.catch_warnings():
230236
warnings.simplefilter("error")
231-
factorized_joint_logprob({y_rv: y_vv}, warn_missing_rvs=False)
237+
conditional_logp({y_rv: y_vv}, warn_missing_rvs=False)
232238

233239

234240
def test_multiple_rvs_to_same_value_raises():
@@ -237,9 +243,9 @@ def test_multiple_rvs_to_same_value_raises():
237243
x = x_rv1.type()
238244
x.name = "x"
239245

240-
msg = "More than one logprob factor was assigned to the value var x"
246+
msg = "More than one logprob term was assigned to the value var x"
241247
with pytest.raises(ValueError, match=msg):
242-
factorized_joint_logprob({x_rv1: x, x_rv2: x})
248+
conditional_logp({x_rv1: x, x_rv2: x})
243249

244250

245251
def test_joint_logp_basic():
@@ -259,7 +265,7 @@ def test_joint_logp_basic():
259265

260266
c_value_var = m.rvs_to_values[c]
261267

262-
(b_logp,) = joint_logp(
268+
(b_logp,) = transformed_conditional_logp(
263269
(b,),
264270
rvs_to_values=m.rvs_to_values,
265271
rvs_to_transforms=m.rvs_to_transforms,
@@ -304,7 +310,7 @@ def test_joint_logp_incsubtensor(indices, size):
304310
a_idx_value_var = a_idx.type()
305311
a_idx_value_var.name = "a_idx_value"
306312

307-
a_idx_logp = joint_logp(
313+
a_idx_logp = transformed_conditional_logp(
308314
(a_idx,),
309315
rvs_to_values={a_idx: a_value_var},
310316
rvs_to_transforms={},

0 commit comments

Comments
 (0)