Skip to content

Commit 66d21e1

Browse files
author
Junpeng Lao
authored
Fix sample_ppc to use all chains (pymc-devs#2725)
* Fix sample_ppc to use all chains close pymc-devs#2633 Index to samples in all chains for sample_ppc and sample_ppc_w * add release note
1 parent a1976a8 commit 66d21e1

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

RELEASE-NOTES.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
# Release Notes
22

3-
## PyMC 3.3. (Unreleased)
43

5-
### New features
6-
7-
- Improve NUTS initialization `advi+adapt_diag_grad` and add `jitter+adapt_diag_grad` (#2643)
4+
## PyMC 3.3. (Unreleased)
5+
6+
### New features
7+
8+
- Improve NUTS initialization `advi+adapt_diag_grad` and add `jitter+adapt_diag_grad` (#2643)
9+
10+
### Fixes
11+
- Fixed `compareplot` to use `loo` output.
12+
- Add test for `model.logp_array` and `model.bijection` (#2724)
13+
- Fixed `sample_ppc` and `sample_ppc_w` to iterate all chains(#2633)
814

9-
### Fixes
10-
- Fixed `compareplot` to use `loo` output.
11-
- Add test for `model.logp_array` and `model.bijection` (#2724)
1215

1316

1417
## PyMC3 3.2 (October 10, 2017)

pymc3/sampling.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,9 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
661661
Dictionary with the variables as keys. The values corresponding to the
662662
posterior predictive samples.
663663
"""
664+
len_trace = len(trace)
665+
nchain = trace.nchains
666+
664667
if samples is None:
665668
samples = len(trace)
666669

@@ -671,14 +674,15 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
671674

672675
np.random.seed(random_seed)
673676

674-
indices = np.random.randint(0, len(trace), samples)
677+
indices = np.random.randint(0, nchain*len_trace, samples)
678+
chain_idx, point_idx = np.divmod(indices, len_trace)
675679
if progressbar:
676680
indices = tqdm(indices, total=samples)
677681

678682
try:
679683
ppc = defaultdict(list)
680-
for idx in indices:
681-
param = trace[idx]
684+
for idx in zip(chain_idx, point_idx):
685+
param = trace._straces[idx[0]].point(idx[1])
682686
for var in vars:
683687
ppc[var.name].append(var.distribution.random(point=param,
684688
size=size))
@@ -751,14 +755,21 @@ def sample_ppc_w(traces, samples=None, models=None, weights=None,
751755
weights = np.asarray(weights)
752756
p = weights / np.sum(weights)
753757

754-
min_tr = min([len(i) for i in traces])
758+
min_tr = min([len(i)*i.nchains for i in traces])
755759

756760
n = (min_tr * p).astype('int')
757761
# ensure n sum up to min_tr
758762
idx = np.argmax(n)
759763
n[idx] = n[idx] + min_tr - np.sum(n)
760-
trace = np.concatenate([np.random.choice(traces[i], j)
761-
for i, j in enumerate(n)])
764+
trace = []
765+
for i, j in enumerate(n):
766+
tr = traces[i]
767+
len_trace = len(tr)
768+
nchain = tr.nchains
769+
indices = np.random.randint(0, nchain*len_trace, j)
770+
chain_idx, point_idx = np.divmod(indices, len_trace)
771+
for idx in zip(chain_idx, point_idx):
772+
trace.append(tr._straces[idx[0]].point(idx[1]))
762773

763774
obs = [x for m in models for x in m.observed_RVs]
764775
variables = np.repeat(obs, n)

0 commit comments

Comments
 (0)