Skip to content

Commit 6bd43f9

Browse files
authored
Type ExtensionArray.{dropna,unique,repeat,take,ravel} (#1423)
* Type ``ExtensionArray.{dropna,unique,repeat,take,ravel}`` * introduce and use AnyArrayLikeInt * add test file * reorder * reorder * remove redundant * remove redundant * include Sequence, add comment about integerarray, reduce diff
1 parent 9d3eb9f commit 6bd43f9

File tree

3 files changed

+46
-7
lines changed

3 files changed

+46
-7
lines changed

pandas-stubs/_typing.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ from typing import (
2828
import numpy as np
2929
from numpy import typing as npt
3030
import pandas as pd
31-
from pandas.core.arrays import ExtensionArray
31+
from pandas.core.arrays import (
32+
ExtensionArray,
33+
IntegerArray,
34+
)
3235
from pandas.core.frame import DataFrame
3336
from pandas.core.generic import NDFrame
3437
from pandas.core.groupby.grouper import Grouper
@@ -861,6 +864,10 @@ np_ndarray: TypeAlias = np.ndarray[ShapeT, np.dtype[GenericT]]
861864
np_1darray: TypeAlias = np.ndarray[tuple[int], np.dtype[GenericT]]
862865
np_2darray: TypeAlias = np.ndarray[tuple[int, int], np.dtype[GenericT]]
863866

867+
AnyArrayLikeInt: TypeAlias = (
868+
IntegerArray | Index[int] | Series[int] | np_1darray[np.integer] | Sequence[int]
869+
)
870+
864871
class SupportsDType(Protocol[GenericT_co]):
865872
@property
866873
def dtype(self) -> np.dtype[GenericT_co]: ...

pandas-stubs/core/arrays/base.pyi

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import numpy as np
99
from typing_extensions import Self
1010

1111
from pandas._typing import (
12+
AnyArrayLikeInt,
1213
ArrayLike,
1314
Scalar,
1415
ScalarIndexer,
@@ -52,22 +53,22 @@ class ExtensionArray:
5253
self, *, ascending: bool = ..., kind: str = ..., **kwargs: Any
5354
) -> np_1darray: ...
5455
def fillna(self, value=..., method=None, limit=None): ...
55-
def dropna(self): ...
56+
def dropna(self) -> Self: ...
5657
def shift(self, periods: int = 1, fill_value: object = ...) -> Self: ...
57-
def unique(self): ...
58+
def unique(self) -> Self: ...
5859
def searchsorted(self, value, side: str = ..., sorter=...): ...
5960
def factorize(self, use_na_sentinel: bool = True) -> tuple[np_1darray, Self]: ...
60-
def repeat(self, repeats, axis=...): ...
61+
def repeat(self, repeats: int | AnyArrayLikeInt, axis: None = None) -> Self: ...
6162
def take(
6263
self,
6364
indexer: TakeIndexer,
6465
*,
65-
allow_fill: bool = ...,
66-
fill_value=...,
66+
allow_fill: bool = False,
67+
fill_value: Any = None,
6768
) -> Self: ...
6869
def copy(self) -> Self: ...
6970
def view(self, dtype=...) -> Self | np_1darray: ...
70-
def ravel(self, order="C") -> Self: ...
71+
def ravel(self, order: Literal["C", "F", "A", "K"] | None = "C") -> Self: ...
7172
def tolist(self) -> list: ...
7273
def _reduce(
7374
self, name: str, *, skipna: bool = ..., keepdims: bool = ..., **kwargs: Any
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Test common ExtensionArray methods
2+
3+
import pandas as pd
4+
from pandas.core.arrays.integer import IntegerArray
5+
from pandas.core.construction import array
6+
from typing_extensions import assert_type
7+
8+
from tests import check
9+
10+
11+
def test_ea_common() -> None:
12+
# Note: `ExtensionArray` is abstract, so we use `IntegerArray` for the tests.
13+
arr = array([1, 2, 3])
14+
15+
check(assert_type(arr.repeat(1), IntegerArray), IntegerArray)
16+
check(assert_type(arr.repeat(arr), IntegerArray), IntegerArray)
17+
check(
18+
assert_type(arr.repeat(repeats=pd.Series([1, 2, 3])), IntegerArray),
19+
IntegerArray,
20+
)
21+
check(assert_type(arr.repeat(pd.Index([1, 2, 3])), IntegerArray), IntegerArray)
22+
check(assert_type(arr.repeat([1, 2, 3]), IntegerArray), IntegerArray)
23+
24+
check(assert_type(arr.unique(), IntegerArray), IntegerArray)
25+
check(assert_type(arr.dropna(), IntegerArray), IntegerArray)
26+
check(assert_type(arr.take([1, 0, 2]), IntegerArray), IntegerArray)
27+
check(
28+
assert_type(arr.take([1, 0, 2], allow_fill=True, fill_value=-1), IntegerArray),
29+
IntegerArray,
30+
)
31+
check(assert_type(arr.ravel(), IntegerArray), IntegerArray)

0 commit comments

Comments
 (0)