Skip to content

Commit 5d30236

Browse files
aloctavodiaJunpeng Lao
authored and
Junpeng Lao
committed
add test for r2_score (#2729)
* add test for r2_score * add change to release-notes
1 parent 1b1caa6 commit 5d30236

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- Fixed `compareplot` to use `loo` output.
1212
- Add test for `model.logp_array` and `model.bijection` (#2724)
1313
- Fixed `sample_ppc` and `sample_ppc_w` to iterate all chains(#2633)
14+
- Add test for `stats.r2_score` (#2729)
1415

1516

1617

pymc3/stats.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -995,14 +995,13 @@ def r2_score(y_true, y_pred, round_to=2):
995995
if y_true.ndim > 1:
996996
dimension = 1
997997

998-
e = y_true - y_pred
999-
var_y_est = np.var(y_pred, dimension)
1000-
var_e = np.var(e, dimension)
998+
var_y_est = np.var(y_pred, axis=dimension)
999+
var_e = np.var(y_true - y_pred, axis=dimension)
10011000

10021001
r2 = var_y_est / (var_y_est + var_e)
10031002
r2_median = np.around(np.median(r2), round_to)
10041003
r2_mean = np.around(np.mean(r2), round_to)
10051004
r2_std = np.around(np.std(r2), round_to)
1006-
R2_r = namedtuple('R2_r', 'R2_median, R2_mean, R2_std')
1007-
return R2_r(r2_median, r2_mean, r2_std)
1005+
r2_r = namedtuple('r2_r', 'r2_median, r2_mean, r2_std')
1006+
return r2_r(r2_median, r2_mean, r2_std)
10081007

pymc3/tests/test_stats.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from .helpers import SeededTest
77
from ..tests import backend_fixtures as bf
88
from ..backends import ndarray
9-
from ..stats import summary, autocorr, hpd, mc_error, quantiles, make_indices, bfmi
9+
from ..stats import (summary, autocorr, hpd, mc_error, quantiles, make_indices,
10+
bfmi, r2_score)
1011
from ..theanof import floatX_array
1112
import pymc3.stats as pmstats
1213
from numpy.random import random, normal
@@ -276,6 +277,14 @@ def test_bfmi(self):
276277

277278
assert_almost_equal(bfmi(trace), 0.8)
278279

280+
def test_r2_score(self):
281+
x = np.linspace(0, 1, 100)
282+
y = np.random.normal(x, 1)
283+
res = st.linregress(x, y)
284+
assert_almost_equal(res.rvalue ** 2,
285+
r2_score(y, res.intercept +
286+
res.slope * x).r2_median,
287+
2)
279288

280289
class TestDfSummary(bf.ModelBackendSampledTestCase):
281290
backend = ndarray.NDArray

0 commit comments

Comments
 (0)