Skip to content

Commit 72b4819

Browse files
Michal-Fularzdpieczynski
authored andcommitted
Added some extra info into gridsearch result filename and changed its name. Fixed parfit code.
1 parent 37d5bef commit 72b4819

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

decision_trees/gridsearch.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ def perform_gridsearch(train_data: np.ndarray, train_target: np.ndarray,
2424
gridsearch_type: GridSearchType,
2525
path: str
2626
):
27-
filename = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + "_gridsearch_results.txt"
27+
filename_with_path = path + '/' + datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') \
28+
+ '_' + clf_type.name + '_gridsearch_results.txt'
2829

29-
# first train on the non-qunatized data
30+
# first train on the non-quantized data
3031
if gridsearch_type == GridSearchType.SCIKIT:
3132
best_model, best_score = _scikit_gridsearch(train_data, train_target, test_data, test_target, clf_type)
3233
elif gridsearch_type == GridSearchType.PARFIT:
@@ -37,11 +38,13 @@ def perform_gridsearch(train_data: np.ndarray, train_target: np.ndarray,
3738
raise ValueError('Requested GridSearchType is not available')
3839

3940
print('No quantization - full resolution')
40-
_save_score_and_model_to_file(best_score, best_model, filename)
41+
with open(filename_with_path, 'a') as f:
42+
print('No quantization - full resolution', file=f)
43+
_save_score_and_model_to_file(best_score, best_model, filename_with_path)
4144

4245
# repeat on quantized data with different number of bits
4346
for i in range(number_of_bits_per_feature_max, 0, -1):
44-
train_data_quantized, test_data_quantized = quantize_data(train_data, test_data, i, False, "./../../data/")
47+
train_data_quantized, test_data_quantized = quantize_data(train_data, test_data, i, False, './../../data/')
4548

4649
if gridsearch_type == GridSearchType.SCIKIT:
4750
best_model, best_score = _scikit_gridsearch(train_data_quantized, train_target, test_data, test_target, clf_type)
@@ -56,12 +59,14 @@ def perform_gridsearch(train_data: np.ndarray, train_target: np.ndarray,
5659
else:
5760
raise ValueError('Requested GridSearchType is not available')
5861
print(f'number of bits: {i}')
59-
_save_score_and_model_to_file(best_score, best_model, path + "/" + filename)
62+
with open(filename_with_path, 'a') as f:
63+
print(f'number of bits: {i}', file=f)
64+
_save_score_and_model_to_file(best_score, best_model, filename_with_path)
6065

6166

62-
def _save_score_and_model_to_file(score, model, fileaname: str):
67+
def _save_score_and_model_to_file(score, model, filename: str):
6368
print(f"f1: {score:{1}.{5}}: {model}")
64-
with open(fileaname, "a") as f:
69+
with open(filename, "a") as f:
6570
print(f"f1: {score:{1}.{5}}: {model}", file=f)
6671

6772

@@ -93,7 +98,7 @@ def _scikit_gridsearch(
9398
elif clf_type == ClassifierType.RANDOM_FOREST:
9499
clf = GridSearchCV(RandomForestClassifier(), tuned_parameters, cv=5, scoring=f'{score}', n_jobs=3)
95100
else:
96-
raise ValueError("Unknown classifier type specified")
101+
raise ValueError('Unknown classifier type specified')
97102

98103
clf = clf.fit(train_data, train_target)
99104

@@ -142,10 +147,17 @@ def _parfit_gridsearch(
142147
):
143148
grid = get_tuned_parameters(clf_type)
144149

145-
best_model, best_score, all_models, all_scores = bestFit(RandomForestClassifier, ParameterGrid(grid),
150+
if clf_type == ClassifierType.DECISION_TREE:
151+
model = DecisionTreeClassifier
152+
elif clf_type == ClassifierType.RANDOM_FOREST:
153+
model = RandomForestClassifier
154+
else:
155+
raise ValueError("Unknown classifier type specified")
156+
157+
best_model, best_score, all_models, all_scores = bestFit(model, ParameterGrid(grid),
146158
train_data, train_target, test_data, test_target,
147159
predictType='predict',
148-
metric=f1_score, bestScore='max',
160+
metric=metrics.f1_score, bestScore='max',
149161
scoreLabel='f1_weighted', showPlot=show_plot)
150162

151-
return best_model, best_score
163+
return best_model, best_score

0 commit comments

Comments
 (0)