22
33# License: BSD 3 clause
44import scipy .sparse as sp
5+ import numpy as np
56
67from .sparsefuncs_fast import (csr_mean_variance_axis0 ,
78 csc_mean_variance_axis0 ,
@@ -66,7 +67,7 @@ def swap_row_csc(X, m, n):
6667 ----------
6768 X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
6869 m : int, index of first sample
69- m : int, index of second sample
70+ n : int, index of second sample
7071 """
7172 if m < 0 :
7273 m += X .shape [0 ]
@@ -86,7 +87,7 @@ def swap_row_csr(X, m, n):
8687 ----------
8788 X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
8889 m : int, index of first sample
89- m : int, index of second sample
90+ n : int, index of second sample
9091 """
9192 if m < 0 :
9293 m += X .shape [0 ]
@@ -96,43 +97,27 @@ def swap_row_csr(X, m, n):
9697 m , n = n , m
9798
9899 indptr = X .indptr
99- indices = X .indices .copy ()
100- data = X .data .copy ()
101-
102- nz_m = indptr [m + 1 ] - indptr [m ]
103- nz_n = indptr [n + 1 ] - indptr [n ]
104100 m_ptr1 = indptr [m ]
105101 m_ptr2 = indptr [m + 1 ]
106102 n_ptr1 = indptr [n ]
107103 n_ptr2 = indptr [n + 1 ]
104+ nz_m = m_ptr2 - m_ptr1
105+ nz_n = n_ptr2 - n_ptr1
108106
109- # If non zero rows are equal in mth and nth row, then swapping becomes
110- # easy.
111- if nz_m == nz_n :
112- mask = X .indices [m_ptr1 : m_ptr2 ].copy ()
113- X .indices [m_ptr1 : m_ptr2 ] = X .indices [n_ptr1 : n_ptr2 ]
114- X .indices [n_ptr1 : n_ptr2 ] = mask
115- mask = X .data [m_ptr1 : m_ptr2 ].copy ()
116- X .data [m_ptr1 : m_ptr2 ] = X .data [n_ptr1 : n_ptr2 ]
117- X .data [n_ptr1 : n_ptr2 ] = mask
118107
119- else :
108+ if nz_m != nz_n :
120109 # Modify indptr first
121- X .indptr [m + 2 : n ] += nz_n - nz_m
110+ X .indptr [m + 2 :n ] += nz_n - nz_m
122111 X .indptr [m + 1 ] = X .indptr [m ] + nz_n
123112 X .indptr [n ] = X .indptr [n + 1 ] - nz_m
124113
125- mask1 = X .indices [m_ptr1 : m_ptr2 ].copy ()
126- mask2 = X .indices [n_ptr1 : n_ptr2 ].copy ()
127- X .indices [m_ptr1 : m_ptr1 + nz_n ] = mask2
128- X .indices [n_ptr2 - nz_m : n_ptr2 ] = mask1
129- X .indices [m_ptr1 + nz_n : n_ptr2 - nz_m ] = indices [m_ptr2 : n_ptr1 ]
130-
131- mask1 = X .data [m_ptr1 : m_ptr2 ].copy ()
132- mask2 = X .data [n_ptr1 : n_ptr2 ].copy ()
133- X .data [m_ptr1 : m_ptr1 + nz_n ] = mask2
134- X .data [n_ptr2 - nz_m : n_ptr2 ] = mask1
135- X .data [m_ptr1 + nz_n : n_ptr2 - nz_m ] = data [m_ptr2 : n_ptr1 ]
114+ X .indices = np .concatenate ([X .indices [:m_ptr1 ], X .indices [n_ptr1 :n_ptr2 ],
115+ X .indices [m_ptr2 :n_ptr1 ],
116+ X .indices [m_ptr1 :m_ptr2 ],
117+ X .indices [n_ptr2 :]])
118+ X .data = np .concatenate ([X .data [:m_ptr1 ], X .data [n_ptr1 :n_ptr2 ],
119+ X .data [m_ptr2 :n_ptr1 ], X .data [m_ptr1 :m_ptr2 ],
120+ X .data [n_ptr2 :]])
136121
137122
138123def swap_row (X , m , n ):
@@ -143,7 +128,7 @@ def swap_row(X, m, n):
143128 ----------
144129 X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
145130 m : int, index of first sample
146- m : int, index of second sample
131+ n : int, index of second sample
147132 """
148133 if isinstance (X , sp .csc_matrix ):
149134 return swap_row_csc (X , m , n )
@@ -152,3 +137,26 @@ def swap_row(X, m, n):
152137 else :
153138 raise TypeError (
154139 "Unsupported type; expected a CSR or CSC sparse matrix." )
140+
141+
142+ def swap_column (X , m , n ):
143+ """
144+ Swaps two columns of a CSC/CSR matrix in-place.
145+
146+ Parameters
147+ ----------
148+ X : scipy.sparse.csc_matrix, shape=(n_samples, n_features)
149+ m : int, index of first sample
150+ n : int, index of second sample
151+ """
152+ if m < 0 :
153+ m += X .shape [1 ]
154+ if n < 0 :
155+ n += X .shape [1 ]
156+ if isinstance (X , sp .csc_matrix ):
157+ return swap_row_csr (X , m , n )
158+ elif isinstance (X , sp .csr_matrix ):
159+ return swap_row_csc (X , m , n )
160+ else :
161+ raise TypeError (
162+ "Unsupported type; expected a CSR or CSC sparse matrix." )
0 commit comments