Skip to content

Commit e6c5b1a

Browse files
fix: replace deep copy with shalow copy
1 parent caa9501 commit e6c5b1a

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

pymc/sampling/mcmc.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,9 @@ def _sample_external_nuts(
309309
nuts_sampler_kwargs: dict | None,
310310
**kwargs,
311311
):
312-
import copy
313-
314-
nuts_sampler_kwargs_copy = copy.deepcopy(nuts_sampler_kwargs)
315-
if nuts_sampler_kwargs_copy is None:
316-
nuts_sampler_kwargs_copy = {}
312+
nuts_sampler_kwargs = nuts_sampler_kwargs.copy()
313+
if nuts_sampler_kwargs is None:
314+
nuts_sampler_kwargs = {}
317315

318316
if sampler == "nutpie":
319317
try:
@@ -342,8 +340,8 @@ def _sample_external_nuts(
342340
)
343341
compile_kwargs = {}
344342
for kwarg in ("backend", "gradient_backend"):
345-
if kwarg in nuts_sampler_kwargs_copy:
346-
compile_kwargs[kwarg] = nuts_sampler_kwargs_copy.pop(kwarg)
343+
if kwarg in nuts_sampler_kwargs:
344+
compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)
347345
compiled_model = nutpie.compile_pymc_model(
348346
model,
349347
**compile_kwargs,
@@ -357,7 +355,7 @@ def _sample_external_nuts(
357355
target_accept=target_accept,
358356
seed=_get_seeds_per_chain(random_seed, 1)[0],
359357
progress_bar=progressbar,
360-
**nuts_sampler_kwargs_copy,
358+
**nuts_sampler_kwargs,
361359
)
362360
t_sample = time.time() - t_start
363361
# Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
@@ -409,7 +407,7 @@ def _sample_external_nuts(
409407
nuts_sampler=sampler,
410408
idata_kwargs=idata_kwargs,
411409
compute_convergence_checks=compute_convergence_checks,
412-
**nuts_sampler_kwargs_copy,
410+
**nuts_sampler_kwargs,
413411
)
414412
return idata
415413

@@ -689,9 +687,7 @@ def sample(
689687
mean sd hdi_3% hdi_97%
690688
p 0.609 0.047 0.528 0.699
691689
"""
692-
import copy
693-
694-
nuts_sampler_kwargs_copy = copy.deepcopy(nuts_sampler_kwargs)
690+
nuts_sampler_kwargs = nuts_sampler_kwargs.copy()
695691
if "start" in kwargs:
696692
if initvals is not None:
697693
raise ValueError("Passing both `start` and `initvals` is not supported.")
@@ -701,8 +697,8 @@ def sample(
701697
stacklevel=2,
702698
)
703699
initvals = kwargs.pop("start")
704-
if nuts_sampler_kwargs_copy is None:
705-
nuts_sampler_kwargs_copy = {}
700+
if nuts_sampler_kwargs is None:
701+
nuts_sampler_kwargs = {}
706702
if "target_accept" in kwargs:
707703
if "nuts" in kwargs and "target_accept" in kwargs["nuts"]:
708704
raise ValueError(
@@ -814,7 +810,7 @@ def joined_blas_limiter():
814810
progressbar=progressbar,
815811
idata_kwargs=idata_kwargs,
816812
compute_convergence_checks=compute_convergence_checks,
817-
nuts_sampler_kwargs=nuts_sampler_kwargs_copy,
813+
nuts_sampler_kwargs=nuts_sampler_kwargs,
818814
**kwargs,
819815
)
820816

0 commit comments

Comments
 (0)