@@ -32,7 +32,7 @@ def _normalise(data: np.ndarray):
32
32
return data
33
33
34
34
35
- def test_mnist_raw ():
35
+ def test_fashion_mnist_raw ():
36
36
#####################################
37
37
# SET THE FOLLOWING PARAMETERS
38
38
# MNIST FASHION DATABASE
@@ -42,7 +42,7 @@ def test_mnist_raw():
42
42
#####################################
43
43
44
44
d = FashionMnistRaw ()
45
- d .run ( )
45
+ d .test_as_classifier ( 8 )
46
46
47
47
assert True
48
48
@@ -56,19 +56,27 @@ def main():
56
56
print (f"np.unique(test_target): { np .unique (test_target )} " )
57
57
58
58
from decision_trees import dataset_tester
59
-
60
- dataset_tester .perform_gridsearch (train_data [:60000 ], train_target [:60000 ],
61
- test_data [:10000 ], test_target [:10000 ],
62
- 10 - 1 ,
63
- dataset_tester .ClassifierType .DECISION_TREE ,
64
- dataset_tester .GridSearchType .NONE ,
65
- "./../../data/gridsearch_results/"
66
- )
67
-
68
- # dataset_tester.test_dataset(4,
69
- # train_data[:60000], train_target[:60000], test_data[:10000], test_target[:10000],
70
- # dataset_tester.ClassifierType.DECISION_TREE,
71
- # )
59
+ from decision_trees .gridsearch import perform_gridsearch
60
+ from decision_trees .utils .constants import ClassifierType , GridSearchType
61
+
62
+ perform_gridsearch (
63
+ train_data [:60000 ], train_target [:60000 ],
64
+ test_data [:10000 ], test_target [:10000 ],
65
+ [16 , 12 , 8 , 6 , 4 , 2 , 1 ],
66
+ ClassifierType .DECISION_TREE ,
67
+ GridSearchType .NONE ,
68
+ "./../../data/gridsearch_results/" ,
69
+ d .__class__ .__name__
70
+ )
71
+
72
+ # this is the same as the code below, but on the whole dataset
73
+ d .test_as_classifier (8 )
74
+
75
+ dataset_tester .test_dataset (
76
+ 8 ,
77
+ train_data [:60000 ], train_target [:60000 ], test_data [:10000 ], test_target [:10000 ],
78
+ ClassifierType .DECISION_TREE
79
+ )
72
80
73
81
74
82
if __name__ == "__main__" :
0 commit comments