@@ -309,11 +309,9 @@ def _sample_external_nuts(
309
309
nuts_sampler_kwargs : dict | None ,
310
310
** kwargs ,
311
311
):
312
- import copy
313
-
314
- nuts_sampler_kwargs_copy = copy .deepcopy (nuts_sampler_kwargs )
315
- if nuts_sampler_kwargs_copy is None :
316
- nuts_sampler_kwargs_copy = {}
312
+ nuts_sampler_kwargs = nuts_sampler_kwargs .copy ()
313
+ if nuts_sampler_kwargs is None :
314
+ nuts_sampler_kwargs = {}
317
315
318
316
if sampler == "nutpie" :
319
317
try :
@@ -342,8 +340,8 @@ def _sample_external_nuts(
342
340
)
343
341
compile_kwargs = {}
344
342
for kwarg in ("backend" , "gradient_backend" ):
345
- if kwarg in nuts_sampler_kwargs_copy :
346
- compile_kwargs [kwarg ] = nuts_sampler_kwargs_copy .pop (kwarg )
343
+ if kwarg in nuts_sampler_kwargs :
344
+ compile_kwargs [kwarg ] = nuts_sampler_kwargs .pop (kwarg )
347
345
compiled_model = nutpie .compile_pymc_model (
348
346
model ,
349
347
** compile_kwargs ,
@@ -357,7 +355,7 @@ def _sample_external_nuts(
357
355
target_accept = target_accept ,
358
356
seed = _get_seeds_per_chain (random_seed , 1 )[0 ],
359
357
progress_bar = progressbar ,
360
- ** nuts_sampler_kwargs_copy ,
358
+ ** nuts_sampler_kwargs ,
361
359
)
362
360
t_sample = time .time () - t_start
363
361
# Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
@@ -409,7 +407,7 @@ def _sample_external_nuts(
409
407
nuts_sampler = sampler ,
410
408
idata_kwargs = idata_kwargs ,
411
409
compute_convergence_checks = compute_convergence_checks ,
412
- ** nuts_sampler_kwargs_copy ,
410
+ ** nuts_sampler_kwargs ,
413
411
)
414
412
return idata
415
413
@@ -689,9 +687,7 @@ def sample(
689
687
mean sd hdi_3% hdi_97%
690
688
p 0.609 0.047 0.528 0.699
691
689
"""
692
- import copy
693
-
694
- nuts_sampler_kwargs_copy = copy .deepcopy (nuts_sampler_kwargs )
690
+ nuts_sampler_kwargs = nuts_sampler_kwargs .copy ()
695
691
if "start" in kwargs :
696
692
if initvals is not None :
697
693
raise ValueError ("Passing both `start` and `initvals` is not supported." )
@@ -701,8 +697,8 @@ def sample(
701
697
stacklevel = 2 ,
702
698
)
703
699
initvals = kwargs .pop ("start" )
704
- if nuts_sampler_kwargs_copy is None :
705
- nuts_sampler_kwargs_copy = {}
700
+ if nuts_sampler_kwargs is None :
701
+ nuts_sampler_kwargs = {}
706
702
if "target_accept" in kwargs :
707
703
if "nuts" in kwargs and "target_accept" in kwargs ["nuts" ]:
708
704
raise ValueError (
@@ -814,7 +810,7 @@ def joined_blas_limiter():
814
810
progressbar = progressbar ,
815
811
idata_kwargs = idata_kwargs ,
816
812
compute_convergence_checks = compute_convergence_checks ,
817
- nuts_sampler_kwargs = nuts_sampler_kwargs_copy ,
813
+ nuts_sampler_kwargs = nuts_sampler_kwargs ,
818
814
** kwargs ,
819
815
)
820
816
0 commit comments