@@ -24,9 +24,10 @@ def perform_gridsearch(train_data: np.ndarray, train_target: np.ndarray,
24
24
gridsearch_type : GridSearchType ,
25
25
path : str
26
26
):
27
- filename = datetime .datetime .now ().strftime ("%Y_%m_%d_%H_%M_%S" ) + "_gridsearch_results.txt"
27
+ filename_with_path = path + '/' + datetime .datetime .now ().strftime ('%Y_%m_%d_%H_%M_%S' ) \
28
+ + '_' + clf_type .name + '_gridsearch_results.txt'
28
29
29
- # first train on the non-qunatized data
30
+ # first train on the non-quantized data
30
31
if gridsearch_type == GridSearchType .SCIKIT :
31
32
best_model , best_score = _scikit_gridsearch (train_data , train_target , test_data , test_target , clf_type )
32
33
elif gridsearch_type == GridSearchType .PARFIT :
@@ -37,11 +38,13 @@ def perform_gridsearch(train_data: np.ndarray, train_target: np.ndarray,
37
38
raise ValueError ('Requested GridSearchType is not available' )
38
39
39
40
print ('No quantization - full resolution' )
40
- _save_score_and_model_to_file (best_score , best_model , filename )
41
+ with open (filename_with_path , 'a' ) as f :
42
+ print ('No quantization - full resolution' , file = f )
43
+ _save_score_and_model_to_file (best_score , best_model , filename_with_path )
41
44
42
45
# repeat on quantized data with different number of bits
43
46
for i in range (number_of_bits_per_feature_max , 0 , - 1 ):
44
- train_data_quantized , test_data_quantized = quantize_data (train_data , test_data , i , False , " ./../../data/" )
47
+ train_data_quantized , test_data_quantized = quantize_data (train_data , test_data , i , False , ' ./../../data/' )
45
48
46
49
if gridsearch_type == GridSearchType .SCIKIT :
47
50
best_model , best_score = _scikit_gridsearch (train_data_quantized , train_target , test_data , test_target , clf_type )
@@ -56,12 +59,14 @@ def perform_gridsearch(train_data: np.ndarray, train_target: np.ndarray,
56
59
else :
57
60
raise ValueError ('Requested GridSearchType is not available' )
58
61
print (f'number of bits: { i } ' )
59
- _save_score_and_model_to_file (best_score , best_model , path + "/" + filename )
62
+ with open (filename_with_path , 'a' ) as f :
63
+ print (f'number of bits: { i } ' , file = f )
64
+ _save_score_and_model_to_file (best_score , best_model , filename_with_path )
60
65
61
66
62
- def _save_score_and_model_to_file (score , model , fileaname : str ):
67
+ def _save_score_and_model_to_file (score , model , filename : str ):
63
68
print (f"f1: { score :{1 }.{5 }} : { model } " )
64
- with open (fileaname , "a" ) as f :
69
+ with open (filename , "a" ) as f :
65
70
print (f"f1: { score :{1 }.{5 }} : { model } " , file = f )
66
71
67
72
@@ -93,7 +98,7 @@ def _scikit_gridsearch(
93
98
elif clf_type == ClassifierType .RANDOM_FOREST :
94
99
clf = GridSearchCV (RandomForestClassifier (), tuned_parameters , cv = 5 , scoring = f'{ score } ' , n_jobs = 3 )
95
100
else :
96
- raise ValueError (" Unknown classifier type specified" )
101
+ raise ValueError (' Unknown classifier type specified' )
97
102
98
103
clf = clf .fit (train_data , train_target )
99
104
@@ -142,10 +147,17 @@ def _parfit_gridsearch(
142
147
):
143
148
grid = get_tuned_parameters (clf_type )
144
149
145
- best_model , best_score , all_models , all_scores = bestFit (RandomForestClassifier , ParameterGrid (grid ),
150
+ if clf_type == ClassifierType .DECISION_TREE :
151
+ model = DecisionTreeClassifier
152
+ elif clf_type == ClassifierType .RANDOM_FOREST :
153
+ model = RandomForestClassifier
154
+ else :
155
+ raise ValueError ("Unknown classifier type specified" )
156
+
157
+ best_model , best_score , all_models , all_scores = bestFit (model , ParameterGrid (grid ),
146
158
train_data , train_target , test_data , test_target ,
147
159
predictType = 'predict' ,
148
- metric = f1_score , bestScore = 'max' ,
160
+ metric = metrics . f1_score , bestScore = 'max' ,
149
161
scoreLabel = 'f1_weighted' , showPlot = show_plot )
150
162
151
- return best_model , best_score
163
+ return best_model , best_score
0 commit comments