From a6f6c93ac1e6135aeab9c09055455becc0d86e09 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 21 Oct 2025 11:58:10 +0200 Subject: [PATCH 01/25] Draft --- src/array_api_extra/_delegation.py | 37 ++++++++ src/array_api_extra/_lib/_quantile.py | 126 ++++++++++++++++++++++++++ 2 files changed, 163 insertions(+) create mode 100644 src/array_api_extra/_lib/_quantile.py diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 7f467366..045f9e5b 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -836,3 +836,40 @@ def argpartition( # kth is not small compared to x.size return _funcs.argpartition(a, kth, axis=axis, xp=xp) + + +def quantile( + a: Array, + q: float | Array, + /, + axis: int | None = None, + method: str = "linear", + keepdims: bool = False, + *, + xp: ModuleType | None = None, +) -> Array: + """ + TODO + """ + + methods = {"linear"} + if method not in methods: + message = f"`method` must be one of {methods}" + raise ValueError(message) + if xp is None: + xp = array_namespace(a) + if a.ndim < 1: + msg = "`a` must be at least 1-dimensional" + raise TypeError(msg) + + # Delegate where possible. + if is_numpy_namespace(xp) or is_dask_namespace(xp): + return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims) + is_linear = method == "linear" + if is_linear and is_jax_namespace(xp) or is_cupy_namespace(xp): + return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims) + if is_linear and is_torch_namespace(xp): + return xp.quantile(a, q, dim=axis, interpolation=method, keepdim=keepdims) + + # Otherwise call our implementation (will sort data) + return _funcs.quantile(a, q, axis=axis, method=method, keepdims=keepdims, xp=xp) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py new file mode 100644 index 00000000..a710ca40 --- /dev/null +++ b/src/array_api_extra/_lib/_quantile.py @@ -0,0 +1,126 @@ +from types import ModuleType + +import numpy as np +from scipy.stats._axis_nan_policy import _broadcast_arrays + +from ._at import at +from ._utils._compat import device as get_device +from ._utils._helpers import eager_shape +from ._utils._typing import Array + + +def quantile( # numpydoc ignore=PR01,RT01 + x, + p, + /, + method: str = 'linear', # noqa: ARG001 + axis: int | None = None, + keepdims: bool = False, + *, + xp: ModuleType, +): + """See docstring in `array_api_extra._delegation.py`.""" + # Input validation / standardization + temp = _quantile_iv(x, p, axis, keepdims) + y, p, axis, keepdims, n, axis_none, ndim = temp + + res = _quantile_hf(y, p, n, xp) + + # Reshape per axis/keepdims + if axis_none and keepdims: + shape = (1,)*(ndim - 1) + res.shape + res = xp.reshape(res, shape) + axis = -1 + + res = xp.moveaxis(res, -1, axis) + + if not keepdims: + res = xp.squeeze(res, axis=axis) + + return res[()] if res.ndim == 0 else res + + +def _quantile_iv( + x: Array, + p: Array, + axis: int | None, + keepdims: bool, + xp: ModuleType +): + + if not xp.isdtype(xp.asarray(x).dtype, ('integral', 'real floating')): + raise ValueError("`x` must have real dtype.") + + if not xp.isdtype(xp.asarray(p).dtype, 'real floating'): + raise ValueError("`p` must have real floating dtype.") + + p_mask = (p > 1) | (p < 0) | xp.isnan(p) + if xp.any(p_mask): + raise ValueError("`p` values must be in the range [0, 1]") + + device = get_device(x) + floating_dtype = xp.result_type(x, p) + x = xp.asarray(x, dtype=floating_dtype, device=device) + p = xp.asarray(p, dtype=floating_dtype, device=device) + dtype = x.dtype + + axis_none = axis is None + ndim = max(x.ndim, p.ndim) + if axis_none: + x = xp.reshape(x, (-1,)) + p = xp.reshape(p, (-1,)) + axis = 0 + elif np.iterable(axis) or int(axis) != axis: + message = "`axis` must be an integer or None." + raise ValueError(message) + elif (axis >= ndim) or (axis < -ndim): + message = "`axis` is not compatible with the shapes of the inputs." + raise ValueError(message) + axis = int(axis) + + if keepdims not in {None, True, False}: + message = "If specified, `keepdims` must be True or False." + raise ValueError(message) + + # If data has length zero along `axis`, the result will be an array of NaNs just + # as if the data had length 1 along axis and were filled with NaNs. + n = eager_shape(x, axis) + if n == 0: + shape = eager_shape(x) + shape[axis] = 1 + n = 1 + x = xp.full(shape, xp.nan, dtype=dtype, device=device) + + y = xp.sort(x, axis=axis, stable=False) + # FIXME: I still need to look into the broacasting: + y, p = _broadcast_arrays((y, p), axis=axis) + + p_shape = eager_shape(p) + if (keepdims is False) and (p_shape[axis] != 1): + message = "`keepdims` may be False only if the length of `p` along `axis` is 1." + raise ValueError(message) + keepdims = (p_shape[axis] != 1) if keepdims is None else keepdims + + y = xp.moveaxis(y, axis, -1) + p = xp.moveaxis(p, axis, -1) + + nans = xp.isnan(y) + nan_out = xp.any(nans, axis=-1) + if xp.any(nan_out): + y = xp.asarray(y, copy=True) # ensure writable + y = at(y, nan_out).set(xp.nan) + + return y, p, axis, keepdims, n, axis_none, ndim, xp + + +def _quantile_hf(y, p, n, xp): + m = 1 - p + jg = p*n + m - 1 + j = jg // 1 + g = jg % 1 + g[j < 0] = 0 + j = xp.clip(j, 0., n - 1) + jp1 = xp.clip(j + 1, 0., n - 1) + + return ((1 - g) * xp.take_along_axis(y, xp.astype(j, xp.int64), axis=-1) + + g * xp.take_along_axis(y, xp.astype(jp1, xp.int64), axis=-1)) From dc236da95e4d8c248d99c9ac1c481bdb6afd5cd7 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 21 Oct 2025 12:06:50 +0200 Subject: [PATCH 02/25] revert changes to renovate.json From f92fc4b57073854ac410ad3c3e0e118163b77e4f Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 21 Oct 2025 12:07:10 +0200 Subject: [PATCH 03/25] revert changes to renovate.json --- renovate.json | 102 ++++++++++++++++++++++++++++---------------------- 1 file changed, 57 insertions(+), 45 deletions(-) diff --git a/renovate.json b/renovate.json index c2ca1641..0810fe8b 100644 --- a/renovate.json +++ b/renovate.json @@ -1,6 +1,10 @@ { "$schema": "https://docs.renovatebot.com/renovate-schema.json", - "extends": ["config:recommended", "helpers:pinGitHubActionDigests", ":automergeMinor"], + "extends": [ + "config:recommended", + "helpers:pinGitHubActionDigests", + ":automergeMinor" + ], "dependencyDashboardTitle": "META: Dependency Dashboard", "commitMessagePrefix": "deps:", "labels": ["dependencies"], @@ -9,48 +13,56 @@ "pixi": ">=v0.45.0" }, "minimumReleaseAge": "14 days", - "packageRules": [{ - "description": "Do not bump deps pinned with '~=' or '='.", - "matchManagers": ["pixi"], - "matchCurrentValue": "/^~?=/", - "enabled": false - }, { - "description": "Do not bump requires-python.", - "matchManagers": ["pep621"], - "matchPackageNames": ["python"], - "enabled": false - }, { - "description": "Schedule automerged GHA updates for the 15th of each month.", - "matchManagers": ["github-actions"], - "groupName": "gha", - "schedule": ["* * 15 * *"], - "automerge": true - }, { - "description": "Block PRs for updates blocked on dropping Python 3.10.", - "matchManagers": ["pixi"], - "matchUpdateTypes": ["major", "minor"], - "matchPackageNames": [ - "numpy", - "jax", - "jaxlib", - "sphinx", - "ipython", - "sphinx-autodoc-typehints", - "pytorch" - ], - "enabled": false - }, { - "description": "Group Dask packages.", - "matchPackageNames": ["dask", "dask-core"], - "groupName": "dask" - }, { - "description": "Group JAX packages.", - "matchPackageNames": ["jax", "jaxlib"], - "groupName": "jax" - }, { - "description": "Schedule hypothesis monthly as releases are frequent.", - "matchManagers": ["pixi"], - "matchPackageNames": ["hypothesis"], - "schedule": ["* * 10 * *"] - }] + "packageRules": [ + { + "description": "Do not bump deps pinned with '~=' or '='.", + "matchManagers": ["pixi"], + "matchCurrentValue": "/^~?=/", + "enabled": false + }, + { + "description": "Do not bump requires-python.", + "matchManagers": ["pep621"], + "matchPackageNames": ["python"], + "enabled": false + }, + { + "description": "Schedule automerged GHA updates for the 15th of each month.", + "matchManagers": ["github-actions"], + "groupName": "gha", + "schedule": ["* * 15 * *"], + "automerge": true + }, + { + "description": "Block PRs for updates blocked on dropping Python 3.10.", + "matchManagers": ["pixi"], + "matchUpdateTypes": ["major", "minor"], + "matchPackageNames": [ + "numpy", + "jax", + "jaxlib", + "sphinx", + "ipython", + "sphinx-autodoc-typehints", + "pytorch" + ], + "enabled": false + }, + { + "description": "Group Dask packages.", + "matchPackageNames": ["dask", "dask-core"], + "groupName": "dask" + }, + { + "description": "Group JAX packages.", + "matchPackageNames": ["jax", "jaxlib"], + "groupName": "jax" + }, + { + "description": "Schedule hypothesis monthly as releases are frequent.", + "matchManagers": ["pixi"], + "matchPackageNames": ["hypothesis"], + "schedule": ["* * 10 * *"] + } + ] } From 06e370ae355f95c94a1100292226c7c2c4c21e2c Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 21 Oct 2025 18:39:48 +0200 Subject: [PATCH 04/25] untested implem; limited to method="linear"; trying to mimic numpy behavior --- src/array_api_extra/_delegation.py | 17 +++- src/array_api_extra/_lib/_quantile.py | 118 +++++++++----------------- 2 files changed, 54 insertions(+), 81 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index be149cad..9343a472 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -910,16 +910,29 @@ def quantile( """ TODO """ - methods = {"linear"} + if method not in methods: message = f"`method` must be one of {methods}" raise ValueError(message) + if keepdims not in {True, False}: + message = "If specified, `keepdims` must be True or False." + raise ValueError(message) if xp is None: xp = array_namespace(a) - if a.ndim < 1: + + a = xp.asarray(a) + if not xp.isdtype(a.dtype, ('integral', 'real floating')): + raise ValueError("`a` must have real dtype.") + if not xp.isdtype(xp.asarray(q).dtype, 'real floating'): + raise ValueError("`q` must have real floating dtype.") + ndim = a.ndim + if ndim < 1: msg = "`a` must be at least 1-dimensional" raise TypeError(msg) + if (axis >= ndim) or (axis < -ndim): + message = "`axis` is not compatible with the dimension of `a`." + raise ValueError(message) # Delegate where possible. if is_numpy_namespace(xp) or is_dask_namespace(xp): diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index ae4bd5da..44544d16 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -10,8 +10,8 @@ def quantile( # numpydoc ignore=PR01,RT01 - x, - p, + a: Array, + q: Array | float, /, method: str = 'linear', # noqa: ARG001 axis: int | None = None, @@ -20,100 +20,57 @@ def quantile( # numpydoc ignore=PR01,RT01 xp: ModuleType, ): """See docstring in `array_api_extra._delegation.py`.""" - # Input validation / standardization - temp = _quantile_iv(x, p, axis, keepdims) - y, p, axis, keepdims, n, axis_none, ndim = temp + device = get_device(a) + floating_dtype = xp.result_type(a, xp.asarray(q)) + a = xp.asarray(a, dtype=floating_dtype, device=device) + q = xp.asarray(q, dtype=floating_dtype, device=device) - res = _quantile_hf(y, p, n, xp) + if xp.any((q > 1) | (q < 0) | xp.isnan(q)): + raise ValueError("`q` values must be in the range [0, 1]") - # Reshape per axis/keepdims - if axis_none and keepdims: - shape = (1,)*(ndim - 1) + res.shape - res = xp.reshape(res, shape) - axis = -1 - - res = xp.moveaxis(res, -1, axis) - - if not keepdims: - res = xp.squeeze(res, axis=axis) - - return res[()] if res.ndim == 0 else res - - -def _quantile_iv( - x: Array, - p: Array, - axis: int | None, - keepdims: bool, - xp: ModuleType -): - - if not xp.isdtype(xp.asarray(x).dtype, ('integral', 'real floating')): - raise ValueError("`x` must have real dtype.") - - if not xp.isdtype(xp.asarray(p).dtype, 'real floating'): - raise ValueError("`p` must have real floating dtype.") - - p_mask = (p > 1) | (p < 0) | xp.isnan(p) - if xp.any(p_mask): - raise ValueError("`p` values must be in the range [0, 1]") - - device = get_device(x) - floating_dtype = xp.result_type(x, p) - x = xp.asarray(x, dtype=floating_dtype, device=device) - p = xp.asarray(p, dtype=floating_dtype, device=device) - dtype = x.dtype + q_scalar = q.ndim == 0 + if q_scalar: + q = xp.reshape(q, (1,)) axis_none = axis is None - ndim = max(x.ndim, p.ndim) if axis_none: - x = xp.reshape(x, (-1,)) - p = xp.reshape(p, (-1,)) + a = xp.reshape(a, (-1,)) axis = 0 - elif np.iterable(axis) or int(axis) != axis: - message = "`axis` must be an integer or None." - raise ValueError(message) - elif (axis >= ndim) or (axis < -ndim): - message = "`axis` is not compatible with the shapes of the inputs." - raise ValueError(message) axis = int(axis) - if keepdims not in {None, True, False}: - message = "If specified, `keepdims` must be True or False." - raise ValueError(message) - + n = eager_shape(a, axis) # If data has length zero along `axis`, the result will be an array of NaNs just # as if the data had length 1 along axis and were filled with NaNs. - n = eager_shape(x, axis) if n == 0: - shape = eager_shape(x) + shape = list(eager_shape(a)) shape[axis] = 1 n = 1 - x = xp.full(shape, xp.nan, dtype=dtype, device=device) - - y = xp.sort(x, axis=axis, stable=False) - # FIXME: I still need to look into the broadcasting: - y, p = _broadcast_arrays((y, p), axis=axis) + a = xp.full(shape, xp.nan, dtype=floating_dtype, device=device) - p_shape = eager_shape(p) - if (keepdims is False) and (p_shape[axis] != 1): - message = "`keepdims` may be False only if the length of `p` along `axis` is 1." - raise ValueError(message) - keepdims = (p_shape[axis] != 1) if keepdims is None else keepdims + a = xp.sort(a, axis=axis, stable=False) + # to support weights, the main thing would be to + # argsort a, and then use it to sort a and w. + # The hard part will be dealing with 0-weights and NaNs + # But maybe a proper use of searchsorted + left/right side will work? - y = xp.moveaxis(y, axis, -1) - p = xp.moveaxis(p, axis, -1) + res = _quantile_hf(a, q, n, axis, xp) - nans = xp.isnan(y) - nan_out = xp.any(nans, axis=-1) - if xp.any(nan_out): - y = xp.asarray(y, copy=True) # ensure writable - y = at(y, nan_out).set(xp.nan) + # reshaping to conform to doc/other libs' behavior + if axis_none: + if keepdims: + res = xp.reshape(res, q.shape + (1,) * a.ndim) + else: + res = xp.moveaxis(res, axis, 0) + if keepdims: + shape = list(a.shape) + shape[axis] = 1 + shape = q.shape + tuple(shape) + res = xp.reshape(res, shape) - return y, p, axis, keepdims, n, axis_none, ndim, xp + return res[0, ...] if q_scalar else res -def _quantile_hf(y, p, n, xp): +def _quantile_hf(y: Array, p: Array, n: int, axis: int, xp: ModuleType): m = 1 - p jg = p*n + m - 1 j = jg // 1 @@ -121,6 +78,9 @@ def _quantile_hf(y, p, n, xp): g[j < 0] = 0 j = xp.clip(j, 0., n - 1) jp1 = xp.clip(j + 1, 0., n - 1) + # `̀j` and `jp1` are 1d arrays - return ((1 - g) * xp.take_along_axis(y, xp.astype(j, xp.int64), axis=-1) - + g * xp.take_along_axis(y, xp.astype(jp1, xp.int64), axis=-1)) + return ( + (1 - g) * xp.take(y, xp.astype(j, xp.int64), axis=axis) + + g * xp.take(y, xp.astype(jp1, xp.int64), axis=axis) + ) From dc7a1e5e7f338120c59a2552bb6ff4b4dae40426 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 21 Oct 2025 18:41:03 +0200 Subject: [PATCH 05/25] remove unused imports --- src/array_api_extra/_lib/_quantile.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 44544d16..a8a14178 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -1,9 +1,5 @@ from types import ModuleType -import numpy as np -from scipy.stats._axis_nan_policy import _broadcast_arrays - -from ._at import at from ._utils._compat import device as get_device from ._utils._helpers import eager_shape from ._utils._typing import Array From 98fe39f8da33543f80a2c0577191b5dde4a83e43 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 22 Oct 2025 15:44:02 +0200 Subject: [PATCH 06/25] draft version with some tests that are passing --- src/array_api_extra/__init__.py | 2 + src/array_api_extra/_delegation.py | 23 +++++-- src/array_api_extra/_lib/_quantile.py | 18 +++-- tests/test_funcs.py | 99 +++++++++++++++++++++++++++ 4 files changed, 132 insertions(+), 10 deletions(-) diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 7c05552a..a3d6021a 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -11,6 +11,7 @@ one_hot, pad, partition, + quantile, sinc, ) from ._lib._at import at @@ -48,6 +49,7 @@ "one_hot", "pad", "partition", + "quantile", "setdiff1d", "sinc", ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 9343a472..eff99d1e 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -4,7 +4,7 @@ from types import ModuleType from typing import Literal -from ._lib import _funcs +from ._lib import _funcs, _quantile from ._lib._utils._compat import ( array_namespace, is_cupy_namespace, @@ -930,13 +930,28 @@ def quantile( if ndim < 1: msg = "`a` must be at least 1-dimensional" raise TypeError(msg) - if (axis >= ndim) or (axis < -ndim): + if axis is not None and ((axis >= ndim) or (axis < -ndim)): message = "`axis` is not compatible with the dimension of `a`." raise ValueError(message) + # Array API states: Mixed integer and floating-point type promotion rules + # are not specified because behavior varies between implementations. + # => We choose to do: + dtype = ( + xp.float64 if xp.isdtype(a.dtype, 'integral') + else xp.result_type(a, xp.asarray(q)) # both a and q are floats + ) + device = get_device(a) + a = xp.asarray(a, dtype=dtype, device=device) + q = xp.asarray(q, dtype=dtype, device=device) + + if xp.any((q > 1) | (q < 0) | xp.isnan(q)): + raise ValueError("`q` values must be in the range [0, 1]") + # Delegate where possible. - if is_numpy_namespace(xp) or is_dask_namespace(xp): + if is_numpy_namespace(xp): return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims) + # No delegating for dask: I couldn't make it work is_linear = method == "linear" if (is_linear and is_jax_namespace(xp)) or is_cupy_namespace(xp): return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims) @@ -944,4 +959,4 @@ def quantile( return xp.quantile(a, q, dim=axis, interpolation=method, keepdim=keepdims) # Otherwise call our implementation (will sort data) - return _funcs.quantile(a, q, axis=axis, method=method, keepdims=keepdims, xp=xp) + return _quantile.quantile(a, q, axis=axis, method=method, keepdims=keepdims, xp=xp) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index a8a14178..0eb06d58 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -17,7 +17,7 @@ def quantile( # numpydoc ignore=PR01,RT01 ): """See docstring in `array_api_extra._delegation.py`.""" device = get_device(a) - floating_dtype = xp.result_type(a, xp.asarray(q)) + floating_dtype = xp.float64 #xp.result_type(a, xp.asarray(q)) a = xp.asarray(a, dtype=floating_dtype, device=device) q = xp.asarray(q, dtype=floating_dtype, device=device) @@ -29,12 +29,13 @@ def quantile( # numpydoc ignore=PR01,RT01 q = xp.reshape(q, (1,)) axis_none = axis is None + a_ndim = a.ndim if axis_none: a = xp.reshape(a, (-1,)) axis = 0 axis = int(axis) - n = eager_shape(a, axis) + n, = eager_shape(a, axis) # If data has length zero along `axis`, the result will be an array of NaNs just # as if the data had length 1 along axis and were filled with NaNs. if n == 0: @@ -49,12 +50,12 @@ def quantile( # numpydoc ignore=PR01,RT01 # The hard part will be dealing with 0-weights and NaNs # But maybe a proper use of searchsorted + left/right side will work? - res = _quantile_hf(a, q, n, axis, xp) + res = _quantile_hf(a, q, float(n), axis, xp) # reshaping to conform to doc/other libs' behavior if axis_none: if keepdims: - res = xp.reshape(res, q.shape + (1,) * a.ndim) + res = xp.reshape(res, q.shape + (1,) * a_ndim) else: res = xp.moveaxis(res, axis, 0) if keepdims: @@ -69,13 +70,18 @@ def quantile( # numpydoc ignore=PR01,RT01 def _quantile_hf(y: Array, p: Array, n: int, axis: int, xp: ModuleType): m = 1 - p jg = p*n + m - 1 + j = jg // 1 - g = jg % 1 - g[j < 0] = 0 j = xp.clip(j, 0., n - 1) jp1 = xp.clip(j + 1, 0., n - 1) # `̀j` and `jp1` are 1d arrays + g = jg % 1 + g = xp.where(j < 0, 0, g) # equiv to g[j < 0] = 0, but work with strictest + new_g_shape = [1] * y.ndim + new_g_shape[axis] = g.shape[0] + g = xp.reshape(g, tuple(new_g_shape)) + return ( (1 - g) * xp.take(y, xp.astype(j, xp.int64), axis=axis) + g * xp.take(y, xp.astype(jp1, xp.int64), axis=axis) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 92e794ed..4cacfb21 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -29,6 +29,7 @@ one_hot, pad, partition, + quantile, setdiff1d, sinc, ) @@ -1529,3 +1530,101 @@ def test_kind(self, xp: ModuleType, library: Backend): expected = xp.asarray([False, True, False, True]) res = isin(a, b, kind="sort") xp_assert_equal(res, expected) + + +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no xp.take") +class TestQuantile: + def test_basic(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + actual = quantile(x, 0.5) + expect = xp.asarray(3.0, dtype=xp.float64) + xp_assert_close(actual, expect) + + def test_multiple_quantiles(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + actual = quantile(x, xp.asarray([0.25, 0.5, 0.75])) + expect = xp.asarray([2.0, 3.0, 4.0], dtype=xp.float64) + xp_assert_close(actual, expect) + + def test_shape(self, xp: ModuleType): + a = xp.asarray(np.random.rand(3, 4, 5)) + q = xp.asarray(np.random.rand(2)) + assert quantile(a, q, axis=0).shape == (2, 4, 5) + assert quantile(a, q, axis=1).shape == (2, 3, 5) + assert quantile(a, q, axis=2).shape == (2, 3, 4) + + assert quantile(a, q, axis=0, keepdims=True).shape == (2, 1, 4, 5) + assert quantile(a, q, axis=1, keepdims=True).shape == (2, 3, 1, 5) + assert quantile(a, q, axis=2, keepdims=True).shape == (2, 3, 4, 1) + + def test_against_numpy(self, xp: ModuleType): + a_np = np.random.rand(3, 4, 5) + q_np = np.random.rand(2) + a = xp.asarray(a_np) + q = xp.asarray(q_np) + for keepdims in [False, True]: + for axis in [None, *range(a.ndim)]: + actual = quantile(a, q, axis=axis, keepdims=keepdims) + expected = np.quantile(a_np, q_np, axis=axis, keepdims=keepdims) + expected = xp.asarray(expected, dtype=xp.float64) + xp_assert_close(actual, expected, atol=1e-12) + + def test_2d_axis(self, xp: ModuleType): + x = xp.asarray([[1, 2, 3], [4, 5, 6]]) + actual = quantile(x, 0.5, axis=0) + expect = xp.asarray([2.5, 3.5, 4.5], dtype=xp.float64) + xp_assert_close(actual, expect) + + def test_2d_axis_keepdims(self, xp: ModuleType): + x = xp.asarray([[1, 2, 3], [4, 5, 6]]) + actual = quantile(x, 0.5, axis=0, keepdims=True) + expect = xp.asarray([[2.5, 3.5, 4.5]], dtype=xp.float64) + xp_assert_close(actual, expect) + + def test_methods(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + methods = ["linear"] #"hazen", "weibull"] + for method in methods: + actual = quantile(x, 0.5, method=method) + # All methods should give reasonable results + assert 2.5 <= float(actual) <= 3.5 + + def test_edge_cases(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + # q = 0 should give minimum + actual = quantile(x, 0.0) + expect = xp.asarray(1.0, dtype=xp.float64) + xp_assert_close(actual, expect) + + # q = 1 should give maximum + actual = quantile(x, 1.0) + expect = xp.asarray(5.0, dtype=xp.float64) + xp_assert_close(actual, expect) + + def test_invalid_q(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + _ = quantile(x, 1.0) + # ^ FIXME: here just to make this test fail for sparse backend + # q > 1 should raise + with pytest.raises( + ValueError, match=r"`q` values must be in the range \[0, 1\]" + ): + _ = quantile(x, 1.5) + # q < 0 should raise + with pytest.raises( + ValueError, match=r"`q` values must be in the range \[0, 1\]" + ): + _ = quantile(x, -0.5) + + def test_device(self, xp: ModuleType, device: Device): + if hasattr(device, 'type') and device.type == "meta": + pytest.xfail("No Tensor.item() on meta device") + x = xp.asarray([1, 2, 3, 4, 5], device=device) + actual = quantile(x, 0.5) + assert get_device(actual) == device + + def test_xp(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + actual = quantile(x, 0.5, xp=xp) + expect = xp.asarray(3.0, dtype=xp.float64) + xp_assert_close(actual, expect) From 034c064938d83df19a3964aa539b00daec43b788 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 22 Oct 2025 16:32:52 +0200 Subject: [PATCH 07/25] linting: fix pyright --- src/array_api_extra/_lib/_quantile.py | 26 +++++++++++++------------- tests/test_funcs.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 0eb06d58..bd9bf642 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -19,14 +19,14 @@ def quantile( # numpydoc ignore=PR01,RT01 device = get_device(a) floating_dtype = xp.float64 #xp.result_type(a, xp.asarray(q)) a = xp.asarray(a, dtype=floating_dtype, device=device) - q = xp.asarray(q, dtype=floating_dtype, device=device) + p: Array = xp.asarray(q, dtype=floating_dtype, device=device) - if xp.any((q > 1) | (q < 0) | xp.isnan(q)): + if xp.any((p > 1) | (p < 0) | xp.isnan(p)): raise ValueError("`q` values must be in the range [0, 1]") - q_scalar = q.ndim == 0 + q_scalar = p.ndim == 0 if q_scalar: - q = xp.reshape(q, (1,)) + p = xp.reshape(p, (1,)) axis_none = axis is None a_ndim = a.ndim @@ -50,26 +50,26 @@ def quantile( # numpydoc ignore=PR01,RT01 # The hard part will be dealing with 0-weights and NaNs # But maybe a proper use of searchsorted + left/right side will work? - res = _quantile_hf(a, q, float(n), axis, xp) + res = _quantile_hf(a, p, float(n), axis, xp) # reshaping to conform to doc/other libs' behavior if axis_none: if keepdims: - res = xp.reshape(res, q.shape + (1,) * a_ndim) + res = xp.reshape(res, p.shape + (1,) * a_ndim) else: res = xp.moveaxis(res, axis, 0) if keepdims: shape = list(a.shape) shape[axis] = 1 - shape = q.shape + tuple(shape) + shape = p.shape + tuple(shape) res = xp.reshape(res, shape) return res[0, ...] if q_scalar else res -def _quantile_hf(y: Array, p: Array, n: int, axis: int, xp: ModuleType): - m = 1 - p - jg = p*n + m - 1 +def _quantile_hf(a: Array, q: Array, n: float, axis: int, xp: ModuleType): + m = 1 - q + jg = q*n + m - 1 j = jg // 1 j = xp.clip(j, 0., n - 1) @@ -78,11 +78,11 @@ def _quantile_hf(y: Array, p: Array, n: int, axis: int, xp: ModuleType): g = jg % 1 g = xp.where(j < 0, 0, g) # equiv to g[j < 0] = 0, but work with strictest - new_g_shape = [1] * y.ndim + new_g_shape = [1] * a.ndim new_g_shape[axis] = g.shape[0] g = xp.reshape(g, tuple(new_g_shape)) return ( - (1 - g) * xp.take(y, xp.astype(j, xp.int64), axis=axis) - + g * xp.take(y, xp.astype(jp1, xp.int64), axis=axis) + (1 - g) * xp.take(a, xp.astype(j, xp.int64), axis=axis) + + g * xp.take(a, xp.astype(jp1, xp.int64), axis=axis) ) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 4cacfb21..78ca9518 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1617,7 +1617,7 @@ def test_invalid_q(self, xp: ModuleType): _ = quantile(x, -0.5) def test_device(self, xp: ModuleType, device: Device): - if hasattr(device, 'type') and device.type == "meta": + if hasattr(device, 'type') and getattr(device, 'type') == "meta": pytest.xfail("No Tensor.item() on meta device") x = xp.asarray([1, 2, 3, 4, 5], device=device) actual = quantile(x, 0.5) From 05ffb7b183cf84823d06b5c155e028a7cde845a4 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 22 Oct 2025 16:42:53 +0200 Subject: [PATCH 08/25] linting: fix mypy --- src/array_api_extra/_lib/_quantile.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index bd9bf642..e0436015 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -14,11 +14,12 @@ def quantile( # numpydoc ignore=PR01,RT01 keepdims: bool = False, *, xp: ModuleType, -): +) -> Array: """See docstring in `array_api_extra._delegation.py`.""" device = get_device(a) floating_dtype = xp.float64 #xp.result_type(a, xp.asarray(q)) a = xp.asarray(a, dtype=floating_dtype, device=device) + a_shape = list(a.shape) p: Array = xp.asarray(q, dtype=floating_dtype, device=device) if xp.any((p > 1) | (p < 0) | xp.isnan(p)): @@ -30,19 +31,19 @@ def quantile( # numpydoc ignore=PR01,RT01 axis_none = axis is None a_ndim = a.ndim - if axis_none: + if axis is None: a = xp.reshape(a, (-1,)) axis = 0 - axis = int(axis) + else: + axis = int(axis) n, = eager_shape(a, axis) # If data has length zero along `axis`, the result will be an array of NaNs just # as if the data had length 1 along axis and were filled with NaNs. if n == 0: - shape = list(eager_shape(a)) - shape[axis] = 1 + a_shape[axis] = 1 n = 1 - a = xp.full(shape, xp.nan, dtype=floating_dtype, device=device) + a = xp.full(tuple(a_shape), xp.nan, dtype=floating_dtype, device=device) a = xp.sort(a, axis=axis, stable=False) # to support weights, the main thing would be to @@ -59,15 +60,13 @@ def quantile( # numpydoc ignore=PR01,RT01 else: res = xp.moveaxis(res, axis, 0) if keepdims: - shape = list(a.shape) - shape[axis] = 1 - shape = p.shape + tuple(shape) - res = xp.reshape(res, shape) + a_shape[axis] = 1 + res = xp.reshape(res, p.shape + tuple(a_shape)) return res[0, ...] if q_scalar else res -def _quantile_hf(a: Array, q: Array, n: float, axis: int, xp: ModuleType): +def _quantile_hf(a: Array, q: Array, n: float, axis: int, xp: ModuleType) -> Array: m = 1 - q jg = q*n + m - 1 From 89d84109e17c4f8379d16fac2c3468f61554a6fc Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 22 Oct 2025 17:09:39 +0200 Subject: [PATCH 09/25] fixed linting --- src/array_api_extra/_delegation.py | 168 +++++++++++++++++++++++--- src/array_api_extra/_lib/_quantile.py | 28 ++--- tests/test_funcs.py | 14 ++- 3 files changed, 173 insertions(+), 37 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index eff99d1e..1eea8215 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -768,7 +768,7 @@ def argpartition( Axis along which to partition. The default is ``-1`` (the last axis). If ``None``, the flattened array is used. xp : array_namespace, optional - The standard-compatible namespace for `x`. Default: infer. + The standard-compatible namespace for `a`. Default: infer. Returns ------- @@ -908,45 +908,179 @@ def quantile( xp: ModuleType | None = None, ) -> Array: """ - TODO + Compute the q-th quantile of the data along the specified axis. + + Parameters + ---------- + a : array_like of real numbers + Input array or object that can be converted to an array. + q : array_like of float + Probability or sequence of probabilities of the quantiles to compute. + Values must be between 0 and 1 inclusive. + axis : {int, tuple of int, None}, optional + Axis or axes along which the quantiles are computed. The default is + to compute the quantile(s) along a flattened version of the array. + method : str, optional + This parameter specifies the method to use for estimating the + quantile. There are many different methods. + The recommended options, numbered as they appear in [1]_, are: + + 1. 'inverted_cdf' + 2. 'averaged_inverted_cdf' + 3. 'closest_observation' + 4. 'interpolated_inverted_cdf' + 5. 'hazen' + 6. 'weibull' + 7. 'linear' (default) + 8. 'median_unbiased' + 9. 'normal_unbiased' + + The first three methods are discontinuous. + Only 'linear' is implemented for now. + + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in + the result as dimensions with size one. With this option, the + result will broadcast correctly against the original array `a`. + + xp : array_namespace, optional + The standard-compatible namespace for `a` and `q`. Default: infer. + + Returns + ------- + scalar or ndarray + If `q` is a single probability and `axis=None`, then the result + is a scalar. If multiple probability levels are given, first axis + of the result corresponds to the quantiles. The other axes are + the axes that remain after the reduction of `a`. If the input + contains integers or floats smaller than ``float64``, the output + data-type is ``float64``. Otherwise, the output data-type is the + same as that of the input. If `out` is specified, that array is + returned instead. + + Notes + ----- + Given a sample `a` from an underlying distribution, `quantile` provides a + nonparametric estimate of the inverse cumulative distribution function. + + By default, this is done by interpolating between adjacent elements in + ``y``, a sorted copy of `a`:: + + (1-g)*y[j] + g*y[j+1] + + where the index ``j`` and coefficient ``g`` are the integral and + fractional components of ``q * (n-1)``, and ``n`` is the number of + elements in the sample. + + This is a special case of Equation 1 of H&F [1]_. More generally, + + - ``j = (q*n + m - 1) // 1``, and + - ``g = (q*n + m - 1) % 1``, + + where ``m`` may be defined according to several different conventions. + The preferred convention may be selected using the ``method`` parameter: + + =============================== =============== =============== + ``method`` number in H&F ``m`` + =============================== =============== =============== + ``interpolated_inverted_cdf`` 4 ``0`` + ``hazen`` 5 ``1/2`` + ``weibull`` 6 ``q`` + ``linear`` (default) 7 ``1 - q`` + ``median_unbiased`` 8 ``q/3 + 1/3`` + ``normal_unbiased`` 9 ``q/4 + 3/8`` + =============================== =============== =============== + + Note that indices ``j`` and ``j + 1`` are clipped to the range ``0`` to + ``n - 1`` when the results of the formula would be outside the allowed + range of non-negative indices. The ``- 1`` in the formulas for ``j`` and + ``g`` accounts for Python's 0-based indexing. + + The table above includes only the estimators from H&F that are continuous + functions of probability `q` (estimators 4-9). NumPy also provides the + three discontinuous estimators from H&F (estimators 1-3), where ``j`` is + defined as above, ``m`` is defined as follows, and ``g`` is a function + of the real-valued ``index = q*n + m - 1`` and ``j``. + + 1. ``inverted_cdf``: ``m = 0`` and ``g = int(index - j > 0)`` + 2. ``averaged_inverted_cdf``: ``m = 0`` and + ``g = (1 + int(index - j > 0)) / 2`` + 3. ``closest_observation``: ``m = -1/2`` and + ``g = 1 - int((index == j) & (j%2 == 1))`` + + **Weighted quantiles:** + More formally, the quantile at probability level :math:`q` of a cumulative + distribution function :math:`F(y)=P(Y \\leq y)` with probability measure + :math:`P` is defined as any number :math:`x` that fulfills the + *coverage conditions* + + .. math:: P(Y < x) \\leq q \\quad\\text{and}\\quad P(Y \\leq x) \\geq q + + with random variable :math:`Y\\sim P`. + Sample quantiles, the result of `quantile`, provide nonparametric + estimation of the underlying population counterparts, represented by the + unknown :math:`F`, given a data vector `a` of length ``n``. + + Some of the estimators above arise when one considers :math:`F` as the + empirical distribution function of the data, i.e. + :math:`F(y) = \\frac{1}{n} \\sum_i 1_{a_i \\leq y}`. + Then, different methods correspond to different choices of :math:`x` that + fulfill the above coverage conditions. Methods that follow this approach + are ``inverted_cdf`` and ``averaged_inverted_cdf``. + + For weighted quantiles, the coverage conditions still hold. The + empirical cumulative distribution is simply replaced by its weighted + version, i.e. + :math:`P(Y \\leq t) = \\frac{1}{\\sum_i w_i} \\sum_i w_i 1_{x_i \\leq t}`. + Only ``method="inverted_cdf"`` supports weights. + + References + ---------- + .. [1] R. J. Hyndman and Y. Fan, + "Sample quantiles in statistical packages," + The American Statistician, 50(4), pp. 361-365, 1996 """ methods = {"linear"} if method not in methods: - message = f"`method` must be one of {methods}" - raise ValueError(message) + msg = f"`method` must be one of {methods}" + raise ValueError(msg) if keepdims not in {True, False}: - message = "If specified, `keepdims` must be True or False." - raise ValueError(message) + msg = "If specified, `keepdims` must be True or False." + raise ValueError(msg) if xp is None: xp = array_namespace(a) a = xp.asarray(a) - if not xp.isdtype(a.dtype, ('integral', 'real floating')): - raise ValueError("`a` must have real dtype.") - if not xp.isdtype(xp.asarray(q).dtype, 'real floating'): - raise ValueError("`q` must have real floating dtype.") + if not xp.isdtype(a.dtype, ("integral", "real floating")): + msg = "`a` must have real dtype." + raise ValueError(msg) + if not xp.isdtype(xp.asarray(q).dtype, "real floating"): + msg = "`q` must have real floating dtype." + raise ValueError(msg) ndim = a.ndim if ndim < 1: msg = "`a` must be at least 1-dimensional" raise TypeError(msg) if axis is not None and ((axis >= ndim) or (axis < -ndim)): - message = "`axis` is not compatible with the dimension of `a`." - raise ValueError(message) + msg = "`axis` is not compatible with the dimension of `a`." + raise ValueError(msg) # Array API states: Mixed integer and floating-point type promotion rules # are not specified because behavior varies between implementations. - # => We choose to do: - dtype = ( - xp.float64 if xp.isdtype(a.dtype, 'integral') - else xp.result_type(a, xp.asarray(q)) # both a and q are floats + # We chose to align with numpy (see docstring): + dtype = xp.result_type( + xp.float64 if xp.isdtype(a.dtype, "integral") else a, + xp.asarray(q), + xp.float64, # at least float64 ) device = get_device(a) a = xp.asarray(a, dtype=dtype, device=device) q = xp.asarray(q, dtype=dtype, device=device) if xp.any((q > 1) | (q < 0) | xp.isnan(q)): - raise ValueError("`q` values must be in the range [0, 1]") + msg = "`q` values must be in the range [0, 1]" + raise ValueError(msg) # Delegate where possible. if is_numpy_namespace(xp): diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index e0436015..b9ba2158 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -1,3 +1,5 @@ +"""Implementations of the quantile function.""" + from types import ModuleType from ._utils._compat import device as get_device @@ -9,7 +11,7 @@ def quantile( # numpydoc ignore=PR01,RT01 a: Array, q: Array | float, /, - method: str = 'linear', # noqa: ARG001 + method: str = "linear", # noqa: ARG001 axis: int | None = None, keepdims: bool = False, *, @@ -17,14 +19,11 @@ def quantile( # numpydoc ignore=PR01,RT01 ) -> Array: """See docstring in `array_api_extra._delegation.py`.""" device = get_device(a) - floating_dtype = xp.float64 #xp.result_type(a, xp.asarray(q)) + floating_dtype = xp.float64 # xp.result_type(a, xp.asarray(q)) a = xp.asarray(a, dtype=floating_dtype, device=device) a_shape = list(a.shape) p: Array = xp.asarray(q, dtype=floating_dtype, device=device) - if xp.any((p > 1) | (p < 0) | xp.isnan(p)): - raise ValueError("`q` values must be in the range [0, 1]") - q_scalar = p.ndim == 0 if q_scalar: p = xp.reshape(p, (1,)) @@ -37,7 +36,7 @@ def quantile( # numpydoc ignore=PR01,RT01 else: axis = int(axis) - n, = eager_shape(a, axis) + (n,) = eager_shape(a, axis) # If data has length zero along `axis`, the result will be an array of NaNs just # as if the data had length 1 along axis and were filled with NaNs. if n == 0: @@ -66,22 +65,23 @@ def quantile( # numpydoc ignore=PR01,RT01 return res[0, ...] if q_scalar else res -def _quantile_hf(a: Array, q: Array, n: float, axis: int, xp: ModuleType) -> Array: +def _quantile_hf( # numpydoc ignore=GL08 + a: Array, q: Array, n: float, axis: int, xp: ModuleType +) -> Array: m = 1 - q - jg = q*n + m - 1 + jg = q * n + m - 1 j = jg // 1 - j = xp.clip(j, 0., n - 1) - jp1 = xp.clip(j + 1, 0., n - 1) + j = xp.clip(j, 0.0, n - 1) + jp1 = xp.clip(j + 1, 0.0, n - 1) # `̀j` and `jp1` are 1d arrays g = jg % 1 - g = xp.where(j < 0, 0, g) # equiv to g[j < 0] = 0, but work with strictest + g = xp.where(j < 0, 0, g) # equivalent to g[j < 0] = 0, but works with strictest new_g_shape = [1] * a.ndim new_g_shape[axis] = g.shape[0] g = xp.reshape(g, tuple(new_g_shape)) - return ( - (1 - g) * xp.take(a, xp.astype(j, xp.int64), axis=axis) - + g * xp.take(a, xp.astype(jp1, xp.int64), axis=axis) + return (1 - g) * xp.take(a, xp.astype(j, xp.int64), axis=axis) + g * xp.take( + a, xp.astype(jp1, xp.int64), axis=axis ) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 78ca9518..200a0a74 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1547,8 +1547,9 @@ def test_multiple_quantiles(self, xp: ModuleType): xp_assert_close(actual, expect) def test_shape(self, xp: ModuleType): - a = xp.asarray(np.random.rand(3, 4, 5)) - q = xp.asarray(np.random.rand(2)) + rng = np.random.default_rng() + a = xp.asarray(rng.random((3, 4, 5))) + q = xp.asarray(rng.random(2)) assert quantile(a, q, axis=0).shape == (2, 4, 5) assert quantile(a, q, axis=1).shape == (2, 3, 5) assert quantile(a, q, axis=2).shape == (2, 3, 4) @@ -1558,8 +1559,9 @@ def test_shape(self, xp: ModuleType): assert quantile(a, q, axis=2, keepdims=True).shape == (2, 3, 4, 1) def test_against_numpy(self, xp: ModuleType): - a_np = np.random.rand(3, 4, 5) - q_np = np.random.rand(2) + rng = np.random.default_rng() + a_np = rng.random((3, 4, 5)) + q_np = rng.random(2) a = xp.asarray(a_np) q = xp.asarray(q_np) for keepdims in [False, True]: @@ -1583,7 +1585,7 @@ def test_2d_axis_keepdims(self, xp: ModuleType): def test_methods(self, xp: ModuleType): x = xp.asarray([1, 2, 3, 4, 5]) - methods = ["linear"] #"hazen", "weibull"] + methods = ["linear"] # "hazen", "weibull"] for method in methods: actual = quantile(x, 0.5, method=method) # All methods should give reasonable results @@ -1617,7 +1619,7 @@ def test_invalid_q(self, xp: ModuleType): _ = quantile(x, -0.5) def test_device(self, xp: ModuleType, device: Device): - if hasattr(device, 'type') and getattr(device, 'type') == "meta": + if hasattr(device, "type") and device.type == "meta": # pyright: ignore[reportAttributeAccessIssue] pytest.xfail("No Tensor.item() on meta device") x = xp.asarray([1, 2, 3, 4, 5], device=device) actual = quantile(x, 0.5) From 19fa6ead7f498f5ff1c4a51a82067f4645b08722 Mon Sep 17 00:00:00 2001 From: Arthur Date: Wed, 22 Oct 2025 18:50:50 +0200 Subject: [PATCH 10/25] WIP: adding support for weights --- src/array_api_extra/_delegation.py | 23 ++++++-- src/array_api_extra/_lib/_quantile.py | 79 +++++++++++++++++++++------ tests/test_funcs.py | 2 +- 3 files changed, 81 insertions(+), 23 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 1eea8215..3fb8dee1 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -905,6 +905,7 @@ def quantile( method: str = "linear", keepdims: bool = False, *, + weights: Array | None = None, xp: ModuleType | None = None, ) -> Array: """ @@ -943,6 +944,16 @@ def quantile( the result as dimensions with size one. With this option, the result will broadcast correctly against the original array `a`. + weights : array_like, optional + An array of weights associated with the values in `a`. Each value in + `a` contributes to the quantile according to its associated weight. + The weights array can either be 1-D (in which case its length must be + the size of `a` along the given axis) or of the same shape as `a`. + If `weights=None`, then all data in `a` are assumed to have a + weight equal to one. + Only `method="inverted_cdf"` or `method="averaged_inverted_cdf"` + support weights. See the notes for more details. + xp : array_namespace, optional The standard-compatible namespace for `a` and `q`. Default: infer. @@ -1040,7 +1051,7 @@ def quantile( "Sample quantiles in statistical packages," The American Statistician, 50(4), pp. 361-365, 1996 """ - methods = {"linear"} + methods = {"linear", "inverted_cdf", "averaged_inverted_cdf"} if method not in methods: msg = f"`method` must be one of {methods}" @@ -1084,12 +1095,12 @@ def quantile( # Delegate where possible. if is_numpy_namespace(xp): + return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims, weights=weights) + # No delegation for dask: I couldn't make it work + basic_case = method == "linear" and weights is None + if (basic_case and is_jax_namespace(xp)) or is_cupy_namespace(xp): return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims) - # No delegating for dask: I couldn't make it work - is_linear = method == "linear" - if (is_linear and is_jax_namespace(xp)) or is_cupy_namespace(xp): - return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims) - if is_linear and is_torch_namespace(xp): + if basic_case and is_torch_namespace(xp): return xp.quantile(a, q, dim=axis, interpolation=method, keepdim=keepdims) # Otherwise call our implementation (will sort data) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index b9ba2158..838a3d3f 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -9,24 +9,22 @@ def quantile( # numpydoc ignore=PR01,RT01 a: Array, - q: Array | float, + q: Array, /, - method: str = "linear", # noqa: ARG001 + method: str = "linear", axis: int | None = None, keepdims: bool = False, *, + weights: Array | None = None, xp: ModuleType, ) -> Array: """See docstring in `array_api_extra._delegation.py`.""" device = get_device(a) - floating_dtype = xp.float64 # xp.result_type(a, xp.asarray(q)) - a = xp.asarray(a, dtype=floating_dtype, device=device) a_shape = list(a.shape) - p: Array = xp.asarray(q, dtype=floating_dtype, device=device) - q_scalar = p.ndim == 0 + q_scalar = q.ndim == 0 if q_scalar: - p = xp.reshape(p, (1,)) + q = xp.reshape(q, (1,)) axis_none = axis is None a_ndim = a.ndim @@ -42,33 +40,41 @@ def quantile( # numpydoc ignore=PR01,RT01 if n == 0: a_shape[axis] = 1 n = 1 - a = xp.full(tuple(a_shape), xp.nan, dtype=floating_dtype, device=device) + a = xp.full(tuple(a_shape), xp.nan, dtype=a.dtype, device=device) - a = xp.sort(a, axis=axis, stable=False) + if weights is None: + res = _quantile(a, q, float(n), axis, method, xp) + else: + average = method == 'averaged_inverted_cdf' + res = _weighted_quantile(a, q, weights, n, axis, average, xp) # to support weights, the main thing would be to # argsort a, and then use it to sort a and w. # The hard part will be dealing with 0-weights and NaNs # But maybe a proper use of searchsorted + left/right side will work? - res = _quantile_hf(a, p, float(n), axis, xp) - # reshaping to conform to doc/other libs' behavior if axis_none: if keepdims: - res = xp.reshape(res, p.shape + (1,) * a_ndim) + res = xp.reshape(res, q.shape + (1,) * a_ndim) else: res = xp.moveaxis(res, axis, 0) if keepdims: a_shape[axis] = 1 - res = xp.reshape(res, p.shape + tuple(a_shape)) + res = xp.reshape(res, q.shape + tuple(a_shape)) return res[0, ...] if q_scalar else res -def _quantile_hf( # numpydoc ignore=GL08 - a: Array, q: Array, n: float, axis: int, xp: ModuleType +def _quantile( # numpydoc ignore=GL08 + a: Array, q: Array, n: float, axis: int, method: str, xp: ModuleType ) -> Array: - m = 1 - q + a = xp.sort(a, axis=axis, stable=False) + + if method == "linear": + m = 1 - q + else: # method is "inverted_cdf" or "averaged_inverted_cdf" + m = 0 + jg = q * n + m - 1 j = jg // 1 @@ -77,6 +83,11 @@ def _quantile_hf( # numpydoc ignore=GL08 # `̀j` and `jp1` are 1d arrays g = jg % 1 + if method == 'inverted_cdf': + g = xp.astype((g > 0), jg.dtype) + elif method == 'averaged_inverted_cdf': + g = (1 + xp.astype((g > 0), jg.dtype)) / 2 + g = xp.where(j < 0, 0, g) # equivalent to g[j < 0] = 0, but works with strictest new_g_shape = [1] * a.ndim new_g_shape[axis] = g.shape[0] @@ -85,3 +96,39 @@ def _quantile_hf( # numpydoc ignore=GL08 return (1 - g) * xp.take(a, xp.astype(j, xp.int64), axis=axis) + g * xp.take( a, xp.astype(jp1, xp.int64), axis=axis ) + + +def _weighted_quantile(a: Array, q: Array, weights: Array, n: int, axis, average: bool, xp: ModuleType): + a = xp.moveaxis(a, axis, -1) + sorter = xp.argsort(a, axis=-1, stable=False) + a = xp.take_along_axis(a, sorter, axis=-1) + + if a.ndim == 1: + return _weighted_quantile_sorted_1d(a, q, weights, n, ) + + d, = eager_shape(a, axis=0) + res = xp.empty((q.shape[0], d)) + for idx in range(d): + w = weights if weights.ndim == 1 else weights[idx, ...] + w = xp.take(w, sorter[idx, ...]) + res[..., idx] = _weighted_quantile_sorted_1d(a[idx, ...], q, w, n, average) + return res + + +def _weighted_quantile_sorted_1d(a, q, w, n, average: bool, xp: ModuleType): + cw = xp.cumsum(w) + t = cw[-1] * q + i = xp.searchsorted(cw, t) + j = xp.searchsorted(cw, t, side='right') + i = xp.minimum(i, float(n - 1)) + j = xp.minimum(j, float(n - 1)) + + # Ignore leading `weights=0` observations when `q=0` + # see https://github.com/scikit-learn/scikit-learn/pull/20528 + i = xp.where(q == 0., j, i) + if average: + # Ignore trailing `weights=0` observations when `q=1` + j = xp.where(q == 1., i, j) + return (xp.take(a, i) + xp.take(a, j)) / 2 + else: + return xp.take(a, i) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 200a0a74..2a0bba5e 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1585,7 +1585,7 @@ def test_2d_axis_keepdims(self, xp: ModuleType): def test_methods(self, xp: ModuleType): x = xp.asarray([1, 2, 3, 4, 5]) - methods = ["linear"] # "hazen", "weibull"] + methods = ["linear", "inverted_cdf", "averaged_inverted_cdf"] for method in methods: actual = quantile(x, 0.5, method=method) # All methods should give reasonable results From fa789fc2d199d9d1fee0be2772dfc733710b6ee3 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 07:37:19 +0200 Subject: [PATCH 11/25] Weighted quantile; nan-policy; everything mostly works --- src/array_api_extra/_delegation.py | 50 +++++++++++---- src/array_api_extra/_lib/_quantile.py | 89 +++++++++++++++++---------- tests/test_funcs.py | 68 +++++++++++++++++--- 3 files changed, 155 insertions(+), 52 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 3fb8dee1..bed61553 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -904,6 +904,7 @@ def quantile( axis: int | None = None, method: str = "linear", keepdims: bool = False, + nan_policy: str = "propagate", *, weights: Array | None = None, xp: ModuleType | None = None, @@ -1051,16 +1052,22 @@ def quantile( "Sample quantiles in statistical packages," The American Statistician, 50(4), pp. 361-365, 1996 """ - methods = {"linear", "inverted_cdf", "averaged_inverted_cdf"} + if xp is None: + xp = array_namespace(a) + if is_pydata_sparse_namespace(xp): + raise ValueError('no supported') + methods = {"linear", "inverted_cdf", "averaged_inverted_cdf"} if method not in methods: msg = f"`method` must be one of {methods}" raise ValueError(msg) + nan_policies = {"propagate", "omit"} + if nan_policy not in nan_policies: + msg = f"`nan_policy` must be one of {nan_policies}" + raise ValueError(msg) if keepdims not in {True, False}: msg = "If specified, `keepdims` must be True or False." raise ValueError(msg) - if xp is None: - xp = array_namespace(a) a = xp.asarray(a) if not xp.isdtype(a.dtype, ("integral", "real floating")): @@ -1071,15 +1078,31 @@ def quantile( raise ValueError(msg) ndim = a.ndim if ndim < 1: - msg = "`a` must be at least 1-dimensional" + msg = "`a` must be at least 1-dimensional." raise TypeError(msg) if axis is not None and ((axis >= ndim) or (axis < -ndim)): msg = "`axis` is not compatible with the dimension of `a`." raise ValueError(msg) - - # Array API states: Mixed integer and floating-point type promotion rules - # are not specified because behavior varies between implementations. - # We chose to align with numpy (see docstring): + if weights is None: + if nan_policy != "propagate": + msg = "" + raise ValueError(msg) + else: + if ndim > 2: + msg = "When weights are provided, dimension of `a` must be 1 or 2." + raise ValueError(msg) + if a.shape != weights.shape: + if axis is None: + msg = "Axis must be specified when shapes of `a` and ̀ weights` differ." + raise TypeError(msg) + if weights.shape != eager_shape(a, axis): + msg = "Shape of weights must be consistent with shape of a along specified axis." + raise ValueError(msg) + if axis is None and ndim == 2: + msg = "When weights are provided, axis must be specified when `a` is 2d" + raise ValueError(msg) + + # Align result dtype with what numpy does: dtype = xp.result_type( xp.float64 if xp.isdtype(a.dtype, "integral") else a, xp.asarray(q), @@ -1088,20 +1111,25 @@ def quantile( device = get_device(a) a = xp.asarray(a, dtype=dtype, device=device) q = xp.asarray(q, dtype=dtype, device=device) + # TODO: cast weights here? Assert weights are on the same device as `a`? if xp.any((q > 1) | (q < 0) | xp.isnan(q)): msg = "`q` values must be in the range [0, 1]" raise ValueError(msg) # Delegate where possible. - if is_numpy_namespace(xp): + if is_numpy_namespace(xp) and nan_policy == "propagate": return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims, weights=weights) # No delegation for dask: I couldn't make it work - basic_case = method == "linear" and weights is None + basic_case = method == "linear" and weights is None and nan_policy == "propagate" if (basic_case and is_jax_namespace(xp)) or is_cupy_namespace(xp): return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims) if basic_case and is_torch_namespace(xp): return xp.quantile(a, q, dim=axis, interpolation=method, keepdim=keepdims) + # XXX: I'm not sure we want to support dask, it seems uterly slow... # Otherwise call our implementation (will sort data) - return _quantile.quantile(a, q, axis=axis, method=method, keepdims=keepdims, xp=xp) + return _quantile.quantile( + a, q, axis=axis, method=method, keepdims=keepdims, + nan_policy=nan_policy, weights=weights, xp=xp + ) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 838a3d3f..9e8613bd 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -4,7 +4,7 @@ from ._utils._compat import device as get_device from ._utils._helpers import eager_shape -from ._utils._typing import Array +from ._utils._typing import Array, Device def quantile( # numpydoc ignore=PR01,RT01 @@ -14,6 +14,7 @@ def quantile( # numpydoc ignore=PR01,RT01 method: str = "linear", axis: int | None = None, keepdims: bool = False, + nan_policy: str = "propagate", *, weights: Array | None = None, xp: ModuleType, @@ -43,43 +44,49 @@ def quantile( # numpydoc ignore=PR01,RT01 a = xp.full(tuple(a_shape), xp.nan, dtype=a.dtype, device=device) if weights is None: - res = _quantile(a, q, float(n), axis, method, xp) + res = _quantile(a, q, n, axis, method, xp) + if not axis_none: + res = xp.moveaxis(res, axis, 0) else: + weights = xp.asarray(weights, dtype=xp.float64, device=device) average = method == 'averaged_inverted_cdf' - res = _weighted_quantile(a, q, weights, n, axis, average, xp) - # to support weights, the main thing would be to - # argsort a, and then use it to sort a and w. - # The hard part will be dealing with 0-weights and NaNs - # But maybe a proper use of searchsorted + left/right side will work? + res = _weighted_quantile( + a, q, weights, n, axis, average, + nan_policy=nan_policy, xp=xp, device=device + ) # reshaping to conform to doc/other libs' behavior if axis_none: if keepdims: res = xp.reshape(res, q.shape + (1,) * a_ndim) - else: - res = xp.moveaxis(res, axis, 0) - if keepdims: - a_shape[axis] = 1 - res = xp.reshape(res, q.shape + tuple(a_shape)) + elif keepdims: + a_shape[axis] = 1 + res = xp.reshape(res, q.shape + tuple(a_shape)) return res[0, ...] if q_scalar else res def _quantile( # numpydoc ignore=GL08 - a: Array, q: Array, n: float, axis: int, method: str, xp: ModuleType + a: Array, q: Array, n: int, axis: int, method: str, xp: ModuleType ) -> Array: a = xp.sort(a, axis=axis, stable=False) + mask_nan = xp.any(xp.isnan(a), axis=axis, keepdims=True) + if xp.any(mask_nan): + # propogate NaNs: + mask = xp.repeat(mask_nan, n, axis=axis) + a = xp.where(mask, xp.nan, a) + del mask if method == "linear": - m = 1 - q + m = 1 - q else: # method is "inverted_cdf" or "averaged_inverted_cdf" m = 0 - jg = q * n + m - 1 + jg = q * float(n) + m - 1 j = jg // 1 - j = xp.clip(j, 0.0, n - 1) - jp1 = xp.clip(j + 1, 0.0, n - 1) + j = xp.clip(j, 0.0, float(n - 1)) + jp1 = xp.clip(j + 1, 0.0, float(n - 1)) # `̀j` and `jp1` are 1d arrays g = jg % 1 @@ -88,7 +95,7 @@ def _quantile( # numpydoc ignore=GL08 elif method == 'averaged_inverted_cdf': g = (1 + xp.astype((g > 0), jg.dtype)) / 2 - g = xp.where(j < 0, 0, g) # equivalent to g[j < 0] = 0, but works with strictest + g = xp.where(j < 0, 0, g) # equivalent to g[j < 0] = 0, but works with readonly new_g_shape = [1] * a.ndim new_g_shape[axis] = g.shape[0] g = xp.reshape(g, tuple(new_g_shape)) @@ -98,37 +105,55 @@ def _quantile( # numpydoc ignore=GL08 ) -def _weighted_quantile(a: Array, q: Array, weights: Array, n: int, axis, average: bool, xp: ModuleType): +def _weighted_quantile( + a: Array, q: Array, weights: Array, n: int, axis: int, average: bool, nan_policy: str, + xp: ModuleType, device: Device +) -> Array: + """ + a is expected to be 1d or 2d. + """ + kwargs = dict(n=n, average=average, nan_policy=nan_policy, xp=xp, device=device) a = xp.moveaxis(a, axis, -1) + if weights.ndim > 1: + weights = xp.moveaxis(weights, axis, -1) sorter = xp.argsort(a, axis=-1, stable=False) - a = xp.take_along_axis(a, sorter, axis=-1) if a.ndim == 1: - return _weighted_quantile_sorted_1d(a, q, weights, n, ) + x = xp.take(a, sorter) + w = xp.take(weights, sorter) + return _weighted_quantile_sorted_1d(x, q, w, **kwargs) d, = eager_shape(a, axis=0) - res = xp.empty((q.shape[0], d)) + res = [] for idx in range(d): w = weights if weights.ndim == 1 else weights[idx, ...] w = xp.take(w, sorter[idx, ...]) - res[..., idx] = _weighted_quantile_sorted_1d(a[idx, ...], q, w, n, average) + x = xp.take(a[idx, ...], sorter[idx, ...]) + res.append(_weighted_quantile_sorted_1d(x, q, w, **kwargs)) + res = xp.stack(res, axis=1) return res -def _weighted_quantile_sorted_1d(a, q, w, n, average: bool, xp: ModuleType): - cw = xp.cumsum(w) +def _weighted_quantile_sorted_1d( + x: Array, q: Array, w: Array, n: int, average: bool, nan_policy: str, + xp: ModuleType, device: Device +) -> Array: + if nan_policy == "omit": + w = xp.where(xp.isnan(x), 0., w) + elif xp.any(xp.isnan(x)): + return xp.full(q.shape, xp.nan, dtype=x.dtype, device=device) + cw = xp.cumulative_sum(w) t = cw[-1] * q - i = xp.searchsorted(cw, t) + i = xp.searchsorted(cw, t, side='left') j = xp.searchsorted(cw, t, side='right') - i = xp.minimum(i, float(n - 1)) - j = xp.minimum(j, float(n - 1)) + i = xp.clip(i, 0, n - 1) + j = xp.clip(j, 0, n - 1) # Ignore leading `weights=0` observations when `q=0` # see https://github.com/scikit-learn/scikit-learn/pull/20528 - i = xp.where(q == 0., j, i) + i = xp.where(q == 0., j, i) if average: # Ignore trailing `weights=0` observations when `q=1` j = xp.where(q == 1., i, j) - return (xp.take(a, i) + xp.take(a, j)) / 2 - else: - return xp.take(a, i) + return (xp.take(x, i) + xp.take(x, j)) / 2 + return xp.take(x, i) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 2a0bba5e..8a12bb12 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1558,18 +1558,70 @@ def test_shape(self, xp: ModuleType): assert quantile(a, q, axis=1, keepdims=True).shape == (2, 3, 1, 5) assert quantile(a, q, axis=2, keepdims=True).shape == (2, 3, 4, 1) - def test_against_numpy(self, xp: ModuleType): + @pytest.mark.parametrize("keepdims", [True, False]) + def test_against_numpy(self, xp: ModuleType, keepdims: bool): rng = np.random.default_rng() a_np = rng.random((3, 4, 5)) q_np = rng.random(2) a = xp.asarray(a_np) q = xp.asarray(q_np) - for keepdims in [False, True]: - for axis in [None, *range(a.ndim)]: - actual = quantile(a, q, axis=axis, keepdims=keepdims) - expected = np.quantile(a_np, q_np, axis=axis, keepdims=keepdims) - expected = xp.asarray(expected, dtype=xp.float64) - xp_assert_close(actual, expected, atol=1e-12) + for axis in [None, *range(a.ndim)]: + actual = quantile(a, q, axis=axis, keepdims=keepdims) + expected = np.quantile(a_np, q_np, axis=axis, keepdims=keepdims) + expected = xp.asarray(expected) + xp_assert_close(actual, expected, atol=1e-12) + + @pytest.mark.parametrize("keepdims", [True, False]) + @pytest.mark.parametrize("nan_policy", ["omit", "no_nans", "propagate"])#, #["omit"])#["no_nans", "propagate"]) + @pytest.mark.parametrize("q_np", [0.5, 0., 1., np.linspace(0, 1, num=11)]) + def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Array | float, nan_policy: str): + rng = np.random.default_rng() + n, d = 10, 20 + a_np = rng.random((n, d)) + kwargs = dict(keepdims=keepdims) + mask_nan = np.zeros((n, d), dtype=bool) + if nan_policy != "no_nans": + # from 0% to 100% of NaNs: + mask_nan = rng.random((n, d)) < rng.random((n, 1)) + # don't put nans in the first row: + mask_nan[:] = False + a_np[mask_nan] = np.nan + kwargs['nan_policy'] = nan_policy + + a = xp.asarray(a_np) + q = xp.asarray(np.copy(q_np)) + m = 'inverted_cdf' + + np_quantile = np.quantile + if nan_policy == "omit": + np_quantile = np.nanquantile + + for w_np, axis in [ + (rng.random(n), 0), + (rng.random(d), 1), + (rng.integers(0, 2, n), 0), + (rng.integers(0, 2, d), 1), + (rng.integers(0, 2, (n, d)), 0), + (rng.integers(0, 2, (n, d)), 1), + ]: + print(w_np) + with warnings.catch_warnings(record=True) as warning: + warnings.filterwarnings("always", "invalid value encountered in divide", RuntimeWarning) + warnings.filterwarnings("ignore", "All-NaN slice encountered", RuntimeWarning) + try: + expected = np_quantile(a_np, q_np, axis=axis, method=m, weights=w_np, keepdims=keepdims) + except IndexError: + print('index error') + continue + if warning: # this means some weights sum was 0, in this case we skip calling xpx.quantile + print('warning') + continue + expected = xp.asarray(expected) + print("not skiped") + + w = xp.asarray(w_np) + actual = quantile(a, q, axis=axis, method=m, weights=w, **kwargs) + xp_assert_close(actual, expected, atol=1e-12) def test_2d_axis(self, xp: ModuleType): x = xp.asarray([[1, 2, 3], [4, 5, 6]]) @@ -1605,8 +1657,6 @@ def test_edge_cases(self, xp: ModuleType): def test_invalid_q(self, xp: ModuleType): x = xp.asarray([1, 2, 3, 4, 5]) - _ = quantile(x, 1.0) - # ^ FIXME: here just to make this test fail for sparse backend # q > 1 should raise with pytest.raises( ValueError, match=r"`q` values must be in the range \[0, 1\]" From 1d8fef7c1f5d60aabbf2a19148ac0287dfdae92a Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 13:25:34 +0200 Subject: [PATCH 12/25] linting: pyright & mypy --- src/array_api_extra/_delegation.py | 12 ++++++------ src/array_api_extra/_lib/_quantile.py | 11 +++++------ tests/test_funcs.py | 23 ++++++++++++----------- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index bed61553..3623f33f 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -1110,26 +1110,26 @@ def quantile( ) device = get_device(a) a = xp.asarray(a, dtype=dtype, device=device) - q = xp.asarray(q, dtype=dtype, device=device) + q_arr = xp.asarray(q, dtype=dtype, device=device) # TODO: cast weights here? Assert weights are on the same device as `a`? - if xp.any((q > 1) | (q < 0) | xp.isnan(q)): + if xp.any((q_arr > 1) | (q_arr < 0) | xp.isnan(q_arr)): msg = "`q` values must be in the range [0, 1]" raise ValueError(msg) # Delegate where possible. if is_numpy_namespace(xp) and nan_policy == "propagate": - return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims, weights=weights) + return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights) # No delegation for dask: I couldn't make it work basic_case = method == "linear" and weights is None and nan_policy == "propagate" if (basic_case and is_jax_namespace(xp)) or is_cupy_namespace(xp): - return xp.quantile(a, q, axis=axis, method=method, keepdims=keepdims) + return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims) if basic_case and is_torch_namespace(xp): - return xp.quantile(a, q, dim=axis, interpolation=method, keepdim=keepdims) + return xp.quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims) # XXX: I'm not sure we want to support dask, it seems uterly slow... # Otherwise call our implementation (will sort data) return _quantile.quantile( - a, q, axis=axis, method=method, keepdims=keepdims, + a, q_arr, axis=axis, method=method, keepdims=keepdims, nan_policy=nan_policy, weights=weights, xp=xp ) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 9e8613bd..3dd7d9dc 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -48,10 +48,10 @@ def quantile( # numpydoc ignore=PR01,RT01 if not axis_none: res = xp.moveaxis(res, axis, 0) else: - weights = xp.asarray(weights, dtype=xp.float64, device=device) + weights_arr = xp.asarray(weights, dtype=xp.float64, device=device) average = method == 'averaged_inverted_cdf' res = _weighted_quantile( - a, q, weights, n, axis, average, + a, q, weights_arr, n, axis, average, nan_policy=nan_policy, xp=xp, device=device ) @@ -80,7 +80,7 @@ def _quantile( # numpydoc ignore=GL08 if method == "linear": m = 1 - q else: # method is "inverted_cdf" or "averaged_inverted_cdf" - m = 0 + m = xp.asarray(0, dtype=q.dtype) jg = q * float(n) + m - 1 @@ -112,7 +112,6 @@ def _weighted_quantile( """ a is expected to be 1d or 2d. """ - kwargs = dict(n=n, average=average, nan_policy=nan_policy, xp=xp, device=device) a = xp.moveaxis(a, axis, -1) if weights.ndim > 1: weights = xp.moveaxis(weights, axis, -1) @@ -121,7 +120,7 @@ def _weighted_quantile( if a.ndim == 1: x = xp.take(a, sorter) w = xp.take(weights, sorter) - return _weighted_quantile_sorted_1d(x, q, w, **kwargs) + return _weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device) d, = eager_shape(a, axis=0) res = [] @@ -129,7 +128,7 @@ def _weighted_quantile( w = weights if weights.ndim == 1 else weights[idx, ...] w = xp.take(w, sorter[idx, ...]) x = xp.take(a[idx, ...], sorter[idx, ...]) - res.append(_weighted_quantile_sorted_1d(x, q, w, **kwargs)) + res.append(_weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device)) res = xp.stack(res, axis=1) return res diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 8a12bb12..298cf274 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1578,18 +1578,18 @@ def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Arra rng = np.random.default_rng() n, d = 10, 20 a_np = rng.random((n, d)) - kwargs = dict(keepdims=keepdims) mask_nan = np.zeros((n, d), dtype=bool) - if nan_policy != "no_nans": + if nan_policy == "no_nans": + nan_policy = "propagate" + else: # from 0% to 100% of NaNs: mask_nan = rng.random((n, d)) < rng.random((n, 1)) # don't put nans in the first row: mask_nan[:] = False a_np[mask_nan] = np.nan - kwargs['nan_policy'] = nan_policy - a = xp.asarray(a_np) - q = xp.asarray(np.copy(q_np)) + a = xp.asarray(a_np, copy=True) + q = xp.asarray(q_np, copy=True) m = 'inverted_cdf' np_quantile = np.quantile @@ -1604,23 +1604,24 @@ def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Arra (rng.integers(0, 2, (n, d)), 0), (rng.integers(0, 2, (n, d)), 1), ]: - print(w_np) with warnings.catch_warnings(record=True) as warning: warnings.filterwarnings("always", "invalid value encountered in divide", RuntimeWarning) warnings.filterwarnings("ignore", "All-NaN slice encountered", RuntimeWarning) try: - expected = np_quantile(a_np, q_np, axis=axis, method=m, weights=w_np, keepdims=keepdims) + expected = np_quantile( # type: ignore[call-overload] + a_np, np.asarray(q_np), + axis=axis, method=m, weights=w_np, keepdims=keepdims + ) except IndexError: - print('index error') continue if warning: # this means some weights sum was 0, in this case we skip calling xpx.quantile - print('warning') continue expected = xp.asarray(expected) - print("not skiped") w = xp.asarray(w_np) - actual = quantile(a, q, axis=axis, method=m, weights=w, **kwargs) + actual = quantile( + a, q, axis=axis, method=m, weights=w, keepdims=keepdims, nan_policy=nan_policy + ) xp_assert_close(actual, expected, atol=1e-12) def test_2d_axis(self, xp: ModuleType): From 26804fedb777777fb1907efdb9ed322d5dfa5435 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 13:32:20 +0200 Subject: [PATCH 13/25] linting: ruff --- src/array_api_extra/_delegation.py | 22 +++++++-- src/array_api_extra/_lib/_quantile.py | 70 ++++++++++++++++++--------- tests/test_funcs.py | 38 ++++++++++----- 3 files changed, 91 insertions(+), 39 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 3623f33f..5045525c 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -1055,7 +1055,8 @@ def quantile( if xp is None: xp = array_namespace(a) if is_pydata_sparse_namespace(xp): - raise ValueError('no supported') + msg = "Sparse backend not supported" + raise ValueError(msg) methods = {"linear", "inverted_cdf", "averaged_inverted_cdf"} if method not in methods: @@ -1096,7 +1097,10 @@ def quantile( msg = "Axis must be specified when shapes of `a` and ̀ weights` differ." raise TypeError(msg) if weights.shape != eager_shape(a, axis): - msg = "Shape of weights must be consistent with shape of a along specified axis." + msg = ( + "Shape of weights must be consistent with shape" + " of a along specified axis." + ) raise ValueError(msg) if axis is None and ndim == 2: msg = "When weights are provided, axis must be specified when `a` is 2d" @@ -1119,7 +1123,9 @@ def quantile( # Delegate where possible. if is_numpy_namespace(xp) and nan_policy == "propagate": - return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights) + return xp.quantile( + a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights + ) # No delegation for dask: I couldn't make it work basic_case = method == "linear" and weights is None and nan_policy == "propagate" if (basic_case and is_jax_namespace(xp)) or is_cupy_namespace(xp): @@ -1130,6 +1136,12 @@ def quantile( # XXX: I'm not sure we want to support dask, it seems uterly slow... # Otherwise call our implementation (will sort data) return _quantile.quantile( - a, q_arr, axis=axis, method=method, keepdims=keepdims, - nan_policy=nan_policy, weights=weights, xp=xp + a, + q_arr, + axis=axis, + method=method, + keepdims=keepdims, + nan_policy=nan_policy, + weights=weights, + xp=xp, ) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 3dd7d9dc..f3efa68d 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -49,10 +49,17 @@ def quantile( # numpydoc ignore=PR01,RT01 res = xp.moveaxis(res, axis, 0) else: weights_arr = xp.asarray(weights, dtype=xp.float64, device=device) - average = method == 'averaged_inverted_cdf' + average = method == "averaged_inverted_cdf" res = _weighted_quantile( - a, q, weights_arr, n, axis, average, - nan_policy=nan_policy, xp=xp, device=device + a, + q, + weights_arr, + n, + axis, + average, + nan_policy=nan_policy, + xp=xp, + device=device, ) # reshaping to conform to doc/other libs' behavior @@ -72,15 +79,17 @@ def _quantile( # numpydoc ignore=GL08 a = xp.sort(a, axis=axis, stable=False) mask_nan = xp.any(xp.isnan(a), axis=axis, keepdims=True) if xp.any(mask_nan): - # propogate NaNs: + # propagate NaNs: mask = xp.repeat(mask_nan, n, axis=axis) a = xp.where(mask, xp.nan, a) del mask - if method == "linear": - m = 1 - q - else: # method is "inverted_cdf" or "averaged_inverted_cdf" - m = xp.asarray(0, dtype=q.dtype) + m = ( + 1 - q + if method == "linear" + # method is "inverted_cdf" or "averaged_inverted_cdf" + else xp.asarray(0, dtype=q.dtype) + ) jg = q * float(n) + m - 1 @@ -90,9 +99,9 @@ def _quantile( # numpydoc ignore=GL08 # `̀j` and `jp1` are 1d arrays g = jg % 1 - if method == 'inverted_cdf': + if method == "inverted_cdf": g = xp.astype((g > 0), jg.dtype) - elif method == 'averaged_inverted_cdf': + elif method == "averaged_inverted_cdf": g = (1 + xp.astype((g > 0), jg.dtype)) / 2 g = xp.where(j < 0, 0, g) # equivalent to g[j < 0] = 0, but works with readonly @@ -106,8 +115,15 @@ def _quantile( # numpydoc ignore=GL08 def _weighted_quantile( - a: Array, q: Array, weights: Array, n: int, axis: int, average: bool, nan_policy: str, - xp: ModuleType, device: Device + a: Array, + q: Array, + weights: Array, + n: int, + axis: int, + average: bool, + nan_policy: str, + xp: ModuleType, + device: Device, ) -> Array: """ a is expected to be 1d or 2d. @@ -122,37 +138,45 @@ def _weighted_quantile( w = xp.take(weights, sorter) return _weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device) - d, = eager_shape(a, axis=0) + (d,) = eager_shape(a, axis=0) res = [] for idx in range(d): w = weights if weights.ndim == 1 else weights[idx, ...] w = xp.take(w, sorter[idx, ...]) x = xp.take(a[idx, ...], sorter[idx, ...]) - res.append(_weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device)) - res = xp.stack(res, axis=1) - return res + res.append( + _weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device) + ) + + return xp.stack(res, axis=1) def _weighted_quantile_sorted_1d( - x: Array, q: Array, w: Array, n: int, average: bool, nan_policy: str, - xp: ModuleType, device: Device + x: Array, + q: Array, + w: Array, + n: int, + average: bool, + nan_policy: str, + xp: ModuleType, + device: Device, ) -> Array: if nan_policy == "omit": - w = xp.where(xp.isnan(x), 0., w) + w = xp.where(xp.isnan(x), 0.0, w) elif xp.any(xp.isnan(x)): return xp.full(q.shape, xp.nan, dtype=x.dtype, device=device) cw = xp.cumulative_sum(w) t = cw[-1] * q - i = xp.searchsorted(cw, t, side='left') - j = xp.searchsorted(cw, t, side='right') + i = xp.searchsorted(cw, t, side="left") + j = xp.searchsorted(cw, t, side="right") i = xp.clip(i, 0, n - 1) j = xp.clip(j, 0, n - 1) # Ignore leading `weights=0` observations when `q=0` # see https://github.com/scikit-learn/scikit-learn/pull/20528 - i = xp.where(q == 0., j, i) + i = xp.where(q == 0.0, j, i) if average: # Ignore trailing `weights=0` observations when `q=1` - j = xp.where(q == 1., i, j) + j = xp.where(q == 1.0, i, j) return (xp.take(x, i) + xp.take(x, j)) / 2 return xp.take(x, i) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 298cf274..cf27e676 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1572,9 +1572,11 @@ def test_against_numpy(self, xp: ModuleType, keepdims: bool): xp_assert_close(actual, expected, atol=1e-12) @pytest.mark.parametrize("keepdims", [True, False]) - @pytest.mark.parametrize("nan_policy", ["omit", "no_nans", "propagate"])#, #["omit"])#["no_nans", "propagate"]) - @pytest.mark.parametrize("q_np", [0.5, 0., 1., np.linspace(0, 1, num=11)]) - def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Array | float, nan_policy: str): + @pytest.mark.parametrize("nan_policy", ["omit", "no_nans", "propagate"]) + @pytest.mark.parametrize("q_np", [0.5, 0.0, 1.0, np.linspace(0, 1, num=11)]) + def test_weighted_against_numpy( + self, xp: ModuleType, keepdims: bool, q_np: Array | float, nan_policy: str + ): rng = np.random.default_rng() n, d = 10, 20 a_np = rng.random((n, d)) @@ -1590,7 +1592,7 @@ def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Arra a = xp.asarray(a_np, copy=True) q = xp.asarray(q_np, copy=True) - m = 'inverted_cdf' + m = "inverted_cdf" np_quantile = np.quantile if nan_policy == "omit": @@ -1605,22 +1607,36 @@ def test_weighted_against_numpy(self, xp: ModuleType, keepdims: bool, q_np: Arra (rng.integers(0, 2, (n, d)), 1), ]: with warnings.catch_warnings(record=True) as warning: - warnings.filterwarnings("always", "invalid value encountered in divide", RuntimeWarning) - warnings.filterwarnings("ignore", "All-NaN slice encountered", RuntimeWarning) + divide_msg = "invalid value encountered in divide" + warnings.filterwarnings("always", divide_msg, RuntimeWarning) + nan_slice_msg = "All-NaN slice encountered" + warnings.filterwarnings("ignore", nan_slice_msg, RuntimeWarning) try: expected = np_quantile( # type: ignore[call-overload] - a_np, np.asarray(q_np), - axis=axis, method=m, weights=w_np, keepdims=keepdims + a_np, + np.asarray(q_np), + axis=axis, + method=m, + weights=w_np, + keepdims=keepdims, ) except IndexError: continue - if warning: # this means some weights sum was 0, in this case we skip calling xpx.quantile + if warning: + # this means some weights sum was 0 + # in this case we skip calling xpx.quantile continue expected = xp.asarray(expected) w = xp.asarray(w_np) - actual = quantile( - a, q, axis=axis, method=m, weights=w, keepdims=keepdims, nan_policy=nan_policy + actual = quantile( + a, + q, + axis=axis, + method=m, + weights=w, + keepdims=keepdims, + nan_policy=nan_policy, ) xp_assert_close(actual, expected, atol=1e-12) From 3611708c3c9cf89cc7face1667b0d8696e8c8447 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 13:46:46 +0200 Subject: [PATCH 14/25] linting & cleanup --- src/array_api_extra/_delegation.py | 21 ++++++++++++++------- src/array_api_extra/_lib/_quantile.py | 19 +++++++++++-------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 5045525c..9daa7466 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -938,13 +938,16 @@ def quantile( 9. 'normal_unbiased' The first three methods are discontinuous. - Only 'linear' is implemented for now. + Only 'linear', 'inverted_cdf' and 'averaged_inverted_cdf' are implemented. keepdims : bool, optional If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array `a`. + nan_policy : str, optional + 'propagate' (default) or 'omit'. + weights : array_like, optional An array of weights associated with the values in `a`. Each value in `a` contributes to the quantile according to its associated weight. @@ -1121,20 +1124,24 @@ def quantile( msg = "`q` values must be in the range [0, 1]" raise ValueError(msg) - # Delegate where possible. + # Delegate when possible. if is_numpy_namespace(xp) and nan_policy == "propagate": + # TODO: call nanquantile for nan_policy == "omit" once + # https://github.com/numpy/numpy/issues/29709 is fixed return xp.quantile( a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights ) - # No delegation for dask: I couldn't make it work - basic_case = method == "linear" and weights is None and nan_policy == "propagate" - if (basic_case and is_jax_namespace(xp)) or is_cupy_namespace(xp): + # No delegation for dask: I couldn't make it work. + basic_case = method == "linear" and weights is None + jax_or_cupy = is_jax_namespace(xp) or is_cupy_namespace(xp) + if basic_case and nan_policy == "propagate" and jax_or_cupy: return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims) if basic_case and is_torch_namespace(xp): - return xp.quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims) + quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile + return quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims) - # XXX: I'm not sure we want to support dask, it seems uterly slow... # Otherwise call our implementation (will sort data) + # XXX: I'm not sure we want to support dask, it seems uterly slow... return _quantile.quantile( a, q_arr, diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index f3efa68d..efc073fd 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -73,9 +73,10 @@ def quantile( # numpydoc ignore=PR01,RT01 return res[0, ...] if q_scalar else res -def _quantile( # numpydoc ignore=GL08 +def _quantile( # numpydoc ignore=PR01,RT01 a: Array, q: Array, n: int, axis: int, method: str, xp: ModuleType ) -> Array: + """Compute quantile by sorting `a`.""" a = xp.sort(a, axis=axis, stable=False) mask_nan = xp.any(xp.isnan(a), axis=axis, keepdims=True) if xp.any(mask_nan): @@ -114,7 +115,7 @@ def _quantile( # numpydoc ignore=GL08 ) -def _weighted_quantile( +def _weighted_quantile( # numpydoc ignore=PR01,RT01 a: Array, q: Array, weights: Array, @@ -126,7 +127,9 @@ def _weighted_quantile( device: Device, ) -> Array: """ - a is expected to be 1d or 2d. + Compute weighted quantile using searchsorted on CDF. + + `a` is expected to be 1d or 2d. """ a = xp.moveaxis(a, axis, -1) if weights.ndim > 1: @@ -151,7 +154,7 @@ def _weighted_quantile( return xp.stack(res, axis=1) -def _weighted_quantile_sorted_1d( +def _weighted_quantile_sorted_1d( # numpydoc ignore=GL08 x: Array, q: Array, w: Array, @@ -165,10 +168,10 @@ def _weighted_quantile_sorted_1d( w = xp.where(xp.isnan(x), 0.0, w) elif xp.any(xp.isnan(x)): return xp.full(q.shape, xp.nan, dtype=x.dtype, device=device) - cw = xp.cumulative_sum(w) - t = cw[-1] * q - i = xp.searchsorted(cw, t, side="left") - j = xp.searchsorted(cw, t, side="right") + cdf = xp.cumulative_sum(w) + t = cdf[-1] * q + i = xp.searchsorted(cdf, t, side="left") + j = xp.searchsorted(cdf, t, side="right") i = xp.clip(i, 0, n - 1) j = xp.clip(j, 0, n - 1) From 7160bae8fc3e72b4040646075193b702c8d1f786 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 15:06:10 +0200 Subject: [PATCH 15/25] fix tests for numpy 1.x --- src/array_api_extra/_delegation.py | 9 +++++---- tests/test_funcs.py | 2 ++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 9daa7466..e20408f9 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -5,6 +5,7 @@ from typing import Literal from ._lib import _funcs, _quantile +from ._lib._backends import NUMPY_VERSION from ._lib._utils._compat import ( array_namespace, is_cupy_namespace, @@ -1047,7 +1048,6 @@ def quantile( empirical cumulative distribution is simply replaced by its weighted version, i.e. :math:`P(Y \\leq t) = \\frac{1}{\\sum_i w_i} \\sum_i w_i 1_{x_i \\leq t}`. - Only ``method="inverted_cdf"`` supports weights. References ---------- @@ -1125,14 +1125,15 @@ def quantile( raise ValueError(msg) # Delegate when possible. - if is_numpy_namespace(xp) and nan_policy == "propagate": + basic_case = method == "linear" and weights is None + np_2 = NUMPY_VERSION >= (2, 0) + if is_numpy_namespace(xp) and nan_policy == "propagate" and (basic_case or np_2): # TODO: call nanquantile for nan_policy == "omit" once # https://github.com/numpy/numpy/issues/29709 is fixed return xp.quantile( a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights ) # No delegation for dask: I couldn't make it work. - basic_case = method == "linear" and weights is None jax_or_cupy = is_jax_namespace(xp) or is_cupy_namespace(xp) if basic_case and nan_policy == "propagate" and jax_or_cupy: return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims) @@ -1141,8 +1142,8 @@ def quantile( return quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims) # Otherwise call our implementation (will sort data) - # XXX: I'm not sure we want to support dask, it seems uterly slow... return _quantile.quantile( + # XXX: I'm not sure we want to support dask, it seems uterly slow... a, q_arr, axis=axis, diff --git a/tests/test_funcs.py b/tests/test_funcs.py index cf27e676..0feff80b 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1577,6 +1577,8 @@ def test_against_numpy(self, xp: ModuleType, keepdims: bool): def test_weighted_against_numpy( self, xp: ModuleType, keepdims: bool, q_np: Array | float, nan_policy: str ): + if NUMPY_VERSION < (2, 0): + pytest.xfail(reason="NumPy 1.x does not support weights in quantile") rng = np.random.default_rng() n, d = 10, 20 a_np = rng.random((n, d)) From 0b2cb9b4a24248b01ae030e1f97b074c95c000b4 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 16:14:29 +0200 Subject: [PATCH 16/25] working on coverage --- src/array_api_extra/_delegation.py | 16 +++-- src/array_api_extra/_lib/_quantile.py | 15 ++-- tests/test_funcs.py | 100 +++++++++++++++++--------- 3 files changed, 84 insertions(+), 47 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index e20408f9..8be2c990 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -948,6 +948,7 @@ def quantile( nan_policy : str, optional 'propagate' (default) or 'omit'. + 'omit' is support only when `weights` are provided. weights : array_like, optional An array of weights associated with the values in `a`. Each value in @@ -1125,19 +1126,26 @@ def quantile( raise ValueError(msg) # Delegate when possible. + # Note: No delegation for dask: I couldn't make it work. basic_case = method == "linear" and weights is None + np_2 = NUMPY_VERSION >= (2, 0) - if is_numpy_namespace(xp) and nan_policy == "propagate" and (basic_case or np_2): + np_handles_weights = np_2 and nan_policy == "propagate" and method == "inverted_cdf" + if weights is None: + if is_numpy_namespace(xp) and (basic_case or np_2): + quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile + return quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims) + elif is_numpy_namespace(xp) and np_handles_weights: # TODO: call nanquantile for nan_policy == "omit" once # https://github.com/numpy/numpy/issues/29709 is fixed return xp.quantile( a, q_arr, axis=axis, method=method, keepdims=keepdims, weights=weights ) - # No delegation for dask: I couldn't make it work. + jax_or_cupy = is_jax_namespace(xp) or is_cupy_namespace(xp) - if basic_case and nan_policy == "propagate" and jax_or_cupy: + if jax_or_cupy and basic_case and nan_policy == "propagate": return xp.quantile(a, q_arr, axis=axis, method=method, keepdims=keepdims) - if basic_case and is_torch_namespace(xp): + if is_torch_namespace(xp) and basic_case: quantile = xp.quantile if nan_policy == "propagate" else xp.nanquantile return quantile(a, q_arr, dim=axis, interpolation=method, keepdim=keepdims) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index efc073fd..f3f7fefb 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -36,12 +36,6 @@ def quantile( # numpydoc ignore=PR01,RT01 axis = int(axis) (n,) = eager_shape(a, axis) - # If data has length zero along `axis`, the result will be an array of NaNs just - # as if the data had length 1 along axis and were filled with NaNs. - if n == 0: - a_shape[axis] = 1 - n = 1 - a = xp.full(tuple(a_shape), xp.nan, dtype=a.dtype, device=device) if weights is None: res = _quantile(a, q, n, axis, method, xp) @@ -93,12 +87,7 @@ def _quantile( # numpydoc ignore=PR01,RT01 ) jg = q * float(n) + m - 1 - j = jg // 1 - j = xp.clip(j, 0.0, float(n - 1)) - jp1 = xp.clip(j + 1, 0.0, float(n - 1)) - # `̀j` and `jp1` are 1d arrays - g = jg % 1 if method == "inverted_cdf": g = xp.astype((g > 0), jg.dtype) @@ -110,6 +99,10 @@ def _quantile( # numpydoc ignore=PR01,RT01 new_g_shape[axis] = g.shape[0] g = xp.reshape(g, tuple(new_g_shape)) + j = xp.clip(j, 0.0, float(n - 1)) + jp1 = xp.clip(j + 1, 0.0, float(n - 1)) + # `̀j` and `jp1` are 1d arrays + return (1 - g) * xp.take(a, xp.astype(j, xp.int64), axis=axis) + g * xp.take( a, xp.astype(jp1, xp.int64), axis=axis ) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 0feff80b..dd39c7d0 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1,7 +1,7 @@ import math import warnings from types import ModuleType -from typing import Any, cast +from typing import Any, Literal, cast, get_args import hypothesis import hypothesis.extra.numpy as npst @@ -1531,6 +1531,7 @@ def test_kind(self, xp: ModuleType, library: Backend): res = isin(a, b, kind="sort") xp_assert_equal(res, expected) +METHOD = Literal["linear", "inverted_cdf", "averaged_inverted_cdf"] @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no xp.take") class TestQuantile: @@ -1558,21 +1559,67 @@ def test_shape(self, xp: ModuleType): assert quantile(a, q, axis=1, keepdims=True).shape == (2, 3, 1, 5) assert quantile(a, q, axis=2, keepdims=True).shape == (2, 3, 4, 1) + @pytest.mark.parametrize("with_nans", ["no_nans", "with_nans"]) + @pytest.mark.parametrize("method", get_args(METHOD)) + def test_against_numpy_1d(self, xp: ModuleType, with_nans: str, method: METHOD): + rng = np.random.default_rng() + a_np = rng.random(40) + if with_nans == "with_nans": + a_np[rng.random(a_np.shape) < rng.random() * 0.5] = np.nan + q_np = np.asarray([0, *rng.random(2), 1]) + a = xp.asarray(a_np) + q = xp.asarray(q_np) + + actual = quantile(a, q, method=method) + expected = np.quantile(a_np, q_np, method=method) + expected = xp.asarray(expected) + xp_assert_close(actual, expected) + + @pytest.mark.parametrize("with_nans", ["no_nans", "with_nans"]) + @pytest.mark.parametrize("method", get_args(METHOD)) @pytest.mark.parametrize("keepdims", [True, False]) - def test_against_numpy(self, xp: ModuleType, keepdims: bool): + def test_against_numpy_nd(self, xp: ModuleType, keepdims: bool, + with_nans: str, method: METHOD): rng = np.random.default_rng() a_np = rng.random((3, 4, 5)) + if with_nans == "with_nans": + a_np[rng.random(a_np.shape) < rng.random()] = np.nan q_np = rng.random(2) a = xp.asarray(a_np) q = xp.asarray(q_np) for axis in [None, *range(a.ndim)]: - actual = quantile(a, q, axis=axis, keepdims=keepdims) - expected = np.quantile(a_np, q_np, axis=axis, keepdims=keepdims) + actual = quantile(a, q, axis=axis, keepdims=keepdims, method=method) + expected = np.quantile( + a_np, q_np, axis=axis, keepdims=keepdims, method=method + ) expected = xp.asarray(expected) - xp_assert_close(actual, expected, atol=1e-12) + xp_assert_close(actual, expected) + + @pytest.mark.parametrize("nan_policy", ["no_nans", "propagate"]) + @pytest.mark.parametrize("with_weights", ["with_weights", "no_weights"]) + def test_against_median( + self, xp: ModuleType, nan_policy: str, with_weights: str, + ): + rng = np.random.default_rng() + n = 40 + a_np = rng.random(n) + w_np = rng.integers(0, 2, n) if with_weights == "with_weights" else None + if nan_policy == "no_nans": + nan_policy = "propagate" + else: + # from 0% to 50% of NaNs: + a_np[rng.random(n) < rng.random(n) * 0.5] = np.nan + m = "averaged_inverted_cdf" + + np_median = np.nanmedian if nan_policy == "omit" else np.median + expected = np_median(a_np if w_np is None else a_np[w_np > 0]) + a = xp.asarray(a_np) + w = xp.asarray(w_np) if w_np is not None else None + actual = quantile(a, 0.5, method=m, nan_policy=nan_policy, weights=w) + xp_assert_close(actual, xp.asarray(expected)) @pytest.mark.parametrize("keepdims", [True, False]) - @pytest.mark.parametrize("nan_policy", ["omit", "no_nans", "propagate"]) + @pytest.mark.parametrize("nan_policy", ["no_nans", "propagate", "omit"]) @pytest.mark.parametrize("q_np", [0.5, 0.0, 1.0, np.linspace(0, 1, num=11)]) def test_weighted_against_numpy( self, xp: ModuleType, keepdims: bool, q_np: Array | float, nan_policy: str @@ -1581,7 +1628,7 @@ def test_weighted_against_numpy( pytest.xfail(reason="NumPy 1.x does not support weights in quantile") rng = np.random.default_rng() n, d = 10, 20 - a_np = rng.random((n, d)) + a_2d = rng.random((n, d)) mask_nan = np.zeros((n, d), dtype=bool) if nan_policy == "no_nans": nan_policy = "propagate" @@ -1590,23 +1637,23 @@ def test_weighted_against_numpy( mask_nan = rng.random((n, d)) < rng.random((n, 1)) # don't put nans in the first row: mask_nan[:] = False - a_np[mask_nan] = np.nan + a_2d[mask_nan] = np.nan - a = xp.asarray(a_np, copy=True) q = xp.asarray(q_np, copy=True) - m = "inverted_cdf" + m: METHOD = "inverted_cdf" np_quantile = np.quantile if nan_policy == "omit": np_quantile = np.nanquantile - for w_np, axis in [ - (rng.random(n), 0), - (rng.random(d), 1), - (rng.integers(0, 2, n), 0), - (rng.integers(0, 2, d), 1), - (rng.integers(0, 2, (n, d)), 0), - (rng.integers(0, 2, (n, d)), 1), + for a_np, w_np, axis in [ + (a_2d, rng.random(n), 0), + (a_2d, rng.random(d), 1), + (a_2d[0], rng.random(d), None), + (a_2d, rng.integers(0, 3, n), 0), + (a_2d, rng.integers(0, 2, d), 1), + (a_2d, rng.integers(0, 2, (n, d)), 0), + (a_2d, rng.integers(0, 3, (n, d)), 1), ]: with warnings.catch_warnings(record=True) as warning: divide_msg = "invalid value encountered in divide" @@ -1614,12 +1661,12 @@ def test_weighted_against_numpy( nan_slice_msg = "All-NaN slice encountered" warnings.filterwarnings("ignore", nan_slice_msg, RuntimeWarning) try: - expected = np_quantile( # type: ignore[call-overload] + expected = np_quantile( a_np, np.asarray(q_np), axis=axis, method=m, - weights=w_np, + weights=w_np, # type: ignore[arg-type] keepdims=keepdims, ) except IndexError: @@ -1630,6 +1677,7 @@ def test_weighted_against_numpy( continue expected = xp.asarray(expected) + a = xp.asarray(a_np) w = xp.asarray(w_np) actual = quantile( a, @@ -1640,19 +1688,7 @@ def test_weighted_against_numpy( keepdims=keepdims, nan_policy=nan_policy, ) - xp_assert_close(actual, expected, atol=1e-12) - - def test_2d_axis(self, xp: ModuleType): - x = xp.asarray([[1, 2, 3], [4, 5, 6]]) - actual = quantile(x, 0.5, axis=0) - expect = xp.asarray([2.5, 3.5, 4.5], dtype=xp.float64) - xp_assert_close(actual, expect) - - def test_2d_axis_keepdims(self, xp: ModuleType): - x = xp.asarray([[1, 2, 3], [4, 5, 6]]) - actual = quantile(x, 0.5, axis=0, keepdims=True) - expect = xp.asarray([[2.5, 3.5, 4.5]], dtype=xp.float64) - xp_assert_close(actual, expect) + xp_assert_close(actual, expected) def test_methods(self, xp: ModuleType): x = xp.asarray([1, 2, 3, 4, 5]) From 3226659e0e5c942f2958150f3ea0041fc82e4a28 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 16:14:48 +0200 Subject: [PATCH 17/25] working on coverage --- tests/test_funcs.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index dd39c7d0..7204733b 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1531,8 +1531,10 @@ def test_kind(self, xp: ModuleType, library: Backend): res = isin(a, b, kind="sort") xp_assert_equal(res, expected) + METHOD = Literal["linear", "inverted_cdf", "averaged_inverted_cdf"] + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no xp.take") class TestQuantile: def test_basic(self, xp: ModuleType): @@ -1578,8 +1580,9 @@ def test_against_numpy_1d(self, xp: ModuleType, with_nans: str, method: METHOD): @pytest.mark.parametrize("with_nans", ["no_nans", "with_nans"]) @pytest.mark.parametrize("method", get_args(METHOD)) @pytest.mark.parametrize("keepdims", [True, False]) - def test_against_numpy_nd(self, xp: ModuleType, keepdims: bool, - with_nans: str, method: METHOD): + def test_against_numpy_nd( + self, xp: ModuleType, keepdims: bool, with_nans: str, method: METHOD + ): rng = np.random.default_rng() a_np = rng.random((3, 4, 5)) if with_nans == "with_nans": @@ -1598,7 +1601,10 @@ def test_against_numpy_nd(self, xp: ModuleType, keepdims: bool, @pytest.mark.parametrize("nan_policy", ["no_nans", "propagate"]) @pytest.mark.parametrize("with_weights", ["with_weights", "no_weights"]) def test_against_median( - self, xp: ModuleType, nan_policy: str, with_weights: str, + self, + xp: ModuleType, + nan_policy: str, + with_weights: str, ): rng = np.random.default_rng() n = 40 From c395b84438e747711bd1c3672abe725110f6067a Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 16:39:20 +0200 Subject: [PATCH 18/25] more coverage --- src/array_api_extra/_delegation.py | 7 ++-- tests/test_funcs.py | 52 ++++++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 8be2c990..deafa6b8 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -1070,9 +1070,6 @@ def quantile( if nan_policy not in nan_policies: msg = f"`nan_policy` must be one of {nan_policies}" raise ValueError(msg) - if keepdims not in {True, False}: - msg = "If specified, `keepdims` must be True or False." - raise ValueError(msg) a = xp.asarray(a) if not xp.isdtype(a.dtype, ("integral", "real floating")): @@ -1090,7 +1087,7 @@ def quantile( raise ValueError(msg) if weights is None: if nan_policy != "propagate": - msg = "" + msg = "When `weights` aren't provided, `nan_policy` must be 'propagate'" raise ValueError(msg) else: if ndim > 2: @@ -1107,7 +1104,7 @@ def quantile( ) raise ValueError(msg) if axis is None and ndim == 2: - msg = "When weights are provided, axis must be specified when `a` is 2d" + msg = "Axis must be specified when `a` and ̀ weights` are 2d." raise ValueError(msg) # Align result dtype with what numpy does: diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 7204733b..374a10ff 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1543,6 +1543,12 @@ def test_basic(self, xp: ModuleType): expect = xp.asarray(3.0, dtype=xp.float64) xp_assert_close(actual, expect) + def test_xp(self, xp: ModuleType): + x = xp.asarray([1, 2, 3, 4, 5]) + actual = quantile(x, 0.5, xp=xp) + expect = xp.asarray(3.0, dtype=xp.float64) + xp_assert_close(actual, expect) + def test_multiple_quantiles(self, xp: ModuleType): x = xp.asarray([1, 2, 3, 4, 5]) actual = quantile(x, xp.asarray([0.25, 0.5, 0.75])) @@ -1729,15 +1735,49 @@ def test_invalid_q(self, xp: ModuleType): ): _ = quantile(x, -0.5) + def test_invalid_shape(self, xp: ModuleType): + with pytest.raises(TypeError, match="at least 1-dimensional"): + _ = quantile(xp.asarray(3.0), 0.5) + with pytest.raises(ValueError, match="not compatible with the dimension"): + _ = quantile(xp.asarray([3.0]), 0.5, axis=1) + # with weights: + method = "inverted_cdf" + shape = (2, 3, 4) + with pytest.raises(ValueError, match="dimension of `a` must be 1 or 2"): + _ = quantile( + xp.ones(shape), 0.5, axis=1, weights=xp.ones(shape), method=method + ) + with pytest.raises(TypeError, match="Axis must be specified"): + _ = quantile(xp.ones((2, 3)), 0.5, weights=xp.ones(3), method=method) + with pytest.raises(ValueError, match="Shape of weights must be consistent"): + _ = quantile( + xp.ones((2, 3)), 0.5, axis=0, weights=xp.ones(3), method=method + ) + with pytest.raises(ValueError, match="Axis must be specified"): + _ = quantile(xp.ones((2, 3)), 0.5, weights=xp.ones((2, 3)), method=method) + + def test_invalid_dtype(self, xp: ModuleType): + with pytest.raises(ValueError, match="`a` must have real dtype"): + _ = quantile(xp.ones(5, dtype=xp.bool), 0.5) + + with pytest.raises(ValueError, match="`q` must have real floating dtype"): + _ = quantile(xp.ones(5), xp.asarray([0, 1])) + + def test_invalid_method(self, xp: ModuleType): + with pytest.raises(ValueError, match="`method` must be one of"): + _ = quantile(xp.ones(5), 0.5, method="invalid") + # TODO: with weights? + + def test_invalid_nan_policy(self, xp: ModuleType): + with pytest.raises(ValueError, match="`nan_policy` must be one of"): + _ = quantile(xp.ones(5), 0.5, nan_policy="invalid") + + with pytest.raises(ValueError, match="must be 'propagate'"): + _ = quantile(xp.ones(5), 0.5, nan_policy="omit") + def test_device(self, xp: ModuleType, device: Device): if hasattr(device, "type") and device.type == "meta": # pyright: ignore[reportAttributeAccessIssue] pytest.xfail("No Tensor.item() on meta device") x = xp.asarray([1, 2, 3, 4, 5], device=device) actual = quantile(x, 0.5) assert get_device(actual) == device - - def test_xp(self, xp: ModuleType): - x = xp.asarray([1, 2, 3, 4, 5]) - actual = quantile(x, 0.5, xp=xp) - expect = xp.asarray(3.0, dtype=xp.float64) - xp_assert_close(actual, expect) From 07f70075c575fb39ca66ecf703692bde97340a17 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 16:53:35 +0200 Subject: [PATCH 19/25] fix test --- tests/test_funcs.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 374a10ff..912a4457 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1621,10 +1621,14 @@ def test_against_median( else: # from 0% to 50% of NaNs: a_np[rng.random(n) < rng.random(n) * 0.5] = np.nan + if w_np is not None: + # ensure at least one NaN on non-null weight: + a_np[w_np > 0][0] = np.nan m = "averaged_inverted_cdf" np_median = np.nanmedian if nan_policy == "omit" else np.median - expected = np_median(a_np if w_np is None else a_np[w_np > 0]) + a_np_med = a_np if w_np is None else a_np[w_np > 0] + expected = np_median(a_np_med) a = xp.asarray(a_np) w = xp.asarray(w_np) if w_np is not None else None actual = quantile(a, 0.5, method=m, nan_policy=nan_policy, weights=w) From e3195299367e8c6dab047d9a21d9fb1493a48e22 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 16:57:27 +0200 Subject: [PATCH 20/25] second attempt to fix test --- tests/test_funcs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 912a4457..e9105163 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1623,7 +1623,8 @@ def test_against_median( a_np[rng.random(n) < rng.random(n) * 0.5] = np.nan if w_np is not None: # ensure at least one NaN on non-null weight: - a_np[w_np > 0][0] = np.nan + nz_weights_idx, = np.where(w_np > 0) + a_np[nz_weights_idx[0]] = np.nan m = "averaged_inverted_cdf" np_median = np.nanmedian if nan_policy == "omit" else np.median From 8ab7d62c5d1a4dd31c27527fb210ccec1516758d Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 18:04:50 +0200 Subject: [PATCH 21/25] more validation --- src/array_api_extra/_delegation.py | 10 +++++++++- tests/test_funcs.py | 17 ++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index deafa6b8..c46205e3 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -1078,6 +1078,8 @@ def quantile( if not xp.isdtype(xp.asarray(q).dtype, "real floating"): msg = "`q` must have real floating dtype." raise ValueError(msg) + weights = None if weights is None else xp.asarray(weights) + ndim = a.ndim if ndim < 1: msg = "`a` must be at least 1-dimensional." @@ -1087,9 +1089,15 @@ def quantile( raise ValueError(msg) if weights is None: if nan_policy != "propagate": - msg = "When `weights` aren't provided, `nan_policy` must be 'propagate'" + msg = "When `weights` aren't provided, `nan_policy` must be 'propagate'." raise ValueError(msg) else: + if method not in {"inverted_cdf", "averaged_inverted_cdf"}: + msg = f"`method` '{method}' not supported with weights." + raise ValueError(msg) + if not xp.isdtype(weights.dtype, ("integral", "real floating")): + msg = "`weights` must have real dtype." + raise ValueError(msg) if ndim > 2: msg = "When weights are provided, dimension of `a` must be 1 or 2." raise ValueError(msg) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index e9105163..5c318dde 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1623,7 +1623,7 @@ def test_against_median( a_np[rng.random(n) < rng.random(n) * 0.5] = np.nan if w_np is not None: # ensure at least one NaN on non-null weight: - nz_weights_idx, = np.where(w_np > 0) + (nz_weights_idx,) = np.where(w_np > 0) a_np[nz_weights_idx[0]] = np.nan m = "averaged_inverted_cdf" @@ -1747,17 +1747,21 @@ def test_invalid_shape(self, xp: ModuleType): _ = quantile(xp.asarray([3.0]), 0.5, axis=1) # with weights: method = "inverted_cdf" + shape = (2, 3, 4) with pytest.raises(ValueError, match="dimension of `a` must be 1 or 2"): _ = quantile( xp.ones(shape), 0.5, axis=1, weights=xp.ones(shape), method=method ) + with pytest.raises(TypeError, match="Axis must be specified"): _ = quantile(xp.ones((2, 3)), 0.5, weights=xp.ones(3), method=method) + with pytest.raises(ValueError, match="Shape of weights must be consistent"): _ = quantile( xp.ones((2, 3)), 0.5, axis=0, weights=xp.ones(3), method=method ) + with pytest.raises(ValueError, match="Axis must be specified"): _ = quantile(xp.ones((2, 3)), 0.5, weights=xp.ones((2, 3)), method=method) @@ -1765,13 +1769,20 @@ def test_invalid_dtype(self, xp: ModuleType): with pytest.raises(ValueError, match="`a` must have real dtype"): _ = quantile(xp.ones(5, dtype=xp.bool), 0.5) + a = xp.ones(5) with pytest.raises(ValueError, match="`q` must have real floating dtype"): - _ = quantile(xp.ones(5), xp.asarray([0, 1])) + _ = quantile(a, xp.asarray([0, 1])) + + weights = xp.ones(5, dtype=xp.bool) + with pytest.raises(ValueError, match="`weights` must have real dtype"): + _ = quantile(a, 0.5, weights=weights, method="inverted_cdf") def test_invalid_method(self, xp: ModuleType): with pytest.raises(ValueError, match="`method` must be one of"): _ = quantile(xp.ones(5), 0.5, method="invalid") - # TODO: with weights? + + with pytest.raises(ValueError, match="not supported with weights"): + _ = quantile(xp.ones(5), 0.5, method="linear", weights=xp.ones(5)) def test_invalid_nan_policy(self, xp: ModuleType): with pytest.raises(ValueError, match="`nan_policy` must be one of"): From 1b48267179313e8e29a0cee6d31c36cd2ab379e7 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 23 Oct 2025 19:47:06 +0200 Subject: [PATCH 22/25] some more tests --- tests/test_funcs.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 5c318dde..ed2f8a36 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1606,7 +1606,7 @@ def test_against_numpy_nd( @pytest.mark.parametrize("nan_policy", ["no_nans", "propagate"]) @pytest.mark.parametrize("with_weights", ["with_weights", "no_weights"]) - def test_against_median( + def test_against_median_min_max( self, xp: ModuleType, nan_policy: str, @@ -1625,16 +1625,28 @@ def test_against_median( # ensure at least one NaN on non-null weight: (nz_weights_idx,) = np.where(w_np > 0) a_np[nz_weights_idx[0]] = np.nan - m = "averaged_inverted_cdf" - np_median = np.nanmedian if nan_policy == "omit" else np.median a_np_med = a_np if w_np is None else a_np[w_np > 0] - expected = np_median(a_np_med) a = xp.asarray(a_np) w = xp.asarray(w_np) if w_np is not None else None - actual = quantile(a, 0.5, method=m, nan_policy=nan_policy, weights=w) + + np_median = np.nanmedian if nan_policy == "omit" else np.median + expected = np_median(a_np_med) + method = "averaged_inverted_cdf" + actual = quantile(a, 0.5, method=method, nan_policy=nan_policy, weights=w) xp_assert_close(actual, xp.asarray(expected)) + for method in ["inverted_cdf", "averaged_inverted_cdf"]: + np_min = np.nanmin if nan_policy == "omit" else np.min + expected = np_min(a_np_med) + actual = quantile(a, 0., method=method, nan_policy=nan_policy, weights=w) + xp_assert_close(actual, xp.asarray(expected)) + + np_max = np.nanmax if nan_policy == "omit" else np.max + expected = np_max(a_np_med) + actual = quantile(a, 1., method=method, nan_policy=nan_policy, weights=w) + xp_assert_close(actual, xp.asarray(expected)) + @pytest.mark.parametrize("keepdims", [True, False]) @pytest.mark.parametrize("nan_policy", ["no_nans", "propagate", "omit"]) @pytest.mark.parametrize("q_np", [0.5, 0.0, 1.0, np.linspace(0, 1, num=11)]) From c71351f17542db46675c82371a4c911423f84404 Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Sat, 25 Oct 2025 11:34:50 +0200 Subject: [PATCH 23/25] Fix typo in err msg Co-authored-by: Mathias Hauser --- src/array_api_extra/_delegation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index c46205e3..cd8c3989 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -948,7 +948,7 @@ def quantile( nan_policy : str, optional 'propagate' (default) or 'omit'. - 'omit' is support only when `weights` are provided. + 'omit' is supported only when `weights` are provided. weights : array_like, optional An array of weights associated with the values in `a`. Each value in From ce55335b3ba2acab411ff433f565c7706a46c9d2 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 25 Oct 2025 12:43:58 +0200 Subject: [PATCH 24/25] avoid sorting a; just sort the weights --- src/array_api_extra/_lib/_quantile.py | 44 ++++++++++++++++----------- tests/test_funcs.py | 4 +-- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index f3f7fefb..4d50dfd4 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -130,18 +130,18 @@ def _weighted_quantile( # numpydoc ignore=PR01,RT01 sorter = xp.argsort(a, axis=-1, stable=False) if a.ndim == 1: - x = xp.take(a, sorter) - w = xp.take(weights, sorter) - return _weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device) + return _weighted_quantile_sorted_1d( + a, weights, sorter, q, n, average, nan_policy, xp, device + ) (d,) = eager_shape(a, axis=0) res = [] for idx in range(d): w = weights if weights.ndim == 1 else weights[idx, ...] - w = xp.take(w, sorter[idx, ...]) - x = xp.take(a[idx, ...], sorter[idx, ...]) res.append( - _weighted_quantile_sorted_1d(x, q, w, n, average, nan_policy, xp, device) + _weighted_quantile_sorted_1d( + a[idx, ...], w, sorter[idx, ...], q, n, average, nan_policy, xp, device + ) ) return xp.stack(res, axis=1) @@ -149,8 +149,9 @@ def _weighted_quantile( # numpydoc ignore=PR01,RT01 def _weighted_quantile_sorted_1d( # numpydoc ignore=GL08 x: Array, - q: Array, w: Array, + sorter: Array, + q: Array, n: int, average: bool, nan_policy: str, @@ -161,18 +162,25 @@ def _weighted_quantile_sorted_1d( # numpydoc ignore=GL08 w = xp.where(xp.isnan(x), 0.0, w) elif xp.any(xp.isnan(x)): return xp.full(q.shape, xp.nan, dtype=x.dtype, device=device) - cdf = xp.cumulative_sum(w) + + cdf = xp.cumulative_sum(xp.take(w, sorter)) t = cdf[-1] * q + i = xp.searchsorted(cdf, t, side="left") - j = xp.searchsorted(cdf, t, side="right") i = xp.clip(i, 0, n - 1) - j = xp.clip(j, 0, n - 1) - - # Ignore leading `weights=0` observations when `q=0` - # see https://github.com/scikit-learn/scikit-learn/pull/20528 - i = xp.where(q == 0.0, j, i) - if average: - # Ignore trailing `weights=0` observations when `q=1` - j = xp.where(q == 1.0, i, j) - return (xp.take(x, i) + xp.take(x, j)) / 2 + i = xp.take(sorter, i) + + q0 = q == 0.0 + if average or xp.any(q0): + j = xp.searchsorted(cdf, t, side="right") + j = xp.clip(j, 0, n - 1) + j = xp.take(sorter, j) + # Ignore leading `weights=0` observations when `q=0` + i = xp.where(q0, j, i) + + if average: + # Ignore trailing `weights=0` observations when `q=1` + j = xp.where(q == 1.0, i, j) + return (xp.take(x, i) + xp.take(x, j)) / 2 + return xp.take(x, i) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index ed2f8a36..09a5babe 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1639,12 +1639,12 @@ def test_against_median_min_max( for method in ["inverted_cdf", "averaged_inverted_cdf"]: np_min = np.nanmin if nan_policy == "omit" else np.min expected = np_min(a_np_med) - actual = quantile(a, 0., method=method, nan_policy=nan_policy, weights=w) + actual = quantile(a, 0.0, method=method, nan_policy=nan_policy, weights=w) xp_assert_close(actual, xp.asarray(expected)) np_max = np.nanmax if nan_policy == "omit" else np.max expected = np_max(a_np_med) - actual = quantile(a, 1., method=method, nan_policy=nan_policy, weights=w) + actual = quantile(a, 1.0, method=method, nan_policy=nan_policy, weights=w) xp_assert_close(actual, xp.asarray(expected)) @pytest.mark.parametrize("keepdims", [True, False]) From 7c18a82d7e1603e07bcb4970cf548f671457cb29 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 26 Oct 2025 15:33:45 +0100 Subject: [PATCH 25/25] return max values when all weights are null --- src/array_api_extra/_lib/_quantile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_quantile.py b/src/array_api_extra/_lib/_quantile.py index 4d50dfd4..df23e94f 100644 --- a/src/array_api_extra/_lib/_quantile.py +++ b/src/array_api_extra/_lib/_quantile.py @@ -170,7 +170,7 @@ def _weighted_quantile_sorted_1d( # numpydoc ignore=GL08 i = xp.clip(i, 0, n - 1) i = xp.take(sorter, i) - q0 = q == 0.0 + q0 = t == 0.0 if average or xp.any(q0): j = xp.searchsorted(cdf, t, side="right") j = xp.clip(j, 0, n - 1)