1
+ # flake8: noqa
1
2
import numpy as np
2
3
from sklearn import datasets
3
4
from sklearn .model_selection import train_test_split
@@ -13,6 +14,7 @@ def test_GaussianNB(N=10):
13
14
N = np .inf if N is None else N
14
15
15
16
i = 1
17
+ eps = np .finfo (float ).eps
16
18
while i < N + 1 :
17
19
n_ex = np .random .randint (1 , 300 )
18
20
n_feats = np .random .randint (1 , 100 )
@@ -33,29 +35,29 @@ def test_GaussianNB(N=10):
33
35
34
36
sk_preds = sklearn_NB .predict (X_test )
35
37
36
- for i in range (len (NB .labels )):
38
+ for j in range (len (NB .labels )):
37
39
P = NB .parameters
38
- jointi = np .log (sklearn_NB .class_prior_ [i ])
39
- jointi_mine = np .log (P ["prior" ][i ])
40
+ jointi = np .log (sklearn_NB .class_prior_ [j ])
41
+ jointi_mine = np .log (P ["prior" ][j ])
40
42
41
43
np .testing .assert_almost_equal (jointi , jointi_mine )
42
44
43
- n_ij = - 0.5 * np .sum (np .log (2.0 * np .pi * sklearn_NB .sigma_ [i , :]))
44
- n_ij_mine = - 0.5 * np .sum (np .log (2.0 * np .pi * P ["sigma" ][i ] ))
45
+ n_jk = - 0.5 * np .sum (np .log (2.0 * np .pi * sklearn_NB .sigma_ [j , :] + eps ))
46
+ n_jk_mine = - 0.5 * np .sum (np .log (2.0 * np .pi * P ["sigma" ][j ] + eps ))
45
47
46
- np .testing .assert_almost_equal (n_ij_mine , n_ij )
48
+ np .testing .assert_almost_equal (n_jk_mine , n_jk )
47
49
48
- n_ij2 = n_ij - 0.5 * np .sum (
49
- ((X_test - sklearn_NB .theta_ [i , :]) ** 2 ) / (sklearn_NB .sigma_ [i , :]), 1
50
+ n_jk2 = n_jk - 0.5 * np .sum (
51
+ ((X_test - sklearn_NB .theta_ [j , :]) ** 2 ) / (sklearn_NB .sigma_ [j , :]), 1
50
52
)
51
53
52
- n_ij2_mine = n_ij_mine - 0.5 * np .sum (
53
- ((X_test - P ["mean" ][i ]) ** 2 ) / (P ["sigma" ][i ]), 1
54
+ n_jk2_mine = n_jk_mine - 0.5 * np .sum (
55
+ ((X_test - P ["mean" ][j ]) ** 2 ) / (P ["sigma" ][j ]), 1
54
56
)
55
- np .testing .assert_almost_equal (n_ij2_mine , n_ij2 , decimal = 4 )
57
+ np .testing .assert_almost_equal (n_jk2_mine , n_jk2 , decimal = 4 )
56
58
57
- llh = jointi + n_ij2
58
- llh_mine = jointi_mine + n_ij2_mine
59
+ llh = jointi + n_jk2
60
+ llh_mine = jointi_mine + n_jk2_mine
59
61
60
62
np .testing .assert_almost_equal (llh_mine , llh , decimal = 4 )
61
63
0 commit comments