Skip to content

TYP: Type annotations overhaul, episode 2 #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 17, 2025
Merged
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b9b0206
TYP: annotate `_internal.get_xp` (and curse at `ParamSpec` for being …
jorenham Mar 22, 2025
6a17007
TYP: fix (or ignore) typing errors in `common._helpers` (and curse at…
jorenham Mar 22, 2025
3b134b0
TYP: fix typing errors in `common._fft`
jorenham Mar 22, 2025
344ac1e
TYP: fix typing errors in `common._aliases`
jorenham Mar 22, 2025
cbec5f3
TYP: fix typing errors in `common._linalg`
jorenham Mar 22, 2025
dc79e3f
TYP: fix/ignore typing errors in `numpy.__init__`
jorenham Mar 22, 2025
9643256
TYP: fix typing errors in `numpy._typing`
jorenham Mar 22, 2025
18870dc
TYP: fix typing errors in `numpy._aliases`
jorenham Mar 22, 2025
1fb929b
TYP: fix typing errors in `numpy._info`
jorenham Mar 22, 2025
014385f
TYP: fix typing errors in `numpy._fft`
jorenham Mar 22, 2025
ec72825
TYP: it's a bad idea to import `TypeAlias` from `typing` on `python<3…
jorenham Mar 22, 2025
ccd9bc6
TYP: it's also a bad idea to import `TypeGuard` from `typing` on `pyt…
jorenham Mar 22, 2025
b8c7883
TYP: don't scare the prehistoric `dtype` from numpy 1.21
jorenham Mar 22, 2025
ef066d1
TYP: dust off the DeLorean
jorenham Mar 22, 2025
a522dbc
TYP: figure out how to drive a DeLorean
jorenham Mar 22, 2025
bca9c0c
TYP: apply review suggestions
jorenham Apr 15, 2025
0dd925f
TYP: sprinkle some `TypeAlias`es and `Final`s around
jorenham Apr 15, 2025
953d7c0
TYP: `__dir__`
jorenham Apr 15, 2025
c66b750
TYP: fix typing errors in `numpy.linalg`
jorenham Apr 15, 2025
9acba46
TYP: add a `common._typing.Capabilities` typed dict type
jorenham Apr 15, 2025
ba0b4e5
TYP: `__array_namespace_info__` helper types
jorenham Apr 15, 2025
4278dfb
TYP: `dask.array` typing fixes and improvements
jorenham Apr 15, 2025
4f6ef6d
STY: give the `=` some breathing room
jorenham Apr 17, 2025
d758d6f
STY: apply review suggestions
jorenham Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
TYP: fix typing errors in common._fft
  • Loading branch information
jorenham committed Apr 15, 2025
commit 3b134b0c5d835f816b05729de4485ec2575520df
68 changes: 34 additions & 34 deletions array_api_compat/common/_fft.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Union, Optional, Literal
from typing import Literal, TypeAlias

from ._typing import Array, Device, DType, Namespace

from ._typing import Device, Array, DType, Namespace
_Norm: TypeAlias = Literal["backward", "ortho", "forward"]

# Note: NumPy fft functions improperly upcast float32 and complex64 to
# complex128, which is why we require wrapping them all here.
Expand All @@ -13,9 +13,9 @@ def fft(
/,
xp: Namespace,
*,
n: Optional[int] = None,
n: int | None = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
norm: _Norm = "backward",
) -> Array:
res = xp.fft.fft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
Expand All @@ -27,9 +27,9 @@ def ifft(
/,
xp: Namespace,
*,
n: Optional[int] = None,
n: int | None = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
norm: _Norm = "backward",
) -> Array:
res = xp.fft.ifft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
Expand All @@ -41,9 +41,9 @@ def fftn(
/,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.fftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
Expand All @@ -55,9 +55,9 @@ def ifftn(
/,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
Expand All @@ -69,9 +69,9 @@ def rfft(
/,
xp: Namespace,
*,
n: Optional[int] = None,
n: int | None = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
norm: _Norm = "backward",
) -> Array:
res = xp.fft.rfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.float32:
Expand All @@ -83,9 +83,9 @@ def irfft(
/,
xp: Namespace,
*,
n: Optional[int] = None,
n: int | None = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
norm: _Norm = "backward",
) -> Array:
res = xp.fft.irfft(x, n=n, axis=axis, norm=norm)
if x.dtype == xp.complex64:
Expand All @@ -97,9 +97,9 @@ def rfftn(
/,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.float32:
Expand All @@ -111,9 +111,9 @@ def irfftn(
/,
xp: Namespace,
*,
s: Sequence[int] = None,
axes: Sequence[int] = None,
norm: Literal["backward", "ortho", "forward"] = "backward",
s: Sequence[int] | None = None,
axes: Sequence[int] | None = None,
norm: _Norm = "backward",
) -> Array:
res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm)
if x.dtype == xp.complex64:
Expand All @@ -125,9 +125,9 @@ def hfft(
/,
xp: Namespace,
*,
n: Optional[int] = None,
n: int | None = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
norm: _Norm = "backward",
) -> Array:
res = xp.fft.hfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
Expand All @@ -139,9 +139,9 @@ def ihfft(
/,
xp: Namespace,
*,
n: Optional[int] = None,
n: int | None = None,
axis: int = -1,
norm: Literal["backward", "ortho", "forward"] = "backward",
norm: _Norm = "backward",
) -> Array:
res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm)
if x.dtype in [xp.float32, xp.complex64]:
Expand All @@ -154,8 +154,8 @@ def fftfreq(
xp: Namespace,
*,
d: float = 1.0,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
dtype: DType | None = None,
device: Device | None = None,
) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
Expand All @@ -170,8 +170,8 @@ def rfftfreq(
xp: Namespace,
*,
d: float = 1.0,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
dtype: DType | None = None,
device: Device | None = None,
) -> Array:
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
Expand All @@ -181,12 +181,12 @@ def rfftfreq(
return res

def fftshift(
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
) -> Array:
return xp.fft.fftshift(x, axes=axes)

def ifftshift(
x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None
x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None
) -> Array:
return xp.fft.ifftshift(x, axes=axes)

Expand Down