Skip to content

Commit a9ae693

Browse files
authored
FIX Approximate nearest neighbors in TSNE example (scikit-learn#19809)
1 parent bc7cd31 commit a9ae693

File tree

1 file changed

+28
-36
lines changed

1 file changed

+28
-36
lines changed

examples/neighbors/approximate_nearest_neighbors.py

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88
replace KNeighborsTransformer and perform approximate nearest neighbors.
99
These packages can be installed with `pip install annoy nmslib`.
1010
11-
Note: Currently `TSNE(metric='precomputed')` does not modify the precomputed
12-
distances, and thus assumes that precomputed euclidean distances are squared.
13-
In future versions, a parameter in TSNE will control the optional squaring of
14-
precomputed distances (see #12401).
15-
1611
Note: In KNeighborsTransformer we use the definition which includes each
1712
training point as its own neighbor in the count of `n_neighbors`, and for
1813
compatibility reasons, one extra neighbor is computed when
@@ -91,7 +86,6 @@ def fit(self, X):
9186
# see more metric in the manual
9287
# https://github.com/nmslib/nmslib/tree/master/manual
9388
space = {
94-
'sqeuclidean': 'l2',
9589
'euclidean': 'l2',
9690
'cosine': 'cosinesimil',
9791
'l1': 'l1',
@@ -115,9 +109,6 @@ def transform(self, X):
115109
indices, distances = zip(*results)
116110
indices, distances = np.vstack(indices), np.vstack(distances)
117111

118-
if self.metric == 'sqeuclidean':
119-
distances **= 2
120-
121112
indptr = np.arange(0, n_samples_transform * n_neighbors + 1,
122113
n_neighbors)
123114
kneighbors_graph = csr_matrix((distances.ravel(), indices.ravel(),
@@ -139,8 +130,7 @@ def __init__(self, n_neighbors=5, metric='euclidean', n_trees=10,
139130

140131
def fit(self, X):
141132
self.n_samples_fit_ = X.shape[0]
142-
metric = self.metric if self.metric != 'sqeuclidean' else 'euclidean'
143-
self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=metric)
133+
self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=self.metric)
144134
for i, x in enumerate(X):
145135
self.annoy_.add_item(i, x.tolist())
146136
self.annoy_.build(self.n_trees)
@@ -177,9 +167,6 @@ def _transform(self, X):
177167
x.tolist(), n_neighbors, self.search_k,
178168
include_distances=True)
179169

180-
if self.metric == 'sqeuclidean':
181-
distances **= 2
182-
183170
indptr = np.arange(0, n_samples_transform * n_neighbors + 1,
184171
n_neighbors)
185172
kneighbors_graph = csr_matrix((distances.ravel(), indices.ravel(),
@@ -209,7 +196,7 @@ def test_transformers():
209196

210197
def load_mnist(n_samples):
211198
"""Load MNIST, shuffle the data, and return only n_samples."""
212-
mnist = fetch_openml("mnist_784")
199+
mnist = fetch_openml("mnist_784", as_frame=False)
213200
X, y = shuffle(mnist.data, mnist.target, random_state=2)
214201
return X[:n_samples] / 255, y[:n_samples]
215202

@@ -222,34 +209,39 @@ def run_benchmark():
222209

223210
n_iter = 500
224211
perplexity = 30
212+
metric = "euclidean"
225213
# TSNE requires a certain number of neighbors which depends on the
226214
# perplexity parameter.
227215
# Add one since we include each sample as its own neighbor.
228216
n_neighbors = int(3. * perplexity + 1) + 1
229217

218+
tsne_params = dict(perplexity=perplexity, method="barnes_hut",
219+
random_state=42, n_iter=n_iter,
220+
square_distances=True)
221+
230222
transformers = [
231-
('AnnoyTransformer', AnnoyTransformer(n_neighbors=n_neighbors,
232-
metric='sqeuclidean')),
233-
('NMSlibTransformer', NMSlibTransformer(n_neighbors=n_neighbors,
234-
metric='sqeuclidean')),
235-
('KNeighborsTransformer', KNeighborsTransformer(
236-
n_neighbors=n_neighbors, mode='distance', metric='sqeuclidean')),
237-
('TSNE with AnnoyTransformer', make_pipeline(
238-
AnnoyTransformer(n_neighbors=n_neighbors, metric='sqeuclidean'),
239-
TSNE(metric='precomputed', perplexity=perplexity,
240-
method="barnes_hut", random_state=42, n_iter=n_iter), )),
241-
('TSNE with NMSlibTransformer', make_pipeline(
242-
NMSlibTransformer(n_neighbors=n_neighbors, metric='sqeuclidean'),
243-
TSNE(metric='precomputed', perplexity=perplexity,
244-
method="barnes_hut", random_state=42, n_iter=n_iter), )),
245-
('TSNE with KNeighborsTransformer', make_pipeline(
246-
KNeighborsTransformer(n_neighbors=n_neighbors, mode='distance',
247-
metric='sqeuclidean'),
248-
TSNE(metric='precomputed', perplexity=perplexity,
249-
method="barnes_hut", random_state=42, n_iter=n_iter), )),
223+
('AnnoyTransformer',
224+
AnnoyTransformer(n_neighbors=n_neighbors, metric=metric)),
225+
('NMSlibTransformer',
226+
NMSlibTransformer(n_neighbors=n_neighbors, metric=metric)),
227+
('KNeighborsTransformer',
228+
KNeighborsTransformer(n_neighbors=n_neighbors, mode='distance',
229+
metric=metric)),
230+
('TSNE with AnnoyTransformer',
231+
make_pipeline(
232+
AnnoyTransformer(n_neighbors=n_neighbors, metric=metric),
233+
TSNE(metric='precomputed', **tsne_params))),
234+
('TSNE with NMSlibTransformer',
235+
make_pipeline(
236+
NMSlibTransformer(n_neighbors=n_neighbors, metric=metric),
237+
TSNE(metric='precomputed', **tsne_params))),
238+
('TSNE with KNeighborsTransformer',
239+
make_pipeline(
240+
KNeighborsTransformer(n_neighbors=n_neighbors, mode='distance',
241+
metric=metric),
242+
TSNE(metric='precomputed', **tsne_params))),
250243
('TSNE with internal NearestNeighbors',
251-
TSNE(metric='sqeuclidean', perplexity=perplexity, method="barnes_hut",
252-
random_state=42, n_iter=n_iter)),
244+
TSNE(metric=metric, **tsne_params)),
253245
]
254246

255247
# init the plot

0 commit comments

Comments
 (0)