Skip to content

Commit af6ab92

Browse files
author
Fabian Pedregosa
committed
Add optional parameter n_class to load_digits.
1 parent 1409e01 commit af6ab92

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

scikits/learn/datasets/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,15 @@ def load_iris():
205205
DESCR=fdescr.read())
206206

207207

208-
def load_digits():
208+
def load_digits(n_class=10):
209209
"""load the digits dataset and returns it.
210210
211+
212+
Parameters
213+
----------
214+
n_class : integer, between 0 and 10
215+
Number of classes to return, defaults to 10
216+
211217
Returns
212218
-------
213219
data : Bunch
@@ -237,6 +243,12 @@ def load_digits():
237243
flat_data = data[:, :-1]
238244
images = flat_data.view()
239245
images.shape = (-1, 8, 8)
246+
247+
if n_class < 10:
248+
idx = target < n_class
249+
flat_data, target = flat_data[idx], target[idx]
250+
images = images[idx]
251+
240252
return Bunch(data=flat_data, target=target.astype(np.int),
241253
target_names=np.arange(10),
242254
images=images,

0 commit comments

Comments
 (0)