|
| 1 | +import pickle |
1 | 2 | import numpy as np |
2 | 3 | from numpy.testing import assert_array_almost_equal |
3 | 4 | from sklearn.neighbors.ball_tree import (BallTree, NeighborsHeap, |
|
29 | 30 | 'sokalsneath'] |
30 | 31 |
|
31 | 32 |
|
| 33 | +def dist_func(x1, x2, p): |
| 34 | + return np.sum((x1 - x2) ** p) ** (1. / p) |
| 35 | + |
| 36 | + |
32 | 37 | def brute_force_neighbors(X, Y, k, metric, **kwargs): |
33 | 38 | D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X) |
34 | 39 | ind = np.argsort(D, axis=1)[:, :k] |
@@ -216,19 +221,32 @@ def check_two_point(r, dualtree): |
216 | 221 |
|
217 | 222 |
|
218 | 223 | def test_ball_tree_pickle(): |
219 | | - import pickle |
220 | 224 | np.random.seed(0) |
221 | 225 | X = np.random.random((10, 3)) |
| 226 | + |
222 | 227 | bt1 = BallTree(X, leaf_size=1) |
| 228 | + # Test if BallTree with callable metric is picklable |
| 229 | + bt1_pyfunc = BallTree(X, metric=dist_func, leaf_size=1, p=2) |
| 230 | + |
223 | 231 | ind1, dist1 = bt1.query(X) |
| 232 | + ind1_pyfunc, dist1_pyfunc = bt1_pyfunc.query(X) |
224 | 233 |
|
225 | 234 | def check_pickle_protocol(protocol): |
226 | 235 | s = pickle.dumps(bt1, protocol=protocol) |
227 | 236 | bt2 = pickle.loads(s) |
| 237 | + |
| 238 | + s_pyfunc = pickle.dumps(bt1_pyfunc, protocol=protocol) |
| 239 | + bt2_pyfunc = pickle.loads(s_pyfunc) |
| 240 | + |
228 | 241 | ind2, dist2 = bt2.query(X) |
| 242 | + ind2_pyfunc, dist2_pyfunc = bt2_pyfunc.query(X) |
| 243 | + |
229 | 244 | assert_array_almost_equal(ind1, ind2) |
230 | 245 | assert_array_almost_equal(dist1, dist2) |
231 | 246 |
|
| 247 | + assert_array_almost_equal(ind1_pyfunc, ind2_pyfunc) |
| 248 | + assert_array_almost_equal(dist1_pyfunc, dist2_pyfunc) |
| 249 | + |
232 | 250 | for protocol in (0, 1, 2): |
233 | 251 | yield check_pickle_protocol, protocol |
234 | 252 |
|
|
0 commit comments