Skip to content

Commit 924db9b

Browse files
committed
Made the following changes
a] Replaced numpy slicing with concatanate b] Added swap_sparse_column
1 parent 5119d68 commit 924db9b

File tree

2 files changed

+65
-31
lines changed

2 files changed

+65
-31
lines changed

sklearn/utils/sparsefuncs.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# License: BSD 3 clause
44
import scipy.sparse as sp
5+
import numpy as np
56

67
from .sparsefuncs_fast import (csr_mean_variance_axis0,
78
csc_mean_variance_axis0,
@@ -66,7 +67,7 @@ def swap_row_csc(X, m, n):
6667
----------
6768
X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
6869
m : int, index of first sample
69-
m : int, index of second sample
70+
n : int, index of second sample
7071
"""
7172
if m < 0:
7273
m += X.shape[0]
@@ -86,7 +87,7 @@ def swap_row_csr(X, m, n):
8687
----------
8788
X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
8889
m : int, index of first sample
89-
m : int, index of second sample
90+
n : int, index of second sample
9091
"""
9192
if m < 0:
9293
m += X.shape[0]
@@ -96,43 +97,27 @@ def swap_row_csr(X, m, n):
9697
m, n = n, m
9798

9899
indptr = X.indptr
99-
indices = X.indices.copy()
100-
data = X.data.copy()
101-
102-
nz_m = indptr[m + 1] - indptr[m]
103-
nz_n = indptr[n + 1] - indptr[n]
104100
m_ptr1 = indptr[m]
105101
m_ptr2 = indptr[m + 1]
106102
n_ptr1 = indptr[n]
107103
n_ptr2 = indptr[n + 1]
104+
nz_m = m_ptr2 - m_ptr1
105+
nz_n = n_ptr2 - n_ptr1
108106

109-
# If non zero rows are equal in mth and nth row, then swapping becomes
110-
# easy.
111-
if nz_m == nz_n:
112-
mask = X.indices[m_ptr1: m_ptr2].copy()
113-
X.indices[m_ptr1: m_ptr2] = X.indices[n_ptr1: n_ptr2]
114-
X.indices[n_ptr1: n_ptr2] = mask
115-
mask = X.data[m_ptr1: m_ptr2].copy()
116-
X.data[m_ptr1: m_ptr2] = X.data[n_ptr1: n_ptr2]
117-
X.data[n_ptr1: n_ptr2] = mask
118107

119-
else:
108+
if nz_m != nz_n:
120109
# Modify indptr first
121-
X.indptr[m + 2: n] += nz_n - nz_m
110+
X.indptr[m + 2:n] += nz_n - nz_m
122111
X.indptr[m + 1] = X.indptr[m] + nz_n
123112
X.indptr[n] = X.indptr[n + 1] - nz_m
124113

125-
mask1 = X.indices[m_ptr1: m_ptr2].copy()
126-
mask2 = X.indices[n_ptr1: n_ptr2].copy()
127-
X.indices[m_ptr1: m_ptr1 + nz_n] = mask2
128-
X.indices[n_ptr2 - nz_m: n_ptr2] = mask1
129-
X.indices[m_ptr1 + nz_n: n_ptr2 - nz_m] = indices[m_ptr2: n_ptr1]
130-
131-
mask1 = X.data[m_ptr1: m_ptr2].copy()
132-
mask2 = X.data[n_ptr1: n_ptr2].copy()
133-
X.data[m_ptr1: m_ptr1 + nz_n] = mask2
134-
X.data[n_ptr2 - nz_m: n_ptr2] = mask1
135-
X.data[m_ptr1 + nz_n: n_ptr2 - nz_m] = data[m_ptr2: n_ptr1]
114+
X.indices = np.concatenate([X.indices[:m_ptr1], X.indices[n_ptr1:n_ptr2],
115+
X.indices[m_ptr2:n_ptr1],
116+
X.indices[m_ptr1:m_ptr2],
117+
X.indices[n_ptr2:]])
118+
X.data = np.concatenate([X.data[:m_ptr1], X.data[n_ptr1:n_ptr2],
119+
X.data[m_ptr2:n_ptr1], X.data[m_ptr1:m_ptr2],
120+
X.data[n_ptr2:]])
136121

137122

138123
def swap_row(X, m, n):
@@ -143,7 +128,7 @@ def swap_row(X, m, n):
143128
----------
144129
X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
145130
m : int, index of first sample
146-
m : int, index of second sample
131+
n : int, index of second sample
147132
"""
148133
if isinstance(X, sp.csc_matrix):
149134
return swap_row_csc(X, m, n)
@@ -152,3 +137,26 @@ def swap_row(X, m, n):
152137
else:
153138
raise TypeError(
154139
"Unsupported type; expected a CSR or CSC sparse matrix.")
140+
141+
142+
def swap_column(X, m, n):
143+
"""
144+
Swaps two columns of a CSC/CSR matrix in-place.
145+
146+
Parameters
147+
----------
148+
X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
149+
m : int, index of first sample
150+
n : int, index of second sample
151+
"""
152+
if m < 0:
153+
m += X.shape[1]
154+
if n < 0:
155+
n += X.shape[1]
156+
if isinstance(X, sp.csc_matrix):
157+
return swap_row_csr(X, m, n)
158+
elif isinstance(X, sp.csr_matrix):
159+
return swap_row_csc(X, m, n)
160+
else:
161+
raise TypeError(
162+
"Unsupported type; expected a CSR or CSC sparse matrix.")

sklearn/utils/tests/test_sparsefuncs.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sklearn.datasets import make_classification
88
from sklearn.utils.sparsefuncs import (mean_variance_axis0,
99
inplace_column_scale,
10-
swap_row)
10+
swap_row, swap_column)
1111
from sklearn.utils.sparsefuncs_fast import assign_rows_csr
1212
from sklearn.utils.testing import assert_raises
1313

@@ -89,3 +89,29 @@ def test_swap_row():
8989
assert_array_equal(X_csr.toarray(), X_csc.toarray())
9090
assert_array_equal(X, X_csc.toarray())
9191
assert_array_equal(X, X_csr.toarray())
92+
93+
94+
def test_swap_column():
95+
X = np.array([[0, 3, 0],
96+
[2, 4, 0],
97+
[0, 0, 0],
98+
[9, 8, 7],
99+
[4, 0, 5]], dtype=np.float64)
100+
X_csr = sp.csr_matrix(X)
101+
X_csc = sp.csc_matrix(X)
102+
103+
swap = linalg.get_blas_funcs(('swap',), (X,))
104+
swap = swap[0]
105+
X[:, 0], X[:, -1] = swap(X[:, 0], X[:, -1])
106+
swap_column(X_csr, 0, -1)
107+
swap_column(X_csc, 0, -1)
108+
assert_array_equal(X_csr.toarray(), X_csc.toarray())
109+
assert_array_equal(X, X_csc.toarray())
110+
assert_array_equal(X, X_csr.toarray())
111+
112+
X[:, 0], X[:, 1] = swap(X[:, 0], X[:, 1])
113+
swap_column(X_csr, 0, 1)
114+
swap_column(X_csc, 0, 1)
115+
assert_array_equal(X_csr.toarray(), X_csc.toarray())
116+
assert_array_equal(X, X_csc.toarray())
117+
assert_array_equal(X, X_csr.toarray())

0 commit comments

Comments
 (0)