1212from .base import LinearRegression
1313
1414
15+ _EPSILON = np .spacing (1 )
16+
17+
18+ def _dynamic_max_trials (n_inliers , n_samples , min_samples , probability ):
19+ """Determine number trials such that at least one outlier-free subset is
20+ sampled for the given inlier/outlier ratio.
21+
22+ Parameters
23+ ----------
24+ n_inliers : int
25+ Number of inliers in the data.
26+
27+ n_samples : int
28+ Total number of samples in the data.
29+
30+ min_samples : int
31+ Minimum number of samples chosen randomly from original data.
32+
33+ probability : float
34+ Probability (confidence) that one outlier-free sample is generated.
35+
36+ Returns
37+ -------
38+ trials : int
39+ Number of trials.
40+
41+ """
42+ inlier_ratio = n_inliers / float (n_samples )
43+ nom = max (_EPSILON , 1 - probability )
44+ denom = max (_EPSILON , 1 - inlier_ratio ** min_samples )
45+ if nom == 1 :
46+ return 0
47+ if denom == 1 :
48+ return float ('inf' )
49+ return abs (float (np .ceil (np .log (nom ) / np .log (denom ))))
50+
51+
1552class RANSACRegressor (BaseEstimator , MetaEstimatorMixin , RegressorMixin ):
1653 """RANSAC (RANdom SAmple Consensus) algorithm.
1754
@@ -44,7 +81,8 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
4481 Minimum number of samples chosen randomly from original data. Treated
4582 as an absolute number of samples for `min_samples >= 1`, treated as a
4683 relative number `ceil(min_samples * X.shape[0]`) for
47- `min_samples < 1`. By default a
84+ `min_samples < 1`. This is typically chosen as the minimal number of
85+ samples necessary to estimate the given `base_estimator`. By default a
4886 ``sklearn.linear_model.LinearRegression()`` estimator is assumed and
4987 `min_samples` is chosen as ``X.shape[1] + 1``.
5088
@@ -75,6 +113,17 @@ class RANSACRegressor(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
75113 stop_score : float, optional
76114 Stop iteration if score is greater equal than this threshold.
77115
116+ stop_probability : float in range [0, 1], optional
117+ RANSAC iteration stops if at least one outlier-free set of the training
118+ data is sampled in RANSAC. This requires to generate at least N
119+ samples (iterations)::
120+
121+ N >= log(1 - probability) / log(1 - e**m)
122+
123+ where the probability (confidence) is typically set to high value such
124+ as 0.99 (the default) and e is the current fraction of inliers w.r.t.
125+ the total number of samples.
126+
78127 residual_metric : callable, optional
79128 Metric to reduce the dimensionality of the residuals to 1 for
80129 multi-dimensional target values ``y.shape[1] > 1``. By default the sum
@@ -110,7 +159,8 @@ def __init__(self, base_estimator=None, min_samples=None,
110159 residual_threshold = None , is_data_valid = None ,
111160 is_model_valid = None , max_trials = 100 ,
112161 stop_n_inliers = np .inf , stop_score = np .inf ,
113- residual_metric = None , random_state = None ):
162+ stop_probability = 0.99 , residual_metric = None ,
163+ random_state = None ):
114164
115165 self .base_estimator = base_estimator
116166 self .min_samples = min_samples
@@ -120,6 +170,7 @@ def __init__(self, base_estimator=None, min_samples=None,
120170 self .max_trials = max_trials
121171 self .stop_n_inliers = stop_n_inliers
122172 self .stop_score = stop_score
173+ self .stop_probability = stop_probability
123174 self .residual_metric = residual_metric
124175 self .random_state = random_state
125176
@@ -164,6 +215,9 @@ def fit(self, X, y):
164215 raise ValueError ("`min_samples` may not be larger than number "
165216 "of samples ``X.shape[0]``." )
166217
218+ if self .stop_probability < 0 or self .stop_probability > 1 :
219+ raise ValueError ("`stop_probability` must be in range [0, 1]." )
220+
167221 if self .residual_threshold is None :
168222 # MAD (median absolute deviation)
169223 residual_threshold = np .median (np .abs (y - np .median (y )))
@@ -258,7 +312,11 @@ def fit(self, X, y):
258312
259313 # break if sufficient number of inliers or score is reached
260314 if (n_inliers_best >= self .stop_n_inliers
261- or score_best >= self .stop_score ):
315+ or score_best >= self .stop_score
316+ or self .n_trials_
317+ >= _dynamic_max_trials (n_inliers_best , n_samples ,
318+ min_samples ,
319+ self .stop_probability )):
262320 break
263321
264322 # if none of the iterations met the required criteria
0 commit comments