|
209 | 209 | "collapsed": false, |
210 | 210 | "input": [ |
211 | 211 | "from sklearn.naive_bayes import GaussianNB\n", |
212 | | - "from sklearn import cross_validation" |
| 212 | + "from sklearn.cross_validation import train_test_split" |
213 | 213 | ], |
214 | 214 | "language": "python", |
215 | 215 | "metadata": {}, |
|
220 | 220 | "collapsed": false, |
221 | 221 | "input": [ |
222 | 222 | "# split the data into training and validation sets\n", |
223 | | - "X = digits.data\n", |
224 | | - "y = digits.target\n", |
| 223 | + "X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target)\n", |
225 | 224 | "\n", |
226 | 225 | "# train the model\n", |
227 | 226 | "clf = GaussianNB()\n", |
228 | | - "clf.fit(X, y)\n", |
| 227 | + "clf.fit(X_train, y_train)\n", |
229 | 228 | "\n", |
230 | 229 | "# use the model to predict the labels of the test data\n", |
231 | | - "predicted = clf.predict(X)\n", |
232 | | - "expected = y" |
| 230 | + "predicted = clf.predict(X_test)\n", |
| 231 | + "expected = y_test" |
233 | 232 | ], |
234 | 233 | "language": "python", |
235 | 234 | "metadata": {}, |
236 | 235 | "outputs": [] |
237 | 236 | }, |
| 237 | + { |
| 238 | + "cell_type": "markdown", |
| 239 | + "metadata": {}, |
| 240 | + "source": [ |
| 241 | + "**Question**: why did we split the data into training and validation sets?" |
| 242 | + ] |
| 243 | + }, |
238 | 244 | { |
239 | 245 | "cell_type": "markdown", |
240 | 246 | "metadata": {}, |
|
253 | 259 | "# plot the digits: each image is 8x8 pixels\n", |
254 | 260 | "for i in range(64):\n", |
255 | 261 | " ax = fig.add_subplot(8, 8, i + 1, xticks=[], yticks=[])\n", |
256 | | - " ax.imshow(digits.images[i], cmap=plt.cm.binary)\n", |
| 262 | + " ax.imshow(X_test.reshape(-1, 8, 8)[i], cmap=plt.cm.binary,\n", |
| 263 | + " interpolation='nearest')\n", |
257 | 264 | " \n", |
258 | 265 | " # label the image with the target value\n", |
259 | 266 | " if predicted[i] == expected[i]:\n", |
|
265 | 272 | "metadata": {}, |
266 | 273 | "outputs": [] |
267 | 274 | }, |
268 | | - { |
269 | | - "cell_type": "markdown", |
270 | | - "metadata": {}, |
271 | | - "source": [ |
272 | | - "**Question: what might be a problem with judging performance based on these predictions?**" |
273 | | - ] |
274 | | - }, |
275 | 275 | { |
276 | 276 | "cell_type": "heading", |
277 | 277 | "level": 2, |
|
301 | 301 | "metadata": {}, |
302 | 302 | "outputs": [] |
303 | 303 | }, |
| 304 | + { |
| 305 | + "cell_type": "code", |
| 306 | + "collapsed": false, |
| 307 | + "input": [ |
| 308 | + "matches.sum() / float(len(matches))" |
| 309 | + ], |
| 310 | + "language": "python", |
| 311 | + "metadata": {}, |
| 312 | + "outputs": [] |
| 313 | + }, |
304 | 314 | { |
305 | 315 | "cell_type": "markdown", |
306 | 316 | "metadata": {}, |
|
349 | 359 | "source": [ |
350 | 360 | "We see here that in particular, the numbers 1, 2, 3, and 9 are often being labeled 8." |
351 | 361 | ] |
352 | | - }, |
353 | | - { |
354 | | - "cell_type": "markdown", |
355 | | - "metadata": {}, |
356 | | - "source": [ |
357 | | - "As alluded to above, however, this is not a very good way to measure performance.\n", |
358 | | - "Why? Because we are using the same data for **training** and **validation**.\n", |
359 | | - "With this metric, a classifier could be perfect by simply storing all the training\n", |
360 | | - "samples, and checking whether the \"unknown\" sample matches any exactly. Things are\n", |
361 | | - "rarely this easy in real problems.\n", |
362 | | - "\n", |
363 | | - "In a later notebook, we'll learn how **validation sets** can be used\n", |
364 | | - "to get around this difficulty." |
365 | | - ] |
366 | 362 | } |
367 | 363 | ], |
368 | 364 | "metadata": {} |
|
0 commit comments