|
19 | 19 | from ._lib._utils._typing import Array, DType |
20 | 20 |
|
21 | 21 | __all__ = [ |
| 22 | + "atleast_nd", |
22 | 23 | "cov", |
23 | 24 | "expand_dims", |
24 | 25 | "isclose", |
|
29 | 30 | ] |
30 | 31 |
|
31 | 32 |
|
| 33 | +def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: |
| 34 | + """ |
| 35 | + Recursively expand the dimension of an array to at least `ndim`. |
| 36 | +
|
| 37 | + Parameters |
| 38 | + ---------- |
| 39 | + x : array |
| 40 | + Input array. |
| 41 | + ndim : int |
| 42 | + The minimum number of dimensions for the result. |
| 43 | + xp : array_namespace, optional |
| 44 | + The standard-compatible namespace for `x`. Default: infer. |
| 45 | +
|
| 46 | + Returns |
| 47 | + ------- |
| 48 | + array |
| 49 | + An array with ``res.ndim`` >= `ndim`. |
| 50 | + If ``x.ndim`` >= `ndim`, `x` is returned. |
| 51 | + If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes |
| 52 | + until ``res.ndim`` equals `ndim`. |
| 53 | +
|
| 54 | + Examples |
| 55 | + -------- |
| 56 | + >>> import array_api_strict as xp |
| 57 | + >>> import array_api_extra as xpx |
| 58 | + >>> x = xp.asarray([1]) |
| 59 | + >>> xpx.atleast_nd(x, ndim=3, xp=xp) |
| 60 | + Array([[[1]]], dtype=array_api_strict.int64) |
| 61 | +
|
| 62 | + >>> x = xp.asarray([[[1, 2], |
| 63 | + ... [3, 4]]]) |
| 64 | + >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x |
| 65 | + True |
| 66 | + """ |
| 67 | + if xp is None: |
| 68 | + xp = array_namespace(x) |
| 69 | + |
| 70 | + if 1 <= ndim <= 3 and ( |
| 71 | + is_numpy_namespace(xp) |
| 72 | + or is_jax_namespace(xp) |
| 73 | + or is_dask_namespace(xp) |
| 74 | + or is_cupy_namespace(xp) |
| 75 | + or is_torch_namespace(xp) |
| 76 | + ): |
| 77 | + return getattr(xp, f"atleast_{ndim}d")(x) |
| 78 | + |
| 79 | + return _funcs.atleast_nd(x, ndim=ndim, xp=xp) |
| 80 | + |
| 81 | + |
32 | 82 | def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: |
33 | 83 | """ |
34 | 84 | Estimate a covariance matrix. |
@@ -197,55 +247,6 @@ def expand_dims( |
197 | 247 | return _funcs.expand_dims(a, axis=axis, xp=xp) |
198 | 248 |
|
199 | 249 |
|
200 | | -def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: |
201 | | - """ |
202 | | - Recursively expand the dimension of an array to at least `ndim`. |
203 | | -
|
204 | | - Parameters |
205 | | - ---------- |
206 | | - x : array |
207 | | - Input array. |
208 | | - ndim : int |
209 | | - The minimum number of dimensions for the result. |
210 | | - xp : array_namespace, optional |
211 | | - The standard-compatible namespace for `x`. Default: infer. |
212 | | -
|
213 | | - Returns |
214 | | - ------- |
215 | | - array |
216 | | - An array with ``res.ndim`` >= `ndim`. |
217 | | - If ``x.ndim`` >= `ndim`, `x` is returned. |
218 | | - If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes |
219 | | - until ``res.ndim`` equals `ndim`. |
220 | | -
|
221 | | - Examples |
222 | | - -------- |
223 | | - >>> import array_api_strict as xp |
224 | | - >>> import array_api_extra as xpx |
225 | | - >>> x = xp.asarray([1]) |
226 | | - >>> xpx.atleast_nd(x, ndim=3, xp=xp) |
227 | | - Array([[[1]]], dtype=array_api_strict.int64) |
228 | | -
|
229 | | - >>> x = xp.asarray([[[1, 2], |
230 | | - ... [3, 4]]]) |
231 | | - >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x |
232 | | - True |
233 | | - """ |
234 | | - if xp is None: |
235 | | - xp = array_namespace(x) |
236 | | - |
237 | | - if 1 <= ndim <= 3 and ( |
238 | | - is_numpy_namespace(xp) |
239 | | - or is_jax_namespace(xp) |
240 | | - or is_dask_namespace(xp) |
241 | | - or is_cupy_namespace(xp) |
242 | | - or is_torch_namespace(xp) |
243 | | - ): |
244 | | - return getattr(xp, f"atleast_{ndim}d")(x) |
245 | | - |
246 | | - return _funcs.atleast_nd(x, ndim=ndim, xp=xp) |
247 | | - |
248 | | - |
249 | 250 | def isclose( |
250 | 251 | a: Array | complex, |
251 | 252 | b: Array | complex, |
@@ -553,6 +554,59 @@ def pad( |
553 | 554 | return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) |
554 | 555 |
|
555 | 556 |
|
| 557 | +def setdiff1d( |
| 558 | + x1: Array | complex, |
| 559 | + x2: Array | complex, |
| 560 | + /, |
| 561 | + *, |
| 562 | + assume_unique: bool = False, |
| 563 | + xp: ModuleType | None = None, |
| 564 | +) -> Array: |
| 565 | + """ |
| 566 | + Find the set difference of two arrays. |
| 567 | +
|
| 568 | + Return the unique values in `x1` that are not in `x2`. |
| 569 | +
|
| 570 | + Parameters |
| 571 | + ---------- |
| 572 | + x1 : array | int | float | complex | bool |
| 573 | + Input array. |
| 574 | + x2 : array |
| 575 | + Input comparison array. |
| 576 | + assume_unique : bool |
| 577 | + If ``True``, the input arrays are both assumed to be unique, which |
| 578 | + can speed up the calculation. Default is ``False``. |
| 579 | + xp : array_namespace, optional |
| 580 | + The standard-compatible namespace for `x1` and `x2`. Default: infer. |
| 581 | +
|
| 582 | + Returns |
| 583 | + ------- |
| 584 | + array |
| 585 | + 1D array of values in `x1` that are not in `x2`. The result |
| 586 | + is sorted when `assume_unique` is ``False``, but otherwise only sorted |
| 587 | + if the input is sorted. |
| 588 | +
|
| 589 | + Examples |
| 590 | + -------- |
| 591 | + >>> import array_api_strict as xp |
| 592 | + >>> import array_api_extra as xpx |
| 593 | +
|
| 594 | + >>> x1 = xp.asarray([1, 2, 3, 2, 4, 1]) |
| 595 | + >>> x2 = xp.asarray([3, 4, 5, 6]) |
| 596 | + >>> xpx.setdiff1d(x1, x2, xp=xp) |
| 597 | + Array([1, 2], dtype=array_api_strict.int64) |
| 598 | + """ |
| 599 | + |
| 600 | + if xp is None: |
| 601 | + xp = array_namespace(x1, x2) |
| 602 | + |
| 603 | + if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp): |
| 604 | + x1, x2 = asarrays(x1, x2, xp=xp) |
| 605 | + return xp.setdiff1d(x1, x2, assume_unique=assume_unique) |
| 606 | + |
| 607 | + return _funcs.setdiff1d(x1, x2, assume_unique=assume_unique, xp=xp) |
| 608 | + |
| 609 | + |
556 | 610 | def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
557 | 611 | r""" |
558 | 612 | Return the normalized sinc function. |
|
0 commit comments