Skip to content

Commit 5bdc5a4

Browse files
thouisFabian Pedregosa
authored andcommitted
Renamed for backwards compatibility, fixed C++ Exceptions to propagate to python
1 parent fa3f402 commit 5bdc5a4

File tree

4 files changed

+111
-110
lines changed

4 files changed

+111
-110
lines changed

scikits/learn/neighbors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99

1010
from .base import BaseEstimator, ClassifierMixin, RegressorMixin
11-
from .ball_tree import PyBallTree as BallTree
11+
from .ball_tree import BallTree
1212
from .metrics import euclidean_distances
1313

1414

scikits/learn/src/BallTree.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <algorithm>
99
#include <cstdlib>
1010
#include <sstream>
11+
#include <exception>
1112

1213
/************************************************************
1314
* templated Ball Tree class
@@ -42,13 +43,13 @@
4243

4344

4445
/* Custom exception to allow Python to catch C++ exceptions */
45-
class BallTreeException {
46-
public:
47-
BallTreeException(const char* str = "There's a problem") : message(str) {}
48-
const char* what() const {return message;}
49-
50-
private:
51-
const char* message;
46+
class BallTreeException : public std::exception {
47+
public:
48+
BallTreeException(std::string msg) { message = msg; }
49+
~BallTreeException() throw() {};
50+
virtual const char* what() const throw() { return message.c_str(); }
51+
private:
52+
std::string message;
5253
};
5354

5455

@@ -65,7 +66,7 @@ template<class P1_Type,class P2_Type>
6566
if(p2.size() != D){
6667
std::stringstream oss;
6768
oss << "Euclidean_Dist : point sizes must match (" << D << " != " << p2.size() << ").\n";
68-
throw BallTreeException(oss.str().c_str());
69+
throw BallTreeException(oss.str());
6970
}
7071
typename P1_Type::value_type dist = 0;
7172
typename P1_Type::value_type diff;
@@ -536,7 +537,7 @@ class BallTree{
536537
if(num_nbrs > (int)(Points_.size()) ){
537538
std::stringstream oss;
538539
oss << "query: k must be less than or equal to N Points (" << num_nbrs << " > " << (int)(Points_.size()) << ")\n";
539-
throw BallTreeException(oss.str().c_str());
540+
throw BallTreeException(oss.str());
540541
}
541542

542543
std::vector<pd_tuple<value_type> > PointSet;

0 commit comments

Comments
 (0)