Skip to content

Commit 24facbe

Browse files
ENH: setdiff1d: add delegation (#456)
Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
1 parent 13dcc24 commit 24facbe

File tree

4 files changed

+121
-92
lines changed

4 files changed

+121
-92
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
one_hot,
1212
pad,
1313
partition,
14+
setdiff1d,
1415
sinc,
1516
)
1617
from ._lib._at import at
@@ -21,7 +22,6 @@
2122
default_dtype,
2223
kron,
2324
nunique,
24-
setdiff1d,
2525
)
2626
from ._lib._lazy import lazy_apply
2727

src/array_api_extra/_delegation.py

Lines changed: 103 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ._lib._utils._typing import Array, DType
2020

2121
__all__ = [
22+
"atleast_nd",
2223
"cov",
2324
"expand_dims",
2425
"isclose",
@@ -29,6 +30,55 @@
2930
]
3031

3132

33+
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
34+
"""
35+
Recursively expand the dimension of an array to at least `ndim`.
36+
37+
Parameters
38+
----------
39+
x : array
40+
Input array.
41+
ndim : int
42+
The minimum number of dimensions for the result.
43+
xp : array_namespace, optional
44+
The standard-compatible namespace for `x`. Default: infer.
45+
46+
Returns
47+
-------
48+
array
49+
An array with ``res.ndim`` >= `ndim`.
50+
If ``x.ndim`` >= `ndim`, `x` is returned.
51+
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
52+
until ``res.ndim`` equals `ndim`.
53+
54+
Examples
55+
--------
56+
>>> import array_api_strict as xp
57+
>>> import array_api_extra as xpx
58+
>>> x = xp.asarray([1])
59+
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
60+
Array([[[1]]], dtype=array_api_strict.int64)
61+
62+
>>> x = xp.asarray([[[1, 2],
63+
... [3, 4]]])
64+
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
65+
True
66+
"""
67+
if xp is None:
68+
xp = array_namespace(x)
69+
70+
if 1 <= ndim <= 3 and (
71+
is_numpy_namespace(xp)
72+
or is_jax_namespace(xp)
73+
or is_dask_namespace(xp)
74+
or is_cupy_namespace(xp)
75+
or is_torch_namespace(xp)
76+
):
77+
return getattr(xp, f"atleast_{ndim}d")(x)
78+
79+
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)
80+
81+
3282
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
3383
"""
3484
Estimate a covariance matrix.
@@ -197,55 +247,6 @@ def expand_dims(
197247
return _funcs.expand_dims(a, axis=axis, xp=xp)
198248

199249

200-
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
201-
"""
202-
Recursively expand the dimension of an array to at least `ndim`.
203-
204-
Parameters
205-
----------
206-
x : array
207-
Input array.
208-
ndim : int
209-
The minimum number of dimensions for the result.
210-
xp : array_namespace, optional
211-
The standard-compatible namespace for `x`. Default: infer.
212-
213-
Returns
214-
-------
215-
array
216-
An array with ``res.ndim`` >= `ndim`.
217-
If ``x.ndim`` >= `ndim`, `x` is returned.
218-
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
219-
until ``res.ndim`` equals `ndim`.
220-
221-
Examples
222-
--------
223-
>>> import array_api_strict as xp
224-
>>> import array_api_extra as xpx
225-
>>> x = xp.asarray([1])
226-
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
227-
Array([[[1]]], dtype=array_api_strict.int64)
228-
229-
>>> x = xp.asarray([[[1, 2],
230-
... [3, 4]]])
231-
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
232-
True
233-
"""
234-
if xp is None:
235-
xp = array_namespace(x)
236-
237-
if 1 <= ndim <= 3 and (
238-
is_numpy_namespace(xp)
239-
or is_jax_namespace(xp)
240-
or is_dask_namespace(xp)
241-
or is_cupy_namespace(xp)
242-
or is_torch_namespace(xp)
243-
):
244-
return getattr(xp, f"atleast_{ndim}d")(x)
245-
246-
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)
247-
248-
249250
def isclose(
250251
a: Array | complex,
251252
b: Array | complex,
@@ -553,6 +554,59 @@ def pad(
553554
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
554555

555556

557+
def setdiff1d(
558+
x1: Array | complex,
559+
x2: Array | complex,
560+
/,
561+
*,
562+
assume_unique: bool = False,
563+
xp: ModuleType | None = None,
564+
) -> Array:
565+
"""
566+
Find the set difference of two arrays.
567+
568+
Return the unique values in `x1` that are not in `x2`.
569+
570+
Parameters
571+
----------
572+
x1 : array | int | float | complex | bool
573+
Input array.
574+
x2 : array
575+
Input comparison array.
576+
assume_unique : bool
577+
If ``True``, the input arrays are both assumed to be unique, which
578+
can speed up the calculation. Default is ``False``.
579+
xp : array_namespace, optional
580+
The standard-compatible namespace for `x1` and `x2`. Default: infer.
581+
582+
Returns
583+
-------
584+
array
585+
1D array of values in `x1` that are not in `x2`. The result
586+
is sorted when `assume_unique` is ``False``, but otherwise only sorted
587+
if the input is sorted.
588+
589+
Examples
590+
--------
591+
>>> import array_api_strict as xp
592+
>>> import array_api_extra as xpx
593+
594+
>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
595+
>>> x2 = xp.asarray([3, 4, 5, 6])
596+
>>> xpx.setdiff1d(x1, x2, xp=xp)
597+
Array([1, 2], dtype=array_api_strict.int64)
598+
"""
599+
600+
if xp is None:
601+
xp = array_namespace(x1, x2)
602+
603+
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
604+
x1, x2 = asarrays(x1, x2, xp=xp)
605+
return xp.setdiff1d(x1, x2, assume_unique=assume_unique)
606+
607+
return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp)
608+
609+
556610
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
557611
r"""
558612
Return the normalized sinc function.

src/array_api_extra/_lib/_funcs.py

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -715,44 +715,10 @@ def setdiff1d(
715715
/,
716716
*,
717717
assume_unique: bool = False,
718-
xp: ModuleType | None = None,
719-
) -> Array:
720-
"""
721-
Find the set difference of two arrays.
722-
723-
Return the unique values in `x1` that are not in `x2`.
724-
725-
Parameters
726-
----------
727-
x1 : array | int | float | complex | bool
728-
Input array.
729-
x2 : array
730-
Input comparison array.
731-
assume_unique : bool
732-
If ``True``, the input arrays are both assumed to be unique, which
733-
can speed up the calculation. Default is ``False``.
734-
xp : array_namespace, optional
735-
The standard-compatible namespace for `x1` and `x2`. Default: infer.
736-
737-
Returns
738-
-------
739-
array
740-
1D array of values in `x1` that are not in `x2`. The result
741-
is sorted when `assume_unique` is ``False``, but otherwise only sorted
742-
if the input is sorted.
743-
744-
Examples
745-
--------
746-
>>> import array_api_strict as xp
747-
>>> import array_api_extra as xpx
718+
xp: ModuleType,
719+
) -> Array: # numpydoc ignore=PR01,RT01
720+
"""See docstring in `array_api_extra._delegation.py`."""
748721

749-
>>> x1 = xp.asarray([1, 2, 3, 2, 4, 1])
750-
>>> x2 = xp.asarray([3, 4, 5, 6])
751-
>>> xpx.setdiff1d(x1, x2, xp=xp)
752-
Array([1, 2], dtype=array_api_strict.int64)
753-
"""
754-
if xp is None:
755-
xp = array_namespace(x1, x2)
756722
# https://github.com/microsoft/pyright/issues/10103
757723
x1_, x2_ = asarrays(x1, x2, xp=xp)
758724

tests/test_funcs.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@
3333
sinc,
3434
)
3535
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
36-
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
37-
from array_api_extra._lib._utils._compat import (
38-
device as get_device,
39-
)
36+
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
37+
from array_api_extra._lib._utils._compat import device as get_device
38+
from array_api_extra._lib._utils._compat import is_jax_namespace
4039
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
4140
from array_api_extra._lib._utils._typing import Array, Device
4241
from array_api_extra.testing import lazy_xp_function
@@ -1264,25 +1263,35 @@ def test_assume_unique(self, xp: ModuleType):
12641263
@pytest.mark.parametrize("shape2", [(), (1,), (1, 1)])
12651264
def test_shapes(
12661265
self,
1266+
request: pytest.FixtureRequest,
12671267
assume_unique: bool,
12681268
shape1: tuple[int, ...],
12691269
shape2: tuple[int, ...],
12701270
xp: ModuleType,
12711271
):
12721272
x1 = xp.zeros(shape1)
12731273
x2 = xp.zeros(shape2)
1274+
1275+
if is_jax_namespace(xp) and assume_unique and shape1 != (1,):
1276+
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")
1277+
12741278
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
12751279
xp_assert_equal(actual, xp.empty((0,)))
12761280

12771281
@assume_unique
12781282
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
1279-
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
1283+
def test_python_scalar(
1284+
self, request: pytest.FixtureRequest, xp: ModuleType, assume_unique: bool
1285+
):
12801286
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
12811287
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
12821288
x2 = 3
12831289
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
12841290
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))
12851291

1292+
if is_jax_namespace(xp) and assume_unique:
1293+
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")
1294+
12861295
actual = setdiff1d(x2, x1, assume_unique=assume_unique)
12871296
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))
12881297

0 commit comments

Comments
 (0)