|
7 | 7 |
|
8 | 8 |
|
9 | 9 | def traceplot(trace, vars=None, figsize=None,
|
10 |
| - lines=None, combined=False, grid=True, |
| 10 | + lines=None, combined=False, grid=True, |
11 | 11 | alpha=0.35, ax=None):
|
12 | 12 | """Plot samples histograms and values
|
13 | 13 |
|
@@ -149,82 +149,94 @@ def kde2plot(x, y, grid=200, ax=None):
|
149 | 149 | return ax
|
150 | 150 |
|
151 | 151 |
|
152 |
| -def autocorrplot(trace, vars=None, max_lag=100, burn=0, ax=None, |
153 |
| - symmetric_plot=False): |
| 152 | +def autocorrplot(trace, varnames=None, max_lag=100, burn=0, |
| 153 | + symmetric_plot=False, ax=None, figsize=None): |
154 | 154 | """Bar plot of the autocorrelation function for a trace
|
| 155 | +
|
155 | 156 | Parameters
|
156 | 157 | ----------
|
157 | 158 | trace : result of MCMC run
|
158 |
| - vars : list of variable names |
| 159 | + varnames : list of variable names |
159 | 160 | Variables to be plotted, if None all variable are plotted.
|
160 | 161 | Vector-value stochastics are handled automatically.
|
161 | 162 | max_lag : int
|
162 | 163 | Maximum lag to calculate autocorrelation. Defaults to 100.
|
163 | 164 | burn : int
|
164 |
| - Number of samples to discard from the beginning of the trace. |
| 165 | + Number of samples to discard from the beginning of the trace. |
165 | 166 | Defaults to 0.
|
166 |
| - ax : axes |
167 |
| - Matplotlib axes. Defaults to None. |
168 | 167 | symmetric_plot : boolean
|
169 | 168 | Plot from either [0, +lag] or [-lag, lag]. Defaults to False, [-, +lag].
|
170 |
| - |
| 169 | + ax : axes |
| 170 | + Matplotlib axes. Defaults to None. |
| 171 | + figsize : figure size tuple |
| 172 | + If None, size is (12, num of variables * 2) inches. |
| 173 | + Note this is not used if ax is supplied. |
| 174 | +
|
171 | 175 | Returns
|
172 | 176 | -------
|
173 | 177 | ax : matplotlib axes
|
174 | 178 | """
|
175 |
| - |
| 179 | + |
176 | 180 | import matplotlib.pyplot as plt
|
177 |
| - |
178 |
| - def _handle_array_varnames(val): |
179 |
| - if trace[0][val].__class__ is np.ndarray: |
180 |
| - k = trace[val].shape[1] |
181 |
| - for i in xrange(k): |
182 |
| - yield val + '_{0}'.format(i) |
| 181 | + |
| 182 | + def _handle_array_varnames(varname): |
| 183 | + if trace[0][varname].__class__ is np.ndarray: |
| 184 | + k = trace[varname].shape[1] |
| 185 | + for i in range(k): |
| 186 | + yield varname + '_{0}'.format(i) |
183 | 187 | else:
|
184 |
| - yield val |
185 |
| - |
186 |
| - if vars is None: |
187 |
| - vars = [item for sub in [[i for i in _handle_array_varnames(var)] |
188 |
| - for var in trace.varnames] for item in sub] |
| 188 | + yield varname |
| 189 | + |
| 190 | + if varnames is None: |
| 191 | + varnames = trace.varnames |
189 | 192 | else:
|
190 |
| - vars = [str(var) for var in vars] |
191 |
| - vars = [item for sub in [[i for i in _handle_array_varnames(var)] |
192 |
| - for var in vars] for item in sub] |
| 193 | + varnames = [str(v) for v in varnames] |
193 | 194 |
|
194 |
| - chains = trace.nchains |
| 195 | + varnames = [item for sub in [[i for i in _handle_array_varnames(v)] |
| 196 | + for v in varnames] for item in sub] |
195 | 197 |
|
196 |
| - fig, ax = plt.subplots(len(vars), chains, squeeze=False, |
197 |
| - sharex=True, sharey=True) |
| 198 | + nchains = trace.nchains |
| 199 | + |
| 200 | + if figsize is None: |
| 201 | + figsize = (12, len(varnames)*2) |
| 202 | + |
| 203 | + if ax is None: |
| 204 | + fig, ax = plt.subplots(len(varnames), nchains, squeeze=False, |
| 205 | + sharex=True, sharey=True, figsize=figsize) |
| 206 | + elif ax.shape != (len(varnames), nchains): |
| 207 | + raise ValueError('autocorrplot requires {}*{} subplots'.format( |
| 208 | + len(varnames), nchains)) |
| 209 | + return None |
198 | 210 |
|
199 | 211 | max_lag = min(len(trace) - 1, max_lag)
|
200 | 212 |
|
201 |
| - for i, v in enumerate(vars): |
202 |
| - for j in range(chains): |
| 213 | + for i, v in enumerate(varnames): |
| 214 | + for j in range(nchains): |
203 | 215 | try:
|
204 | 216 | d = np.squeeze(trace.get_values(v, chains=[j], burn=burn,
|
205 | 217 | combine=False))
|
206 | 218 | except KeyError:
|
207 | 219 | k = int(v.split('_')[-1])
|
208 | 220 | v_use = '_'.join(v.split('_')[:-1])
|
209 |
| - d = np.squeeze(trace.get_values(v_use, chains=[j], burn=burn, |
210 |
| - combine=False)[:, k]) |
| 221 | + d = np.squeeze(trace.get_values(v_use, chains=[j], |
| 222 | + burn=burn, combine=False)[:, k]) |
211 | 223 |
|
212 | 224 | ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag)
|
213 | 225 |
|
214 | 226 | if not j:
|
215 | 227 | ax[i, j].set_ylabel("correlation")
|
216 |
| - if i == len(vars) - 1: |
| 228 | + if i == len(varnames) - 1: |
217 | 229 | ax[i, j].set_xlabel("lag")
|
218 |
| - |
| 230 | + |
219 | 231 | ax[i, j].set_title(v)
|
220 |
| - |
| 232 | + |
221 | 233 | if not symmetric_plot:
|
222 | 234 | ax[i, j].set_xlim(0, max_lag)
|
223 |
| - |
224 |
| - if chains > 1: |
| 235 | + |
| 236 | + if nchains > 1: |
225 | 237 | ax[i, j].set_title("chain {0}".format(j+1))
|
226 |
| - |
227 |
| - return (fig, ax) |
| 238 | + |
| 239 | + return ax |
228 | 240 |
|
229 | 241 |
|
230 | 242 | def var_str(name, shape):
|
@@ -294,10 +306,10 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
|
294 | 306 |
|
295 | 307 | vline (optional): numeric
|
296 | 308 | Location of vertical reference line (defaults to 0).
|
297 |
| - |
| 309 | +
|
298 | 310 | gs : GridSpec
|
299 | 311 | Matplotlib GridSpec object. Defaults to None.
|
300 |
| - |
| 312 | +
|
301 | 313 | Returns
|
302 | 314 | -------
|
303 | 315 |
|
|
0 commit comments