|
1 | 1 | import unittest
|
2 | 2 | from rforecast import wrappers
|
| 3 | +from rforecast import converters |
3 | 4 | from rforecast import ts_io
|
| 5 | +from rpy2 import robjects |
| 6 | +from rpy2.robjects.packages import importr |
4 | 7 |
|
5 | 8 |
|
6 | 9 | class EndToEndTestCase(unittest.TestCase):
|
7 | 10 |
|
8 | 11 | def setUp(self):
|
9 |
| - self.oil = ts_io.read_series('data/oil.csv') |
10 |
| - self.aus = ts_io.read_series('data/aus.csv') |
11 |
| - self.austa = ts_io.read_ts('austa', 'fpp') |
12 |
| - |
| 12 | + self.oil_r = ts_io.read_ts('oil', 'fpp', as_pandas=False) |
| 13 | + self.oil_py = converters.ts_as_series(self.oil_r) |
| 14 | + self.aus_r = ts_io.read_ts('austourists', 'fpp', as_pandas=False) |
| 15 | + self.aus_py = converters.ts_as_series(self.aus_r) |
| 16 | + self.austa_r = ts_io.read_ts('austa', 'fpp', as_pandas=False) |
| 17 | + self.austa_py = converters.ts_as_series(self.austa_r) |
| 18 | + self.fc = importr('forecast') |
| 19 | + |
| 20 | + def _check_points(self, fc_py, fc_r): |
| 21 | + ''' |
| 22 | + Checks that the R and python forecasts are the same at select points |
| 23 | + for both the mean forecast and the prediction intervals. Compares the |
| 24 | + first and last values of the mean forecast, and the first value of the |
| 25 | + 80% confidence lower PI and the last value of the 95% upper PI. |
| 26 | + |
| 27 | + Args: |
| 28 | + fc_py: the python forecast |
| 29 | + fc_r : the R forecast |
| 30 | + |
| 31 | + Return: |
| 32 | + Nothing, but makes tests assertions which can fail. |
| 33 | + ''' |
| 34 | + lower = fc_r.rx2('lower') |
| 35 | + upper = fc_r.rx2('upper') |
| 36 | + mean = fc_r.rx2('mean') |
| 37 | + self.assertAlmostEqual(fc_py.point_fc.iloc[0], mean[0], places=3) |
| 38 | + self.assertAlmostEqual(fc_py.point_fc.iloc[-1], mean[-1], places=3) |
| 39 | + self.assertAlmostEqual(fc_py.lower80.iloc[0], lower[0], places=3) |
| 40 | + self.assertAlmostEqual(fc_py.upper95.iloc[-1], upper[-1], places=3) |
| 41 | + |
13 | 42 | def test_naive(self):
|
14 |
| - fc = wrappers.naive(self.oil) |
15 |
| - self.assertAlmostEqual(fc.point_fc[2011], 467.7724, places=3) |
16 |
| - self.assertAlmostEqual(fc.point_fc[2020], 467.7724, places=3) |
17 |
| - self.assertAlmostEqual(fc.lower80[2011], 404.6370, places=3) |
18 |
| - self.assertAlmostEqual(fc.upper95[2020], 773.1130, places=3) |
| 43 | + fc_py = wrappers.naive(self.oil_py) |
| 44 | + fc_r = self.fc.naive(self.oil_r) |
| 45 | + self._check_points(fc_py, fc_r) |
19 | 46 |
|
20 | 47 | def test_thetaf(self):
|
21 |
| - fc = wrappers.thetaf(self.oil) |
22 |
| - self.assertAlmostEqual(fc.point_fc[2011], 470.9975, places=3) |
23 |
| - self.assertAlmostEqual(fc.point_fc[2020], 500.0231, places=3) |
24 |
| - self.assertAlmostEqual(fc.lower80[2011], 408.5509, places=3) |
25 |
| - self.assertAlmostEqual(fc.upper95[2020], 802.0053, places=3) |
| 48 | + fc_py = wrappers.thetaf(self.oil_py) |
| 49 | + fc_r = self.fc.thetaf(self.oil_r) |
| 50 | + self._check_points(fc_py, fc_r) |
26 | 51 |
|
27 | 52 | def test_snaive(self):
|
28 |
| - fc = wrappers.snaive(self.aus) |
29 |
| - self.assertAlmostEqual(fc.point_fc[(2011, 1)], 59.76678, places=3) |
30 |
| - self.assertAlmostEqual(fc.point_fc[(2012, 4)], 47.91374, places=3) |
31 |
| - self.assertAlmostEqual(fc.lower80[(2011, 1)], 55.37882, places=3) |
32 |
| - self.assertAlmostEqual(fc.upper95[(2012, 4)], 57.40424, places=3) |
| 53 | + fc_py = wrappers.snaive(self.aus_py) |
| 54 | + fc_r = self.fc.snaive(self.aus_r) |
| 55 | + self._check_points(fc_py, fc_r) |
33 | 56 |
|
34 | 57 | def test_rwf(self):
|
35 |
| - fc = wrappers.rwf(self.oil) |
36 |
| - self.assertAlmostEqual(fc.point_fc[2011], 467.7724, places=3) |
37 |
| - self.assertAlmostEqual(fc.point_fc[2020], 467.7724, places=3) |
38 |
| - self.assertAlmostEqual(fc.lower80[2011], 404.7558, places=3) |
39 |
| - self.assertAlmostEqual(fc.upper95[2020], 772.5385, places=3) |
40 |
| - |
41 |
| - def test_forecast(self): |
42 |
| - fc = wrappers.forecast(self.oil) |
43 |
| - self.assertAlmostEqual(fc.point_fc[2011], 467.7721, places=3) |
44 |
| - self.assertAlmostEqual(fc.point_fc[2020], 467.7721, places=3) |
45 |
| - self.assertAlmostEqual(fc.lower80[2011], 405.3255, places=3) |
46 |
| - self.assertAlmostEqual(fc.upper95[2020], 769.7543, places=3) |
47 |
| - fc = wrappers.forecast(self.aus) |
48 |
| - self.assertAlmostEqual(fc.point_fc[(2011, 1)], 57.87294, places=3) |
49 |
| - self.assertAlmostEqual(fc.point_fc[(2012, 4)], 52.84327, places=3) |
50 |
| - self.assertAlmostEqual(fc.lower80[(2011, 1)], 53.30794, places=3) |
51 |
| - self.assertAlmostEqual(fc.upper95[(2012, 4)], 62.60852, places=3) |
| 58 | + fc_py = wrappers.rwf(self.oil_py) |
| 59 | + fc_r = self.fc.rwf(self.oil_r) |
| 60 | + self._check_points(fc_py, fc_r) |
| 61 | + |
| 62 | + def test_forecast_nonseasonal(self): |
| 63 | + fc_py = wrappers.forecast(self.oil_py) |
| 64 | + fc_r = self.fc.forecast(self.oil_r) |
| 65 | + self._check_points(fc_py, fc_r) |
52 | 66 |
|
53 |
| - def test_auto_arima(self): |
54 |
| - fc = wrappers.auto_arima(self.oil) |
55 |
| - self.assertAlmostEqual(fc.point_fc[2011], 475.7004, places=3) |
56 |
| - self.assertAlmostEqual(fc.point_fc[2020], 547.0531, places=3) |
57 |
| - self.assertAlmostEqual(fc.lower80[2011], 412.6839, places=3) |
58 |
| - self.assertAlmostEqual(fc.upper95[2020], 851.8193, places=3) |
59 |
| - fc = wrappers.auto_arima(self.aus) |
60 |
| - self.assertAlmostEqual(fc.point_fc[(2011, 1)], 60.64208, places=3) |
61 |
| - self.assertAlmostEqual(fc.point_fc[(2012, 4)], 51.55356, places=3) |
62 |
| - self.assertAlmostEqual(fc.lower80[(2011, 1)], 57.57010, places=3) |
63 |
| - self.assertAlmostEqual(fc.upper95[(2012, 4)], 57.52426, places=3) |
64 |
| - self.assertRaises(ValueError, wrappers.auto_arima, self.oil , |
65 |
| - xreg=range(len(self.oil))) |
66 |
| - self.assertRaises(ValueError, wrappers.auto_arima, self.oil, h=10, |
| 67 | + def test_forecast_seasonal(self): |
| 68 | + fc_py = wrappers.forecast(self.aus_py) |
| 69 | + fc_r = self.fc.forecast(self.aus_r) |
| 70 | + self._check_points(fc_py, fc_r) |
| 71 | + |
| 72 | + def test_auto_arima_nonseasonal(self): |
| 73 | + fc_py = wrappers.auto_arima(self.oil_py) |
| 74 | + model = self.fc.auto_arima(self.oil_r) |
| 75 | + fc_r = self.fc.forecast(model) |
| 76 | + self._check_points(fc_py, fc_r) |
| 77 | + |
| 78 | + def test_auto_arima_seasonal(self): |
| 79 | + fc_py = wrappers.auto_arima(self.aus_py) |
| 80 | + model = self.fc.auto_arima(self.aus_r) |
| 81 | + fc_r = self.fc.forecast(model) |
| 82 | + self._check_points(fc_py, fc_r) |
| 83 | + |
| 84 | + def test_auto_arima_raises(self): |
| 85 | + self.assertRaises(ValueError, wrappers.auto_arima, self.oil_py , |
| 86 | + xreg=range(len(self.oil_py))) |
| 87 | + self.assertRaises(ValueError, wrappers.auto_arima, self.oil_py, h=10, |
67 | 88 | newxreg=range(10))
|
68 | 89 |
|
69 | 90 | def test_stlf(self):
|
70 |
| - fc = wrappers.stlf(self.aus) |
71 |
| - self.assertAlmostEqual(fc.point_fc[(2011, 1)], 58.88951, places=3) |
72 |
| - self.assertAlmostEqual(fc.point_fc[(2012, 4)], 51.68499, places=3) |
73 |
| - self.assertAlmostEqual(fc.lower80[(2011, 1)], 56.69656, places=3) |
74 |
| - self.assertAlmostEqual(fc.upper95[(2012, 4)], 57.59713, places=3) |
| 91 | + fc_py = wrappers.stlf(self.aus_py) |
| 92 | + fc_r = self.fc.stlf(self.aus_r) |
| 93 | + self._check_points(fc_py, fc_r) |
75 | 94 |
|
76 | 95 | def test_acf(self):
|
77 |
| - acf = wrappers.acf(self.oil, lag_max=10) |
78 |
| - self.assertEqual(acf.name, 'Acf') |
79 |
| - self.assertEqual(len(acf), 10) |
80 |
| - self.assertAlmostEqual(acf[1], 0.8708, places=3) |
81 |
| - self.assertAlmostEqual(acf[10], -0.2016, places=3) |
| 96 | + acf_py = wrappers.acf(self.oil_py, lag_max=10) |
| 97 | + self.assertEqual(acf_py.name, 'Acf') |
| 98 | + self.assertEqual(len(acf_py), 10) |
| 99 | + acf_r = self.fc.Acf(self.oil_r, plot=False, lag_max=10) |
| 100 | + self.assertAlmostEqual(acf_py[1], acf_r.rx2('acf')[1], places=3) |
| 101 | + self.assertAlmostEqual(acf_py[10], acf_r.rx2('acf')[10], places=3) |
82 | 102 |
|
83 | 103 | def test_pacf(self):
|
84 |
| - acf = wrappers.pacf(self.oil, lag_max=10) |
85 |
| - self.assertEqual(acf.name, 'Pacf') |
86 |
| - self.assertEqual(len(acf), 10) |
87 |
| - self.assertAlmostEqual(acf[1], 0.8708, places=3) |
88 |
| - self.assertAlmostEqual(acf[10], 0.1104, places=3) |
| 104 | + pacf_py = wrappers.pacf(self.oil_py, lag_max=10) |
| 105 | + self.assertEqual(pacf_py.name, 'Pacf') |
| 106 | + self.assertEqual(len(pacf_py), 10) |
| 107 | + pacf_r = self.fc.Pacf(self.oil_r, plot=False, **{'lag.max':10}) |
| 108 | + self.assertAlmostEqual(pacf_py.values[0], pacf_r.rx2('acf')[0], places=3) |
| 109 | + self.assertAlmostEqual(pacf_py.values[-1], pacf_r.rx2('acf')[-1], places=3) |
89 | 110 |
|
90 | 111 | def test_ses(self):
|
91 |
| - fc = wrappers.ses(self.oil, level=80) |
92 |
| - self.assertAlmostEqual(fc.point_fc[2011], 467.7724, places=3) |
93 |
| - self.assertAlmostEqual(fc.point_fc[2020], 467.7724, places=3) |
94 |
| - self.assertAlmostEqual(fc.lower80[2011], 405.3270, places=3) |
95 |
| - self.assertAlmostEqual(fc.upper80[2020], 665.2418, places=3) |
96 |
| - self.assertRaises(ValueError, wrappers.ses, self.oil, alpha=0) |
97 |
| - self.assertRaises(ValueError, wrappers.ses, self.oil, alpha=1.0) |
| 112 | + fc_py = wrappers.ses(self.oil_py) |
| 113 | + fc_r = self.fc.ses(self.oil_r, initial='simple') |
| 114 | + self._check_points(fc_py, fc_r) |
| 115 | + |
| 116 | + def test_ses_raises(self): |
| 117 | + self.assertRaises(ValueError, wrappers.ses, self.oil_py, alpha=0) |
| 118 | + self.assertRaises(ValueError, wrappers.ses, self.oil_py, alpha=1.0) |
98 | 119 |
|
99 | 120 | def test_holt(self):
|
100 |
| - fc = wrappers.holt(self.austa) |
101 |
| - self.assertAlmostEqual(fc.point_fc[2011], 5.5690, places=3) |
102 |
| - self.assertAlmostEqual(fc.point_fc[2020], 6.7219, places=3) |
103 |
| - self.assertAlmostEqual(fc.lower80[2011], 5.3316, places=3) |
104 |
| - self.assertAlmostEqual(fc.upper95[2020], 8.8653, places=3) |
105 |
| - self.assertRaises(ValueError, wrappers.holt, self.austa, alpha=0) |
106 |
| - self.assertRaises(ValueError, wrappers.holt, self.austa, alpha=1.0) |
107 |
| - self.assertRaises(ValueError, wrappers.holt, self.austa, beta=0) |
108 |
| - self.assertRaises(ValueError, wrappers.holt, self.austa, beta=1.0) |
| 121 | + fc_py = wrappers.holt(self.austa_py) |
| 122 | + fc_r = self.fc.holt(self.austa_r, initial='simple') |
| 123 | + self._check_points(fc_py, fc_r) |
| 124 | + |
| 125 | + def test_holt_raises(self): |
| 126 | + self.assertRaises(ValueError, wrappers.holt, self.austa_py, alpha=0) |
| 127 | + self.assertRaises(ValueError, wrappers.holt, self.austa_py, alpha=1.0) |
| 128 | + self.assertRaises(ValueError, wrappers.holt, self.austa_py, beta=0) |
| 129 | + self.assertRaises(ValueError, wrappers.holt, self.austa_py, beta=1.0) |
109 | 130 |
|
110 | 131 | def test_hw(self):
|
111 |
| - fc = wrappers.hw(self.aus) |
112 |
| - self.assertAlmostEqual(fc.point_fc[(2011, 1)], 60.0507, places=3) |
113 |
| - self.assertAlmostEqual(fc.point_fc[(2012, 4)], 52.2922, places=3) |
114 |
| - self.assertAlmostEqual(fc.lower80[(2011, 1)], 56.9353, places=3) |
115 |
| - self.assertAlmostEqual(fc.upper95[(2012, 4)], 64.7920, places=3) |
116 |
| - self.assertRaises(ValueError, wrappers.hw, self.aus, alpha=0) |
117 |
| - self.assertRaises(ValueError, wrappers.hw, self.aus, alpha=1.0) |
118 |
| - self.assertRaises(ValueError, wrappers.hw, self.aus, beta=0) |
119 |
| - self.assertRaises(ValueError, wrappers.hw, self.aus, beta=1.0) |
120 |
| - self.assertRaises(ValueError, wrappers.hw, self.aus, gamma=0) |
121 |
| - self.assertRaises(ValueError, wrappers.hw, self.aus, gamma=1.0) |
122 |
| - |
123 |
| - def test_arima(self): |
124 |
| - fc = wrappers.arima(self.aus, order=(1,0,0), seasonal=(1,1,0), |
| 132 | + fc_py = wrappers.hw(self.aus_py) |
| 133 | + fc_r = self.fc.hw(self.aus_r, initial='simple') |
| 134 | + self._check_points(fc_py, fc_r) |
| 135 | + |
| 136 | + def test_hw_raises(self): |
| 137 | + self.assertRaises(ValueError, wrappers.hw, self.aus_py, alpha=0) |
| 138 | + self.assertRaises(ValueError, wrappers.hw, self.aus_py, alpha=1.0) |
| 139 | + self.assertRaises(ValueError, wrappers.hw, self.aus_py, beta=0) |
| 140 | + self.assertRaises(ValueError, wrappers.hw, self.aus_py, beta=1.0) |
| 141 | + self.assertRaises(ValueError, wrappers.hw, self.aus_py, gamma=0) |
| 142 | + self.assertRaises(ValueError, wrappers.hw, self.aus_py, gamma=1.0) |
| 143 | + |
| 144 | + def test_arima_seasonal(self): |
| 145 | + fc_py = wrappers.arima(self.aus_py, order=(1,0,0), seasonal=(1,1,0), |
125 | 146 | include_constant=True)
|
126 |
| - self.assertAlmostEqual(fc.point_fc[(2011, 1)], 60.6420, places=3) |
127 |
| - self.assertAlmostEqual(fc.point_fc[(2012, 4)], 51.5535, places=3) |
128 |
| - self.assertAlmostEqual(fc.lower80[(2011, 1)], 57.5701, places=3) |
129 |
| - self.assertAlmostEqual(fc.upper95[(2012, 4)], 57.5242, places=3) |
130 |
| - self.assertEqual(fc.shape[0], 8) |
131 |
| - fc = wrappers.arima(self.oil, order=(0,1,0)) |
132 |
| - self.assertEqual(fc.shape[0], 10) |
| 147 | + order = robjects.r.c(1., 0., 0.) |
| 148 | + seasonal = robjects.r.c(1., 1., 0.) |
| 149 | + model = self.fc.Arima(self.aus_r, order=order, seasonal=seasonal, |
| 150 | + include_constant=True) |
| 151 | + fc_r = self.fc.forecast(model) |
| 152 | + self._check_points(fc_py, fc_r) |
| 153 | + self.assertEqual(fc_py.shape[0], 8) |
| 154 | + |
| 155 | + def test_arima_nonseasonal(self): |
| 156 | + fc_py = wrappers.arima(self.oil_py, order=(0,1,0)) |
| 157 | + order = robjects.r.c(0., 1., 0.) |
| 158 | + model = self.fc.Arima(self.oil_r, order=order) |
| 159 | + fc_r = self.fc.forecast(model) |
| 160 | + self._check_points(fc_py, fc_r) |
| 161 | + self.assertEqual(fc_py.shape[0], 10) |
133 | 162 |
|
134 | 163 |
|
135 | 164 |
|
|
0 commit comments