Skip to content

Commit d4c8742

Browse files
ColCarrolltwiecki
authored andcommitted
GLM tests probably pass (pymc-devs#1352)
1 parent a54438d commit d4c8742

File tree

2 files changed

+39
-55
lines changed

2 files changed

+39
-55
lines changed

pymc3/tests/helpers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33

44

55
class SeededTest(unittest.TestCase):
6-
random_seed = 20160907
6+
random_seed = 20160911
7+
8+
@classmethod
9+
def setUpClass(cls):
10+
nr.seed(cls.random_seed)
711

812
def setUp(self):
913
nr.seed(self.random_seed)

pymc3/tests/test_glm.py

Lines changed: 34 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,61 @@
1-
import unittest
2-
from nose import SkipTest
31
import numpy as np
4-
import sys
5-
try:
6-
import statsmodels.api as sm
7-
except ImportError:
8-
raise SkipTest("Test requires statsmodels.")
92

10-
from pymc3.examples import glm_linear, glm_robust
3+
from .helpers import SeededTest
4+
from pymc3 import glm, Model, Uniform, Normal, find_MAP, Slice, sample
115

126

13-
np.random.seed(1)
147
# Generate data
15-
true_intercept = 0
16-
true_slope = 3
17-
18-
19-
def generate_data(size=700):
8+
def generate_data(intercept, slope, size=700):
209
x = np.linspace(-1, 1, size)
21-
y = true_intercept + x * true_slope
10+
y = intercept + x * slope
2211
return x, y
2312

24-
true_sd = .05
25-
x_linear, y_linear = generate_data(size=1000)
26-
y_linear += np.random.normal(size=1000, scale=true_sd)
27-
data_linear = dict(x=x_linear, y=y_linear)
2813

29-
x_logistic, y_logistic = generate_data(size=3000)
30-
y_logistic = 1 / (1 + np.exp(-y_logistic))
31-
bern_trials = [np.random.binomial(1, i) for i in y_logistic]
32-
data_logistic = dict(x=x_logistic, y=bern_trials)
14+
class TestGLM(SeededTest):
15+
@classmethod
16+
def setUpClass(cls):
17+
super(TestGLM, cls).setUpClass()
18+
cls.intercept = 1
19+
cls.slope = 3
20+
cls.sd = .05
21+
x_linear, cls.y_linear = generate_data(cls.intercept, cls.slope, size=1000)
22+
cls.y_linear += np.random.normal(size=1000, scale=cls.sd)
23+
cls.data_linear = dict(x=x_linear, y=cls.y_linear)
3324

25+
x_logistic, y_logistic = generate_data(cls.intercept, cls.slope, size=3000)
26+
y_logistic = 1 / (1 + np.exp(-y_logistic))
27+
bern_trials = [np.random.binomial(1, i) for i in y_logistic]
28+
cls.data_logistic = dict(x=x_logistic, y=bern_trials)
3429

35-
class TestGLM(unittest.TestCase):
36-
37-
@unittest.skip("Fails only on travis. Investigate")
3830
def test_linear_component(self):
3931
with Model() as model:
40-
y_est, coeffs = glm.linear_component('y ~ x', data_linear)
41-
for coeff, true_val in zip(coeffs, [true_intercept, true_slope]):
42-
self.assertAlmostEqual(coeff.tag.test_value, true_val, 1)
32+
y_est, _ = glm.linear_component('y ~ x', self.data_linear)
4333
sigma = Uniform('sigma', 0, 20)
44-
y_obs = Normal('y_obs', mu=y_est, sd=sigma, observed=y_linear)
34+
Normal('y_obs', mu=y_est, sd=sigma, observed=self.y_linear)
4535
start = find_MAP(vars=[sigma])
4636
step = Slice(model.vars)
47-
trace = sample(2000, step, start, progressbar=False)
37+
trace = sample(500, step, start, progressbar=False, random_seed=self.random_seed)
4838

49-
self.assertAlmostEqual(
50-
np.mean(trace['Intercept']), true_intercept, 1)
51-
self.assertAlmostEqual(np.mean(trace['x']), true_slope, 1)
52-
self.assertAlmostEqual(np.mean(trace['sigma']), true_sd, 1)
39+
self.assertAlmostEqual(np.mean(trace['Intercept']), self.intercept, 1)
40+
self.assertAlmostEqual(np.mean(trace['x']), self.slope, 1)
41+
self.assertAlmostEqual(np.mean(trace['sigma']), self.sd, 1)
5342

54-
@unittest.skip("Fails only on travis. Investigate")
5543
def test_glm(self):
5644
with Model() as model:
57-
vars = glm.glm('y ~ x', data_linear)
58-
for coeff, true_val in zip(vars[1:], [true_intercept, true_slope, true_sd]):
59-
self.assertAlmostEqual(coeff.tag.test_value, true_val, 1)
45+
glm.glm('y ~ x', self.data_linear)
6046
step = Slice(model.vars)
61-
trace = sample(2000, step, progressbar=False)
47+
trace = sample(500, step, progressbar=False, random_seed=self.random_seed)
6248

63-
self.assertAlmostEqual(
64-
np.mean(trace['Intercept']), true_intercept, 1)
65-
self.assertAlmostEqual(np.mean(trace['x']), true_slope, 1)
66-
self.assertAlmostEqual(np.mean(trace['sigma']), true_sd, 1)
49+
self.assertAlmostEqual(np.mean(trace['Intercept']), self.intercept, 1)
50+
self.assertAlmostEqual(np.mean(trace['x']), self.slope, 1)
51+
self.assertAlmostEqual(np.mean(trace['sd']), self.sd, 1)
6752

68-
@unittest.skip("Was an error, then a fail, now a skip.")
6953
def test_glm_link_func(self):
7054
with Model() as model:
71-
vars = glm.glm('y ~ x', data_logistic,
72-
family=glm.families.Binomial(link=glm.families.logit))
73-
74-
for coeff, true_val in zip(vars[1:], [true_intercept, true_slope]):
75-
self.assertAlmostEqual(coeff.tag.test_value, true_val, 0)
55+
glm.glm('y ~ x', self.data_logistic,
56+
family=glm.families.Binomial(link=glm.families.logit))
7657
step = Slice(model.vars)
77-
trace = sample(2000, step, progressbar=False)
58+
trace = sample(1000, step, progressbar=False, random_seed=self.random_seed)
7859

79-
self.assertAlmostEqual(
80-
np.mean(trace['Intercept']), true_intercept, 1)
81-
self.assertAlmostEqual(np.mean(trace['x']), true_slope, 0)
60+
self.assertAlmostEqual(np.mean(trace['Intercept']), self.intercept, 1)
61+
self.assertAlmostEqual(np.mean(trace['x']), self.slope, 1)

0 commit comments

Comments
 (0)