Skip to content

Commit 6157989

Browse files
Tony ShenTony Shen
authored andcommitted
Add example project
Each ML project is expected to be in its folder, with train.py as entry point and requirement.txt. This is a simple sklearn model on MNIST data
1 parent 080da8f commit 6157989

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

example_project/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
matplotlib==3.3.2
2+
scikit_learn==0.23.2

example_project/train.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
2+
# License: BSD 3 clause
3+
4+
# Standard scientific Python imports
5+
import matplotlib.pyplot as plt
6+
7+
# Import datasets, classifiers and performance metrics
8+
from sklearn import datasets, svm, metrics
9+
from sklearn.model_selection import train_test_split
10+
11+
# The digits dataset
12+
digits = datasets.load_digits()
13+
14+
# The data that we are interested in is made of 8x8 images of digits, let's
15+
# have a look at the first 4 images, stored in the `images` attribute of the
16+
# dataset. If we were working from image files, we could load them using
17+
# matplotlib.pyplot.imread. Note that each image must have the same size. For these
18+
# images, we know which digit they represent: it is given in the 'target' of
19+
# the dataset.
20+
_, axes = plt.subplots(2, 4)
21+
images_and_labels = list(zip(digits.images, digits.target))
22+
for ax, (image, label) in zip(axes[0, :], images_and_labels[:4]):
23+
ax.set_axis_off()
24+
ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
25+
ax.set_title('Training: %i' % label)
26+
27+
# To apply a classifier on this data, we need to flatten the image, to
28+
# turn the data in a (samples, feature) matrix:
29+
n_samples = len(digits.images)
30+
data = digits.images.reshape((n_samples, -1))
31+
32+
# Create a classifier: a support vector classifier
33+
classifier = svm.SVC(gamma=0.001)
34+
35+
# Split data into train and test subsets
36+
X_train, X_test, y_train, y_test = train_test_split(
37+
data, digits.target, test_size=0.5, shuffle=False)
38+
39+
# We learn the digits on the first half of the digits
40+
classifier.fit(X_train, y_train)
41+
42+
# Now predict the value of the digit on the second half:
43+
predicted = classifier.predict(X_test)
44+
45+
images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted))
46+
for ax, (image, prediction) in zip(axes[1, :], images_and_predictions[:4]):
47+
ax.set_axis_off()
48+
ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
49+
ax.set_title('Prediction: %i' % prediction)
50+
51+
print("Classification report for classifier %s:\n%s\n"
52+
% (classifier, metrics.classification_report(y_test, predicted)))
53+
disp = metrics.plot_confusion_matrix(classifier, X_test, y_test)
54+
disp.figure_.suptitle("Confusion Matrix")
55+
print("Confusion matrix:\n%s" % disp.confusion_matrix)
56+
57+
plt.show()

0 commit comments

Comments
 (0)