Skip to content

Commit f16227b

Browse files
jeremiedbbrth
authored andcommitted
FIX euclidean_distances float32 numerical instabilities (scikit-learn#13554)
1 parent f8af325 commit f16227b

File tree

4 files changed

+203
-33
lines changed

4 files changed

+203
-33
lines changed

doc/whats_new/v0.21.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,9 +543,14 @@ Support for Python 3.4 and below has been officially dropped.
543543
:pr:`13447` by :user:`Dan Ellis <dpwe>`.
544544

545545
- |API| The parameter ``labels`` in :func:`metrics.hamming_loss` is deprecated
546-
in version 0.21 and will be removed in version 0.23.
547-
:pr:`10580` by :user:`Reshama Shaikh <reshamas>` and :user:`Sandra
548-
Mitrovic <SandraMNE>`.
546+
in version 0.21 and will be removed in version 0.23. :pr:`10580` by
547+
:user:`Reshama Shaikh <reshamas>` and :user:`Sandra Mitrovic <SandraMNE>`.
548+
549+
- |Fix| The function :func:`euclidean_distances`, and therefore
550+
several estimators with ``metric='euclidean'``, suffered from numerical
551+
precision issues with ``float32`` features. Precision has been increased at the
552+
cost of a small drop of performance. :pr:`13554` by :user:`Celelibi` and
553+
:user:`Jérémie du Boisberranger <jeremiedbb>`.
549554

550555
- |API| :func:`metrics.jaccard_similarity_score` is deprecated in favour of
551556
the more consistent :func:`metrics.jaccard_score`. The former behavior for

sklearn/metrics/pairwise.py

Lines changed: 100 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,17 +193,24 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
193193
Y_norm_squared : array-like, shape (n_samples_2, ), optional
194194
Pre-computed dot-products of vectors in Y (e.g.,
195195
``(Y**2).sum(axis=1)``)
196+
May be ignored in some cases, see the note below.
196197
197198
squared : boolean, optional
198199
Return squared Euclidean distances.
199200
200201
X_norm_squared : array-like, shape = [n_samples_1], optional
201202
Pre-computed dot-products of vectors in X (e.g.,
202203
``(X**2).sum(axis=1)``)
204+
May be ignored in some cases, see the note below.
205+
206+
Notes
207+
-----
208+
To achieve better accuracy, `X_norm_squared` and `Y_norm_squared` may be
209+
unused if they are passed as ``float32``.
203210
204211
Returns
205212
-------
206-
distances : {array, sparse matrix}, shape (n_samples_1, n_samples_2)
213+
distances : array, shape (n_samples_1, n_samples_2)
207214
208215
Examples
209216
--------
@@ -224,41 +231,125 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False,
224231
"""
225232
X, Y = check_pairwise_arrays(X, Y)
226233

234+
# If norms are passed as float32, they are unused. If arrays are passed as
235+
# float32, norms needs to be recomputed on upcast chunks.
236+
# TODO: use a float64 accumulator in row_norms to avoid the latter.
227237
if X_norm_squared is not None:
228238
XX = check_array(X_norm_squared)
229239
if XX.shape == (1, X.shape[0]):
230240
XX = XX.T
231241
elif XX.shape != (X.shape[0], 1):
232242
raise ValueError(
233243
"Incompatible dimensions for X and X_norm_squared")
244+
if XX.dtype == np.float32:
245+
XX = None
246+
elif X.dtype == np.float32:
247+
XX = None
234248
else:
235249
XX = row_norms(X, squared=True)[:, np.newaxis]
236250

237-
if X is Y: # shortcut in the common case euclidean_distances(X, X)
251+
if X is Y and XX is not None:
252+
# shortcut in the common case euclidean_distances(X, X)
238253
YY = XX.T
239254
elif Y_norm_squared is not None:
240255
YY = np.atleast_2d(Y_norm_squared)
241256

242257
if YY.shape != (1, Y.shape[0]):
243258
raise ValueError(
244259
"Incompatible dimensions for Y and Y_norm_squared")
260+
if YY.dtype == np.float32:
261+
YY = None
262+
elif Y.dtype == np.float32:
263+
YY = None
245264
else:
246265
YY = row_norms(Y, squared=True)[np.newaxis, :]
247266

248-
distances = safe_sparse_dot(X, Y.T, dense_output=True)
249-
distances *= -2
250-
distances += XX
251-
distances += YY
267+
if X.dtype == np.float32:
268+
# To minimize precision issues with float32, we compute the distance
269+
# matrix on chunks of X and Y upcast to float64
270+
distances = _euclidean_distances_upcast(X, XX, Y, YY)
271+
else:
272+
# if dtype is already float64, no need to chunk and upcast
273+
distances = - 2 * safe_sparse_dot(X, Y.T, dense_output=True)
274+
distances += XX
275+
distances += YY
252276
np.maximum(distances, 0, out=distances)
253277

278+
# Ensure that distances between vectors and themselves are set to 0.0.
279+
# This may not be the case due to floating point rounding errors.
254280
if X is Y:
255-
# Ensure that distances between vectors and themselves are set to 0.0.
256-
# This may not be the case due to floating point rounding errors.
257-
distances.flat[::distances.shape[0] + 1] = 0.0
281+
np.fill_diagonal(distances, 0)
258282

259283
return distances if squared else np.sqrt(distances, out=distances)
260284

261285

286+
def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None):
287+
"""Euclidean distances between X and Y
288+
289+
Assumes X and Y have float32 dtype.
290+
Assumes XX and YY have float64 dtype or are None.
291+
292+
X and Y are upcast to float64 by chunks, which size is chosen to limit
293+
memory increase by approximately 10% (at least 10MiB).
294+
"""
295+
n_samples_X = X.shape[0]
296+
n_samples_Y = Y.shape[0]
297+
n_features = X.shape[1]
298+
299+
distances = np.empty((n_samples_X, n_samples_Y), dtype=np.float32)
300+
301+
x_density = X.nnz / np.prod(X.shape) if issparse(X) else 1
302+
y_density = Y.nnz / np.prod(Y.shape) if issparse(Y) else 1
303+
304+
# Allow 10% more memory than X, Y and the distance matrix take (at least
305+
# 10MiB)
306+
maxmem = max(
307+
((x_density * n_samples_X + y_density * n_samples_Y) * n_features
308+
+ (x_density * n_samples_X * y_density * n_samples_Y)) / 10,
309+
10 * 2**17)
310+
311+
# The increase amount of memory in 8-byte blocks is:
312+
# - x_density * batch_size * n_features (copy of chunk of X)
313+
# - y_density * batch_size * n_features (copy of chunk of Y)
314+
# - batch_size * batch_size (chunk of distance matrix)
315+
# Hence x² + (xd+yd)kx = M, where x=batch_size, k=n_features, M=maxmem
316+
# xd=x_density and yd=y_density
317+
tmp = (x_density + y_density) * n_features
318+
batch_size = (-tmp + np.sqrt(tmp**2 + 4 * maxmem)) / 2
319+
batch_size = max(int(batch_size), 1)
320+
321+
x_batches = gen_batches(X.shape[0], batch_size)
322+
y_batches = gen_batches(Y.shape[0], batch_size)
323+
324+
for i, x_slice in enumerate(x_batches):
325+
X_chunk = X[x_slice].astype(np.float64)
326+
if XX is None:
327+
XX_chunk = row_norms(X_chunk, squared=True)[:, np.newaxis]
328+
else:
329+
XX_chunk = XX[x_slice]
330+
331+
for j, y_slice in enumerate(y_batches):
332+
if X is Y and j < i:
333+
# when X is Y the distance matrix is symmetric so we only need
334+
# to compute half of it.
335+
d = distances[y_slice, x_slice].T
336+
337+
else:
338+
Y_chunk = Y[y_slice].astype(np.float64)
339+
if YY is None:
340+
YY_chunk = row_norms(Y_chunk, squared=True)[np.newaxis, :]
341+
else:
342+
YY_chunk = YY[:, y_slice]
343+
344+
d = -2 * safe_sparse_dot(X_chunk, Y_chunk.T, dense_output=True)
345+
d += XX_chunk
346+
d += YY_chunk
347+
348+
distances[x_slice, y_slice] = d.astype(np.float32, copy=False)
349+
350+
return distances
351+
352+
262353
def _argmin_min_reduce(dist, start):
263354
indices = dist.argmin(axis=1)
264355
values = dist[np.arange(dist.shape[0]), indices]

sklearn/metrics/pairwise_fast.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
#
88
# License: BSD 3 clause
99

10-
from libc.string cimport memset
1110
import numpy as np
1211
cimport numpy as np
1312
from cython cimport floating
13+
from libc.string cimport memset
1414

1515
from ..utils._cython_blas cimport _asum
1616

sklearn/metrics/tests/test_pairwise.py

Lines changed: 94 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -584,41 +584,115 @@ def test_pairwise_distances_chunked():
584584
assert_raises(StopIteration, next, gen)
585585

586586

587-
def test_euclidean_distances():
588-
# Check the pairwise Euclidean distances computation
589-
X = [[0]]
590-
Y = [[1], [2]]
587+
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
588+
ids=["dense", "sparse"])
589+
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
590+
ids=["dense", "sparse"])
591+
def test_euclidean_distances_known_result(x_array_constr, y_array_constr):
592+
# Check the pairwise Euclidean distances computation on known result
593+
X = x_array_constr([[0]])
594+
Y = y_array_constr([[1], [2]])
591595
D = euclidean_distances(X, Y)
592-
assert_array_almost_equal(D, [[1., 2.]])
596+
assert_allclose(D, [[1., 2.]])
593597

594-
X = csr_matrix(X)
595-
Y = csr_matrix(Y)
596-
D = euclidean_distances(X, Y)
597-
assert_array_almost_equal(D, [[1., 2.]])
598598

599+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
600+
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
601+
ids=["dense", "sparse"])
602+
def test_euclidean_distances_with_norms(dtype, y_array_constr):
603+
# check that we still get the right answers with {X,Y}_norm_squared
604+
# and that we get a wrong answer with wrong {X,Y}_norm_squared
599605
rng = np.random.RandomState(0)
600-
X = rng.random_sample((10, 4))
601-
Y = rng.random_sample((20, 4))
602-
X_norm_sq = (X ** 2).sum(axis=1).reshape(1, -1)
603-
Y_norm_sq = (Y ** 2).sum(axis=1).reshape(1, -1)
606+
X = rng.random_sample((10, 10)).astype(dtype, copy=False)
607+
Y = rng.random_sample((20, 10)).astype(dtype, copy=False)
608+
609+
# norms will only be used if their dtype is float64
610+
X_norm_sq = (X.astype(np.float64) ** 2).sum(axis=1).reshape(1, -1)
611+
Y_norm_sq = (Y.astype(np.float64) ** 2).sum(axis=1).reshape(1, -1)
612+
613+
Y = y_array_constr(Y)
604614

605-
# check that we still get the right answers with {X,Y}_norm_squared
606615
D1 = euclidean_distances(X, Y)
607616
D2 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq)
608617
D3 = euclidean_distances(X, Y, Y_norm_squared=Y_norm_sq)
609618
D4 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq,
610619
Y_norm_squared=Y_norm_sq)
611-
assert_array_almost_equal(D2, D1)
612-
assert_array_almost_equal(D3, D1)
613-
assert_array_almost_equal(D4, D1)
620+
assert_allclose(D2, D1)
621+
assert_allclose(D3, D1)
622+
assert_allclose(D4, D1)
614623

615624
# check we get the wrong answer with wrong {X,Y}_norm_squared
616-
X_norm_sq *= 0.5
617-
Y_norm_sq *= 0.5
618625
wrong_D = euclidean_distances(X, Y,
619626
X_norm_squared=np.zeros_like(X_norm_sq),
620627
Y_norm_squared=np.zeros_like(Y_norm_sq))
621-
assert_greater(np.max(np.abs(wrong_D - D1)), .01)
628+
with pytest.raises(AssertionError):
629+
assert_allclose(wrong_D, D1)
630+
631+
632+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
633+
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
634+
ids=["dense", "sparse"])
635+
@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix],
636+
ids=["dense", "sparse"])
637+
def test_euclidean_distances(dtype, x_array_constr, y_array_constr):
638+
# check that euclidean distances gives same result as scipy cdist
639+
# when X and Y != X are provided
640+
rng = np.random.RandomState(0)
641+
X = rng.random_sample((100, 10)).astype(dtype, copy=False)
642+
X[X < 0.8] = 0
643+
Y = rng.random_sample((10, 10)).astype(dtype, copy=False)
644+
Y[Y < 0.8] = 0
645+
646+
expected = cdist(X, Y)
647+
648+
X = x_array_constr(X)
649+
Y = y_array_constr(Y)
650+
distances = euclidean_distances(X, Y)
651+
652+
# the default rtol=1e-7 is too close to the float32 precision
653+
# and fails due too rounding errors.
654+
assert_allclose(distances, expected, rtol=1e-6)
655+
assert distances.dtype == dtype
656+
657+
658+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
659+
@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix],
660+
ids=["dense", "sparse"])
661+
def test_euclidean_distances_sym(dtype, x_array_constr):
662+
# check that euclidean distances gives same result as scipy pdist
663+
# when only X is provided
664+
rng = np.random.RandomState(0)
665+
X = rng.random_sample((100, 10)).astype(dtype, copy=False)
666+
X[X < 0.8] = 0
667+
668+
expected = squareform(pdist(X))
669+
670+
X = x_array_constr(X)
671+
distances = euclidean_distances(X)
672+
673+
# the default rtol=1e-7 is too close to the float32 precision
674+
# and fails due too rounding errors.
675+
assert_allclose(distances, expected, rtol=1e-6)
676+
assert distances.dtype == dtype
677+
678+
679+
@pytest.mark.parametrize(
680+
"dtype, eps, rtol",
681+
[(np.float32, 1e-4, 1e-5),
682+
pytest.param(
683+
np.float64, 1e-8, 0.99,
684+
marks=pytest.mark.xfail(reason='failing due to lack of precision'))])
685+
@pytest.mark.parametrize("dim", [1, 1000000])
686+
def test_euclidean_distances_extreme_values(dtype, eps, rtol, dim):
687+
# check that euclidean distances is correct with float32 input thanks to
688+
# upcasting. On float64 there are still precision issues.
689+
X = np.array([[1.] * dim], dtype=dtype)
690+
Y = np.array([[1. + eps] * dim], dtype=dtype)
691+
692+
distances = euclidean_distances(X, Y)
693+
expected = cdist(X, Y)
694+
695+
assert_allclose(distances, expected, rtol=1e-5)
622696

623697

624698
def test_cosine_distances():

0 commit comments

Comments
 (0)