1+ """ Cython bindings for the C++ BallTree code.
2+ """
3+ # Author: Thouis Jones
4+ # License: BSD
5+
16from libcpp.vector cimport vector
27
38import numpy as np
@@ -12,25 +17,30 @@ cdef extern from "BallTreePoint.h":
1217
1318ctypedef Point * Point_p
1419
20+
1521cdef extern from " BallTree.h" :
1622 cdef cppclass cBallTree " BallTree<Point>" :
1723 cBallTree(vector[Point_p] * , int )
1824 double query(Point * , vector[long int ] & ) except +
1925 double Euclidean_Dist(Point * , Point * ) except +
2026
27+
2128cdef Point * make_point(vals):
2229 pt = new Point(vals.size)
2330 for idx, v in enumerate (vals.flat):
2431 SET(pt, idx, v)
2532 return pt
2633
34+
35+ # ###############################################################################
2736# Cython wrapper
2837cdef class BallTree:
2938 cdef cBallTree * bt_ptr
3039 cdef vector[Point_p] * ptdata
3140 cdef int num_points
3241 cdef int num_dims
3342 cdef public object data
43+
3444 def __cinit__ (self , arr , leafsize = 20 ):
3545 # copy points into ptdata
3646 num_points, num_dims = self .num_points, self .num_dims = arr.shape
@@ -39,6 +49,7 @@ cdef class BallTree:
3949 self .ptdata.push_back(make_point(arr[i, :]))
4050 self .bt_ptr = new cBallTree(self .ptdata, leafsize)
4151 self .data = arr.copy()
52+
4253 def __dealloc__ (self ):
4354 cdef Point * temp
4455 for idx in range (self .ptdata.size()):
@@ -47,6 +58,7 @@ cdef class BallTree:
4758 del temp
4859 del self .ptdata
4960 del self .bt_ptr
61+
5062 def query (self , x , k = 1 , return_distance = True ):
5163 x = np.atleast_2d(x)
5264 assert x.shape[- 1 ] == self .num_dims
0 commit comments