Skip to content

Commit 88acc7c

Browse files
committed
Deprecated nuts_kwargs and step_kwargs
1 parent e526c1e commit 88acc7c

12 files changed

+53
-53
lines changed

docs/source/notebooks/Diagnosing_biased_Inference_with_Divergences.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@
10021002
"source": [
10031003
"with Centered_eight:\n",
10041004
" fit_cp85 = pm.sample(5000, chains=2, tune=2000,\n",
1005-
" nuts_kwargs=dict(target_accept=.85))"
1005+
" target_accept=.85)"
10061006
]
10071007
},
10081008
{
@@ -1029,7 +1029,7 @@
10291029
"source": [
10301030
"with Centered_eight:\n",
10311031
" fit_cp90 = pm.sample(5000, chains=2, tune=2000,\n",
1032-
" nuts_kwargs=dict(target_accept=.90))"
1032+
" target_accept=.90)"
10331033
]
10341034
},
10351035
{
@@ -1056,7 +1056,7 @@
10561056
"source": [
10571057
"with Centered_eight:\n",
10581058
" fit_cp95 = pm.sample(5000, chains=2, tune=2000,\n",
1059-
" nuts_kwargs=dict(target_accept=.95))"
1059+
" target_accept=.95)"
10601060
]
10611061
},
10621062
{
@@ -1083,7 +1083,7 @@
10831083
"source": [
10841084
"with Centered_eight:\n",
10851085
" fit_cp99 = pm.sample(5000, chains=2, tune=2000,\n",
1086-
" nuts_kwargs=dict(target_accept=.99))"
1086+
" target_accept=.99)"
10871087
]
10881088
},
10891089
{
@@ -1350,7 +1350,7 @@
13501350
"source": [
13511351
"with NonCentered_eight:\n",
13521352
" fit_ncp80 = pm.sample(5000, chains=2, tune=1000, random_seed=SEED,\n",
1353-
" nuts_kwargs=dict(target_accept=.80))"
1353+
" target_accept=.80)"
13541354
]
13551355
},
13561356
{
@@ -1708,7 +1708,7 @@
17081708
"source": [
17091709
"with NonCentered_eight:\n",
17101710
" fit_ncp90 = pm.sample(5000, chains=2, tune=1000, random_seed=SEED,\n",
1711-
" nuts_kwargs=dict(target_accept=.90))\n",
1711+
" target_accept=.90)\n",
17121712
" \n",
17131713
"# display the total number and percentage of divergent\n",
17141714
"divergent = fit_ncp90['diverging']\n",

docs/source/notebooks/GLM-hierarchical-binominal-model.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@
309309
" theta = pm.Beta('theta', alpha=ab[0], beta=ab[1], shape=N)\n",
310310
"\n",
311311
" p = pm.Binomial('y', p=theta, observed=y, n=n)\n",
312-
" trace = pm.sample(1000, tune=2000, nuts_kwargs={'target_accept': .95})\n",
312+
" trace = pm.sample(1000, tune=2000, target_accept=0.95)\n",
313313
" "
314314
]
315315
},

docs/source/notebooks/GLM-rolling-regression.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@
328328
"source": [
329329
"with model_randomwalk:\n",
330330
" trace_rw = pm.sample(tune=2000, cores=4, samples=200, \n",
331-
" nuts_kwargs=dict(target_accept=.9))"
331+
" target_accept=0.9)"
332332
]
333333
},
334334
{

docs/source/notebooks/GP-MaunaLoa2.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@
260260
],
261261
"source": [
262262
"with model:\n",
263-
" tr = pm.sample(1000, tune=1000, chains=2, cores=1, nuts_kwargs={\"target_accept\":0.95})"
263+
" tr = pm.sample(1000, tune=1000, chains=2, cores=1, target_accept=0.95)"
264264
]
265265
},
266266
{
@@ -595,7 +595,7 @@
595595
],
596596
"source": [
597597
"with model:\n",
598-
" tr = pm.sample(1000, tune=1000, chains=2, cores=1, nuts_kwargs={\"target_accept\":0.95})"
598+
" tr = pm.sample(1000, tune=1000, chains=2, cores=1, target_accept=0.95)"
599599
]
600600
},
601601
{
@@ -1084,7 +1084,7 @@
10841084
],
10851085
"source": [
10861086
"with model:\n",
1087-
" tr = pm.sample(500, chains=2, cores=1, nuts_kwargs={\"target_accept\": 0.95})"
1087+
" tr = pm.sample(500, chains=2, cores=1, target_accept=0.95)"
10881088
]
10891089
},
10901090
{

docs/source/notebooks/PyMC3_tips_and_heuristic.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@
484484
" # Proportion sptial variance\n",
485485
" alpha = pm.Deterministic('alpha', sd_c/(sd_h+sd_c))\n",
486486
"\n",
487-
" trace1 = pm.sample(3e3, cores=2, tune=1000, nuts_kwargs={'max_treedepth': 15})"
487+
" trace1 = pm.sample(3e3, cores=2, tune=1000, max_treedepth=15)"
488488
]
489489
},
490490
{
@@ -702,7 +702,7 @@
702702
" # Proportion sptial variance\n",
703703
" alpha = pm.Deterministic('alpha', sd_c/(sd_h+sd_c))\n",
704704
"\n",
705-
" trace2 = pm.sample(3e3, cores=2, tune=1000, nuts_kwargs={'max_treedepth': 15})"
705+
" trace2 = pm.sample(3e3, cores=2, tune=1000, max_treedepth=15)"
706706
]
707707
},
708708
{

docs/source/notebooks/hierarchical_partial_pooling.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@
171171
"source": [
172172
"with baseball_model:\n",
173173
" trace = pm.sample(2000, tune=1000, chains=2,\n",
174-
" nuts_kwargs={'target_accept': 0.95})"
174+
" target_accept=0.95)"
175175
]
176176
},
177177
{

docs/source/notebooks/stochastic_volatility.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@
171171
],
172172
"source": [
173173
"with model:\n",
174-
" trace = pm.sample(tune=2000, nuts_kwargs=dict(target_accept=.9))"
174+
" trace = pm.sample(tune=2000, target_accept=0.9)"
175175
]
176176
},
177177
{

docs/source/notebooks/weibull_aft.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@
184184
"with model_1:\n",
185185
" # Increase tune and change init to avoid divergences\n",
186186
" trace_1 = pm.sample(draws=1000, tune=1000,\n",
187-
" nuts_kwargs={'target_accept': 0.9},\n",
187+
" target_accept=0.9,\n",
188188
" init='adapt_diag')"
189189
]
190190
},
@@ -337,7 +337,7 @@
337337
"with model_2:\n",
338338
" # Increase tune and target_accept to avoid divergences\n",
339339
" trace_2 = pm.sample(draws=1000, tune=1000,\n",
340-
" nuts_kwargs={'target_accept': 0.9})"
340+
" target_accept=0.9)"
341341
]
342342
},
343343
{

pymc3/examples/arma_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def run(n_samples=1000):
7878
with model:
7979
trace = pm.sample(draws=n_samples,
8080
tune=1000,
81-
nuts_kwargs=dict(target_accept=.99))
81+
target_accept=.99)
8282

8383
pm.plots.traceplot(trace)
8484
pm.plots.forestplot(trace)

pymc3/examples/baseball.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def build_model():
2828
def run(n=2000):
2929
model = build_model()
3030
with model:
31-
trace = pm.sample(n, nuts_kwargs={'target_accept':.99})
31+
trace = pm.sample(n, target_accept=0.99)
3232

3333
pm.traceplot(trace)
3434

pymc3/sampling.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _cpu_count():
188188

189189

190190
def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=None, chain_idx=0,
191-
chains=None, cores=None, tune=500, nuts_kwargs=None, step_kwargs=None, progressbar=True,
191+
chains=None, cores=None, tune=500, progressbar=True,
192192
model=None, random_seed=None, live_plot=False, discard_tuned_samples=True,
193193
live_plot_kwargs=None, compute_convergence_checks=True, **kwargs):
194194
"""Draw samples from the posterior using the given step methods.
@@ -255,22 +255,6 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N
255255
the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition
256256
to the number specified in the `draws` argument, and will be discarded unless
257257
`discard_tuned_samples` is set to False.
258-
nuts_kwargs : dict
259-
Options for the NUTS sampler. See the docstring of NUTS for a complete list of options.
260-
Common options are:
261-
262-
* target_accept: float in [0, 1]. The step size is tuned such that we approximate this
263-
acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic
264-
posteriors.
265-
* max_treedepth: The maximum depth of the trajectory tree.
266-
* step_scale: float, default 0.25
267-
The initial guess for the step size scaled down by `1/n**(1/4)`.
268-
269-
If you want to pass options to other step methods, please use `step_kwargs`.
270-
step_kwargs : dict
271-
Options for step methods. Keys are the lower case names of the step method, values are
272-
dicts of keyword arguments. You can find a full list of arguments in the docstring of the
273-
step methods. If you want to pass arguments only to nuts, you can use `nuts_kwargs`.
274258
progressbar : bool
275259
Whether or not to display a progress bar in the command line. The bar shows the percentage
276260
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
@@ -294,6 +278,22 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N
294278
trace : pymc3.backends.base.MultiTrace
295279
A `MultiTrace` object that contains the samples.
296280
281+
Notes
282+
-----
283+
284+
Optional keyword arguments can be passed to `sample` to be delivered to the
285+
`step_method`s used during sampling. In particular, the NUTS step method accepts
286+
a number of arguments. Common options are:
287+
288+
* target_accept: float in [0, 1]. The step size is tuned such that we approximate this
289+
acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic
290+
posteriors.
291+
* max_treedepth: The maximum depth of the trajectory tree.
292+
* step_scale: float, default 0.25
293+
The initial guess for the step size scaled down by `1/n**(1/4)`.
294+
295+
You can find a full list of arguments in the docstring of the step methods.
296+
297297
Examples
298298
--------
299299
.. code:: ipython
@@ -316,9 +316,20 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N
316316
"""
317317
model = modelcontext(model)
318318

319+
nuts_kwargs = kwargs.pop('nuts_kwargs', None)
320+
if nuts_kwargs is not None:
321+
warnings.warn("The nuts_kwargs argument has been deprecated. Pass step "
322+
"method arguments directly to sample instead",
323+
DeprecationWarning)
324+
kwargs.update(nuts_kwargs)
325+
step_kwargs = kwargs.pop('step_kwargs', None)
326+
if step_kwargs is not None:
327+
warnings.warn("The step_kwargs argument has been deprecated. Pass step "
328+
"method arguments directly to sample instead",
329+
DeprecationWarning)
330+
kwargs.update(step_kwargs)
331+
319332
if isinstance(step, pm.step_methods.smc.SMC):
320-
if step_kwargs is None:
321-
step_kwargs = {}
322333
trace = smc.sample_smc(draws=draws,
323334
step=step,
324335
progressbar=progressbar,
@@ -372,33 +383,26 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N
372383

373384
draws += tune
374385

375-
if nuts_kwargs is not None:
376-
if step_kwargs is not None:
377-
raise ValueError("Specify only one of step_kwargs and nuts_kwargs")
378-
step_kwargs = {'nuts': nuts_kwargs}
379-
380386
if model.ndim == 0:
381387
raise ValueError('The model does not contain any free variables.')
382388

383389
if step is None and init is not None and all_continuous(model.vars):
384390
try:
385391
# By default, try to use NUTS
386392
_log.info('Auto-assigning NUTS sampler...')
387-
args = step_kwargs if step_kwargs is not None else {}
388-
args = args.get('nuts', {})
389393
start_, step = init_nuts(init=init, chains=chains, n_init=n_init,
390394
model=model, random_seed=random_seed,
391-
progressbar=progressbar, **args)
395+
progressbar=progressbar, **kwargs)
392396
if start is None:
393397
start = start_
394398
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
395399
# gradient computation failed
396400
_log.info("Initializing NUTS failed. "
397401
"Falling back to elementwise auto-assignment.")
398402
_log.debug('Exception in init nuts', exec_info=True)
399-
step = assign_step_methods(model, step, step_kwargs=step_kwargs)
403+
step = assign_step_methods(model, step, step_kwargs=kwargs)
400404
else:
401-
step = assign_step_methods(model, step, step_kwargs=step_kwargs)
405+
step = assign_step_methods(model, step, step_kwargs=kwargs)
402406

403407
if isinstance(step, list):
404408
step = CompoundStep(step)

pymc3/tests/test_sampling.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,7 @@ def test_sample_args(self):
7777
pm.sample(50, tune=0, init=None, step_kwargs={'foo': {}})
7878
assert 'foo' in str(excinfo.value)
7979

80-
pm.sample(10, tune=0, init=None, nuts_kwargs={'target_accept': 0.9})
81-
82-
with pytest.raises(ValueError) as excinfo:
83-
pm.sample(5, tune=0, init=None, step_kwargs={}, nuts_kwargs={})
84-
assert 'Specify only one' in str(excinfo.value)
80+
pm.sample(10, tune=0, init=None, target_accept=0.9)
8581

8682
def test_iter_sample(self):
8783
with self.model:

0 commit comments

Comments
 (0)