1
+ from typing import Tuple , List
1
2
import csv
2
3
import os
3
- from typing import Tuple , List
4
-
5
4
import numpy as np
6
5
7
6
from decision_trees .datasets .dataset_base import DatasetBase
@@ -50,7 +49,7 @@ def _load_files(self, files_paths: List[str], is_output: bool) -> np.ndarray:
50
49
51
50
return data_array
52
51
53
- def _load_data (self ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray ]:
52
+ def load_data (self ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray ]:
54
53
input_train_files = []
55
54
output_train_files = []
56
55
input_test_files = []
@@ -71,6 +70,9 @@ def _load_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
71
70
input_test_data = self ._load_files (input_test_files , is_output = False )
72
71
output_test_data = self ._load_files (output_test_files , is_output = True )
73
72
73
+ input_train_data = self ._normalise (input_train_data )
74
+ input_test_data = self ._normalise (input_test_data )
75
+
74
76
return input_train_data , output_train_data , input_test_data , output_test_data
75
77
76
78
def _normalise (self , data : np .ndarray ):
@@ -83,16 +85,13 @@ def _normalise(self, data: np.ndarray):
83
85
if __name__ == "__main__" :
84
86
d = EMGRaw ("./../../data/EMG/" )
85
87
86
- train_data , train_target , test_data , test_target = d ._load_data ()
88
+ train_data , train_target , test_data , test_target = d .load_data ()
87
89
88
90
print (f"train_data.shape: { train_data .shape } " )
89
91
print (f"test_data.shape: { test_data .shape } " )
90
92
print (f"np.unique(train_target): { np .unique (train_target )} " )
91
93
print (f"np.unique(test_target): { np .unique (test_target )} " )
92
94
93
- train_data = d ._normalise (train_data )
94
- test_data = d ._normalise (test_data )
95
-
96
95
from decision_trees import dataset_tester
97
96
98
97
dataset_tester .perform_gridsearch (train_data [:19000 ], train_target [:19000 ],
@@ -106,4 +105,4 @@ def _normalise(self, data: np.ndarray):
106
105
# train_data[:19000], train_target[:19000],
107
106
# test_data[:10000], test_target[:10000],
108
107
# dataset_tester.ClassifierType.RANDOM_FOREST,
109
- # )
108
+ # )
0 commit comments