Skip to content

Commit 2a03979

Browse files
committed
Add sklearn example for cancer dataset
1 parent 27d7a9f commit 2a03979

File tree

6 files changed

+72
-0
lines changed

6 files changed

+72
-0
lines changed

data/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
This is the general tool to convert CSV file to TFRecords file.
44

5+
## Cancer
6+
57
The example data in [cancer.csv](cancer.csv) looks like these.
68

9+
From [](https://github.com/mark-watson/cancer-deep-learning-model)
10+
711
```
812
3,7,7,4,4,9,4,8,1,1
913
1,1,1,1,2,1,2,1,1,0

data/a8a_test.libsvm

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
0 5:1 6:1 17:1 21:1 35:1 40:1 53:1 63:1 71:1 73:1 74:1 76:1 80:1 83:1
2+
1 5:1 7:1 17:1 22:1 36:1 40:1 51:1 63:1 67:1 73:1 74:1 76:1 81:1 83:1

data/cancer_test.csv

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
1,2,3,4,5,6,7,8,9,1
2+
1,1,1,1,1,1,1,1,1,1
3+
9,8,7,6,5,4,3,2,1,1
4+
9,9,9,9,9,9,9,9,9,1

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import setuptools
2+
setuptools.setup(name='trainer', version='1.0', packages=['trainer'])

sklearn_exmaples/cancer_classifier.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/env python
2+
3+
import sys
4+
import numpy as np
5+
from sklearn import metrics
6+
from sklearn.svm import SVC
7+
from sklearn.neural_network import MLPClassifier
8+
from sklearn.neighbors import KNeighborsClassifier
9+
from sklearn.tree import DecisionTreeClassifier
10+
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
11+
from sklearn.naive_bayes import GaussianNB
12+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
13+
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
14+
15+
FEATURE_NUMBER = 9
16+
17+
# Read train and test data
18+
with open("../data/cancer_train.csv", "r") as f:
19+
train_dataset = np.loadtxt(f, delimiter=",")
20+
train_labels = train_dataset[:, FEATURE_NUMBER]
21+
train_features = train_dataset[:, 0:FEATURE_NUMBER]
22+
23+
with open("../data/cancer_test.csv", "r") as f:
24+
test_dataset = np.loadtxt(f, delimiter=",")
25+
test_labels = test_dataset[:, FEATURE_NUMBER]
26+
test_features = test_dataset[:, 0:FEATURE_NUMBER]
27+
28+
# Define the model
29+
classifiers = [
30+
DecisionTreeClassifier(max_depth=5),
31+
MLPClassifier(algorithm='sgd', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1, learning_rate_init=0.001, batch_size=64, max_iter=100, verbose=False),
32+
MLPClassifier(algorithm='l-bfgs', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1),
33+
MLPClassifier(algorithm='adam', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1),
34+
KNeighborsClassifier(2),
35+
SVC(kernel="linear", C=0.025),
36+
SVC(gamma=2, C=1),
37+
RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
38+
AdaBoostClassifier(),
39+
GaussianNB(),
40+
LinearDiscriminantAnalysis(),
41+
QuadraticDiscriminantAnalysis()]
42+
43+
if len(sys.argv) > 1:
44+
classifier_index = int(sys.argv[1])
45+
else:
46+
classifier_index = 0
47+
classifier = classifiers[classifier_index]
48+
print("Use the classifier: {}".format(classifier))
49+
50+
# Train the model
51+
print("Start to train")
52+
model = classifier.fit(train_features, train_labels)
53+
54+
print("Start to validate")
55+
predict_labels = model.predict(test_features)
56+
auc = metrics.roc_auc_score(test_labels, predict_labels)
57+
accuracy = metrics.accuracy_score(test_labels, predict_labels)
58+
59+
# Print the metrics
60+
print("Accuracy: {}, acu: {}".format(accuracy, auc))

trainer/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)