Skip to content

Commit 1150d28

Browse files
committed
Merge branch 'master' of github.com:scikit-learn/scikit-learn
2 parents 632364c + 3ef468a commit 1150d28

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

scikits/learn/neighbors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,17 @@ class from an array representing our data set and ask who's
120120
>>> neigh = NeighborsClassifier(n_neighbors=1)
121121
>>> neigh.fit(samples, labels)
122122
NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto')
123-
>>> print neigh.kneighbors([1., 1., 1.])
124-
(array([[ 0.5]]), array([[2]]))
123+
>>> print neigh.kneighbors([1., 1., 1.]) # doctest: +ELLIPSIS
124+
(array([[ 0.5]]), array([[2]]...))
125125
126126
As you can see, it returns [[0.5]], and [[2]], which means that the
127127
element is at distance 0.5 and is the third element of samples
128128
(indexes start at 0). You can also query for multiple points:
129129
130130
>>> X = [[0., 1., 0.], [1., 0., 1.]]
131-
>>> neigh.kneighbors(X, return_distance=False)
131+
>>> neigh.kneighbors(X, return_distance=False) # doctest: +ELLIPSIS
132132
array([[1],
133-
[2]])
133+
[2]]...)
134134
135135
"""
136136
self._set_params(**params)

scikits/learn/src/ball_tree.pyx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
""" Cython bindings for the C++ BallTree code.
2+
"""
3+
# Author: Thouis Jones
4+
# License: BSD
5+
16
from libcpp.vector cimport vector
27

38
import numpy as np
@@ -12,25 +17,30 @@ cdef extern from "BallTreePoint.h":
1217

1318
ctypedef Point *Point_p
1419

20+
1521
cdef extern from "BallTree.h":
1622
cdef cppclass cBallTree "BallTree<Point>":
1723
cBallTree(vector[Point_p] *, int)
1824
double query(Point *, vector[long int] &) except +
1925
double Euclidean_Dist(Point *, Point *) except +
2026

27+
2128
cdef Point *make_point(vals):
2229
pt = new Point(vals.size)
2330
for idx, v in enumerate(vals.flat):
2431
SET(pt, idx, v)
2532
return pt
2633

34+
35+
################################################################################
2736
# Cython wrapper
2837
cdef class BallTree:
2938
cdef cBallTree *bt_ptr
3039
cdef vector[Point_p] *ptdata
3140
cdef int num_points
3241
cdef int num_dims
3342
cdef public object data
43+
3444
def __cinit__(self, arr, leafsize=20):
3545
# copy points into ptdata
3646
num_points, num_dims = self.num_points, self.num_dims = arr.shape
@@ -39,6 +49,7 @@ cdef class BallTree:
3949
self.ptdata.push_back(make_point(arr[i, :]))
4050
self.bt_ptr = new cBallTree(self.ptdata, leafsize)
4151
self.data = arr.copy()
52+
4253
def __dealloc__(self):
4354
cdef Point *temp
4455
for idx in range(self.ptdata.size()):
@@ -47,6 +58,7 @@ cdef class BallTree:
4758
del temp
4859
del self.ptdata
4960
del self.bt_ptr
61+
5062
def query(self, x, k=1, return_distance=True):
5163
x = np.atleast_2d(x)
5264
assert x.shape[-1] == self.num_dims

0 commit comments

Comments
 (0)