Skip to content

Commit a19d398

Browse files
authored
MAINT: deprecate return_indices in favor of attribute sample_indices_ (#474)
1 parent 7f93dfc commit a19d398

30 files changed

+337
-176
lines changed

doc/whats_new/v0.0.4.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,8 @@ Deprecation
147147

148148
- Deprecate :class:`imblearn.ensemble.BalanceCascade`.
149149
:issue:`472` by :user:`Guillaume Lemaitre <glemaitre>`.
150+
151+
- Deprecate ``return_indices`` in all samplers. Instead, an attribute
152+
``sample_indices_`` is created whenever the sampler is selecting a subset of
153+
the original samples.
154+
:issue:`474` by :user:`Guillaume Lemaitre <glemaitre`.

imblearn/ensemble/_easy_ensemble.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,13 @@ def _fit_resample(self, X, y):
121121
for _ in range(self.n_subsets):
122122
rus = RandomUnderSampler(
123123
sampling_strategy=self.sampling_strategy_,
124-
return_indices=True,
125124
random_state=random_state.randint(MAX_INT),
126125
replacement=self.replacement)
127-
sel_x, sel_y, sel_idx = rus.fit_resample(X, y)
126+
sel_x, sel_y = rus.fit_resample(X, y)
128127
X_resampled.append(sel_x)
129128
y_resampled.append(sel_y)
130129
if self.return_indices:
131-
idx_under.append(sel_idx)
130+
idx_under.append(rus.sample_indices_)
132131

133132
if self.return_indices:
134133
return (np.array(X_resampled), np.array(y_resampled),

imblearn/ensemble/_forest.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def _local_parallel_build_trees(sampler, tree, forest, X, y, sample_weight,
3636
tree_idx, n_trees, verbose=0,
3737
class_weight=None):
3838
# resample before to fit the tree
39-
X_resampled, y_resampled, selected_idx = sampler.fit_sample(X, y)
39+
X_resampled, y_resampled = sampler.fit_sample(X, y)
4040
if sample_weight is not None:
41-
sample_weight = safe_indexing(sample_weight, selected_idx)
41+
sample_weight = safe_indexing(sample_weight, sampler.sample_indices_)
4242
tree = _parallel_build_trees(tree, forest, X_resampled, y_resampled,
4343
sample_weight, tree_idx, n_trees,
4444
verbose=verbose, class_weight=class_weight)
@@ -306,8 +306,7 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):
306306

307307
self.base_sampler_ = RandomUnderSampler(
308308
sampling_strategy=self.sampling_strategy,
309-
replacement=self.replacement,
310-
return_indices=True)
309+
replacement=self.replacement)
311310

312311
def _make_sampler_estimator(self, random_state=None):
313312
"""Make and configure a copy of the `base_estimator_` attribute.
@@ -450,9 +449,6 @@ def fit(self, X, y, sample_weight=None):
450449
# Create pipeline with the fitted samplers and trees
451450
self.pipelines_.extend([make_pipeline(deepcopy(s), deepcopy(t))
452451
for s, t in zip(samplers, trees)])
453-
for idx in range(len(self.pipelines_)):
454-
self.pipelines_[idx].named_steps[
455-
'randomundersampler'].set_params(return_indices=False)
456452

457453
if self.oob_score:
458454
self._set_oob_score(X, y)

imblearn/ensemble/_weight_boosting.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,7 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):
169169

170170
self.base_sampler_ = RandomUnderSampler(
171171
sampling_strategy=self.sampling_strategy,
172-
replacement=self.replacement,
173-
return_indices=True)
172+
replacement=self.replacement)
174173

175174
def _make_sampler_estimator(self, append=True, random_state=None):
176175
"""Make and configure a copy of the `base_estimator_` attribute.
@@ -191,9 +190,6 @@ def _make_sampler_estimator(self, append=True, random_state=None):
191190
self.samplers_.append(sampler)
192191
self.pipelines_.append(make_pipeline(deepcopy(sampler),
193192
deepcopy(estimator)))
194-
# do not return the indices within a pipeline
195-
self.pipelines_[-1].named_steps['randomundersampler'].set_params(
196-
return_indices=False)
197193

198194
return estimator, sampler
199195

@@ -202,8 +198,9 @@ def _boost_real(self, iboost, X, y, sample_weight, random_state):
202198
estimator, sampler = self._make_sampler_estimator(
203199
random_state=random_state)
204200

205-
X_res, y_res, idx_res = sampler.fit_resample(X, y)
206-
sample_weight_res = safe_indexing(sample_weight, idx_res)
201+
X_res, y_res = sampler.fit_resample(X, y)
202+
sample_weight_res = safe_indexing(sample_weight,
203+
sampler.sample_indices_)
207204
estimator.fit(X_res, y_res, sample_weight=sample_weight_res)
208205

209206
y_predict_proba = estimator.predict_proba(X)
@@ -263,8 +260,9 @@ def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
263260
estimator, sampler = self._make_sampler_estimator(
264261
random_state=random_state)
265262

266-
X_res, y_res, idx_res = sampler.fit_resample(X, y)
267-
sample_weight_res = safe_indexing(sample_weight, idx_res)
263+
X_res, y_res = sampler.fit_resample(X, y)
264+
sample_weight_res = safe_indexing(sample_weight,
265+
sampler.sample_indices_)
268266
estimator.fit(X_res, y_res, sample_weight=sample_weight_res)
269267

270268
y_predict = estimator.predict(X)

imblearn/ensemble/tests/test_forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_balanced_random_forest_attributes(imbalanced_dataset):
6767
brf.fit(X, y)
6868

6969
for idx in range(n_estimators):
70-
X_res, y_res, _ = brf.samplers_[idx].fit_resample(X, y)
70+
X_res, y_res = brf.samplers_[idx].fit_resample(X, y)
7171
X_res_2, y_res_2 = brf.pipelines_[idx].named_steps[
7272
'randomundersampler'].fit_resample(X, y)
7373
assert_allclose(X_res, X_res_2)

imblearn/keras/_generator.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class BalancedBatchGenerator(ParentClass):
3535
Create a keras ``Sequence`` which is given to ``fit_generator``. The
3636
sampler defines the sampling strategy used to balance the dataset ahead of
3737
creating the batch. The sampler should have an attribute
38-
``return_indices``.
38+
``sample_indices_``.
3939
4040
Parameters
4141
----------
@@ -49,7 +49,7 @@ class BalancedBatchGenerator(ParentClass):
4949
Sample weight.
5050
5151
sampler : object or None, optional (default=RandomUnderSampler)
52-
A sampler instance which has an attribute ``return_indices``.
52+
A sampler instance which has an attribute ``sample_indices_``.
5353
By default, the sampler used is a
5454
:class:`imblearn.under_sampling.RandomUnderSampler`.
5555
@@ -118,20 +118,18 @@ def __init__(self, X, y, sample_weight=None, sampler=None, batch_size=32,
118118
def _sample(self):
119119
random_state = check_random_state(self.random_state)
120120
if self.sampler is None:
121-
self.sampler_ = RandomUnderSampler(return_indices=True,
122-
random_state=random_state)
121+
self.sampler_ = RandomUnderSampler(random_state=random_state)
123122
else:
124-
if not hasattr(self.sampler, 'return_indices'):
125-
raise ValueError("'sampler' needs to return the indices of "
126-
"the samples selected. Provide a sampler "
127-
"which has an attribute 'return_indices'.")
128123
self.sampler_ = clone(self.sampler)
129-
self.sampler_.set_params(return_indices=True)
130124
# FIXME: Remove in 0.6
131125
if self.sampler_.__class__.__name__ not in DONT_HAVE_RANDOM_STATE:
132126
set_random_state(self.sampler_, random_state)
133127

134-
_, _, self.indices_ = self.sampler_.fit_resample(self.X, self.y)
128+
self.sampler_.fit_resample(self.X, self.y)
129+
if not hasattr(self.sampler_, 'sample_indices_'):
130+
raise ValueError("'sampler' needs to have an attribute "
131+
"'sample_indices_'.")
132+
self.indices_ = self.sampler_.sample_indices_
135133
# shuffle the indices since the sampler are packing them by class
136134
random_state.shuffle(self.indices_)
137135

@@ -168,7 +166,7 @@ def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
168166
Returns a generator --- as well as the number of step per epoch --- which
169167
is given to ``fit_generator``. The sampler defines the sampling strategy
170168
used to balance the dataset ahead of creating the batch. The sampler should
171-
have an attribute ``return_indices``.
169+
have an attribute ``sample_indices_``.
172170
173171
Parameters
174172
----------
@@ -182,7 +180,7 @@ def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
182180
Sample weight.
183181
184182
sampler : object or None, optional (default=RandomUnderSampler)
185-
A sampler instance which has an attribute ``return_indices``.
183+
A sampler instance which has an attribute ``sample_indices_``.
186184
By default, the sampler used is a
187185
:class:`imblearn.under_sampling.RandomUnderSampler`.
188186

imblearn/keras/tests/test_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _build_keras_model(n_classes, n_features):
3636

3737

3838
def test_balanced_batch_generator_class_no_return_indices(data):
39-
with pytest.raises(ValueError, match='needs to return the indices'):
39+
with pytest.raises(ValueError, match='needs to have an attribute'):
4040
BalancedBatchGenerator(*data, sampler=ClusterCentroids(), batch_size=10)
4141

4242

@@ -75,7 +75,7 @@ def test_balanced_batch_generator_class_sparse(data, keep_sparse):
7575

7676

7777
def test_balanced_batch_generator_function_no_return_indices(data):
78-
with pytest.raises(ValueError, match='needs to return the indices'):
78+
with pytest.raises(ValueError, match='needs to have an attribute'):
7979
balanced_batch_generator(
8080
*data, sampler=ClusterCentroids(), batch_size=10, random_state=42)
8181

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .base import BaseOverSampler
1414
from ..utils import check_target_type
1515
from ..utils import Substitution
16+
from ..utils.deprecation import deprecate_parameter
1617
from ..utils._docstring import _random_state_docstring
1718

1819

@@ -37,11 +38,23 @@ class RandomOverSampler(BaseOverSampler):
3738
Whether or not to return the indices of the samples randomly selected
3839
in the corresponding classes.
3940
41+
.. deprecated:: 0.4
42+
``return_indices`` is deprecated. Use the attribute
43+
``sample_indices_`` instead.
44+
4045
ratio : str, dict, or callable
4146
.. deprecated:: 0.4
4247
Use the parameter ``sampling_strategy`` instead. It will be removed
4348
in 0.6.
4449
50+
Attributes
51+
----------
52+
sample_indices_ : ndarray, shape (n_new_samples)
53+
Indices of the samples selected.
54+
55+
.. versionadded:: 0.4
56+
``sample_indices_`` used instead of ``return_indices=True``.
57+
4558
Notes
4659
-----
4760
Supports multi-class resampling by sampling each class independently.
@@ -83,6 +96,10 @@ def _check_X_y(X, y):
8396
return X, y, binarize_y
8497

8598
def _fit_resample(self, X, y):
99+
if self.return_indices:
100+
deprecate_parameter(self, '0.4', 'return_indices',
101+
'sample_indices_')
102+
86103
random_state = check_random_state(self.random_state)
87104
target_stats = Counter(y)
88105

@@ -95,10 +112,10 @@ def _fit_resample(self, X, y):
95112

96113
sample_indices = np.append(sample_indices,
97114
target_class_indices[indices])
115+
self.sample_indices_ = np.array(sample_indices)
98116

99117
if self.return_indices:
100-
return (safe_indexing(X, sample_indices), safe_indexing(
101-
y, sample_indices), sample_indices)
102-
else:
103-
return (safe_indexing(X, sample_indices), safe_indexing(
104-
y, sample_indices))
118+
return (safe_indexing(X, sample_indices),
119+
safe_indexing(y, sample_indices), sample_indices)
120+
return (safe_indexing(X, sample_indices),
121+
safe_indexing(y, sample_indices))

imblearn/over_sampling/tests/test_random_over_sampler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from collections import Counter
77

8+
import pytest
89
import numpy as np
910
from sklearn.utils.testing import assert_allclose
1011
from sklearn.utils.testing import assert_array_equal
@@ -59,6 +60,7 @@ def test_ros_fit_resample_half():
5960
assert_array_equal(y_resampled, y_gt)
6061

6162

63+
@pytest.mark.filterwarnings("ignore:'return_indices' is deprecated from 0.4")
6264
def test_random_over_sampling_return_indices():
6365
ros = RandomOverSampler(return_indices=True, random_state=RND_SEED)
6466
X_resampled, y_resampled, sample_indices = ros.fit_resample(X, Y)

imblearn/tensorflow/_generator.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
2727
Returns a generator --- as well as the number of step per epoch --- which
2828
is given to ``fit_generator``. The sampler defines the sampling strategy
2929
used to balance the dataset ahead of creating the batch. The sampler should
30-
have an attribute ``return_indices``.
30+
have an attribute ``sample_indices_``.
3131
3232
Parameters
3333
----------
@@ -41,7 +41,7 @@ def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
4141
Sample weight.
4242
4343
sampler : object or None, optional (default=RandomUnderSampler)
44-
A sampler instance which has an attribute ``return_indices``.
44+
A sampler instance which has an attribute ``sample_indices_``.
4545
By default, the sampler used is a
4646
:class:`imblearn.under_sampling.RandomUnderSampler`.
4747
@@ -122,20 +122,17 @@ def balanced_batch_generator(X, y, sample_weight=None, sampler=None,
122122

123123
random_state = check_random_state(random_state)
124124
if sampler is None:
125-
sampler_ = RandomUnderSampler(return_indices=True,
126-
random_state=random_state)
125+
sampler_ = RandomUnderSampler(random_state=random_state)
127126
else:
128-
if not hasattr(sampler, 'return_indices'):
129-
raise ValueError("'sampler' needs to return the indices of "
130-
"the samples selected. Provide a sampler "
131-
"which has an attribute 'return_indices'.")
132127
sampler_ = clone(sampler)
133-
sampler_.set_params(return_indices=True)
134128
# FIXME: Remove in 0.6
135129
if sampler_.__class__.__name__ not in DONT_HAVE_RANDOM_STATE:
136130
set_random_state(sampler_, random_state)
137-
138-
_, _, indices = sampler_.fit_resample(X, y)
131+
sampler_.fit_resample(X, y)
132+
if not hasattr(sampler_, 'sample_indices_'):
133+
raise ValueError("'sampler' needs to have an attribute "
134+
"'sample_indices_'.")
135+
indices = sampler_.sample_indices_
139136
# shuffle the indices since the sampler are packing them by class
140137
random_state.shuffle(indices)
141138

imblearn/tests/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def _generate_checks_per_estimator(check_generator, estimators):
6262
@pytest.mark.filterwarnings('ignore:"out_step" is deprecated in 0.4 and')
6363
@pytest.mark.filterwarnings('ignore:"m_neighbors" is deprecated in 0.4 and')
6464
@pytest.mark.filterwarnings("ignore:'y' should be of types")
65+
@pytest.mark.filterwarnings("ignore:'return_indices' is deprecated from 0.4")
6566
@pytest.mark.parametrize(
6667
'name, Estimator, check',
6768
_generate_checks_per_estimator(_yield_all_checks,

imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from ..base import BaseCleaningSampler
2121
from ...utils import Substitution
22+
from ...utils.deprecation import deprecate_parameter
2223
from ...utils._docstring import _random_state_docstring
2324

2425

@@ -37,7 +38,11 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
3738
3839
return_indices : bool, optional (default=False)
3940
Whether or not to return the indices of the samples randomly
40-
selected from the majority class.
41+
selected.
42+
43+
.. deprecated:: 0.4
44+
``return_indices`` is deprecated. Use the attribute
45+
``sample_indices_`` instead.
4146
4247
{random_state}
4348
@@ -59,6 +64,14 @@ class CondensedNearestNeighbour(BaseCleaningSampler):
5964
Use the parameter ``sampling_strategy`` instead. It will be removed
6065
in 0.6.
6166
67+
Attributes
68+
----------
69+
sample_indices_ : ndarray, shape (n_new_samples)
70+
Indices of the samples selected.
71+
72+
.. versionadded:: 0.4
73+
``sample_indices_`` used instead of ``return_indices=True``.
74+
6275
Notes
6376
-----
6477
The method is based on [1]_.
@@ -126,6 +139,9 @@ def _validate_estimator(self):
126139
' Got {} instead.'.format(type(self.n_neighbors)))
127140

128141
def _fit_resample(self, X, y):
142+
if self.return_indices:
143+
deprecate_parameter(self, '0.4', 'return_indices',
144+
'sample_indices_')
129145
self._validate_estimator()
130146

131147
random_state = check_random_state(self.random_state)
@@ -198,8 +214,9 @@ def _fit_resample(self, X, y):
198214
idx_under = np.concatenate(
199215
(idx_under, np.flatnonzero(y == target_class)), axis=0)
200216

217+
self.sample_indices_ = idx_under
218+
201219
if self.return_indices:
202220
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
203221
idx_under)
204-
else:
205-
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)
222+
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)

0 commit comments

Comments
 (0)