@@ -661,6 +661,9 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
661
661
Dictionary with the variables as keys. The values corresponding to the
662
662
posterior predictive samples.
663
663
"""
664
+ len_trace = len (trace )
665
+ nchain = trace .nchains
666
+
664
667
if samples is None :
665
668
samples = len (trace )
666
669
@@ -671,14 +674,15 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
671
674
672
675
np .random .seed (random_seed )
673
676
674
- indices = np .random .randint (0 , len (trace ), samples )
677
+ indices = np .random .randint (0 , nchain * len_trace , samples )
678
+ chain_idx , point_idx = np .divmod (indices , len_trace )
675
679
if progressbar :
676
680
indices = tqdm (indices , total = samples )
677
681
678
682
try :
679
683
ppc = defaultdict (list )
680
- for idx in indices :
681
- param = trace [idx ]
684
+ for idx in zip ( chain_idx , point_idx ) :
685
+ param = trace . _straces [idx [ 0 ]]. point ( idx [ 1 ])
682
686
for var in vars :
683
687
ppc [var .name ].append (var .distribution .random (point = param ,
684
688
size = size ))
@@ -751,14 +755,21 @@ def sample_ppc_w(traces, samples=None, models=None, weights=None,
751
755
weights = np .asarray (weights )
752
756
p = weights / np .sum (weights )
753
757
754
- min_tr = min ([len (i ) for i in traces ])
758
+ min_tr = min ([len (i )* i . nchains for i in traces ])
755
759
756
760
n = (min_tr * p ).astype ('int' )
757
761
# ensure n sum up to min_tr
758
762
idx = np .argmax (n )
759
763
n [idx ] = n [idx ] + min_tr - np .sum (n )
760
- trace = np .concatenate ([np .random .choice (traces [i ], j )
761
- for i , j in enumerate (n )])
764
+ trace = []
765
+ for i , j in enumerate (n ):
766
+ tr = traces [i ]
767
+ len_trace = len (tr )
768
+ nchain = tr .nchains
769
+ indices = np .random .randint (0 , nchain * len_trace , j )
770
+ chain_idx , point_idx = np .divmod (indices , len_trace )
771
+ for idx in zip (chain_idx , point_idx ):
772
+ trace .append (tr ._straces [idx [0 ]].point (idx [1 ]))
762
773
763
774
obs = [x for m in models for x in m .observed_RVs ]
764
775
variables = np .repeat (obs , n )
0 commit comments