Skip to content

Commit 3cdde86

Browse files
committed
Fixed base dataset class and adjusted the rest to use new gridsearch function
1 parent bbe50aa commit 3cdde86

File tree

5 files changed

+50
-47
lines changed

5 files changed

+50
-47
lines changed

decision_trees/datasets/dataset_base.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import numpy as np
55

6-
from decision_trees import dataset_tester
6+
from decision_trees.dataset_tester import test_dataset
7+
from decision_trees.utils.constants import ClassifierType, GridSearchType
78

89

910
class DatasetBase(metaclass=abc.ABCMeta):
@@ -16,18 +17,19 @@ def load_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
1617
def _normalise(data: np.ndarray) -> np.ndarray:
1718
pass
1819

19-
def run(self):
20+
def test_as_classifier(self, number_of_bits_per_feature: int):
2021
train_data, train_target, test_data, test_target = self.load_data()
2122

22-
dataset_tester.test_dataset(8,
23-
train_data, train_target, test_data, test_target,
24-
dataset_tester.ClassifierType.DECISION_TREE,
25-
)
26-
27-
def run_grid_search(self):
28-
train_data, train_target, test_data, test_target = self.load_data()
29-
30-
dataset_tester.grid_search(train_data, train_target,
31-
test_data, test_target,
32-
dataset_tester.ClassifierType.DECISION_TREE
33-
)
23+
print('Testing decision tree classifier')
24+
test_dataset(
25+
number_of_bits_per_feature,
26+
train_data, train_target, test_data, test_target,
27+
ClassifierType.DECISION_TREE
28+
)
29+
30+
print('Testing random forest classifier')
31+
test_dataset(
32+
number_of_bits_per_feature,
33+
train_data, train_target, test_data, test_target,
34+
ClassifierType.RANDOM_FOREST
35+
)

decision_trees/datasets/emg_raw.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55

66
from decision_trees.datasets.dataset_base import DatasetBase
7-
from decision_trees.dataset_tester import test_dataset
87
from decision_trees.gridsearch import perform_gridsearch
98
from decision_trees.utils.constants import ClassifierType, GridSearchType
109

@@ -95,16 +94,13 @@ def _normalise(self, data: np.ndarray):
9594
print(f"np.unique(train_target): {np.unique(train_target)}")
9695
print(f"np.unique(test_target): {np.unique(test_target)}")
9796

98-
test_dataset(8,
99-
train_data, train_target,
100-
test_data, test_target,
101-
ClassifierType.RANDOM_FOREST,
102-
)
97+
d.test_as_classifier(16)
10398

10499
perform_gridsearch(train_data, train_target,
105100
test_data, test_target,
106-
10 - 1,
101+
[16, 12, 8, 6, 4, 2, 1],
107102
ClassifierType.RANDOM_FOREST,
108-
GridSearchType.NONE,
109-
'./../../data/gridsearch_results/'
103+
GridSearchType.PARFIT,
104+
'./../../data/gridsearch_results/',
105+
d.__class__.__name__
110106
)

decision_trees/datasets/fashion_mnist_raw.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _normalise(data: np.ndarray):
3232
return data
3333

3434

35-
def test_mnist_raw():
35+
def test_fashion_mnist_raw():
3636
#####################################
3737
# SET THE FOLLOWING PARAMETERS
3838
# MNIST FASHION DATABASE
@@ -42,7 +42,7 @@ def test_mnist_raw():
4242
#####################################
4343

4444
d = FashionMnistRaw()
45-
d.run()
45+
d.test_as_classifier(8)
4646

4747
assert True
4848

@@ -56,19 +56,27 @@ def main():
5656
print(f"np.unique(test_target): {np.unique(test_target)}")
5757

5858
from decision_trees import dataset_tester
59-
60-
dataset_tester.perform_gridsearch(train_data[:60000], train_target[:60000],
61-
test_data[:10000], test_target[:10000],
62-
10 - 1,
63-
dataset_tester.ClassifierType.DECISION_TREE,
64-
dataset_tester.GridSearchType.NONE,
65-
"./../../data/gridsearch_results/"
66-
)
67-
68-
# dataset_tester.test_dataset(4,
69-
# train_data[:60000], train_target[:60000], test_data[:10000], test_target[:10000],
70-
# dataset_tester.ClassifierType.DECISION_TREE,
71-
# )
59+
from decision_trees.gridsearch import perform_gridsearch
60+
from decision_trees.utils.constants import ClassifierType, GridSearchType
61+
62+
perform_gridsearch(
63+
train_data[:60000], train_target[:60000],
64+
test_data[:10000], test_target[:10000],
65+
[16, 12, 8, 6, 4, 2, 1],
66+
ClassifierType.DECISION_TREE,
67+
GridSearchType.NONE,
68+
"./../../data/gridsearch_results/",
69+
d.__class__.__name__
70+
)
71+
72+
# this is the same as the code below, but on the whole dataset
73+
d.test_as_classifier(8)
74+
75+
dataset_tester.test_dataset(
76+
8,
77+
train_data[:60000], train_target[:60000], test_data[:10000], test_target[:10000],
78+
ClassifierType.DECISION_TREE
79+
)
7280

7381

7482
if __name__ == "__main__":

decision_trees/datasets/mnist_raw.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def test_mnist_raw():
4949
number_of_test_samples = 100 # 70000 - number_of_train_samples
5050
# END OF PARAMETERS SETTING
5151
if (number_of_train_samples + number_of_test_samples) > 70000:
52-
print("ERROR, too much samples set!")
52+
print("ERROR, too many samples set!")
5353
#####################################
5454

5555
d = MnistRaw(number_of_train_samples, number_of_test_samples)
56-
d.run()
56+
d.test_as_classifier(8)
5757

5858
assert True

decision_trees/datasets/terrain.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import numpy as np
77

88
from decision_trees.datasets.dataset_base import DatasetBase
9-
from decision_trees.dataset_tester import test_dataset
109
from decision_trees.gridsearch import perform_gridsearch
1110
from decision_trees.utils.constants import ClassifierType, GridSearchType
1211

@@ -118,16 +117,14 @@ def main():
118117
print(f'np.shape(train_data): {np.shape(train_data)}')
119118
print(f'np.unique(test_target): {np.unique(test_target)}')
120119

121-
test_dataset(32,
122-
train_data, train_target, test_data, test_target,
123-
ClassifierType.RANDOM_FOREST
124-
)
120+
d.test_as_classifier(16)
125121

126122
perform_gridsearch(train_data, train_target, test_data, test_target,
127-
10 - 1,
123+
[16, 12, 8, 6, 4, 2, 1],
128124
ClassifierType.RANDOM_FOREST,
129125
GridSearchType.NONE,
130-
'./../../data/gridsearch_results/'
126+
'./../../data/gridsearch_results/',
127+
d.__class__.__name__
131128
)
132129

133130

0 commit comments

Comments
 (0)