Skip to content

Commit b1c1841

Browse files
committed
Use train / test split in the classification introduction to avoid teaching bad habits
1 parent 522c972 commit b1c1841

File tree

1 file changed

+24
-28
lines changed

1 file changed

+24
-28
lines changed

notebooks/04A_supervised_classification.ipynb

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@
209209
"collapsed": false,
210210
"input": [
211211
"from sklearn.naive_bayes import GaussianNB\n",
212-
"from sklearn import cross_validation"
212+
"from sklearn.cross_validation import train_test_split"
213213
],
214214
"language": "python",
215215
"metadata": {},
@@ -220,21 +220,27 @@
220220
"collapsed": false,
221221
"input": [
222222
"# split the data into training and validation sets\n",
223-
"X = digits.data\n",
224-
"y = digits.target\n",
223+
"X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target)\n",
225224
"\n",
226225
"# train the model\n",
227226
"clf = GaussianNB()\n",
228-
"clf.fit(X, y)\n",
227+
"clf.fit(X_train, y_train)\n",
229228
"\n",
230229
"# use the model to predict the labels of the test data\n",
231-
"predicted = clf.predict(X)\n",
232-
"expected = y"
230+
"predicted = clf.predict(X_test)\n",
231+
"expected = y_test"
233232
],
234233
"language": "python",
235234
"metadata": {},
236235
"outputs": []
237236
},
237+
{
238+
"cell_type": "markdown",
239+
"metadata": {},
240+
"source": [
241+
"**Question**: why did we split the data into training and validation sets?"
242+
]
243+
},
238244
{
239245
"cell_type": "markdown",
240246
"metadata": {},
@@ -253,7 +259,8 @@
253259
"# plot the digits: each image is 8x8 pixels\n",
254260
"for i in range(64):\n",
255261
" ax = fig.add_subplot(8, 8, i + 1, xticks=[], yticks=[])\n",
256-
" ax.imshow(digits.images[i], cmap=plt.cm.binary)\n",
262+
" ax.imshow(X_test.reshape(-1, 8, 8)[i], cmap=plt.cm.binary,\n",
263+
" interpolation='nearest')\n",
257264
" \n",
258265
" # label the image with the target value\n",
259266
" if predicted[i] == expected[i]:\n",
@@ -265,13 +272,6 @@
265272
"metadata": {},
266273
"outputs": []
267274
},
268-
{
269-
"cell_type": "markdown",
270-
"metadata": {},
271-
"source": [
272-
"**Question: what might be a problem with judging performance based on these predictions?**"
273-
]
274-
},
275275
{
276276
"cell_type": "heading",
277277
"level": 2,
@@ -301,6 +301,16 @@
301301
"metadata": {},
302302
"outputs": []
303303
},
304+
{
305+
"cell_type": "code",
306+
"collapsed": false,
307+
"input": [
308+
"matches.sum() / float(len(matches))"
309+
],
310+
"language": "python",
311+
"metadata": {},
312+
"outputs": []
313+
},
304314
{
305315
"cell_type": "markdown",
306316
"metadata": {},
@@ -349,20 +359,6 @@
349359
"source": [
350360
"We see here that in particular, the numbers 1, 2, 3, and 9 are often being labeled 8."
351361
]
352-
},
353-
{
354-
"cell_type": "markdown",
355-
"metadata": {},
356-
"source": [
357-
"As alluded to above, however, this is not a very good way to measure performance.\n",
358-
"Why? Because we are using the same data for **training** and **validation**.\n",
359-
"With this metric, a classifier could be perfect by simply storing all the training\n",
360-
"samples, and checking whether the \"unknown\" sample matches any exactly. Things are\n",
361-
"rarely this easy in real problems.\n",
362-
"\n",
363-
"In a later notebook, we'll learn how **validation sets** can be used\n",
364-
"to get around this difficulty."
365-
]
366362
}
367363
],
368364
"metadata": {}

0 commit comments

Comments
 (0)