Skip to content

Commit 89a58cd

Browse files
committed
FIX support random state in libsvm
1 parent f2643c8 commit 89a58cd

File tree

13 files changed

+2099
-1632
lines changed

13 files changed

+2099
-1632
lines changed

sklearn/svm/base.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class BaseLibSVM(six.with_metaclass(ABCMeta, BaseEstimator)):
6868
@abstractmethod
6969
def __init__(self, impl, kernel, degree, gamma, coef0,
7070
tol, C, nu, epsilon, shrinking, probability, cache_size,
71-
class_weight, verbose, max_iter):
71+
class_weight, verbose, max_iter, random_state):
7272

7373
if not impl in LIBSVM_IMPL: # pragma: no cover
7474
raise ValueError("impl should be one of %s, %s was given" % (
@@ -89,6 +89,7 @@ def __init__(self, impl, kernel, degree, gamma, coef0,
8989
self.class_weight = class_weight
9090
self.verbose = verbose
9191
self.max_iter = max_iter
92+
self.random_state = random_state
9293

9394
@property
9495
def _pairwise(self):
@@ -126,6 +127,8 @@ def fit(self, X, y, sample_weight=None):
126127
matrices as input.
127128
"""
128129

130+
rnd = check_random_state(self.random_state)
131+
129132
self._sparse = sp.isspmatrix(X) and not self._pairwise
130133

131134
if self._sparse and self._pairwise:
@@ -170,7 +173,10 @@ def fit(self, X, y, sample_weight=None):
170173
fit = self._sparse_fit if self._sparse else self._dense_fit
171174
if self.verbose: # pragma: no cover
172175
print('[LibSVM]', end='')
173-
fit(X, y, sample_weight, solver_type, kernel)
176+
177+
seed = rnd.randint(np.iinfo('i').max)
178+
fit(X, y, sample_weight, solver_type, kernel, random_seed=seed)
179+
# see comment on the other call to np.iinfo in this file
174180

175181
self.shape_fit_ = X.shape
176182

@@ -199,7 +205,8 @@ def _warn_from_fit_status(self):
199205
' StandardScaler or MinMaxScaler.'
200206
% self.max_iter, ConvergenceWarning)
201207

202-
def _dense_fit(self, X, y, sample_weight, solver_type, kernel):
208+
def _dense_fit(self, X, y, sample_weight, solver_type, kernel,
209+
random_seed):
203210
if callable(self.kernel):
204211
# you must store a reference to X to compute the kernel in predict
205212
# TODO: add keyword copy to copy on demand
@@ -223,11 +230,12 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel):
223230
shrinking=self.shrinking, tol=self.tol,
224231
cache_size=self.cache_size, coef0=self.coef0,
225232
gamma=self._gamma, epsilon=self.epsilon,
226-
max_iter=self.max_iter)
233+
max_iter=self.max_iter, random_seed=random_seed)
227234

228235
self._warn_from_fit_status()
229236

230-
def _sparse_fit(self, X, y, sample_weight, solver_type, kernel):
237+
def _sparse_fit(self, X, y, sample_weight, solver_type, kernel,
238+
random_seed):
231239
X.data = np.asarray(X.data, dtype=np.float64, order='C')
232240
X.sort_indices()
233241

@@ -243,7 +251,8 @@ def _sparse_fit(self, X, y, sample_weight, solver_type, kernel):
243251
kernel_type, self.degree, self._gamma, self.coef0, self.tol,
244252
self.C, self.class_weight_,
245253
sample_weight, self.nu, self.cache_size, self.epsilon,
246-
int(self.shrinking), int(self.probability), self.max_iter)
254+
int(self.shrinking), int(self.probability), self.max_iter,
255+
random_seed)
247256

248257
self._warn_from_fit_status()
249258

sklearn/svm/classes.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class frequencies.
7676
per-process runtime setting in liblinear that, if enabled, may not work
7777
properly in a multithreaded context.
7878
79-
random_state: int seed, RandomState instance, or None (default)
79+
random_state : int seed, RandomState instance, or None (default)
8080
The seed of the pseudo random number generator to use when
8181
shuffling the data.
8282
@@ -206,6 +206,10 @@ class frequencies.
206206
max_iter : int, optional (default=-1)
207207
Hard limit on iterations within solver, or -1 for no limit.
208208
209+
random_state : int seed, RandomState instance, or None (default)
210+
The seed of the pseudo random number generator to use when
211+
shuffling the data for probability estimation.
212+
209213
Attributes
210214
----------
211215
`support_` : array-like, shape = [n_SV]
@@ -263,11 +267,12 @@ class frequencies.
263267
def __init__(self, C=1.0, kernel='rbf', degree=3, gamma=0.0,
264268
coef0=0.0, shrinking=True, probability=False,
265269
tol=1e-3, cache_size=200, class_weight=None,
266-
verbose=False, max_iter=-1):
270+
verbose=False, max_iter=-1, random_state=None):
267271

268272
super(SVC, self).__init__(
269273
'c_svc', kernel, degree, gamma, coef0, tol, C, 0., 0., shrinking,
270-
probability, cache_size, class_weight, verbose, max_iter)
274+
probability, cache_size, class_weight, verbose, max_iter,
275+
random_state)
271276

272277

273278
class NuSVC(BaseSVC):
@@ -325,6 +330,10 @@ class NuSVC(BaseSVC):
325330
max_iter : int, optional (default=-1)
326331
Hard limit on iterations within solver, or -1 for no limit.
327332
333+
random_state : int seed, RandomState instance, or None (default)
334+
The seed of the pseudo random number generator to use when
335+
shuffling the data for probability estimation.
336+
328337
Attributes
329338
----------
330339
`support_` : array-like, shape = [n_SV]
@@ -379,11 +388,12 @@ class NuSVC(BaseSVC):
379388

380389
def __init__(self, nu=0.5, kernel='rbf', degree=3, gamma=0.0,
381390
coef0=0.0, shrinking=True, probability=False,
382-
tol=1e-3, cache_size=200, verbose=False, max_iter=-1):
391+
tol=1e-3, cache_size=200, verbose=False, max_iter=-1,
392+
random_state=None):
383393

384394
super(NuSVC, self).__init__(
385395
'nu_svc', kernel, degree, gamma, coef0, tol, 0., nu, 0., shrinking,
386-
probability, cache_size, None, verbose, max_iter)
396+
probability, cache_size, None, verbose, max_iter, random_state)
387397

388398

389399
class SVR(BaseLibSVM, RegressorMixin):
@@ -444,6 +454,10 @@ class SVR(BaseLibSVM, RegressorMixin):
444454
max_iter : int, optional (default=-1)
445455
Hard limit on iterations within solver, or -1 for no limit.
446456
457+
random_state : int seed, RandomState instance, or None (default)
458+
The seed of the pseudo random number generator to use when
459+
shuffling the data for probability estimaton.
460+
447461
Attributes
448462
----------
449463
`support_` : array-like, shape = [n_SV]
@@ -488,12 +502,13 @@ class SVR(BaseLibSVM, RegressorMixin):
488502
"""
489503
def __init__(self, kernel='rbf', degree=3, gamma=0.0, coef0=0.0, tol=1e-3,
490504
C=1.0, epsilon=0.1, shrinking=True, probability=False,
491-
cache_size=200, verbose=False, max_iter=-1):
505+
cache_size=200, verbose=False, max_iter=-1,
506+
random_state=None):
492507

493508
super(SVR, self).__init__(
494509
'epsilon_svr', kernel, degree, gamma, coef0, tol, C, 0., epsilon,
495510
shrinking, probability, cache_size, None, verbose,
496-
max_iter)
511+
max_iter, random_state)
497512

498513

499514
class NuSVR(BaseLibSVM, RegressorMixin):
@@ -555,6 +570,10 @@ class NuSVR(BaseLibSVM, RegressorMixin):
555570
max_iter : int, optional (default=-1)
556571
Hard limit on iterations within solver, or -1 for no limit.
557572
573+
random_state : int seed, RandomState instance, or None (default)
574+
The seed of the pseudo random number generator to use when
575+
shuffling the data for probability estimation.
576+
558577
Attributes
559578
----------
560579
`support_` : array-like, shape = [n_SV]
@@ -603,11 +622,11 @@ class NuSVR(BaseLibSVM, RegressorMixin):
603622
def __init__(self, nu=0.5, C=1.0, kernel='rbf', degree=3,
604623
gamma=0.0, coef0=0.0, shrinking=True,
605624
probability=False, tol=1e-3, cache_size=200,
606-
verbose=False, max_iter=-1):
625+
verbose=False, max_iter=-1, random_state=None):
607626

608627
super(NuSVR, self).__init__(
609628
'nu_svr', kernel, degree, gamma, coef0, tol, C, nu, 0., shrinking,
610-
probability, cache_size, None, verbose, max_iter)
629+
probability, cache_size, None, verbose, max_iter, random_state)
611630

612631

613632
class OneClassSVM(BaseLibSVM):
@@ -660,6 +679,10 @@ class OneClassSVM(BaseLibSVM):
660679
max_iter : int, optional (default=-1)
661680
Hard limit on iterations within solver, or -1 for no limit.
662681
682+
random_state : int seed, RandomState instance, or None (default)
683+
The seed of the pseudo random number generator to use when
684+
shuffling the data for probability estimation.
685+
663686
Attributes
664687
----------
665688
`support_` : array-like, shape = [n_SV]
@@ -684,11 +707,12 @@ class OneClassSVM(BaseLibSVM):
684707
"""
685708
def __init__(self, kernel='rbf', degree=3, gamma=0.0, coef0=0.0, tol=1e-3,
686709
nu=0.5, shrinking=True, cache_size=200, verbose=False,
687-
max_iter=-1):
710+
max_iter=-1, random_state=None):
688711

689712
super(OneClassSVM, self).__init__(
690713
'one_class', kernel, degree, gamma, coef0, tol, 0., nu, 0.,
691-
shrinking, False, cache_size, None, verbose, max_iter)
714+
shrinking, False, cache_size, None, verbose, max_iter,
715+
random_state)
692716

693717
def fit(self, X, sample_weight=None, **params):
694718
"""

0 commit comments

Comments
 (0)