2424from sklearn .learning_curve import learning_curve
2525
2626
27- def plot_learning_curve (estimator , title , X , y , ylim = ( 0.7 , 1.01 ), cv = None ,
28- n_jobs = 1 , train_sizes = np .linspace (.1 , 1.0 , 5 )):
27+ def plot_learning_curve (estimator , title , X , y , ylim = None , cv = None ,
28+ n_jobs = 1 , train_sizes = np .linspace (.1 , 1.0 , 5 )):
2929 """
3030 Generate a simple plot of the test and traning learning curve.
3131
@@ -46,8 +46,7 @@ def plot_learning_curve(estimator, title, X, y, ylim=(0.7,1.01), cv=None,
4646 None for unsupervised learning.
4747
4848 ylim : tuple, shape (ymin, ymax), optional
49- Defines minimum and maximum yvalues plotted. Defaults to (0.7, 1.01)
50- for easy comparison of plots.
49+ Defines minimum and maximum yvalues plotted.
5150
5251 cv : integer, cross-validation generator, optional
5352 If an integer is passed, it is the number of folds (defaults to 3).
@@ -59,7 +58,8 @@ def plot_learning_curve(estimator, title, X, y, ylim=(0.7,1.01), cv=None,
5958 """
6059 plt .figure ()
6160 plt .title (title )
62- plt .ylim ( * ylim )
61+ if ylim is not None :
62+ plt .ylim (* ylim )
6363 plt .xlabel ("Training examples" )
6464 plt .ylabel ("Score" )
6565 train_sizes , train_scores , test_scores = learning_curve (
@@ -71,7 +71,8 @@ def plot_learning_curve(estimator, title, X, y, ylim=(0.7,1.01), cv=None,
7171 plt .grid ()
7272
7373 plt .fill_between (train_sizes , train_scores_mean - train_scores_std ,
74- train_scores_mean + train_scores_std , alpha = 0.1 , color = "r" )
74+ train_scores_mean + train_scores_std , alpha = 0.1 ,
75+ color = "r" )
7576 plt .fill_between (train_sizes , test_scores_mean - test_scores_std ,
7677 test_scores_mean + test_scores_std , alpha = 0.1 , color = "g" )
7778 plt .plot (train_sizes , train_scores_mean , 'o-' , color = "r" ,
@@ -87,20 +88,20 @@ def plot_learning_curve(estimator, title, X, y, ylim=(0.7,1.01), cv=None,
8788X , y = digits .data , digits .target
8889
8990
90- title = "Learning Curve (Naive Bayes)"
91- # Cross validation with 100 iterations to get smoother mean test and train
92- # score curves, each time with 20% data randomly selected as the validation set.
91+ title = "Learning Curve (Naive Bayes)"
92+ # Cross validation with 100 iterations to get smoother mean test and train
93+ # score curves, each time with 20% data randomly selected as a validation set.
9394cv = cross_validation .ShuffleSplit (digits .data .shape [0 ], n_iter = 100 ,
9495 test_size = 0.2 , random_state = 0 )
9596
96- estimator = GaussianNB ()
97- plot_learning_curve (estimator , title , X , y , cv = cv , n_jobs = 4 )
97+ estimator = GaussianNB ()
98+ plot_learning_curve (estimator , title , X , y , ylim = ( 0.7 , 1.01 ), cv = cv , n_jobs = 4 )
9899
99- title = "Learning Curve (SVM, RBF kernel, $\gamma=0.001$)"
100+ title = "Learning Curve (SVM, RBF kernel, $\gamma=0.001$)"
100101# SVC is more expensive so we do a lower number of CV iterations:
101102cv = cross_validation .ShuffleSplit (digits .data .shape [0 ], n_iter = 10 ,
102103 test_size = 0.2 , random_state = 0 )
103- estimator = SVC (gamma = 0.001 )
104- plot_learning_curve (estimator , title , X , y , cv = cv , n_jobs = 4 )
104+ estimator = SVC (gamma = 0.001 )
105+ plot_learning_curve (estimator , title , X , y , ( 0.7 , 1.01 ), cv = cv , n_jobs = 4 )
105106
106107plt .show ()
0 commit comments