Skip to content

Commit 9e913c0

Browse files
Merge pull request scikit-learn#7064 from Erotemic/quickfix_py2_libsvm_kernel_unicode
Fixed python2 libsvm call with unicode kernel
2 parents 9539c0c + d8db108 commit 9e913c0

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

sklearn/svm/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,15 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel,
232232

233233
libsvm.set_verbosity_wrap(self.verbose)
234234

235+
if six.PY2:
236+
# In python2 ensure kernel is ascii bytes to prevent a TypeError
237+
if isinstance(kernel, six.types.UnicodeType):
238+
kernel = str(kernel)
239+
if six.PY3:
240+
# In python2 ensure kernel is utf8 unicode to prevent a TypeError
241+
if isinstance(kernel, bytes):
242+
kernel = str(kernel, 'utf8')
243+
235244
# we don't pass **self.get_params() to allow subclasses to
236245
# add other parameters to __init__
237246
self.support_, self.support_vectors_, self.n_support_, \

sklearn/svm/tests/test_svm.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
44
TODO: remove hard coded numerical results when possible
55
"""
6-
76
import numpy as np
87
import itertools
98
from numpy.testing import assert_array_equal, assert_array_almost_equal
@@ -25,6 +24,7 @@
2524
from sklearn.exceptions import ConvergenceWarning
2625
from sklearn.exceptions import NotFittedError
2726
from sklearn.multiclass import OneVsRestClassifier
27+
from sklearn.externals import six
2828

2929
# toy sample
3030
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
@@ -521,6 +521,30 @@ def test_bad_input():
521521
assert_raises(ValueError, clf.predict, Xt)
522522

523523

524+
def test_unicode_kernel():
525+
# Test that a unicode kernel name does not cause a TypeError on clf.fit
526+
if six.PY2:
527+
# Test unicode (same as str on python3)
528+
clf = svm.SVC(kernel=unicode('linear'))
529+
clf.fit(X, Y)
530+
531+
# Test ascii bytes (str is bytes in python2)
532+
clf = svm.SVC(kernel=str('linear'))
533+
clf.fit(X, Y)
534+
else:
535+
# Test unicode (str is unicode in python3)
536+
clf = svm.SVC(kernel=str('linear'))
537+
clf.fit(X, Y)
538+
539+
# Test ascii bytes (same as str on python2)
540+
clf = svm.SVC(kernel=bytes('linear', 'ascii'))
541+
clf.fit(X, Y)
542+
543+
# Test default behavior on both versions
544+
clf = svm.SVC(kernel='linear')
545+
clf.fit(X, Y)
546+
547+
524548
def test_sparse_precomputed():
525549
clf = svm.SVC(kernel='precomputed')
526550
sparse_gram = sparse.csr_matrix([[1, 0], [0, 1]])

0 commit comments

Comments
 (0)