Skip to content

Commit 999661c

Browse files
ferrineJunpeng Lao
authored and
Junpeng Lao
committed
fix PyMC3 variable is not replaced if provided in more_replacements (VI) (#2891)
* fixes #2890 * float32 y * update release notes * use floatX
1 parent b385791 commit 999661c

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
- `VonMises` does not overflow for large values of kappa. i0 and i1 have been removed and we now use log_i0 to compute the logp.
2525
- The bandwidth for KDE plots is computed using a modified version of Scott's rule. The new version uses entropy instead of standard deviation. This works better for multimodal distributions. Functions using KDE plots has a new argument `bw` controlling the bandwidth.
26+
- fix PyMC3 variable is not replaced if provided in more_replacements (#2890)
2627

2728
### Deprecations
2829

pymc3/tests/test_variational_inference.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,20 @@ def test_sample_replacements(binomial_model_inference):
835835
assert sampled.shape[0] == 101
836836

837837

838+
def test_var_replacement():
839+
X_mean = pm.floatX(np.linspace(0, 10, 10))
840+
y = pm.floatX(np.random.normal(X_mean*4, .05))
841+
with pm.Model():
842+
inp = pm.Normal('X', X_mean, shape=X_mean.shape)
843+
coef = pm.Normal('b', 4.)
844+
mean = inp * coef
845+
pm.Normal('y', mean, .1, observed=y)
846+
advi = pm.fit(100)
847+
assert advi.sample_node(mean).eval().shape == (10, )
848+
x_new = pm.floatX(np.linspace(0, 10, 11))
849+
assert advi.sample_node(mean, more_replacements={inp: x_new}).eval().shape == (11, )
850+
851+
838852
def test_empirical_from_trace(another_simple_model):
839853
with another_simple_model:
840854
step = pm.Metropolis()

pymc3/variational/opvi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,7 @@ def sample_node(self, node, size=None,
14591459
sampled node(s) with replacements
14601460
"""
14611461
node_in = node
1462+
node = theano.clone(node, more_replacements)
14621463
if size is None:
14631464
node_out = self.symbolic_single_sample(node)
14641465
else:

0 commit comments

Comments
 (0)