Skip to content

Commit 35c2508

Browse files
authored
fix: use stanio for creating Stan's data JSON (#205)
1 parent e435d60 commit 35c2508

File tree

3 files changed

+16
-14
lines changed

3 files changed

+16
-14
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ Homepage = "https://pymc-devs.github.io/nutpie/"
2828
Repository = "https://github.com/pymc-devs/nutpie"
2929

3030
[project.optional-dependencies]
31-
stan = ["bridgestan >= 2.6.1"]
31+
stan = ["bridgestan >= 2.6.1", "stanio >= 0.5.1"]
3232
pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"]
3333
pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"]
3434
nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"]
3535
dev = [
3636
"bridgestan >= 2.6.1",
37+
"stanio >= 0.5.1",
3738
"pymc >= 5.20.1",
3839
"numba >= 0.60.0",
3940
"jax >= 0.4.27",
@@ -43,6 +44,7 @@ dev = [
4344
]
4445
all = [
4546
"bridgestan >= 2.6.1",
47+
"stanio >= 0.5.1",
4648
"pymc >= 5.20.1",
4749
"numba >= 0.60.0",
4850
"jax >= 0.4.27",

python/nutpie/compile_stan.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,16 @@
1-
import json
21
import tempfile
32
from dataclasses import dataclass, replace
43
from importlib.util import find_spec
54
from pathlib import Path
65
from typing import Any, Optional
76

8-
import numpy as np
97
import pandas as pd
108
from numpy.typing import NDArray
119

1210
from nutpie import _lib
1311
from nutpie.sample import CompiledModel
1412

1513

16-
class _NumpyArrayEncoder(json.JSONEncoder):
17-
def default(self, obj):
18-
if isinstance(obj, np.ndarray):
19-
return obj.tolist()
20-
return json.JSONEncoder.default(self, obj)
21-
22-
2314
@dataclass(frozen=True)
2415
class CompiledStanModel(CompiledModel):
2516
_coords: Optional[dict[str, Any]]
@@ -39,7 +30,16 @@ def with_data(self, *, seed=None, **updates):
3930
data.update(updates)
4031

4132
if data is not None:
42-
data_json = json.dumps(data, cls=_NumpyArrayEncoder)
33+
if find_spec("stanio") is None:
34+
raise ImportError(
35+
"stanio is not installed in the current environment. "
36+
"Please install it with something like "
37+
"'pip install stanio' or 'pip install nutpie[stan]'."
38+
)
39+
40+
import stanio
41+
42+
data_json = stanio.dump_stan_json(data)
4343
else:
4444
data_json = None
4545

@@ -136,7 +136,7 @@ def compile_stan_model(
136136
raise ImportError(
137137
"BridgeStan is not installed in the current environment. "
138138
"Please install it with something like "
139-
"'pip install bridgestan'."
139+
"'pip install bridgestan' or 'pip install nutpie[stan]'."
140140
)
141141

142142
import bridgestan

tests/test_stan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_stan_model():
3131
def test_stan_model_data():
3232
model = """
3333
data {
34-
real x;
34+
complex x;
3535
}
3636
parameters {
3737
real a;
@@ -44,7 +44,7 @@ def test_stan_model_data():
4444
compiled_model = nutpie.compile_stan_model(code=model)
4545
with pytest.raises(RuntimeError):
4646
trace = nutpie.sample(compiled_model)
47-
trace = nutpie.sample(compiled_model.with_data(x=np.array(3.0)))
47+
trace = nutpie.sample(compiled_model.with_data(x=np.array(3.0j)))
4848
trace.posterior.a # noqa: B018
4949

5050

0 commit comments

Comments
 (0)