33from numpy .testing import assert_allclose , assert_array_equal
44from pytest import approx
55
6+ from sklearn ._config import config_context
7+ from sklearn .utils ._array_api import (
8+ _convert_to_numpy ,
9+ get_namespace ,
10+ yield_namespace_device_dtype_combinations ,
11+ )
12+ from sklearn .utils ._array_api import device as array_device
13+ from sklearn .utils .estimator_checks import _array_api_for_tests
614from sklearn .utils .fixes import np_version , parse_version
715from sklearn .utils .stats import _averaged_weighted_percentile , _weighted_percentile
816
@@ -39,6 +47,7 @@ def test_averaged_and_weighted_percentile():
3947
4048
4149def test_weighted_percentile ():
50+ """Check `weighted_percentile` on artificial data with obvious median."""
4251 y = np .empty (102 , dtype = np .float64 )
4352 y [:50 ] = 0
4453 y [- 51 :] = 2
@@ -51,15 +60,16 @@ def test_weighted_percentile():
5160
5261
5362def test_weighted_percentile_equal ():
63+ """Check `weighted_percentile` with all weights equal to 1."""
5464 y = np .empty (102 , dtype = np .float64 )
5565 y .fill (0.0 )
5666 sw = np .ones (102 , dtype = np .float64 )
57- sw [- 1 ] = 0.0
58- value = _weighted_percentile (y , sw , 50 )
59- assert value == 0
67+ score = _weighted_percentile (y , sw , 50 )
68+ assert approx (score ) == 0
6069
6170
6271def test_weighted_percentile_zero_weight ():
72+ """Check `weighted_percentile` with all weights equal to 0."""
6373 y = np .empty (102 , dtype = np .float64 )
6474 y .fill (1.0 )
6575 sw = np .ones (102 , dtype = np .float64 )
@@ -69,6 +79,11 @@ def test_weighted_percentile_zero_weight():
6979
7080
7181def test_weighted_percentile_zero_weight_zero_percentile ():
82+ """Check `weighted_percentile(percentile_rank=0)` behaves correctly.
83+
84+ Ensures that (leading)zero-weight observations ignored when `percentile_rank=0`.
85+ See #20528 for details.
86+ """
7287 y = np .array ([0 , 1 , 2 , 3 , 4 , 5 ])
7388 sw = np .array ([0 , 0 , 1 , 1 , 1 , 0 ])
7489 value = _weighted_percentile (y , sw , 0 )
@@ -82,18 +97,18 @@ def test_weighted_percentile_zero_weight_zero_percentile():
8297
8398
8499def test_weighted_median_equal_weights ():
85- # Checks that `_weighted_percentile` and `np.median` (both at probability level=0.5
86- # and with `sample_weights` being all 1s) return the same percentiles if the number
87- # of the samples in the data is odd. In this special case, `_weighted_percentile`
88- # always falls on a precise value (not on the next lower value) and is thus equal to
89- # `np.median`.
90- # As discussed in #17370, a similar check with an even number of samples does not
91- # consistently hold, since then the lower of two percentiles might be selected,
92- # while the median might lie in between.
100+ """Checks `_weighted_percentile(percentile_rank=50)` is the same as `np.median`.
101+
102+ `sample_weights` are all 1s and the number of samples is odd.
103+ When number of samples is odd, `_weighted_percentile` always falls on a single
104+ observation (not between 2 values, in which case the lower value would be taken)
105+ and is thus equal to `np.median`.
106+ For an even number of samples, this check will not always hold as (note that
107+ for some other percentile methods it will always hold). See #17370 for details.
108+ """
93109 rng = np .random .RandomState (0 )
94110 x = rng .randint (10 , size = 11 )
95111 weights = np .ones (x .shape )
96-
97112 median = np .median (x )
98113 w_median = _weighted_percentile (x , weights )
99114 assert median == approx (w_median )
@@ -106,10 +121,8 @@ def test_weighted_median_integer_weights():
106121 x = rng .randint (20 , size = 10 )
107122 weights = rng .choice (5 , size = 10 )
108123 x_manual = np .repeat (x , weights )
109-
110124 median = np .median (x_manual )
111125 w_median = _weighted_percentile (x , weights )
112-
113126 assert median == approx (w_median )
114127
115128
@@ -125,8 +138,7 @@ def test_weighted_percentile_2d():
125138 w_median = _weighted_percentile (x_2d , w1 )
126139 p_axis_0 = [_weighted_percentile (x_2d [:, i ], w1 ) for i in range (x_2d .shape [1 ])]
127140 assert_allclose (w_median , p_axis_0 )
128-
129- # Check when array and sample_weight boht 2D
141+ # Check when array and sample_weight both 2D
130142 w2 = rng .choice (5 , size = 10 )
131143 w_2d = np .vstack ((w1 , w2 )).T
132144
@@ -137,6 +149,91 @@ def test_weighted_percentile_2d():
137149 assert_allclose (w_median , p_axis_0 )
138150
139151
152+ @pytest .mark .parametrize (
153+ "array_namespace, device, dtype_name" , yield_namespace_device_dtype_combinations ()
154+ )
155+ @pytest .mark .parametrize (
156+ "data, weights, percentile" ,
157+ [
158+ # NumPy scalars input (handled as 0D arrays on array API)
159+ (np .float32 (42 ), np .int32 (1 ), 50 ),
160+ # Random 1D array, constant weights
161+ (lambda rng : rng .rand (50 ), np .ones (50 ).astype (np .int32 ), 50 ),
162+ # Random 2D array and random 1D weights
163+ (lambda rng : rng .rand (50 , 3 ), lambda rng : rng .rand (50 ).astype (np .float32 ), 75 ),
164+ # Random 2D array and random 2D weights
165+ (
166+ lambda rng : rng .rand (20 , 3 ),
167+ lambda rng : rng .rand (20 , 3 ).astype (np .float32 ),
168+ 25 ,
169+ ),
170+ # zero-weights and `rank_percentile=0` (#20528) (`sample_weight` dtype: int64)
171+ (np .array ([0 , 1 , 2 , 3 , 4 , 5 ]), np .array ([0 , 0 , 1 , 1 , 1 , 0 ]), 0 ),
172+ # np.nan's in data and some zero-weights (`sample_weight` dtype: int64)
173+ (np .array ([np .nan , np .nan , 0 , 3 , 4 , 5 ]), np .array ([0 , 1 , 1 , 1 , 1 , 0 ]), 0 ),
174+ # `sample_weight` dtype: int32
175+ (
176+ np .array ([0 , 1 , 2 , 3 , 4 , 5 ]),
177+ np .array ([0 , 1 , 1 , 1 , 1 , 0 ], dtype = np .int32 ),
178+ 25 ,
179+ ),
180+ ],
181+ )
182+ def test_weighted_percentile_array_api_consistency (
183+ global_random_seed , array_namespace , device , dtype_name , data , weights , percentile
184+ ):
185+ """Check `_weighted_percentile` gives consistent results with array API."""
186+ if array_namespace == "array_api_strict" :
187+ try :
188+ import array_api_strict
189+ except ImportError :
190+ pass
191+ else :
192+ if device == array_api_strict .Device ("device1" ):
193+ # See https://github.com/data-apis/array-api-strict/issues/134
194+ pytest .xfail (
195+ "array_api_strict has bug when indexing with tuple of arrays "
196+ "on non-'CPU_DEVICE' devices."
197+ )
198+
199+ xp = _array_api_for_tests (array_namespace , device )
200+
201+ # Skip test for percentile=0 edge case (#20528) on namespace/device where
202+ # xp.nextafter is broken. This is the case for torch with MPS device:
203+ # https://github.com/pytorch/pytorch/issues/150027
204+ zero = xp .zeros (1 , device = device )
205+ one = xp .ones (1 , device = device )
206+ if percentile == 0 and xp .all (xp .nextafter (zero , one ) == zero ):
207+ pytest .xfail (f"xp.nextafter is broken on { device } " )
208+
209+ rng = np .random .RandomState (global_random_seed )
210+ X_np = data (rng ) if callable (data ) else data
211+ weights_np = weights (rng ) if callable (weights ) else weights
212+ # Ensure `data` of correct dtype
213+ X_np = X_np .astype (dtype_name )
214+
215+ result_np = _weighted_percentile (X_np , weights_np , percentile )
216+ # Convert to Array API arrays
217+ X_xp = xp .asarray (X_np , device = device )
218+ weights_xp = xp .asarray (weights_np , device = device )
219+
220+ with config_context (array_api_dispatch = True ):
221+ result_xp = _weighted_percentile (X_xp , weights_xp , percentile )
222+ assert array_device (result_xp ) == array_device (X_xp )
223+ assert get_namespace (result_xp )[0 ] == get_namespace (X_xp )[0 ]
224+ result_xp_np = _convert_to_numpy (result_xp , xp = xp )
225+
226+ assert result_xp_np .dtype == result_np .dtype
227+ assert result_xp_np .shape == result_np .shape
228+ assert_allclose (result_np , result_xp_np )
229+
230+ # Check dtype correct (`sample_weight` should follow `array`)
231+ if dtype_name == "float32" :
232+ assert result_xp_np .dtype == result_np .dtype == np .float32
233+ else :
234+ assert result_xp_np .dtype == np .float64
235+
236+
140237@pytest .mark .parametrize ("sample_weight_ndim" , [1 , 2 ])
141238def test_weighted_percentile_nan_filtered (sample_weight_ndim ):
142239 """Test that calling _weighted_percentile on an array with nan values returns
0 commit comments