2323 SupportsIndex ,
2424 TypeAlias ,
2525 TypeGuard ,
26- TypeVar ,
2726 cast ,
2827 overload ,
2928)
3029
3130from ._typing import Array , Device , HasShape , Namespace , SupportsArrayNamespace
3231
3332if TYPE_CHECKING :
34-
33+ import cupy as cp
3534 import dask .array as da
3635 import jax
3736 import ndonnx as ndx
3837 import numpy as np
3938 import numpy .typing as npt
40- import sparse # pyright: ignore[reportMissingTypeStubs]
39+ import sparse
4140 import torch
4241
4342 # TODO: import from typing (requires Python >=3.13)
44- from typing_extensions import TypeIs , TypeVar
45-
46- _SizeT = TypeVar ("_SizeT" , bound = int | None )
43+ from typing_extensions import TypeIs
4744
4845 _ZeroGradientArray : TypeAlias = npt .NDArray [np .void ]
49- _CupyArray : TypeAlias = Any # cupy has no py.typed
5046
5147 _ArrayApiObj : TypeAlias = (
5248 npt .NDArray [Any ]
49+ | cp .ndarray
5350 | da .Array
5451 | jax .Array
5552 | ndx .Array
5653 | sparse .SparseArray
5754 | torch .Tensor
5855 | SupportsArrayNamespace [Any ]
59- | _CupyArray
6056 )
6157
6258_API_VERSIONS_OLD : Final = frozenset ({"2021.12" , "2022.12" , "2023.12" })
@@ -96,7 +92,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
9692 return dtype == jax .float0
9793
9894
99- def is_numpy_array (x : object ) -> TypeGuard [npt .NDArray [Any ]]:
95+ def is_numpy_array (x : object ) -> TypeIs [npt .NDArray [Any ]]:
10096 """
10197 Return True if `x` is a NumPy array.
10298
@@ -267,7 +263,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
267263 return _issubclass_fast (cls , "sparse" , "SparseArray" )
268264
269265
270- def is_array_api_obj (x : object ) -> TypeIs [_ArrayApiObj ]: # pyright: ignore[reportUnknownParameterType]
266+ def is_array_api_obj (x : object ) -> TypeGuard [_ArrayApiObj ]:
271267 """
272268 Return True if `x` is an array API compatible array object.
273269
@@ -748,7 +744,7 @@ def device(x: _ArrayApiObj, /) -> Device:
748744 return "cpu"
749745 elif is_dask_array (x ):
750746 # Peek at the metadata of the Dask array to determine type
751- if is_numpy_array (x ._meta ): # pyright: ignore
747+ if is_numpy_array (x ._meta ):
752748 # Must be on CPU since backed by numpy
753749 return "cpu"
754750 return _DASK_DEVICE
@@ -777,7 +773,7 @@ def device(x: _ArrayApiObj, /) -> Device:
777773 return "cpu"
778774 # Return the device of the constituent array
779775 return device (inner ) # pyright: ignore
780- return x .device # pyright: ignore
776+ return x .device # type: ignore # pyright: ignore
781777
782778
783779# Prevent shadowing, used below
@@ -786,11 +782,11 @@ def device(x: _ArrayApiObj, /) -> Device:
786782
787783# Based on cupy.array_api.Array.to_device
788784def _cupy_to_device (
789- x : _CupyArray ,
785+ x : cp . ndarray ,
790786 device : Device ,
791787 / ,
792788 stream : int | Any | None = None ,
793- ) -> _CupyArray :
789+ ) -> cp . ndarray :
794790 import cupy as cp
795791
796792 if device == "cpu" :
@@ -819,7 +815,7 @@ def _torch_to_device(
819815 x : torch .Tensor ,
820816 device : torch .device | str | int ,
821817 / ,
822- stream : None = None ,
818+ stream : int | Any | None = None ,
823819) -> torch .Tensor :
824820 if stream is not None :
825821 raise NotImplementedError
@@ -885,7 +881,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
885881 # cupy does not yet have to_device
886882 return _cupy_to_device (x , device , stream = stream )
887883 elif is_torch_array (x ):
888- return _torch_to_device (x , device , stream = stream ) # pyright: ignore[reportArgumentType]
884+ return _torch_to_device (x , device , stream = stream )
889885 elif is_dask_array (x ):
890886 if stream is not None :
891887 raise ValueError ("The stream argument to to_device() is not supported" )
@@ -912,8 +908,6 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
912908@overload
913909def size (x : HasShape [Collection [SupportsIndex ]]) -> int : ...
914910@overload
915- def size (x : HasShape [Collection [None ]]) -> None : ...
916- @overload
917911def size (x : HasShape [Collection [SupportsIndex | None ]]) -> int | None : ...
918912def size (x : HasShape [Collection [SupportsIndex | None ]]) -> int | None :
919913 """
@@ -948,7 +942,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
948942 return None
949943
950944
951- def is_writeable_array (x : object ) -> bool :
945+ def is_writeable_array (x : object ) -> TypeGuard [ _ArrayApiObj ] :
952946 """
953947 Return False if ``x.__setitem__`` is expected to raise; True otherwise.
954948 Return False if `x` is not an array API compatible object.
@@ -986,7 +980,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
986980 return None
987981
988982
989- def is_lazy_array (x : object ) -> bool :
983+ def is_lazy_array (x : object ) -> TypeGuard [ _ArrayApiObj ] :
990984 """Return True if x is potentially a future or it may be otherwise impossible or
991985 expensive to eagerly read its contents, regardless of their size, e.g. by
992986 calling ``bool(x)`` or ``float(x)``.
0 commit comments