Skip to content

Commit fc634f9

Browse files
committed
Made the following changes
1. Minor changes to docs 2. Replaced swap with inplace_swap
1 parent 31d438e commit fc634f9

File tree

2 files changed

+62
-31
lines changed

2 files changed

+62
-31
lines changed

sklearn/utils/sparsefuncs.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,25 @@ def inplace_column_scale(X, scale):
5959
"Unsupported type; expected a CSR or CSC sparse matrix.")
6060

6161

62-
def swap_row_csc(X, m, n):
62+
def inplace_swap_row_csc(X, m, n):
6363
"""
6464
Swaps two rows of a CSC matrix in-place.
6565
6666
Parameters
6767
----------
68-
X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
69-
m : int, index of first sample
70-
n : int, index of second sample
68+
X: scipy.sparse.csc_matrix, shape=(n_samples, n_features)
69+
Matrix whose two rows are to be swapped.
70+
71+
m: int
72+
Index of the row of X to be swapped.
73+
74+
n: int
75+
Index of the row of X to be swapped.
7176
"""
77+
for t in [m, n]:
78+
if isinstance(t, np.ndarray):
79+
raise TypeError("m and n should be valid integers")
80+
7281
if m < 0:
7382
m += X.shape[0]
7483
if n < 0:
@@ -79,20 +88,32 @@ def swap_row_csc(X, m, n):
7988
X.indices[m_mask] = n
8089

8190

82-
def swap_row_csr(X, m, n):
91+
def inplace_swap_row_csr(X, m, n):
8392
"""
8493
Swaps two rows of a CSR matrix in-place.
8594
8695
Parameters
8796
----------
88-
X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
89-
m : int, index of first sample
90-
n : int, index of second sample
97+
X: scipy.sparse.csr_matrix, shape=(n_samples, n_features)
98+
Matrix whose two rows are to be swapped.
99+
100+
m: int
101+
Index of the row of X to be swapped.
102+
103+
n: int
104+
Index of the row of X to be swapped.
91105
"""
106+
for t in [m, n]:
107+
if isinstance(t, np.ndarray):
108+
raise TypeError("m and n should be valid integers")
109+
92110
if m < 0:
93111
m += X.shape[0]
94112
if n < 0:
95113
n += X.shape[0]
114+
115+
# The following swapping makes life easier since m is assumed to be the
116+
# smaller integer below.
96117
if m > n:
97118
m, n = n, m
98119

@@ -123,43 +144,53 @@ def swap_row_csr(X, m, n):
123144
X.data[n_stop:]])
124145

125146

126-
def swap_row(X, m, n):
147+
def inplace_swap_row(X, m, n):
127148
"""
128149
Swaps two rows of a CSC/CSR matrix in-place.
129150
130151
Parameters
131152
----------
132-
X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
133-
m : int, index of first sample
134-
n : int, index of second sample
153+
X : CSR or CSC sparse matrix, shape=(n_samples, n_features)
154+
Matrix whose two rows are to be swapped.
155+
156+
m: int
157+
Index of the row of X to be swapped.
158+
159+
n: int
160+
Index of the row of X to be swapped.
135161
"""
136162
if isinstance(X, sp.csc_matrix):
137-
return swap_row_csc(X, m, n)
163+
return inplace_swap_row_csc(X, m, n)
138164
elif isinstance(X, sp.csr_matrix):
139-
return swap_row_csr(X, m, n)
165+
return inplace_swap_row_csr(X, m, n)
140166
else:
141167
raise TypeError(
142168
"Unsupported type; expected a CSR or CSC sparse matrix.")
143169

144170

145-
def swap_column(X, m, n):
171+
def inplace_swap_column(X, m, n):
146172
"""
147173
Swaps two columns of a CSC/CSR matrix in-place.
148174
149175
Parameters
150176
----------
151-
X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
152-
m : int, index of first sample
153-
n : int, index of second sample
177+
X : CSR or CSC sparse matrix, shape=(n_samples, n_features)
178+
Matrix whose two columns are to be swapped.
179+
180+
m: int
181+
Index of the column of X to be swapped.
182+
183+
n : int
184+
Index of the column of X to be swapped.
154185
"""
155186
if m < 0:
156187
m += X.shape[1]
157188
if n < 0:
158189
n += X.shape[1]
159190
if isinstance(X, sp.csc_matrix):
160-
return swap_row_csr(X, m, n)
191+
return inplace_swap_row_csr(X, m, n)
161192
elif isinstance(X, sp.csr_matrix):
162-
return swap_row_csc(X, m, n)
193+
return inplace_swap_row_csc(X, m, n)
163194
else:
164195
raise TypeError(
165196
"Unsupported type; expected a CSR or CSC sparse matrix.")

sklearn/utils/tests/test_sparsefuncs.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sklearn.datasets import make_classification
88
from sklearn.utils.sparsefuncs import (mean_variance_axis0,
99
inplace_column_scale,
10-
swap_row, swap_column)
10+
inplace_swap_row, inplace_swap_column)
1111
from sklearn.utils.sparsefuncs_fast import assign_rows_csr
1212
from sklearn.utils.testing import assert_raises
1313

@@ -65,7 +65,7 @@ def test_inplace_column_scale():
6565
assert_raises(TypeError, inplace_column_scale, X.tolil(), scale)
6666

6767

68-
def test_swap_row():
68+
def test_inplace_swap_row():
6969
X = np.array([[0, 3, 0],
7070
[2, 4, 0],
7171
[0, 0, 0],
@@ -77,21 +77,21 @@ def test_swap_row():
7777
swap = linalg.get_blas_funcs(('swap',), (X,))
7878
swap = swap[0]
7979
X[0], X[-1] = swap(X[0], X[-1])
80-
swap_row(X_csr, 0, -1)
81-
swap_row(X_csc, 0, -1)
80+
inplace_swap_row(X_csr, 0, -1)
81+
inplace_swap_row(X_csc, 0, -1)
8282
assert_array_equal(X_csr.toarray(), X_csc.toarray())
8383
assert_array_equal(X, X_csc.toarray())
8484
assert_array_equal(X, X_csr.toarray())
8585

8686
X[2], X[3] = swap(X[2], X[3])
87-
swap_row(X_csr, 2, 3)
88-
swap_row(X_csc, 2, 3)
87+
inplace_swap_row(X_csr, 2, 3)
88+
inplace_swap_row(X_csc, 2, 3)
8989
assert_array_equal(X_csr.toarray(), X_csc.toarray())
9090
assert_array_equal(X, X_csc.toarray())
9191
assert_array_equal(X, X_csr.toarray())
9292

9393

94-
def test_swap_column():
94+
def test_inplace_swap_column():
9595
X = np.array([[0, 3, 0],
9696
[2, 4, 0],
9797
[0, 0, 0],
@@ -103,15 +103,15 @@ def test_swap_column():
103103
swap = linalg.get_blas_funcs(('swap',), (X,))
104104
swap = swap[0]
105105
X[:, 0], X[:, -1] = swap(X[:, 0], X[:, -1])
106-
swap_column(X_csr, 0, -1)
107-
swap_column(X_csc, 0, -1)
106+
inplace_swap_column(X_csr, 0, -1)
107+
inplace_swap_column(X_csc, 0, -1)
108108
assert_array_equal(X_csr.toarray(), X_csc.toarray())
109109
assert_array_equal(X, X_csc.toarray())
110110
assert_array_equal(X, X_csr.toarray())
111111

112112
X[:, 0], X[:, 1] = swap(X[:, 0], X[:, 1])
113-
swap_column(X_csr, 0, 1)
114-
swap_column(X_csc, 0, 1)
113+
inplace_swap_column(X_csr, 0, 1)
114+
inplace_swap_column(X_csc, 0, 1)
115115
assert_array_equal(X_csr.toarray(), X_csc.toarray())
116116
assert_array_equal(X, X_csc.toarray())
117117
assert_array_equal(X, X_csr.toarray())

0 commit comments

Comments
 (0)