Skip to content

Commit ff11845

Browse files
committed
BUG: fix tophat sampling in KDE
1 parent a42e4f1 commit ff11845

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

sklearn/neighbors/kde.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Author: Jake Vanderplas <[email protected]>
66

77
import numpy as np
8+
from scipy.special import gammainc
89
from ..base import BaseEstimator
910
from ..utils import array2d, check_random_state
1011
from .ball_tree import BallTree, DTYPE
@@ -195,5 +196,14 @@ def sample(self, n_samples=1, random_state=None):
195196

196197
if self.kernel == 'gaussian':
197198
return rng.normal(data[i], self.bandwidth)
199+
198200
elif self.kernel == 'tophat':
199-
return data[i] - 1 + 2 * rng.random_sample(i.shape)[:, None]
201+
# we first draw points from a d-dimensional normal distribution,
202+
# then use an incomplete gamma function to map them to a uniform
203+
# d-dimensional tophat distribution.
204+
dim = data.shape[1]
205+
X = rng.normal(size=(n_samples, dim))
206+
s_sq = (X ** 2).sum(1)
207+
correction = (gammainc(0.5 * dim, 0.5 * s_sq) ** (1. / dim)
208+
* self.bandwidth / np.sqrt(s_sq))
209+
return data[i] + X * correction[:, np.newaxis]

0 commit comments

Comments
 (0)