diff --git a/doc/sources/array_api.rst b/doc/sources/array_api.rst index b2eb7a8bee..9ed34ea49e 100644 --- a/doc/sources/array_api.rst +++ b/doc/sources/array_api.rst @@ -96,6 +96,10 @@ The following patched classes have support for array API inputs: - :obj:`sklearn.linear_model.Ridge` - :obj:`sklearnex.linear_model.IncrementalLinearRegression` - :obj:`sklearnex.linear_model.IncrementalRidge` +- :obj:`sklearn.neighbors.KNeighborsClassifier` +- :obj:`sklearn.neighbors.KNeighborsRegressor` +- :obj:`sklearn.neighbors.NearestNeighbors` +- :obj:`sklearn.neighbors.LocalOutlierFactor` .. note:: While full array API support is currently not implemented for all classes, :external+dpnp:doc:`dpnp.ndarray ` diff --git a/onedal/neighbors/neighbors.py b/onedal/neighbors/neighbors.py index e952dddebf..d19d91abeb 100755 --- a/onedal/neighbors/neighbors.py +++ b/onedal/neighbors/neighbors.py @@ -15,27 +15,14 @@ # ============================================================================== from abc import ABCMeta, abstractmethod -from numbers import Integral - -import numpy as np from onedal._device_offload import supports_queue from onedal.common._backend import bind_default_backend from onedal.utils import _sycl_queue_manager as QM -from .._config import _get_config from ..common._estimator_checks import _check_is_fitted, _is_classifier, _is_regressor from ..common._mixin import ClassifierMixin, RegressorMixin from ..datatypes import from_table, to_table -from ..utils._array_api import _get_sycl_namespace -from ..utils.validation import ( - _check_array, - _check_classification_targets, - _check_n_features, - _check_X_y, - _column_or_1d, - _num_samples, -) class NeighborsCommonBase(metaclass=ABCMeta): @@ -76,66 +63,6 @@ def infer(self, *args, **kwargs): ... @abstractmethod def _onedal_fit(self, X, y): ... - def _validate_data( - self, X, y=None, reset=True, validate_separately=None, **check_params - ): - if y is None: - if self.requires_y: - raise ValueError( - f"This {self.__class__.__name__} estimator " - f"requires y to be passed, but the target y is None." - ) - X = _check_array(X, **check_params) - out = X, y - else: - if validate_separately: - # We need this because some estimators validate X and y - # separately, and in general, separately calling _check_array() - # on X and y isn't equivalent to just calling _check_X_y() - # :( - check_X_params, check_y_params = validate_separately - X = _check_array(X, **check_X_params) - y = _check_array(y, **check_y_params) - else: - X, y = _check_X_y(X, y, **check_params) - out = X, y - - if check_params.get("ensure_2d", True): - _check_n_features(self, X, reset=reset) - - return out - - def _get_weights(self, dist, weights): - if weights in (None, "uniform"): - return None - if weights == "distance": - # if user attempts to classify a point that was zero distance from one - # or more training points, those training points are weighted as 1.0 - # and the other points as 0.0 - if dist.dtype is np.dtype(object): - for point_dist_i, point_dist in enumerate(dist): - # check if point_dist is iterable - # (ex: RadiusNeighborClassifier.predict may set an element of - # dist to 1e-6 to represent an 'outlier') - if hasattr(point_dist, "__contains__") and 0.0 in point_dist: - dist[point_dist_i] = point_dist == 0.0 - else: - dist[point_dist_i] = 1.0 / point_dist - else: - with np.errstate(divide="ignore"): - dist = 1.0 / dist - inf_mask = np.isinf(dist) - inf_row = np.any(inf_mask, axis=1) - dist[inf_row] = inf_mask[inf_row] - return dist - elif callable(weights): - return weights(dist) - else: - raise ValueError( - "weights not recognized: should be 'uniform', " - "'distance', or a callable function" - ) - def _get_onedal_params(self, X, y=None, n_neighbors=None): class_count = 0 if self.classes_ is None else len(self.classes_) weights = getattr(self, "weights", "uniform") @@ -176,78 +103,47 @@ def __init__( self.p = p self.metric_params = metric_params - def _validate_targets(self, y, dtype): - arr = _column_or_1d(y, warn=True) - - try: - return arr.astype(dtype, copy=False) - except ValueError: - return arr - - def _validate_n_classes(self): - length = 0 if self.classes_ is None else len(self.classes_) - if length < 2: - raise ValueError( - f"The number of classes has to be greater than one; got {length}" - ) - def _fit(self, X, y): self._onedal_model = None self._tree = None - self._shape = None - self.classes_ = None + # REFACTOR: Shape processing moved to sklearnex layer + # _shape should be set by _process_classification_targets or _process_regression_targets in sklearnex + # self._shape = None + if not hasattr(self, "_shape"): + self._shape = None + # REFACTOR STEP 1: Don't reset classes_ - it may have been set by sklearnex layer + # self.classes_ = None + if not hasattr(self, "classes_"): + self.classes_ = None self.effective_metric_ = getattr(self, "effective_metric_", self.metric) self.effective_metric_params_ = getattr( self, "effective_metric_params_", self.metric_params ) - _, xp, _ = _get_sycl_namespace(X) - use_raw_input = _get_config().get("use_raw_input", False) is True + # REFACTOR: _validate_data call commented out - validation now happens in sklearnex layer + # Original code kept for reference: + # use_raw_input = _get_config().get("use_raw_input", False) is True if y is not None or self.requires_y: - shape = getattr(y, "shape", None) - if not use_raw_input: - X, y = super()._validate_data( - X, y, dtype=[np.float64, np.float32], accept_sparse="csr" - ) - self._shape = shape if shape is not None else y.shape - + # REFACTOR: Classification target processing moved to sklearnex layer + # This code is now commented out - processing MUST happen in sklearnex before calling fit + # Assertion: Verify that sklearnex has done the preprocessing if _is_classifier(self): - if y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1: - self.outputs_2d_ = False - y = y.reshape((-1, 1)) - else: - self.outputs_2d_ = True - - _check_classification_targets(y) - self.classes_ = [] - self._y = np.empty(y.shape, dtype=int) - for k in range(self._y.shape[1]): - classes, self._y[:, k] = np.unique(y[:, k], return_inverse=True) - self.classes_.append(classes) - - if not self.outputs_2d_: - self.classes_ = self.classes_[0] - self._y = self._y.ravel() - - self._validate_n_classes() + # if not hasattr(self, "classes_") or self.classes_ is None: + # raise ValueError( + # "Classification target processing must be done in sklearnex layer before calling onedal fit. " + # "classes_ attribute is not set. This indicates the refactoring is incomplete." + # ) + if not hasattr(self, "_y") or self._y is None: + raise ValueError( + "Classification target processing must be done in sklearnex layer before calling onedal fit. " + "_y attribute is not set. This indicates the refactoring is incomplete." + ) else: + # For regressors, just store y self._y = y - elif not use_raw_input: - X, _ = super()._validate_data(X, dtype=[np.float64, np.float32]) - self.n_samples_fit_ = X.shape[0] self.n_features_in_ = X.shape[1] self._fit_X = X - - if self.n_neighbors is not None: - if self.n_neighbors <= 0: - raise ValueError("Expected n_neighbors > 0. Got %d" % self.n_neighbors) - if not isinstance(self.n_neighbors, Integral): - raise TypeError( - "n_neighbors does not take %s value, " - "enter integer value" % type(self.n_neighbors) - ) - self._fit_method = super()._parse_auto_method( self.algorithm, self.n_samples_fit_, self.n_features_in_ ) @@ -255,125 +151,40 @@ def _fit(self, X, y): _fit_y = None queue = QM.get_global_queue() gpu_device = queue is not None and queue.sycl_device.is_gpu - + # Just pass self._y as-is - sklearnex should have already reshaped it if _is_classifier(self) or (_is_regressor(self) and gpu_device): - _fit_y = self._validate_targets(self._y, X.dtype).reshape((-1, 1)) + _fit_y = self._y result = self._onedal_fit(X, _fit_y) - - if y is not None and _is_regressor(self): - self._y = y if self._shape is None else xp.reshape(y, self._shape) - self._onedal_model = result result = self return result def _kneighbors(self, X=None, n_neighbors=None, return_distance=True): - use_raw_input = _get_config().get("use_raw_input", False) is True - n_features = getattr(self, "n_features_in_", None) - shape = getattr(X, "shape", None) - if n_features and shape and len(shape) > 1 and shape[1] != n_features: - raise ValueError( - ( - f"X has {X.shape[1]} features, " - f"but kneighbors is expecting " - f"{n_features} features as input" - ) - ) + # Still need n_features for _parse_auto_method call later + # n_features = getattr(self, "n_features_in_", None) _check_is_fitted(self) if n_neighbors is None: n_neighbors = self.n_neighbors - elif n_neighbors <= 0: - raise ValueError("Expected n_neighbors > 0. Got %d" % n_neighbors) - else: - if not isinstance(n_neighbors, Integral): - raise TypeError( - "n_neighbors does not take %s value, " - "enter integer value" % type(n_neighbors) - ) - - if X is not None: - query_is_train = False - if not use_raw_input: - X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32]) - else: - query_is_train = True - X = self._fit_X - # Include an extra neighbor to account for the sample itself being - # returned, which is removed later - n_neighbors += 1 - - n_samples_fit = self.n_samples_fit_ - if n_neighbors > n_samples_fit: - if query_is_train: - n_neighbors -= 1 # ok to modify inplace because an error is raised - inequality_str = "n_neighbors < n_samples_fit" - else: - inequality_str = "n_neighbors <= n_samples_fit" - raise ValueError( - f"Expected {inequality_str}, but " - f"n_neighbors = {n_neighbors}, n_samples_fit = {n_samples_fit}, " - f"n_samples = {X.shape[0]}" # include n_samples for common tests - ) - chunked_results = None - method = self._parse_auto_method( - self._fit_method, self.n_samples_fit_, n_features - ) + # onedal now just returns raw results, sklearnex does all processing + # Following PCA pattern: simple onedal layer + if X is None: + X = self._fit_X + # onedal just calls backend and returns raw results + # All post-processing (kd_tree sorting, removing self, return_distance decision) moved to sklearnex params = super()._get_onedal_params(X, n_neighbors=n_neighbors) prediction_results = self._onedal_predict(self._onedal_model, X, params) distances = from_table(prediction_results.distances) indices = from_table(prediction_results.indices) - if method == "kd_tree": - for i in range(distances.shape[0]): - seq = distances[i].argsort() - indices[i] = indices[i][seq] - distances[i] = distances[i][seq] - - if return_distance: - results = distances, indices - else: - results = indices - - if chunked_results is not None: - if return_distance: - neigh_dist, neigh_ind = zip(*chunked_results) - results = np.vstack(neigh_dist), np.vstack(neigh_ind) - else: - results = np.vstack(chunked_results) - - if not query_is_train: - return results - - # If the query data is the same as the indexed data, we would like - # to ignore the first nearest neighbor of every sample, i.e - # the sample itself. - if return_distance: - neigh_dist, neigh_ind = results - else: - neigh_ind = results - - n_queries, _ = X.shape - sample_range = np.arange(n_queries)[:, None] - sample_mask = neigh_ind != sample_range - - # Corner case: When the number of duplicates are more - # than the number of neighbors, the first NN will not - # be the sample, but a duplicate. - # In that case mask the first duplicate. - dup_gr_nbrs = np.all(sample_mask, axis=1) - sample_mask[:, 0][dup_gr_nbrs] = False - - neigh_ind = np.reshape(neigh_ind[sample_mask], (n_queries, n_neighbors - 1)) - - if return_distance: - neigh_dist = np.reshape(neigh_dist[sample_mask], (n_queries, n_neighbors - 1)) - return neigh_dist, neigh_ind - return neigh_ind + # Always return both - sklearnex will decide what to return to user + results = distances, indices + # Return raw results - sklearnex will do all post-processing + return results class KNeighborsClassifier(NeighborsBase, ClassifierMixin): @@ -412,8 +223,10 @@ def infer(self, *args, **kwargs): ... def _onedal_fit(self, X, y): # global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function queue = QM.get_global_queue() - params = self._get_onedal_params(X, y) + # REFACTOR: Convert to table FIRST, then get params from table (following PCA pattern) + # This ensures dtype is normalized (array API dtype -> numpy dtype) X_table, y_table = to_table(X, y, queue=queue) + params = self._get_onedal_params(X_table, y) return self.train(params, X_table, y_table).model def _onedal_predict(self, model, X, params): @@ -429,77 +242,6 @@ def _onedal_predict(self, model, X, params): def fit(self, X, y, queue=None): return self._fit(X, y) - @supports_queue - def predict(self, X, queue=None): - use_raw_input = _get_config().get("use_raw_input", False) is True - if not use_raw_input: - X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32]) - onedal_model = getattr(self, "_onedal_model", None) - n_features = getattr(self, "n_features_in_", None) - n_samples_fit_ = getattr(self, "n_samples_fit_", None) - shape = getattr(X, "shape", None) - if n_features and shape and len(shape) > 1 and shape[1] != n_features: - raise ValueError( - ( - f"X has {X.shape[1]} features, " - f"but KNNClassifier is expecting " - f"{n_features} features as input" - ) - ) - - _check_is_fitted(self) - - self._fit_method = self._parse_auto_method( - self.algorithm, n_samples_fit_, n_features - ) - - self._validate_n_classes() - - params = self._get_onedal_params(X) - prediction_result = self._onedal_predict(onedal_model, X, params) - responses = from_table(prediction_result.responses) - - result = self.classes_.take(np.asarray(responses.ravel(), dtype=np.intp)) - return result - - @supports_queue - def predict_proba(self, X, queue=None): - neigh_dist, neigh_ind = self.kneighbors(X, queue=queue) - - classes_ = self.classes_ - _y = self._y - if not self.outputs_2d_: - _y = self._y.reshape((-1, 1)) - classes_ = [self.classes_] - - n_queries = _num_samples(X) - - weights = self._get_weights(neigh_dist, self.weights) - if weights is None: - weights = np.ones_like(neigh_ind) - - all_rows = np.arange(n_queries) - probabilities = [] - for k, classes_k in enumerate(classes_): - pred_labels = _y[:, k][neigh_ind] - proba_k = np.zeros((n_queries, classes_k.size)) - - # a simple ':' index doesn't work right - for i, idx in enumerate(pred_labels.T): # loop is O(n_neighbors) - proba_k[all_rows, idx] += weights[:, i] - - # normalize 'votes' into real [0,1] probabilities - normalizer = proba_k.sum(axis=1)[:, np.newaxis] - normalizer[normalizer == 0.0] = 1.0 - proba_k /= normalizer - - probabilities.append(proba_k) - - if not self.outputs_2d_: - probabilities = probabilities[0] - - return probabilities - @supports_queue def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None): return self._kneighbors(X, n_neighbors, return_distance) @@ -576,22 +318,14 @@ def fit(self, X, y, queue=None): def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None): return self._kneighbors(X, n_neighbors, return_distance) + # REFACTOR: Keep _predict_gpu for GPU backend support (called by sklearnex) + # This is the ONLY prediction method needed in onedal - it calls the backend directly + # All computation logic (weights, averaging, etc.) is in sklearnex def _predict_gpu(self, X): - use_raw_input = _get_config().get("use_raw_input", False) is True - if not use_raw_input: - X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32]) + # REFACTOR: Validation commented out - should be done in sklearnex layer before calling this onedal_model = getattr(self, "_onedal_model", None) n_features = getattr(self, "n_features_in_", None) n_samples_fit_ = getattr(self, "n_samples_fit_", None) - shape = getattr(X, "shape", None) - if n_features and shape and len(shape) > 1 and shape[1] != n_features: - raise ValueError( - ( - f"X has {X.shape[1]} features, " - f"but KNNClassifier is expecting " - f"{n_features} features as input" - ) - ) _check_is_fitted(self) @@ -607,39 +341,6 @@ def _predict_gpu(self, X): return result - def _predict_skl(self, X): - neigh_dist, neigh_ind = self.kneighbors(X) - - weights = self._get_weights(neigh_dist, self.weights) - - _y = self._y - if _y.ndim == 1: - _y = _y.reshape((-1, 1)) - - if weights is None: - y_pred = np.mean(_y[neigh_ind], axis=1) - else: - y_pred = np.empty((X.shape[0], _y.shape[1]), dtype=np.float64) - denom = np.sum(weights, axis=1) - - for j in range(_y.shape[1]): - num = np.sum(_y[neigh_ind, j] * weights, axis=1) - y_pred[:, j] = num / denom - - if self._y.ndim == 1: - y_pred = y_pred.ravel() - - return y_pred - - @supports_queue - def predict(self, X, queue=None): - gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) - is_uniform_weights = getattr(self, "weights", "uniform") == "uniform" - if gpu_device and is_uniform_weights: - return self._predict_gpu(X) - else: - return self._predict_skl(X) - class NearestNeighbors(NeighborsBase): def __init__( @@ -671,8 +372,11 @@ def infer(self, *arg, **kwargs): ... def _onedal_fit(self, X, y): # global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function queue = QM.get_global_queue() + # REFACTOR: Convert to table FIRST, then get params from table (following PCA pattern) + # This ensures dtype is normalized (array API dtype -> numpy dtype) + # Note: NearestNeighbors has no y, so only convert X to avoid y becoming a table + X = to_table(X, queue=queue) params = self._get_onedal_params(X, y) - X, y = to_table(X, y, queue=queue) return self.train(params, X).model def _onedal_predict(self, model, X, params): diff --git a/onedal/neighbors/tests/test_knn_classification.py b/onedal/neighbors/tests/test_knn_classification.py index d29bdab345..f3cf0b823a 100755 --- a/onedal/neighbors/tests/test_knn_classification.py +++ b/onedal/neighbors/tests/test_knn_classification.py @@ -19,15 +19,19 @@ from numpy.testing import assert_array_equal from sklearn import datasets -from onedal.neighbors import KNeighborsClassifier from onedal.tests.utils._device_selection import get_queues +# Classification processing now happens in sklearnex layer +from sklearnex.neighbors import KNeighborsClassifier + @pytest.mark.parametrize("queue", get_queues()) def test_iris(queue): + # queue parameter not used with sklearnex, but kept for test parametrization iris = datasets.load_iris() - clf = KNeighborsClassifier(2).fit(iris.data, iris.target, queue=queue) - assert clf.score(iris.data, iris.target, queue=queue) > 0.9 + clf = KNeighborsClassifier(2).fit(iris.data, iris.target) + score = clf.score(iris.data, iris.target) + assert score > 0.9 assert_array_equal(clf.classes_, np.sort(clf.classes_)) @@ -36,14 +40,13 @@ def test_pickle(queue): if queue and queue.sycl_device.is_gpu: pytest.skip("KNN classifier pickling for the GPU sycl_queue is buggy.") iris = datasets.load_iris() - clf = KNeighborsClassifier(2).fit(iris.data, iris.target, queue=queue) - expected = clf.predict(iris.data, queue=queue) - + clf = KNeighborsClassifier(2).fit(iris.data, iris.target) + expected = clf.predict(iris.data) import pickle dump = pickle.dumps(clf) clf2 = pickle.loads(dump) assert type(clf2) == clf.__class__ - result = clf2.predict(iris.data, queue=queue) + result = clf2.predict(iris.data) assert_array_equal(expected, result) diff --git a/sklearnex/neighbors/_lof.py b/sklearnex/neighbors/_lof.py index 63a98164e7..728d09b8c4 100644 --- a/sklearnex/neighbors/_lof.py +++ b/sklearnex/neighbors/_lof.py @@ -29,7 +29,7 @@ from sklearnex.neighbors.knn_unsupervised import NearestNeighbors from ..utils._array_api import get_namespace -from ..utils.validation import check_feature_names +from ..utils.validation import check_feature_names, validate_data @control_n_jobs(decorated_methods=["fit", "kneighbors", "_kneighbors"]) @@ -56,6 +56,7 @@ def _onedal_fit(self, X, y, queue=None): if sklearn_check_version("1.2"): self._validate_params() + # Let _onedal_knn_fit (NearestNeighbors._onedal_fit) handle validation self._onedal_knn_fit(X, y, queue=queue) if self.contamination != "auto": @@ -74,7 +75,6 @@ def _onedal_fit(self, X, y, queue=None): % (self.n_neighbors, n_samples) ) self.n_neighbors_ = max(1, min(self.n_neighbors, n_samples - 1)) - ( self._distances_fit_X_, _neighbors_indices_fit_X_, @@ -108,11 +108,10 @@ def _onedal_fit(self, X, y, queue=None): "Duplicate values are leading to incorrect results. " "Increase the number of neighbors for more accurate results." ) - return self def fit(self, X, y=None): - result = dispatch( + return dispatch( self, "fit", { @@ -122,7 +121,6 @@ def fit(self, X, y=None): X, None, ) - return result def _predict(self, X=None): check_is_fitted(self) @@ -135,7 +133,6 @@ def _predict(self, X=None): else: is_inlier = np.ones(self.n_samples_fit_, dtype=int) is_inlier[self.negative_outlier_factor_ < self.offset_] = -1 - return is_inlier # This had to be done because predict loses the queue when no @@ -149,9 +146,15 @@ def fit_predict(self, X, y=None): return self.fit(X)._predict() def _kneighbors(self, X=None, n_neighbors=None, return_distance=True): + # Validate n_neighbors parameter first + if n_neighbors is not None: + self._validate_n_neighbors(n_neighbors) + check_is_fitted(self) - if X is not None: - check_feature_names(self, X, reset=False) + + # Validate kneighbors parameters (inherited from KNeighborsDispatchingBase) + self._kneighbors_validation(X, n_neighbors) + return dispatch( self, "kneighbors", @@ -172,6 +175,16 @@ def _kneighbors(self, X=None, n_neighbors=None, return_distance=True): def score_samples(self, X): check_is_fitted(self) + # Validate and convert X + xp, _ = get_namespace(X) + X = validate_data( + self, + X, + dtype=[xp.float64, xp.float32], + accept_sparse="csr", + reset=False, + ) + distances_X, neighbors_indices_X = self._kneighbors( X, n_neighbors=self.n_neighbors_ ) diff --git a/sklearnex/neighbors/common.py b/sklearnex/neighbors/common.py index ed48c48e77..510096532a 100644 --- a/sklearnex/neighbors/common.py +++ b/sklearnex/neighbors/common.py @@ -14,7 +14,9 @@ # limitations under the License. # ============================================================================== +import sys import warnings +from numbers import Integral import numpy as np from scipy import sparse as sp @@ -24,21 +26,317 @@ from sklearn.neighbors._kd_tree import KDTree from sklearn.utils.validation import check_is_fitted +from daal4py.sklearn._n_jobs_support import control_n_jobs from daal4py.sklearn._utils import sklearn_check_version from onedal._device_offload import _transfer_to_host -from onedal.utils.validation import _check_array, _num_features, _num_samples +from onedal.utils._array_api import _is_numpy_namespace +from onedal.utils.validation import ( + _check_array, + _check_classification_targets, + _check_X_y, + _column_or_1d, + _num_features, + _num_samples, +) from .._utils import PatchingConditionsChain from ..base import oneDALEstimator from ..utils._array_api import get_namespace -from ..utils.validation import check_feature_names +from ..utils.validation import check_feature_names, validate_data class KNeighborsDispatchingBase(oneDALEstimator): - def _fit_validation(self, X, y=None): - if sklearn_check_version("1.2"): - self._validate_params() - check_feature_names(self, X, reset=True) + def _parse_auto_method(self, method, n_samples, n_features): + result_method = method + + if method in ["auto", "ball_tree"]: + condition = ( + self.n_neighbors is not None and self.n_neighbors >= n_samples // 2 + ) + if self.metric == "precomputed" or n_features > 15 or condition: + result_method = "brute" + else: + if self.metric == "euclidean": + result_method = "kd_tree" + else: + result_method = "brute" + + return result_method + + def _get_weights(self, dist, weights): + if weights in (None, "uniform"): + return None + if weights == "distance": + # Array API support: get namespace from dist array + xp, _ = get_namespace(dist) + # if user attempts to classify a point that was zero distance from one + # or more training points, those training points are weighted as 1.0 + # and the other points as 0.0 + # Check for object dtype - use string comparison for Array API compatibility + is_object_dtype = str(dist.dtype) == "object" or ( + hasattr(dist.dtype, "kind") and dist.dtype.kind == "O" + ) + if is_object_dtype: + for point_dist_i, point_dist in enumerate(dist): + # check if point_dist is iterable + # (ex: RadiusNeighborClassifier.predict may set an element of + # dist to 1e-6 to represent an 'outlier') + if hasattr(point_dist, "__contains__") and 0.0 in point_dist: + dist[point_dist_i] = point_dist == 0.0 + else: + dist[point_dist_i] = 1.0 / point_dist + else: + with ( + xp.errstate(divide="ignore") + if hasattr(xp, "errstate") + else np.errstate(divide="ignore") + ): + dist = 1.0 / dist + inf_mask = xp.isinf(dist) + inf_row = xp.any(inf_mask, axis=1) + dist[inf_row] = inf_mask[inf_row] + return dist + elif callable(weights): + return weights(dist) + else: + raise ValueError( + "weights not recognized: should be 'uniform', " + "'distance', or a callable function" + ) + + def _compute_weighted_prediction(self, neigh_dist, neigh_ind, weights_param, y_train): + """Compute weighted prediction for regression. + + Args: + neigh_dist: Distances to neighbors + neigh_ind: Indices of neighbors + weights_param: Weight parameter ('uniform', 'distance', or callable) + y_train: Training target values + + Returns: + Predicted values + """ + # Array API support: get namespace from input arrays + xp, _ = get_namespace(neigh_dist, neigh_ind, y_train) + + weights = self._get_weights(neigh_dist, weights_param) + + _y = y_train + if _y.ndim == 1: + _y = xp.reshape(_y, (-1, 1)) + + if weights is None: + # Array API: Use take() per row since array API take() only supports 1-D indices + # Build result by gathering rows one at a time + gathered_list = [] + for i in range(neigh_ind.shape[0]): + # Get indices for this sample's neighbors + sample_indices = neigh_ind[i, ...] # Shape: (n_neighbors,) + # Gather those rows from _y + sample_neighbors = xp.take( + _y, sample_indices, axis=0 + ) # Shape: (n_neighbors, n_outputs) + gathered_list.append(sample_neighbors) + # Stack and compute mean + gathered = xp.stack( + gathered_list, axis=0 + ) # Shape: (n_samples, n_neighbors, n_outputs) + y_pred = xp.mean(gathered, axis=1) + else: + # Create y_pred array - matches original onedal implementation using empty() + # For Array API arrays (dpctl/dpnp), pass device parameter to match input device + # For numpy arrays, device parameter is not supported and not needed + y_pred_shape = (neigh_ind.shape[0], _y.shape[1]) + if not _is_numpy_namespace(xp): + # Array API: pass device to ensure same device as input + y_pred = xp.empty(y_pred_shape, dtype=xp.float64, device=neigh_ind.device) + else: + # Numpy: no device parameter + y_pred = xp.empty(y_pred_shape, dtype=xp.float64) + denom = xp.sum(weights, axis=1) + + for j in range(_y.shape[1]): + # Array API: Iterate over samples to gather values + y_col_j = _y[:, j, ...] # Shape: (n_train_samples,) + gathered_vals = [] + for i in range(neigh_ind.shape[0]): + sample_indices = neigh_ind[i, ...] # Shape: (n_neighbors,) + sample_vals = xp.take( + y_col_j, sample_indices, axis=0 + ) # Shape: (n_neighbors,) + gathered_vals.append(sample_vals) + gathered_j = xp.stack( + gathered_vals, axis=0 + ) # Shape: (n_samples, n_neighbors) + num = xp.sum(gathered_j * weights, axis=1) + y_pred[:, j, ...] = num / denom + + if y_train.ndim == 1: + y_pred = xp.reshape(y_pred, (-1,)) + + return y_pred + + def _compute_class_probabilities( + self, neigh_dist, neigh_ind, weights_param, y_train, classes, outputs_2d + ): + """Compute class probabilities for classification. + + Args: + neigh_dist: Distances to neighbors + neigh_ind: Indices of neighbors + weights_param: Weight parameter ('uniform', 'distance', or callable) + y_train: Encoded training labels + classes: Class labels + outputs_2d: Whether output is 2D (multi-output) + + Returns: + Class probabilities + """ + from ..utils.validation import _num_samples + + # Transfer all arrays to host to ensure they're on the same queue/device + # This is needed especially for SPMD where arrays might be on different queues + _, (neigh_dist, neigh_ind, y_train) = _transfer_to_host( + neigh_dist, neigh_ind, y_train + ) + + # After transfer, get the array namespace (will be numpy for host arrays) + xp, _ = get_namespace(neigh_dist, neigh_ind, y_train) + + _y = y_train + classes_ = classes + if not outputs_2d: + _y = xp.reshape(y_train, (-1, 1)) + classes_ = [classes] + + n_queries = neigh_ind.shape[0] + + weights = self._get_weights(neigh_dist, weights_param) + if weights is None: + # REFACTOR: Ensure weights is float for array API type promotion + # neigh_ind is int, so ones_like would give int, but we need float + weights = xp.ones_like(neigh_ind, dtype=xp.float64) + + probabilities = [] + for k, classes_k in enumerate(classes_): + # Get predicted labels for each neighbor: shape (n_samples, n_neighbors) + # _y[:, k] gives training labels for output k, then gather using neigh_ind + y_col_k = _y[:, k, ...] + + # Array API: Use take() with iteration since take() only supports 1-D indices + pred_labels_list = [] + for i in range(neigh_ind.shape[0]): + sample_indices = neigh_ind[i, ...] + sample_labels = xp.take(y_col_k, sample_indices, axis=0) + pred_labels_list.append(sample_labels) + pred_labels = xp.stack( + pred_labels_list, axis=0 + ) # Shape: (n_queries, n_neighbors) + + proba_k = xp.zeros((n_queries, classes_k.size), dtype=xp.float64) + + # Array API: Cannot use fancy indexing __setitem__ like proba_k[all_rows, idx] = ... + # Instead, build probabilities sample by sample + proba_list = [] + zero_weight = xp.asarray(0.0, dtype=xp.float64) + for sample_idx in range(n_queries): + sample_proba = xp.zeros((classes_k.size,), dtype=xp.float64) + # For this sample, accumulate weights for each neighbor's predicted class + for neighbor_idx in range(pred_labels.shape[1]): + class_label = int(pred_labels[sample_idx, neighbor_idx]) + weight = weights[sample_idx, neighbor_idx] + # Update probability for this class using array indexing + # Create a mask for this class and add weight where mask is True + mask = xp.arange(classes_k.size) == class_label + sample_proba = sample_proba + xp.where(mask, weight, zero_weight) + proba_list.append(sample_proba) + proba_k = xp.stack(proba_list, axis=0) # Shape: (n_queries, n_classes) + + # normalize 'votes' into real [0,1] probabilities + normalizer = xp.sum(proba_k, axis=1)[:, xp.newaxis] + # Use array scalar for comparison and assignment + zero_scalar = xp.asarray(0.0, dtype=xp.float64) + one_scalar = xp.asarray(1.0, dtype=xp.float64) + normalizer = xp.where(normalizer == zero_scalar, one_scalar, normalizer) + proba_k /= normalizer + + probabilities.append(proba_k) + + if not outputs_2d: + probabilities = probabilities[0] + + return probabilities + + def _predict_skl_regression(self, X): + """SKL prediction path for regression - calls kneighbors, computes predictions. + + This method handles X=None (LOOCV) properly by calling self.kneighbors which + has the query_is_train logic. + + Args: + X: Query samples (or None for LOOCV) + Returns: + Predicted regression values + """ + neigh_dist, neigh_ind = self.kneighbors(X) + return self._compute_weighted_prediction( + neigh_dist, neigh_ind, self.weights, self._y + ) + + def _predict_skl_classification(self, X): + """SKL prediction path for classification - calls kneighbors, computes predictions. + + This method handles X=None (LOOCV) properly by calling self.kneighbors which + has the query_is_train logic. + + Args: + X: Query samples (or None for LOOCV) + Returns: + Predicted class labels + """ + neigh_dist, neigh_ind = self.kneighbors(X) + proba = self._compute_class_probabilities( + neigh_dist, neigh_ind, self.weights, self._y, self.classes_, self.outputs_2d_ + ) + # Array API support: get namespace from probability array + xp, _ = get_namespace(proba) + + if not self.outputs_2d_: + # Single output: classes_[argmax(proba, axis=1)] + return self.classes_[xp.argmax(proba, axis=1)] + else: + # Multi-output: apply argmax separately for each output + result = [ + classes_k[xp.argmax(proba_k, axis=1)] + for classes_k, proba_k in zip(self.classes_, proba.T) + ] + return xp.asarray(result).T + + def _validate_targets(self, y, dtype): + arr = _column_or_1d(y, warn=True) + + try: + return arr.astype(dtype, copy=False) + except ValueError: + return arr + + def _validate_n_neighbors(self, n_neighbors): + if n_neighbors is not None: + if n_neighbors <= 0: + raise ValueError("Expected n_neighbors > 0. Got %d" % n_neighbors) + if not isinstance(n_neighbors, Integral): + raise TypeError( + "n_neighbors does not take %s value, " + "enter integer value" % type(n_neighbors) + ) + + def _set_effective_metric(self): + """Set effective_metric_ and effective_metric_params_ without validation. + + Used when we need to set metrics but can't call _fit_validation + (e.g., in SPMD mode with use_raw_input=True where sklearn validation + would try to convert array API to numpy). + """ if self.metric_params is not None and "p" in self.metric_params: if self.p is not None: warnings.warn( @@ -56,6 +354,16 @@ def _fit_validation(self, X, y=None): self.effective_metric_params_["p"] = effective_p self.effective_metric_ = self.metric + + # Convert sklearn metric aliases to canonical names for oneDAL compatibility + metric_aliases = { + "cityblock": "manhattan", + "l1": "manhattan", + "l2": "euclidean", + } + if self.metric in metric_aliases: + self.effective_metric_ = metric_aliases[self.metric] + # For minkowski distance, use more efficient methods where available if self.metric == "minkowski": p = self.effective_metric_params_["p"] @@ -66,9 +374,260 @@ def _fit_validation(self, X, y=None): elif p == np.inf: self.effective_metric_ = "chebyshev" + def _validate_n_classes(self): + """Validate that the classifier has at least 2 classes.""" + length = 0 if self.classes_ is None else len(self.classes_) + if length < 2: + raise ValueError( + f"The number of classes has to be greater than one; got {length}" + ) + + def _validate_kneighbors_bounds(self, n_neighbors, query_is_train, X): + n_samples_fit = self.n_samples_fit_ + if n_neighbors > n_samples_fit: + if query_is_train: + n_neighbors -= 1 # ok to modify inplace because an error is raised + inequality_str = "n_neighbors < n_samples_fit" + else: + inequality_str = "n_neighbors <= n_samples_fit" + raise ValueError( + f"Expected {inequality_str}, but " + f"n_neighbors = {n_neighbors}, n_samples_fit = {n_samples_fit}, " + f"n_samples = {X.shape[0]}" # include n_samples for common tests + ) + + def _kneighbors_validation(self, X, n_neighbors): + """Shared validation for kneighbors method called from sklearnex layer. + + Validates: + - n_neighbors is within valid bounds if provided + + Note: Feature validation (count, names, etc.) happens in validate_data + called by _onedal_kneighbors, so we don't duplicate it here. + """ + # Validate n_neighbors bounds if provided + if n_neighbors is not None: + # Determine if query is the training set + query_is_train = X is None or (hasattr(self, "_fit_X") and X is self._fit_X) + self._validate_kneighbors_bounds( + n_neighbors, query_is_train, X if X is not None else self._fit_X + ) + + def _prepare_kneighbors_inputs(self, X, n_neighbors): + """Prepare inputs for kneighbors call to onedal backend. + + Handles query_is_train case: when X=None, sets X to training data and adds +1 to n_neighbors. + Validates n_neighbors bounds AFTER adding +1 (replicates original onedal behavior). + + NOTE: Caller is responsible for validating X (via validate_data or _check_array). + This function does NOT validate X to avoid double validation and to support + use_raw_input mode where validation should be skipped. + + Args: + X: Query data or None + n_neighbors: Number of neighbors or None + + Returns: + Tuple of (X, n_neighbors, query_is_train) + - X: Processed query data (self._fit_X if original X was None) + - n_neighbors: Adjusted n_neighbors (includes +1 if query_is_train) + - query_is_train: Boolean flag indicating if original X was None + """ + query_is_train = X is None + + if X is not None: + # X validation should already be done by caller + # Do NOT call _check_array here to avoid double validation + # and to support use_raw_input mode + pass + else: + X = self._fit_X + # Include an extra neighbor to account for the sample itself being + # returned, which is removed later + if n_neighbors is None: + n_neighbors = self.n_neighbors + n_neighbors += 1 + + # Validate bounds AFTER adding +1 (replicates original onedal behavior) + # Original code in onedal had validation after n_neighbors += 1 + n_samples_fit = self.n_samples_fit_ + if n_neighbors > n_samples_fit: + n_neighbors_for_msg = ( + n_neighbors - 1 + ) # for error message, show original value + raise ValueError( + f"Expected n_neighbors < n_samples_fit, but " + f"n_neighbors = {n_neighbors_for_msg}, n_samples_fit = {n_samples_fit}, " + f"n_samples = {X.shape[0]}" + ) + + return X, n_neighbors, query_is_train + + def _kneighbors_post_processing( + self, X, n_neighbors, return_distance, result, query_is_train + ): + """Shared post-processing for kneighbors results. + + Following PCA pattern: all post-processing in sklearnex, onedal returns raw results. + Replicates exact logic from main branch onedal._kneighbors() method. + + Handles (in order, matching main branch): + 1. kd_tree sorting: sorts results by distance (BEFORE deciding what to return) + 2. query_is_train case (X=None): removes self from results + 3. return_distance decision: return distances+indices or just indices + + Args: + X: Query data (self._fit_X if query_is_train) + n_neighbors: Number of neighbors (already includes +1 if query_is_train) + return_distance: Whether to return distances to user + result: Raw result from onedal backend - always (distances, indices) + query_is_train: Boolean indicating if original X was None + + Returns: + Post-processed result: (distances, indices) if return_distance else indices + """ + # Array API support: get namespace from result arrays + # onedal always returns both distances and indices (backend computes both) + distances, indices = result + xp, _ = get_namespace(distances, indices) + + # POST-PROCESSING STEP 1: kd_tree sorting (moved from onedal) + # This happens BEFORE deciding what to return, using distances that are always available + # Matches main branch: sorting uses distances even when return_distance=False + if self._fit_method == "kd_tree": + for i in range(distances.shape[0]): + seq = xp.argsort(distances[i]) + indices[i] = indices[i][seq] + distances[i] = distances[i][seq] + + # POST-PROCESSING STEP 2: Decide what to return (moved from onedal) + # This happens AFTER kd_tree sorting + if return_distance: + results = distances, indices + else: + results = indices + + # POST-PROCESSING STEP 3: Remove self from results when query_is_train (moved from onedal) + # This happens LAST, after sorting and after deciding format + if not query_is_train: + return results + + # If the query data is the same as the indexed data, we would like + # to ignore the first nearest neighbor of every sample, i.e the sample itself. + if return_distance: + neigh_dist, neigh_ind = results + else: + neigh_ind = results + + # X is self._fit_X in query_is_train case (set by caller) + n_queries, _ = X.shape + sample_range = xp.arange(n_queries)[:, xp.newaxis] + sample_mask = neigh_ind != sample_range + + # Corner case: When the number of duplicates are more + # than the number of neighbors, the first NN will not + # be the sample, but a duplicate. + # In that case mask the first duplicate. + dup_gr_nbrs = xp.all(sample_mask, axis=1) + sample_mask[:, 0][dup_gr_nbrs] = False + + neigh_ind = xp.reshape(neigh_ind[sample_mask], (n_queries, n_neighbors - 1)) + + if return_distance: + neigh_dist = xp.reshape(neigh_dist[sample_mask], (n_queries, n_neighbors - 1)) + return neigh_dist, neigh_ind + return neigh_ind + + def _process_classification_targets(self, y, skip_validation=False): + """Process classification targets and set class-related attributes. + + Parameters + ---------- + y : array-like + Target values + skip_validation : bool, default=False + If True, skip _check_classification_targets validation. + Used when use_raw_input=True (raw array API arrays like dpctl.usm_ndarray). + """ + # Array API support: get namespace from y + xp, _ = get_namespace(y) + + # y should already be numpy array from validate_data + y = xp.asarray(y) + + # Handle shape processing + shape = getattr(y, "shape", None) + self._shape = shape if shape is not None else y.shape + + if y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1: + self.outputs_2d_ = False + y = xp.reshape(y, (-1, 1)) + else: + self.outputs_2d_ = True + + # Validate classification targets (skip for raw array API inputs) + if not skip_validation: + _check_classification_targets(y) + + # Process classes - note: np.unique is used for class extraction + # This is acceptable as classes are typically numpy arrays in sklearn + self.classes_ = [] + self._y = xp.empty(y.shape, dtype=xp.int32) + for k in range(self._y.shape[1]): + # Use numpy unique for class extraction (standard sklearn pattern) + # Transfer to host first to ensure proper numpy array conversion + y_k_host = np.asarray(_transfer_to_host(y[:, k])[1][0]) + classes, indices = np.unique(y_k_host, return_inverse=True) + self.classes_.append(classes) + self._y[:, k] = xp.asarray(indices, dtype=xp.int32) + + if not self.outputs_2d_: + self.classes_ = self.classes_[0] + self._y = xp.reshape(self._y, (-1,)) + + # Validate we have at least 2 classes + self._validate_n_classes() + + return y + + def _process_regression_targets(self, y): + """Process regression targets and set shape-related attributes. + + REFACTOR: This replicates the EXACT shape processing that was in onedal _fit. + Original onedal code: + shape = getattr(y, "shape", None) + self._shape = shape if shape is not None else y.shape + # (later, after fit) + self._y = y if self._shape is None else xp.reshape(y, self._shape) + + For now, just store _shape and _y as-is. The reshape happens after onedal fit is complete. + """ + # EXACT replication of original onedal shape processing + shape = getattr(y, "shape", None) + self._shape = shape if shape is not None else y.shape + self._y = y + return y + + def _fit_validation(self, X, y=None): + if sklearn_check_version("1.2"): + self._validate_params() + # check_feature_names(self, X, reset=True) + # Validate n_neighbors parameter + self._validate_n_neighbors(self.n_neighbors) + + # Set effective metric and parameters + self._set_effective_metric() + if not isinstance(X, (KDTree, BallTree, _sklearn_NeighborsBase)): + # Use _check_array like main branch, but with array API dtype support + # Get array namespace for array API support + # Don't check for NaN - let oneDAL handle it (will fallback to sklearn if needed) + xp, _ = get_namespace(X) self._fit_X = _check_array( - X, dtype=[np.float64, np.float32], accept_sparse=True + X, + dtype=[xp.float64, xp.float32], + accept_sparse=True, + force_all_finite=False, ) self.n_samples_fit_ = _num_samples(self._fit_X) self.n_features_in_ = _num_features(self._fit_X) @@ -95,7 +654,9 @@ def _fit_validation(self, X, y=None): else: self._fit_method = self.algorithm - if hasattr(self, "_onedal_estimator"): + # Only delete _onedal_estimator if it's an instance attribute, not a class attribute + # (SPMD classes define _onedal_estimator as a staticmethod at class level) + if "_onedal_estimator" in self.__dict__: delattr(self, "_onedal_estimator") # To cover test case when we pass patched # estimator as an input for other estimator @@ -105,7 +666,8 @@ def _fit_validation(self, X, y=None): self._fit_method = X._fit_method self.n_samples_fit_ = X.n_samples_fit_ self.n_features_in_ = X.n_features_in_ - if hasattr(X, "_onedal_estimator"): + # Check if X has _onedal_estimator as an instance attribute (not class attribute) + if "_onedal_estimator" in X.__dict__: self.effective_metric_params_.pop("p") if self._fit_method == "ball_tree": X._tree = BallTree( @@ -199,10 +761,15 @@ def _onedal_supported(self, device, method_name, *data): y = None # To check multioutput, might be overhead if len(data) > 1: - y = np.asarray(data[1]) + # Array API support: get namespace from y + y_input = data[1] + xp, _ = get_namespace(y_input) + y = xp.asarray(y_input) if is_classifier: - class_count = len(np.unique(y)) - if hasattr(self, "_onedal_estimator"): + # Use numpy for unique (standard sklearn pattern) + class_count = len(np.unique(np.asarray(y))) + # Only access _onedal_estimator if it's an instance attribute (not a class-level staticmethod) + if "_onedal_estimator" in self.__dict__: y = self._onedal_estimator._y if y is not None and hasattr(y, "ndim") and hasattr(y, "shape"): is_single_output = y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1 @@ -261,8 +828,10 @@ def _onedal_supported(self, device, method_name, *data): ) return patching_status if method_name in ["predict", "predict_proba", "kneighbors", "score"]: + # Check if _onedal_estimator is an instance attribute (model was trained) + # For SPMD classes, _onedal_estimator is a class-level staticmethod, so we check __dict__ patching_status.and_condition( - hasattr(self, "_onedal_estimator"), "oneDAL model was not trained." + "_onedal_estimator" in self.__dict__, "oneDAL model was not trained." ) return patching_status raise RuntimeError(f"Unknown method {method_name} in {class_name}") @@ -284,13 +853,17 @@ def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"): # requires moving data to host to construct the csr_matrix if mode == "connectivity": A_ind = self.kneighbors(X, n_neighbors, return_distance=False) + # Transfer to host - after this, arrays are numpy _, (A_ind,) = _transfer_to_host(A_ind) n_queries = A_ind.shape[0] + # Use numpy after transfer to host A_data = np.ones(n_queries * n_neighbors) elif mode == "distance": A_data, A_ind = self.kneighbors(X, n_neighbors, return_distance=True) + # Transfer to host - after this, arrays are numpy _, (A_data, A_ind) = _transfer_to_host(A_data, A_ind) + # Use numpy after transfer to host A_data = np.reshape(A_data, (-1,)) else: @@ -302,6 +875,7 @@ def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"): n_queries = A_ind.shape[0] n_samples_fit = self.n_samples_fit_ n_nonzero = n_queries * n_neighbors + # Use numpy after transfer to host A_indptr = np.arange(0, n_nonzero + 1, n_neighbors) kneighbors_graph = sp.csr_matrix( diff --git a/sklearnex/neighbors/knn_classification.py b/sklearnex/neighbors/knn_classification.py index 7e25fa5ae1..e86e7c433d 100755 --- a/sklearnex/neighbors/knn_classification.py +++ b/sklearnex/neighbors/knn_classification.py @@ -14,6 +14,7 @@ # limitations under the License. # =============================================================================== +import numpy as np from sklearn.metrics import accuracy_score from sklearn.neighbors._classification import ( KNeighborsClassifier as _sklearn_KNeighborsClassifier, @@ -23,13 +24,17 @@ from daal4py.sklearn._n_jobs_support import control_n_jobs from daal4py.sklearn._utils import sklearn_check_version from daal4py.sklearn.utils.validation import get_requires_y_tag +from onedal._device_offload import _transfer_to_host from onedal.neighbors import KNeighborsClassifier as onedal_KNeighborsClassifier +from .._config import get_config from .._device_offload import dispatch, wrap_output_data -from ..utils.validation import check_feature_names +from ..utils._array_api import enable_array_api, get_namespace +from ..utils.validation import check_feature_names, validate_data from .common import KNeighborsDispatchingBase +@enable_array_api @control_n_jobs( decorated_methods=["fit", "predict", "predict_proba", "kneighbors", "score"] ) @@ -79,7 +84,7 @@ def fit(self, X, y): @wrap_output_data def predict(self, X): check_is_fitted(self) - check_feature_names(self, X, reset=False) + return dispatch( self, "predict", @@ -93,7 +98,7 @@ def predict(self, X): @wrap_output_data def predict_proba(self, X): check_is_fitted(self) - check_feature_names(self, X, reset=False) + return dispatch( self, "predict_proba", @@ -107,7 +112,7 @@ def predict_proba(self, X): @wrap_output_data def score(self, X, y, sample_weight=None): check_is_fitted(self) - check_feature_names(self, X, reset=False) + return dispatch( self, "score", @@ -122,9 +127,15 @@ def score(self, X, y, sample_weight=None): @wrap_output_data def kneighbors(self, X=None, n_neighbors=None, return_distance=True): + # Validate n_neighbors parameter first + if n_neighbors is not None: + self._validate_n_neighbors(n_neighbors) + check_is_fitted(self) - if X is not None: - check_feature_names(self, X, reset=False) + + # Validate kneighbors parameters (inherited from KNeighborsDispatchingBase) + self._kneighbors_validation(X, n_neighbors) + return dispatch( self, "kneighbors", @@ -138,6 +149,29 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True): ) def _onedal_fit(self, X, y, queue=None): + xp, _ = get_namespace(X) + + # Validation step (follows PCA pattern) + if not get_config()["use_raw_input"]: + X, y = validate_data( + self, + X, + y, + dtype=[xp.float64, xp.float32], + accept_sparse="csr", + ) + # Set effective metric after validation + self._set_effective_metric() + else: + # SPMD mode: skip validation but still set effective metric + self._set_effective_metric() + + # Process classification targets before passing to onedal + self._process_classification_targets( + y, skip_validation=get_config()["use_raw_input"] + ) + + # Call onedal backend onedal_params = { "n_neighbors": self.n_neighbors, "weights": self.weights, @@ -150,28 +184,72 @@ def _onedal_fit(self, X, y, queue=None): self._onedal_estimator.requires_y = get_requires_y_tag(self) self._onedal_estimator.effective_metric_ = self.effective_metric_ self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_ + self._onedal_estimator.classes_ = self.classes_ + self._onedal_estimator._y = self._y + self._onedal_estimator.outputs_2d_ = self.outputs_2d_ + self._onedal_estimator._shape = self._shape + self._onedal_estimator.fit(X, y, queue=queue) + # Post-processing self._save_attributes() def _onedal_predict(self, X, queue=None): - return self._onedal_estimator.predict(X, queue=queue) + # Use the unified helper from common.py (calls kneighbors + computes prediction) + # This properly handles X=None (LOOCV) case + # Note: X validation happens in kneighbors + return self._predict_skl_classification(X) def _onedal_predict_proba(self, X, queue=None): - return self._onedal_estimator.predict_proba(X, queue=queue) + # Call kneighbors through sklearnex (self.kneighbors is the sklearnex method) + # This properly handles X=None case (LOOCV) with query_is_train logic + # Note: X validation happens in kneighbors + neigh_dist, neigh_ind = self.kneighbors(X) + + # Use the helper method to compute class probabilities + return self._compute_class_probabilities( + neigh_dist, neigh_ind, self.weights, self._y, self.classes_, self.outputs_2d_ + ) def _onedal_kneighbors( self, X=None, n_neighbors=None, return_distance=True, queue=None ): - return self._onedal_estimator.kneighbors( + # Only skip validation when use_raw_input=True (SPMD mode) + use_raw_input = get_config()["use_raw_input"] + + if X is not None and not use_raw_input: + xp, _ = get_namespace(X) + X = validate_data( + self, + X, + dtype=[xp.float64, xp.float32], + accept_sparse="csr", + reset=False, + ) + + # Prepare inputs and handle query_is_train case + X, n_neighbors, query_is_train = self._prepare_kneighbors_inputs(X, n_neighbors) + + # Get raw results from onedal backend + result = self._onedal_estimator.kneighbors( X, n_neighbors, return_distance, queue=queue ) - def _onedal_score(self, X, y, sample_weight=None, queue=None): - return accuracy_score( - y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight + # Apply post-processing (kd_tree sorting, removing self from results) + return self._kneighbors_post_processing( + X, n_neighbors, return_distance, result, query_is_train ) + def _onedal_score(self, X, y, sample_weight=None, queue=None): + # Get predictions + y_pred = self._onedal_predict(X, queue=queue) + + # Convert array API to numpy for sklearn's accuracy_score using _transfer_to_host + # This properly handles Array API arrays that don't allow implicit conversion + _, (y, y_pred, sample_weight) = _transfer_to_host(y, y_pred, sample_weight) + + return accuracy_score(y, y_pred, sample_weight=sample_weight) + def _save_attributes(self): self.classes_ = self._onedal_estimator.classes_ self.n_features_in_ = self._onedal_estimator.n_features_in_ diff --git a/sklearnex/neighbors/knn_regression.py b/sklearnex/neighbors/knn_regression.py index ba1626b4ff..1a1760af9d 100755 --- a/sklearnex/neighbors/knn_regression.py +++ b/sklearnex/neighbors/knn_regression.py @@ -18,18 +18,23 @@ from sklearn.neighbors._regression import ( KNeighborsRegressor as _sklearn_KNeighborsRegressor, ) -from sklearn.utils.validation import check_is_fitted +from sklearn.utils.validation import assert_all_finite, check_is_fitted from daal4py.sklearn._n_jobs_support import control_n_jobs from daal4py.sklearn._utils import sklearn_check_version from daal4py.sklearn.utils.validation import get_requires_y_tag +from onedal._device_offload import _transfer_to_host from onedal.neighbors import KNeighborsRegressor as onedal_KNeighborsRegressor +from onedal.utils import _sycl_queue_manager as QM +from .._config import get_config from .._device_offload import dispatch, wrap_output_data -from ..utils.validation import check_feature_names +from ..utils._array_api import enable_array_api, get_namespace +from ..utils.validation import check_feature_names, validate_data from .common import KNeighborsDispatchingBase +@enable_array_api("1.5") # validate_data y_numeric requires sklearn >=1.5 @control_n_jobs(decorated_methods=["fit", "predict", "kneighbors", "score"]) class KNeighborsRegressor(KNeighborsDispatchingBase, _sklearn_KNeighborsRegressor): __doc__ = _sklearn_KNeighborsRegressor.__doc__ @@ -77,7 +82,7 @@ def fit(self, X, y): @wrap_output_data def predict(self, X): check_is_fitted(self) - check_feature_names(self, X, reset=False) + return dispatch( self, "predict", @@ -91,7 +96,7 @@ def predict(self, X): @wrap_output_data def score(self, X, y, sample_weight=None): check_is_fitted(self) - check_feature_names(self, X, reset=False) + return dispatch( self, "score", @@ -106,9 +111,15 @@ def score(self, X, y, sample_weight=None): @wrap_output_data def kneighbors(self, X=None, n_neighbors=None, return_distance=True): + # Validate n_neighbors parameter first (before check_is_fitted) + if n_neighbors is not None: + self._validate_n_neighbors(n_neighbors) + check_is_fitted(self) - if X is not None: - check_feature_names(self, X, reset=False) + + # Validate kneighbors parameters (inherited from KNeighborsDispatchingBase) + self._kneighbors_validation(X, n_neighbors) + return dispatch( self, "kneighbors", @@ -122,6 +133,30 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True): ) def _onedal_fit(self, X, y, queue=None): + xp, _ = get_namespace(X, y) + + # Validation step - validates and converts dtypes to float32/float64 + if not get_config()["use_raw_input"]: + X, y = validate_data( + self, + X, + y, + dtype=[xp.float64, xp.float32], + accept_sparse="csr", + multi_output=True, + y_numeric=True, + ) + + # Set effective metric after validation + self._set_effective_metric() + else: + # SPMD mode: skip validation but still set effective metric + self._set_effective_metric() + + # Process regression targets before passing to onedal (uses validated y) + self._process_regression_targets(y) + + # Call onedal backend onedal_params = { "n_neighbors": self.n_neighbors, "weights": self.weights, @@ -134,25 +169,103 @@ def _onedal_fit(self, X, y, queue=None): self._onedal_estimator.requires_y = get_requires_y_tag(self) self._onedal_estimator.effective_metric_ = self.effective_metric_ self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_ + self._onedal_estimator._shape = self._shape + + # Reshape _y for GPU backend + queue_instance = QM.get_global_queue() + gpu_device = queue_instance is not None and queue_instance.sycl_device.is_gpu + if gpu_device: + self._onedal_estimator._y = xp.reshape(self._y, (-1, 1)) + else: + self._onedal_estimator._y = self._y + + # Pass validated X and y to onedal (after validate_data converted dtypes) self._onedal_estimator.fit(X, y, queue=queue) + # Post-processing: save attributes and reshape _y self._save_attributes() + if y is not None: + xp, _ = get_namespace(y) + self._y = y if self._shape is None else xp.reshape(y, self._shape) + self._onedal_estimator._y = self._y def _onedal_predict(self, X, queue=None): - return self._onedal_estimator.predict(X, queue=queue) + # Dispatch between GPU and SKL prediction methods + # This logic matches onedal regressor predict() method but computation happens in sklearnex + # Note: X validation happens in kneighbors (for SKL path) or _predict_gpu (for GPU path) + gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) + is_uniform_weights = getattr(self, "weights", "uniform") == "uniform" + + if gpu_device and is_uniform_weights: + # GPU path: call onedal backend directly + return self._predict_gpu(X, queue=queue) + else: + # SKL path: call kneighbors (through sklearnex) then compute in sklearnex + return self._predict_skl(X, queue=queue) + + def _predict_gpu(self, X, queue=None): + """GPU prediction path - calls onedal backend.""" + # Validate X for GPU path (SKL path validation happens in kneighbors) + if X is not None: + xp, _ = get_namespace(X) + # For precomputed metric, only check NaN/inf, don't validate features + if getattr(self, "effective_metric_", self.metric) == "precomputed": + from ..utils.validation import assert_all_finite + + assert_all_finite(X, allow_nan=False, input_name="X") + else: + X = validate_data( + self, + X, + dtype=[xp.float64, xp.float32], + accept_sparse="csr", + reset=False, + ) + # Call onedal backend for GPU prediction + return self._onedal_estimator._predict_gpu(X) + + def _predict_skl(self, X, queue=None): + """SKL prediction path - calls kneighbors through sklearnex, computes prediction here.""" + # Use the unified helper from common.py (calls kneighbors + computes prediction) + return self._predict_skl_regression(X) def _onedal_kneighbors( self, X=None, n_neighbors=None, return_distance=True, queue=None ): - return self._onedal_estimator.kneighbors( + # Validation step + if X is not None and not get_config()["use_raw_input"]: + xp, _ = get_namespace(X) + X = validate_data( + self, + X, + dtype=[xp.float64, xp.float32], + accept_sparse="csr", + reset=False, + ) + + # Prepare inputs + X, n_neighbors, query_is_train = self._prepare_kneighbors_inputs(X, n_neighbors) + + # Call onedal backend + result = self._onedal_estimator.kneighbors( X, n_neighbors, return_distance, queue=queue ) - def _onedal_score(self, X, y, sample_weight=None, queue=None): - return r2_score( - y, self._onedal_predict(X, queue=queue), sample_weight=sample_weight + # Post-processing + return self._kneighbors_post_processing( + X, n_neighbors, return_distance, result, query_is_train ) + def _onedal_score(self, X, y, sample_weight=None, queue=None): + y_pred = self._onedal_predict(X, queue=queue) + + # Convert array API/USM arrays back to numpy for r2_score + # r2_score doesn't support Array API, following PCA's pattern with _transfer_to_host + _, host_data = _transfer_to_host(y, y_pred, sample_weight) + y, y_pred, sample_weight = host_data + + return r2_score(y, y_pred, sample_weight=sample_weight) + def _save_attributes(self): self.n_features_in_ = self._onedal_estimator.n_features_in_ self.n_samples_fit_ = self._onedal_estimator.n_samples_fit_ diff --git a/sklearnex/neighbors/knn_unsupervised.py b/sklearnex/neighbors/knn_unsupervised.py index 80da8bb2cf..f2c5d950d0 100755 --- a/sklearnex/neighbors/knn_unsupervised.py +++ b/sklearnex/neighbors/knn_unsupervised.py @@ -23,10 +23,12 @@ from onedal.neighbors import NearestNeighbors as onedal_NearestNeighbors from .._device_offload import dispatch, wrap_output_data -from ..utils.validation import check_feature_names +from ..utils._array_api import enable_array_api, get_namespace +from ..utils.validation import check_feature_names, validate_data from .common import KNeighborsDispatchingBase +@enable_array_api @control_n_jobs(decorated_methods=["fit", "kneighbors", "radius_neighbors"]) class NearestNeighbors(KNeighborsDispatchingBase, _sklearn_NearestNeighbors): __doc__ = _sklearn_NearestNeighbors.__doc__ @@ -73,9 +75,15 @@ def fit(self, X, y=None): @wrap_output_data def kneighbors(self, X=None, n_neighbors=None, return_distance=True): + # Validate n_neighbors parameter first + if n_neighbors is not None: + self._validate_n_neighbors(n_neighbors) + check_is_fitted(self) - if X is not None: - check_feature_names(self, X, reset=False) + + # Validate kneighbors parameters (inherited from KNeighborsDispatchingBase) + self._kneighbors_validation(X, n_neighbors) + return dispatch( self, "kneighbors", @@ -93,7 +101,7 @@ def radius_neighbors( self, X=None, radius=None, return_distance=True, sort_results=False ): if ( - hasattr(self, "_onedal_estimator") + "_onedal_estimator" in self.__dict__ or getattr(self, "_tree", 0) is None and self._fit_method == "kd_tree" ): @@ -129,6 +137,14 @@ def radius_neighbors_graph( ) def _onedal_fit(self, X, y=None, queue=None): + xp, _ = get_namespace(X) + X = validate_data( + self, + X, + dtype=[xp.float64, xp.float32], + accept_sparse="csr", + ) + onedal_params = { "n_neighbors": self.n_neighbors, "algorithm": self.algorithm, @@ -141,19 +157,48 @@ def _onedal_fit(self, X, y=None, queue=None): self._onedal_estimator.effective_metric_ = self.effective_metric_ self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_ self._onedal_estimator.fit(X, y, queue=queue) - self._save_attributes() def _onedal_predict(self, X, queue=None): + # Validate and convert X + if X is not None: + xp, _ = get_namespace(X) + X = validate_data( + self, + X, + dtype=[xp.float64, xp.float32], + accept_sparse="csr", + reset=False, + force_all_finite=False, + ) return self._onedal_estimator.predict(X, queue=queue) def _onedal_kneighbors( self, X=None, n_neighbors=None, return_distance=True, queue=None ): - return self._onedal_estimator.kneighbors( + if X is not None: + xp, _ = get_namespace(X) + X = validate_data( + self, + X, + dtype=[xp.float64, xp.float32], + accept_sparse="csr", + reset=False, + ) + + # Prepare inputs and handle query_is_train case + X, n_neighbors, query_is_train = self._prepare_kneighbors_inputs(X, n_neighbors) + + # Get raw results from onedal backend + result = self._onedal_estimator.kneighbors( X, n_neighbors, return_distance, queue=queue ) + # Apply post-processing (kd_tree sorting, removing self from results) + return self._kneighbors_post_processing( + X, n_neighbors, return_distance, result, query_is_train + ) + def _save_attributes(self): self.classes_ = self._onedal_estimator.classes_ self.n_features_in_ = self._onedal_estimator.n_features_in_ diff --git a/sklearnex/spmd/neighbors/__init__.py b/sklearnex/spmd/neighbors/__init__.py index 44cb849591..8036511d9f 100644 --- a/sklearnex/spmd/neighbors/__init__.py +++ b/sklearnex/spmd/neighbors/__init__.py @@ -14,10 +14,6 @@ # limitations under the License. # ============================================================================== -from onedal.spmd.neighbors import ( - KNeighborsClassifier, - KNeighborsRegressor, - NearestNeighbors, -) +from .neighbors import KNeighborsClassifier, KNeighborsRegressor, NearestNeighbors __all__ = ["KNeighborsClassifier", "KNeighborsRegressor", "NearestNeighbors"] diff --git a/sklearnex/spmd/neighbors/neighbors.py b/sklearnex/spmd/neighbors/neighbors.py new file mode 100644 index 0000000000..d333f4530a --- /dev/null +++ b/sklearnex/spmd/neighbors/neighbors.py @@ -0,0 +1,46 @@ +# ============================================================================== +# Copyright 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from onedal.spmd.neighbors import KNeighborsClassifier as onedal_KNeighborsClassifier +from onedal.spmd.neighbors import KNeighborsRegressor as onedal_KNeighborsRegressor +from onedal.spmd.neighbors import NearestNeighbors as onedal_NearestNeighbors + +from ...neighbors import KNeighborsClassifier as base_KNeighborsClassifier +from ...neighbors import KNeighborsRegressor as base_KNeighborsRegressor +from ...neighbors import NearestNeighbors as base_NearestNeighbors + + +class KNeighborsClassifier(base_KNeighborsClassifier): + _onedal_estimator = staticmethod(onedal_KNeighborsClassifier) + + +class KNeighborsRegressor(base_KNeighborsRegressor): + _onedal_estimator = staticmethod(onedal_KNeighborsRegressor) + + def _onedal_predict(self, X, queue=None): + """Override to always use GPU path in SPMD mode. + + SPMD KNN regression always trains on GPU (creating regression.model), + so we must always use the GPU prediction path even with weights='distance'. + The parent class would dispatch to CPU/SKL path for weights='distance', + which would call infer_search() expecting search.model, causing type mismatch. + """ + # Always use GPU path - call parent's _predict_gpu directly + return self._predict_gpu(X, queue=queue) + + +class NearestNeighbors(base_NearestNeighbors): + _onedal_estimator = staticmethod(onedal_NearestNeighbors) diff --git a/sklearnex/tests/test_common.py b/sklearnex/tests/test_common.py index d01597344d..435d7359da 100644 --- a/sklearnex/tests/test_common.py +++ b/sklearnex/tests/test_common.py @@ -106,41 +106,6 @@ "DummyRegressor-fit-n_jobs_check": "default parameters use sklearn", "DummyRegressor-predict-n_jobs_check": "default parameters use sklearn", "DummyRegressor-score-n_jobs_check": "default parameters use sklearn", - # KNeighborsClassifier validate_data issues - will be fixed later - "KNeighborsClassifier-fit-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsClassifier-predict_proba-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsClassifier-score-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsClassifier-kneighbors-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsClassifier-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsClassifier-predict-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsRegressor-fit-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsRegressor-score-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsRegressor-kneighbors-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsRegressor-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsRegressor-predict-call_validate_data": "validate_data implementation needs fixing", - "NearestNeighbors-fit-call_validate_data": "validate_data implementation needs fixing", - "NearestNeighbors-kneighbors-call_validate_data": "validate_data implementation needs fixing", - "NearestNeighbors-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing", - "LocalOutlierFactor-fit-call_validate_data": "validate_data implementation needs fixing", - "LocalOutlierFactor-kneighbors-call_validate_data": "validate_data implementation needs fixing", - "LocalOutlierFactor-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing", - "LocalOutlierFactor(novelty=True)-fit-call_validate_data": "validate_data implementation needs fixing", - "LocalOutlierFactor(novelty=True)-kneighbors-call_validate_data": "validate_data implementation needs fixing", - "LocalOutlierFactor(novelty=True)-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsClassifier(algorithm='brute')-fit-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsClassifier(algorithm='brute')-predict_proba-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsClassifier(algorithm='brute')-score-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsClassifier(algorithm='brute')-kneighbors-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsClassifier(algorithm='brute')-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsClassifier(algorithm='brute')-predict-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsRegressor(algorithm='brute')-fit-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsRegressor(algorithm='brute')-score-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsRegressor(algorithm='brute')-kneighbors-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsRegressor(algorithm='brute')-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing", - "KNeighborsRegressor(algorithm='brute')-predict-call_validate_data": "validate_data implementation needs fixing", - "NearestNeighbors(algorithm='brute')-fit-call_validate_data": "validate_data implementation needs fixing", - "NearestNeighbors(algorithm='brute')-kneighbors-call_validate_data": "validate_data implementation needs fixing", - "NearestNeighbors(algorithm='brute')-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing", }