Skip to content

Commit 80a8fd8

Browse files
committed
made test_all invariant to R forecast changes
1 parent 869f0b8 commit 80a8fd8

File tree

1 file changed

+131
-102
lines changed

1 file changed

+131
-102
lines changed

test/test_all.py

Lines changed: 131 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,135 +1,164 @@
11
import unittest
22
from rforecast import wrappers
3+
from rforecast import converters
34
from rforecast import ts_io
5+
from rpy2 import robjects
6+
from rpy2.robjects.packages import importr
47

58

69
class EndToEndTestCase(unittest.TestCase):
710

811
def setUp(self):
9-
self.oil = ts_io.read_series('data/oil.csv')
10-
self.aus = ts_io.read_series('data/aus.csv')
11-
self.austa = ts_io.read_ts('austa', 'fpp')
12-
12+
self.oil_r = ts_io.read_ts('oil', 'fpp', as_pandas=False)
13+
self.oil_py = converters.ts_as_series(self.oil_r)
14+
self.aus_r = ts_io.read_ts('austourists', 'fpp', as_pandas=False)
15+
self.aus_py = converters.ts_as_series(self.aus_r)
16+
self.austa_r = ts_io.read_ts('austa', 'fpp', as_pandas=False)
17+
self.austa_py = converters.ts_as_series(self.austa_r)
18+
self.fc = importr('forecast')
19+
20+
def _check_points(self, fc_py, fc_r):
21+
'''
22+
Checks that the R and python forecasts are the same at select points
23+
for both the mean forecast and the prediction intervals. Compares the
24+
first and last values of the mean forecast, and the first value of the
25+
80% confidence lower PI and the last value of the 95% upper PI.
26+
27+
Args:
28+
fc_py: the python forecast
29+
fc_r : the R forecast
30+
31+
Return:
32+
Nothing, but makes tests assertions which can fail.
33+
'''
34+
lower = fc_r.rx2('lower')
35+
upper = fc_r.rx2('upper')
36+
mean = fc_r.rx2('mean')
37+
self.assertAlmostEqual(fc_py.point_fc.iloc[0], mean[0], places=3)
38+
self.assertAlmostEqual(fc_py.point_fc.iloc[-1], mean[-1], places=3)
39+
self.assertAlmostEqual(fc_py.lower80.iloc[0], lower[0], places=3)
40+
self.assertAlmostEqual(fc_py.upper95.iloc[-1], upper[-1], places=3)
41+
1342
def test_naive(self):
14-
fc = wrappers.naive(self.oil)
15-
self.assertAlmostEqual(fc.point_fc[2011], 467.7724, places=3)
16-
self.assertAlmostEqual(fc.point_fc[2020], 467.7724, places=3)
17-
self.assertAlmostEqual(fc.lower80[2011], 404.6370, places=3)
18-
self.assertAlmostEqual(fc.upper95[2020], 773.1130, places=3)
43+
fc_py = wrappers.naive(self.oil_py)
44+
fc_r = self.fc.naive(self.oil_r)
45+
self._check_points(fc_py, fc_r)
1946

2047
def test_thetaf(self):
21-
fc = wrappers.thetaf(self.oil)
22-
self.assertAlmostEqual(fc.point_fc[2011], 470.9975, places=3)
23-
self.assertAlmostEqual(fc.point_fc[2020], 500.0231, places=3)
24-
self.assertAlmostEqual(fc.lower80[2011], 408.5509, places=3)
25-
self.assertAlmostEqual(fc.upper95[2020], 802.0053, places=3)
48+
fc_py = wrappers.thetaf(self.oil_py)
49+
fc_r = self.fc.thetaf(self.oil_r)
50+
self._check_points(fc_py, fc_r)
2651

2752
def test_snaive(self):
28-
fc = wrappers.snaive(self.aus)
29-
self.assertAlmostEqual(fc.point_fc[(2011, 1)], 59.76678, places=3)
30-
self.assertAlmostEqual(fc.point_fc[(2012, 4)], 47.91374, places=3)
31-
self.assertAlmostEqual(fc.lower80[(2011, 1)], 55.37882, places=3)
32-
self.assertAlmostEqual(fc.upper95[(2012, 4)], 57.40424, places=3)
53+
fc_py = wrappers.snaive(self.aus_py)
54+
fc_r = self.fc.snaive(self.aus_r)
55+
self._check_points(fc_py, fc_r)
3356

3457
def test_rwf(self):
35-
fc = wrappers.rwf(self.oil)
36-
self.assertAlmostEqual(fc.point_fc[2011], 467.7724, places=3)
37-
self.assertAlmostEqual(fc.point_fc[2020], 467.7724, places=3)
38-
self.assertAlmostEqual(fc.lower80[2011], 404.7558, places=3)
39-
self.assertAlmostEqual(fc.upper95[2020], 772.5385, places=3)
40-
41-
def test_forecast(self):
42-
fc = wrappers.forecast(self.oil)
43-
self.assertAlmostEqual(fc.point_fc[2011], 467.7721, places=3)
44-
self.assertAlmostEqual(fc.point_fc[2020], 467.7721, places=3)
45-
self.assertAlmostEqual(fc.lower80[2011], 405.3255, places=3)
46-
self.assertAlmostEqual(fc.upper95[2020], 769.7543, places=3)
47-
fc = wrappers.forecast(self.aus)
48-
self.assertAlmostEqual(fc.point_fc[(2011, 1)], 57.87294, places=3)
49-
self.assertAlmostEqual(fc.point_fc[(2012, 4)], 52.84327, places=3)
50-
self.assertAlmostEqual(fc.lower80[(2011, 1)], 53.30794, places=3)
51-
self.assertAlmostEqual(fc.upper95[(2012, 4)], 62.60852, places=3)
58+
fc_py = wrappers.rwf(self.oil_py)
59+
fc_r = self.fc.rwf(self.oil_r)
60+
self._check_points(fc_py, fc_r)
61+
62+
def test_forecast_nonseasonal(self):
63+
fc_py = wrappers.forecast(self.oil_py)
64+
fc_r = self.fc.forecast(self.oil_r)
65+
self._check_points(fc_py, fc_r)
5266

53-
def test_auto_arima(self):
54-
fc = wrappers.auto_arima(self.oil)
55-
self.assertAlmostEqual(fc.point_fc[2011], 475.7004, places=3)
56-
self.assertAlmostEqual(fc.point_fc[2020], 547.0531, places=3)
57-
self.assertAlmostEqual(fc.lower80[2011], 412.6839, places=3)
58-
self.assertAlmostEqual(fc.upper95[2020], 851.8193, places=3)
59-
fc = wrappers.auto_arima(self.aus)
60-
self.assertAlmostEqual(fc.point_fc[(2011, 1)], 60.64208, places=3)
61-
self.assertAlmostEqual(fc.point_fc[(2012, 4)], 51.55356, places=3)
62-
self.assertAlmostEqual(fc.lower80[(2011, 1)], 57.57010, places=3)
63-
self.assertAlmostEqual(fc.upper95[(2012, 4)], 57.52426, places=3)
64-
self.assertRaises(ValueError, wrappers.auto_arima, self.oil ,
65-
xreg=range(len(self.oil)))
66-
self.assertRaises(ValueError, wrappers.auto_arima, self.oil, h=10,
67+
def test_forecast_seasonal(self):
68+
fc_py = wrappers.forecast(self.aus_py)
69+
fc_r = self.fc.forecast(self.aus_r)
70+
self._check_points(fc_py, fc_r)
71+
72+
def test_auto_arima_nonseasonal(self):
73+
fc_py = wrappers.auto_arima(self.oil_py)
74+
model = self.fc.auto_arima(self.oil_r)
75+
fc_r = self.fc.forecast(model)
76+
self._check_points(fc_py, fc_r)
77+
78+
def test_auto_arima_seasonal(self):
79+
fc_py = wrappers.auto_arima(self.aus_py)
80+
model = self.fc.auto_arima(self.aus_r)
81+
fc_r = self.fc.forecast(model)
82+
self._check_points(fc_py, fc_r)
83+
84+
def test_auto_arima_raises(self):
85+
self.assertRaises(ValueError, wrappers.auto_arima, self.oil_py ,
86+
xreg=range(len(self.oil_py)))
87+
self.assertRaises(ValueError, wrappers.auto_arima, self.oil_py, h=10,
6788
newxreg=range(10))
6889

6990
def test_stlf(self):
70-
fc = wrappers.stlf(self.aus)
71-
self.assertAlmostEqual(fc.point_fc[(2011, 1)], 58.88951, places=3)
72-
self.assertAlmostEqual(fc.point_fc[(2012, 4)], 51.68499, places=3)
73-
self.assertAlmostEqual(fc.lower80[(2011, 1)], 56.69656, places=3)
74-
self.assertAlmostEqual(fc.upper95[(2012, 4)], 57.59713, places=3)
91+
fc_py = wrappers.stlf(self.aus_py)
92+
fc_r = self.fc.stlf(self.aus_r)
93+
self._check_points(fc_py, fc_r)
7594

7695
def test_acf(self):
77-
acf = wrappers.acf(self.oil, lag_max=10)
78-
self.assertEqual(acf.name, 'Acf')
79-
self.assertEqual(len(acf), 10)
80-
self.assertAlmostEqual(acf[1], 0.8708, places=3)
81-
self.assertAlmostEqual(acf[10], -0.2016, places=3)
96+
acf_py = wrappers.acf(self.oil_py, lag_max=10)
97+
self.assertEqual(acf_py.name, 'Acf')
98+
self.assertEqual(len(acf_py), 10)
99+
acf_r = self.fc.Acf(self.oil_r, plot=False, lag_max=10)
100+
self.assertAlmostEqual(acf_py[1], acf_r.rx2('acf')[1], places=3)
101+
self.assertAlmostEqual(acf_py[10], acf_r.rx2('acf')[10], places=3)
82102

83103
def test_pacf(self):
84-
acf = wrappers.pacf(self.oil, lag_max=10)
85-
self.assertEqual(acf.name, 'Pacf')
86-
self.assertEqual(len(acf), 10)
87-
self.assertAlmostEqual(acf[1], 0.8708, places=3)
88-
self.assertAlmostEqual(acf[10], 0.1104, places=3)
104+
pacf_py = wrappers.pacf(self.oil_py, lag_max=10)
105+
self.assertEqual(pacf_py.name, 'Pacf')
106+
self.assertEqual(len(pacf_py), 10)
107+
pacf_r = self.fc.Pacf(self.oil_r, plot=False, **{'lag.max':10})
108+
self.assertAlmostEqual(pacf_py.values[0], pacf_r.rx2('acf')[0], places=3)
109+
self.assertAlmostEqual(pacf_py.values[-1], pacf_r.rx2('acf')[-1], places=3)
89110

90111
def test_ses(self):
91-
fc = wrappers.ses(self.oil, level=80)
92-
self.assertAlmostEqual(fc.point_fc[2011], 467.7724, places=3)
93-
self.assertAlmostEqual(fc.point_fc[2020], 467.7724, places=3)
94-
self.assertAlmostEqual(fc.lower80[2011], 405.3270, places=3)
95-
self.assertAlmostEqual(fc.upper80[2020], 665.2418, places=3)
96-
self.assertRaises(ValueError, wrappers.ses, self.oil, alpha=0)
97-
self.assertRaises(ValueError, wrappers.ses, self.oil, alpha=1.0)
112+
fc_py = wrappers.ses(self.oil_py)
113+
fc_r = self.fc.ses(self.oil_r, initial='simple')
114+
self._check_points(fc_py, fc_r)
115+
116+
def test_ses_raises(self):
117+
self.assertRaises(ValueError, wrappers.ses, self.oil_py, alpha=0)
118+
self.assertRaises(ValueError, wrappers.ses, self.oil_py, alpha=1.0)
98119

99120
def test_holt(self):
100-
fc = wrappers.holt(self.austa)
101-
self.assertAlmostEqual(fc.point_fc[2011], 5.5690, places=3)
102-
self.assertAlmostEqual(fc.point_fc[2020], 6.7219, places=3)
103-
self.assertAlmostEqual(fc.lower80[2011], 5.3316, places=3)
104-
self.assertAlmostEqual(fc.upper95[2020], 8.8653, places=3)
105-
self.assertRaises(ValueError, wrappers.holt, self.austa, alpha=0)
106-
self.assertRaises(ValueError, wrappers.holt, self.austa, alpha=1.0)
107-
self.assertRaises(ValueError, wrappers.holt, self.austa, beta=0)
108-
self.assertRaises(ValueError, wrappers.holt, self.austa, beta=1.0)
121+
fc_py = wrappers.holt(self.austa_py)
122+
fc_r = self.fc.holt(self.austa_r, initial='simple')
123+
self._check_points(fc_py, fc_r)
124+
125+
def test_holt_raises(self):
126+
self.assertRaises(ValueError, wrappers.holt, self.austa_py, alpha=0)
127+
self.assertRaises(ValueError, wrappers.holt, self.austa_py, alpha=1.0)
128+
self.assertRaises(ValueError, wrappers.holt, self.austa_py, beta=0)
129+
self.assertRaises(ValueError, wrappers.holt, self.austa_py, beta=1.0)
109130

110131
def test_hw(self):
111-
fc = wrappers.hw(self.aus)
112-
self.assertAlmostEqual(fc.point_fc[(2011, 1)], 60.0507, places=3)
113-
self.assertAlmostEqual(fc.point_fc[(2012, 4)], 52.2922, places=3)
114-
self.assertAlmostEqual(fc.lower80[(2011, 1)], 56.9353, places=3)
115-
self.assertAlmostEqual(fc.upper95[(2012, 4)], 64.7920, places=3)
116-
self.assertRaises(ValueError, wrappers.hw, self.aus, alpha=0)
117-
self.assertRaises(ValueError, wrappers.hw, self.aus, alpha=1.0)
118-
self.assertRaises(ValueError, wrappers.hw, self.aus, beta=0)
119-
self.assertRaises(ValueError, wrappers.hw, self.aus, beta=1.0)
120-
self.assertRaises(ValueError, wrappers.hw, self.aus, gamma=0)
121-
self.assertRaises(ValueError, wrappers.hw, self.aus, gamma=1.0)
122-
123-
def test_arima(self):
124-
fc = wrappers.arima(self.aus, order=(1,0,0), seasonal=(1,1,0),
132+
fc_py = wrappers.hw(self.aus_py)
133+
fc_r = self.fc.hw(self.aus_r, initial='simple')
134+
self._check_points(fc_py, fc_r)
135+
136+
def test_hw_raises(self):
137+
self.assertRaises(ValueError, wrappers.hw, self.aus_py, alpha=0)
138+
self.assertRaises(ValueError, wrappers.hw, self.aus_py, alpha=1.0)
139+
self.assertRaises(ValueError, wrappers.hw, self.aus_py, beta=0)
140+
self.assertRaises(ValueError, wrappers.hw, self.aus_py, beta=1.0)
141+
self.assertRaises(ValueError, wrappers.hw, self.aus_py, gamma=0)
142+
self.assertRaises(ValueError, wrappers.hw, self.aus_py, gamma=1.0)
143+
144+
def test_arima_seasonal(self):
145+
fc_py = wrappers.arima(self.aus_py, order=(1,0,0), seasonal=(1,1,0),
125146
include_constant=True)
126-
self.assertAlmostEqual(fc.point_fc[(2011, 1)], 60.6420, places=3)
127-
self.assertAlmostEqual(fc.point_fc[(2012, 4)], 51.5535, places=3)
128-
self.assertAlmostEqual(fc.lower80[(2011, 1)], 57.5701, places=3)
129-
self.assertAlmostEqual(fc.upper95[(2012, 4)], 57.5242, places=3)
130-
self.assertEqual(fc.shape[0], 8)
131-
fc = wrappers.arima(self.oil, order=(0,1,0))
132-
self.assertEqual(fc.shape[0], 10)
147+
order = robjects.r.c(1., 0., 0.)
148+
seasonal = robjects.r.c(1., 1., 0.)
149+
model = self.fc.Arima(self.aus_r, order=order, seasonal=seasonal,
150+
include_constant=True)
151+
fc_r = self.fc.forecast(model)
152+
self._check_points(fc_py, fc_r)
153+
self.assertEqual(fc_py.shape[0], 8)
154+
155+
def test_arima_nonseasonal(self):
156+
fc_py = wrappers.arima(self.oil_py, order=(0,1,0))
157+
order = robjects.r.c(0., 1., 0.)
158+
model = self.fc.Arima(self.oil_r, order=order)
159+
fc_r = self.fc.forecast(model)
160+
self._check_points(fc_py, fc_r)
161+
self.assertEqual(fc_py.shape[0], 10)
133162

134163

135164

0 commit comments

Comments
 (0)