-
Notifications
You must be signed in to change notification settings - Fork 7
Description
@sethaxen told me on slack, and I am recording it here so it doesn't get eaten by history before i act on it, the following:
Great start! The plots from the final 3 sampling results (conditioned on c3, c8, and c1 or any subset of these) looked fishy, which made me think you were getting divergences, which happen when HMC encounters regions of high curvature where it can't reliably sample. So let's do a few diagnostic checks. First, let's sample multiple chains, as this allows more reliable convergence diagnostics:
julia> chain=sample(model, NUTS(), MCMCThreads(), 1_000, 4)
...
Chains MCMC chain (1000×23×4 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 4
Samples per chain = 1000
Wall duration = 25.09 seconds
Compute duration = 99.61 seconds
parameters = c_max, t_max, halflife, err, c2, c4, c6, c10, c12, c16, c24
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
c_max 108.6282 1.9417 0.0307 0.0731 683.8163 1.0067 6.8649
t_max 2.1786 0.0673 0.0011 0.0019 1278.5006 1.0022 12.8349
halflife 6.8477 0.8309 0.0131 0.0409 412.5109 1.0126 4.1412
err 1.1108 1.0635 0.0168 0.0827 48.9063 1.0935 0.4910
c2 99.7609 3.3208 0.0525 0.0976 1299.0791 1.0029 13.0415
c4 90.2151 2.0221 0.0320 0.0432 2331.7233 1.0008 23.4083
c6 73.6348 1.9702 0.0312 0.0505 1622.1679 1.0020 16.2850
c10 49.0056 2.5075 0.0396 0.0819 913.6478 1.0080 9.1722
c12 39.9438 2.7244 0.0431 0.0925 823.6327 1.0085 8.2685
c16 26.6574 2.6566 0.0420 0.0985 699.1499 1.0102 7.0188
c24 11.8861 2.7374 0.0433 0.1261 485.6644 1.0110 4.8756
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
c_max 104.3847 108.2221 108.8569 109.1745 112.0284
t_max 2.0441 2.1585 2.1771 2.1943 2.3212
halflife 6.0577 6.6845 6.7670 6.9133 7.8893
err 0.2000 0.3568 0.7444 1.5416 3.8401
c2 91.7489 98.9999 100.1175 100.9388 105.9018
c4 85.8345 89.7698 90.3332 90.7347 94.1405
c6 69.5013 73.0164 73.6187 74.1465 77.9393
c10 44.0264 48.1991 49.0064 49.6336 54.9456
c12 34.5891 39.2227 39.7270 40.6457 45.4985
c16 21.5447 25.9203 26.3998 27.3007 32.2124
c24 8.0231 11.1646 11.5504 12.3348 17.1209
julia> mean(chain[:numerical_error])
0.183
The columns to check here are ess
and rhat
. The first estimates how many truly independent draws would give an estimate for the mean of a given parameter with the same standard error as these non-independent draws. 45 draws for err
is too low and indicates something is wrong. rhat
is a convergence diagnostic and for all parameters should be less than 1.01, but some exceed this threshold. In the final check, we see that 18% of transitions failed due to numerical error (usually divergences). So there are geometric issues preventing some regions of the posterior from being sampled. So we can at least say MCMC didn't work well, and I wouldn't do much downstream analysis with these results unless it was to try to figure out why sampling failed. Often sampling problems indicate problems with the model.
Sometimes we can increase the adapt delta to some very large value and re-run sampling. This causes HMC to adapt a smaller step size and be better able to handle high curvature:
julia> chain2=sample(model, NUTS(0.999), MCMCThreads(), 1_000, 4);
julia> mean(chain2[:numerical_error])
0.0025
Even though most of the divergences went away, at such a high adapt delta, we should see no more, so I think it's worth looking into this further to see if the model can be improved.
One thing we can check is if divergences cluster in parameter space, which we can use ArviZ for:
julia> using ArviZ
julia> idata = convert_to_inference_data(chain)
InferenceData with groups:
> posterior
> sample_stats
julia> plot_pair(idata; var_names=[:c_max, :t_max, :halflife, :err], divergences=true)
This plot (attached) shows that divergent transitions occur when err
is low. In fact, you may have a funnel geometry, which tends to pose problems for MCMC methods.
I wasn't able to put more time into this, but if you come back to the model, I'd suggest maybe simulating dose curves and observations from the prior fixing sigma to low and high values and see if it makes sense why sigma being low could be problematic.