Skip to content

Commit 7dd1ade

Browse files
committed
ENH speed up RBFSampler by ~10%
Applied the same to SkewedChi2Sampler, without measuring. I was actually aiming for reduced memory usage, but unfortunately that didn't matter much.
1 parent 5cfd45e commit 7dd1ade

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

sklearn/kernel_approximation.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ class RBFSampler(BaseEstimator, TransformerMixin):
2828
Parameters
2929
----------
3030
gamma: float
31-
parameter of RBF kernel: exp(-gamma * x**2)
31+
Parameter of RBF kernel: exp(-γ × x²)
3232
3333
n_components: int
34-
number of Monte Carlo samples per original feature.
34+
Number of Monte Carlo samples per original feature.
3535
Equals the dimensionality of the computed feature space.
3636
3737
random_state : {int, RandomState}, optional
@@ -44,7 +44,7 @@ class RBFSampler(BaseEstimator, TransformerMixin):
4444
Benjamin Recht.
4545
"""
4646

47-
def __init__(self, gamma=1., n_components=100., random_state=None):
47+
def __init__(self, gamma=1., n_components=100, random_state=None):
4848
self.gamma = gamma
4949
self.n_components = n_components
5050
self.random_state = random_state
@@ -90,10 +90,11 @@ def transform(self, X, y=None):
9090
-------
9191
X_new: array-like, shape (n_samples, n_components)
9292
"""
93-
X = atleast2d_or_csr(X)
9493
projection = safe_sparse_dot(X, self.random_weights_)
95-
return (np.sqrt(2.) / np.sqrt(self.n_components)
96-
* np.cos(projection + self.random_offset_))
94+
projection += self.random_offset_
95+
np.cos(projection, projection)
96+
projection *= np.sqrt(2.) / np.sqrt(self.n_components)
97+
return projection
9798

9899

99100
class SkewedChi2Sampler(BaseEstimator, TransformerMixin):
@@ -172,15 +173,17 @@ def transform(self, X, y=None):
172173
-------
173174
X_new: array-like, shape (n_samples, n_components)
174175
"""
175-
X = array2d(X)
176+
X = array2d(X, copy=True)
176177
if (X < 0).any():
177178
raise ValueError("X may not contain entries smaller than zero.")
178179

179-
projection = safe_sparse_dot(np.log(X + self.skewedness),
180-
self.random_weights_)
181-
182-
return (np.sqrt(2.) / np.sqrt(self.n_components)
183-
* np.cos(projection + self.random_offset_))
180+
X += self.skewedness
181+
np.log(X, X)
182+
projection = safe_sparse_dot(X, self.random_weights_)
183+
projection += self.random_offset_
184+
np.cos(projection, projection)
185+
projection *= np.sqrt(2.) / np.sqrt(self.n_components)
186+
return projection
184187

185188

186189
class AdditiveChi2Sampler(BaseEstimator, TransformerMixin):

0 commit comments

Comments
 (0)