Skip to content

Commit a395e97

Browse files
Junpeng Laoaloctavodia
Junpeng Lao
authored andcommitted
sample_ppc bug fix (#2748)
* sample_ppc bug fix My fix in #2725 to use all chains breaks the `sample_ppc([point]...)` and also the progress bar. These issue should be fix here now and also add test for sample_ppc from a list. * fixe sample_ppc_w, edited release note * Add docstring, remove space
1 parent dee6575 commit a395e97

File tree

3 files changed

+38
-16
lines changed

3 files changed

+38
-16
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- Update loo, new improved algorithm (#2730)
1313
- New CSG (Constant Stochastic Gradient) approximate posterior sampling
1414
algorithm (#2544)
15+
- Michael Osthege added support for population-samplers and implemented differential evolution metropolis (`DEMetropolis`). For models with correlated dimensions that can not use gradient-based samplers, the `DEMetropolis` sampler can give higher effective sampling rates. (also see [PR#2735](https://github.com/pymc-devs/pymc3/pull/2735))
1516

1617
### Fixes
1718

@@ -20,12 +21,9 @@
2021
- `sample_ppc_w` now broadcasts
2122
- `df_summary` function renamed to `summary`
2223
- Add test for `model.logp_array` and `model.bijection` (#2724)
23-
- Fixed `sample_ppc` and `sample_ppc_w` to iterate all chains(#2633)
24+
- Fixed `sample_ppc` and `sample_ppc_w` to iterate all chains(#2633, #2748)
2425
- Add Bayesian R2 score (for GLMs) `stats.r2_score` (#2696) and test (#2729).
2526

26-
### New Features
27-
- Michael Osthege added support for population-samplers and implemented differential evolution metropolis (`DEMetropolis`). For models with correlated dimensions that can not use gradient-based samplers, the `DEMetropolis` sampler can give higher effective sampling rates. (also see [PR#2735](https://github.com/pymc-devs/pymc3/pull/2735))
28-
2927

3028
## PyMC3 3.2 (October 10, 2017)
3129

pymc3/sampling.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,8 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
944944
Parameters
945945
----------
946946
trace : backend, list, or MultiTrace
947-
Trace generated from MCMC sampling.
947+
Trace generated from MCMC sampling. Or a list containing dicts from
948+
find_MAP() or points
948949
samples : int
949950
Number of posterior predictive samples to generate. Defaults to the
950951
length of `trace`
@@ -971,7 +972,10 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
971972
posterior predictive samples.
972973
"""
973974
len_trace = len(trace)
974-
nchain = trace.nchains
975+
try:
976+
nchain = trace.nchains
977+
except AttributeError:
978+
nchain = 1
975979

976980
if samples is None:
977981
samples = len(trace)
@@ -984,14 +988,19 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
984988
np.random.seed(random_seed)
985989

986990
indices = np.random.randint(0, nchain*len_trace, samples)
987-
chain_idx, point_idx = np.divmod(indices, len_trace)
991+
988992
if progressbar:
989993
indices = tqdm(indices, total=samples)
990994

991995
try:
992996
ppc = defaultdict(list)
993-
for idx in zip(chain_idx, point_idx):
994-
param = trace._straces[idx[0]].point(idx[1])
997+
for idx in indices:
998+
if nchain > 1:
999+
chain_idx, point_idx = np.divmod(idx, len_trace)
1000+
param = trace._straces[chain_idx].point(point_idx)
1001+
else:
1002+
param = trace[idx]
1003+
9951004
for var in vars:
9961005
ppc[var.name].append(var.distribution.random(point=param,
9971006
size=size))
@@ -1013,8 +1022,9 @@ def sample_ppc_w(traces, samples=None, models=None, weights=None,
10131022
10141023
Parameters
10151024
----------
1016-
traces : list
1017-
List of traces generated from MCMC sampling. The number of traces should
1025+
traces : list or list of lists
1026+
List of traces generated from MCMC sampling, or a list of list
1027+
containing dicts from find_MAP() or points. The number of traces should
10181028
be equal to the number of weights.
10191029
samples : int
10201030
Number of posterior predictive samples to generate. Defaults to the
@@ -1073,12 +1083,20 @@ def sample_ppc_w(traces, samples=None, models=None, weights=None,
10731083
trace = []
10741084
for i, j in enumerate(n):
10751085
tr = traces[i]
1076-
len_trace = len(tr)
1077-
nchain = tr.nchains
1086+
len_trace = len(tr)
1087+
try:
1088+
nchain = tr.nchains
1089+
except AttributeError:
1090+
nchain = 1
1091+
10781092
indices = np.random.randint(0, nchain*len_trace, j)
1079-
chain_idx, point_idx = np.divmod(indices, len_trace)
1080-
for idx in zip(chain_idx, point_idx):
1081-
trace.append(tr._straces[idx[0]].point(idx[1]))
1093+
if nchain > 1:
1094+
chain_idx, point_idx = np.divmod(indices, len_trace)
1095+
for idx in zip(chain_idx, point_idx):
1096+
trace.append(tr._straces[idx[0]].point(idx[1]))
1097+
else:
1098+
for idx in indices:
1099+
trace.append(tr[idx])
10821100

10831101
obs = [x for m in models for x in m.observed_RVs]
10841102
variables = np.repeat(obs, n)

pymc3/tests/test_sampling.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ def test_normal_scalar(self):
210210
trace = pm.sample()
211211

212212
with model:
213+
# test list input
214+
ppc0 = pm.sample_ppc([model.test_point], samples=10)
213215
ppc = pm.sample_ppc(trace, samples=1000, vars=[])
214216
assert len(ppc) == 0
215217
ppc = pm.sample_ppc(trace, samples=1000, vars=[a])
@@ -228,6 +230,8 @@ def test_normal_vector(self):
228230
trace = pm.sample()
229231

230232
with model:
233+
# test list input
234+
ppc0 = pm.sample_ppc([model.test_point], samples=10)
231235
ppc = pm.sample_ppc(trace, samples=10, vars=[])
232236
assert len(ppc) == 0
233237
ppc = pm.sample_ppc(trace, samples=10, vars=[a])
@@ -245,6 +249,8 @@ def test_sum_normal(self):
245249
trace = pm.sample()
246250

247251
with model:
252+
# test list input
253+
ppc0 = pm.sample_ppc([model.test_point], samples=10)
248254
ppc = pm.sample_ppc(trace, samples=1000, vars=[b])
249255
assert len(ppc) == 1
250256
assert ppc['b'].shape == (1000,)

0 commit comments

Comments
 (0)