Skip to content

Commit 1d2baf8

Browse files
ahojnneslarsmans
authored andcommitted
ENH: Add dynamic maximum trial determination to RANSACRegressor
1 parent adc24fd commit 1d2baf8

File tree

2 files changed

+103
-5
lines changed

2 files changed

+103
-5
lines changed

sklearn/linear_model/ransac.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,43 @@
1212
from .base import LinearRegression
1313

1414

15+
_EPSILON = np.spacing(1)
16+
17+
18+
def _dynamic_max_trials(n_inliers, n_samples, min_samples, probability):
19+
"""Determine number trials such that at least one outlier-free subset is
20+
sampled for the given inlier/outlier ratio.
21+
22+
Parameters
23+
----------
24+
n_inliers : int
25+
Number of inliers in the data.
26+
27+
n_samples : int
28+
Total number of samples in the data.
29+
30+
min_samples : int
31+
Minimum number of samples chosen randomly from original data.
32+
33+
probability : float
34+
Probability (confidence) that one outlier-free sample is generated.
35+
36+
Returns
37+
-------
38+
trials : int
39+
Number of trials.
40+
41+
"""
42+
inlier_ratio = n_inliers / float(n_samples)
43+
nom = max(_EPSILON, 1 - probability)
44+
denom = max(_EPSILON, 1 - inlier_ratio ** min_samples)
45+
if nom == 1:
46+
return 0
47+
if denom == 1:
48+
return float('inf')
49+
return abs(float(np.ceil(np.log(nom) / np.log(denom))))
50+
51+
1552
class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
1653
"""RANSAC (RANdom SAmple Consensus) algorithm.
1754
@@ -44,7 +81,8 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
4481
Minimum number of samples chosen randomly from original data. Treated
4582
as an absolute number of samples for `min_samples >= 1`, treated as a
4683
relative number `ceil(min_samples * X.shape[0]`) for
47-
`min_samples < 1`. By default a
84+
`min_samples < 1`. This is typically chosen as the minimal number of
85+
samples necessary to estimate the given `base_estimator`. By default a
4886
``sklearn.linear_model.LinearRegression()`` estimator is assumed and
4987
`min_samples` is chosen as ``X.shape[1] + 1``.
5088
@@ -75,6 +113,17 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
75113
stop_score : float, optional
76114
Stop iteration if score is greater equal than this threshold.
77115
116+
stop_probability : float in range [0, 1], optional
117+
RANSAC iteration stops if at least one outlier-free set of the training
118+
data is sampled in RANSAC. This requires to generate at least N
119+
samples (iterations)::
120+
121+
N >= log(1 - probability) / log(1 - e**m)
122+
123+
where the probability (confidence) is typically set to high value such
124+
as 0.99 (the default) and e is the current fraction of inliers w.r.t.
125+
the total number of samples.
126+
78127
residual_metric : callable, optional
79128
Metric to reduce the dimensionality of the residuals to 1 for
80129
multi-dimensional target values ``y.shape[1] > 1``. By default the sum
@@ -110,7 +159,8 @@ def __init__(self, base_estimator=None, min_samples=None,
110159
residual_threshold=None, is_data_valid=None,
111160
is_model_valid=None, max_trials=100,
112161
stop_n_inliers=np.inf, stop_score=np.inf,
113-
residual_metric=None, random_state=None):
162+
stop_probability=0.99, residual_metric=None,
163+
random_state=None):
114164

115165
self.base_estimator = base_estimator
116166
self.min_samples = min_samples
@@ -120,6 +170,7 @@ def __init__(self, base_estimator=None, min_samples=None,
120170
self.max_trials = max_trials
121171
self.stop_n_inliers = stop_n_inliers
122172
self.stop_score = stop_score
173+
self.stop_probability = stop_probability
123174
self.residual_metric = residual_metric
124175
self.random_state = random_state
125176

@@ -164,6 +215,9 @@ def fit(self, X, y):
164215
raise ValueError("`min_samples` may not be larger than number "
165216
"of samples ``X.shape[0]``.")
166217

218+
if self.stop_probability < 0 or self.stop_probability > 1:
219+
raise ValueError("`stop_probability` must be in range [0, 1].")
220+
167221
if self.residual_threshold is None:
168222
# MAD (median absolute deviation)
169223
residual_threshold = np.median(np.abs(y - np.median(y)))
@@ -258,7 +312,11 @@ def fit(self, X, y):
258312

259313
# break if sufficient number of inliers or score is reached
260314
if (n_inliers_best >= self.stop_n_inliers
261-
or score_best >= self.stop_score):
315+
or score_best >= self.stop_score
316+
or self.n_trials_
317+
>= _dynamic_max_trials(n_inliers_best, n_samples,
318+
min_samples,
319+
self.stop_probability)):
262320
break
263321

264322
# if none of the iterations met the required criteria

sklearn/linear_model/tests/test_ransac.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from sklearn.utils.testing import assert_less
66
from sklearn.linear_model import LinearRegression, RANSACRegressor
7+
from sklearn.linear_model.ransac import _dynamic_max_trials
78

89

910
# Generate coordinates of line
@@ -84,7 +85,7 @@ def test_ransac_max_trials():
8485
random_state=0)
8586
assert getattr(ransac_estimator, 'n_trials_', None) is None
8687
ransac_estimator.fit(X, y)
87-
assert_equal(ransac_estimator.n_trials_, 11)
88+
assert_equal(ransac_estimator.n_trials_, 2)
8889

8990

9091
def test_ransac_stop_n_inliers():
@@ -277,7 +278,6 @@ def test_ransac_residual_metric():
277278

278279

279280
def test_ransac_default_residual_threshold():
280-
281281
base_estimator = LinearRegression()
282282
ransac_estimator = RANSACRegressor(base_estimator, min_samples=2,
283283
random_state=0)
@@ -293,5 +293,45 @@ def test_ransac_default_residual_threshold():
293293
assert_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
294294

295295

296+
def test_ransac_dynamic_max_trials():
297+
# Numbers hand-calculated and confirmed on page 119 (Table 4.3) in
298+
# Hartley, R.~I. and Zisserman, A., 2004,
299+
# Multiple View Geometry in Computer Vision, Second Edition,
300+
# Cambridge University Press, ISBN: 0521540518
301+
302+
# e = 0%, min_samples = X
303+
assert_equal(_dynamic_max_trials(100, 100, 2, 0.99), 1)
304+
305+
# e = 5%, min_samples = 2
306+
assert_equal(_dynamic_max_trials(95, 100, 2, 0.99), 2)
307+
# e = 10%, min_samples = 2
308+
assert_equal(_dynamic_max_trials(90, 100, 2, 0.99), 3)
309+
# e = 30%, min_samples = 2
310+
assert_equal(_dynamic_max_trials(70, 100, 2, 0.99), 7)
311+
# e = 50%, min_samples = 2
312+
assert_equal(_dynamic_max_trials(50, 100, 2, 0.99), 17)
313+
314+
# e = 5%, min_samples = 8
315+
assert_equal(_dynamic_max_trials(95, 100, 8, 0.99), 5)
316+
# e = 10%, min_samples = 8
317+
assert_equal(_dynamic_max_trials(90, 100, 8, 0.99), 9)
318+
# e = 30%, min_samples = 8
319+
assert_equal(_dynamic_max_trials(70, 100, 8, 0.99), 78)
320+
# e = 50%, min_samples = 8
321+
assert_equal(_dynamic_max_trials(50, 100, 8, 0.99), 1177)
322+
323+
# e = 0%, min_samples = 10
324+
assert_equal(_dynamic_max_trials(1, 100, 10, 0), 0)
325+
assert_equal(_dynamic_max_trials(1, 100, 10, 1), float('inf'))
326+
327+
base_estimator = LinearRegression()
328+
ransac_estimator = RANSACRegressor(base_estimator, min_samples=2,
329+
stop_probability=-0.1)
330+
assert_raises(ValueError, ransac_estimator.fit, X, y)
331+
ransac_estimator = RANSACRegressor(base_estimator, min_samples=2,
332+
stop_probability=1.1)
333+
assert_raises(ValueError, ransac_estimator.fit, X, y)
334+
335+
296336
if __name__ == "__main__":
297337
np.testing.run_module_suite()

0 commit comments

Comments
 (0)