Skip to content

Commit 7e8631a

Browse files
committed
cleanup bench script
1 parent 5fc1364 commit 7e8631a

File tree

1 file changed

+32
-30
lines changed

1 file changed

+32
-30
lines changed

benchmarks/bench_isolation_forest.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919

2020
np.random.seed(1)
2121

22-
# set to True to obtain histograms of the decision functions
23-
decision_function = True
24-
25-
datasets = ['http'] #, 'smtp', 'SA', 'SF', 'shuttle', 'forestcover']
22+
datasets = ['http', 'smtp', 'SA', 'SF', 'shuttle', 'forestcover']
23+
# datasets = ['http']
2624

2725
for dat in datasets:
2826
# loading and vectorization
@@ -78,9 +76,8 @@
7876
if dat == 'http' or dat == 'smtp':
7977
y = (y != 'normal.').astype(int)
8078

81-
n_samples, n_features = np.shape(X)
79+
n_samples, n_features = X.shape
8280
n_samples_train = n_samples // 2
83-
n_samples_test = n_samples - n_samples_train
8481

8582
X = X.astype(float)
8683
X_train = X[:n_samples_train, :]
@@ -97,29 +94,34 @@
9794

9895
scoring = - model.decision_function(X_test) # the lower, the more normal
9996

100-
if decision_function==True:
101-
f, ax = plt.subplots(3, sharex=True, sharey=True)
102-
ax[0].hist(scoring, np.linspace(-0.5, 0.5, 200), color='black')
103-
ax[0].set_title('decision function for %s dataset' % dat, size=20)
104-
ax[0].legend(loc="lower right")
105-
ax[1].hist(scoring[y_test == 0], np.linspace(-0.5, 0.5, 200), color='b',
106-
label='normal data')
107-
ax[1].legend(loc="lower right")
108-
ax[2].hist(scoring[y_test == 1], np.linspace(-0.5, 0.5, 200), color='r',
109-
label='outliers')
110-
ax[2].legend(loc="lower right")
111-
else:
112-
predict_time = time() - tstart
113-
fpr, tpr, thresholds = roc_curve(y_test, scoring)
114-
AUC = auc(fpr, tpr)
115-
plt.plot(fpr, tpr, lw=1, label='ROC for %s (area = %0.3f, train-time: %0.2fs, test-time: %0.2fs)' % (dat, AUC, fit_time, predict_time))
116-
117-
if decision_function==False:
118-
plt.xlim([-0.05, 1.05])
119-
plt.ylim([-0.05, 1.05])
120-
plt.xlabel('False Positive Rate')
121-
plt.ylabel('True Positive Rate')
122-
plt.title('Receiver operating characteristic')
123-
plt.legend(loc="lower right")
97+
# Show score histograms
98+
f, ax = plt.subplots(3, sharex=True, sharey=True)
99+
bins = np.linspace(-0.5, 0.5, 200)
100+
ax[0].hist(scoring, bins, color='black')
101+
ax[0].set_title('decision function for %s dataset' % dat)
102+
ax[0].legend(loc="lower right")
103+
ax[1].hist(scoring[y_test == 0], bins, color='b',
104+
label='normal data')
105+
ax[1].legend(loc="lower right")
106+
ax[2].hist(scoring[y_test == 1], bins, color='r',
107+
label='outliers')
108+
ax[2].legend(loc="lower right")
109+
110+
# Show ROC Curves
111+
plt.figure(0)
112+
predict_time = time() - tstart
113+
fpr, tpr, thresholds = roc_curve(y_test, scoring)
114+
AUC = auc(fpr, tpr)
115+
label = ('%s (area: %0.3f, train-time: %0.2fs, '
116+
'test-time: %0.2fs)' % (dat, AUC, fit_time, predict_time))
117+
plt.plot(fpr, tpr, lw=1, label=label)
118+
119+
plt.figure(0) # for ROC curves
120+
plt.xlim([-0.05, 1.05])
121+
plt.ylim([-0.05, 1.05])
122+
plt.xlabel('False Positive Rate')
123+
plt.ylabel('True Positive Rate')
124+
plt.title('Receiver operating characteristic (ROC) curves')
125+
plt.legend(loc="lower right")
124126

125127
plt.show()

0 commit comments

Comments
 (0)