Skip to content

Commit 06bf9df

Browse files
aerubanovricardoV94
authored andcommitted
remove SeededTest class
1 parent 50d056d commit 06bf9df

13 files changed

+84
-107
lines changed

pymc/testing.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -775,24 +775,7 @@ def discrete_random_tester(
775775
assert p > alpha, str(point)
776776

777777

778-
class SeededTest:
779-
random_seed = 20160911
780-
random_state = None
781-
782-
@classmethod
783-
def setup_class(cls):
784-
nr.seed(cls.random_seed)
785-
786-
def setup_method(self):
787-
nr.seed(self.random_seed)
788-
789-
def get_random_state(self, reset=False):
790-
if self.random_state is None or reset:
791-
self.random_state = nr.RandomState(self.random_seed)
792-
return self.random_state
793-
794-
795-
class BaseTestDistributionRandom(SeededTest):
778+
class BaseTestDistributionRandom:
796779
"""
797780
Base class for tests that new RandomVariables are correctly
798781
implemented, and that the mapping of parameters between the PyMC
@@ -863,8 +846,9 @@ class BaseTestDistributionRandom(SeededTest):
863846
sizes_to_check: Optional[List] = None
864847
sizes_expected: Optional[List] = None
865848
repeated_params_shape = 5
849+
random_state = None
866850

867-
def test_distribution(self):
851+
def test_distribution(self, seeded_test):
868852
self.validate_tests_list()
869853
if self.pymc_dist == pm.Wishart:
870854
with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):
@@ -886,6 +870,11 @@ def test_distribution(self):
886870
else:
887871
getattr(self, check_name)()
888872

873+
def get_random_state(self, reset=False):
874+
if self.random_state is None or reset:
875+
self.random_state = nr.RandomState(20160911)
876+
return self.random_state
877+
889878
def _instantiate_pymc_rv(self, dist_params=None):
890879
params = dist_params if dist_params else self.pymc_dist_params
891880
self.pymc_rv = self.pymc_dist.dist(

tests/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def strict_float32():
4444

4545
@pytest.fixture(scope="function", autouse=False)
4646
def seeded_test():
47-
# TODO: use this instead of SeededTest
48-
np.random.seed(42)
47+
np.random.seed(20160911)
4948

5049

5150
@pytest.fixture

tests/distributions/test_mixture.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979
R,
8080
Rplus,
8181
Rplusbig,
82-
SeededTest,
8382
Simplex,
8483
Unit,
8584
assert_moment_is_expected,
@@ -115,7 +114,10 @@ def generate_poisson_mixture_data(w, mu, size=1000):
115114
return x
116115

117116

118-
class TestMixture(SeededTest):
117+
class TestMixture:
118+
def get_random_state(self):
119+
return np.random.RandomState(20160911)
120+
119121
def get_initial_point(self, model):
120122
"""Get initial point with untransformed variables for posterior predictive sampling"""
121123
return {
@@ -477,7 +479,7 @@ def test_single_poisson_sampling(self):
477479
trace = sample(
478480
5000,
479481
step=step,
480-
random_seed=self.random_seed,
482+
random_seed=45354,
481483
progressbar=False,
482484
chains=1,
483485
return_inferencedata=False,
@@ -502,7 +504,7 @@ def test_list_poissons_sampling(self):
502504
5000,
503505
chains=1,
504506
step=Metropolis(),
505-
random_seed=self.random_seed,
507+
random_seed=5363567,
506508
progressbar=False,
507509
return_inferencedata=False,
508510
)
@@ -533,7 +535,7 @@ def test_list_normals_sampling(self):
533535
5000,
534536
chains=1,
535537
step=Metropolis(),
536-
random_seed=self.random_seed,
538+
random_seed=645334,
537539
progressbar=False,
538540
return_inferencedata=False,
539541
)
@@ -785,8 +787,8 @@ def test_preventing_mixing_cont_and_discrete(self):
785787
)
786788

787789

788-
class TestNormalMixture(SeededTest):
789-
def test_normal_mixture_sampling(self):
790+
class TestNormalMixture:
791+
def test_normal_mixture_sampling(self, seeded_test):
790792
norm_w = np.array([0.75, 0.25])
791793
norm_mu = np.array([0.0, 5.0])
792794
norm_sigma = np.ones_like(norm_mu)
@@ -804,7 +806,7 @@ def test_normal_mixture_sampling(self):
804806
trace = sample(
805807
5000,
806808
step=step,
807-
random_seed=self.random_seed,
809+
random_seed=20160911,
808810
progressbar=False,
809811
chains=1,
810812
return_inferencedata=False,
@@ -816,7 +818,7 @@ def test_normal_mixture_sampling(self):
816818
@pytest.mark.parametrize(
817819
"nd, ncomp", [(tuple(), 5), (1, 5), (3, 5), ((3, 3), 5), (3, 3), ((3, 3), 3)], ids=str
818820
)
819-
def test_normal_mixture_nd(self, nd, ncomp):
821+
def test_normal_mixture_nd(self, seeded_test, nd, ncomp):
820822
nd = to_tuple(nd)
821823
ncomp = int(ncomp)
822824
comp_shape = nd + (ncomp,)
@@ -865,7 +867,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
865867
assert_allclose(logp0, logp1)
866868
assert_allclose(logp0, logp2)
867869

868-
def test_random(self):
870+
def test_random(self, seeded_test):
869871
def ref_rand(size, w, mu, sigma):
870872
component = np.random.choice(w.size, size=size, p=w)
871873
return np.random.normal(mu[component], sigma[component], size=size)
@@ -894,9 +896,12 @@ def ref_rand(size, w, mu, sigma):
894896
)
895897

896898

897-
class TestMixtureVsLatent(SeededTest):
899+
class TestMixtureVsLatent:
898900
"""This class contains tests that compare a marginal Mixture with a latent indexed Mixture"""
899901

902+
def get_random_state(self):
903+
return np.random.RandomState(20160911)
904+
900905
def test_scalar_components(self):
901906
nd = 3
902907
npop = 4
@@ -1013,21 +1018,20 @@ def loose_logp(model, vars):
10131018
assert_allclose(mix_logp, latent_mix_logp, rtol=rtol)
10141019

10151020

1016-
class TestMixtureSameFamily(SeededTest):
1021+
class TestMixtureSameFamily:
10171022
"""Tests that used to belong to deprecated `TestMixtureSameFamily`.
10181023
10191024
The functionality is now expected to be provided by `Mixture`
10201025
"""
10211026

10221027
@classmethod
10231028
def setup_class(cls):
1024-
super().setup_class()
10251029
cls.size = 50
10261030
cls.n_samples = 1000
10271031
cls.mixture_comps = 10
10281032

10291033
@pytest.mark.parametrize("batch_shape", [(3, 4), (20,)], ids=str)
1030-
def test_with_multinomial(self, batch_shape):
1034+
def test_with_multinomial(self, seeded_test, batch_shape):
10311035
p = np.random.uniform(size=(*batch_shape, self.mixture_comps, 3))
10321036
p /= p.sum(axis=-1, keepdims=True)
10331037
n = 100 * np.ones((*batch_shape, 1))
@@ -1062,7 +1066,7 @@ def test_with_multinomial(self, batch_shape):
10621066
rtol,
10631067
)
10641068

1065-
def test_with_mvnormal(self):
1069+
def test_with_mvnormal(self, seeded_test):
10661070
# 10 batch, 3-variate Gaussian
10671071
mu = np.random.randn(self.mixture_comps, 3)
10681072
mat = np.random.randn(3, 3)

tests/distributions/test_simulator.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@
3333
from pymc.initial_point import make_initial_point_fn
3434
from pymc.pytensorf import compile_pymc
3535
from pymc.smc.kernels import IMH
36-
from pymc.testing import SeededTest
3736

3837

39-
class TestSimulator(SeededTest):
38+
class TestSimulator:
4039
@staticmethod
4140
def count_rvs(end_node):
4241
return len(
@@ -60,7 +59,6 @@ def quantiles(x):
6059
return np.quantile(x, [0.25, 0.5, 0.75])
6160

6261
def setup_class(self):
63-
super().setup_class()
6462
self.data = np.random.normal(loc=0, scale=1, size=1000)
6563

6664
with pm.Model() as self.SMABC_test:
@@ -75,7 +73,7 @@ def setup_class(self):
7573
c = pm.Potential("c", pm.math.switch(a > 0, 0, -np.inf))
7674
s = pm.Simulator("s", self.normal_sim, a, b, observed=self.data)
7775

78-
def test_one_gaussian(self):
76+
def test_one_gaussian(self, seeded_test):
7977
assert self.count_rvs(self.SMABC_test.logp()) == 1
8078

8179
with self.SMABC_test:
@@ -95,7 +93,7 @@ def test_one_gaussian(self):
9593
assert abs(self.data.std() - po_p["s"].std()) < 0.10
9694

9795
@pytest.mark.parametrize("floatX", ["float32", "float64"])
98-
def test_custom_dist_sum_stat(self, floatX):
96+
def test_custom_dist_sum_stat(self, seeded_test, floatX):
9997
with pytensor.config.change_flags(floatX=floatX):
10098
with pm.Model() as m:
10199
a = pm.Normal("a", mu=0, sigma=1)
@@ -118,7 +116,7 @@ def test_custom_dist_sum_stat(self, floatX):
118116
pm.sample_smc(draws=100)
119117

120118
@pytest.mark.parametrize("floatX", ["float32", "float64"])
121-
def test_custom_dist_sum_stat_scalar(self, floatX):
119+
def test_custom_dist_sum_stat_scalar(self, seeded_test, floatX):
122120
"""
123121
Test that automatically wrapped functions cope well with scalar inputs
124122
"""
@@ -149,22 +147,22 @@ def test_custom_dist_sum_stat_scalar(self, floatX):
149147
)
150148
assert self.count_rvs(m.logp()) == 1
151149

152-
def test_model_with_potential(self):
150+
def test_model_with_potential(self, seeded_test):
153151
assert self.count_rvs(self.SMABC_potential.logp()) == 1
154152

155153
with self.SMABC_potential:
156154
trace = pm.sample_smc(draws=100, chains=1, return_inferencedata=False)
157155
assert np.all(trace["a"] >= 0)
158156

159-
def test_simulator_metropolis_mcmc(self):
157+
def test_simulator_metropolis_mcmc(self, seeded_test):
160158
with self.SMABC_test as m:
161159
step = pm.Metropolis([m.rvs_to_values[m["a"]], m.rvs_to_values[m["b"]]])
162160
trace = pm.sample(step=step, return_inferencedata=False)
163161

164162
assert abs(self.data.mean() - trace["a"].mean()) < 0.05
165163
assert abs(self.data.std() - trace["b"].mean()) < 0.05
166164

167-
def test_multiple_simulators(self):
165+
def test_multiple_simulators(self, seeded_test):
168166
true_a = 2
169167
true_b = -2
170168

@@ -214,9 +212,9 @@ def test_multiple_simulators(self):
214212
assert abs(true_a - trace["a"].mean()) < 0.05
215213
assert abs(true_b - trace["b"].mean()) < 0.05
216214

217-
def test_nested_simulators(self):
215+
def test_nested_simulators(self, seeded_test):
218216
true_a = 2
219-
rng = self.get_random_state()
217+
rng = np.random.RandomState(20160911)
220218
data = rng.normal(true_a, 0.1, size=1000)
221219

222220
with pm.Model() as m:
@@ -244,7 +242,7 @@ def test_nested_simulators(self):
244242

245243
assert np.abs(true_a - trace["sim1"].mean()) < 0.1
246244

247-
def test_upstream_rngs_not_in_compiled_logp(self):
245+
def test_upstream_rngs_not_in_compiled_logp(self, seeded_test):
248246
smc = IMH(model=self.SMABC_test)
249247
smc.initialize_population()
250248
smc._initialize_kernel()
@@ -263,7 +261,7 @@ def test_upstream_rngs_not_in_compiled_logp(self):
263261
]
264262
assert len(shared_rng_vars) == 1
265263

266-
def test_simulator_error_msg(self):
264+
def test_simulator_error_msg(self, seeded_test):
267265
msg = "The distance metric not_real is not implemented"
268266
with pytest.raises(ValueError, match=msg):
269267
with pm.Model() as m:
@@ -280,7 +278,7 @@ def test_simulator_error_msg(self):
280278
sim = pm.Simulator("sim", self.normal_sim, 0, params=(1))
281279

282280
@pytest.mark.xfail(reason="KL not refactored")
283-
def test_automatic_use_of_sort(self):
281+
def test_automatic_use_of_sort(self, seeded_test):
284282
with pm.Model() as model:
285283
s_k = pm.Simulator(
286284
"s_k",
@@ -292,7 +290,7 @@ def test_automatic_use_of_sort(self):
292290
)
293291
assert s_k.distribution.sum_stat is pm.distributions.simulator.identity
294292

295-
def test_name_is_string_type(self):
293+
def test_name_is_string_type(self, seeded_test):
296294
with self.SMABC_potential:
297295
assert not self.SMABC_potential.name
298296
with warnings.catch_warnings():
@@ -303,7 +301,7 @@ def test_name_is_string_type(self):
303301
trace = pm.sample_smc(draws=10, chains=1, return_inferencedata=False)
304302
assert isinstance(trace._straces[0].name, str)
305303

306-
def test_named_model(self):
304+
def test_named_model(self, seeded_test):
307305
# Named models used to fail with Simulator because the arguments to the
308306
# random fn used to be passed by name. This is no longer true.
309307
# https://github.com/pymc-devs/pymc/pull/4365#issuecomment-761221146
@@ -323,7 +321,7 @@ def test_named_model(self):
323321
@pytest.mark.parametrize("mu", [0, np.arange(3)], ids=str)
324322
@pytest.mark.parametrize("sigma", [1, np.array([1, 2, 5])], ids=str)
325323
@pytest.mark.parametrize("size", [None, 3, (5, 3)], ids=str)
326-
def test_simulator_moment(self, mu, sigma, size):
324+
def test_simulator_moment(self, seeded_test, mu, sigma, size):
327325
def normal_sim(rng, mu, sigma, size):
328326
return rng.normal(mu, sigma, size=size)
329327

@@ -357,7 +355,7 @@ def normal_sim(rng, mu, sigma, size):
357355

358356
assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)
359357

360-
def test_dist(self):
358+
def test_dist(self, seeded_test):
361359
x = pm.Simulator.dist(self.normal_sim, 0, 1, sum_stat="sort", shape=(3,))
362360
x = cloudpickle.loads(cloudpickle.dumps(x))
363361

tests/distributions/test_transform.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
R,
3434
Rminusbig,
3535
Rplusbig,
36-
SeededTest,
3736
Simplex,
3837
SortedVector,
3938
Unit,
@@ -301,7 +300,7 @@ def test_chain_jacob_det():
301300
check_jacobian_det(chain_tranf, Vector(R, 4), pt.vector, floatX(np.zeros(4)), elemwise=False)
302301

303302

304-
class TestElementWiseLogp(SeededTest):
303+
class TestElementWiseLogp:
305304
def build_model(self, distfam, params, size, transform, initval=None):
306305
if initval is not None:
307306
initval = pm.floatX(initval)

tests/sampler_fixtures.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import pymc as pm
2222

2323
from pymc.backends.arviz import to_inference_data
24-
from pymc.testing import SeededTest
2524
from pymc.util import get_var_name
2625

2726

@@ -135,10 +134,11 @@ def make_model(cls):
135134
return model
136135

137136

138-
class BaseSampler(SeededTest):
137+
class BaseSampler:
139138
@classmethod
140139
def setup_class(cls):
141-
super().setup_class()
140+
cls.random_seed = 20160911
141+
np.random.seed(cls.random_seed)
142142
cls.model = cls.make_model()
143143
with cls.model:
144144
cls.step = cls.make_step()

0 commit comments

Comments
 (0)