diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c2953289a..9c91d8d627 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ - PR #3112: Speed test_array - PR #3111: Adding Cython to Code Coverage - PR #3129: Update notebooks README +- PR #3040: Improved Array Conversion with CumlArrayDescriptor and Decorators - PR #3134: Improving the Deprecation Message Formatting in Documentation ## Bug Fixes diff --git a/python/cuml/__init__.py b/python/cuml/__init__.py index a230661ad8..63151af9c2 100644 --- a/python/cuml/__init__.py +++ b/python/cuml/__init__.py @@ -82,7 +82,7 @@ # Output type configuration -global_output_type = 'input' +global_output_type = None from cuml.common.memory_utils import set_global_output_type, using_output_type diff --git a/python/cuml/benchmark/datagen.py b/python/cuml/benchmark/datagen.py index 5d37ae30cc..923758d1ed 100644 --- a/python/cuml/benchmark/datagen.py +++ b/python/cuml/benchmark/datagen.py @@ -219,7 +219,8 @@ def _convert_to_gpuarray(data, order='F'): gs = cudf.Series.from_pandas(data) return cuda.as_cuda_array(gs) else: - return input_utils.input_to_dev_array(data, order=order)[0] + return input_utils.input_to_cuml_array( + data, order=order)[0].to_output("numba") def _convert_to_gpuarray_c(data): diff --git a/python/cuml/cluster/dbscan.pyx b/python/cuml/cluster/dbscan.pyx index e4e57220cd..1173c12af9 100644 --- a/python/cuml/cluster/dbscan.pyx +++ b/python/cuml/cluster/dbscan.pyx @@ -30,6 +30,8 @@ from cuml.common.base import Base from cuml.common.doc_utils import generate_docstring from cuml.raft.common.handle cimport handle_t from cuml.common import input_to_cuml_array +from cuml.common import using_output_type +from cuml.common.array_descriptor import CumlArrayDescriptor from collections import defaultdict @@ -186,6 +188,9 @@ class DBSCAN(Base): `_. """ + labels_ = CumlArrayDescriptor() + core_sample_indices_ = CumlArrayDescriptor() + def __init__(self, eps=0.5, handle=None, min_samples=5, verbose=False, max_mbytes_per_batch=None, output_type=None, calc_core_sample_indices=True): @@ -196,18 +201,17 @@ class DBSCAN(Base): self.calc_core_sample_indices = calc_core_sample_indices # internal array attributes - self._labels_ = None # accessed via estimator.labels_ + self.labels_ = None - # accessed via estimator._core_sample_indices_ when - # self.calc_core_sample_indices == True - self._core_sample_indices_ = None + # One used when `self.calc_core_sample_indices == True` + self.core_sample_indices_ = None # C++ API expects this to be numeric. if self.max_mbytes_per_batch is None: self.max_mbytes_per_batch = 0 @generate_docstring(skip_parameters_heading=True) - def fit(self, X, out_dtype="int32"): + def fit(self, X, out_dtype="int32") -> "DBSCAN": """ Perform DBSCAN clustering from features. @@ -218,11 +222,6 @@ class DBSCAN(Base): "int64", np.int64}. """ - self._set_base_attributes(output_type=X, n_features=X) - - if self._labels_ is not None: - del self._labels_ - if out_dtype not in ["int32", np.int32, "int64", np.int64]: raise ValueError("Invalid value for out_dtype. " "Valid values are {'int32', 'int64', " @@ -236,16 +235,16 @@ class DBSCAN(Base): cdef handle_t* handle_ = self.handle.getHandle() - self._labels_ = CumlArray.empty(n_rows, dtype=out_dtype) - cdef uintptr_t labels_ptr = self._labels_.ptr + self.labels_ = CumlArray.empty(n_rows, dtype=out_dtype) + cdef uintptr_t labels_ptr = self.labels_.ptr cdef uintptr_t core_sample_indices_ptr = NULL # Create the output core_sample_indices only if needed if self.calc_core_sample_indices: - self._core_sample_indices_ = \ + self.core_sample_indices_ = \ CumlArray.empty(n_rows, dtype=out_dtype) - core_sample_indices_ptr = self._core_sample_indices_.ptr + core_sample_indices_ptr = self.core_sample_indices_.ptr if self.dtype == np.float32: if out_dtype is "int32" or out_dtype is np.int32: @@ -303,20 +302,21 @@ class DBSCAN(Base): # Finally, resize the core_sample_indices array if necessary if self.calc_core_sample_indices: - # Temp convert to cupy array only once - core_samples_cupy = self._core_sample_indices_.to_output("cupy") + # Temp convert to cupy array (better than using `cupy.asarray`) + with using_output_type("cupy"): - # First get the min index. These have to monotonically increasing, - # so the min index should be the first returned -1 - min_index = cp.argmin(core_samples_cupy).item() + # First get the min index. These have to monotonically + # increasing, so the min index should be the first returned -1 + min_index = cp.argmin(self.core_sample_indices_).item() - # Check for the case where there are no -1's - if (min_index == 0 and core_samples_cupy[min_index].item() != -1): - # Nothing to delete. The array has no -1's - pass - else: - self._core_sample_indices_ = \ - self._core_sample_indices_[:min_index] + # Check for the case where there are no -1's + if ((min_index == 0 and + self.core_sample_indices_[min_index].item() != -1)): + # Nothing to delete. The array has no -1's + pass + else: + self.core_sample_indices_ = \ + self.core_sample_indices_[:min_index] return self @@ -325,7 +325,7 @@ class DBSCAN(Base): 'type': 'dense', 'description': 'Cluster labels', 'shape': '(n_samples, 1)'}) - def fit_predict(self, X, out_dtype="int32"): + def fit_predict(self, X, out_dtype="int32") -> CumlArray: """ Performs clustering on X and returns cluster labels. diff --git a/python/cuml/cluster/kmeans.pyx b/python/cuml/cluster/kmeans.pyx index 4a9bdfe68d..b37e619eb0 100644 --- a/python/cuml/cluster/kmeans.pyx +++ b/python/cuml/cluster/kmeans.pyx @@ -21,12 +21,14 @@ import cudf import numpy as np import rmm import warnings +import typing from libcpp cimport bool from libc.stdint cimport uintptr_t, int64_t from libc.stdlib cimport calloc, malloc, free from cuml.common.array import CumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.base import Base from cuml.common.doc_utils import generate_docstring from cuml.raft.common.handle cimport handle_t @@ -259,6 +261,9 @@ class KMeans(Base): `_. """ + labels_ = CumlArrayDescriptor() + cluster_centers_ = CumlArrayDescriptor() + def __init__(self, handle=None, n_clusters=8, max_iter=300, tol=1e-4, verbose=False, random_state=1, init='scalable-k-means++', n_init=1, oversampling_factor=2.0, @@ -275,8 +280,8 @@ class KMeans(Base): self.max_samples_per_batch=int(max_samples_per_batch) # internal array attributes - self._labels_ = None # accessed via estimator.labels_ - self._cluster_centers_ = None # accessed via estimator.cluster_centers_ # noqa + self.labels_ = None + self.cluster_centers_ = None cdef KMeansParams params params.n_clusters = self.n_clusters @@ -301,7 +306,7 @@ class KMeans(Base): else: self.init = 'preset' params.init = Array - self._cluster_centers_, n_rows, self.n_cols, self.dtype = \ + self.cluster_centers_, n_rows, self.n_cols, self.dtype = \ input_to_cuml_array(init, order='C', check_dtype=[np.float32, np.float64]) @@ -316,13 +321,11 @@ class KMeans(Base): self._params = params @generate_docstring() - def fit(self, X, sample_weight=None): + def fit(self, X, sample_weight=None) -> "KMeans": """ Compute k-means clustering with X. """ - self._set_base_attributes(output_type=X, n_features=X) - if self.init == 'preset': check_cols = self.n_cols check_dtype = self.dtype @@ -349,15 +352,15 @@ class KMeans(Base): cdef uintptr_t sample_weight_ptr = sample_weight_m.ptr - self._labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32) - cdef uintptr_t labels_ptr = self._labels_.ptr + self.labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32) + cdef uintptr_t labels_ptr = self.labels_.ptr if (self.init in ['scalable-k-means++', 'k-means||', 'random']): - self._cluster_centers_ = \ + self.cluster_centers_ = \ CumlArray.zeros(shape=(self.n_clusters, self.n_cols), dtype=self.dtype, order='C') - cdef uintptr_t cluster_centers_ptr = self._cluster_centers_.ptr + cdef uintptr_t cluster_centers_ptr = self.cluster_centers_.ptr cdef float inertiaf = 0 cdef double inertiad = 0 @@ -409,7 +412,7 @@ class KMeans(Base): 'type': 'dense', 'description': 'Cluster indexes', 'shape': '(n_samples, 1)'}) - def fit_predict(self, X, sample_weight=None): + def fit_predict(self, X, sample_weight=None) -> CumlArray: """ Compute cluster centers and predict cluster index for each sample. @@ -417,7 +420,8 @@ class KMeans(Base): return self.fit(X, sample_weight=sample_weight).labels_ def _predict_labels_inertia(self, X, convert_dtype=False, - sample_weight=None): + sample_weight=None) -> typing.Tuple[CumlArray, + float]: """ Predict the closest cluster each sample in X belongs to. @@ -446,8 +450,6 @@ class KMeans(Base): Sum of squared distances of samples to their closest cluster center. """ - out_type = self._get_output_type(X) - X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, order='C', check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype @@ -468,10 +470,10 @@ class KMeans(Base): cdef handle_t* handle_ = self.handle.getHandle() - cdef uintptr_t cluster_centers_ptr = self._cluster_centers_.ptr + cdef uintptr_t cluster_centers_ptr = self.cluster_centers_.ptr - self._labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32) - cdef uintptr_t labels_ptr = self._labels_.ptr + self.labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32) + cdef uintptr_t labels_ptr = self.labels_.ptr # Sum of squared distances of samples to their closest cluster center. cdef float inertiaf = 0 @@ -511,13 +513,13 @@ class KMeans(Base): self.handle.sync() del(X_m) del(sample_weight_m) - return self._labels_.to_output(out_type), inertia + return self.labels_, inertia @generate_docstring(return_values={'name': 'preds', 'type': 'dense', 'description': 'Cluster indexes', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=False, sample_weight=None): + def predict(self, X, convert_dtype=False, sample_weight=None) -> CumlArray: """ Predict the closest cluster each sample in X belongs to. @@ -532,14 +534,12 @@ class KMeans(Base): 'type': 'dense', 'description': 'Transformed data', 'shape': '(n_samples, n_clusters)'}) - def transform(self, X, convert_dtype=False): + def transform(self, X, convert_dtype=False) -> CumlArray: """ Transform X to a cluster-distance space. """ - out_type = self._get_output_type(X) - X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, order='C', check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype @@ -550,7 +550,7 @@ class KMeans(Base): cdef handle_t* handle_ = self.handle.getHandle() - cdef uintptr_t cluster_centers_ptr = self._cluster_centers_.ptr + cdef uintptr_t cluster_centers_ptr = self.cluster_centers_.ptr preds = CumlArray.zeros(shape=(n_rows, self.n_clusters), dtype=self.dtype, @@ -589,7 +589,7 @@ class KMeans(Base): self.handle.sync() del(X_m) - return preds.to_output(out_type) + return preds @generate_docstring(return_values={'name': 'score', 'type': 'float', @@ -610,7 +610,7 @@ class KMeans(Base): 'type': 'dense', 'description': 'Transformed data', 'shape': '(n_samples, n_clusters)'}) - def fit_transform(self, X, convert_dtype=False): + def fit_transform(self, X, convert_dtype=False) -> CumlArray: """ Compute clustering and transform X to cluster-distance space. diff --git a/python/cuml/cluster/kmeans_mg.pyx b/python/cuml/cluster/kmeans_mg.pyx index cbac199b72..487fecfe01 100644 --- a/python/cuml/cluster/kmeans_mg.pyx +++ b/python/cuml/cluster/kmeans_mg.pyx @@ -72,7 +72,7 @@ class KMeansMG(KMeans): def __init__(self, **kwargs): super(KMeansMG, self).__init__(**kwargs) - def fit(self, X): + def fit(self, X) -> "KMeansMG": """ Compute k-means clustering with X in a multi-node multi-GPU setting. @@ -84,7 +84,6 @@ class KMeansMG(KMeans): ndarray, cuda array interface compliant array like CuPy """ - self._set_base_attributes(n_features=X) X_m, self.n_rows, self.n_cols, self.dtype = \ input_to_cuml_array(X, order='C') @@ -94,12 +93,12 @@ class KMeansMG(KMeans): cdef handle_t* handle_ = self.handle.getHandle() if (self.init in ['scalable-k-means++', 'k-means||', 'random']): - self._cluster_centers_ = CumlArray.zeros(shape=(self.n_clusters, - self.n_cols), - dtype=self.dtype, - order='C') + self.cluster_centers_ = CumlArray.zeros(shape=(self.n_clusters, + self.n_cols), + dtype=self.dtype, + order='C') - cdef uintptr_t cluster_centers_ptr = self._cluster_centers_.ptr + cdef uintptr_t cluster_centers_ptr = self.cluster_centers_.ptr cdef size_t n_rows = self.n_rows cdef size_t n_cols = self.n_cols diff --git a/python/cuml/common/__init__.py b/python/cuml/common/__init__.py index 8a12f2ac36..e86261045b 100644 --- a/python/cuml/common/__init__.py +++ b/python/cuml/common/__init__.py @@ -37,9 +37,5 @@ ## legacy to be removed after complete CumlAray migration -from cuml.common.numba_utils import zeros -from cuml.common.input_utils import get_cudf_column_ptr -from cuml.common.input_utils import get_dev_array_ptr -from cuml.common.input_utils import input_to_dev_array from cuml.common.input_utils import sparse_scipy_to_cp from cuml.common.timing_utils import timed diff --git a/python/cuml/common/array.py b/python/cuml/common/array.py index 91c8e71846..be8c511231 100644 --- a/python/cuml/common/array.py +++ b/python/cuml/common/array.py @@ -24,9 +24,11 @@ from cuml.common.memory_utils import _get_size_from_shape from cuml.common.memory_utils import _order_to_strides from cuml.common.memory_utils import _strides_to_order +from cuml.common.memory_utils import class_with_cupy_rmm from numba import cuda +@class_with_cupy_rmm(ignore_pattern=["serialize"]) class CumlArray(Buffer): """ @@ -170,6 +172,7 @@ def __init__(self, data=None, owner=None, dtype=None, shape=None, self.strides = ary_interface['strides'] self.order = _strides_to_order(self.strides, self.dtype) + @with_cupy_rmm def __getitem__(self, slice): return CumlArray(data=cp.asarray(self).__getitem__(slice)) @@ -199,7 +202,9 @@ def __cuda_array_interface__(self): } return output - @with_cupy_rmm + def item(self): + return cp.asarray(self).item() + def to_output(self, output_type='cupy', output_dtype=None): """ Convert array to output format @@ -234,6 +239,8 @@ def to_output(self, output_type='cupy', output_dtype=None): else: output_type = 'dataframe' + assert output_type != "mirror" + if output_type == 'cupy': return cp.asarray(self, dtype=output_dtype) @@ -265,8 +272,8 @@ def to_output(self, output_type='cupy', output_dtype=None): else: raise ValueError('cuDF unsupported Array dtype') elif self.shape[1] > 1: - raise ValueError('Only single dimensional arrays can be \ - transformed to cuDF Series. ') + raise ValueError('Only single dimensional arrays can be ' + 'transformed to cuDF Series. ') else: if self.dtype not in [np.uint8, np.uint16, np.uint32, np.uint64, np.float16]: @@ -274,6 +281,8 @@ def to_output(self, output_type='cupy', output_dtype=None): else: raise ValueError('cuDF unsupported Array dtype') + return self + def serialize(self): header, frames = super(CumlArray, self).serialize() header["constructor-kwargs"] = { @@ -299,9 +308,7 @@ def empty(cls, shape, dtype, order='F'): Whether to create a F-major or C-major array. """ - size, _ = _get_size_from_shape(shape, dtype) - dbuf = DeviceBuffer(size=size) - return CumlArray(data=dbuf, shape=shape, dtype=dtype, order=order) + return CumlArray(cp.empty(shape, dtype, order)) @classmethod def full(cls, shape, value, dtype, order='F'): @@ -317,11 +324,8 @@ def full(cls, shape, value, dtype, order='F'): order: string, optional Whether to create a F-major or C-major array. """ - size, _ = _get_size_from_shape(shape, dtype) - dbuf = DeviceBuffer(size=size) - cp.asarray(dbuf).view(dtype=dtype).fill(value) - return CumlArray(data=dbuf, shape=shape, dtype=dtype, - order=order) + + return CumlArray(cp.full(shape, value, dtype, order)) @classmethod def zeros(cls, shape, dtype='float32', order='F'): diff --git a/python/cuml/common/array_descriptor.py b/python/cuml/common/array_descriptor.py new file mode 100644 index 0000000000..304b5842a7 --- /dev/null +++ b/python/cuml/common/array_descriptor.py @@ -0,0 +1,145 @@ +# +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass, field +from cuml.common.array import CumlArray +import cuml +from cuml.common.input_utils import input_to_cuml_array, determine_array_type + + +@dataclass +class CumlArrayDescriptorMeta: + + # The type for the input value. One of: _input_type_to_str + input_type: str + + # Dict containing values in different formats. One entry per type. Both the + # input type and any cached converted types will be stored. Erased on set + values: dict = field(default_factory=dict) + + def get_input_value(self): + + assert self.input_type in self.values, \ + "Missing value for input_type {}".format(self.input_type) + + return self.values[self.input_type] + + def __getstate__(self): + # Need to only return the input_value from + return { + "input_type": self.input_type, + "input_value": self.get_input_value() + } + + def __setstate__(self, d): + self.input_type = d["input_type"] + self.values = {self.input_type: d["input_value"]} + + +class CumlArrayDescriptor(): + """ + Python descriptor object to control getting/setting `CumlArray` attributes + on `Base` objects. See the Estimator Guide for an in depth guide. + """ + def __set_name__(self, owner, name): + self.name = name + + def _get_meta(self, + instance, + throw_on_missing=False) -> CumlArrayDescriptorMeta: + + if (throw_on_missing): + if (self.name not in instance.__dict__): + raise AttributeError() + + return instance.__dict__.setdefault( + self.name, CumlArrayDescriptorMeta(input_type=None, values={})) + + def _to_output(self, instance, to_output_type, to_output_dtype=None): + + existing = self._get_meta(instance, throw_on_missing=True) + + # Handle input_type==None which means we have a non-array object stored + if (existing.input_type is None): + # Dont save in the cache. Just return the value + return existing.values[existing.input_type] + + # Return a cached value if it exists + if (to_output_type in existing.values): + return existing.values[to_output_type] + + # If the input type was anything but CumlArray, need to create one now + if ("cuml" not in existing.values): + existing.values["cuml"] = input_to_cuml_array( + existing.get_input_value(), order="K").array + + cuml_arr: CumlArray = existing.values["cuml"] + + # Do the conversion + output = cuml_arr.to_output(output_type=to_output_type, + output_dtype=to_output_dtype) + + # Cache the value + existing.values[to_output_type] = output + + return output + + def __get__(self, instance, owner): + + if (instance is None): + return self + + existing = self._get_meta(instance, throw_on_missing=True) + + assert len(existing.values) > 0 + + # Get the global output type + output_type = cuml.global_output_type + + # First, determine if we need to call to_output at all + if (output_type == "mirror"): + # We must be internal, just return the input type + return existing.get_input_value() + + else: + # We are external, determine the target output type + if (output_type is None): + # Default to the owning base object output_type + output_type = instance.output_type + + if (output_type == "input"): + # Default to the owning base object, _input_type + output_type = instance._input_type + + return self._to_output(instance, output_type) + + def __set__(self, instance, value): + + existing = self._get_meta(instance) + + # Determine the type + existing.input_type = determine_array_type(value) + + # Clear any existing values + existing.values.clear() + + # Set the existing value + existing.values[existing.input_type] = value + + def __delete__(self, instance): + + if (instance is not None): + del instance.__dict__[self.name] diff --git a/python/cuml/common/array_sparse.py b/python/cuml/common/array_sparse.py index e43d8e9339..505623fb78 100644 --- a/python/cuml/common/array_sparse.py +++ b/python/cuml/common/array_sparse.py @@ -15,18 +15,18 @@ # import cupyx as cpx import numpy as np - from cuml.common.import_utils import has_scipy -from cuml.common.input_utils import input_to_cuml_array +from cuml.common.memory_utils import class_with_cupy_rmm from cuml.common.logger import debug -from cuml.common.memory_utils import with_cupy_rmm +import cuml.common if has_scipy(): import scipy.sparse -class SparseCumlArray: +@class_with_cupy_rmm() +class SparseCumlArray(): """ SparseCumlArray abstracts sparse array allocations. This will accept either a Scipy or Cupy sparse array and construct CumlArrays @@ -69,7 +69,6 @@ class SparseCumlArray: Number of nonzeros in underlying arrays """ - @with_cupy_rmm def __init__(self, data=None, convert_to_dtype=False, convert_index=np.int32, @@ -105,15 +104,15 @@ def __init__(self, data=None, # Note: Only 32-bit indexing is supported currently. # In CUDA11, Cusparse provides 64-bit function calls # but these are not yet used in RAFT/Cuml - self.indptr, _, _, _ = input_to_cuml_array( + self.indptr, _, _, _ = cuml.common.input_to_cuml_array( data.indptr, check_dtype=convert_index, convert_to_dtype=convert_index) - self.indices, _, _, _ = input_to_cuml_array( + self.indices, _, _, _ = cuml.common.input_to_cuml_array( data.indices, check_dtype=convert_index, convert_to_dtype=convert_index) - self.data, _, _, _ = input_to_cuml_array( + self.data, _, _, _ = cuml.common.input_to_cuml_array( data.data, check_dtype=data.dtype, convert_to_dtype=convert_to_dtype) @@ -121,7 +120,6 @@ def __init__(self, data=None, self.dtype = self.data.dtype self.nnz = data.nnz - @with_cupy_rmm def to_output(self, output_type='cupy', output_format=None, output_dtype=None): @@ -142,6 +140,10 @@ def to_output(self, output_type='cupy', Optionally cast the array to a specified dtype, creating a copy if necessary. """ + # Treat numpy and scipy as the same + if (output_type == "numpy"): + output_type = "scipy" + output_dtype = self.data.dtype \ if output_dtype is None else output_dtype diff --git a/python/cuml/common/base.pyx b/python/cuml/common/base.pyx index 23465b5402..8c310c1108 100644 --- a/python/cuml/common/base.pyx +++ b/python/cuml/common/base.pyx @@ -16,27 +16,18 @@ # distutils: language = c++ -import cuml -import cuml.common.cuda -import cuml.raft.common.handle -import cuml.common.logger as logger -from cuml.common import input_to_cuml_array import inspect -from cudf.core import Series as cuSeries -from cudf.core import DataFrame as cuDataFrame -from cuml.common.array import CumlArray +import cuml.common +import cuml.common.cuda +import cuml.common.logger as logger +import cuml.internals +import cuml.raft.common.handle from cuml.common.doc_utils import generate_docstring -from cupy import ndarray as cupyArray -from numba.cuda import devicearray as numbaArray -from numpy import ndarray as numpyArray -from pandas import DataFrame as pdDataFrame -from pandas import Series as pdSeries +import cuml.common.input_utils -from numba import cuda - -class Base: +class Base(metaclass=cuml.internals.BaseMetaClass): """ Base class for all the ML algos. It handles some of the common operations across all algos. Every ML algo class exposed at cython level must inherit @@ -160,8 +151,9 @@ class Base: base.handle.sync() del base # optional! """ - - def __init__(self, handle=None, verbose=False, + def __init__(self, + handle=None, + verbose=False, output_type=None): """ Constructor. All children must call init method of this base class. @@ -182,8 +174,9 @@ class Base: self.output_type = _check_output_type_str( cuml.global_output_type if output_type is None else output_type) - - self._mirror_input = True if self.output_type == 'input' else False + self._input_type = None + self.target_dtype = None + self.n_features_in_ = None def __repr__(self): """ @@ -259,25 +252,14 @@ class Base: def __getattr__(self, attr): """ - Method gives access to the correct format of cuml Array attribute to - the users. Any variable that starts with `_` and is a cuml Array - will return as the cuml Array converted to the appropriate format. + Redirects to `solver_model` if the attribute exists. """ - real_name = '_' + attr - # using __dict__ due to a bug with scikit-learn hyperparam - # when doing hasattr. github issue #1736 - if real_name in self.__dict__.keys(): - if isinstance(self.__dict__[real_name], CumlArray): - return self.__dict__[real_name].to_output(self.output_type) - else: - return self.__dict__[real_name] + if attr == "solver_model": + return self.__dict__['solver_model'] + if "solver_model" in self.__dict__.keys(): + return getattr(self.solver_model, attr) else: - if attr == "solver_model": - return self.__dict__['solver_model'] - if "solver_model" in self.__dict__.keys(): - return getattr(self.solver_model, attr) - else: - raise AttributeError + raise AttributeError def _set_base_attributes(self, output_type=None, @@ -321,23 +303,32 @@ class Base: if n_features is not None: self._set_n_features_in(n_features) - def _set_output_type(self, input): - if self.output_type == 'input' or self._mirror_input: - self.output_type = _input_to_type(input) + def _set_output_type(self, inp): + self._input_type = cuml.common.input_utils.determine_array_type(inp) - def _get_output_type(self, input): + def _get_output_type(self, inp): """ Method to be called by predict/transform methods of inheriting classes. Returns the appropriate output type depending on the type of the input, class output type and global output type. """ - if self._mirror_input: - return _input_to_type(input) - else: - return self.output_type + + # Default to the global type + output_type = cuml.global_output_type + + # If its None, default to our type + if (output_type is None or output_type == "mirror"): + output_type = self.output_type + + # If we are input, get the type from the input + if output_type == 'input': + output_type = cuml.common.input_utils.determine_array_type(inp) + + return output_type def _set_target_dtype(self, target): - self.target_dtype = _input_target_to_dtype(target) + self.target_dtype = cuml.common.input_utils.determine_array_dtype( + target) def _get_target_dtype(self): """ @@ -363,10 +354,14 @@ class RegressorMixin: _estimator_type = "regressor" - @generate_docstring(return_values={'name': 'score', - 'type': 'float', - 'description': 'R^2 of self.predict(X) ' - 'wrt. y.'}) + @generate_docstring( + return_values={ + 'name': 'score', + 'type': 'float', + 'description': 'R^2 of self.predict(X) ' + 'wrt. y.' + }) + @cuml.internals.api_base_return_any_skipall def score(self, X, y, **kwargs): """ Scoring function for regression estimators @@ -390,21 +385,22 @@ class ClassifierMixin: _estimator_type = "classifier" - @generate_docstring(return_values={'name': 'score', - 'type': 'float', - 'description': 'Accuracy of \ - self.predict(X) wrt. y \ - (fraction where y == \ - pred_y)'}) + @generate_docstring( + return_values={ + 'name': + 'score', + 'type': + 'float', + 'description': ('Accuracy of self.predict(X) wrt. y ' + '(fraction where y == pred_y)') + }) + @cuml.internals.api_base_return_any_skipall def score(self, X, y, **kwargs): """ Scoring function for classifier estimators based on mean accuracy. """ from cuml.metrics.accuracy import accuracy_score - from cuml.common import input_to_dev_array - - y_m = input_to_dev_array(y)[0] if hasattr(self, 'handle'): handle = self.handle @@ -412,59 +408,34 @@ class ClassifierMixin: handle = None preds = self.predict(X, **kwargs) - return accuracy_score(y_m, preds, handle=handle) + return accuracy_score(y, preds, handle=handle) # Internal, non class owned helper functions +def _check_output_type_str(output_str): -_input_type_to_str = { - numpyArray: 'numpy', - cupyArray: 'cupy', - cuSeries: 'cudf', - cuDataFrame: 'cudf', - pdSeries: 'numpy', - pdDataFrame: 'numpy' -} - - -def _input_to_type(input): - # function to access _input_to_str, while still using the correct - # numba check for a numba device_array - if type(input) in _input_type_to_str.keys(): - return _input_type_to_str[type(input)] - elif numbaArray.is_cuda_ndarray(input): - return 'numba' - else: - return 'cupy' + if (output_str is None): + return "input" + assert output_str != "mirror", \ + ("Cannot pass output_type='mirror' in Base.__init__(). Did you forget " + "to pass `output_type=self.output_type` to a child estimator? " + "Currently `cuml.global_output_type==`{}`" + ).format(cuml.global_output_type) -def _check_output_type_str(output_str): if isinstance(output_str, str): output_type = output_str.lower() + # Check for valid output types + "input" if output_type in ['numpy', 'cupy', 'cudf', 'numba', 'input']: - return output_str - else: - raise ValueError(("output_type must be one of " - "'numpy', 'cupy', 'cudf', 'numba', or 'input'." - " Got: '{}'" - ).format(output_str)) - else: - raise ValueError(("output_type must be a string" - " Got: '{}'" - ).format(type(output_str))) - - -def _input_target_to_dtype(target): - canonical_input_types = tuple(_input_type_to_str.keys()) + # Return the original version if nothing has changed, otherwise + # return the lowered. This is to try and keep references the same + # to support sklearn.base.clone() where possible + return output_str if output_type == output_str else output_type - if isinstance(target, (cuDataFrame, pdDataFrame)): - # Assume single-label target - dtype = target[target.columns[0]].dtype - elif isinstance(target, canonical_input_types): - dtype = target.dtype - else: - dtype = None - return dtype + # Did not match any acceptable value + raise ValueError("output_type must be one of " + + "'numpy', 'cupy', 'cudf' or 'numba'" + + "Got: {}".format(output_str)) def _determine_stateless_output_type(output_type, input_obj): @@ -482,6 +453,6 @@ def _determine_stateless_output_type(output_type, input_obj): # If we are using 'input', determine the the type from the input object if temp_output == 'input': - temp_output = _input_to_type(input_obj) + temp_output = cuml.common.input_utils.determine_array_type(input_obj) return temp_output diff --git a/python/cuml/common/input_utils.py b/python/cuml/common/input_utils.py index 9a84571ae9..adfea879ad 100644 --- a/python/cuml/common/input_utils.py +++ b/python/cuml/common/input_utils.py @@ -15,19 +15,25 @@ # import copy +from collections import namedtuple + import cudf import cupy as cp import cupyx +import numba.cuda import numpy as np import pandas as pd - -from collections import namedtuple -from cuml.common import CumlArray +import cuml.internals +import cuml.common.array +from cuml.common.array import CumlArray +from cuml.common.array_sparse import SparseCumlArray +from cuml.common.import_utils import has_scipy from cuml.common.logger import debug -from cuml.common.memory_utils import with_cupy_rmm +from cuml.common.memory_utils import ArrayInfo from cuml.common.memory_utils import _check_array_contiguity -from numba import cuda +if has_scipy(): + import scipy.sparse cuml_array = namedtuple('cuml_array', 'array n_rows n_cols dtype') @@ -35,25 +41,34 @@ # in all algos. Github issue #1716 inp_array = namedtuple('inp_array', 'array pointer n_rows n_cols dtype') +unsupported_cudf_dtypes = [ + np.uint8, np.uint16, np.uint32, np.uint64, np.float16 +] -def get_dev_array_ptr(ary): - """ - Returns ctype pointer of a numba style device array +_input_type_to_str = { + CumlArray: "cuml", + SparseCumlArray: "cuml", + np.ndarray: "numpy", + cp.ndarray: "cupy", + cudf.Series: "cudf", + cudf.DataFrame: "cudf", + pd.Series: "numpy", + pd.DataFrame: "numpy", + numba.cuda.devicearray.DeviceNDArrayBase: "numba", + cupyx.scipy.sparse.spmatrix: "cupy", +} - Deprecated: will be removed once all codebase uses cuml Array - See Github issue #1716 - """ - return ary.device_ctypes_pointer.value +_sparse_types = [ + SparseCumlArray, + cupyx.scipy.sparse.spmatrix, +] +if has_scipy(): + _input_type_to_str.update({ + scipy.sparse.spmatrix: "numpy", + }) -def get_cudf_column_ptr(col): - """ - Returns pointer of a cudf Series - - Deprecated: will be removed once all codebase uses cuml Array - See Github issue #1716 - """ - return col.__cuda_array_interface__['data'][0] + _sparse_types.append(scipy.sparse.spmatrix) def get_supported_input_type(X): @@ -80,6 +95,13 @@ def get_supported_input_type(X): If the array-like object is supported, the type is returned. Otherwise, `None` is returned. """ + # Check CumlArray first to shorten search time + if isinstance(X, CumlArray): + return CumlArray + + if isinstance(X, SparseCumlArray): + return SparseCumlArray + if (isinstance(X, cudf.Series)): if X.null_count != 0: return None @@ -96,24 +118,98 @@ def get_supported_input_type(X): if isinstance(X, cudf.DataFrame): return cudf.DataFrame - if isinstance(X, CumlArray): - return CumlArray + if numba.cuda.devicearray.is_cuda_ndarray(X): + return numba.cuda.devicearray.DeviceNDArrayBase if hasattr(X, "__cuda_array_interface__"): return cp.ndarray if hasattr(X, "__array_interface__"): - return np.ndarray + # For some reason, numpy scalar types also implement + # `__array_interface__`. See numpy.generic.__doc__. Exclude those types + # as well as np.dtypes + if (not isinstance(X, np.generic) and not isinstance(X, type)): + return np.ndarray + + if cupyx.scipy.sparse.isspmatrix(X): + return cupyx.scipy.sparse.spmatrix + + if has_scipy(): + if (scipy.sparse.isspmatrix(X)): + return scipy.sparse.spmatrix # Return None if this type isnt supported return None -@with_cupy_rmm -def input_to_cuml_array(X, order='F', deepcopy=False, - check_dtype=False, convert_to_dtype=False, - check_cols=False, check_rows=False, - fail_on_order=False, force_contiguous=True): +def determine_array_type(X): + if (X is None): + return None + + # Get the generic type + gen_type = get_supported_input_type(X) + + return None if gen_type is None else _input_type_to_str[gen_type] + + +def determine_array_dtype(X): + + if (X is None): + return None + + canonical_input_types = tuple(_input_type_to_str.keys()) + + if isinstance(X, (cudf.DataFrame, pd.DataFrame)): + # Assume single-label target + dtype = X[X.columns[0]].dtype + elif isinstance(X, canonical_input_types): + dtype = X.dtype + else: + dtype = None + + return dtype + + +def determine_array_type_full(X): + """ + Returns a tuple of the array type, and a boolean if it is sparse + + Parameters + ---------- + X : array-like + Input array to test + + Returns + ------- + (string, bool) Returns a tuple of the array type string and a boolean if it + is a sparse array. + """ + if (X is None): + return None, None + + # Get the generic type + gen_type = get_supported_input_type(X) + + if (gen_type is None): + return None, None + + return _input_type_to_str[gen_type], gen_type in _sparse_types + + +def is_array_like(X): + return determine_array_type(X) is not None + + +@cuml.internals.api_return_any() +def input_to_cuml_array(X, + order='F', + deepcopy=False, + check_dtype=False, + convert_to_dtype=False, + check_cols=False, + check_rows=False, + fail_on_order=False, + force_contiguous=True): """ Convert input X to CumlArray. @@ -177,6 +273,17 @@ def input_to_cuml_array(X, order='F', deepcopy=False, A new CumlArray and associated data. """ + def check_order(arr_order): + if order != 'K' and arr_order != order: + if fail_on_order: + raise ValueError("Expected " + order_to_str(order) + + " major order, but got the opposite.") + else: + debug("Expected " + order_to_str(order) + " major order, " + "but got the opposite. Converting data, this will " + "result in additional memory utilization.") + return True + return False # dtype conversion @@ -193,8 +300,8 @@ def input_to_cuml_array(X, order='F', deepcopy=False, if (isinstance(X, cudf.Series)): if X.null_count != 0: - raise ValueError("Error: cuDF Series has missing/null values, \ - which are not supported by cuML.") + raise ValueError("Error: cuDF Series has missing/null values, " + "which are not supported by cuML.") # converting pandas to numpy before sending it to CumlArray if isinstance(X, pd.DataFrame) or isinstance(X, pd.Series): @@ -213,13 +320,27 @@ def input_to_cuml_array(X, order='F', deepcopy=False, elif hasattr(X, "__array_interface__") or \ hasattr(X, "__cuda_array_interface__"): + # Since we create the array with the correct order here, do the order + # check now if necessary + interface = getattr(X, "__array_interface__", None) or getattr( + X, "__cuda_array_interface__", None) + + arr_info = ArrayInfo.from_interface(interface) + + check_order(arr_info.order) + + make_copy = False + if force_contiguous or hasattr(X, "__array_interface__"): if not _check_array_contiguity(X): - debug("Non contiguous array or view detected, a \ - contiguous copy of the data will be done. ") - X = cp.array(X, order=order, copy=True) + debug("Non contiguous array or view detected, a " + "contiguous copy of the data will be done.") + # X = cp.array(X, order=order, copy=True) + make_copy = True + + cp_arr = cp.array(X, copy=make_copy, order=order) - X_m = CumlArray(data=X) + X_m = CumlArray(data=cp_arr) if deepcopy: X_m = copy.deepcopy(X_m) @@ -255,32 +376,53 @@ def input_to_cuml_array(X, order='F', deepcopy=False, if check_cols: if n_cols != check_cols: raise ValueError("Expected " + str(check_cols) + - " columns but got " + str(n_cols) + - " columns.") + " columns but got " + str(n_cols) + " columns.") if check_rows: if n_rows != check_rows: - raise ValueError("Expected " + str(check_rows) + - " rows but got " + str(n_rows) + - " rows.") - - if order != 'K' and X_m.order != order: - if fail_on_order: - raise ValueError("Expected " + order_to_str(order) + - " major order, but got the opposite.") - else: - debug("Expected " + order_to_str(order) + " major order, " - "but got the opposite. Converting data, this will " - "result in additional memory utilization.") - X_m = cp.array(X_m, copy=False, order=order) - X_m = CumlArray(data=X_m) + raise ValueError("Expected " + str(check_rows) + " rows but got " + + str(n_rows) + " rows.") + + if (check_order(X_m.order)): + X_m = cp.array(X_m, copy=False, order=order) + X_m = CumlArray(data=X_m) return cuml_array(array=X_m, n_rows=n_rows, n_cols=n_cols, dtype=X_m.dtype) -def input_to_host_array(X, order='F', deepcopy=False, - check_dtype=False, convert_to_dtype=False, - check_cols=False, check_rows=False, +def input_to_cupy_array(X, + order='F', + deepcopy=False, + check_dtype=False, + convert_to_dtype=False, + check_cols=False, + check_rows=False, + fail_on_order=False, + force_contiguous=True) -> cuml_array: + """ + Identical to input_to_cuml_array but it returns a cupy array instead of + CumlArray + """ + out_data = input_to_cuml_array(X, + order=order, + deepcopy=deepcopy, + check_dtype=check_dtype, + convert_to_dtype=convert_to_dtype, + check_cols=check_cols, + check_rows=check_rows, + fail_on_order=fail_on_order, + force_contiguous=force_contiguous) + + return out_data._replace(array=out_data.array.to_output("cupy")) + + +def input_to_host_array(X, + order='F', + deepcopy=False, + check_dtype=False, + convert_to_dtype=False, + check_cols=False, + check_rows=False, fail_on_order=False): """ Convert input X to host array (NumPy) suitable for C++ methods that accept @@ -371,87 +513,7 @@ def input_to_host_array(X, order='F', deepcopy=False, dtype=ary_tuple.dtype) -def input_to_dev_array(X, order='F', deepcopy=False, - check_dtype=False, convert_to_dtype=False, - check_cols=False, check_rows=False, - fail_on_order=False): - """ - *** Deprecated, used in classes that have not migrated to use cuML Array - yet. Please use input_to_cuml_array instead for cuml Array. - See Github issue #1716 *** - - Convert input X to device array suitable for C++ methods. - - Acceptable input formats: - - * cuDF Dataframe - returns a deep copy always. - * cuDF Series - returns by reference or a deep copy depending on - `deepcopy`. - * Numpy array - returns a copy in device always - * cuda array interface compliant array (like Cupy) - returns a - reference unless `deepcopy`=True. - * numba device array - returns a reference unless deepcopy=True - - Parameters - ---------- - - X : cuDF.DataFrame, cuDF.Series, numba array, NumPy array or any - cuda_array_interface compliant array like CuPy or pytorch. - - order: string (default: 'F') - Whether to return a F-major or C-major array. Used to check the order - of the input. If fail_on_order=True method will raise ValueError, - otherwise it will convert X to be of order `order`. - - deepcopy: boolean (default: False) - Set to True to always return a deep copy of X. - - check_dtype: np.dtype (default: False) - Set to a np.dtype to throw an error if X is not of dtype `check_dtype`. - - convert_to_dtype: np.dtype (default: False) - Set to a dtype if you want X to be converted to that dtype if it is - not that dtype already. - - check_cols: int (default: False) - Set to an int `i` to check that input X has `i` columns. Set to False - (default) to not check at all. - - check_rows: boolean (default: False) - Set to an int `i` to check that input X has `i` columns. Set to False - (default) to not check at all. - - fail_on_order: boolean (default: False) - Set to True if you want the method to raise a ValueError if X is not - of order `order`. - - Returns - ------- - `inp_array`: namedtuple('inp_array', 'array pointer n_rows n_cols dtype') - - A new device array if the input was not a numba device - array. It is a reference to the input X if it was a numba device array - or cuda array interface compliant (like cupy) - - """ - - ary_tuple = input_to_cuml_array(X, - order=order, - deepcopy=deepcopy, - check_dtype=check_dtype, - convert_to_dtype=convert_to_dtype, - check_cols=check_cols, - check_rows=check_rows, - fail_on_order=fail_on_order) - - return inp_array(array=cuda.as_cuda_array(ary_tuple.array), - pointer=ary_tuple.array.ptr, - n_rows=ary_tuple.n_rows, - n_cols=ary_tuple.n_cols, - dtype=ary_tuple.dtype) - - -@with_cupy_rmm +@cuml.internals.api_return_any() def convert_dtype(X, to_dtype=np.float32, legacy=True): """ Convert X to be of dtype `dtype`, raising a TypeError @@ -470,12 +532,12 @@ def convert_dtype(X, to_dtype=np.float32, legacy=True): elif isinstance(X, (cudf.Series, cudf.DataFrame, pd.Series, pd.DataFrame)): return X.astype(to_dtype, copy=False) - elif cuda.is_cuda_array(X): + elif numba.cuda.is_cuda_array(X): X_m = cp.asarray(X) X_m = X_m.astype(to_dtype, copy=False) if legacy: - return cuda.as_cuda_array(X_m) + return numba.cuda.as_cuda_array(X_m) else: return CumlArray(data=X_m) @@ -501,16 +563,14 @@ def _typecast_will_lose_information(X, target_dtype): if X.dtype.type == target_dtype: return False - return ( - (X < target_dtype_range.min) | - (X > target_dtype_range.max) - ).any() + return ((X < target_dtype_range.min) | + (X > target_dtype_range.max)).any() elif isinstance(X, (pd.DataFrame, cudf.DataFrame)): X_m = X.values return _typecast_will_lose_information(X_m, target_dtype) - elif cuda.is_cuda_array(X): + elif numba.cuda.is_cuda_array(X): X_m = cp.asarray(X) return _typecast_will_lose_information(X_m, target_dtype) diff --git a/python/cuml/common/memory_utils.py b/python/cuml/common/memory_utils.py index dffb591d7c..4c391606af 100644 --- a/python/cuml/common/memory_utils.py +++ b/python/cuml/common/memory_utils.py @@ -15,15 +15,17 @@ # import contextlib +import functools +import operator +import re +from dataclasses import dataclass +from functools import wraps + import cuml import cupy as cp -import functools import numpy as np -import operator import rmm - from cuml.common.import_utils import check_min_cupy_version -from functools import wraps from numba import cuda as nbcuda try: @@ -35,6 +37,37 @@ pass +@dataclass(frozen=True) +class ArrayInfo: + """ + Calculate the necessary shape, order, stride and dtype of an array from an + ``__array_interface__`` or ``__cuda_array_interface__`` + """ + shape: tuple + order: str + dtype: np.dtype + strides: tuple + + @staticmethod + def from_interface(interface: dict) -> "ArrayInfo": + out_shape = interface['shape'] + out_type = np.dtype(interface['typestr']) + out_order = "C" + out_strides = None + + if interface.get('strides', None) is None: + out_order = 'C' + out_strides = _order_to_strides(out_order, out_shape, out_type) + else: + out_strides = interface['strides'] + out_order = _strides_to_order(out_strides, out_type) + + return ArrayInfo(shape=out_shape, + order=out_order, + dtype=out_type, + strides=out_strides) + + def with_cupy_rmm(func): """ @@ -50,14 +83,79 @@ def fx(...): a = cp.arange(10) # uses RMM for allocation """ + + if (func.__dict__.get("__cuml_rmm_wrapped", False)): + return func + @wraps(func) def cupy_rmm_wrapper(*args, **kwargs): with cupy_using_allocator(rmm.rmm_cupy_allocator): return func(*args, **kwargs) + # Mark the function as already wrapped + cupy_rmm_wrapper.__dict__["__cuml_rmm_wrapped"] = True + return cupy_rmm_wrapper +def class_with_cupy_rmm(skip_init=False, + skip_private=True, + skip_dunder=True, + ignore_pattern: list = []): + + regex_list = ignore_pattern + + if (skip_private): + # Match private but not dunder + regex_list.append(r"^_(?!(_))\w+$") + + if (skip_dunder): + if (not skip_init): + # Make sure to not match __init__ + regex_list.append(r"^__(?!(init))\w+__$") + else: + # Match all dunder + regex_list.append(r"^__\w+__$") + elif (skip_init): + regex_list.append(r"^__init__$") + + final_regex = '(?:%s)' % '|'.join(regex_list) + + def inner(klass): + + for attributeName, attribute in klass.__dict__.items(): + + # Skip patters that dont match + if (re.match(final_regex, attributeName)): + continue + + if callable(attribute): + + # Passed the ignore patters. Wrap the function (will do nothing + # if already wrapped) + setattr(klass, attributeName, with_cupy_rmm(attribute)) + + # Class/Static methods work differently since they are descriptors + # (and not callable). Instead unwrap the function, and rewrap it + elif (isinstance(attribute, classmethod)): + unwrapped = attribute.__func__ + + setattr(klass, + attributeName, + classmethod(with_cupy_rmm(unwrapped))) + + elif (isinstance(attribute, staticmethod)): + unwrapped = attribute.__func__ + + setattr(klass, + attributeName, + staticmethod(with_cupy_rmm(unwrapped))) + + return klass + + return inner + + def rmm_cupy_ary(cupy_fn, *args, **kwargs): """ @@ -140,13 +238,13 @@ def _strides_to_order(strides, dtype): def _order_to_strides(order, shape, dtype): itemsize = cp.dtype(dtype).itemsize if isinstance(shape, int): - return (itemsize,) + return (itemsize, ) elif len(shape) == 0: return None elif len(shape) == 1: - return (itemsize,) + return (itemsize, ) elif order == 'C': dim_minor = shape[1] * itemsize @@ -172,7 +270,7 @@ def _get_size_from_shape(shape, dtype): itemsize = cp.dtype(dtype).itemsize if isinstance(shape, int): size = itemsize * shape - shape = (shape,) + shape = (shape, ) elif isinstance(shape, tuple): size = functools.reduce(operator.mul, shape) size = size * itemsize @@ -310,17 +408,18 @@ def set_global_output_type(output_type): CPU memory. """ - if isinstance(output_type, str): + if (isinstance(output_type, str)): output_type = output_type.lower() - if output_type in ['numpy', 'cupy', 'cudf', 'numba', 'input']: - cuml.global_output_type = output_type - else: - raise ValueError('Parameter output_type must be one of ' + - '"series", "dataframe", cupy", "numpy", ' + - '"numba" or "input') - else: - raise ValueError('Parameter output_type must be one of "series" ' + - '"dataframe", cupy", "numpy", "numba" or "input') + + # Check for allowed types. Allow 'cuml' to support internal estimators + if output_type not in [ + 'numpy', 'cupy', 'cudf', 'numba', 'cuml', "input", None + ]: + # Omit 'cuml' from the error message. Should only be used internally + raise ValueError('Parameter output_type must be one of "numpy", ' + '"cupy", cudf", "numba", "input" or None') + + cuml.global_output_type = output_type @contextlib.contextmanager @@ -405,21 +504,12 @@ def using_output_type(output_type): """ - if isinstance(output_type, str): - output_type = output_type.lower() - if output_type in ['numpy', 'cupy', 'cudf', 'numba', 'input']: - prev_output_type = cuml.global_output_type - try: - cuml.global_output_type = output_type - yield - finally: - cuml.global_output_type = prev_output_type - else: - raise ValueError('Parameter output_type must be one of "series" ' + - '"dataframe", cupy", "numpy", "numba" or "input') - else: - raise ValueError('Parameter output_type must be one of "series" ' + - '"dataframe", cupy", "numpy", "numba" or "input') + prev_output_type = cuml.global_output_type + try: + set_global_output_type(output_type) + yield prev_output_type + finally: + cuml.global_output_type = prev_output_type @with_cupy_rmm diff --git a/python/cuml/common/sparsefuncs.py b/python/cuml/common/sparsefuncs.py index 3f3fa770f6..8b0e0b082c 100644 --- a/python/cuml/common/sparsefuncs.py +++ b/python/cuml/common/sparsefuncs.py @@ -16,7 +16,7 @@ import math import cupy as cp import cupyx -from cuml.common import with_cupy_rmm +import cuml.internals from cuml.common.kernel_utils import cuda_kernel_factory @@ -74,7 +74,7 @@ def _map_l2_norm_kernel(dtype): return cuda_kernel_factory(map_kernel_str, dtype, "map_l2_norm_kernel") -@with_cupy_rmm +@cuml.internals.api_return_any() def csr_row_normalize_l1(X, inplace=True): """Row normalize for csr matrix using the l1 norm""" if not inplace: @@ -87,7 +87,7 @@ def csr_row_normalize_l1(X, inplace=True): return X -@with_cupy_rmm +@cuml.internals.api_return_any() def csr_row_normalize_l2(X, inplace=True): """Row normalize for csr matrix using the l2 norm""" if not inplace: @@ -100,7 +100,7 @@ def csr_row_normalize_l2(X, inplace=True): return X -@with_cupy_rmm +@cuml.internals.api_return_any() def csr_diag_mul(X, y, inplace=True): """Multiply a sparse X matrix with diagonal matrix y""" if not inplace: @@ -111,6 +111,7 @@ def csr_diag_mul(X, y, inplace=True): return X +@cuml.internals.api_return_any() def create_csr_matrix_from_count_df(count_df, empty_doc_ids, n_doc, n_features, dtype=cp.float32): """ diff --git a/python/cuml/datasets/arima.pyx b/python/cuml/datasets/arima.pyx index d1e19730fa..e6a625131c 100644 --- a/python/cuml/datasets/arima.pyx +++ b/python/cuml/datasets/arima.pyx @@ -16,10 +16,12 @@ # distutils: language = c++ -import cuml +import warnings + import numpy as np from cuml.common.array import CumlArray as cumlArray +import cuml.internals from cuml.raft.common.handle cimport handle_t from cuml.raft.common.handle import Handle from cuml.tsa.arima cimport ARIMAOrder @@ -64,6 +66,7 @@ inp_to_dtype = { } +@cuml.internals.api_return_array() def make_arima(batch_size=1000, n_obs=100, order=(1, 1, 1), seasonal_order=(0, 0, 0, 0), intercept=False, random_state=None, dtype='double', output_type='cupy', @@ -97,6 +100,13 @@ def make_arima(batch_size=1000, n_obs=100, order=(1, 1, 1), or 'double' output_type: {'cudf', 'cupy', 'numpy'} Type of the returned dataset + + .. deprecated:: 0.17 + `output_type` is deprecated in 0.17 and will be removed in 0.18. + Please use the module level output type control, + `cuml.global_output_type`. + See :ref:`output-data-type-configuration` for more info. + handle: cuml.Handle If it is None, a new one is created just for this function call @@ -106,6 +116,15 @@ def make_arima(batch_size=1000, n_obs=100, order=(1, 1, 1), Array of the requested type containing the generated dataset """ + # Check for deprecated `output_type` and warn. Set manually if specified + if (output_type is not None): + warnings.warn("Using the `output_type` argument is deprecated and " + "will be removed in 0.18. Please specify the output " + "type using `cuml.using_output_type()` instead", + DeprecationWarning) + + cuml.internals.set_api_output_type(output_type) + cdef ARIMAOrder cpp_order cpp_order.p, cpp_order.d, cpp_order.q = order cpp_order.P, cpp_order.D, cpp_order.Q, cpp_order.s = seasonal_order @@ -142,4 +161,4 @@ def make_arima(batch_size=1000, n_obs=100, order=(1, 1, 1), noise_scale, intercept_scale, random_state) - return out.to_output(output_type) + return out diff --git a/python/cuml/datasets/blobs.py b/python/cuml/datasets/blobs.py index 8367010739..44e28671d3 100644 --- a/python/cuml/datasets/blobs.py +++ b/python/cuml/datasets/blobs.py @@ -19,7 +19,7 @@ from collections.abc import Iterable import cupy as cp import numpy as np -from cuml.common import with_cupy_rmm +import cuml.internals from cuml.datasets.utils import _create_rs_generator @@ -65,7 +65,7 @@ def _get_centers(rs, centers, center_box, n_samples, n_features, dtype): return centers, n_centers -@with_cupy_rmm +@cuml.internals.api_return_any() def make_blobs(n_samples=100, n_features=2, centers=None, cluster_std=1.0, center_box=(-10.0, 10.0), shuffle=True, random_state=None, return_centers=False, order='F', dtype='float32'): diff --git a/python/cuml/datasets/classification.py b/python/cuml/datasets/classification.py index eacfbe096e..41315b4333 100644 --- a/python/cuml/datasets/classification.py +++ b/python/cuml/datasets/classification.py @@ -13,10 +13,9 @@ # limitations under the License. # - +import cuml.internals from cuml.common.import_utils import has_sklearn from cuml.datasets.utils import _create_rs_generator -from cuml.common import with_cupy_rmm import cupy as cp import numpy as np @@ -42,7 +41,7 @@ def _generate_hypercube(samples, dimensions, rng): return out -@with_cupy_rmm +@cuml.internals.api_return_any() def make_classification(n_samples=100, n_features=20, n_informative=2, n_redundant=2, n_repeated=0, n_classes=2, n_clusters_per_class=2, weights=None, flip_y=0.01, diff --git a/python/cuml/datasets/regression.pyx b/python/cuml/datasets/regression.pyx index 2cfdc9581e..6472825263 100644 --- a/python/cuml/datasets/regression.pyx +++ b/python/cuml/datasets/regression.pyx @@ -16,14 +16,15 @@ # distutils: language = c++ -import cuml +import typing + import numpy as np +import cuml.internals +from cuml.common.array import CumlArray from cuml.raft.common.handle cimport handle_t from cuml.raft.common.handle import Handle -from cuml.common import get_dev_array_ptr, zeros - from libcpp cimport bool from libc.stdint cimport uint64_t, uintptr_t @@ -71,13 +72,26 @@ inp_to_dtype = { } -def make_regression(n_samples=100, n_features=2, n_informative=2, n_targets=1, - bias=0.0, effective_rank=None, tail_strength=0.5, - noise=0.0, shuffle=True, coef=False, random_state=None, - dtype='single', handle=None): +@cuml.internals.api_return_generic() +def make_regression( + n_samples=100, + n_features=2, + n_informative=2, + n_targets=1, + bias=0.0, + effective_rank=None, + tail_strength=0.5, + noise=0.0, + shuffle=True, + coef=False, + random_state=None, + dtype='single', + handle=None +) -> typing.Union[typing.Tuple[CumlArray, CumlArray], + typing.Tuple[CumlArray, CumlArray, CumlArray]]: """Generate a random regression problem. - See https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html + See https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html # noqa: E501 Examples -------- @@ -148,7 +162,12 @@ def make_regression(n_samples=100, n_features=2, n_informative=2, n_targets=1, coef : device array of shape [n_features, n_targets], optional The coefficient of the underlying linear model. It is returned only if coef is True. - """ # noqa + """ + + # Set the default output type to "cupy". This will be ignored if the user + # has set `cuml.global_output_type`. Only necessary for array generation + # methods that do not take an array as input + cuml.internals.set_api_output_type("cupy") if dtype not in ['single', 'float', 'double', np.float32, np.float64]: raise TypeError("dtype must be either 'float' or 'double'") @@ -161,17 +180,19 @@ def make_regression(n_samples=100, n_features=2, n_informative=2, n_targets=1, handle = Handle() if handle is None else handle cdef handle_t* handle_ = handle.getHandle() - out = zeros((n_samples, n_features), dtype=dtype, order='C') - cdef uintptr_t out_ptr = get_dev_array_ptr(out) + out = CumlArray.zeros((n_samples, n_features), dtype=dtype, order='C') + cdef uintptr_t out_ptr = out.ptr - values = zeros((n_samples, n_targets), dtype=dtype, order='C') - cdef uintptr_t values_ptr = get_dev_array_ptr(values) + values = CumlArray.zeros((n_samples, n_targets), dtype=dtype, order='C') + cdef uintptr_t values_ptr = values.ptr cdef uintptr_t coef_ptr coef_ptr = NULL if coef: - coefs = zeros((n_features, n_targets), dtype=dtype, order='C') - coef_ptr = get_dev_array_ptr(coefs) + coefs = CumlArray.zeros((n_features, n_targets), + dtype=dtype, + order='C') + coef_ptr = coefs.ptr if random_state is None: random_state = randint(0, 1e18) diff --git a/python/cuml/decomposition/base_mg.pyx b/python/cuml/decomposition/base_mg.pyx index bdedec6eec..8615034c88 100644 --- a/python/cuml/decomposition/base_mg.pyx +++ b/python/cuml/decomposition/base_mg.pyx @@ -31,10 +31,10 @@ from cython.operator cimport dereference as deref from cuml.common.array import CumlArray import cuml.common.opg_data_utils_mg as opg +import cuml.internals from cuml.common.base import Base from cuml.raft.common.handle cimport handle_t from cuml.decomposition.utils cimport * -from cuml.common import input_to_dev_array, zeros from cuml.common import input_to_cuml_array from cuml.common.opg_data_utils_mg cimport * @@ -44,6 +44,7 @@ class BaseDecompositionMG(object): def __init__(self, **kwargs): super(BaseDecompositionMG, self).__init__(**kwargs) + @cuml.internals.api_base_return_any_skipall def fit(self, X, total_rows, n_cols, partsToRanks, rank, _transform=False): """ diff --git a/python/cuml/decomposition/pca.pyx b/python/cuml/decomposition/pca.pyx index 0a7e64b718..0e0e40a720 100644 --- a/python/cuml/decomposition/pca.pyx +++ b/python/cuml/decomposition/pca.pyx @@ -27,23 +27,23 @@ from enum import IntEnum import rmm -import cuml - from libcpp cimport bool from libc.stdint cimport uintptr_t from cython.operator cimport dereference as deref +import cuml.internals from cuml.common.array import CumlArray from cuml.common.base import Base -from cuml.common.base import _input_to_type from cuml.common.doc_utils import generate_docstring from cuml.raft.common.handle cimport handle_t from cuml.raft.common.handle import Handle import cuml.common.logger as logger from cuml.decomposition.utils cimport * -from cuml.common import input_to_cuml_array -from cuml.common import with_cupy_rmm +from cuml.common.input_utils import input_to_cuml_array +from cuml.common.input_utils import input_to_cupy_array +from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.common import using_output_type from cuml.prims.stats import cov from cuml.common.input_utils import sparse_scipy_to_cp @@ -109,6 +109,7 @@ class Solver(IntEnum): class PCA(Base): + """ PCA (Principal Component Analysis) is a fundamental dimensionality reduction technique used to combine features in X in linear combinations @@ -143,13 +144,13 @@ class PCA(Base): pca_float.fit(gdf_float) print(f'components: {pca_float.components_}') - print(f'explained variance: {pca_float._explained_variance_}') - exp_var = pca_float._explained_variance_ratio_ + print(f'explained variance: {pca_float.explained_variance_}') + exp_var = pca_float.explained_variance_ratio_ print(f'explained variance ratio: {exp_var}') - print(f'singular values: {pca_float._singular_values_}') - print(f'mean: {pca_float._mean_}') - print(f'noise variance: {pca_float._noise_variance_}') + print(f'singular values: {pca_float.singular_values_}') + print(f'mean: {pca_float.mean_}') + print(f'noise variance: {pca_float.noise_variance_}') trans_gdf_float = pca_float.transform(gdf_float) print(f'Inverse: {trans_gdf_float}') @@ -286,6 +287,14 @@ class PCA(Base): `_. """ + components_ = CumlArrayDescriptor() + explained_variance_ = CumlArrayDescriptor() + explained_variance_ratio_ = CumlArrayDescriptor() + singular_values_ = CumlArrayDescriptor() + mean_ = CumlArrayDescriptor() + noise_variance_ = CumlArrayDescriptor() + trans_input_ = CumlArrayDescriptor() + def __init__(self, copy=True, handle=None, iterated_power=15, n_components=None, random_state=None, svd_solver='auto', tol=1e-7, verbose=False, whiten=False, @@ -303,19 +312,13 @@ class PCA(Base): self.c_algorithm = self._get_algorithm_c_name(self.svd_solver) # internal array attributes - self._components_ = None # accessed via estimator.components_ - self._trans_input_ = None # accessed via estimator.trans_input_ - self._explained_variance_ = None - # accessed via estimator.explained_variance_ - - self._explained_variance_ratio_ = None - # accessed via estimator.explained_variance_ratio_ - - self._singular_values_ = None - # accessed via estimator.singular_values_ - - self._mean_ = None # accessed via estimator.mean_ - self._noise_variance_ = None # accessed via estimator.noise_variance_ + self.components_ = None + self.trans_input_ = None + self.explained_variance_ = None + self.explained_variance_ratio_ = None + self.singular_values_ = None + self.mean_ = None + self.noise_variance_ = None # This variable controls whether a sparse model was fit # This can be removed once there is more inter-operability @@ -359,19 +362,18 @@ class PCA(Base): def _initialize_arrays(self, n_components, n_rows, n_cols): - self._components_ = CumlArray.zeros((n_components, n_cols), - dtype=self.dtype) - self._explained_variance_ = CumlArray.zeros(n_components, - dtype=self.dtype) - self._explained_variance_ratio_ = CumlArray.zeros(n_components, - dtype=self.dtype) - self._mean_ = CumlArray.zeros(n_cols, dtype=self.dtype) + self.components_ = CumlArray.zeros((n_components, n_cols), + dtype=self.dtype) + self.explained_variance_ = CumlArray.zeros(n_components, + dtype=self.dtype) + self.explained_variance_ratio_ = CumlArray.zeros(n_components, + dtype=self.dtype) + self.mean_ = CumlArray.zeros(n_cols, dtype=self.dtype) - self._singular_values_ = CumlArray.zeros(n_components, - dtype=self.dtype) - self._noise_variance_ = CumlArray.zeros(1, dtype=self.dtype) + self.singular_values_ = CumlArray.zeros(n_components, + dtype=self.dtype) + self.noise_variance_ = CumlArray.zeros(1, dtype=self.dtype) - @with_cupy_rmm def _sparse_fit(self, X): self._sparse_model = True @@ -383,66 +385,51 @@ class PCA(Base): # NOTE: All intermediate calculations are done using cupy.ndarray and # then converted to CumlArray at the end to minimize conversions # between types - covariance, temp_mean_, _ = cov(X, X, return_mean=True) + covariance, self.mean_, _ = cov(X, X, return_mean=True) - temp_explained_variance_, temp_components_ = \ + self.explained_variance_, self.components_ = \ cp.linalg.eigh(covariance, UPLO='U') # NOTE: We reverse the eigen vector and eigen values here # because cupy provides them in ascending order. Make a copy otherwise # it is not C_CONTIGUOUS anymore and would error when converting to # CumlArray - temp_explained_variance_ = temp_explained_variance_[::-1].copy() + self.explained_variance_ = self.explained_variance_[::-1] - temp_components_ = cp.flip(temp_components_, axis=1) + self.components_ = cp.flip(self.components_, axis=1) - temp_components_ = temp_components_.T[:self.n_components, :] + self.components_ = self.components_.T[:self.n_components, :] - temp_explained_variance_ratio_ = temp_explained_variance_ / cp.sum( - temp_explained_variance_) + self.explained_variance_ratio_ = self.explained_variance_ / cp.sum( + self.explained_variance_) if self.n_components < min(self.n_rows, self.n_cols): - temp_noise_variance_ = \ - temp_explained_variance_[self.n_components:].mean() + self.noise_variance_ = \ + self.explained_variance_[self.n_components:].mean() else: - temp_noise_variance_ = cp.array([0.0]) + self.noise_variance_ = cp.array([0.0]) - temp_explained_variance_ = \ - temp_explained_variance_[:self.n_components] + self.explained_variance_ = \ + self.explained_variance_[:self.n_components] - temp_explained_variance_ratio_ = \ - temp_explained_variance_ratio_[:self.n_components] + self.explained_variance_ratio_ = \ + self.explained_variance_ratio_[:self.n_components] # Truncating negative explained variance values to 0 - temp_singular_values_ = \ - cp.where(temp_explained_variance_ < 0, 0, - temp_explained_variance_) - temp_singular_values_ = \ - cp.sqrt(temp_singular_values_ * (self.n_rows - 1)) - - # Since temp_components_ can have a negative stride, copy it to get a - # new contiguous array - temp_components_ = temp_components_.copy() - - # Finally, store everything as CumlArray to support `to_output` - self._mean_ = CumlArray(temp_mean_) - self._explained_variance_ = CumlArray(temp_explained_variance_) - self._components_ = CumlArray(temp_components_) - self._noise_variance_ = CumlArray(temp_noise_variance_) - self._explained_variance_ratio_ = \ - CumlArray(temp_explained_variance_ratio_) - self._singular_values_ = CumlArray(temp_singular_values_) + self.singular_values_ = \ + cp.where(self.explained_variance_ < 0, 0, + self.explained_variance_) + self.singular_values_ = \ + cp.sqrt(self.singular_values_ * (self.n_rows - 1)) return self @generate_docstring(X='dense_sparse') - def fit(self, X, y=None): + def fit(self, X, y=None) -> "PCA": """ Fit the model with X. y is currently ignored. """ - self._set_base_attributes(output_type=X, n_features=X) - if cupyx.scipy.sparse.issparse(X): return self._sparse_fit(X) elif scipy.sparse.issparse(X): @@ -460,24 +447,25 @@ class PCA(Base): raise ValueError('Number of components should not be greater than' 'the number of columns in the data') + # Calling _initialize_arrays, guarantees everything is CumlArray self._initialize_arrays(params.n_components, params.n_rows, params.n_cols) - cdef uintptr_t comp_ptr = self._components_.ptr + cdef uintptr_t comp_ptr = self.components_.ptr cdef uintptr_t explained_var_ptr = \ - self._explained_variance_.ptr + self.explained_variance_.ptr cdef uintptr_t explained_var_ratio_ptr = \ - self._explained_variance_ratio_.ptr + self.explained_variance_ratio_.ptr cdef uintptr_t singular_vals_ptr = \ - self._singular_values_.ptr + self.singular_values_.ptr - cdef uintptr_t _mean_ptr = self._mean_.ptr + cdef uintptr_t _mean_ptr = self.mean_.ptr cdef uintptr_t noise_vars_ptr = \ - self._noise_variance_.ptr + self.noise_variance_.ptr cdef handle_t* handle_ = self.handle.getHandle() if self.dtype == np.float32: @@ -512,7 +500,8 @@ class PCA(Base): 'type': 'dense_sparse', 'description': 'Transformed values', 'shape': '(n_samples, n_components)'}) - def fit_transform(self, X, y=None): + @cuml.internals.api_base_return_array_skipall + def fit_transform(self, X, y=None) -> CumlArray: """ Fit the model with X and apply the dimensionality reduction on X. @@ -520,28 +509,26 @@ class PCA(Base): return self.fit(X).transform(X) - @with_cupy_rmm + @cuml.internals.api_base_return_array_skipall def _sparse_inverse_transform(self, X, return_sparse=False, - sparse_tol=1e-10, out_type=None): + sparse_tol=1e-10) -> CumlArray: # NOTE: All intermediate calculations are done using cupy.ndarray and # then converted to CumlArray at the end to minimize conversions # between types - temp_components_ = cp.asarray(self._components_) - temp_mean_ = self._mean_.to_output("cupy") if self.whiten: - temp_components_ *= (1 / cp.sqrt(self.n_rows - 1)) - temp_components_ *= self._singular_values_ + cp.multiply(self.components_, + (1 / cp.sqrt(self.n_rows - 1)), out=self.components_) + cp.multiply(self.components_, + self.singular_values_, out=self.components_) - X_inv = X.dot(temp_components_) - X_inv += temp_mean_ + X_inv = cp.dot(X, self.components_) + cp.add(X_inv, self.mean_, out=X_inv) if self.whiten: - temp_components_ /= self._singular_values_ - temp_components_ *= cp.sqrt(self.n_rows - 1) - - self._components_ = CumlArray(temp_components_) + self.components_ /= self.singular_values_ + self.components_ *= cp.sqrt(self.n_rows - 1) if return_sparse: X_inv = cp.where(X_inv < sparse_tol, 0, X_inv) @@ -550,20 +537,15 @@ class PCA(Base): return X_inv - if out_type == 'cupy': - return X_inv - else: - X_inv, _, _, _ = input_to_cuml_array(X_inv, order='K') - return X_inv.to_output(out_type) + return X_inv @generate_docstring(X='dense_sparse', return_values={'name': 'X_inv', 'type': 'dense_sparse', 'description': 'Transformed values', 'shape': '(n_samples, n_features)'}) - @with_cupy_rmm def inverse_transform(self, X, convert_dtype=False, - return_sparse=False, sparse_tol=1e-10): + return_sparse=False, sparse_tol=1e-10) -> CumlArray: """ Transform data back to its original space. @@ -571,28 +553,22 @@ class PCA(Base): """ - out_type = self._get_output_type(X) - if cupyx.scipy.sparse.issparse(X): return self._sparse_inverse_transform(X, return_sparse=return_sparse, - sparse_tol=sparse_tol, - out_type=out_type) + sparse_tol=sparse_tol) elif scipy.sparse.issparse(X): X = sparse_scipy_to_cp(X) return self._sparse_inverse_transform(X, return_sparse=return_sparse, - sparse_tol=sparse_tol, - out_type=out_type) + sparse_tol=sparse_tol) elif self._sparse_model: X, _, _, _ = \ - input_to_cuml_array(X, order='K', + input_to_cupy_array(X, order='K', check_dtype=[cp.float32, cp.float64]) - X = X.to_output(output_type='cupy') return self._sparse_inverse_transform(X, return_sparse=return_sparse, - sparse_tol=sparse_tol, - out_type=out_type) + sparse_tol=sparse_tol) X_m, n_rows, _, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, @@ -613,9 +589,9 @@ class PCA(Base): dtype=dtype.type) cdef uintptr_t input_ptr = input_data.ptr - cdef uintptr_t components_ptr = self._components_.ptr - cdef uintptr_t singular_vals_ptr = self._singular_values_.ptr - cdef uintptr_t _mean_ptr = self._mean_.ptr + cdef uintptr_t components_ptr = self.components_.ptr + cdef uintptr_t singular_vals_ptr = self.singular_values_.ptr + cdef uintptr_t _mean_ptr = self.mean_.ptr cdef handle_t* h_ = self.handle.getHandle() if dtype.type == np.float32: @@ -639,44 +615,35 @@ class PCA(Base): # following transfers start self.handle.sync() - return input_data.to_output(out_type) + return input_data - @with_cupy_rmm - def _sparse_transform(self, X, out_type=None): + @cuml.internals.api_base_return_array_skipall + def _sparse_transform(self, X) -> CumlArray: # NOTE: All intermediate calculations are done using cupy.ndarray and # then converted to CumlArray at the end to minimize conversions # between types - temp_components_ = self._components_.to_output("cupy") - temp_mean_ = self._mean_.to_output("cupy") - - if self.whiten: - temp_components_ *= cp.sqrt(self.n_rows - 1) - temp_components_ /= self._singular_values_ + with using_output_type("cupy"): - X = X - temp_mean_ - X_transformed = X.dot(temp_components_.T) + if self.whiten: + self.components_ *= cp.sqrt(self.n_rows - 1) + self.components_ /= self.singular_values_ - if self.whiten: - temp_components_ *= self._singular_values_ - temp_components_ *= (1 / cp.sqrt(self.n_rows - 1)) + X = X - self.mean_ + X_transformed = X.dot(self.components_.T) - self._components_ = CumlArray(temp_components_) + if self.whiten: + self.components_ *= self.singular_values_ + self.components_ *= (1 / cp.sqrt(self.n_rows - 1)) - if self._get_output_type(X) == 'cupy': - return X_transformed - else: - X_transformed, _, _, _ = \ - input_to_cuml_array(X_transformed, order='K') - return X_transformed.to_output(out_type) + return X_transformed @generate_docstring(X='dense_sparse', return_values={'name': 'trans', 'type': 'dense_sparse', 'description': 'Transformed values', 'shape': '(n_samples, n_components)'}) - @with_cupy_rmm - def transform(self, X, convert_dtype=False): + def transform(self, X, convert_dtype=False) -> CumlArray: """ Apply dimensionality reduction to X. @@ -685,19 +652,16 @@ class PCA(Base): """ - out_type = self._get_output_type(X) - if cupyx.scipy.sparse.issparse(X): - return self._sparse_transform(X, out_type=out_type) + return self._sparse_transform(X) elif scipy.sparse.issparse(X): X = sparse_scipy_to_cp(X) - return self._sparse_transform(X, out_type=out_type) + return self._sparse_transform(X) elif self._sparse_model: X, _, _, _ = \ - input_to_cuml_array(X, order='K', + input_to_cupy_array(X, order='K', check_dtype=[cp.float32, cp.float64]) - X = X.to_output(output_type='cupy') - return self._sparse_transform(X, out_type=out_type) + return self._sparse_transform(X) X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, @@ -719,10 +683,10 @@ class PCA(Base): dtype=dtype.type) cdef uintptr_t _trans_input_ptr = t_input_data.ptr - cdef uintptr_t components_ptr = self._components_.ptr + cdef uintptr_t components_ptr = self.components_.ptr cdef uintptr_t singular_vals_ptr = \ - self._singular_values_.ptr - cdef uintptr_t _mean_ptr = self._mean_.ptr + self.singular_values_.ptr + cdef uintptr_t _mean_ptr = self.mean_.ptr cdef handle_t* handle_ = self.handle.getHandle() if dtype.type == np.float32: @@ -746,7 +710,7 @@ class PCA(Base): # following transfers start self.handle.sync() - return t_input_data.to_output(out_type) + return t_input_data def get_param_names(self): return super().get_param_names() + \ diff --git a/python/cuml/decomposition/pca_mg.pyx b/python/cuml/decomposition/pca_mg.pyx index 94cf4ef2c4..175ffdd7e0 100644 --- a/python/cuml/decomposition/pca_mg.pyx +++ b/python/cuml/decomposition/pca_mg.pyx @@ -32,10 +32,10 @@ from cython.operator cimport dereference as deref from cuml.common.array import CumlArray import cuml.common.opg_data_utils_mg as opg +import cuml.internals from cuml.common.base import Base from cuml.raft.common.handle cimport handle_t from cuml.decomposition.utils cimport paramsSolver -from cuml.common import input_to_dev_array, zeros from cuml.common.opg_data_utils_mg cimport * from cuml.decomposition import PCA @@ -127,15 +127,16 @@ class PCAMG(BaseDecompositionMG, PCA): return params + @cuml.internals.api_base_return_any_skipall def _call_fit(self, X, rank, part_desc, arg_params): - cdef uintptr_t comp_ptr = self._components_.ptr - cdef uintptr_t explained_var_ptr = self._explained_variance_.ptr + cdef uintptr_t comp_ptr = self.components_.ptr + cdef uintptr_t explained_var_ptr = self.explained_variance_.ptr cdef uintptr_t explained_var_ratio_ptr = \ - self._explained_variance_ratio_.ptr - cdef uintptr_t singular_vals_ptr = self._singular_values_.ptr - cdef uintptr_t mean_ptr = self._mean_.ptr - cdef uintptr_t noise_vars_ptr = self._noise_variance_.ptr + self.explained_variance_ratio_.ptr + cdef uintptr_t singular_vals_ptr = self.singular_values_.ptr + cdef uintptr_t mean_ptr = self.mean_.ptr + cdef uintptr_t noise_vars_ptr = self.noise_variance_.ptr cdef handle_t* handle_ = self.handle.getHandle() cdef paramsPCAMG *params = arg_params diff --git a/python/cuml/decomposition/tsvd.pyx b/python/cuml/decomposition/tsvd.pyx index cf5bcd0170..7881141172 100644 --- a/python/cuml/decomposition/tsvd.pyx +++ b/python/cuml/decomposition/tsvd.pyx @@ -26,7 +26,6 @@ import rmm from libcpp cimport bool from libc.stdint cimport uintptr_t -import cuml from cuml.common.array import CumlArray from cuml.common.base import Base @@ -34,6 +33,7 @@ from cuml.common.doc_utils import generate_docstring from cuml.raft.common.handle cimport handle_t from cuml.decomposition.utils cimport * from cuml.common import input_to_cuml_array +from cuml.common.array_descriptor import CumlArrayDescriptor from cython.operator cimport dereference as deref @@ -241,6 +241,11 @@ class TruncatedSVD(Base): """ + components_ = CumlArrayDescriptor() + explained_variance_ = CumlArrayDescriptor() + explained_variance_ratio_ = CumlArrayDescriptor() + singular_values_ = CumlArrayDescriptor() + def __init__(self, algorithm='full', handle=None, n_components=1, n_iter=15, random_state=None, tol=1e-7, verbose=False, output_type=None): @@ -255,15 +260,12 @@ class TruncatedSVD(Base): self.c_algorithm = self._get_algorithm_c_name(self.algorithm) # internal array attributes - self._components_ = None # accessed via estimator.components_ - self._explained_variance_ = None - # accessed via estimator.explained_variance_ + self.components_ = None + self.explained_variance_ = None - self._explained_variance_ratio_ = None - # accessed via estimator.explained_variance_ratio_ + self.explained_variance_ratio_ = None - self._singular_values_ = None - # accessed via estimator.singular_values_ + self.singular_values_ = None def _get_algorithm_c_name(self, algorithm): algo_map = { @@ -290,19 +292,17 @@ class TruncatedSVD(Base): def _initialize_arrays(self, n_components, n_rows, n_cols): - self._components_ = CumlArray.zeros((n_components, n_cols), - dtype=self.dtype) - self._explained_variance_ = CumlArray.zeros(n_components, - dtype=self.dtype) - self._explained_variance_ratio_ = CumlArray.zeros(n_components, - dtype=self.dtype) - self._mean_ = CumlArray.zeros(n_cols, dtype=self.dtype) - self._singular_values_ = CumlArray.zeros(n_components, - dtype=self.dtype) - self._noise_variance_ = CumlArray.zeros(1, dtype=self.dtype) + self.components_ = CumlArray.zeros((n_components, n_cols), + dtype=self.dtype) + self.explained_variance_ = CumlArray.zeros(n_components, + dtype=self.dtype) + self.explained_variance_ratio_ = CumlArray.zeros(n_components, + dtype=self.dtype) + self.singular_values_ = CumlArray.zeros(n_components, + dtype=self.dtype) @generate_docstring() - def fit(self, X, y=None): + def fit(self, X, y=None) -> "TruncatedSVD": """ Fit LSI model on training cudf DataFrame X. y is currently ignored. @@ -316,14 +316,12 @@ class TruncatedSVD(Base): 'type': 'dense', 'description': 'Reduced version of X', 'shape': '(n_samples, n_components)'}) - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None) -> CumlArray: """ Fit LSI model to X and perform dimensionality reduction on X. y is currently ignored. """ - self._set_base_attributes(output_type=X, n_features=X) - X_m, self.n_rows, self.n_cols, self.dtype = \ input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) cdef uintptr_t input_ptr = X_m.ptr @@ -333,16 +331,16 @@ class TruncatedSVD(Base): self._initialize_arrays(self.n_components, self.n_rows, self.n_cols) - cdef uintptr_t comp_ptr = self._components_.ptr + cdef uintptr_t comp_ptr = self.components_.ptr cdef uintptr_t explained_var_ptr = \ - self._explained_variance_.ptr + self.explained_variance_.ptr cdef uintptr_t explained_var_ratio_ptr = \ - self._explained_variance_ratio_.ptr + self.explained_variance_ratio_.ptr cdef uintptr_t singular_vals_ptr = \ - self._singular_values_.ptr + self.singular_values_.ptr _trans_input_ = CumlArray.zeros((params.n_rows, params.n_components), dtype=self.dtype) @@ -375,14 +373,13 @@ class TruncatedSVD(Base): # following transfers start self.handle.sync() - out_type = self._get_output_type(X) - return _trans_input_.to_output(out_type) + return _trans_input_ @generate_docstring(return_values={'name': 'X_original', 'type': 'dense', 'description': 'X in original space', 'shape': '(n_samples, n_features)'}) - def inverse_transform(self, X, convert_dtype=False): + def inverse_transform(self, X, convert_dtype=False) -> CumlArray: """ Transform X back to its original space. Returns X_original whose transform would be X. @@ -404,7 +401,7 @@ class TruncatedSVD(Base): cdef uintptr_t trans_input_ptr = trans_input.ptr cdef uintptr_t input_ptr = input_data.ptr - cdef uintptr_t components_ptr = self._components_.ptr + cdef uintptr_t components_ptr = self.components_.ptr cdef handle_t* handle_ = self.handle.getHandle() @@ -425,14 +422,13 @@ class TruncatedSVD(Base): # following transfers start self.handle.sync() - out_type = self._get_output_type(X) - return input_data.to_output(out_type) + return input_data @generate_docstring(return_values={'name': 'X_new', 'type': 'dense', 'description': 'Reduced version of X', 'shape': '(n_samples, n_components)'}) - def transform(self, X, convert_dtype=False): + def transform(self, X, convert_dtype=False) -> CumlArray: """ Perform dimensionality reduction on X. @@ -454,7 +450,7 @@ class TruncatedSVD(Base): cdef uintptr_t input_ptr = input.ptr cdef uintptr_t trans_input_ptr = t_input_data.ptr - cdef uintptr_t components_ptr = self._components_.ptr + cdef uintptr_t components_ptr = self.components_.ptr cdef handle_t* handle_ = self.handle.getHandle() @@ -475,8 +471,7 @@ class TruncatedSVD(Base): # following transfers start self.handle.sync() - out_type = self._get_output_type(X) - return t_input_data.to_output(out_type) + return t_input_data def get_param_names(self): return super().get_param_names() + \ diff --git a/python/cuml/decomposition/tsvd_mg.pyx b/python/cuml/decomposition/tsvd_mg.pyx index e50de8af66..adfdc174f5 100644 --- a/python/cuml/decomposition/tsvd_mg.pyx +++ b/python/cuml/decomposition/tsvd_mg.pyx @@ -27,6 +27,7 @@ from libcpp cimport bool from libc.stdint cimport uintptr_t, uint32_t, uint64_t from cython.operator cimport dereference as deref +import cuml.internals from cuml.common.base import Base from cuml.raft.common.handle cimport handle_t from cuml.decomposition.utils cimport * @@ -68,14 +69,15 @@ class TSVDMG(BaseDecompositionMG, TruncatedSVD): def __init__(self, **kwargs): super(TSVDMG, self).__init__(**kwargs) + @cuml.internals.api_base_return_any_skipall def _call_fit(self, X, trans, rank, input_desc, trans_desc, arg_params): - cdef uintptr_t comp_ptr = self._components_.ptr - cdef uintptr_t explained_var_ptr = self._explained_variance_.ptr + cdef uintptr_t comp_ptr = self.components_.ptr + cdef uintptr_t explained_var_ptr = self.explained_variance_.ptr cdef uintptr_t explained_var_ratio_ptr = \ - self._explained_variance_ratio_.ptr - cdef uintptr_t singular_vals_ptr = self._singular_values_.ptr + self.explained_variance_ratio_.ptr + cdef uintptr_t singular_vals_ptr = self.singular_values_.ptr cdef handle_t* handle_ = self.handle.getHandle() cdef paramsTSVD *params = arg_params diff --git a/python/cuml/ensemble/randomforest_common.pyx b/python/cuml/ensemble/randomforest_common.pyx index 192711fa69..8e1de06199 100644 --- a/python/cuml/ensemble/randomforest_common.pyx +++ b/python/cuml/ensemble/randomforest_common.pyx @@ -18,6 +18,7 @@ import ctypes import cupy as cp import math import warnings +import typing import numpy as np from cuml import ForestInference @@ -25,13 +26,15 @@ from cuml.fil.fil import TreeliteModel from cuml.raft.common.handle import Handle from cuml.common.base import Base from cuml.common.array import CumlArray +import cuml.internals from cython.operator cimport dereference as deref from cuml.ensemble.randomforest_shared import treelite_serialize, \ treelite_deserialize from cuml.ensemble.randomforest_shared cimport * -from cuml.common import input_to_cuml_array, with_cupy_rmm +from cuml.common import input_to_cuml_array +from cuml.common.array_descriptor import CumlArrayDescriptor class BaseRandomForestModel(Base): @@ -48,6 +51,8 @@ class BaseRandomForestModel(Base): criterion_dict = {'0': GINI, '1': ENTROPY, '2': MSE, '3': MAE, '4': CRITERION_END} + classes_ = CumlArrayDescriptor() + def __init__(self, *, split_criterion, seed=None, n_streams=8, n_estimators=100, max_depth=16, handle=None, max_features='auto', @@ -149,7 +154,7 @@ class BaseRandomForestModel(Base): self.treelite_handle = None self.treelite_serialized_model = None - def _get_max_feat_val(self): + def _get_max_feat_val(self) -> float: if type(self.max_features) == int: return self.max_features/self.n_cols elif type(self.max_features) == float: @@ -228,10 +233,12 @@ class BaseRandomForestModel(Base): self.treelite_handle = tl_handle return self.treelite_handle - @with_cupy_rmm - def _dataset_setup_for_fit(self, X, y, convert_dtype): - self._set_output_type(X) - self._set_n_features_in(X) + @cuml.internals.api_base_return_generic(set_output_type=True, + set_n_features_in=True, + get_output_type=False) + def _dataset_setup_for_fit( + self, X, y, + convert_dtype) -> typing.Tuple[CumlArray, CumlArray, float]: # Reset the old tree data for new fit call self._reset_forest_data() @@ -252,16 +259,14 @@ class BaseRandomForestModel(Base): if y_dtype != np.int32: raise TypeError("The labels `y` need to be of dtype" " `int32`") - temp_classes = cp.unique(y_m) - self.num_classes = len(temp_classes) + self.classes_ = cp.unique(y_m) + self.num_classes = len(self.classes_) for i in range(self.num_classes): - if i not in temp_classes: + if i not in self.classes_: raise ValueError("The labels need " "to be consecutive values from " "0 to the number of unique label values") - # Save internally as CumlArray - self._classes_ = CumlArray(temp_classes) else: y_m, _, _, y_dtype = \ input_to_cuml_array( @@ -313,8 +318,8 @@ class BaseRandomForestModel(Base): def _predict_model_on_gpu(self, X, algo, convert_dtype, fil_sparse_format, threshold=0.5, - output_class=False, predict_proba=False): - out_type = self._get_output_type(X) + output_class=False, + predict_proba=False) -> CumlArray: _, n_rows, n_cols, dtype = \ input_to_cuml_array(X, order='F', check_cols=self.n_cols) @@ -334,7 +339,8 @@ class BaseRandomForestModel(Base): _check_fil_parameter_validity(depth=self.max_depth, fil_sparse_format=fil_sparse_format, algo=algo) - fil_model = ForestInference() + fil_model = ForestInference(handle=self.handle, verbose=self.verbose, + output_type=self.output_type) tl_to_fil_model = \ fil_model.load_using_treelite_handle(treelite_handle, output_class=output_class, @@ -342,8 +348,10 @@ class BaseRandomForestModel(Base): algo=algo, storage_type=storage_type) - preds = tl_to_fil_model.predict(X, output_type=out_type, - predict_proba=predict_proba) + if (predict_proba): + preds = tl_to_fil_model.predict_proba(X) + else: + preds = tl_to_fil_model.predict(X) return preds def get_param_names(self): @@ -427,7 +435,9 @@ def _obtain_fil_model(treelite_handle, depth, fil_sparse_format=fil_sparse_format, algo=algo) - fil_model = ForestInference() + # Use output_type="input" to prevent an error + fil_model = ForestInference(output_type="input") + tl_to_fil_model = \ fil_model.load_using_treelite_handle(treelite_handle, output_class=output_class, diff --git a/python/cuml/ensemble/randomforest_shared.pxd b/python/cuml/ensemble/randomforest_shared.pxd index 9f4a474773..b8178d7e88 100644 --- a/python/cuml/ensemble/randomforest_shared.pxd +++ b/python/cuml/ensemble/randomforest_shared.pxd @@ -29,8 +29,6 @@ from cuml.raft.common.handle import Handle from cuml import ForestInference from cuml.common.base import Base from cuml.raft.common.handle cimport handle_t -from cuml.common import get_cudf_column_ptr, get_dev_array_ptr, \ - input_to_dev_array, zeros cimport cuml.common.cuda cdef extern from "treelite/c_api.h": diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index db24a2a304..4c5f2969ea 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -26,10 +26,11 @@ import cuml.common.logger as logger from cuml import ForestInference from cuml.common.array import CumlArray from cuml.common.base import ClassifierMixin +import cuml.internals from cuml.common.doc_utils import generate_docstring from cuml.common.doc_utils import insert_into_docstring from cuml.raft.common.handle import Handle -from cuml.common import input_to_cuml_array, rmm_cupy_ary +from cuml.common import input_to_cuml_array from cuml.ensemble.randomforest_common import BaseRandomForestModel from cuml.ensemble.randomforest_common import _obtain_fil_model @@ -416,6 +417,9 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): @generate_docstring(skip_parameters_heading=True, y='dense_intdtype', convert_dtype_cast='np.float32') + @cuml.internals.api_base_return_any(set_output_type=False, + set_output_dtype=True, + set_n_features_in=False) def fit(self, X, y, convert_dtype=True): """ Perform Random Forest Classification on the input data @@ -427,8 +431,6 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): y to be of dtype int32. This will increase memory used for the method. """ - self._set_base_attributes(target_dtype=y) - X_m, y_m, max_feature_val = self._dataset_setup_for_fit(X, y, convert_dtype) cdef uintptr_t X_ptr, y_ptr @@ -503,10 +505,8 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): del y_m return self - def _predict_model_on_cpu(self, X, convert_dtype): - out_type = self._get_output_type(X) - out_dtype = self._get_target_dtype() - + @cuml.internals.api_base_return_array(get_output_dtype=True) + def _predict_model_on_cpu(self, X, convert_dtype) -> CumlArray: cdef uintptr_t X_ptr X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, order='C', @@ -550,7 +550,7 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): self.handle.sync() # synchronous w/o a stream del(X_m) - return preds.to_output(output_type=out_type, output_dtype=out_dtype) + return preds @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], return_values=[('dense', '(n_samples, 1)')]) @@ -558,7 +558,7 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): output_class=True, threshold=0.5, algo='auto', num_classes=None, convert_dtype=True, - fil_sparse_format='auto'): + fil_sparse_format='auto') -> CumlArray: """ Predicts the labels for X. @@ -646,7 +646,7 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): return preds - def _predict_get_all(self, X, convert_dtype=True): + def _predict_get_all(self, X, convert_dtype=True) -> CumlArray: """ Predicts the labels for X. @@ -662,7 +662,6 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): y : NumPy Dense vector (int) of shape (n_samples, 1) """ - out_type = self._get_output_type(X) cdef uintptr_t X_ptr, preds_ptr X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, order='C', @@ -704,14 +703,14 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): % (str(self.dtype))) self.handle.sync() del(X_m) - return preds.to_output(out_type) + return preds @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], return_values=[('dense', '(n_samples, 1)')]) def predict_proba(self, X, output_class=True, threshold=0.5, algo='auto', num_classes=None, convert_dtype=True, - fil_sparse_format='auto'): + fil_sparse_format='auto') -> CumlArray: """ Predicts class probabilites for X. This function uses the GPU implementation of predict. Therefore, data with 'dtype = np.float32' diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index 6ab080daac..b894b207f3 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -24,12 +24,13 @@ import cuml.common.logger as logger from cuml import ForestInference from cuml.common.array import CumlArray +import cuml.internals from cuml.common.base import RegressorMixin from cuml.common.doc_utils import generate_docstring from cuml.common.doc_utils import insert_into_docstring from cuml.raft.common.handle import Handle -from cuml.common import input_to_cuml_array, rmm_cupy_ary +from cuml.common import input_to_cuml_array from cuml.ensemble.randomforest_common import BaseRandomForestModel from cuml.ensemble.randomforest_common import _obtain_fil_model @@ -403,6 +404,7 @@ class RandomForestRegressor(BaseRandomForestModel, RegressorMixin): fil_sparse_format=fil_sparse_format) @generate_docstring() + @cuml.internals.api_base_return_any_skipall def fit(self, X, y, convert_dtype=True): """ Perform Random Forest Regression on the input data @@ -475,8 +477,7 @@ class RandomForestRegressor(BaseRandomForestModel, RegressorMixin): del y_m return self - def _predict_model_on_cpu(self, X, convert_dtype): - out_type = self._get_output_type(X) + def _predict_model_on_cpu(self, X, convert_dtype) -> CumlArray: cdef uintptr_t X_ptr X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, order='C', @@ -521,13 +522,13 @@ class RandomForestRegressor(BaseRandomForestModel, RegressorMixin): self.handle.sync() # synchronous w/o a stream del(X_m) - return preds.to_output(out_type) + return preds @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], return_values=[('dense', '(n_samples, 1)')]) def predict(self, X, predict_model="GPU", algo='auto', convert_dtype=True, - fil_sparse_format='auto'): + fil_sparse_format='auto') -> CumlArray: """ Predicts the labels for X. diff --git a/python/cuml/experimental/decomposition/incremental_pca.py b/python/cuml/experimental/decomposition/incremental_pca.py index b302a8e499..cd41e42b07 100644 --- a/python/cuml/experimental/decomposition/incremental_pca.py +++ b/python/cuml/experimental/decomposition/incremental_pca.py @@ -21,7 +21,8 @@ import scipy from cuml import Base from cuml.common import input_to_cuml_array -from cuml.common import with_cupy_rmm +from cuml.common.input_utils import input_to_cupy_array +import cuml.internals from cuml.common.array import CumlArray from cuml.decomposition import PCA @@ -199,11 +200,9 @@ def __init__(self, handle=None, n_components=None, *, whiten=False, output_type=output_type) self.batch_size = batch_size self._hyperparams = ["n_components", "whiten", "copy", "batch_size"] - self._cupy_attributes = True self._sparse_model = True - @with_cupy_rmm - def fit(self, X, y=None): + def fit(self, X, y=None) -> "IncrementalPCA": """ Fit the model with X, using minibatches of size batch_size. @@ -221,26 +220,21 @@ def fit(self, X, y=None): Returns the instance itself. """ - - self._set_base_attributes(output_type=X) - self.n_samples_seen_ = 0 - self._mean_ = .0 + self.mean_ = .0 self.var_ = .0 if scipy.sparse.issparse(X) or cupyx.scipy.sparse.issparse(X): X = _validate_sparse_input(X) else: - X, n_samples, n_features, self.dtype = \ - input_to_cuml_array(X, order='K', - check_dtype=[cp.float32, cp.float64]) - # NOTE: While we cast the input to a cupy array here, we still # respect the `output_type` parameter in the constructor. This # is done by PCA, which IncrementalPCA inherits from. PCA's # transform and inverse transform convert the output to the # required type. - X = X.to_output(output_type='cupy') + X, n_samples, n_features, self.dtype = \ + input_to_cupy_array(X, order='K', + check_dtype=[cp.float32, cp.float64]) n_samples, n_features = X.shape @@ -259,8 +253,8 @@ def fit(self, X, y=None): return self - @with_cupy_rmm - def partial_fit(self, X, y=None, check_input=True): + @cuml.internals.api_base_return_any_skipall + def partial_fit(self, X, y=None, check_input=True) -> "IncrementalPCA": """ Incremental fit with X. All of X is processed as a single batch. @@ -291,20 +285,19 @@ def partial_fit(self, X, y=None, check_input=True): self._set_output_type(X) X, n_samples, n_features, self.dtype = \ - input_to_cuml_array(X, order='K', + input_to_cupy_array(X, order='K', check_dtype=[cp.float32, cp.float64]) - X = X.to_output(output_type='cupy') else: n_samples, n_features = X.shape - if not hasattr(self, '_components_'): - self._components_ = None + if not hasattr(self, 'components_'): + self.components_ = None if self.n_components is None: - if self._components_ is None: + if self.components_ is None: self.n_components_ = min(n_samples, n_features) else: - self.n_components_ = self._components_.shape[0] + self.n_components_ = self.components_.shape[0] elif not 1 <= self.n_components <= n_features: raise ValueError("n_components=%r invalid for n_features=%d, need " "more rows than columns for IncrementalPCA " @@ -316,27 +309,22 @@ def partial_fit(self, X, y=None, check_input=True): else: self.n_components_ = self.n_components - if (self._components_ is not None) and (self._components_.shape[0] != - self.n_components_): + if (self.components_ is not None) and (self.components_.shape[0] != + self.n_components_): raise ValueError("Number of input features has changed from %i " "to %i between calls to partial_fit! Try " "setting n_components to a fixed value." % - (self._components_.shape[0], self.n_components_)) - - if not self._cupy_attributes: - self._cumlarray_to_cupy_attrs() - self._cupy_attributes = True - + (self.components_.shape[0], self.n_components_)) # This is the first partial_fit if not hasattr(self, 'n_samples_seen_'): self.n_samples_seen_ = 0 - self._mean_ = .0 + self.mean_ = .0 self.var_ = .0 # Update stats - they are 0 if this is the first step col_mean, col_var, n_total_samples = \ _incremental_mean_and_var( - X, last_mean=self._mean_, last_variance=self.var_, + X, last_mean=self.mean_, last_variance=self.var_, last_sample_count=cp.repeat(cp.asarray([self.n_samples_seen_]), X.shape[1])) n_total_samples = n_total_samples[0] @@ -351,9 +339,9 @@ def partial_fit(self, X, y=None, check_input=True): # Build matrix of combined previous basis and new data mean_correction = \ cp.sqrt((self.n_samples_seen_ * n_samples) / - n_total_samples) * (self._mean_ - col_batch_mean) - X = cp.vstack((self._singular_values_.reshape((-1, 1)) * - self._components_, X, mean_correction)) + n_total_samples) * (self.mean_ - col_batch_mean) + X = cp.vstack((self.singular_values_.reshape((-1, 1)) * + self.components_, X, mean_correction)) U, S, V = cp.linalg.svd(X, full_matrices=False) U, V = _svd_flip(U, V, u_based_decision=False) @@ -361,27 +349,22 @@ def partial_fit(self, X, y=None, check_input=True): explained_variance_ratio = S ** 2 / cp.sum(col_var * n_total_samples) self.n_samples_seen_ = n_total_samples - self._components_ = V[:self.n_components_] - self._singular_values_ = S[:self.n_components_] - self._mean_ = col_mean + self.components_ = V[:self.n_components_] + self.singular_values_ = S[:self.n_components_] + self.mean_ = col_mean self.var_ = col_var - self._explained_variance_ = explained_variance[:self.n_components_] - self._explained_variance_ratio_ = \ + self.explained_variance_ = explained_variance[:self.n_components_] + self.explained_variance_ratio_ = \ explained_variance_ratio[:self.n_components_] if self.n_components_ < n_features: - self._noise_variance_ = \ + self.noise_variance_ = \ explained_variance[self.n_components_:].mean() else: - self._noise_variance_ = 0. - - if self._cupy_attributes: - self._cupy_to_cumlarray_attrs() - self._cupy_attributes = False + self.noise_variance_ = 0. return self - @with_cupy_rmm - def transform(self, X, convert_dtype=False): + def transform(self, X, convert_dtype=False) -> CumlArray: """ Apply dimensionality reduction to X. @@ -409,7 +392,6 @@ def transform(self, X, convert_dtype=False): """ if scipy.sparse.issparse(X) or cupyx.scipy.sparse.issparse(X): - out_type = self._get_output_type(X) X = _validate_sparse_input(X) @@ -421,7 +403,7 @@ def transform(self, X, convert_dtype=False): output, _, _, _ = \ input_to_cuml_array(cp.vstack(output), order='K') - return output.to_output(out_type) + return output else: return super().transform(X) @@ -429,24 +411,6 @@ def get_param_names(self): # Skip super() since we dont pass any extra parameters in __init__ return Base.get_param_names(self) + self._hyperparams - def _cupy_to_cumlarray_attrs(self): - self._components_ = CumlArray(self._components_.copy()) - self._mean_ = CumlArray(self._mean_) - self._noise_variance_ = CumlArray(self._noise_variance_) - self._singular_values_ = CumlArray(self._singular_values_) - self._explained_variance_ = CumlArray(self._explained_variance_.copy()) - self._explained_variance_ratio_ = \ - CumlArray(self._explained_variance_ratio_) - - def _cumlarray_to_cupy_attrs(self): - self._components_ = self._components_.to_output("cupy") - self._mean_ = self._mean_.to_output("cupy") - self._noise_variance_ = self._noise_variance_.to_output("cupy") - self._singular_values_ = self._singular_values_.to_output("cupy") - self._explained_variance_ = self._explained_variance_.to_output("cupy") - self._explained_variance_ratio_ = \ - self._explained_variance_ratio_.to_output("cupy") - def _validate_sparse_input(X): """ diff --git a/python/cuml/feature_extraction/_tfidf.py b/python/cuml/feature_extraction/_tfidf.py index d01f8e0fa6..ffe54324d1 100644 --- a/python/cuml/feature_extraction/_tfidf.py +++ b/python/cuml/feature_extraction/_tfidf.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import cuml.internals from cuml.common.exceptions import NotFittedError import cupy as cp import cupyx -from cuml.common import with_cupy_rmm from cuml.common.sparsefuncs import csr_row_normalize_l1, csr_row_normalize_l2 from cuml.common.sparsefuncs import csr_diag_mul from cuml.common.array import CumlArray @@ -132,7 +132,6 @@ def __init__(self, *, norm='l2', use_idf=True, smooth_idf=True, self.smooth_idf = smooth_idf self.sublinear_tf = sublinear_tf - @with_cupy_rmm def _set_doc_stats(self, X): """ We set the following document level statistics here: @@ -152,7 +151,6 @@ def _set_doc_stats(self, X): return - @with_cupy_rmm def _set_idf_diag(self): """ Sets idf_diagonal sparse array @@ -172,8 +170,8 @@ def _set_idf_diag(self): # Free up memory occupied by below del self.__df - @with_cupy_rmm - def fit(self, X): + @cuml.internals.api_base_return_any_skipall + def fit(self, X) -> "TfidfTransformer": """Learn the idf vector (global term weights). Parameters @@ -189,7 +187,7 @@ def fit(self, X): return self - @with_cupy_rmm + @cuml.internals.api_base_return_any_skipall def transform(self, X, copy=True): """Transform a count matrix to a tf or tf-idf representation @@ -239,6 +237,7 @@ def transform(self, X, copy=True): return X + @cuml.internals.api_base_return_any_skipall def fit_transform(self, X, copy=True): """ Fit TfidfTransformer to X, then transform X. diff --git a/python/cuml/fil/fil.pyx b/python/cuml/fil/fil.pyx index c8c9b8e994..046227f87b 100644 --- a/python/cuml/fil/fil.pyx +++ b/python/cuml/fil/fil.pyx @@ -30,6 +30,7 @@ from libcpp cimport bool from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free +import cuml.internals from cuml.common.array import CumlArray from cuml.common.base import Base from cuml.raft.common.handle cimport handle_t @@ -195,7 +196,7 @@ cdef extern from "cuml/fil/fil.h" namespace "ML::fil": cdef class ForestInference_impl(): - cpdef object handle + cdef object handle cdef forest_t forest_data cdef size_t num_output_groups cdef bool output_class @@ -237,7 +238,7 @@ cdef class ForestInference_impl(): logger.info('storage_type=="sparse8" is an experimental feature') return storage_type_dict[storage_type_str] - def predict(self, X, output_type='numpy', + def predict(self, X, output_dtype=None, predict_proba=False, preds=None): """ Returns the results of forest inference on the examples in X @@ -246,10 +247,6 @@ cdef class ForestInference_impl(): ---------- X : float32 array-like (device or host) shape = (n_samples, n_features) For optimal performance, pass a device array with C-style layout - output_type : string (default = 'numpy') - possible options are : {'input', 'cudf', 'cupy', 'numpy'}, optional - Variable to control output type of the results and attributes of - the estimators. preds : float32 device array, shape = n_samples predict_proba : bool, whether to output class probabilities(vs classes) Supported only for binary classification. output format @@ -261,6 +258,10 @@ cdef class ForestInference_impl(): Predicted results of type as defined by the output_type variable """ + + # Set the output_dtype. None is fine here + cuml.internals.set_api_output_dtype(output_dtype) + if (not self.output_class) and predict_proba: raise NotImplementedError("Predict_proba function is not available" " for Regression models. If you are " @@ -304,11 +305,9 @@ cdef class ForestInference_impl(): # special case due to predict and predict_proba # both coming from the same CUDA/C++ function if predict_proba: - output_dtype = None - return preds.to_output( - output_type=output_type, - output_dtype=output_dtype - ) + cuml.internals.set_api_output_dtype(None) + + return preds def load_from_treelite_model_handle(self, uintptr_t model_handle, @@ -376,11 +375,10 @@ cdef class ForestInference_impl(): return self def __dealloc__(self): - cdef handle_t* handle_ =\ - self.handle.getHandle() + cdef handle_t* handle_ = self.handle.getHandle() + if self.forest_data !=NULL: - free(handle_[0], - self.forest_data) + free(handle_[0], self.forest_data) class ForestInference(Base): @@ -474,7 +472,7 @@ class ForestInference(Base): verbose=verbose) self._impl = ForestInference_impl(self.handle) - def predict(self, X, preds=None): + def predict(self, X, preds=None) -> CumlArray: """ Predicts the labels for X with the loaded forest model. By default, the result is the raw floating point output @@ -498,10 +496,9 @@ class ForestInference(Base): GPU array of length n_samples with inference results (or 'preds' filled with inference results if preds was specified) """ - out_type = self._get_output_type(X) - return self._impl.predict(X, out_type, predict_proba=False, preds=None) + return self._impl.predict(X, predict_proba=False, preds=None) - def predict_proba(self, X, preds=None): + def predict_proba(self, X, preds=None) -> CumlArray: """ Predicts the class probabilities for X with the loaded forest model. The result is the raw floating point output @@ -523,9 +520,7 @@ class ForestInference(Base): GPU array of shape (n_samples,2) with inference results (or 'preds' filled with inference results if preds was specified) """ - out_type = self._get_output_type(X) - - return self._impl.predict(X, out_type, predict_proba=True, preds=None) + return self._impl.predict(X, predict_proba=True, preds=None) def load_from_treelite_model(self, model, output_class=False, algo='auto', @@ -726,7 +721,10 @@ class ForestInference(Base): A Forest Inference model which can be used to perform inferencing on the random forest model. """ - return self._impl.load_using_treelite_handle(model_handle, - output_class, - algo, threshold, - str(storage_type)) + self._impl.load_using_treelite_handle(model_handle, + output_class, + algo, threshold, + str(storage_type)) + + # DO NOT RETURN self._impl here!! + return self diff --git a/python/cuml/internals/__init__.py b/python/cuml/internals/__init__.py index 64158e83ba..2a8d9ac9da 100644 --- a/python/cuml/internals/__init__.py +++ b/python/cuml/internals/__init__.py @@ -14,4 +14,28 @@ # limitations under the License. # +from cuml.internals.base_helpers import BaseMetaClass +from cuml.internals.api_decorators import ( + api_base_fit_transform, + api_base_return_any_skipall, + api_base_return_any, + api_base_return_array_skipall, + api_base_return_array, + api_base_return_autoarray, + api_base_return_generic_skipall, + api_base_return_generic, + api_base_return_sparse_array, + api_ignore, + api_return_any, + api_return_array_skipall, + api_return_array, + api_return_generic, + api_return_sparse_array, + exit_internal_api, +) +from cuml.internals.api_context_managers import ( + in_internal_api, + set_api_output_dtype, + set_api_output_type, +) from cuml.internals.internals import GraphBasedDimRedCallback diff --git a/python/cuml/internals/api_context_managers.py b/python/cuml/internals/api_context_managers.py new file mode 100644 index 0000000000..634f704ac9 --- /dev/null +++ b/python/cuml/internals/api_context_managers.py @@ -0,0 +1,499 @@ +# +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import contextlib +import threading +import typing +from collections import deque + +import cuml +import cuml.common +import cuml.common.array +import cuml.common.array_sparse +import cuml.common.input_utils +import rmm + +try: + from cupy.cuda import using_allocator as cupy_using_allocator +except ImportError: + try: + from cupy.cuda.memory import using_allocator as cupy_using_allocator + except ImportError: + pass + +# Use _F as a type variable for decorators. See: +# https://github.com/python/mypy/pull/8336/files#diff-eb668b35b7c0c4f88822160f3ca4c111f444c88a38a3b9df9bb8427131538f9cR260 +_F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) + + +@contextlib.contextmanager +def _using_mirror_output_type(): + """ + Sets cuml.global_output_type to "mirror" for internal API handling. We need + a separate function since `cuml.using_output_type()` doesn't accept + "mirror" + + Yields + ------- + string + Returns the previous value in cuml.global_output_type + """ + prev_output_type = cuml.global_output_type + try: + cuml.global_output_type = "mirror" + yield prev_output_type + finally: + cuml.global_output_type = prev_output_type + + +global_output_type_data = threading.local() +global_output_type_data.root_cm = None + + +def in_internal_api(): + return global_output_type_data.root_cm is not None + + +def set_api_output_type(output_type: str): + assert (global_output_type_data.root_cm is not None) + + # Quick exit + if (isinstance(output_type, str)): + global_output_type_data.root_cm.output_type = output_type + return + + # Try to convert any array objects to their type + array_type = cuml.common.input_utils.determine_array_type(output_type) + + # Ensure that this is an array-like object + assert output_type is None or array_type is not None + + global_output_type_data.root_cm.output_type = array_type + + +def set_api_output_dtype(output_dtype): + assert (global_output_type_data.root_cm is not None) + + # Try to convert any array objects to their type + if (output_dtype is not None + and cuml.common.input_utils.is_array_like(output_dtype)): + output_dtype = cuml.common.input_utils.determine_array_dtype( + output_dtype) + + assert (output_dtype is not None) + + global_output_type_data.root_cm.output_dtype = output_dtype + + +class InternalAPIContext(contextlib.ExitStack): + def __init__(self): + super().__init__() + + def cleanup(): + global_output_type_data.root_cm = None + + self.callback(cleanup) + + self.enter_context(cupy_using_allocator(rmm.rmm_cupy_allocator)) + self.prev_output_type = self.enter_context(_using_mirror_output_type()) + + self._output_type = None + self.output_dtype = None + + # Set the output type to the prev_output_type. If "input", set to None + # to allow inner functions to specify the input + self.output_type = (None if self.prev_output_type == "input" else + self.prev_output_type) + + self._count = 0 + + self.call_stack = {} + + global_output_type_data.root_cm = self + + @property + def output_type(self): + return self._output_type + + @output_type.setter + def output_type(self, value: str): + self._output_type = value + + def pop_all(self): + """Preserve the context stack by transferring it to a new instance.""" + new_stack = contextlib.ExitStack() + new_stack._exit_callbacks = self._exit_callbacks + self._exit_callbacks = deque() + return new_stack + + def __enter__(self) -> int: + + self._count += 1 + + return self._count + + def __exit__(self, *exc_details): + + self._count -= 1 + + return + + @contextlib.contextmanager + def push_output_types(self): + try: + old_output_type = self.output_type + old_output_dtype = self.output_dtype + + self.output_type = None + self.output_dtype = None + + yield + + finally: + self.output_type = (old_output_type if old_output_type is not None + else self.output_type) + self.output_dtype = (old_output_dtype if old_output_dtype + is not None else self.output_dtype) + + +def get_internal_context() -> InternalAPIContext: + + # Dask workers can have a separate thread access the object requiring this + # check + if (not hasattr(global_output_type_data, "root_cm")): + global_output_type_data.root_cm = None + + if (global_output_type_data.root_cm is None): + return InternalAPIContext() + + return global_output_type_data.root_cm + + +class ProcessEnter(object): + def __init__(self, context: "InternalAPIContextBase"): + super().__init__() + + self._context = context + + self._process_enter_cbs: typing.Deque[typing.Callable] = deque() + + def process_enter(self): + + for cb in self._process_enter_cbs: + cb() + + +class ProcessReturn(object): + def __init__(self, context: "InternalAPIContextBase"): + super().__init__() + + self._context = context + + self._process_return_cbs: typing.Deque[typing.Callable[ + [typing.Any], typing.Any]] = deque() + + def process_return(self, ret_val): + + for cb in self._process_return_cbs: + ret_val = cb(ret_val) + + return ret_val + + +EnterT = typing.TypeVar("EnterT", bound=ProcessEnter) +ProcessT = typing.TypeVar("ProcessT", bound=ProcessReturn) + + +class InternalAPIContextBase(contextlib.ExitStack, + typing.Generic[EnterT, ProcessT]): + + ProcessEnter_Type: typing.Type[EnterT] = None + ProcessReturn_Type: typing.Type[ProcessT] = None + + def __init__(self, func=None, args=None): + super().__init__() + + self._func = func + self._args = args + + self.root_cm = get_internal_context() + + self.is_root = False + + self._enter_obj: ProcessEnter = self.ProcessEnter_Type(self) + self._process_obj: ProcessReturn = None + + def __enter__(self): + + # Enter the root context to know if we are the root cm + self.is_root = self.enter_context(self.root_cm) == 1 + + # If we are the first, push any callbacks from the root into this CM + # If we are not the first, this will have no effect + self.push(self.root_cm.pop_all()) + + self._enter_obj.process_enter() + + # Now create the process functions since we know if we are root or not + self._process_obj = self.ProcessReturn_Type(self) + + return super().__enter__() + + def process_return(self, ret_val): + + return self._process_obj.process_return(ret_val) + + def __class_getitem__(cls: typing.Type["InternalAPIContextBase"], params): + + param_names = [ + param.__name__ if hasattr(param, '__name__') else str(param) + for param in params + ] + + type_name = f'{cls.__name__}[{", ".join(param_names)}]' + + ns = { + "ProcessEnter_Type": params[0], + "ProcessReturn_Type": params[1], + } + + return type(type_name, (cls, ), ns) + + +class ProcessEnterBaseMixin(ProcessEnter): + def __init__(self, context: "InternalAPIContextBase"): + super().__init__(context) + + self.base_obj: cuml.Base = self._context._args[0] + + +class ProcessEnterReturnAny(ProcessEnter): + pass + + +class ProcessEnterReturnArray(ProcessEnter): + def __init__(self, context: "InternalAPIContextBase"): + super().__init__(context) + + self._process_enter_cbs.append(self.push_output_types) + + def push_output_types(self): + + self._context.enter_context(self._context.root_cm.push_output_types()) + + +class ProcessEnterBaseReturnArray(ProcessEnterReturnArray, + ProcessEnterBaseMixin): + def __init__(self, context: "InternalAPIContextBase"): + super().__init__(context) + + # IMPORTANT: Only perform output type processing if + # `root_cm.output_type` is None. Since we default to using the incoming + # value if its set, there is no need to do any processing if the user + # has specified the output type + if (self._context.root_cm.prev_output_type is None + or self._context.root_cm.prev_output_type == "input"): + self._process_enter_cbs.append(self.base_output_type_callback) + + def base_output_type_callback(self): + + root_cm = self._context.root_cm + + def set_output_type(): + output_type = root_cm.output_type + + # Check if output_type is None, can happen if no output type has + # been set by estimator + if (output_type is None): + output_type = self.base_obj.output_type + + if (output_type == "input"): + output_type = self.base_obj._input_type + + if (output_type != root_cm.output_type): + set_api_output_type(output_type) + + assert (output_type != "mirror") + + self._context.callback(set_output_type) + + +class ProcessReturnAny(ProcessReturn): + pass + + +class ProcessReturnArray(ProcessReturn): + def __init__(self, context: "InternalAPIContextBase"): + super().__init__(context) + + self._process_return_cbs.append(self.convert_to_cumlarray) + + if (self._context.is_root or cuml.global_output_type != "mirror"): + self._process_return_cbs.append(self.convert_to_outputtype) + + def convert_to_cumlarray(self, ret_val): + + # Get the output type + ret_val_type_str = cuml.common.input_utils.determine_array_type( + ret_val) + + # If we are a supported array and not already cuml, convert to cuml + if (ret_val_type_str is not None and ret_val_type_str != "cuml"): + ret_val = cuml.common.input_utils.input_to_cuml_array( + ret_val, order="K").array + + return ret_val + + def convert_to_outputtype(self, ret_val): + + output_type = cuml.global_output_type + + if (output_type is None or output_type == "mirror" + or output_type == "input"): + output_type = self._context.root_cm.output_type + + assert (output_type is not None + and output_type != "mirror" + and output_type != "input"), \ + ("Invalid root_cm.output_type: " + "'{}'.").format(output_type) + + return ret_val.to_output( + output_type=output_type, + output_dtype=self._context.root_cm.output_dtype) + + +class ProcessReturnSparseArray(ProcessReturnArray): + + def convert_to_cumlarray(self, ret_val): + + # Get the output type + ret_val_type_str, is_sparse = \ + cuml.common.input_utils.determine_array_type_full(ret_val) + + # If we are a supported array and not already cuml, convert to cuml + if (ret_val_type_str is not None and ret_val_type_str != "cuml"): + if is_sparse: + ret_val = cuml.common.array_sparse.SparseCumlArray( + ret_val, convert_index=False) + else: + ret_val = cuml.common.input_utils.input_to_cuml_array( + ret_val, order="K").array + + return ret_val + + +class ProcessReturnGeneric(ProcessReturnArray): + def __init__(self, context: "InternalAPIContextBase"): + super().__init__(context) + + # Clear the existing callbacks to allow processing one at a time + self._single_array_cbs = self._process_return_cbs + + # Make a new queue + self._process_return_cbs = deque() + + self._process_return_cbs.append(self.process_generic) + + def process_single(self, ret_val): + for cb in self._single_array_cbs: + ret_val = cb(ret_val) + + return ret_val + + def process_tuple(self, ret_val: tuple): + + # Convert to a list + out_val = list(ret_val) + + for idx, item in enumerate(out_val): + + out_val[idx] = self.process_generic(item) + + return tuple(out_val) + + def process_dict(self, ret_val): + + for name, item in ret_val.items(): + + ret_val[name] = self.process_generic(item) + + return ret_val + + def process_list(self, ret_val): + + for idx, item in enumerate(ret_val): + + ret_val[idx] = self.process_generic(item) + + return ret_val + + def process_generic(self, ret_val): + + if (cuml.common.input_utils.is_array_like(ret_val)): + return self.process_single(ret_val) + + if (isinstance(ret_val, tuple)): + return self.process_tuple(ret_val) + + if (isinstance(ret_val, dict)): + return self.process_dict(ret_val) + + if (isinstance(ret_val, list)): + return self.process_list(ret_val) + + return ret_val + + +class ReturnAnyCM(InternalAPIContextBase[ProcessEnterReturnAny, + ProcessReturnAny]): + pass + + +class ReturnArrayCM(InternalAPIContextBase[ProcessEnterReturnArray, + ProcessReturnArray]): + pass + + +class ReturnSparseArrayCM(InternalAPIContextBase[ProcessEnterReturnArray, + ProcessReturnSparseArray]): + pass + + +class ReturnGenericCM(InternalAPIContextBase[ProcessEnterReturnArray, + ProcessReturnGeneric]): + pass + + +class BaseReturnAnyCM(InternalAPIContextBase[ProcessEnterReturnAny, + ProcessReturnAny]): + pass + + +class BaseReturnArrayCM(InternalAPIContextBase[ProcessEnterBaseReturnArray, + ProcessReturnArray]): + pass + + +class BaseReturnSparseArrayCM( + InternalAPIContextBase[ProcessEnterBaseReturnArray, + ProcessReturnSparseArray]): + pass + + +class BaseReturnGenericCM(InternalAPIContextBase[ProcessEnterBaseReturnArray, + ProcessReturnGeneric]): + pass diff --git a/python/cuml/internals/api_decorators.py b/python/cuml/internals/api_decorators.py new file mode 100644 index 0000000000..6c455474d3 --- /dev/null +++ b/python/cuml/internals/api_decorators.py @@ -0,0 +1,732 @@ +# +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import contextlib +import functools +import inspect +import typing +from functools import wraps + +import cuml +import cuml.common +import cuml.common.array +import cuml.common.array_sparse +import cuml.common.input_utils +from cuml.internals.api_context_managers import BaseReturnAnyCM +from cuml.internals.api_context_managers import BaseReturnArrayCM +from cuml.internals.api_context_managers import BaseReturnGenericCM +from cuml.internals.api_context_managers import BaseReturnSparseArrayCM +from cuml.internals.api_context_managers import InternalAPIContextBase +from cuml.internals.api_context_managers import ReturnAnyCM +from cuml.internals.api_context_managers import ReturnArrayCM +from cuml.internals.api_context_managers import ReturnGenericCM +from cuml.internals.api_context_managers import ReturnSparseArrayCM +from cuml.internals.api_context_managers import global_output_type_data +from cuml.internals.api_context_managers import set_api_output_dtype +from cuml.internals.api_context_managers import set_api_output_type +from cuml.internals.base_helpers import _get_base_return_type + +# Use _F as a type variable for decorators. See: +# https://github.com/python/mypy/pull/8336/files#diff-eb668b35b7c0c4f88822160f3ca4c111f444c88a38a3b9df9bb8427131538f9cR260 +_F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) + + +class DecoratorMetaClass(type): + """ + This metaclass is used to prevent wrapping functions multiple times by + adding `__cuml_is_wrapped = True` to the function __dict__ + """ + def __new__(cls, classname, bases, classDict): + + if ("__call__" in classDict): + + func = classDict["__call__"] + + @wraps(func) + def wrap_call(*args, **kwargs): + ret_val = func(*args, **kwargs) + + ret_val.__dict__["__cuml_is_wrapped"] = True + + return ret_val + + classDict["__call__"] = wrap_call + + return type.__new__(cls, classname, bases, classDict) + + +class WithArgsDecoratorMixin(object): + """ + This decorator mixin handles processing the input arguments for all api + decorators. It supplies the input_arg, target_arg properties + """ + def __init__(self, + *, + input_arg: str = ..., + target_arg: str = ..., + needs_self=True, + needs_input=False, + needs_target=False): + super().__init__() + + # For input_arg and target_arg, use Ellipsis to auto detect, None to + # skip (this has different functionality on Base where it can determine + # the output type like CumlArrayDescriptor) + self.input_arg = input_arg + self.target_arg = target_arg + + self.needs_self = needs_self + self.needs_input = needs_input + self.needs_target = needs_target + + def prep_arg_to_use(self, func) -> bool: + + # Determine from the signature what processing needs to be done. This + # is executed once per function on import + sig = inspect.signature(func, follow_wrapped=True) + sig_args = list(sig.parameters.keys()) + + self.has_self = "self" in sig.parameters and sig_args.index( + "self") == 0 + + if (not self.has_self and self.needs_self): + raise Exception("No self found on function!") + + # Return early if we dont need args + if (not self.needs_input and not self.needs_target): + return + + self_offset = (1 if self.has_self else 0) + + if (self.needs_input): + input_arg_to_use = self.input_arg + input_arg_to_use_name = None + + # if input_arg is None, then set to first non self argument + if (input_arg_to_use is ...): + + # Check for "X" in input args + if ("X" in sig_args): + input_arg_to_use = "X" + else: + if (len(sig.parameters) <= self_offset): + raise Exception("No input_arg could be determined!") + + input_arg_to_use = sig_args[self_offset] + + # Now convert that to an index + if (isinstance(input_arg_to_use, str)): + input_arg_to_use_name = input_arg_to_use + input_arg_to_use = sig_args.index(input_arg_to_use) + + assert input_arg_to_use != -1 and input_arg_to_use is not None, \ + "Could not determine input_arg" + + # Save the name and argument to use later + self.input_arg_to_use = input_arg_to_use + self.input_arg_to_use_name = input_arg_to_use_name + + if (self.needs_target): + + target_arg_to_use = self.target_arg + target_arg_to_use_name = None + + # if input_arg is None, then set to first non self argument + if (target_arg_to_use is ...): + + # Check for "y" in args + if ("y" in sig_args): + target_arg_to_use = "y" + else: + if (len(sig.parameters) <= self_offset + 1): + raise Exception("No target_arg could be determined!") + + target_arg_to_use = sig_args[self_offset + 1] + + # Now convert that to an index + if (isinstance(target_arg_to_use, str)): + target_arg_to_use_name = target_arg_to_use + target_arg_to_use = sig_args.index(target_arg_to_use) + + assert target_arg_to_use != -1 and target_arg_to_use is not None, \ + "Could not determine target_arg" + + # Save the name and argument to use later + self.target_arg_to_use = target_arg_to_use + self.target_arg_to_use_name = target_arg_to_use_name + + return True + + def get_arg_values(self, *args, **kwargs): + """ + This function is called once per function invocation to get the values + of self, input and target. + + Returns + ------- + tuple + Returns a tuple of self, input, target values + + Raises + ------ + IndexError + Raises an exception if the specified input argument is not + available or called with the wrong number of arguments + """ + self_val = None + input_val = None + target_val = None + + if (self.has_self): + self_val = args[0] + + if (self.needs_input): + # Check if its set to a string + if (isinstance(self.input_arg_to_use, str)): + input_val = kwargs[self.input_arg_to_use] + + # If all arguments are set by name, then this can happen + elif (self.input_arg_to_use >= len(args)): + # Check for the name in kwargs + if (self.input_arg_to_use_name in kwargs): + input_val = kwargs[self.input_arg_to_use_name] + else: + raise IndexError( + ("Specified input_arg idx: {}, and argument name: {}, " + "were not found in args or kwargs").format( + self.input_arg_to_use, + self.input_arg_to_use_name)) + else: + # Otherwise return the index + input_val = args[self.input_arg_to_use] + + if (self.needs_target): + # Check if its set to a string + if (isinstance(self.target_arg_to_use, str)): + target_val = kwargs[self.target_arg_to_use] + + # If all arguments are set by name, then this can happen + elif (self.target_arg_to_use >= len(args)): + # Check for the name in kwargs + if (self.target_arg_to_use_name in kwargs): + target_val = kwargs[self.target_arg_to_use_name] + else: + raise IndexError(( + "Specified target_arg idx: {}, and argument name: {}, " + "were not found in args or kwargs").format( + self.target_arg_to_use, + self.target_arg_to_use_name)) + else: + # Otherwise return the index + target_val = args[self.target_arg_to_use] + + return self_val, input_val, target_val + + +class HasSettersDecoratorMixin(object): + """ + This mixin is responsible for handling any "set_XXX" methods used by api + decorators. Mostly used by `fit()` functions + """ + def __init__(self, + *, + set_output_type=True, + set_output_dtype=False, + set_n_features_in=True) -> None: + + super().__init__() + + self.set_output_type = set_output_type + self.set_output_dtype = set_output_dtype + self.set_n_features_in = set_n_features_in + + self.has_setters = (self.set_output_type or self.set_output_dtype + or self.set_n_features_in) + + def do_setters(self, *, self_val, input_val, target_val): + if (self.set_output_type): + assert input_val is not None, \ + "`set_output_type` is False but no input_arg detected" + self_val._set_output_type(input_val) + + if (self.set_output_dtype): + assert target_val is not None, \ + "`set_output_dtype` is True but no target_arg detected" + self_val._set_target_dtype(target_val) + + if (self.set_n_features_in): + assert input_val is not None, \ + "`set_n_features_in` is False but no input_arg detected" + if (len(input_val.shape) >= 2): + self_val._set_n_features_in(input_val) + + def has_setters_input(self): + return self.set_output_type or self.set_n_features_in + + def has_setters_target(self): + return self.set_output_dtype + + +class HasGettersDecoratorMixin(object): + """ + This mixin is responsible for handling any "get_XXX" methods used by api + decorators. Used for many functions like `predict()`, `transform()`, etc. + """ + def __init__(self, + *, + get_output_type=False, + get_output_dtype=False) -> None: + + super().__init__() + + self.get_output_type = get_output_type + self.get_output_dtype = get_output_dtype + + self.has_getters = (self.get_output_type or self.get_output_dtype) + + def do_getters_with_self_no_input(self, *, self_val): + if (self.get_output_type): + out_type = self_val.output_type + + if (out_type == "input"): + out_type = self_val._input_type + + set_api_output_type(out_type) + + if (self.get_output_dtype): + set_api_output_dtype(self_val._get_target_dtype()) + + def do_getters_with_self(self, *, self_val, input_val): + if (self.get_output_type): + out_type = self_val._get_output_type(input_val) + assert out_type is not None, \ + ("`get_output_type` is False but output_type could not " + "be determined from input_arg") + set_api_output_type(out_type) + + if (self.get_output_dtype): + set_api_output_dtype(self_val._get_target_dtype()) + + def do_getters_no_self(self, *, input_val, target_val): + if (self.get_output_type): + assert input_val is not None, \ + "`get_output_type` is False but no input_arg detected" + set_api_output_type( + cuml.common.input_utils.determine_array_type(input_val)) + + if (self.get_output_dtype): + assert target_val is not None, \ + "`get_output_dtype` is False but no target_arg detected" + set_api_output_dtype( + cuml.common.input_utils.determine_array_dtype(target_val)) + + def has_getters_input(self): + return self.get_output_type + + def has_getters_target(self, needs_self): + return False if needs_self else self.get_output_dtype + + +class ReturnDecorator(metaclass=DecoratorMetaClass): + def __init__(self): + super().__init__() + + self.do_autowrap = False + + def __call__(self, func: _F) -> _F: + raise NotImplementedError() + + def _recreate_cm(self, func, args) -> InternalAPIContextBase: + raise NotImplementedError() + + +class ReturnAnyDecorator(ReturnDecorator): + def __call__(self, func: _F) -> _F: + @wraps(func) + def inner(*args, **kwargs): + with self._recreate_cm(func, args): + return func(*args, **kwargs) + + return inner + + def _recreate_cm(self, func, args): + return ReturnAnyCM(func, args) + + +class BaseReturnAnyDecorator(ReturnDecorator, + HasSettersDecoratorMixin, + WithArgsDecoratorMixin): + def __init__(self, + *, + input_arg: str = ..., + target_arg: str = ..., + set_output_type=True, + set_output_dtype=False, + set_n_features_in=True) -> None: + + ReturnDecorator.__init__(self) + HasSettersDecoratorMixin.__init__(self, + set_output_type=set_output_type, + set_output_dtype=set_output_dtype, + set_n_features_in=set_n_features_in) + WithArgsDecoratorMixin.__init__(self, + input_arg=input_arg, + target_arg=target_arg, + needs_self=True, + needs_input=self.has_setters_input(), + needs_target=self.has_setters_target()) + + self.do_autowrap = self.has_setters + + def __call__(self, func: _F) -> _F: + + self.prep_arg_to_use(func) + + @wraps(func) + def inner_with_setters(*args, **kwargs): + + with self._recreate_cm(func, args): + + self_val, input_val, target_val = \ + self.get_arg_values(*args, **kwargs) + + self.do_setters(self_val=self_val, + input_val=input_val, + target_val=target_val) + + return func(*args, **kwargs) + + @wraps(func) + def inner(*args, **kwargs): + + with self._recreate_cm(func, args): + return func(*args, **kwargs) + + # Return the function depending on whether or not we do any automatic + # wrapping + return inner_with_setters if self.has_setters else inner + + def _recreate_cm(self, func, args): + return BaseReturnAnyCM(func, args) + + +class ReturnArrayDecorator(ReturnDecorator, + HasGettersDecoratorMixin, + WithArgsDecoratorMixin): + def __init__(self, + *, + input_arg: str = ..., + target_arg: str = ..., + get_output_type=False, + get_output_dtype=False) -> None: + + ReturnDecorator.__init__(self) + HasGettersDecoratorMixin.__init__(self, + get_output_type=get_output_type, + get_output_dtype=get_output_dtype) + WithArgsDecoratorMixin.__init__( + self, + input_arg=input_arg, + target_arg=target_arg, + needs_self=False, + needs_input=self.has_getters_input(), + needs_target=self.has_getters_target(False)) + + self.do_autowrap = self.has_getters + + def __call__(self, func: _F) -> _F: + + self.prep_arg_to_use(func) + + @wraps(func) + def inner_with_getters(*args, **kwargs): + with self._recreate_cm(func, args) as cm: + + # Get input/target values + _, input_val, target_val = self.get_arg_values(*args, **kwargs) + + # Now execute the getters + self.do_getters_no_self(input_val=input_val, + target_val=target_val) + + # Call the function + ret_val = func(*args, **kwargs) + + return cm.process_return(ret_val) + + @wraps(func) + def inner(*args, **kwargs): + with self._recreate_cm(func, args) as cm: + + ret_val = func(*args, **kwargs) + + return cm.process_return(ret_val) + + return inner_with_getters if self.has_getters else inner + + def _recreate_cm(self, func, args): + + return ReturnArrayCM(func, args) + + +class ReturnSparseArrayDecorator(ReturnArrayDecorator): + def _recreate_cm(self, func, args): + + return ReturnSparseArrayCM(func, args) + + +class BaseReturnArrayDecorator(ReturnDecorator, + HasSettersDecoratorMixin, + HasGettersDecoratorMixin, + WithArgsDecoratorMixin): + def __init__(self, + *, + input_arg: str = ..., + target_arg: str = ..., + get_output_type=True, + get_output_dtype=False, + set_output_type=False, + set_output_dtype=False, + set_n_features_in=False) -> None: + + ReturnDecorator.__init__(self) + HasSettersDecoratorMixin.__init__(self, + set_output_type=set_output_type, + set_output_dtype=set_output_dtype, + set_n_features_in=set_n_features_in) + HasGettersDecoratorMixin.__init__(self, + get_output_type=get_output_type, + get_output_dtype=get_output_dtype) + WithArgsDecoratorMixin.__init__( + self, + input_arg=input_arg, + target_arg=target_arg, + needs_self=True, + needs_input=(self.has_setters_input() or self.has_getters_input()) + and input_arg is not None, + needs_target=self.has_setters_target() + or self.has_getters_target(True)) + + self.do_autowrap = self.has_setters or self.has_getters + + def __call__(self, func: _F) -> _F: + + self.prep_arg_to_use(func) + + @wraps(func) + def inner_set_get(*args, **kwargs): + with self._recreate_cm(func, args) as cm: + + # Get input/target values + self_val, input_val, target_val = \ + self.get_arg_values(*args, **kwargs) + + # Must do the setters first + self.do_setters(self_val=self_val, + input_val=input_val, + target_val=target_val) + + # Now execute the getters + if (self.needs_input): + self.do_getters_with_self(self_val=self_val, + input_val=input_val) + else: + self.do_getters_with_self_no_input(self_val=self_val) + + # Call the function + ret_val = func(*args, **kwargs) + + return cm.process_return(ret_val) + + @wraps(func) + def inner_set(*args, **kwargs): + with self._recreate_cm(func, args) as cm: + + # Get input/target values + self_val, input_val, target_val = \ + self.get_arg_values(*args, **kwargs) + + # Must do the setters first + self.do_setters(self_val=self_val, + input_val=input_val, + target_val=target_val) + + # Call the function + ret_val = func(*args, **kwargs) + + return cm.process_return(ret_val) + + @wraps(func) + def inner_get(*args, **kwargs): + with self._recreate_cm(func, args) as cm: + + # Get input/target values + self_val, input_val, _ = self.get_arg_values(*args, **kwargs) + + # Do the getters + if (self.needs_input): + self.do_getters_with_self(self_val=self_val, + input_val=input_val) + else: + self.do_getters_with_self_no_input(self_val=self_val) + + # Call the function + ret_val = func(*args, **kwargs) + + return cm.process_return(ret_val) + + @wraps(func) + def inner(*args, **kwargs): + with self._recreate_cm(func, args) as cm: + + # Call the function + ret_val = func(*args, **kwargs) + + return cm.process_return(ret_val) + + # Return the function depending on whether or not we do any automatic + # wrapping + if (self.has_getters and self.has_setters): + return inner_set_get + elif (self.has_getters): + return inner_get + elif (self.has_setters): + return inner_set + else: + return inner + + def _recreate_cm(self, func, args): + + return BaseReturnArrayCM(func, args) + + +class BaseReturnSparseArrayDecorator(BaseReturnArrayDecorator): + def _recreate_cm(self, func, args): + + return BaseReturnSparseArrayCM(func, args) + + +class ReturnGenericDecorator(ReturnArrayDecorator): + def _recreate_cm(self, func, args): + + return ReturnGenericCM(func, args) + + +class BaseReturnGenericDecorator(BaseReturnArrayDecorator): + def _recreate_cm(self, func, args): + + return BaseReturnGenericCM(func, args) + + +class BaseReturnArrayFitTransformDecorator(BaseReturnArrayDecorator): + """ + Identical to `BaseReturnArrayDecorator`, however the defaults have been + changed to better suit `fit_transform` methods + """ + def __init__(self, + *, + input_arg: str = ..., + target_arg: str = ..., + get_output_type=True, + get_output_dtype=False, + set_output_type=True, + set_output_dtype=False, + set_n_features_in=True) -> None: + + super().__init__(input_arg=input_arg, + target_arg=target_arg, + get_output_type=get_output_type, + get_output_dtype=get_output_dtype, + set_output_type=set_output_type, + set_output_dtype=set_output_dtype, + set_n_features_in=set_n_features_in) + + +api_return_any = ReturnAnyDecorator +api_base_return_any = BaseReturnAnyDecorator +api_return_array = ReturnArrayDecorator +api_base_return_array = BaseReturnArrayDecorator +api_return_generic = ReturnGenericDecorator +api_base_return_generic = BaseReturnGenericDecorator +api_base_fit_transform = BaseReturnArrayFitTransformDecorator + +api_return_sparse_array = ReturnSparseArrayDecorator +api_base_return_sparse_array = BaseReturnSparseArrayDecorator + +api_return_array_skipall = ReturnArrayDecorator(get_output_dtype=False, + get_output_type=False) + +api_base_return_any_skipall = BaseReturnAnyDecorator(set_output_type=False, + set_n_features_in=False) +api_base_return_array_skipall = BaseReturnArrayDecorator(get_output_type=False) +api_base_return_generic_skipall = BaseReturnGenericDecorator( + get_output_type=False) + + +def api_ignore(func: _F) -> _F: + + func.__dict__["__cuml_is_wrapped"] = True + + return func + + +@contextlib.contextmanager +def exit_internal_api(): + + assert (global_output_type_data.root_cm is not None) + + try: + old_root_cm = global_output_type_data.root_cm + + global_output_type_data.root_cm = None + + # Set the global output type to the previous value to pretend we never + # entered the API + with cuml.using_output_type(old_root_cm.prev_output_type): + + yield + + finally: + global_output_type_data.root_cm = old_root_cm + + +def mirror_args( + wrapped: _F, + assigned=('__doc__', '__annotations__'), + updated=functools.WRAPPER_UPDATES) -> typing.Callable[[_F], _F]: + return wraps(wrapped=wrapped, assigned=assigned, updated=updated) + + +@mirror_args(BaseReturnArrayDecorator) +def api_base_return_autoarray(*args, **kwargs): + def inner(func: _F) -> _F: + # Determine the array return type and choose + return_type = _get_base_return_type(None, func) + + if (return_type == "generic"): + func = api_base_return_generic(*args, **kwargs)(func) + elif (return_type == "array"): + func = api_base_return_array(*args, **kwargs)(func) + elif (return_type == "sparsearray"): + func = api_base_return_sparse_array(*args, **kwargs)(func) + elif (return_type == "base"): + assert False, \ + ("Must use api_base_return_autoarray decorator on function " + "that returns some array") + + return func + + return inner diff --git a/python/cuml/internals/base_helpers.py b/python/cuml/internals/base_helpers.py new file mode 100644 index 0000000000..151d7add46 --- /dev/null +++ b/python/cuml/internals/base_helpers.py @@ -0,0 +1,145 @@ +# +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import typing + +import cuml +import cuml.internals +import cuml.common + + +def _process_generic(gen_type): + + # Check if the type is not a generic. If not, must return "generic" if + # subtype is CumlArray otherwise None + if (not isinstance(gen_type, typing._GenericAlias)): + if (issubclass(gen_type, cuml.common.CumlArray)): + return "generic" + + # We don't handle SparseCumlArray at this time + if (issubclass(gen_type, cuml.common.SparseCumlArray)): + raise NotImplementedError( + "Generic return types with SparseCumlArray are not supported " + "at this time") + + # Otherwise None (keep processing) + return None + + # Its a generic type by this point. Support Union, Tuple, Dict and List + supported_gen_types = [ + tuple, + dict, + list, + typing.Union, + ] + + if (gen_type.__origin__ in supported_gen_types): + # Check for a CumlArray type in the args + for arg in gen_type.__args__: + inner_type = _process_generic(arg) + + if (inner_type is not None): + return inner_type + else: + raise NotImplementedError("Unknow generic type: {}".format(gen_type)) + + return None + + +def _get_base_return_type(class_name, attr): + + if (not hasattr(attr, "__annotations__") + or "return" not in attr.__annotations__): + return None + + try: + type_hints = typing.get_type_hints(attr) + + if ("return" in type_hints): + + ret_type = type_hints["return"] + + is_generic = isinstance(ret_type, typing._GenericAlias) + + if (is_generic): + return _process_generic(ret_type) + elif (issubclass(ret_type, cuml.common.CumlArray)): + return "array" + elif (issubclass(ret_type, cuml.common.SparseCumlArray)): + return "sparsearray" + elif (issubclass(ret_type, cuml.Base)): + return "base" + else: + return None + except NameError: + # A NameError is raised if the return type is the same as the + # type being defined (which is incomplete). Check that here and + # return base if the name matches + if (attr.__annotations__["return"] == class_name): + return "base" + except Exception: + assert False, "Shouldnt get here" + return None + + return None + + +def _wrap_attribute(class_name: str, + attribute_name: str, + attribute, + **kwargs): + + # Skip items marked with autowrap_ignore + if ("__cuml_is_wrapped" in attribute.__dict__ + and attribute.__dict__["__cuml_is_wrapped"]): + return attribute + + return_type = _get_base_return_type(class_name, attribute) + + if (return_type == "generic"): + attribute = cuml.internals.api_base_return_generic(**kwargs)(attribute) + elif (return_type == "array"): + attribute = cuml.internals.api_base_return_array(**kwargs)(attribute) + elif (return_type == "sparsearray"): + attribute = cuml.internals.api_base_return_sparse_array( + **kwargs)(attribute) + elif (return_type == "base"): + attribute = cuml.internals.api_base_return_any(**kwargs)(attribute) + elif (not attribute_name.startswith("_")): + # Only replace public functions with return any + attribute = cuml.internals.api_return_any()(attribute) + + return attribute + + +class BaseMetaClass(type): + def __new__(cls, classname, bases, classDict): + + for attributeName, attribute in classDict.items(): + # Must be a function + if callable(attribute): + classDict[attributeName] = _wrap_attribute( + classname, attributeName, attribute) + elif isinstance(attribute, property): + # Need to wrap the getter if it exists + if (hasattr(attribute, "fget") and attribute.fget is not None): + classDict[attributeName] = attribute.getter( + _wrap_attribute(classname, + attributeName, + attribute.fget, + input_arg=None)) + + return type.__new__(cls, classname, bases, classDict) diff --git a/python/cuml/linear_model/base_mg.pyx b/python/cuml/linear_model/base_mg.pyx index 8642a40409..0eb38051db 100644 --- a/python/cuml/linear_model/base_mg.pyx +++ b/python/cuml/linear_model/base_mg.pyx @@ -25,6 +25,7 @@ import rmm from libc.stdint cimport uintptr_t from cython.operator cimport dereference as deref +import cuml.internals from cuml.common.base import Base from cuml.common.array import CumlArray from cuml.raft.common.handle cimport handle_t @@ -35,6 +36,7 @@ from cuml.decomposition.utils cimport * class MGFitMixin(object): + @cuml.internals.api_base_return_any_skipall def fit(self, input_data, n_rows, n_cols, partsToSizes, rank): """ Fit function for MNMG linear regression classes @@ -69,9 +71,9 @@ class MGFitMixin(object): check_dtype=self.dtype) y_arys.append(y_m) - self._coef_ = CumlArray.zeros(self.n_cols, - dtype=self.dtype) - cdef uintptr_t coef_ptr = self._coef_.ptr + self.coef_ = CumlArray.zeros(self.n_cols, + dtype=self.dtype) + cdef uintptr_t coef_ptr = self.coef_.ptr coef_ptr_arg = coef_ptr cdef uintptr_t rank_to_sizes = opg.build_rank_size_pair(partsToSizes, diff --git a/python/cuml/linear_model/elastic_net.pyx b/python/cuml/linear_model/elastic_net.pyx index 49eddd4c42..fb942bc814 100644 --- a/python/cuml/linear_model/elastic_net.pyx +++ b/python/cuml/linear_model/elastic_net.pyx @@ -19,6 +19,8 @@ from cuml.solvers import CD from cuml.common.base import Base, RegressorMixin from cuml.common.doc_utils import generate_docstring +from cuml.common.array import CumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor class ElasticNet(Base, RegressorMixin): @@ -141,6 +143,8 @@ class ElasticNet(Base, RegressorMixin): `_. """ + coef_ = CumlArrayDescriptor() + def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True, normalize=False, max_iter=1000, tol=1e-3, selection='cyclic', handle=None, output_type=None, verbose=False): @@ -171,7 +175,6 @@ class ElasticNet(Base, RegressorMixin): self.alpha = alpha self.l1_ratio = l1_ratio - self._coef_ = None self.intercept_ = None self.fit_intercept = fit_intercept self.normalize = normalize @@ -206,12 +209,11 @@ class ElasticNet(Base, RegressorMixin): raise ValueError(msg.format(l1_ratio)) @generate_docstring() - def fit(self, X, y, convert_dtype=True): + def fit(self, X, y, convert_dtype=True) -> "ElasticNet": """ Fit the model with X and y. """ - self._set_base_attributes(output_type=X, n_features=X) self.solver_model.fit(X, y, convert_dtype=convert_dtype) return self @@ -220,7 +222,7 @@ class ElasticNet(Base, RegressorMixin): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=True): + def predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts `y` values for `X`. diff --git a/python/cuml/linear_model/lasso.pyx b/python/cuml/linear_model/lasso.pyx index 24eccb3d27..21013195e3 100644 --- a/python/cuml/linear_model/lasso.pyx +++ b/python/cuml/linear_model/lasso.pyx @@ -16,6 +16,8 @@ # distutils: language = c++ +import cuml.internals +from cuml.common.array import CumlArray from cuml.solvers import CD from cuml.common.base import Base, RegressorMixin from cuml.common.doc_utils import generate_docstring @@ -175,12 +177,11 @@ class Lasso(Base, RegressorMixin): raise ValueError(msg.format(alpha)) @generate_docstring() - def fit(self, X, y, convert_dtype=True): + def fit(self, X, y, convert_dtype=True) -> "Lasso": """ Fit the model with X and y. """ - self._set_base_attributes(output_type=X, n_features=X) self.solver_model.fit(X, y, convert_dtype=convert_dtype) return self @@ -189,7 +190,8 @@ class Lasso(Base, RegressorMixin): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=True): + @cuml.internals.api_base_return_array_skipall + def predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts the y for X. diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index 2c22249e45..54b5c0be7a 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -29,6 +29,7 @@ from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free from cuml.common.array import CumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.base import Base, RegressorMixin from cuml.common.doc_utils import generate_docstring from cuml.raft.common.handle cimport handle_t @@ -197,6 +198,9 @@ class LinearRegression(Base, RegressorMixin): """ + coef_ = CumlArrayDescriptor() + intercept_ = CumlArrayDescriptor() + def __init__(self, algorithm='eig', fit_intercept=True, normalize=False, handle=None, verbose=False, output_type=None): super(LinearRegression, self).__init__(handle=handle, @@ -204,8 +208,8 @@ class LinearRegression(Base, RegressorMixin): output_type=output_type) # internal array attributes - self._coef_ = None # accessed via estimator.coef_ - self._intercept_ = None # accessed via estimator.intercept_ + self.coef_ = None + self.intercept_ = None self.fit_intercept = fit_intercept self.normalize = normalize @@ -225,13 +229,11 @@ class LinearRegression(Base, RegressorMixin): }[algorithm] @generate_docstring() - def fit(self, X, y, convert_dtype=True): + def fit(self, X, y, convert_dtype=True) -> "LinearRegression": """ Fit the model with X and y. """ - self._set_base_attributes(output_type=X, n_features=X) - cdef uintptr_t X_ptr, y_ptr X_m, n_rows, self.n_cols, self.dtype = \ input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) @@ -258,8 +260,8 @@ class LinearRegression(Base, RegressorMixin): "column currently.", UserWarning) self.algo = 0 - self._coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) - cdef uintptr_t coef_ptr = self._coef_.ptr + self.coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) + cdef uintptr_t coef_ptr = self.coef_.ptr cdef float c_intercept1 cdef double c_intercept2 @@ -304,14 +306,11 @@ class LinearRegression(Base, RegressorMixin): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=True): + def predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts `y` values for `X`. """ - - out_type = self._get_output_type(X) - cdef uintptr_t X_ptr X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, @@ -320,7 +319,7 @@ class LinearRegression(Base, RegressorMixin): check_cols=self.n_cols) X_ptr = X_m.ptr - cdef uintptr_t coef_ptr = self._coef_.ptr + cdef uintptr_t coef_ptr = self.coef_.ptr preds = CumlArray.zeros(n_rows, dtype=dtype) cdef uintptr_t preds_ptr = preds.ptr @@ -348,7 +347,7 @@ class LinearRegression(Base, RegressorMixin): del(X_m) - return preds.to_output(out_type) + return preds def get_param_names(self): return super().get_param_names() + \ diff --git a/python/cuml/linear_model/linear_regression_mg.pyx b/python/cuml/linear_model/linear_regression_mg.pyx index 24bc3ab332..bddafc3eb5 100644 --- a/python/cuml/linear_model/linear_regression_mg.pyx +++ b/python/cuml/linear_model/linear_regression_mg.pyx @@ -27,6 +27,7 @@ from libcpp cimport bool from libc.stdint cimport uintptr_t, uint32_t, uint64_t from cython.operator cimport dereference as deref +import cuml.internals from cuml.common.base import Base from cuml.common.array import CumlArray from cuml.raft.common.handle cimport handle_t @@ -68,6 +69,7 @@ class LinearRegressionMG(MGFitMixin, LinearRegression): def __init__(self, **kwargs): super(LinearRegressionMG, self).__init__(**kwargs) + @cuml.internals.api_base_return_any_skipall def _fit(self, X, y, coef_ptr, input_desc): cdef float float_intercept diff --git a/python/cuml/linear_model/logistic_regression.pyx b/python/cuml/linear_model/logistic_regression.pyx index d367eadd65..5f10cd18b5 100644 --- a/python/cuml/linear_model/logistic_regression.pyx +++ b/python/cuml/linear_model/logistic_regression.pyx @@ -19,12 +19,15 @@ import cupy as cp import pprint +import cuml.internals from cuml.solvers import QN from cuml.common.base import Base, ClassifierMixin +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.array import CumlArray from cuml.common.doc_utils import generate_docstring import cuml.common.logger as logger -from cuml.common import input_to_cuml_array, with_cupy_rmm +from cuml.common import input_to_cuml_array +from cuml.common import using_output_type supported_penalties = ["l1", "l2", "none", "elasticnet"] @@ -170,6 +173,8 @@ class LogisticRegression(Base, ClassifierMixin): `_. """ + classes_ = CumlArrayDescriptor() + def __init__( self, penalty="l2", @@ -257,22 +262,19 @@ class LogisticRegression(Base, ClassifierMixin): self.verb_prefix = "" @generate_docstring() - @with_cupy_rmm - def fit(self, X, y, convert_dtype=True): + @cuml.internals.api_base_return_any(set_output_dtype=True) + def fit(self, X, y, convert_dtype=True) -> "LogisticRegression": """ Fit the model with X and y. """ - self.solver_model._set_base_attributes(target_dtype=y) - self._set_base_attributes(output_type=X, n_features=X) - # Converting y to device array here to use `unique` function - # since calling input_to_dev_array again in QN has no cost + # since calling input_to_cuml_array again in QN has no cost # Not needed to check dtype since qn class checks it already y_m, _, _, _ = input_to_cuml_array(y) - self._classes_ = CumlArray(cp.unique(y_m)) - self._num_classes = len(self._classes_) + self.classes_ = cp.unique(y_m) + self._num_classes = len(self.classes_) if self._num_classes > 2: loss = "softmax" @@ -296,14 +298,15 @@ class LogisticRegression(Base, ClassifierMixin): ) if logger.should_log_for(logger.level_trace): - logger.trace(self.verb_prefix + "Coefficients: " + - str(self._coef_.to_output("cupy"))) - if self.fit_intercept: - logger.trace( - self.verb_prefix - + "Intercept: " - + str(self._intercept_.to_output("cupy")) - ) + with using_output_type("cupy"): + logger.trace(self.verb_prefix + "Coefficients: " + + str(self.solver_model.coef_)) + if self.fit_intercept: + logger.trace( + self.verb_prefix + + "Intercept: " + + str(self.solver_model.intercept_) + ) return self @@ -311,7 +314,7 @@ class LogisticRegression(Base, ClassifierMixin): 'type': 'dense', 'description': 'Confidence score', 'shape': '(n_samples, n_classes)'}) - def decision_function(self, X, convert_dtype=False): + def decision_function(self, X, convert_dtype=False) -> CumlArray: """ Gives confidence score for X @@ -319,13 +322,14 @@ class LogisticRegression(Base, ClassifierMixin): return self.solver_model._decision_function( X, convert_dtype=convert_dtype - ).to_output(output_type=self._get_output_type(X)) + ) @generate_docstring(return_values={'name': 'preds', 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=True): + @cuml.internals.api_base_return_array(get_output_dtype=True) + def predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts the y for X. @@ -337,8 +341,7 @@ class LogisticRegression(Base, ClassifierMixin): 'description': 'Predicted class \ probabilities', 'shape': '(n_samples, n_classes)'}) - @with_cupy_rmm - def predict_proba(self, X, convert_dtype=True): + def predict_proba(self, X, convert_dtype=True) -> CumlArray: """ Predicts the class probabilities for each class in X @@ -354,7 +357,7 @@ class LogisticRegression(Base, ClassifierMixin): 'description': 'Logaright of predicted \ class probabilities', 'shape': '(n_samples, n_classes)'}) - def predict_log_proba(self, X, convert_dtype=True): + def predict_log_proba(self, X, convert_dtype=True) -> CumlArray: """ Predicts the log class probabilities for each class in X @@ -365,9 +368,10 @@ class LogisticRegression(Base, ClassifierMixin): log_proba=True ) - def _predict_proba_impl(self, X, convert_dtype=False, log_proba=False): - out_type = self._get_output_type(X) - + def _predict_proba_impl(self, + X, + convert_dtype=False, + log_proba=False) -> CumlArray: # TODO: # We currently need to grab the dtype and ncols attributes via the # qn solver due to https://github.com/rapidsai/cuml/issues/2404 @@ -397,8 +401,7 @@ class LogisticRegression(Base, ClassifierMixin): if log_proba: proba = cp.log(proba) - proba = CumlArray(proba) - return proba.to_output(out_type) + return proba def get_param_names(self): return super().get_param_names() + [ diff --git a/python/cuml/linear_model/mbsgd_classifier.pyx b/python/cuml/linear_model/mbsgd_classifier.pyx index f70b13b798..90f923c15d 100644 --- a/python/cuml/linear_model/mbsgd_classifier.pyx +++ b/python/cuml/linear_model/mbsgd_classifier.pyx @@ -15,6 +15,9 @@ # # distutils: language = c++ + +import cuml.internals +from cuml.common.array import CumlArray from cuml.common.base import Base, ClassifierMixin from cuml.common.doc_utils import generate_docstring from cuml.solvers import SGD @@ -175,12 +178,11 @@ class MBSGDClassifier(Base, ClassifierMixin): self.solver_model = SGD(**self.get_params()) @generate_docstring() - def fit(self, X, y, convert_dtype=True): + def fit(self, X, y, convert_dtype=True) -> "MBSGDClassifier": """ Fit the model with X and y. """ - self._set_base_attributes(n_features=X) self.solver_model._estimator_type = self._estimator_type self.solver_model.fit(X, y, convert_dtype=convert_dtype) return self @@ -189,7 +191,8 @@ class MBSGDClassifier(Base, ClassifierMixin): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=False): + @cuml.internals.api_base_return_array_skipall + def predict(self, X, convert_dtype=False) -> CumlArray: """ Predicts the y for X. diff --git a/python/cuml/linear_model/mbsgd_regressor.pyx b/python/cuml/linear_model/mbsgd_regressor.pyx index bee609cbf7..68a8ac6b80 100644 --- a/python/cuml/linear_model/mbsgd_regressor.pyx +++ b/python/cuml/linear_model/mbsgd_regressor.pyx @@ -15,6 +15,9 @@ # # distutils: language = c++ + +import cuml.internals +from cuml.common.array import CumlArray from cuml.common.base import Base, RegressorMixin from cuml.common.doc_utils import generate_docstring from cuml.solvers import SGD @@ -171,12 +174,11 @@ class MBSGDRegressor(Base, RegressorMixin): self.solver_model = SGD(**self.get_params()) @generate_docstring() - def fit(self, X, y, convert_dtype=True): + def fit(self, X, y, convert_dtype=True) -> "MBSGDRegressor": """ Fit the model with X and y. """ - self._set_base_attributes(n_features=X) self.solver_model.fit(X, y, convert_dtype=convert_dtype) return self @@ -184,7 +186,8 @@ class MBSGDRegressor(Base, RegressorMixin): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=False): + @cuml.internals.api_base_return_array_skipall + def predict(self, X, convert_dtype=False) -> CumlArray: """ Predicts the y for X. diff --git a/python/cuml/linear_model/ridge.pyx b/python/cuml/linear_model/ridge.pyx index f6cc24c043..9634b05f75 100644 --- a/python/cuml/linear_model/ridge.pyx +++ b/python/cuml/linear_model/ridge.pyx @@ -27,6 +27,7 @@ from libcpp cimport bool from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.base import Base, RegressorMixin from cuml.common.array import CumlArray from cuml.common.doc_utils import generate_docstring @@ -207,6 +208,9 @@ class Ridge(Base, RegressorMixin): `_. """ + coef_ = CumlArrayDescriptor() + intercept_ = CumlArrayDescriptor() + def __init__(self, alpha=1.0, solver='eig', fit_intercept=True, normalize=False, handle=None, output_type=None, verbose=False): @@ -229,8 +233,8 @@ class Ridge(Base, RegressorMixin): output_type=output_type) # internal array attributes - self._coef_ = None # accessed via estimator.coef_ - self._intercept_ = None # accessed via estimator.intercept_ + self.coef_ = None + self.intercept_ = None self.alpha = alpha self.fit_intercept = fit_intercept @@ -257,12 +261,11 @@ class Ridge(Base, RegressorMixin): }[algorithm] @generate_docstring() - def fit(self, X, y, convert_dtype=True): + def fit(self, X, y, convert_dtype=True) -> "Ridge": """ Fit the model with X and y. """ - self._set_base_attributes(output_type=X, n_features=X) cdef uintptr_t X_ptr, y_ptr X_m, n_rows, self.n_cols, self.dtype = \ input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) @@ -291,8 +294,8 @@ class Ridge(Base, RegressorMixin): self.n_alpha = 1 - self._coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) - cdef uintptr_t coef_ptr = self._coef_.ptr + self.coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) + cdef uintptr_t coef_ptr = self.coef_.ptr cdef float c_intercept1 cdef double c_intercept2 @@ -345,13 +348,11 @@ class Ridge(Base, RegressorMixin): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=True): + def predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts the y for X. """ - out_type = self._get_output_type(X) - cdef uintptr_t X_ptr X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, @@ -360,7 +361,7 @@ class Ridge(Base, RegressorMixin): check_cols=self.n_cols) X_ptr = X_m.ptr - cdef uintptr_t coef_ptr = self._coef_.ptr + cdef uintptr_t coef_ptr = self.coef_.ptr preds = CumlArray.zeros(n_rows, dtype=dtype) cdef uintptr_t preds_ptr = preds.ptr @@ -388,7 +389,7 @@ class Ridge(Base, RegressorMixin): del(X_m) - return preds.to_output(out_type) + return preds def get_param_names(self): return super().get_param_names() + \ diff --git a/python/cuml/linear_model/ridge_mg.pyx b/python/cuml/linear_model/ridge_mg.pyx index a143607baf..fa7f85c30d 100644 --- a/python/cuml/linear_model/ridge_mg.pyx +++ b/python/cuml/linear_model/ridge_mg.pyx @@ -27,6 +27,7 @@ from libcpp cimport bool from libc.stdint cimport uintptr_t, uint32_t, uint64_t from cython.operator cimport dereference as deref +import cuml.internals from cuml.common.base import Base from cuml.common.array import CumlArray from cuml.raft.common.handle cimport handle_t @@ -71,6 +72,7 @@ class RidgeMG(MGFitMixin, Ridge): def __init__(self, **kwargs): super(RidgeMG, self).__init__(**kwargs) + @cuml.internals.api_base_return_any_skipall def _fit(self, X, y, coef_ptr, input_desc): cdef float float_intercept diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 15e91118c5..c9fd1e149d 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -20,13 +20,14 @@ # cython: wraparound = False import cudf -import cuml import ctypes import numpy as np import inspect import pandas as pd import warnings +import cuml.internals +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.base import Base from cuml.raft.common.handle cimport handle_t import cuml.common.logger as logger @@ -192,6 +193,9 @@ class TSNE(Base): (https://arxiv.org/abs/1807.11824). """ + + embedding_ = CumlArrayDescriptor() + def __init__(self, n_components=2, perplexity=30.0, @@ -322,12 +326,11 @@ class TSNE(Base): self.post_learning_rate = learning_rate * 2 @generate_docstring(convert_dtype_cast='np.float32') - def fit(self, X, convert_dtype=True): + def fit(self, X, convert_dtype=True) -> "TSNE": """ Fit X into an embedded space. """ - self._set_base_attributes(n_features=X) cdef int n, p cdef handle_t* handle_ = self.handle.getHandle() if handle_ == NULL: @@ -415,13 +418,13 @@ class TSNE(Base): (self.method == 'barnes_hut')) # Clean up memory - self._embedding_ = Y + self.embedding_ = Y return self def __del__(self): - if hasattr(self, '_embedding_'): - del self._embedding_ - self._embedding_ = None + if hasattr(self, 'embedding_'): + del self.embedding_ + self.embedding_ = None @generate_docstring(convert_dtype_cast='np.float32', return_values={'name': 'X_new', @@ -430,17 +433,22 @@ class TSNE(Base): training data in \ low-dimensional space.', 'shape': '(n_samples, n_components)'}) - def fit_transform(self, X, convert_dtype=True): + @cuml.internals.api_base_return_array_skipall + def fit_transform(self, X, convert_dtype=True) -> CumlArray: """ Fit X into an embedded space and return that transformed output. + """ + return self.fit(X, convert_dtype=convert_dtype)._transform(X) - + def _transform(self, X) -> CumlArray: + """ + Internal transform function to allow base wrappers default + functionality to work """ - self.fit(X, convert_dtype=convert_dtype) - out_type = self._get_output_type(X) - data = self._embedding_.to_output(out_type) - del self._embedding_ + data = self.embedding_ + + del self.embedding_ return data diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 41e83534b5..5d54058ebb 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -16,8 +16,8 @@ # distutils: language = c++ +import typing import cudf -import cuml import ctypes import numpy as np import pandas as pd @@ -32,13 +32,16 @@ import numba.cuda as cuda from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix,\ coo_matrix as cp_coo_matrix, csc_matrix as cp_csc_matrix +import cuml.internals +from cuml.common import using_output_type from cuml.common.base import Base from cuml.raft.common.handle cimport handle_t from cuml.common.doc_utils import generate_docstring from cuml.common.input_utils import input_to_cuml_array -from cuml.common.memory_utils import with_cupy_rmm +from cuml.common.memory_utils import using_output_type from cuml.common.import_utils import has_scipy from cuml.common.array import CumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor import rmm @@ -296,6 +299,9 @@ class UMAP(Base): `_ """ + X_m = CumlArrayDescriptor() + embedding_ = CumlArrayDescriptor() + def __init__(self, n_neighbors=15, n_components=2, @@ -380,8 +386,8 @@ class UMAP(Base): self.optim_batch_size = optim_batch_size self.callback = callback # prevent callback destruction - self._X_m = None # accessed via X_m - self._embedding_ = None # accessed via embedding_ + self.X_m = None + self.embedding_ = None self.validate_hyperparams() @@ -457,8 +463,13 @@ class UMAP(Base): params, covar = curve_fit(curve, xv, yv) return params[0], params[1] - @with_cupy_rmm - def _extract_knn_graph(self, knn_graph, convert_dtype=True): + @cuml.internals.api_base_return_generic_skipall + def _extract_knn_graph( + self, + knn_graph, + convert_dtype=True + ) -> typing.Tuple[typing.Tuple[CumlArray, typing.Any], + typing.Tuple[CumlArray, typing.Any]]: if has_scipy(): from scipy.sparse import csr_matrix, coo_matrix, csc_matrix else: @@ -509,9 +520,8 @@ class UMAP(Base): @generate_docstring(convert_dtype_cast='np.float32', skip_parameters_heading=True) - @with_cupy_rmm def fit(self, X, y=None, convert_dtype=True, - knn_graph=None): + knn_graph=None) -> "UMAP": """ Fit X into an embedded space. @@ -544,7 +554,7 @@ class UMAP(Base): raise ValueError("Cannot provide a KNN graph when in \ semi-supervised mode with categorical target_metric for now.") - self._X_m, self.n_rows, self.n_dims, dtype = \ + self.X_m, self.n_rows, self.n_dims, dtype = \ input_to_cuml_array(X, order='C', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype @@ -554,8 +564,6 @@ class UMAP(Base): raise ValueError("There needs to be more than 1 sample to " "build nearest the neighbors graph") - self._set_base_attributes(output_type=X, n_features=X) - (knn_indices_m, knn_indices_ctype), (knn_dists_m, knn_dists_ctype) =\ self._extract_knn_graph(knn_graph, convert_dtype) @@ -564,18 +572,19 @@ class UMAP(Base): self.n_neighbors = min(self.n_rows, self.n_neighbors) - self._embedding_ = CumlArray.zeros((self.n_rows, + self.embedding_ = CumlArray.zeros((self.n_rows, self.n_components), - order="C", dtype=np.float32) + order="C", dtype=np.float32) if self.hash_input: - self.input_hash = joblib.hash(self._X_m.to_output('numpy')) + with using_output_type("numpy"): + self.input_hash = joblib.hash(self.X_m) cdef handle_t * handle_ = \ self.handle.getHandle() - cdef uintptr_t x_raw = self._X_m.ptr - cdef uintptr_t embed_raw = self._embedding_.ptr + cdef uintptr_t x_raw = self.X_m.ptr + cdef uintptr_t embed_raw = self.embedding_.ptr cdef UMAPParams* umap_params = \ UMAP._build_umap_params(self) @@ -622,8 +631,9 @@ class UMAP(Base): data in \ low-dimensional space.', 'shape': '(n_samples, n_components)'}) + @cuml.internals.api_base_fit_transform() def fit_transform(self, X, y=None, convert_dtype=True, - knn_graph=None): + knn_graph=None) -> CumlArray: """ Fit X into an embedded space and return that transformed output. @@ -656,10 +666,9 @@ class UMAP(Base): CSR/COO preferred other formats will go through conversion to CSR """ - self.fit(X, y, convert_dtype=convert_dtype, - knn_graph=knn_graph) - out_type = self._get_output_type(X) - return self._embedding_.to_output(out_type) + self.fit(X, y, convert_dtype=convert_dtype, knn_graph=knn_graph) + + return self.embedding_ @generate_docstring(convert_dtype_cast='np.float32', skip_parameters_heading=True, @@ -669,9 +678,7 @@ class UMAP(Base): data in \ low-dimensional space.', 'shape': '(n_samples, n_components)'}) - @with_cupy_rmm - def transform(self, X, convert_dtype=True, - knn_graph=None): + def transform(self, X, convert_dtype=True, knn_graph=None) -> CumlArray: """ Transform X into the existing embedded space and return that transformed output. @@ -723,16 +730,14 @@ class UMAP(Base): raise ValueError("n_features of X must match n_features of " "training data") - out_type = self._get_output_type(X) - if self.hash_input and joblib.hash(X_m.to_output('numpy')) == \ self.input_hash: - ret = self._embedding_.to_output(out_type) + del X_m - return ret + return self.embedding_ embedding = CumlArray.zeros((X_m.shape[0], - self.n_components), + self.n_components), order="C", dtype=np.float32) cdef uintptr_t xformed_ptr = embedding.ptr @@ -745,8 +750,8 @@ class UMAP(Base): cdef handle_t * handle_ = \ self.handle.getHandle() - cdef uintptr_t orig_x_raw = self._X_m.ptr - cdef uintptr_t embed_ptr = self._embedding_.ptr + cdef uintptr_t orig_x_raw = self.X_m.ptr + cdef uintptr_t embed_ptr = self.embedding_.ptr cdef UMAPParams* umap_params = \ UMAP._build_umap_params(self) @@ -767,9 +772,8 @@ class UMAP(Base): UMAP._destroy_umap_params(umap_params) - ret = embedding.to_output(out_type) del X_m - return ret + return embedding def get_param_names(self): return super().get_param_names() + [ diff --git a/python/cuml/metrics/_classification.py b/python/cuml/metrics/_classification.py index ad24926590..7b388f1955 100644 --- a/python/cuml/metrics/_classification.py +++ b/python/cuml/metrics/_classification.py @@ -16,12 +16,16 @@ import cupy as cp import numpy as np -from cuml.common.memory_utils import with_cupy_rmm -from cuml.common import input_to_cuml_array +import cuml.internals +from cuml.common.input_utils import input_to_cupy_array -@with_cupy_rmm -def log_loss(y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None): +@cuml.internals.api_return_any() +def log_loss(y_true, + y_pred, + eps=1e-15, + normalize=True, + sample_weight=None) -> float: """ Log loss, aka logistic loss or cross-entropy loss. This is the loss function used in (multinomial) logistic regression and extensions of it such as neural networks, defined as the negative @@ -66,21 +70,19 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None): """ y_true, n_rows, n_cols, ytype = \ - input_to_cuml_array(y_true, check_dtype=[np.int32, np.int64, + input_to_cupy_array(y_true, check_dtype=[np.int32, np.int64, np.float32, np.float64]) - y_true = y_true.to_output('cupy') if y_true.dtype.kind == 'f' and np.any(y_true != y_true.astype(int)): raise ValueError("'y_true' can only have integer values") if y_true.min() < 0: raise ValueError("'y_true' cannot have negative values") y_pred, _, _, _ = \ - input_to_cuml_array(y_pred, check_dtype=[np.int32, np.int64, + input_to_cupy_array(y_pred, check_dtype=[np.int32, np.int64, np.float32, np.float64], check_rows=n_rows) - y_pred = y_pred.to_output('cupy') y_true_max = y_true.max() if (y_pred.ndim == 1 and y_true_max > 1) \ or (y_pred.ndim > 1 and y_pred.shape[1] <= y_true_max): diff --git a/python/cuml/metrics/_ranking.py b/python/cuml/metrics/_ranking.py index 37263ad627..031a234e26 100644 --- a/python/cuml/metrics/_ranking.py +++ b/python/cuml/metrics/_ranking.py @@ -14,15 +14,18 @@ # limitations under the License. # +import typing import cupy as cp import numpy as np -from cuml.common.memory_utils import with_cupy_rmm -from cuml.common import input_to_cuml_array +import cuml.internals +from cuml.common.array import CumlArray +from cuml.common.input_utils import input_to_cupy_array import math -@with_cupy_rmm -def precision_recall_curve(y_true, probs_pred): +@cuml.internals.api_return_generic(get_output_type=True) +def precision_recall_curve( + y_true, probs_pred) -> typing.Tuple[CumlArray, CumlArray, CumlArray]: """ Compute precision-recall pairs for different probability thresholds @@ -86,17 +89,14 @@ def precision_recall_curve(y_true, probs_pred): """ y_true, n_rows, n_cols, ytype = \ - input_to_cuml_array(y_true, check_dtype=[np.int32, np.int64, + input_to_cupy_array(y_true, check_dtype=[np.int32, np.int64, np.float32, np.float64]) y_score, _, _, _ = \ - input_to_cuml_array(probs_pred, check_dtype=[np.int32, np.int64, + input_to_cupy_array(probs_pred, check_dtype=[np.int32, np.int64, np.float32, np.float64], check_rows=n_rows, check_cols=n_cols) - y_true = y_true.to_output('cupy') - y_score = y_score.to_output('cupy') - if cp.any(y_true) == 0: raise ValueError("precision_recall_curve cannot be used when " "y_true is all zero.") @@ -116,7 +116,7 @@ def precision_recall_curve(y_true, probs_pred): return precision, recall, thresholds -@with_cupy_rmm +@cuml.internals.api_return_any() def roc_auc_score(y_true, y_score): """ Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) @@ -151,11 +151,11 @@ def roc_auc_score(y_true, y_score): """ y_true, n_rows, n_cols, ytype = \ - input_to_cuml_array(y_true, check_dtype=[np.int32, np.int64, + input_to_cupy_array(y_true, check_dtype=[np.int32, np.int64, np.float32, np.float64]) y_score, _, _, _ = \ - input_to_cuml_array(y_score, check_dtype=[np.int32, np.int64, + input_to_cupy_array(y_score, check_dtype=[np.int32, np.int64, np.float32, np.float64], check_rows=n_rows, check_cols=n_cols) return _binary_roc_auc_score(y_true, y_score) @@ -191,8 +191,6 @@ def _binary_clf_curve(y_true, y_score): def _binary_roc_auc_score(y_true, y_score): """Compute binary roc_auc_score using cupy""" - y_true = y_true.to_output('cupy') - y_score = y_score.to_output('cupy') if cp.unique(y_true).shape[0] == 1: raise ValueError("roc_auc_score cannot be used when " diff --git a/python/cuml/metrics/accuracy.pyx b/python/cuml/metrics/accuracy.pyx index 85c43b9fa8..f002e888d0 100644 --- a/python/cuml/metrics/accuracy.pyx +++ b/python/cuml/metrics/accuracy.pyx @@ -22,8 +22,10 @@ from libc.stdint cimport uintptr_t import cudf +import cuml.internals + +from cuml.common.input_utils import input_to_cuml_array from cuml.raft.common.handle cimport handle_t -from cuml.common import input_to_dev_array from cuml.raft.common.handle import Handle cimport cuml.common.cuda @@ -35,6 +37,7 @@ cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics": int n) except + +@cuml.internals.api_return_any() def accuracy_score(ground_truth, predictions, handle=None, convert_dtype=True): """ Calcuates the accuracy score of a classification model. @@ -58,15 +61,19 @@ def accuracy_score(ground_truth, predictions, handle=None, convert_dtype=True): handle.getHandle() cdef uintptr_t preds_ptr, ground_truth_ptr - preds_m, preds_ptr, n_rows, _, _ = \ - input_to_dev_array(predictions, - convert_to_dtype=np.int32 - if convert_dtype else None) - - ground_truth_m, ground_truth_ptr, _, _, ground_truth_dtype=\ - input_to_dev_array(ground_truth, - convert_to_dtype=np.int32 - if convert_dtype else None) + preds_m, n_rows, _, _ = \ + input_to_cuml_array(predictions, + convert_to_dtype=np.int32 + if convert_dtype else None) + + preds_ptr = preds_m.ptr + + ground_truth_m, _, _, ground_truth_dtype=\ + input_to_cuml_array(ground_truth, + convert_to_dtype=np.int32 + if convert_dtype else None) + + ground_truth_ptr = ground_truth_m.ptr acc = accuracy_score_py(handle_[0], preds_ptr, diff --git a/python/cuml/metrics/cluster/adjusted_rand_index.pyx b/python/cuml/metrics/cluster/adjusted_rand_index.pyx index f5cb8b3f61..4c656846c8 100644 --- a/python/cuml/metrics/cluster/adjusted_rand_index.pyx +++ b/python/cuml/metrics/cluster/adjusted_rand_index.pyx @@ -21,6 +21,7 @@ import warnings from libc.stdint cimport uintptr_t +import cuml.internals from cuml.raft.common.handle cimport handle_t from cuml.common import input_to_cuml_array from cuml.raft.common.handle import Handle @@ -34,8 +35,9 @@ cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics": int n) +@cuml.internals.api_return_any() def adjusted_rand_score(labels_true, labels_pred, handle=None, - convert_dtype=True): + convert_dtype=True) -> float: """ Adjusted_rand_score is a clustering similarity metric based on the Rand index and is corrected for chance. diff --git a/python/cuml/metrics/cluster/completeness_score.pyx b/python/cuml/metrics/cluster/completeness_score.pyx index 37f76fffda..c6cfaa39d0 100644 --- a/python/cuml/metrics/cluster/completeness_score.pyx +++ b/python/cuml/metrics/cluster/completeness_score.pyx @@ -16,6 +16,7 @@ # distutils: language = c++ +import cuml.internals from cuml.raft.common.handle cimport handle_t from libc.stdint cimport uintptr_t from cuml.metrics.cluster.utils import prepare_cluster_metric_inputs @@ -29,7 +30,8 @@ cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics": const int upper_class_range) except + -def cython_completeness_score(labels_true, labels_pred, handle=None): +@cuml.internals.api_return_any() +def cython_completeness_score(labels_true, labels_pred, handle=None) -> float: """ Completeness metric of a cluster labeling given a ground truth. diff --git a/python/cuml/metrics/cluster/entropy.pyx b/python/cuml/metrics/cluster/entropy.pyx index c7c3e20452..a1c6360122 100644 --- a/python/cuml/metrics/cluster/entropy.pyx +++ b/python/cuml/metrics/cluster/entropy.pyx @@ -16,14 +16,17 @@ # distutils: language = c++ import math +import typing import numpy as np import cupy as cp from libc.stdint cimport uintptr_t +import cuml.internals from cuml.raft.common.handle cimport handle_t -from cuml.common import with_cupy_rmm, input_to_cuml_array +from cuml.common import CumlArray +from cuml.common.input_utils import input_to_cupy_array from cuml.raft.common.handle import Handle cimport cuml.common.cuda @@ -35,23 +38,23 @@ cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics": const int upper_class_range) except + -@with_cupy_rmm -def _prepare_cluster_input(cluster): +@cuml.internals.api_return_generic() +def _prepare_cluster_input(cluster) -> typing.Tuple[CumlArray, int, int, int]: """Helper function to avoid code duplication for clustering metrics.""" - cluster_m, n_rows, _, _ = input_to_cuml_array( + cluster_m, n_rows, _, _ = input_to_cupy_array( cluster, check_dtype=np.int32, check_cols=1 ) - cp_ground_truth_m = cluster_m.to_output(output_type='cupy') - lower_class_range = cp.min(cp_ground_truth_m) - upper_class_range = cp.max(cp_ground_truth_m) + lower_class_range = cp.min(cluster_m).item() + upper_class_range = cp.max(cluster_m).item() return cluster_m, n_rows, lower_class_range, upper_class_range -def cython_entropy(clustering, base=None, handle=None): +@cuml.internals.api_return_any() +def cython_entropy(clustering, base=None, handle=None) -> float: """ Computes the entropy of a distribution for given probability values. diff --git a/python/cuml/metrics/cluster/homogeneity_score.pyx b/python/cuml/metrics/cluster/homogeneity_score.pyx index a9ccddf95e..16b3d7b39f 100644 --- a/python/cuml/metrics/cluster/homogeneity_score.pyx +++ b/python/cuml/metrics/cluster/homogeneity_score.pyx @@ -16,6 +16,7 @@ # distutils: language = c++ +import cuml.internals from cuml.raft.common.handle cimport handle_t from libc.stdint cimport uintptr_t from cuml.metrics.cluster.utils import prepare_cluster_metric_inputs @@ -29,7 +30,8 @@ cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics": const int upper_class_range) except + -def cython_homogeneity_score(labels_true, labels_pred, handle=None): +@cuml.internals.api_return_any() +def cython_homogeneity_score(labels_true, labels_pred, handle=None) -> float: """ Computes the homogeneity metric of a cluster labeling given a ground truth. diff --git a/python/cuml/metrics/cluster/mutual_info_score.pyx b/python/cuml/metrics/cluster/mutual_info_score.pyx index 6f26817840..7b54b192b9 100644 --- a/python/cuml/metrics/cluster/mutual_info_score.pyx +++ b/python/cuml/metrics/cluster/mutual_info_score.pyx @@ -16,6 +16,7 @@ # distutils: language = c++ +import cuml.internals from cuml.raft.common.handle cimport handle_t from libc.stdint cimport uintptr_t @@ -32,7 +33,8 @@ cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics": const int upper_class_range) except + -def cython_mutual_info_score(labels_true, labels_pred, handle=None): +@cuml.internals.api_return_any() +def cython_mutual_info_score(labels_true, labels_pred, handle=None) -> float: """ Computes the Mutual Information between two clusterings. diff --git a/python/cuml/metrics/cluster/utils.pyx b/python/cuml/metrics/cluster/utils.pyx index da57c3058b..27fff60d6b 100644 --- a/python/cuml/metrics/cluster/utils.pyx +++ b/python/cuml/metrics/cluster/utils.pyx @@ -15,13 +15,16 @@ # # distutils: language = c++ + import cupy as cp + +import cuml.internals from cuml.metrics.utils import sorted_unique_labels from cuml.prims.label import make_monotonic -from cuml.common import with_cupy_rmm, input_to_cuml_array +from cuml.common import input_to_cuml_array -@with_cupy_rmm +@cuml.internals.api_return_generic(get_output_type=True) def prepare_cluster_metric_inputs(labels_true, labels_pred): """Helper function to avoid code duplication for homogeneity score, mutual info score and completeness score. diff --git a/python/cuml/metrics/confusion_matrix.py b/python/cuml/metrics/confusion_matrix.py index 81901feaa9..2ef1a615b8 100644 --- a/python/cuml/metrics/confusion_matrix.py +++ b/python/cuml/metrics/confusion_matrix.py @@ -18,17 +18,20 @@ import cupy as cp import cupyx +import cuml.internals from cuml.common import input_to_cuml_array -from cuml.common.memory_utils import with_cupy_rmm +from cuml.common import using_output_type +from cuml.common.array import CumlArray +from cuml.common.input_utils import input_to_cupy_array from cuml.metrics.utils import sorted_unique_labels from cuml.prims.label import make_monotonic -@with_cupy_rmm +@cuml.internals.api_return_array(get_output_type=True) def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None, - normalize=None): + normalize=None) -> CumlArray: """Compute confusion matrix to evaluate the accuracy of a classification. Parameters @@ -67,25 +70,24 @@ def confusion_matrix(y_true, y_pred, n_labels = len(labels) else: labels, n_labels, _, _ = \ - input_to_cuml_array(labels, check_dtype=dtype, check_cols=1) - labels = labels.to_output('cupy') + input_to_cupy_array(labels, check_dtype=dtype, check_cols=1) if sample_weight is None: sample_weight = cp.ones(n_rows, dtype=dtype) else: sample_weight, _, _, _ = \ - input_to_cuml_array(sample_weight, + input_to_cupy_array(sample_weight, check_dtype=[cp.float32, cp.float64, cp.int32, cp.int64], check_rows=n_rows, check_cols=n_cols) - sample_weight = sample_weight.to_output('cupy') if normalize not in ['true', 'pred', 'all', None]: msg = "normalize must be one of " \ f"{{'true', 'pred', 'all', None}}, got {normalize}." raise ValueError(msg) - y_true, _ = make_monotonic(y_true, labels, copy=True) - y_pred, _ = make_monotonic(y_pred, labels, copy=True) + with using_output_type("cupy"): + y_true, _ = make_monotonic(y_true, labels, copy=True) + y_pred, _ = make_monotonic(y_pred, labels, copy=True) # intersect y_pred, y_true with labels, eliminate items not in labels ind = cp.logical_and(y_pred < n_labels, y_true < n_labels) diff --git a/python/cuml/metrics/pairwise_distances.pyx b/python/cuml/metrics/pairwise_distances.pyx index 53e9cae2f8..54da9df4af 100644 --- a/python/cuml/metrics/pairwise_distances.pyx +++ b/python/cuml/metrics/pairwise_distances.pyx @@ -16,15 +16,17 @@ # distutils: language = c++ +import warnings + from libcpp cimport bool from libc.stdint cimport uintptr_t from cuml.raft.common.handle cimport handle_t from cuml.raft.common.handle import Handle import cupy as cp import numpy as np +import cuml.internals from cuml.common.base import _determine_stateless_output_type -from cuml.common import (get_cudf_column_ptr, get_dev_array_ptr, - input_to_cuml_array, CumlArray, logger, with_cupy_rmm) +from cuml.common import (input_to_cuml_array, CumlArray, logger) from cuml.metrics.cluster.utils import prepare_cluster_metric_inputs cdef extern from "cuml/distance/distance_type.h" namespace "ML::Distance": @@ -100,7 +102,7 @@ def _determine_metric(metric_str): raise ValueError("Unknown metric: {}".format(metric_str)) -@with_cupy_rmm +@cuml.internals.api_return_array(get_output_type=True) def pairwise_distances(X, Y=None, metric="euclidean", handle=None, convert_dtype=True, output_type=None, **kwds): """ @@ -148,6 +150,12 @@ def pairwise_distances(X, Y=None, metric="euclidean", handle=None, module level, `cuml.global_output_type`. See :ref:`output-data-type-configuration` for more info. + .. deprecated:: 0.17 + `output_type` is deprecated in 0.17 and will be removed in 0.18. + Please use the module level output type control, + `cuml.global_output_type`. + See :ref:`output-data-type-configuration` for more info. + Returns ------- D : array [n_samples_x, n_samples_x] or [n_samples_x, n_samples_y] @@ -183,12 +191,18 @@ def pairwise_distances(X, Y=None, metric="euclidean", handle=None, [12., 10.]]) """ + # Check for deprecated `output_type` and warn. Set manually if specified + if (output_type is not None): + warnings.warn("Using the `output_type` argument is deprecated and " + "will be removed in 0.18. Please specify the output " + "type using `cuml.using_output_type()` instead", + DeprecationWarning) + + cuml.internals.set_api_output_type(output_type) + handle = Handle() if handle is None else handle cdef handle_t *handle_ = handle.getHandle() - # Determine the input type to convert to when returning - output_type = _determine_stateless_output_type(output_type, X) - # Get the input arrays, preserve order and type where possible X_m, n_samples_x, n_features_x, dtype_x = \ input_to_cuml_array(X, order="K", check_dtype=[np.float32, np.float64]) @@ -273,4 +287,4 @@ def pairwise_distances(X, Y=None, metric="euclidean", handle=None, del X_m del Y_m - return dest_m.to_output(output_type) + return dest_m diff --git a/python/cuml/metrics/regression.pyx b/python/cuml/metrics/regression.pyx index 60324699d1..daf5775f88 100644 --- a/python/cuml/metrics/regression.pyx +++ b/python/cuml/metrics/regression.pyx @@ -21,14 +21,16 @@ import cupy as cp from libc.stdint cimport uintptr_t +import cuml.internals +from cuml.common.array import CumlArray from cuml.raft.common.handle import Handle from cuml.raft.common.handle cimport handle_t from cuml.metrics cimport regression from cuml.common.input_utils import input_to_cuml_array -from cuml.common.memory_utils import with_cupy_rmm -def r2_score(y, y_hat, convert_dtype=True, handle=None): +@cuml.internals.api_return_any() +def r2_score(y, y_hat, convert_dtype=True, handle=None) -> double: """ Calculates r2 score between y and y_hat @@ -138,7 +140,6 @@ def _prepare_input_reg(y_true, y_pred, sample_weight, multioutput): return y_true, y_pred, sample_weight, multioutput, raw_multioutput -@with_cupy_rmm def _mse(y_true, y_pred, sample_weight, multioutput, squared, raw_multioutput): """Helper to compute the mean squared error""" output_errors = cp.subtract(y_true, y_pred) @@ -153,7 +154,7 @@ def _mse(y_true, y_pred, sample_weight, multioutput, squared, raw_multioutput): return mse if squared else cp.sqrt(mse) -@with_cupy_rmm +@cuml.internals.api_return_any() def mean_squared_error(y_true, y_pred, sample_weight=None, multioutput='uniform_average', @@ -198,7 +199,7 @@ def mean_squared_error(y_true, y_pred, raw_multioutput) -@with_cupy_rmm +@cuml.internals.api_return_any() def mean_absolute_error(y_true, y_pred, sample_weight=None, multioutput='uniform_average'): @@ -249,7 +250,7 @@ def mean_absolute_error(y_true, y_pred, return cp.average(output_errors, weights=multioutput) -@with_cupy_rmm +@cuml.internals.api_return_any() def mean_squared_log_error(y_true, y_pred, sample_weight=None, multioutput='uniform_average', diff --git a/python/cuml/metrics/trustworthiness.pyx b/python/cuml/metrics/trustworthiness.pyx index b8fff10683..2dae1ec129 100644 --- a/python/cuml/metrics/trustworthiness.pyx +++ b/python/cuml/metrics/trustworthiness.pyx @@ -23,10 +23,10 @@ import warnings from numba import cuda from libc.stdint cimport uintptr_t +import cuml.internals +from cuml.common.input_utils import input_to_cuml_array from cuml.raft.common.handle import Handle from cuml.raft.common.handle cimport handle_t -from cuml.common import get_cudf_column_ptr, get_dev_array_ptr, \ - input_to_dev_array cdef extern from "cuml/distance/distance_type.h" namespace "ML::Distance": @@ -52,9 +52,10 @@ def _get_array_ptr(obj): return obj.device_ctypes_pointer.value +@cuml.internals.api_return_any() def trustworthiness(X, X_embedded, handle=None, n_neighbors=5, metric='euclidean', should_downcast=True, - convert_dtype=False, batch_size=512): + convert_dtype=False, batch_size=512) -> double: """ Expresses to what extent the local structure is retained in embedding. The score is defined in the range [0, 1]. @@ -90,15 +91,18 @@ def trustworthiness(X, X_embedded, handle=None, n_neighbors=5, cdef uintptr_t d_X_ptr cdef uintptr_t d_X_embedded_ptr - X_m, d_X_ptr, n_samples, n_features, dtype1 = \ - input_to_dev_array(X, order='C', check_dtype=np.float32, - convert_to_dtype=(np.float32 if convert_dtype - else None)) - X_m2, d_X_embedded_ptr, n_rows, n_components, dtype2 = \ - input_to_dev_array(X_embedded, order='C', - check_dtype=np.float32, - convert_to_dtype=(np.float32 if convert_dtype - else None)) + X_m, n_samples, n_features, dtype1 = \ + input_to_cuml_array(X, order='C', check_dtype=np.float32, + convert_to_dtype=(np.float32 if convert_dtype + else None)) + d_X_ptr = X_m.ptr + + X_m2, n_rows, n_components, dtype2 = \ + input_to_cuml_array(X_embedded, order='C', + check_dtype=np.float32, + convert_to_dtype=(np.float32 if convert_dtype + else None)) + d_X_embedded_ptr = X_m2.ptr handle = Handle() if handle is None else handle cdef handle_t* handle_ = handle.getHandle() diff --git a/python/cuml/naive_bayes/naive_bayes.py b/python/cuml/naive_bayes/naive_bayes.py index 39dcc7ab91..b1cfa64f08 100644 --- a/python/cuml/naive_bayes/naive_bayes.py +++ b/python/cuml/naive_bayes/naive_bayes.py @@ -13,26 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # - - -import cupy as cp -import cupyx -import cupy.prof import math import warnings -from cuml.common import with_cupy_rmm +import cupy as cp +import cupy.prof +import cupyx from cuml.common import CumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.base import Base from cuml.common.doc_utils import generate_docstring -from cuml.common.input_utils import input_to_cuml_array -from cuml.common.kernel_utils import cuda_kernel_factory from cuml.common.import_utils import has_scipy -from cuml.prims.label import make_monotonic -from cuml.prims.label import check_labels -from cuml.prims.label import invert_labels - +from cuml.common.input_utils import input_to_cuml_array, input_to_cupy_array +from cuml.common.kernel_utils import cuda_kernel_factory from cuml.metrics import accuracy_score +from cuml.prims.label import check_labels, invert_labels, make_monotonic def count_features_coo_kernel(float_dtype, int_dtype): @@ -62,8 +57,7 @@ def count_features_coo_kernel(float_dtype, int_dtype): atomicAdd(out + ((col * n_classes) + label), val); }''' - return cuda_kernel_factory(kernel_str, - (float_dtype, int_dtype), + return cuda_kernel_factory(kernel_str, (float_dtype, int_dtype), "count_features_coo") @@ -77,8 +71,7 @@ def count_classes_kernel(float_dtype, int_dtype): atomicAdd(out + label, 1); }''' - return cuda_kernel_factory(kernel_str, - (float_dtype, int_dtype), + return cuda_kernel_factory(kernel_str, (float_dtype, int_dtype), "count_classes") @@ -110,13 +103,11 @@ def count_features_dense_kernel(float_dtype, int_dtype): atomicAdd(out + ((col * n_classes) + label), val); }''' - return cuda_kernel_factory(kernel_str, - (float_dtype, int_dtype,), + return cuda_kernel_factory(kernel_str, (float_dtype, int_dtype), "count_features_dense") class MultinomialNB(Base): - """ Naive Bayes classifier for multinomial models @@ -206,7 +197,13 @@ class MultinomialNB(Base): 0.9244298934936523 """ - @with_cupy_rmm + + classes_ = CumlArrayDescriptor() + class_count_ = CumlArrayDescriptor() + feature_count_ = CumlArrayDescriptor() + class_log_prior_ = CumlArrayDescriptor() + feature_log_prob_ = CumlArrayDescriptor() + def __init__(self, alpha=1.0, fit_prior=True, @@ -234,19 +231,19 @@ def __init__(self, @generate_docstring(X='dense_sparse') @cp.prof.TimeRangeDecorator(message="fit()", color_id=0) - @with_cupy_rmm - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None) -> "MultinomialNB": """ Fit Naive Bayes classifier according to X, y """ - self._set_base_attributes(output_type=X) return self.partial_fit(X, y, sample_weight) @cp.prof.TimeRangeDecorator(message="fit()", color_id=0) - @with_cupy_rmm - def _partial_fit(self, X, y, sample_weight=None, _classes=None): - self._set_output_type(X) + def _partial_fit(self, + X, + y, + sample_weight=None, + _classes=None) -> "MultinomialNB": if has_scipy(): from scipy.sparse import isspmatrix as scipy_sparse_isspmatrix @@ -264,9 +261,9 @@ def _partial_fit(self, X, y, sample_weight=None, _classes=None): X = cupyx.scipy.sparse.coo_matrix((data, (rows, cols)), shape=X.shape) else: - X = input_to_cuml_array(X, order='K').array.to_output('cupy') + X = input_to_cupy_array(X, order='K').array - y = input_to_cuml_array(y).array.to_output('cupy') + y = input_to_cupy_array(y).array Y, label_classes = make_monotonic(y, copy=True) @@ -274,17 +271,16 @@ def _partial_fit(self, X, y, sample_weight=None, _classes=None): self.fit_called_ = True if _classes is not None: _classes, *_ = input_to_cuml_array(_classes, order='K') - check_labels(Y, _classes.to_output('cupy')) - self._classes_ = _classes + check_labels(Y, _classes) + self.classes_ = _classes else: - self._classes_ = CumlArray(data=label_classes) + self.classes_ = label_classes self._n_classes_ = self.classes_.shape[0] self._n_features_ = X.shape[1] - self._init_counters(self._n_classes_, self._n_features_, - X.dtype) + self._init_counters(self._n_classes_, self._n_features_, X.dtype) else: - check_labels(Y, self._classes_) + check_labels(Y, self.classes_) self._count(X, Y) @@ -293,7 +289,6 @@ def _partial_fit(self, X, y, sample_weight=None, _classes=None): return self - @with_cupy_rmm def update_log_probs(self): """ Updates the log probabilities. This enables lazy update for @@ -304,8 +299,11 @@ def update_log_probs(self): self._update_feature_log_prob(self.alpha) self._update_class_log_prior(class_prior=self._class_prior_) - @with_cupy_rmm - def partial_fit(self, X, y, classes=None, sample_weight=None): + def partial_fit(self, + X, + y, + classes=None, + sample_weight=None) -> "MultinomialNB": """ Incremental fit on a batch of samples. @@ -342,23 +340,24 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): self : object """ - return self._partial_fit(X, y, sample_weight=sample_weight, + return self._partial_fit(X, + y, + sample_weight=sample_weight, _classes=classes) @generate_docstring(X='dense_sparse', - return_values={'name': 'y_hat', - 'type': 'dense', - 'description': 'Predicted values', - 'shape': '(n_rows, 1)'}) + return_values={ + 'name': 'y_hat', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_rows, 1)' + }) @cp.prof.TimeRangeDecorator(message="predict()", color_id=1) - @with_cupy_rmm - def predict(self, X): + def predict(self, X) -> CumlArray: """ Perform classification on an array of test vectors X. """ - out_type = self._get_output_type(X) - if has_scipy(): from scipy.sparse import isspmatrix as scipy_sparse_isspmatrix else: @@ -375,29 +374,30 @@ def predict(self, X): X = cupyx.scipy.sparse.coo_matrix((data, (rows, cols)), shape=X.shape) else: - X = input_to_cuml_array(X, order='K').array.to_output('cupy') + X = input_to_cupy_array(X, order='K').array jll = self._joint_log_likelihood(X) indices = cp.argmax(jll, axis=1).astype(self.classes_.dtype) y_hat = invert_labels(indices, classes=self.classes_) - return CumlArray(data=y_hat).to_output(out_type) - - @generate_docstring(X='dense_sparse', - return_values={'name': 'C', - 'type': 'dense', - 'description': 'Returns the log-probability of the samples for each class in the \ - model. The columns correspond to the classes in sorted order, as \ - they appear in the attribute `classes_`.', # noqa - 'shape': '(n_rows, 1)'}) - @with_cupy_rmm - def predict_log_proba(self, X): + return y_hat + + @generate_docstring( + X='dense_sparse', + return_values={ + 'name': 'C', + 'type': 'dense', + 'description': ( + 'Returns the log-probability of the samples for each class in ' + 'the model. The columns correspond to the classes in sorted ' + 'order, as they appear in the attribute `classes_`.'), + 'shape': '(n_rows, 1)' + }) + def predict_log_proba(self, X) -> CumlArray: """ Return log-probability estimates for the test vector X. """ - out_type = self._get_output_type(X) - if has_scipy(): from scipy.sparse import isspmatrix as scipy_sparse_isspmatrix else: @@ -414,7 +414,7 @@ def predict_log_proba(self, X): X = cupyx.scipy.sparse.coo_matrix((data, (rows, cols)), shape=X.shape) else: - X = input_to_cuml_array(X, order='K').array.to_output('cupy') + X = input_to_cupy_array(X, order='K').array jll = self._joint_log_likelihood(X) @@ -435,32 +435,38 @@ def predict_log_proba(self, X): if log_prob_x.ndim < 2: log_prob_x = log_prob_x.reshape((1, log_prob_x.shape[0])) result = jll - log_prob_x.T - return CumlArray(result).to_output(out_type) - - @generate_docstring(X='dense_sparse', - return_values={'name': 'C', - 'type': 'dense', - 'description': 'Returns the probability of the samples for each class in the \ - model. The columns correspond to the classes in sorted order, as \ - they appear in the attribute `classes_`.', # noqa - 'shape': '(n_rows, 1)'}) - @with_cupy_rmm - def predict_proba(self, X): + return result + + @generate_docstring( + X='dense_sparse', + return_values={ + 'name': 'C', + 'type': 'dense', + 'description': ( + 'Returns the probability of the samples for each class in the ' + 'model. The columns correspond to the classes in sorted order,' + ' as they appear in the attribute `classes_`.'), + 'shape': '(n_rows, 1)' + }) + def predict_proba(self, X) -> CumlArray: """ Return probability estimates for the test vector X. """ - out_type = self._get_output_type(X) result = cp.exp(self.predict_log_proba(X)) - return CumlArray(result).to_output(out_type) - - @generate_docstring(X='dense_sparse', - return_values={'name': 'score', - 'type': 'float', - 'description': 'Mean accuracy of \ - self.predict(X) with respect to y.'}) - @with_cupy_rmm - def score(self, X, y, sample_weight=None): + return result + + @generate_docstring( + X='dense_sparse', + return_values={ + 'name': + 'score', + 'type': + 'float', + 'description': + 'Mean accuracy of self.predict(X) with respect to y.' + }) + def score(self, X, y, sample_weight=None) -> float: """ Return the mean accuracy on the given test data and labels. @@ -475,11 +481,12 @@ def score(self, X, y, sample_weight=None): return accuracy_score(y_hat, cp.asarray(y, dtype=y.dtype)) def _init_counters(self, n_effective_classes, n_features, dtype): - self._class_count_ = CumlArray.zeros(n_effective_classes, - order="F", dtype=dtype) - self._feature_count_ = CumlArray.zeros((n_effective_classes, - n_features), - order="F", dtype=dtype) + self.class_count_ = cp.zeros(n_effective_classes, + order="F", + dtype=dtype) + self.feature_count_ = cp.zeros((n_effective_classes, n_features), + order="F", + dtype=dtype) def _count(self, X, Y): """ @@ -499,7 +506,11 @@ def _count(self, X, Y): warnings.warn("Y dtype does not match classes_ dtype. Y will be " "converted, which will increase memory consumption") - counts = cp.zeros((self._n_classes_, self._n_features_), order="F", + # Make sure Y is a cupy array, not CumlArray + Y = cp.asarray(Y) + + counts = cp.zeros((self._n_classes_, self._n_features_), + order="F", dtype=X.dtype) class_c = cp.zeros(self._n_classes_, order="F", dtype=X.dtype) @@ -512,9 +523,9 @@ def _count(self, X, Y): if cupyx.scipy.sparse.isspmatrix(X): X = X.tocoo() - count_features_coo = count_features_coo_kernel(X.dtype, - labels_dtype) - count_features_coo((math.ceil(X.nnz / 32),), (32,), + count_features_coo = count_features_coo_kernel( + X.dtype, labels_dtype) + count_features_coo((math.ceil(X.nnz / 32), ), (32, ), (counts, X.row, X.col, @@ -523,30 +534,30 @@ def _count(self, X, Y): n_rows, n_cols, Y, - self._n_classes_, False)) + self._n_classes_, + False)) else: - count_features_dense = count_features_dense_kernel(X.dtype, - labels_dtype) - count_features_dense((math.ceil(n_rows / 32), - math.ceil(n_cols / 32), 1), - (32, 32, 1), - (counts, - X, - n_rows, - n_cols, - Y, - self._n_classes_, - False, - X.flags["C_CONTIGUOUS"])) + count_features_dense = count_features_dense_kernel( + X.dtype, labels_dtype) + count_features_dense( + (math.ceil(n_rows / 32), math.ceil(n_cols / 32), 1), + (32, 32, 1), + (counts, + X, + n_rows, + n_cols, + Y, + self._n_classes_, + False, + X.flags["C_CONTIGUOUS"])) count_classes = count_classes_kernel(X.dtype, labels_dtype) - count_classes((math.ceil(n_rows / 32),), (32,), - (class_c, n_rows, Y)) + count_classes((math.ceil(n_rows / 32), ), (32, ), (class_c, n_rows, Y)) - self._feature_count_ = CumlArray(self._feature_count_ + counts) - self._class_count_ = CumlArray(self._class_count_ + class_c) + self.feature_count_ = self.feature_count_ + counts + self.class_count_ = self.class_count_ + class_c def _update_class_log_prior(self, class_prior=None): @@ -556,16 +567,17 @@ def _update_class_log_prior(self, class_prior=None): raise ValueError("Number of classes must match " "number of priors") - self._class_log_prior_ = cp.log(class_prior) + self.class_log_prior_ = cp.log(class_prior) elif self.fit_prior: - log_class_count = cp.log(self._class_count_) - self._class_log_prior_ = \ - CumlArray(log_class_count - cp.log( - cp.asarray(self._class_count_).sum())) + log_class_count = cp.log(self.class_count_) + + self.class_log_prior_ = \ + log_class_count - cp.log( + self.class_count_.sum()) else: - self._class_log_prior_ = CumlArray(cp.full(self._n_classes_, - -1*math.log(self._n_classes_))) + self.class_log_prior_ = cp.full(self._n_classes_, + -1 * math.log(self._n_classes_)) def _update_feature_log_prob(self, alpha): """ @@ -577,10 +589,10 @@ def _update_feature_log_prob(self, alpha): alpha : float amount of smoothing to apply (0. means no smoothing) """ - smoothed_fc = cp.asarray(self._feature_count_) + alpha + smoothed_fc = self.feature_count_ + alpha smoothed_cc = smoothed_fc.sum(axis=1).reshape(-1, 1) - self._feature_log_prob_ = CumlArray(cp.log(smoothed_fc) - - cp.log(smoothed_cc.reshape(-1, 1))) + self.feature_log_prob_ = cp.log(smoothed_fc) - cp.log( + smoothed_cc.reshape(-1, 1)) def _joint_log_likelihood(self, X): """ @@ -592,8 +604,8 @@ def _joint_log_likelihood(self, X): X : array-like of size (n_samples, n_features) """ - ret = X.dot(cp.asarray(self._feature_log_prob_).T) - ret += cp.asarray(self._class_log_prior_) + ret = X.dot(self.feature_log_prob_.T) + ret += self.class_log_prior_ return ret def get_param_names(self): diff --git a/python/cuml/neighbors/kneighbors_classifier.pyx b/python/cuml/neighbors/kneighbors_classifier.pyx index a92108995a..e89b563449 100644 --- a/python/cuml/neighbors/kneighbors_classifier.pyx +++ b/python/cuml/neighbors/kneighbors_classifier.pyx @@ -16,10 +16,14 @@ # distutils: language = c++ +import typing + from cuml.neighbors.nearest_neighbors import NearestNeighbors +import cuml.internals from cuml.common.array import CumlArray from cuml.common import input_to_cuml_array +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.base import ClassifierMixin from cuml.common.doc_utils import generate_docstring @@ -33,8 +37,6 @@ from cython.operator cimport dereference as deref from cuml.raft.common.handle cimport handle_t from libcpp.vector cimport vector -from cuml.common import with_cupy_rmm - from libcpp cimport bool from libcpp.memory cimport shared_ptr @@ -142,6 +144,9 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): `_. """ + y = CumlArrayDescriptor() + classes_ = CumlArrayDescriptor() + def __init__(self, weights="uniform", *, handle=None, verbose=False, output_type=None, **kwargs): super(KNeighborsClassifier, self).__init__( @@ -150,8 +155,8 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): output_type=output_type, **kwargs) - self._y = None - self._classes_ = None + self.y = None + self.classes_ = None self.weights = weights if weights != "uniform": @@ -159,21 +164,19 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): "supported currently.") @generate_docstring(convert_dtype_cast='np.float32') - @with_cupy_rmm - def fit(self, X, y, convert_dtype=True): + @cuml.internals.api_base_return_any(set_output_dtype=True) + def fit(self, X, y, convert_dtype=True) -> "KNeighborsClassifier": """ Fit a GPU index for k-nearest neighbors classifier model. """ - self._set_base_attributes(output_type=X, target_dtype=y) - super(KNeighborsClassifier, self).fit(X, convert_dtype) - self._y, _, _, _ = \ + self.y, _, _, _ = \ input_to_cuml_array(y, order='F', check_dtype=np.int32, convert_to_dtype=(np.int32 if convert_dtype else None)) - self._classes_ = CumlArray(cp.unique(self._y)) + self.classes_ = cp.unique(self.y) return self @generate_docstring(convert_dtype_cast='np.float32', @@ -181,16 +184,13 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): 'type': 'dense', 'description': 'Labels predicted', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=True): + @cuml.internals.api_base_return_array(get_output_dtype=True) + def predict(self, X, convert_dtype=True) -> CumlArray: """ Use the trained k-nearest neighbors classifier to predict the labels for X """ - - out_type = self._get_output_type(X) - out_dtype = self._get_target_dtype() - knn_indices = self.kneighbors(X, return_distance=False, convert_dtype=convert_dtype) @@ -201,7 +201,7 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): else None)) cdef uintptr_t inds_ctype = inds.ptr - out_cols = self._y.shape[1] if len(self._y.shape) == 2 else 1 + out_cols = self.y.shape[1] if len(self.y.shape) == 2 else 1 out_shape = (n_rows, out_cols) if out_cols > 1 else n_rows @@ -213,7 +213,7 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): # classification cdef uintptr_t y_ptr for i in range(out_cols): - col = self._y[:, i] if out_cols > 1 else self._y + col = self.y[:, i] if out_cols > 1 else self.y y_ptr = col.ptr y_vec.push_back(y_ptr) @@ -233,23 +233,23 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): self.handle.sync() - return classes.to_output(output_type=out_type, output_dtype=out_dtype) + return classes @generate_docstring(convert_dtype_cast='np.float32', return_values={'name': 'X_new', 'type': 'dense', 'description': 'Labels probabilities', 'shape': '(n_samples, 1)'}) - @with_cupy_rmm - def predict_proba(self, X, convert_dtype=True): + @cuml.internals.api_base_return_generic() + def predict_proba( + self, + X, + convert_dtype=True) -> typing.Union[CumlArray, typing.Tuple]: """ Use the trained k-nearest neighbors classifier to predict the label probabilities for X """ - - out_type = self._get_output_type(X) - knn_indices = self.kneighbors(X, return_distance=False, convert_dtype=convert_dtype) @@ -261,7 +261,7 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): else None)) cdef uintptr_t inds_ctype = inds.ptr - out_cols = self._y.shape[1] if len(self._y.shape) == 2 else 1 + out_cols = self.y.shape[1] if len(self.y.shape) == 2 else 1 cdef vector[int*] *y_vec = new vector[int*]() cdef vector[float*] *out_vec = new vector[float*]() @@ -270,7 +270,7 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): cdef uintptr_t classes_ptr cdef uintptr_t y_ptr for out_col in range(out_cols): - col = self._y[:, out_col] if out_cols > 1 else self._y + col = self.y[:, out_col] if out_cols > 1 else self.y classes = CumlArray.zeros((n_rows, len(cp.unique(cp.asarray(col)))), dtype=np.float32, @@ -298,7 +298,7 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): final_classes = [] for out_class in out_classes: - final_classes.append(out_class.to_output(out_type)) + final_classes.append(out_class) return final_classes[0] \ if len(final_classes) == 1 else tuple(final_classes) diff --git a/python/cuml/neighbors/kneighbors_classifier_mg.pyx b/python/cuml/neighbors/kneighbors_classifier_mg.pyx index 2e4d8f5733..d72ac2f42b 100644 --- a/python/cuml/neighbors/kneighbors_classifier_mg.pyx +++ b/python/cuml/neighbors/kneighbors_classifier_mg.pyx @@ -16,8 +16,11 @@ # distutils: language = c++ +import typing + import numpy as np +import cuml.internals from cuml.common.array import CumlArray from cuml.raft.common.handle cimport handle_t from cuml.common import input_to_cuml_array @@ -76,9 +79,23 @@ class KNeighborsClassifierMG(KNeighborsMG): super(KNeighborsClassifierMG, self).__init__(**kwargs) self.batch_size = batch_size - def predict(self, data, data_parts_to_ranks, data_nrows, - query, query_parts_to_ranks, query_nrows, - uniq_labels, n_unique, ncols, rank, convert_dtype): + @cuml.internals.api_base_return_generic_skipall + def predict( + self, + data, + data_parts_to_ranks, + data_nrows, + query, + query_parts_to_ranks, + query_nrows, + uniq_labels, + n_unique, + ncols, + rank, + convert_dtype + ) -> typing.Tuple[typing.List[CumlArray], + typing.List[CumlArray], + typing.List[CumlArray]]: """ Predict labels for a query from previously stored index and index labels. @@ -103,7 +120,7 @@ class KNeighborsClassifierMG(KNeighborsMG): ------- predictions : labels, indices, distances """ - out_type = self.get_out_type(data, query) + self.get_out_type(data, query) input = self.gen_local_input(data, data_parts_to_ranks, data_nrows, query, query_parts_to_ranks, query_nrows, @@ -178,17 +195,15 @@ class KNeighborsClassifierMG(KNeighborsMG): free(out_result_local_parts.at(i)) free(out_result_local_parts) - output = list(map(lambda o: o.to_output(out_type), output_cais)) - output_i = list(map(lambda o: o.to_output(out_type), - result['cais']['indices'])) - output_d = list(map(lambda o: o.to_output(out_type), - result['cais']['distances'])) - - return output, output_i, output_d + return output_cais, \ + result['cais']['indices'], \ + result['cais']['distances'] + @cuml.internals.api_base_return_generic_skipall def predict_proba(self, data, data_parts_to_ranks, data_nrows, query, query_parts_to_ranks, query_nrows, - uniq_labels, n_unique, ncols, rank, convert_dtype): + uniq_labels, n_unique, ncols, rank, + convert_dtype) -> tuple: """ Predict labels for a query from previously stored index and index labels. @@ -213,7 +228,7 @@ class KNeighborsClassifierMG(KNeighborsMG): ------- predictions : labels, indices, distances """ - out_type = self.get_out_type(data, query) + self.get_out_type(data, query) input = self.gen_local_input(data, data_parts_to_ranks, data_nrows, query, query_parts_to_ranks, query_nrows, @@ -291,7 +306,6 @@ class KNeighborsClassifierMG(KNeighborsMG): probas_out = [] for i in range(n_outputs): - probas_out.append(list(map(lambda o: o.to_output(out_type), - proba_cais[i]))) + probas_out.append(proba_cais[i]) return tuple(probas_out) diff --git a/python/cuml/neighbors/kneighbors_mg.pyx b/python/cuml/neighbors/kneighbors_mg.pyx index 29493731c4..4839a55eb5 100644 --- a/python/cuml/neighbors/kneighbors_mg.pyx +++ b/python/cuml/neighbors/kneighbors_mg.pyx @@ -18,6 +18,7 @@ import numpy as np +import cuml.internals from cuml.common.array import CumlArray from cuml.raft.common.handle cimport handle_t from cuml.common import input_to_cuml_array @@ -46,6 +47,8 @@ class KNeighborsMG(NearestNeighbors): out_type = self.output_type if len(query) > 0: out_type = self._get_output_type(query[0]) + + cuml.internals.set_api_output_type(out_type) return out_type def gen_local_input(self, data, data_parts_to_ranks, data_nrows, diff --git a/python/cuml/neighbors/kneighbors_regressor.pyx b/python/cuml/neighbors/kneighbors_regressor.pyx index 467fe5178e..f626d36082 100644 --- a/python/cuml/neighbors/kneighbors_regressor.pyx +++ b/python/cuml/neighbors/kneighbors_regressor.pyx @@ -18,8 +18,10 @@ from cuml.neighbors.nearest_neighbors import NearestNeighbors +import cuml.internals from cuml.common.array import CumlArray from cuml.common import input_to_cuml_array +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.base import RegressorMixin from cuml.common.doc_utils import generate_docstring @@ -142,6 +144,8 @@ class KNeighborsRegressor(NearestNeighbors, RegressorMixin): `_. """ + y = CumlArrayDescriptor() + def __init__(self, weights="uniform", *, handle=None, verbose=False, output_type=None, **kwargs): super(KNeighborsRegressor, self).__init__( @@ -149,21 +153,22 @@ class KNeighborsRegressor(NearestNeighbors, RegressorMixin): verbose=verbose, output_type=output_type, **kwargs) - self._y = None + self.y = None self.weights = weights if weights != "uniform": raise ValueError("Only uniform weighting strategy " "is supported currently.") @generate_docstring(convert_dtype_cast='np.float32') - def fit(self, X, y, convert_dtype=True): + def fit(self, X, y, convert_dtype=True) -> "KNeighborsRegressor": """ Fit a GPU index for k-nearest neighbors regression model. """ self._set_target_dtype(y) + super(KNeighborsRegressor, self).fit(X, convert_dtype=convert_dtype) - self._y, _, _, _ = \ + self.y, _, _, _ = \ input_to_cuml_array(y, order='F', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype @@ -175,15 +180,14 @@ class KNeighborsRegressor(NearestNeighbors, RegressorMixin): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, n_features)'}) - def predict(self, X, convert_dtype=True): + def predict(self, X, convert_dtype=True) -> CumlArray: """ Use the trained k-nearest neighbors regression model to predict the labels for X """ - - out_type = self._get_output_type(X) - out_dtype = self._get_target_dtype() if convert_dtype else None + if (convert_dtype): + cuml.internals.set_api_output_dtype(self._get_target_dtype()) knn_indices = self.kneighbors(X, return_distance=False, convert_dtype=convert_dtype) @@ -195,7 +199,7 @@ class KNeighborsRegressor(NearestNeighbors, RegressorMixin): else None)) cdef uintptr_t inds_ctype = inds.ptr - res_cols = 1 if len(self._y.shape) == 1 else self._y.shape[1] + res_cols = 1 if len(self.y.shape) == 1 else self.y.shape[1] res_shape = n_rows if res_cols == 1 else (n_rows, res_cols) results = CumlArray.zeros(res_shape, dtype=np.float32, order="C") @@ -205,7 +209,7 @@ class KNeighborsRegressor(NearestNeighbors, RegressorMixin): cdef vector[float*] *y_vec = new vector[float*]() for col_num in range(res_cols): - col = self._y if res_cols == 1 else self._y[:, col_num] + col = self.y if res_cols == 1 else self.y[:, col_num] y_ptr = col.ptr y_vec.push_back(y_ptr) @@ -223,7 +227,7 @@ class KNeighborsRegressor(NearestNeighbors, RegressorMixin): self.handle.sync() - return results.to_output(out_type, output_dtype=out_dtype) + return results def get_param_names(self): return super().get_param_names() + ["weights"] diff --git a/python/cuml/neighbors/kneighbors_regressor_mg.pyx b/python/cuml/neighbors/kneighbors_regressor_mg.pyx index d615ec9226..d1cd90094d 100644 --- a/python/cuml/neighbors/kneighbors_regressor_mg.pyx +++ b/python/cuml/neighbors/kneighbors_regressor_mg.pyx @@ -16,8 +16,11 @@ # distutils: language = c++ +import typing + import numpy as np +import cuml.internals from cuml.common.array import CumlArray from cuml.raft.common.handle cimport handle_t from cuml.common import input_to_cuml_array @@ -68,9 +71,22 @@ class KNeighborsRegressorMG(KNeighborsMG): super(KNeighborsRegressorMG, self).__init__(**kwargs) self.batch_size = batch_size - def predict(self, data, data_parts_to_ranks, data_nrows, - query, query_parts_to_ranks, query_nrows, - ncols, n_outputs, rank, convert_dtype): + @cuml.internals.api_base_return_generic_skipall + def predict( + self, + data, + data_parts_to_ranks, + data_nrows, + query, + query_parts_to_ranks, + query_nrows, + ncols, + n_outputs, + rank, + convert_dtype + ) -> typing.Tuple[typing.List[CumlArray], + typing.List[CumlArray], + typing.List[CumlArray]]: """ Predict outputs for a query from previously stored index and index labels. @@ -93,9 +109,7 @@ class KNeighborsRegressorMG(KNeighborsMG): ------- predictions : outputs, indices, distances """ - out_type = self.get_out_type(data, query) - - out_type = self.get_out_type(data, query) + self.get_out_type(data, query) input = self.gen_local_input(data, data_parts_to_ranks, data_nrows, query, query_parts_to_ranks, query_nrows, @@ -148,10 +162,6 @@ class KNeighborsRegressorMG(KNeighborsMG): free(out_result_local_parts.at(i)) free(out_result_local_parts) - output = list(map(lambda o: o.to_output(out_type), output_cais)) - output_i = list(map(lambda o: o.to_output(out_type), - result['cais']['indices'])) - output_d = list(map(lambda o: o.to_output(out_type), - result['cais']['distances'])) - - return output, output_i, output_d + return output_cais, \ + result['cais']['indices'], \ + result['cais']['distances'] diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index b7dbc0376c..fcc68582e6 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -16,16 +16,20 @@ # distutils: language = c++ +import typing + import numpy as np import cupy as cp import cupyx import cudf import ctypes -import cuml import warnings +import cuml.internals from cuml.common.base import Base +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.array import CumlArray +from cuml.common.array_sparse import SparseCumlArray from cuml.common.doc_utils import generate_docstring from cuml.common.doc_utils import insert_into_docstring from cuml.common import input_to_cuml_array @@ -187,6 +191,9 @@ class NearestNeighbors(Base): For additional docs, see `scikit-learn's NearestNeighbors `_. """ + + X_m = CumlArrayDescriptor() + def __init__(self, n_neighbors=5, verbose=False, @@ -218,19 +225,17 @@ class NearestNeighbors(Base): self.algorithm = algorithm @generate_docstring() - def fit(self, X, convert_dtype=True): + def fit(self, X, convert_dtype=True) -> "NearestNeighbors": """ Fit GPU index for performing nearest neighbor queries. """ - self._set_base_attributes(output_type=X, n_features=X) - if len(X.shape) != 2: raise ValueError("data should be two dimensional") self.n_dims = X.shape[1] - self._X_m, n_rows, n_cols, dtype = \ + self.X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, order='F', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype @@ -282,8 +287,13 @@ class NearestNeighbors(Base): return_values=[('dense', '(n_samples, n_features)'), ('dense', '(n_samples, n_features)')]) - def kneighbors(self, X=None, n_neighbors=None, return_distance=True, - convert_dtype=True): + def kneighbors( + self, + X=None, + n_neighbors=None, + return_distance=True, + convert_dtype=True + ) -> typing.Union[CumlArray, typing.Tuple[CumlArray, CumlArray]]: """ Query the GPU index for the k nearest neighbors of column vectors in X. @@ -315,7 +325,7 @@ class NearestNeighbors(Base): return self._kneighbors(X, n_neighbors, return_distance, convert_dtype) def _kneighbors(self, X=None, n_neighbors=None, return_distance=True, - convert_dtype=True, _output_cumlarray=False): + convert_dtype=True): """ Query the GPU index for the k nearest neighbors of column vectors in X. @@ -343,25 +353,25 @@ class NearestNeighbors(Base): Returns ------- - distances: cuDF DataFrame, pandas DataFrame, numpy or cupy ndarray + distances: cupy ndarray The distances of the k-nearest neighbors for each column vector in X - indices: cuDF DataFrame, pandas DataFrame, numpy or cupy ndarray + indices: cupy ndarray The indices of the k-nearest neighbors for each column vector in X """ n_neighbors = self.n_neighbors if n_neighbors is None else n_neighbors use_training_data = X is None if X is None: - X = self._X_m + X = self.X_m n_neighbors += 1 if (n_neighbors is None and self.n_neighbors is None) \ or n_neighbors <= 0: raise ValueError("k or n_neighbors must be a positive integers") - if n_neighbors > self._X_m.shape[0]: + if n_neighbors > self.X_m.shape[0]: raise ValueError("n_neighbors must be <= number of " "samples in index") @@ -390,9 +400,9 @@ class NearestNeighbors(Base): cdef vector[float*] *inputs = new vector[float*]() cdef vector[int] *sizes = new vector[int]() - cdef uintptr_t idx_ptr = self._X_m.ptr + cdef uintptr_t idx_ptr = self.X_m.ptr inputs.push_back(idx_ptr) - sizes.push_back(self._X_m.shape[0]) + sizes.push_back(self.X_m.shape[0]) cdef handle_t* handle_ = self.handle.getHandle() @@ -421,29 +431,19 @@ class NearestNeighbors(Base): self.handle.sync() - if _output_cumlarray: - return (D_ndarr, I_ndarr) if return_distance else I_ndarr - - out_type = self._get_output_type(X) - I_output = I_ndarr.to_output(out_type) - if return_distance: - D_output = D_ndarr.to_output(out_type) - # drop first column if using training data as X # this will need to be moved to the C++ layer (cuml issue #2562) if use_training_data: - if out_type in {'cupy', 'numpy', 'numba'}: - return (D_output[:, 1:], I_output[:, 1:]) \ - if return_distance else I_output[:, 1:] - else: - I_output.drop(I_output.columns[0], axis=1) - if return_distance: - D_output.drop(D_output.columns[0], axis=1) + D_ndarr = D_ndarr[:, 1:] + I_ndarr = I_ndarr[:, 1:] - return (D_output, I_output) if return_distance else I_output + return (D_ndarr, I_ndarr) if return_distance else I_ndarr @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')]) - def kneighbors_graph(self, X=None, n_neighbors=None, mode='connectivity'): + def kneighbors_graph(self, + X=None, + n_neighbors=None, + mode='connectivity') -> SparseCumlArray: """ Find the k nearest neighbors of column vectors in X and return as a sparse matrix in CSR format. @@ -471,7 +471,7 @@ class NearestNeighbors(Base): numpy's CSR sparse graph (host) """ - if not self._X_m: + if not self.X_m: raise ValueError('This NearestNeighbors instance has not been ' 'fitted yet, call "fit" before using this ' 'estimator') @@ -481,40 +481,36 @@ class NearestNeighbors(Base): if mode == 'connectivity': ind_mlarr = self._kneighbors(X, n_neighbors, - return_distance=False, - _output_cumlarray=True) + return_distance=False) n_samples = ind_mlarr.shape[0] distances = cp.ones(n_samples * n_neighbors, dtype=np.float32) elif mode == 'distance': - dist_mlarr, ind_mlarr = self._kneighbors(X, n_neighbors, - _output_cumlarray=True) - distances = dist_mlarr.to_output('cupy')[:, 1:] if X is None \ - else dist_mlarr.to_output('cupy') - distances = cp.ravel(distances) + dist_mlarr, ind_mlarr = self._kneighbors(X, n_neighbors) + distances = dist_mlarr + distances = cp.ravel(cp.asarray(distances)) else: raise ValueError('Unsupported mode, must be one of "connectivity"' ' or "distance" but got "%s" instead' % mode) - indices = ind_mlarr.to_output('cupy')[:, 1:] if X is None \ - else ind_mlarr.to_output('cupy') + indices = ind_mlarr n_samples = indices.shape[0] - n_samples_fit = self._X_m.shape[0] + n_samples_fit = self.X_m.shape[0] n_nonzero = n_samples * n_neighbors rowptr = cp.arange(0, n_nonzero + 1, n_neighbors) sparse_csr = cupyx.scipy.sparse.csr_matrix((distances, - cp.ravel(indices), - rowptr), shape=(n_samples, - n_samples_fit)) + cp.ravel( + cp.asarray(indices)), + rowptr), + shape=(n_samples, + n_samples_fit)) - if self._get_output_type(X) is 'numpy': - return sparse_csr.get() - else: - return sparse_csr + return sparse_csr +@cuml.internals.api_return_sparse_array() def kneighbors_graph(X=None, n_neighbors=5, mode='connectivity', verbose=False, handle=None, algorithm="brute", metric="euclidean", p=2, include_self=False, metric_params=None, output_type=None): @@ -574,6 +570,12 @@ def kneighbors_graph(X=None, n_neighbors=5, mode='connectivity', verbose=False, module level, `cuml.global_output_type`. See :ref:`output-data-type-configuration` for more info. + .. deprecated:: 0.17 + `output_type` is deprecated in 0.17 and will be removed in 0.18. + Please use the module level output type control, + `cuml.global_output_type`. + See :ref:`output-data-type-configuration` for more info. + Returns ------- A : sparse graph in CSR format, shape = (n_samples, n_samples_fit) @@ -584,6 +586,14 @@ def kneighbors_graph(X=None, n_neighbors=5, mode='connectivity', verbose=False, numpy's CSR sparse graph (host) """ + + # Check for deprecated `output_type` and warn. Set manually if specified + if (output_type is not None): + warnings.warn("Using the `output_type` argument is deprecated and " + "will be removed in 0.18. Please specify the output " + "type using `cuml.using_output_type()` instead", + DeprecationWarning) + X = NearestNeighbors(n_neighbors, verbose, handle, algorithm, metric, p, metric_params=metric_params, output_type=output_type).fit(X) @@ -591,9 +601,10 @@ def kneighbors_graph(X=None, n_neighbors=5, mode='connectivity', verbose=False, if include_self == 'auto': include_self = mode == 'connectivity' - if not include_self: - query = None - else: - query = X.X_m + with cuml.internals.exit_internal_api(): + if not include_self: + query = None + else: + query = X.X_m return X.kneighbors_graph(X=query, n_neighbors=n_neighbors, mode=mode) diff --git a/python/cuml/neighbors/nearest_neighbors_mg.pyx b/python/cuml/neighbors/nearest_neighbors_mg.pyx index d5de5800e1..ee841f9970 100644 --- a/python/cuml/neighbors/nearest_neighbors_mg.pyx +++ b/python/cuml/neighbors/nearest_neighbors_mg.pyx @@ -22,12 +22,13 @@ import numpy as np import pandas as pd import cudf import ctypes -import cuml import warnings +import typing from cuml.common.base import Base from cuml.common.array import CumlArray from cuml.common import input_to_cuml_array +import cuml.internals from cython.operator cimport dereference as deref @@ -153,9 +154,20 @@ class NearestNeighborsMG(NearestNeighbors): super(NearestNeighborsMG, self).__init__(**kwargs) self.batch_size = batch_size - def kneighbors(self, indices, index_m, n, index_parts_to_ranks, - queries, query_m, query_parts_to_ranks, - rank, n_neighbors=None, convert_dtype=True): + @cuml.internals.api_base_return_generic_skipall + def kneighbors( + self, + indices, + index_m, + n, + index_parts_to_ranks, + queries, + query_m, + query_parts_to_ranks, + rank, + n_neighbors=None, + convert_dtype=True + ) -> typing.Tuple[typing.List[CumlArray], typing.List[CumlArray]]: """ Query the kneighbors of an index @@ -178,7 +190,9 @@ class NearestNeighborsMG(NearestNeighbors): output indices, output distances """ self._set_base_attributes(output_type=indices[0]) - out_type = self._get_output_type(queries[0]) + + # Specify the output return type + cuml.internals.set_api_output_type(self._get_output_type(queries[0])) n_neighbors = self.n_neighbors if n_neighbors is None else n_neighbors @@ -245,11 +259,6 @@ class NearestNeighborsMG(NearestNeighbors): self.handle.sync() - output_i = list(map(lambda x: x.to_output(out_type), - output_i_arrs)) - output_d = list(map(lambda x: x.to_output(out_type), - output_d_arrs)) - _free_mem(idx_desc, q_desc, out_i_vec, @@ -257,4 +266,4 @@ class NearestNeighborsMG(NearestNeighbors): idx_local_parts, q_local_parts) - return output_i, output_d + return output_i_arrs, output_d_arrs diff --git a/python/cuml/preprocessing/LabelEncoder.py b/python/cuml/preprocessing/LabelEncoder.py index f8c9cfae32..2e395d8fed 100644 --- a/python/cuml/preprocessing/LabelEncoder.py +++ b/python/cuml/preprocessing/LabelEncoder.py @@ -19,7 +19,6 @@ from cuml import Base -from cuml.common.memory_utils import with_cupy_rmm from cuml.common.exceptions import NotFittedError @@ -152,7 +151,6 @@ def _validate_keywords(self): "got {0}.".format(self.handle_unknown)) raise ValueError(msg) - @with_cupy_rmm def fit(self, y, _classes=None): """ Fit a LabelEncoder (nvcategory) instance to a set of categories @@ -235,7 +233,6 @@ def fit_transform(self, y: cudf.Series) -> cudf.Series: self._fitted = True return cudf.Series(y._column.codes, index=y.index) - @with_cupy_rmm def inverse_transform(self, y: cudf.Series) -> cudf.Series: """ Revert ordinal label to original label diff --git a/python/cuml/preprocessing/encoders.py b/python/cuml/preprocessing/encoders.py index 60c0e17d79..164ffd7164 100644 --- a/python/cuml/preprocessing/encoders.py +++ b/python/cuml/preprocessing/encoders.py @@ -23,7 +23,6 @@ from cudf.core import GenericIndex import cuml.common.logger as logger -from cuml.common import with_cupy_rmm import warnings @@ -246,8 +245,11 @@ def fit(self, X): if type(self.categories) is str and self.categories == 'auto': self._features = X.columns self._encoders = { - feature: LabelEncoder(handle_unknown=self.handle_unknown).fit( - self._unique(X[feature])) + feature: LabelEncoder(handle=self.handle, + verbose=self.verbose, + output_type=self.output_type, + handle_unknown=self.handle_unknown).fit( + self._unique(X[feature])) for feature in self._features } else: @@ -258,8 +260,14 @@ def fit(self, X): " it has to be of shape (n_features, _).") self._encoders = dict() for feature in self._features: - le = LabelEncoder(handle_unknown=self.handle_unknown) + + le = LabelEncoder(handle=self.handle, + verbose=self.verbose, + output_type=self.output_type, + handle_unknown=self.handle_unknown) + self._encoders[feature] = le.fit(self.categories[feature]) + if self.handle_unknown == 'error': if self._has_unknown(X[feature], self._encoders[feature].classes_): @@ -290,7 +298,6 @@ def fit_transform(self, X): X = self._check_input(X) return self.fit(X).transform(X) - @with_cupy_rmm def transform(self, X): """ Transform X using one-hot encoding. @@ -328,9 +335,11 @@ def transform(self, X): # If we exceed the max value, upconvert if (max_value > np.iinfo(col_idx.dtype).max): col_idx = col_idx.astype(np.min_scalar_type(max_value)) - logger.debug("Upconverting column: '{}', to dtype: '{}', \ - to support up to {} classes".format( - feature, np.min_scalar_type(max_value), max_value)) + logger.debug("Upconverting column: '{}', to dtype: '{}', " + "to support up to {} classes".format( + feature, + np.min_scalar_type(max_value), + max_value)) # increase indices to take previous features into account col_idx += j @@ -381,7 +390,6 @@ def transform(self, X): "Calculated column code dtypes: {}.\n" "Internal Error: {}".format(input_types_str, repr(e))) - @with_cupy_rmm def inverse_transform(self, X): """ Convert the data back to the original representation. @@ -454,11 +462,10 @@ def inverse_transform(self, X): return result def get_param_names(self): - return super().get_param_names() + \ - [ - "categories", - "drop", - "sparse", - "dtype", - "handle_unknown", - ] + return super().get_param_names() + [ + "categories", + "drop", + "sparse", + "dtype", + "handle_unknown", + ] diff --git a/python/cuml/preprocessing/label.py b/python/cuml/preprocessing/label.py index 13a4f1c9f1..88f0bb9e2b 100644 --- a/python/cuml/preprocessing/label.py +++ b/python/cuml/preprocessing/label.py @@ -16,16 +16,17 @@ import cupy as cp import cupyx -from cuml.prims.label import make_monotonic, check_labels, \ - invert_labels - from cuml import Base -from cuml.common import rmm_cupy_ary, with_cupy_rmm, CumlArray -from cuml.common import has_scipy +import cuml.internals +from cuml.common import CumlArray, has_scipy +from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.common.array_sparse import SparseCumlArray +from cuml.prims.label import check_labels, invert_labels, make_monotonic +@cuml.internals.api_return_sparse_array() def label_binarize(y, classes, neg_label=0, pos_label=1, - sparse_output=False): + sparse_output=False) -> SparseCumlArray: """ A stateless helper function to dummy encode multi-class labels. @@ -39,17 +40,20 @@ def label_binarize(y, classes, neg_label=0, pos_label=1, sparse_output : bool whether to return sparse array """ - classes = rmm_cupy_ary(cp.asarray, classes, dtype=classes.dtype) - labels = rmm_cupy_ary(cp.asarray, y, dtype=y.dtype) + classes = cp.asarray(classes, dtype=classes.dtype) + labels = cp.asarray(y, dtype=y.dtype) if not check_labels(labels, classes): raise ValueError("Unseen classes encountered in input") - row_ind = rmm_cupy_ary(cp.arange, 0, labels.shape[0], 1, - dtype=y.dtype) + row_ind = cp.arange(0, labels.shape[0], 1, + dtype=y.dtype) col_ind, _ = make_monotonic(labels, classes, copy=True) - val = rmm_cupy_ary(cp.full, row_ind.shape[0], pos_label, dtype=y.dtype) + # Convert from CumlArray to cupy + col_ind = cp.asarray(col_ind) + + val = cp.full(row_ind.shape[0], pos_label, dtype=y.dtype) sp = cupyx.scipy.sparse.coo_matrix((val, (row_ind, col_ind)), shape=(col_ind.shape[0], @@ -145,6 +149,8 @@ class LabelBinarizer(Base): [ 0 5 10 7 2 4 1 0 0 4 3 2 1] """ + classes_ = CumlArrayDescriptor() + def __init__(self, neg_label=0, pos_label=1, @@ -171,10 +177,9 @@ def __init__(self, self.neg_label = neg_label self.pos_label = pos_label self.sparse_output = sparse_output - self._classes_ = None + self.classes_ = None - @with_cupy_rmm - def fit(self, y): + def fit(self, y) -> "LabelBinarizer": """ Fit label binarizer @@ -189,8 +194,6 @@ def fit(self, y): self : returns an instance of self. """ - self._set_output_type(y) - if y.ndim > 2: raise ValueError("labels cannot be greater than 2 dimensions") @@ -200,16 +203,15 @@ def fit(self, y): if unique_classes != [0, 1]: raise ValueError("2-d array can must be binary") - self._classes_ = CumlArray(cp.arange(0, y.shape[1])) + self.classes_ = cp.arange(0, y.shape[1]) else: - self._classes_ = CumlArray(cp.unique(y).astype(y.dtype)) + self.classes_ = cp.unique(y).astype(y.dtype) cp.cuda.Stream.null.synchronize() return self - @with_cupy_rmm - def fit_transform(self, y): + def fit_transform(self, y) -> SparseCumlArray: """ Fit label binarizer and transform multi-class labels to their dummy-encoded representation. @@ -225,7 +227,7 @@ def fit_transform(self, y): """ return self.fit(y).transform(y) - def transform(self, y): + def transform(self, y) -> SparseCumlArray: """ Transform multi-class labels to their dummy-encoded representation labels. @@ -238,12 +240,12 @@ def transform(self, y): ------- arr : array with encoded labels """ - return label_binarize(y, self._classes_, + return label_binarize(y, self.classes_, pos_label=self.pos_label, neg_label=self.neg_label, sparse_output=self.sparse_output) - def inverse_transform(self, y, threshold=None): + def inverse_transform(self, y, threshold=None) -> CumlArray: """ Transform binary labels back to original multi-class labels @@ -267,18 +269,15 @@ def inverse_transform(self, y, threshold=None): # If we are already given multi-class, just return it. if cupyx.scipy.sparse.isspmatrix(y): - y_mapped = y.tocsr().indices.astype(self._classes_.dtype) + y_mapped = y.tocsr().indices.astype(self.classes_.dtype) elif scipy_sparse_isspmatrix(y): y = y.tocsr() - y_mapped = rmm_cupy_ary(cp.array, y.indices, - dtype=y.indices.dtype) + y_mapped = cp.array(y.indices, dtype=y.indices.dtype) else: - y_mapped = rmm_cupy_ary(cp.argmax, - rmm_cupy_ary(cp.asarray, y, - dtype=y.dtype), - axis=1).astype(y.dtype) + y_mapped = cp.argmax(cp.asarray(y, dtype=y.dtype), + axis=1).astype(y.dtype) - return invert_labels(y_mapped, self._classes_) + return invert_labels(y_mapped, self.classes_) def get_param_names(self): return super().get_param_names() + [ diff --git a/python/cuml/preprocessing/model_selection.py b/python/cuml/preprocessing/model_selection.py index f3a94140cc..0ec3359eba 100644 --- a/python/cuml/preprocessing/model_selection.py +++ b/python/cuml/preprocessing/model_selection.py @@ -21,7 +21,6 @@ import warnings from cuml.common.memory_utils import _strides_to_order -from cuml.common.memory_utils import rmm_cupy_ary from numba import cuda from typing import Union @@ -401,11 +400,11 @@ def train_test_split(X, if shuffle: # Shuffle the data if random_state is None or isinstance(random_state, int): - idxs = rmm_cupy_ary(cp.arange, X.shape[0]) + idxs = cp.arange(X.shape[0]) random_state = cp.random.RandomState(seed=random_state) elif isinstance(random_state, cp.random.RandomState): - idxs = rmm_cupy_ary(cp.arange, X.shape[0]) + idxs = cp.arange(X.shape[0]) elif isinstance(random_state, np.random.RandomState): idxs = np.arange(X.shape[0]) diff --git a/python/cuml/prims/label/classlabels.py b/python/cuml/prims/label/classlabels.py index 591b68f9a2..a901754dd0 100644 --- a/python/cuml/prims/label/classlabels.py +++ b/python/cuml/prims/label/classlabels.py @@ -14,10 +14,14 @@ # limitations under the License. # +import typing import cupy as cp import math +from cuml.common.input_utils import input_to_cupy_array + +import cuml.internals +from cuml.common.array import CumlArray -from cuml.common.memory_utils import rmm_cupy_ary from cuml.common.kernel_utils import cuda_kernel_factory @@ -109,8 +113,11 @@ def _validate_kernel(dtype): "validate_labels_kernel") -def make_monotonic(labels, classes=None, copy=False): - +@cuml.internals.api_return_generic(input_arg="labels", + get_output_type=True) +def make_monotonic(labels, + classes=None, + copy=False) -> typing.Tuple[CumlArray, CumlArray]: """ Takes a set of labels that might not be drawn from the set [0, n-1] and renumbers them to be drawn that @@ -133,17 +140,15 @@ def make_monotonic(labels, classes=None, copy=False): mapped_labels : array-like of size (n,) classes : array-like of size (n_classes,) """ - - labels = rmm_cupy_ary(cp.asarray, labels, dtype=labels.dtype) - - if copy: - labels = labels.copy() + labels = input_to_cupy_array(labels, deepcopy=copy).array if labels.ndim != 1: raise ValueError("Labels array must be 1D") if classes is None: - classes = rmm_cupy_ary(cp.unique, labels) + classes = cp.unique(labels) + else: + classes = input_to_cupy_array(classes).array smem = labels.dtype.itemsize * int(classes.shape[0]) @@ -158,7 +163,8 @@ def make_monotonic(labels, classes=None, copy=False): return labels, classes -def check_labels(labels, classes): +@cuml.internals.api_return_any() +def check_labels(labels, classes) -> bool: """ Validates that a set of labels is drawn from the unique set of given classes. @@ -176,13 +182,13 @@ def check_labels(labels, classes): result : boolean """ + labels = input_to_cupy_array(labels, order="K").array + classes = input_to_cupy_array(classes, order="K").array + if labels.dtype != classes.dtype: raise ValueError("Labels and classes must have same dtype (%s != %s" % (labels.dtype, classes.dtype)) - labels = rmm_cupy_ary(cp.asarray, labels, dtype=labels.dtype) - classes = rmm_cupy_ary(cp.asarray, classes, dtype=classes.dtype) - if labels.ndim != 1: raise ValueError("Labels array must be 1D") @@ -198,7 +204,9 @@ def check_labels(labels, classes): return valid[0] == 1 -def invert_labels(labels, classes, copy=False): +@cuml.internals.api_return_array(input_arg="labels", + get_output_type=True) +def invert_labels(labels, classes, copy=False) -> CumlArray: """ Takes a set of labels that have been mapped to be drawn from a monotonically increasing set and inverts them to @@ -221,16 +229,12 @@ def invert_labels(labels, classes, copy=False): inverted labels : array-like of size (n,) """ + labels = input_to_cupy_array(labels, deepcopy=copy).array + classes = input_to_cupy_array(classes).array if labels.dtype != classes.dtype: raise ValueError("Labels and classes must have same dtype (%s != %s" % (labels.dtype, classes.dtype)) - labels = rmm_cupy_ary(cp.asarray, labels, dtype=labels.dtype) - classes = rmm_cupy_ary(cp.asarray, classes, dtype=classes.dtype) - - if copy: - labels = labels.copy() - smem = labels.dtype.itemsize * len(classes) inverse_map = _inverse_map_kernel(labels.dtype) inverse_map((math.ceil(len(labels) / 32),), (32,), diff --git a/python/cuml/prims/stats/covariance.py b/python/cuml/prims/stats/covariance.py index f94f51f17d..10b53d5456 100644 --- a/python/cuml/prims/stats/covariance.py +++ b/python/cuml/prims/stats/covariance.py @@ -18,7 +18,7 @@ import cupyx import math -from cuml.common.memory_utils import with_cupy_rmm +import cuml.internals from cuml.common.kernel_utils import cuda_kernel_factory cov_kernel_str = r''' @@ -41,7 +41,7 @@ def _cov_kernel(dtype): "cov_kernel") -@with_cupy_rmm +@cuml.internals.api_return_any() def cov(x, y, mean_x=None, mean_y=None, return_gram=False, return_mean=False): """ diff --git a/python/cuml/random_projection/random_projection.pyx b/python/cuml/random_projection/random_projection.pyx index 7bf26afb7d..4d8381024f 100644 --- a/python/cuml/random_projection/random_projection.pyx +++ b/python/cuml/random_projection/random_projection.pyx @@ -22,6 +22,7 @@ import numpy as np from libc.stdint cimport uintptr_t from libcpp cimport bool +import cuml.internals from cuml.common.array import CumlArray from cuml.common.base import Base from cuml.raft.common.handle cimport * @@ -224,6 +225,7 @@ cdef class BaseRandomProjection(): def density(self, value): self.params.density = value + @cuml.internals.api_base_return_any() def fit(self, X, y=None): """ Fit the model. This function generates the random matrix on GPU. @@ -241,8 +243,6 @@ cdef class BaseRandomProjection(): generated random matrix as attributes """ - self._set_base_attributes(output_type=X, n_features=X) - _, n_samples, n_features, self.dtype = \ input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) @@ -259,6 +259,7 @@ cdef class BaseRandomProjection(): return self + @cuml.internals.api_base_return_array() def transform(self, X, convert_dtype=True): """ Apply transformation on provided data. This function outputs @@ -283,9 +284,6 @@ cdef class BaseRandomProjection(): Result of multiplication between input matrix and random matrix """ - - out_type = self._get_output_type(X) - X_m, n_samples, n_features, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype @@ -319,8 +317,9 @@ cdef class BaseRandomProjection(): self.handle.sync() - return X_new.to_output(out_type) + return X_new + @cuml.internals.api_base_return_array(get_output_type=False) def fit_transform(self, X, convert_dtype=True): return self.fit(X).transform(X, convert_dtype) diff --git a/python/cuml/solvers/cd.pyx b/python/cuml/solvers/cd.pyx index 38a36c553e..f9e6094f21 100644 --- a/python/cuml/solvers/cd.pyx +++ b/python/cuml/solvers/cd.pyx @@ -25,15 +25,13 @@ from libcpp cimport bool from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free -from cuml.common.array import CumlArray +from cuml.common import CumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.base import Base from cuml.common.doc_utils import generate_docstring from cuml.raft.common.handle cimport handle_t -from cuml.common import get_cudf_column_ptr -from cuml.common import get_dev_array_ptr -from cuml.common import input_to_dev_array -from cuml.common import zeros from cuml.common.input_utils import input_to_cuml_array +from cuml.common.memory_utils import with_cupy_rmm cdef extern from "cuml/solvers/solver.hpp" namespace "ML::Solver": @@ -187,6 +185,8 @@ class CD(Base): """ + coef_ = CumlArrayDescriptor() + def __init__(self, loss='squared_loss', alpha=0.0001, l1_ratio=0.15, fit_intercept=True, normalize=False, max_iter=1000, tol=1e-3, shuffle=True, handle=None, output_type=None, verbose=False): @@ -207,7 +207,7 @@ class CD(Base): self.tol = tol self.shuffle = shuffle self.intercept_value = 0.0 - self._coef_ = None # accessed via estimator.coef_ + self.coef_ = None self.intercept_ = None def _check_alpha(self, alpha): @@ -222,14 +222,12 @@ class CD(Base): }[self.loss] @generate_docstring() - def fit(self, X, y, convert_dtype=False): + def fit(self, X, y, convert_dtype=False) -> "CD": """ Fit the model with X and y. """ - self._set_base_attributes(output_type=X) - X_m, n_rows, self.n_cols, self.dtype = \ input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) @@ -244,8 +242,8 @@ class CD(Base): self.n_alpha = 1 - self._coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) - cdef uintptr_t coef_ptr = self._coef_.ptr + self.coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) + cdef uintptr_t coef_ptr = self.coef_.ptr cdef float c_intercept1 cdef double c_intercept2 @@ -296,13 +294,11 @@ class CD(Base): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=False): + def predict(self, X, convert_dtype=False) -> CumlArray: """ Predicts the y for X. """ - out_type = self._get_output_type(X) - X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype @@ -310,7 +306,7 @@ class CD(Base): check_cols=self.n_cols) cdef uintptr_t X_ptr = X_m.ptr - cdef uintptr_t coef_ptr = self._coef_.ptr + cdef uintptr_t coef_ptr = self.coef_.ptr preds = CumlArray.zeros(n_rows, dtype=self.dtype) cdef uintptr_t preds_ptr = preds.ptr @@ -340,7 +336,7 @@ class CD(Base): del(X_m) - return preds.to_output(out_type) + return preds def get_param_names(self): return super().get_param_names() + [ diff --git a/python/cuml/solvers/cd_mg.pyx b/python/cuml/solvers/cd_mg.pyx index 8f072e58b3..b8621b94c4 100644 --- a/python/cuml/solvers/cd_mg.pyx +++ b/python/cuml/solvers/cd_mg.pyx @@ -24,6 +24,7 @@ from libcpp cimport bool from libc.stdint cimport uintptr_t, uint32_t, uint64_t from cython.operator cimport dereference as deref +import cuml.internals from cuml.common.base import Base from cuml.common.array import CumlArray from cuml.raft.common.handle cimport handle_t @@ -74,6 +75,7 @@ class CDMG(MGFitMixin, CD): def __init__(self, **kwargs): super(CDMG, self).__init__(**kwargs) + @cuml.internals.api_base_return_any_skipall def _fit(self, X, y, coef_ptr, input_desc): cdef float float_intercept diff --git a/python/cuml/solvers/qn.pyx b/python/cuml/solvers/qn.pyx index b3a8c8c786..684bd89a03 100644 --- a/python/cuml/solvers/qn.pyx +++ b/python/cuml/solvers/qn.pyx @@ -22,8 +22,10 @@ import numpy as np from libcpp cimport bool from libc.stdint cimport uintptr_t +import cuml.internals from cuml.common.array import CumlArray from cuml.common.base import Base +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.doc_utils import generate_docstring from cuml.raft.common.handle cimport handle_t from cuml.common import input_to_cuml_array @@ -245,6 +247,9 @@ class QN(Base): """ + _coef_ = CumlArrayDescriptor() + intercept_ = CumlArrayDescriptor() + def __init__(self, loss='sigmoid', fit_intercept=True, l1_strength=0.0, l2_strength=0.0, max_iter=1000, tol=1e-3, linesearch_max_iter=50, lbfgs_memory=5, @@ -261,7 +266,8 @@ class QN(Base): self.linesearch_max_iter = linesearch_max_iter self.lbfgs_memory = lbfgs_memory self.num_iter = 0 - self._coef_ = None # accessed via coef_ + self._coef_ = None + self.intercept_ = None if loss not in ['sigmoid', 'softmax', 'normal']: raise ValueError("loss " + str(loss) + " not supported.") @@ -275,15 +281,20 @@ class QN(Base): 'normal': 1 }[loss] + @property + @cuml.internals.api_base_return_array_skipall + def coef_(self): + if self.fit_intercept: + return self._coef_[0:-1] + else: + return self._coef_ + @generate_docstring() - @with_cupy_rmm - def fit(self, X, y, convert_dtype=False): + def fit(self, X, y, convert_dtype=False) -> "QN": """ Fit the model with X and y. """ - self._set_base_attributes(output_type=X) - X_m, n_rows, self.n_cols, self.dtype = input_to_cuml_array( X, order='F', check_dtype=[np.float32, np.float64] ) @@ -374,6 +385,8 @@ class QN(Base): self.num_iters = num_iters + self._calc_intercept() + self.handle.sync() del X_m @@ -381,7 +394,8 @@ class QN(Base): return self - def _decision_function(self, X, convert_dtype=False): + @cuml.internals.api_base_return_array_skipall + def _decision_function(self, X, convert_dtype=False) -> CumlArray: """ Gives confidence score for X @@ -441,6 +455,8 @@ class QN(Base): self.loss_type, scores_ptr) + self._calc_intercept() + self.handle.sync() del X_m @@ -451,14 +467,12 @@ class QN(Base): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=False): + @cuml.internals.api_base_return_array(get_output_dtype=True) + def predict(self, X, convert_dtype=False) -> CumlArray: """ Predicts the y for X. """ - out_type = self._get_output_type(X) - out_dtype = self._get_target_dtype() - X_m, n_rows, n_cols, self.dtype = input_to_cuml_array( X, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype else None), @@ -496,28 +510,28 @@ class QN(Base): self.loss_type, pred_ptr) + self._calc_intercept() + self.handle.sync() del X_m - return preds.to_output(output_type=out_type, output_dtype=out_dtype) + return preds def score(self, X, y): return accuracy_score(y, self.predict(X)) - def __getattr__(self, attr): - if attr == 'intercept_': - if self.fit_intercept: - return self._coef_[-1].to_output(self.output_type) - else: - return CumlArray.zeros(shape=1) - elif attr == 'coef_': - if self.fit_intercept: - return self._coef_[0:-1].to_output(self.output_type) - else: - return self._coef_.to_output(self.output_type) + def _calc_intercept(self): + """ + If `fit_intercept == True`, then the last row of `coef_` contains + `intercept_`. This should be called after every function that sets + `coef_` + """ + + if (self.fit_intercept): + self.intercept_ = self._coef_[-1] else: - return super().__getattr__(attr) + self.intercept_ = CumlArray.zeros(shape=1) def get_param_names(self): return super().get_param_names() + \ diff --git a/python/cuml/solvers/sgd.pyx b/python/cuml/solvers/sgd.pyx index f70f14922f..091fada5ed 100644 --- a/python/cuml/solvers/sgd.pyx +++ b/python/cuml/solvers/sgd.pyx @@ -15,6 +15,8 @@ # distutils: language = c++ +import typing + import ctypes import cudf import numpy as np @@ -26,8 +28,10 @@ from libcpp cimport bool from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free +import cuml.internals from cuml.common.base import Base from cuml.common.array import CumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.doc_utils import generate_docstring from cuml.raft.common.handle cimport handle_t from cuml.common import input_to_cuml_array, with_cupy_rmm @@ -217,6 +221,9 @@ class SGD(Base): """ + coef_ = CumlArrayDescriptor() + classes_ = CumlArrayDescriptor() + def __init__(self, loss='squared_loss', penalty='none', alpha=0.0001, l1_ratio=0.15, fit_intercept=True, epochs=1000, tol=1e-3, shuffle=True, learning_rate='constant', eta0=0.001, @@ -278,7 +285,7 @@ class SGD(Base): self.batch_size = batch_size self.n_iter_no_change = n_iter_no_change self.intercept_value = 0.0 - self._coef_ = None # accessed via coef_ + self.coef_ = None self.intercept_ = None def _check_alpha(self, alpha): @@ -303,14 +310,12 @@ class SGD(Base): }[self.penalty] @generate_docstring() - @with_cupy_rmm - def fit(self, X, y, convert_dtype=False): + @cuml.internals.api_base_return_any(set_output_dtype=True) + def fit(self, X, y, convert_dtype=False) -> "SGD": """ Fit the model with X and y. """ - self._set_base_attributes(output_type=X, target_dtype=y) - X_m, n_rows, self.n_cols, self.dtype = \ input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) @@ -322,16 +327,16 @@ class SGD(Base): _estimator_type = getattr(self, '_estimator_type', None) if _estimator_type == "classifier": - self._classes_ = CumlArray(cp.unique(y_m)) + self.classes_ = cp.unique(y_m) cdef uintptr_t X_ptr = X_m.ptr cdef uintptr_t y_ptr = y_m.ptr self.n_alpha = 1 - self._coef_ = CumlArray.zeros(self.n_cols, - dtype=self.dtype) - cdef uintptr_t coef_ptr = self._coef_.ptr + self.coef_ = CumlArray.zeros(self.n_cols, + dtype=self.dtype) + cdef uintptr_t coef_ptr = self.coef_.ptr cdef float c_intercept1 cdef double c_intercept2 @@ -395,13 +400,11 @@ class SGD(Base): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=False): + def predict(self, X, convert_dtype=False) -> CumlArray: """ Predicts the y for X. """ - output_type = self._get_output_type(X) - X_m, n_rows, n_cols, self.dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype @@ -410,7 +413,7 @@ class SGD(Base): cdef uintptr_t X_ptr = X_m.ptr - cdef uintptr_t coef_ptr = self._coef_.ptr + cdef uintptr_t coef_ptr = self.coef_.ptr preds = CumlArray.zeros(n_rows, dtype=self.dtype) cdef uintptr_t preds_ptr = preds.ptr @@ -439,20 +442,18 @@ class SGD(Base): del(X_m) - return preds.to_output(output_type) + return preds @generate_docstring(return_values={'name': 'preds', 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predictClass(self, X, convert_dtype=False): + @cuml.internals.api_base_return_array(get_output_dtype=True) + def predictClass(self, X, convert_dtype=False) -> CumlArray: """ Predicts the y for X. """ - output_type = self._get_output_type(X) - out_dtype = self._get_target_dtype() - X_m, n_rows, n_cols, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype @@ -460,7 +461,7 @@ class SGD(Base): check_cols=self.n_cols) cdef uintptr_t X_ptr = X_m.ptr - cdef uintptr_t coef_ptr = self._coef_.ptr + cdef uintptr_t coef_ptr = self.coef_.ptr preds = CumlArray.zeros(n_rows, dtype=dtype) cdef uintptr_t preds_ptr = preds.ptr cdef handle_t* handle_ = self.handle.getHandle() @@ -488,7 +489,7 @@ class SGD(Base): del(X_m) - return preds.to_output(output_type=output_type, output_dtype=out_dtype) + return preds def get_param_names(self): return super().get_param_names() + [ diff --git a/python/cuml/svm/svc.pyx b/python/cuml/svm/svc.pyx index 8e35acd84f..cfc3d17536 100644 --- a/python/cuml/svm/svc.pyx +++ b/python/cuml/svm/svc.pyx @@ -15,6 +15,8 @@ # distutils: language = c++ +import typing + import ctypes import cudf import cupy as cp @@ -25,12 +27,15 @@ from numba import cuda from cython.operator cimport dereference as deref from libc.stdint cimport uintptr_t +import cuml.internals from cuml.common.array import CumlArray from cuml.common.base import Base, ClassifierMixin from cuml.common.doc_utils import generate_docstring +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.logger import warn from cuml.raft.common.handle cimport handle_t from cuml.common import input_to_cuml_array, input_to_host_array, with_cupy_rmm +from cuml.common.input_utils import input_to_cupy_array from cuml.preprocessing import LabelEncoder from cuml.common.memory_utils import using_output_type from libcpp cimport bool, nullptr @@ -101,24 +106,6 @@ cdef extern from "cuml/svm/svc.hpp" namespace "ML::SVM": svmModel[math_t] &m) except + -def _to_output(X, out_type): - """ Convert array X to out_type. - - X can be host (numpy) array. - - Arguments: - X: cuDF.DataFrame, cuDF.Series, numba array, NumPy array or any - cuda_array_interface compliant array like CuPy or pytorch. - - out_type: string (as defined by the CumlArray's to_output method). - """ - if out_type == 'numpy' and isinstance(X, np.ndarray): - return X - else: - X, _, _, _ = input_to_cuml_array(X) - return X.to_output(output_type=out_type) - - class SVC(SVMBase, ClassifierMixin): """ SVC (C-Support Vector Classification) @@ -253,6 +240,7 @@ class SVC(SVMBase, ClassifierMixin): `_ """ + def __init__(self, handle=None, C=1, kernel='rbf', degree=3, gamma='scale', coef0=0.0, tol=1e-3, cache_size=200.0, max_iter=-1, nochange_steps=1000, verbose=False, @@ -269,14 +257,15 @@ class SVC(SVMBase, ClassifierMixin): self.svmType = C_SVC @property + @cuml.internals.api_base_return_array_skipall def classes_(self): if self.probability: return self.prob_svc.classes_ else: - return self.unique_labels + return self._unique_labels_ - @with_cupy_rmm - def _apply_class_weight(self, sample_weight, y_m): + @cuml.internals.api_base_return_array_skipall + def _apply_class_weight(self, sample_weight, y_m) -> CumlArray: """ Scale the sample weights with the class weights. @@ -299,7 +288,9 @@ class SVC(SVMBase, ClassifierMixin): if self.class_weight is None: return sample_weight - le = LabelEncoder() + le = LabelEncoder(handle=self.handle, + verbose=self.verbose, + output_type=self.output_type) labels = y_m.to_output(output_type='series') encoded_labels = cp.asarray(le.fit_transform(labels)) @@ -320,10 +311,9 @@ class SVC(SVMBase, ClassifierMixin): if sample_weight is None: sample_weight = cp.ones(y_m.shape, dtype=self.dtype) else: - sample_weight_m, _, _, _ = \ - input_to_cuml_array(sample_weight, convert_to_dtype=self.dtype, + sample_weight, _, _, _ = \ + input_to_cupy_array(sample_weight, convert_to_dtype=self.dtype, check_rows=self.n_rows, check_cols=1) - sample_weight = sample_weight_m.to_output(output_type='cupy') for label, weight in class_weight.items(): sample_weight[encoded_labels==label] *= weight @@ -331,30 +321,36 @@ class SVC(SVMBase, ClassifierMixin): return sample_weight @generate_docstring(y='dense_anydtype') - @with_cupy_rmm - def fit(self, X, y, sample_weight=None, convert_dtype=True): + @cuml.internals.api_base_return_any(set_output_dtype=True) + def fit(self, X, y, sample_weight=None, convert_dtype=True) -> "SVC": """ Fit the model with X and y. """ - self._set_base_attributes(output_type=X, target_dtype=y, n_features=X) if self.probability: params = self.get_params() params["probability"] = False + + # Ensure it always outputs numpy + params["output_type"] = "numpy" + # Currently CalibratedClassifierCV expects data on the host, see # https://github.com/rapidsai/cuml/issues/2608 X, _, _, _, _ = input_to_host_array(X) y, _, _, _, _ = input_to_host_array(y) - with using_output_type('numpy'): - if not has_sklearn(): - raise RuntimeError( - "Scikit-learn is needed to use SVM probabilities") - self.prob_svc = CalibratedClassifierCV(SVC(**params), cv=5, - method='sigmoid') + if not has_sklearn(): + raise RuntimeError( + "Scikit-learn is needed to use SVM probabilities") + + self.prob_svc = CalibratedClassifierCV(SVC(**params), + cv=5, + method='sigmoid') + + with cuml.internals.exit_internal_api(): self.prob_svc.fit(X, y) - self._fit_status_ = 0 + self._fit_status_ = 0 return self X_m, self.n_rows, self.n_cols, self.dtype = \ @@ -379,7 +375,7 @@ class SVC(SVMBase, ClassifierMixin): sample_weight_ptr = sample_weight_m.ptr self._dealloc() # delete any previously fitted model - self._coef_ = None + self.coef_ = None cdef KernelParams _kernel_params = self._get_kernel_params(X_m) cdef svmParameter param = self._get_svm_params() @@ -415,7 +411,7 @@ class SVC(SVMBase, ClassifierMixin): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=True): + def predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts the class labels for X. The returned y values are the class labels associated to sign(decision_function(X)). @@ -423,11 +419,13 @@ class SVC(SVMBase, ClassifierMixin): if self.probability: self._check_is_fitted('prob_svc') - out_type = self._get_output_type(X) + X, _, _, _, _ = input_to_host_array(X) - preds = self.prob_svc.predict(X) - # prob_svc has numpy output type, change it if it is necessary: - return _to_output(preds, out_type) + + with cuml.internals.exit_internal_api(): + preds = self.prob_svc.predict(X) + # prob_svc has numpy output type, change it if it is necessary: + return preds else: return super(SVC, self).predict(X, True, convert_dtype) @@ -437,7 +435,7 @@ class SVC(SVMBase, ClassifierMixin): 'description': 'Predicted \ probabilities', 'shape': '(n_samples, n_classes)'}) - def predict_proba(self, X, log=False): + def predict_proba(self, X, log=False) -> CumlArray: """ Predicts the class probabilities for X. @@ -452,13 +450,17 @@ class SVC(SVMBase, ClassifierMixin): if self.probability: self._check_is_fitted('prob_svc') - out_type = self._get_output_type(X) + X, _, _, _, _ = input_to_host_array(X) - preds = self.prob_svc.predict_proba(X) - if (log): - preds = np.log(preds) - # prob_svc has numpy output type, change it if it is necessary: - return _to_output(preds, out_type) + + # Exit the internal API when calling sklearn code (forces numpy + # conversion) + with cuml.internals.exit_internal_api(): + preds = self.prob_svc.predict_proba(X) + if (log): + preds = np.log(preds) + # prob_svc has numpy output type, change it if it is necessary: + return preds else: raise AttributeError("This classifier is not fitted to predict " "probabilities. Fit a new classifier with " @@ -469,7 +471,8 @@ class SVC(SVMBase, ClassifierMixin): 'description': 'Log of predicted \ probabilities', 'shape': '(n_samples, n_classes)'}) - def predict_log_proba(self, X): + @cuml.internals.api_base_return_array_skipall + def predict_log_proba(self, X) -> CumlArray: """ Predicts the log probabilities for X (returns log(predict_proba(x)). @@ -483,14 +486,13 @@ class SVC(SVMBase, ClassifierMixin): 'description': 'Decision function \ values', 'shape': '(n_samples, 1)'}) - def decision_function(self, X): + def decision_function(self, X) -> CumlArray: """ Calculates the decision function values for X. """ if self.probability: self._check_is_fitted('prob_svc') - out_type = self._get_output_type(X) # Probabilistic SVC is an ensemble of simple SVC classifiers # fitted to different subset of the training data. As such, it # does not have a single decision function. (During prediction @@ -499,13 +501,14 @@ class SVC(SVMBase, ClassifierMixin): # be useful for visualization, but predictions should be made # using the probabilities. df = np.zeros((X.shape[0],)) - with using_output_type('numpy'): + + with cuml.internals.exit_internal_api(): for clf in self.prob_svc.calibrated_classifiers_: df = df + clf.base_estimator.decision_function(X) df = df / len(self.prob_svc.calibrated_classifiers_) - return _to_output(df, out_type) + return df else: - return super(SVC, self).predict(X, False) + return super().predict(X, False) def get_param_names(self): params = super().get_param_names() + \ diff --git a/python/cuml/svm/svm_base.pyx b/python/cuml/svm/svm_base.pyx index 142045db99..bdd38d2300 100644 --- a/python/cuml/svm/svm_base.pyx +++ b/python/cuml/svm/svm_base.pyx @@ -25,11 +25,14 @@ from numba import cuda from cython.operator cimport dereference as deref from libc.stdint cimport uintptr_t +import cuml.internals from cuml.common.array import CumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.base import Base from cuml.common.exceptions import NotFittedError from cuml.raft.common.handle cimport handle_t from cuml.common import input_to_cuml_array +from cuml.common import using_output_type from libcpp cimport bool cdef extern from "cuml/matrix/kernelparams.h" namespace "MLCommon::Matrix": @@ -196,6 +199,14 @@ class SVMBase(Base): `_ """ + + dual_coef_ = CumlArrayDescriptor() + support_ = CumlArrayDescriptor() + support_vectors_ = CumlArrayDescriptor() + intercept_ = CumlArrayDescriptor() + _internal_coef_ = CumlArrayDescriptor() + _unique_labels_ = CumlArrayDescriptor() + def __init__(self, handle=None, C=1, kernel='rbf', degree=3, gamma='auto', coef0=0.0, tol=1e-3, cache_size=200.0, max_iter=-1, nochange_steps=1000, verbose=False, @@ -220,15 +231,15 @@ class SVMBase(Base): self._fit_status_ = -1 # Attributes (parameters of the fitted model) - self._dual_coef_ = None - self._support_ = None - self._support_vectors_ = None - self._intercept_ = None - self._n_support_ = None + self.dual_coef_ = None + self.support_ = None + self.support_vectors_ = None + self.intercept_ = None + self.n_support_ = None self._c_kernel = self._get_c_kernel(kernel) self._gamma_val = None # the actual numerical value used for training - self._coef_ = None # value of the coef_ attribute, only for lin kernel + self.coef_ = None # value of the coef_ attribute, only for lin kernel self.dtype = None self._model = None # structure of the model parameters self._freeSvmBuffers = False # whether to call the C++ lib for cleanup @@ -298,8 +309,8 @@ class SVMBase(Base): return self.gamma def _calc_coef(self): - return cupy.dot(cupy.asarray(self._dual_coef_), - cupy.asarray(self._support_vectors_)) + with using_output_type("cupy"): + return cupy.dot(self.dual_coef_, self.support_vectors_) def _check_is_fitted(self, attr): if not hasattr(self, attr) or (getattr(self, attr) is None): @@ -308,15 +319,20 @@ class SVMBase(Base): raise NotFittedError(msg) @property + @cuml.internals.api_base_return_array_skipall def coef_(self): if self._c_kernel != LINEAR: - raise AttributeError("coef_ is only available for linear kernels") + raise RuntimeError("coef_ is only available for linear kernels") if self._model is None: raise RuntimeError("Call fit before prediction") - if self._coef_ is None: - self._coef_ = CumlArray(self._calc_coef()) - # Call the base class to perform the to_output conversion - return super().__getattr__("coef_") + if self._internal_coef_ is None: + self._internal_coef_ = self._calc_coef() + # Call the base class to perform the output conversion + return self._internal_coef_ + + @coef_.setter + def coef_(self, value): + self._internal_coef_ = value def _get_kernel_params(self, X=None): """ Wrap the kernel parameters in a KernelParams obtect """ @@ -342,6 +358,7 @@ class SVMBase(Base): param.svmType = self.svmType return param + @cuml.internals.api_base_return_any_skipall def _get_svm_model(self): """ Wrap the fitted model parameters into an svmModel structure. This is used if the model is loaded by pickle, the self._model struct @@ -349,42 +366,42 @@ class SVMBase(Base): """ cdef svmModel[float] *model_f cdef svmModel[double] *model_d - if self._dual_coef_ is None: + if self.dual_coef_ is None: # the model is not fitted in this case return None if self.dtype == np.float32: model_f = new svmModel[float]() - model_f.n_support = self._n_support_ + model_f.n_support = self.n_support_ model_f.n_cols = self.n_cols - model_f.b = self._intercept_ + model_f.b = self.intercept_.item() model_f.dual_coefs = \ - self._dual_coef_.ptr + self.dual_coef_.ptr model_f.x_support = \ - self._support_vectors_.ptr + self.support_vectors_.ptr model_f.support_idx = \ - self._support_.ptr + self.support_.ptr model_f.n_classes = self._n_classes if self._n_classes > 0: model_f.unique_labels = \ - self._unique_labels.ptr + self._unique_labels_.ptr else: model_f.unique_labels = NULL return model_f else: model_d = new svmModel[double]() - model_d.n_support = self._n_support_ + model_d.n_support = self.n_support_ model_d.n_cols = self.n_cols - model_d.b = self._intercept_ + model_d.b = self.intercept_.item() model_d.dual_coefs = \ - self._dual_coef_.ptr + self.dual_coef_.ptr model_d.x_support = \ - self._support_vectors_.ptr + self.support_vectors_.ptr model_d.support_idx = \ - self._support_.ptr + self.support_.ptr model_d.n_classes = self._n_classes if self._n_classes > 0: model_d.unique_labels = \ - self._unique_labels.ptr + self._unique_labels_.ptr else: model_d.unique_labels = NULL return model_d @@ -404,71 +421,71 @@ class SVMBase(Base): if model_f.n_support == 0: self._fit_status_ = 1 # incorrect fit return - self._intercept_ = model_f.b - self._n_support_ = model_f.n_support + self.intercept_ = CumlArray.full(1, model_f.b, np.float32) + self.n_support_ = model_f.n_support - self._dual_coef_ = CumlArray( + self.dual_coef_ = CumlArray( data=model_f.dual_coefs, - shape=(1, self._n_support_), + shape=(1, self.n_support_), dtype=self.dtype, order='F') - self._support_ = CumlArray( + self.support_ = CumlArray( data=model_f.support_idx, - shape=(self._n_support_,), + shape=(self.n_support_,), dtype=np.int32, order='F') - self._support_vectors_ = CumlArray( + self.support_vectors_ = CumlArray( data=model_f.x_support, - shape=(self._n_support_, self.n_cols), + shape=(self.n_support_, self.n_cols), dtype=self.dtype, order='F') self._n_classes = model_f.n_classes if self._n_classes > 0: - self._unique_labels = CumlArray( + self._unique_labels_ = CumlArray( data=model_f.unique_labels, shape=(self._n_classes,), dtype=self.dtype, order='F') else: - self._unique_labels = None + self._unique_labels_ = None else: model_d = self._model if model_d.n_support == 0: self._fit_status_ = 1 # incorrect fit return - self._intercept_ = model_d.b - self._n_support_ = model_d.n_support + self.intercept_ = CumlArray.full(1, model_d.b, np.float64) + self.n_support_ = model_d.n_support - self._dual_coef_ = CumlArray( + self.dual_coef_ = CumlArray( data=model_d.dual_coefs, - shape=(1, self._n_support_), + shape=(1, self.n_support_), dtype=self.dtype, order='F') - self._support_ = CumlArray( + self.support_ = CumlArray( data=model_d.support_idx, - shape=(self._n_support_,), + shape=(self.n_support_,), dtype=np.int32, order='F') - self._support_vectors_ = CumlArray( + self.support_vectors_ = CumlArray( data=model_d.x_support, - shape=(self._n_support_, self.n_cols), + shape=(self.n_support_, self.n_cols), dtype=self.dtype, order='F') self._n_classes = model_d.n_classes if self._n_classes > 0: - self._unique_labels = CumlArray( + self._unique_labels_ = CumlArray( data=model_d.unique_labels, shape=(self._n_classes,), dtype=self.dtype, order='F') else: - self._unique_labels = None + self._unique_labels_ = None - def predict(self, X, predict_class, convert_dtype=True): + def predict(self, X, predict_class, convert_dtype=True) -> CumlArray: """ Predicts the y for X, where y is either the decision function value (if predict_class == False), or the label associated with X. @@ -489,12 +506,13 @@ class SVMBase(Base): y : cuDF Series Dense vector (floats or doubles) of shape (n_samples, 1) """ - out_type = self._get_output_type(X) if predict_class: out_dtype = self._get_target_dtype() else: out_dtype = self.dtype + cuml.internals.set_api_output_dtype(out_dtype) + self._check_is_fitted('_model') X_m, n_rows, n_cols, pred_dtype = \ @@ -528,7 +546,7 @@ class SVMBase(Base): del(X_m) - return preds.to_output(output_type=out_type, output_dtype=out_dtype) + return preds def get_param_names(self): return super().get_param_names() + [ diff --git a/python/cuml/svm/svr.pyx b/python/cuml/svm/svr.pyx index 06711016be..14156d245e 100644 --- a/python/cuml/svm/svr.pyx +++ b/python/cuml/svm/svr.pyx @@ -234,12 +234,11 @@ class SVR(SVMBase, RegressorMixin): self.svmType = EPSILON_SVR @generate_docstring() - def fit(self, X, y, sample_weight=None, convert_dtype=True): + def fit(self, X, y, sample_weight=None, convert_dtype=True) -> "SVR": """ Fit the model with X and y. """ - self._set_base_attributes(output_type=X, n_features=X) cdef uintptr_t X_ptr, y_ptr X_m, self.n_rows, self.n_cols, self.dtype = \ @@ -263,7 +262,7 @@ class SVR(SVMBase, RegressorMixin): sample_weight_ptr = sample_weight_m.ptr self._dealloc() # delete any previously fitted model - self._coef_ = None + self.coef_ = None cdef KernelParams _kernel_params = self._get_kernel_params(X_m) cdef svmParameter param = self._get_svm_params() @@ -299,7 +298,7 @@ class SVR(SVMBase, RegressorMixin): 'type': 'dense', 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) - def predict(self, X, convert_dtype=True): + def predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts the values for X. diff --git a/python/cuml/test/conftest.py b/python/cuml/test/conftest.py index 1f0f79fc74..e2ee68200a 100644 --- a/python/cuml/test/conftest.py +++ b/python/cuml/test/conftest.py @@ -14,192 +14,16 @@ # limitations under the License. # -import os -import sys import cupy as cp -import cupyx import pytest -from pytest import Item from sklearn.datasets import fetch_20newsgroups from sklearn.feature_extraction.text import CountVectorizer -import numbers - - -# Stores incorrect uses of CumlArray on cuml.common.base.Base to print at the -# end -bad_cuml_array_loc = set() - - -def checked_isinstance(obj, class_name_dot_separated): - """ - Small helper function to check instance of object that doesn't import - class_path at import time, only at check time. Returns False if - class_path cannot be imported. - - Parameters: - ----------- - obj: Python object - object to check if it is instance of a class - class_name_dot_separated: list of str - List of classes to check whether object is an instance of, each item - can be a full dot separated class like - 'cuml.dask.preprocessing.LabelEncoder' - """ - ret = False - for class_path in class_name_dot_separated: - module_name, class_name = class_path.rsplit(".", 1) - module = sys.modules[module_name] - module_class = getattr(module, class_name, None) - - if module_class is not None: - ret = isinstance(obj, module_class) or ret - - return ret def pytest_configure(config): cp.cuda.set_allocator(None) -# Use the runtest_makereport hook to get the result of the test. This is -# necessary because pytest has some magic to extract the Cython source file -# from the traceback -@pytest.hookimpl(hookwrapper=True) -def pytest_runtest_makereport(item: Item, call): - - # Yield to the default implementation and get the result - outcome = yield - report = outcome.get_result() - - if (report.failed): - - # Save the abs path to this file. We will only mark bad CumlArray uses - # if the assertion failure comes from this file - conf_test_path = os.path.abspath(__file__) - - found_assert = False - - # Ensure these attributes exist. They can be missing if something else - # failed outside of the test - if (hasattr(report.longrepr, "reprtraceback") - and hasattr(report.longrepr.reprtraceback, "reprentries")): - - for entry in reversed(report.longrepr.reprtraceback.reprentries): - - if (not found_assert and - entry.reprfileloc.message.startswith("AssertionError") - and os.path.abspath( - entry.reprfileloc.path) == conf_test_path): - found_assert = True - elif (found_assert): - true_path = "{}:{}".format(entry.reprfileloc.path, - entry.reprfileloc.lineno) - - bad_cuml_array_loc.add( - (true_path, entry.reprfileloc.message)) - - break - - -# Closing hook to display the file/line numbers at the end of the test -def pytest_unconfigure(config): - def split_exists(filename: str) -> bool: - strip_colon = filename[:filename.rfind(":")] - return os.path.exists(strip_colon) - - if (len(bad_cuml_array_loc) > 0): - - print("Incorrect CumlArray uses in class derived from " - "cuml.common.base.Base:") - - prefix = "" - - # Depending on where pytest was launched from, it may need to append - # "python" - if (not os.path.basename(os.path.abspath( - os.curdir)).endswith("python")): - prefix = "python" - - for location, message in bad_cuml_array_loc: - - combined_path = os.path.abspath(location) - - # Try appending prefix if that file doesnt exist - if (not split_exists(combined_path)): - combined_path = os.path.abspath(os.path.join(prefix, location)) - - # If that still doesnt exist, just use the original - if (not split_exists(combined_path)): - combined_path = location - - print("{} {}".format(combined_path, message)) - - print("See https://github.com/rapidsai/cuml/issues/2456#issuecomment-666106406" # noqa - " for more information on naming conventions") - - -# This fixture will monkeypatch cuml.common.base.Base to check for incorrect -# uses of CumlArray. -@pytest.fixture(autouse=True) -def fail_on_bad_cuml_array_name(monkeypatch, request): - - if 'no_bad_cuml_array_check' in request.keywords: - return - - from cuml.common import CumlArray - from cuml.common.base import Base - from cuml.common.input_utils import get_supported_input_type - - def patched__setattr__(self, name, value): - - if name == 'classes_' and \ - checked_isinstance(self, - ['cuml.dask.preprocessing.LabelEncoder', - 'cuml.preprocessing.LabelEncoder']): - # For label encoder, classes_ stores the set of unique classes - # which is strings, and can't be saved as cuml array - # even called `get_supported_input_type` causes a failure. - pass - else: - supported_type = get_supported_input_type(value) - - if name == 'idf_': - # We skip this test because idf_' for tfidf setter returns - # a sparse diagonal matrix and getter gets a cupy array - # see discussion at: - # https://github.com/rapidsai/cuml/pull/2698/files#r471865982 - pass - elif (supported_type == CumlArray): - assert name.startswith("_"), "Invalid CumlArray Use! CumlArray \ - attributes need a leading underscore. Attribute: '{}' In: {}" \ - .format(name, self.__repr__()) - elif (supported_type == cp.ndarray and - cupyx.scipy.sparse.issparse(value)): - # Leave sparse matrices alone for now. - pass - elif (supported_type is not None): - if not isinstance(value, numbers.Number): - # Is this an estimated property? - # If so, should always be CumlArray - assert not name.endswith("_"), "Invalid Estimated Array-Like \ - Attribute! Estimated attributes should always be \ - CumlArray. \ - Attribute: '{}' In: {}".format(name, self.__repr__()) - assert not name.startswith("_"), "Invalid Public Array-Like \ - Attribute! Public array-like attributes should always \ - be CumlArray. Attribute: '{}' In: {}".format( - name, self.__repr__()) - else: - # Estimated properties can be numbers - pass - - return super(Base, self).__setattr__(name, value) - - # Monkeypatch CumlArray.__setattr__ to test for incorrect uses of - # array-like objects - monkeypatch.setattr(Base, "__setattr__", patched__setattr__) - - @pytest.fixture(scope="module") def nlp_20news(): twenty_train = fetch_20newsgroups(subset='train', diff --git a/python/cuml/test/test_array.py b/python/cuml/test/test_array.py index 1840191c17..4d4847e779 100644 --- a/python/cuml/test/test_array.py +++ b/python/cuml/test/test_array.py @@ -281,7 +281,7 @@ def test_create_empty(shape, dtype, order): else: assert ary.shape == shape assert ary.dtype == np.dtype(dtype) - assert isinstance(ary._owner, DeviceBuffer) + assert isinstance(ary._owner.data.mem._owner, DeviceBuffer) @pytest.mark.parametrize('shape', test_shapes) diff --git a/python/cuml/test/test_cuml_descr_decor.py b/python/cuml/test/test_cuml_descr_decor.py new file mode 100644 index 0000000000..cc8945df81 --- /dev/null +++ b/python/cuml/test/test_cuml_descr_decor.py @@ -0,0 +1,349 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pickle + +import cuml +import cuml.internals +import cupy as cp +import numpy as np +import pytest +from cuml.common.array import CumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.common.input_utils import determine_array_dtype +from cuml.common.input_utils import determine_array_type +from cuml.common.input_utils import input_to_cuml_array +from cuml.common.input_utils import unsupported_cudf_dtypes + +test_input_types = ['numpy', 'numba', 'cupy', 'cudf'] + +test_output_types_str = ['numpy', 'numba', 'cupy', 'cudf'] + +test_dtypes_short = [ + np.uint8, + np.float16, + np.int32, + np.float64, +] + +test_shapes = [10, (10, 1), (10, 5), (1, 10)] + + +class TestEstimator(cuml.Base): + + input_any_ = CumlArrayDescriptor() + + def _set_input(self, X): + self.input_any_ = X + + @cuml.internals.api_base_return_any() + def store_input(self, X): + self.input_any_ = X + + @cuml.internals.api_return_any() + def get_input(self): + return self.input_any_ + + # === Standard Functions === + def fit(self, X, convert_dtype=True) -> "TestEstimator": + + return self + + def predict(self, X, convert_dtype=True) -> CumlArray: + + return X + + def transform(self, X, convert_dtype=False) -> CumlArray: + + pass + + def fit_transform(self, X, y=None) -> CumlArray: + + return self.fit(X).transform(X) + + +def array_identical(a, b): + + cupy_a = input_to_cuml_array(a, order="K").array + cupy_b = input_to_cuml_array(b, order="K").array + + if len(a) == 0 and len(b) == 0: + return True + + if (cupy_a.shape != cupy_b.shape): + return False + + if (cupy_a.dtype != cupy_b.dtype): + return False + + if (cupy_a.order != cupy_b.order): + return False + + return cp.all(cp.asarray(cupy_a) == cp.asarray(cupy_b)).item() + + +def create_input(input_type, input_dtype, input_shape, input_order): + rand_ary = cp.ones(input_shape, dtype=input_dtype, order=input_order) + + cuml_ary = CumlArray(rand_ary) + + return cuml_ary.to_output(input_type) + + +def create_output(X_in, output_type): + + cuml_ary_tuple = input_to_cuml_array(X_in, order="K") + + return cuml_ary_tuple.array.to_output(output_type) + + +@pytest.mark.parametrize('input_type', test_input_types) +def test_pickle(input_type): + + if (input_type == "numba"): + pytest.skip("numba arrays cant be picked at this time") + + est = TestEstimator() + + X_in = create_input(input_type, np.float32, (10, 5), "C") + + est.store_input(X_in) + + # Loop and verify we have filled the cache + for out_type in test_output_types_str: + with cuml.using_output_type(out_type): + assert array_identical(est.input_any_, + create_output(X_in, out_type)) + + est_pickled_bytes = pickle.dumps(est) + est_unpickled: TestEstimator = pickle.loads(est_pickled_bytes) + + # Assert that we only resture the input + assert est_unpickled.__dict__["input_any_"].input_type == input_type + assert len(est_unpickled.__dict__["input_any_"].values) == 1 + + assert array_identical(est.get_input(), est_unpickled.get_input()) + assert array_identical(est.input_any_, est_unpickled.input_any_) + + # Loop one more time with the picked one to make sure it works right + for out_type in test_output_types_str: + with cuml.using_output_type(out_type): + assert array_identical(est.input_any_, + create_output(X_in, out_type)) + + est_unpickled.output_type = out_type + + assert array_identical(est_unpickled.input_any_, + create_output(X_in, out_type)) + + +@pytest.mark.parametrize('input_type', test_input_types) +@pytest.mark.parametrize('input_dtype', [np.float32, np.int16]) +@pytest.mark.parametrize('input_shape', [10, (10, 5)]) +@pytest.mark.parametrize('output_type', test_output_types_str) +def test_dec_input_output(input_type, input_dtype, input_shape, output_type): + + if (input_type == "cudf" or output_type == "cudf"): + if (input_dtype in unsupported_cudf_dtypes): + pytest.skip("Unsupported cudf combination") + + X_in = create_input(input_type, input_dtype, input_shape, "C") + X_out = create_output(X_in, output_type) + + # Test with output_type="input" + est = TestEstimator(output_type="input") + + est.store_input(X_in) + + # Test is was stored internally correctly + assert X_in is est.get_input() + + assert est.__dict__["input_any_"].input_type == input_type + + # Check the current type matches input type + assert determine_array_type(est.input_any_) == input_type + + assert array_identical(est.input_any_, X_in) + + # Switch output type and check type and equality + with cuml.using_output_type(output_type): + + assert determine_array_type(est.input_any_) == output_type + + assert array_identical(est.input_any_, X_out) + + # Now Test with output_type=output_type + est = TestEstimator(output_type=output_type) + + est.store_input(X_in) + + # Check the current type matches output type + assert determine_array_type(est.input_any_) == output_type + + assert array_identical(est.input_any_, X_out) + + with cuml.using_output_type("input"): + + assert determine_array_type(est.input_any_) == input_type + + assert array_identical(est.input_any_, X_in) + + +@pytest.mark.parametrize('input_type', test_input_types) +@pytest.mark.parametrize('input_dtype', [np.float32, np.int16]) +@pytest.mark.parametrize('input_shape', test_shapes) +def test_auto_fit(input_type, input_dtype, input_shape): + """ + Test autowrapping on fit that will set output_type, and n_features + """ + X_in = create_input(input_type, input_dtype, input_shape, "C") + + # Test with output_type="input" + est = TestEstimator() + + est.fit(X_in) + + def calc_n_features(shape): + if (isinstance(shape, tuple) and len(shape) >= 1): + + # When cudf and shape[1] is used, a series is created which will + # remove the last shape + if (input_type == "cudf" and shape[1] == 1): + return None + + return shape[1] + + return None + + assert est._input_type == input_type + assert est.target_dtype is None + assert est.n_features_in_ == calc_n_features(input_shape) + + +@pytest.mark.parametrize('input_type', test_input_types) +@pytest.mark.parametrize('base_output_type', test_input_types) +@pytest.mark.parametrize('global_output_type', + test_output_types_str + ["input", None]) +def test_auto_predict(input_type, base_output_type, global_output_type): + """ + Test autowrapping on predict that will set target_type + """ + X_in = create_input(input_type, np.float32, (10, 10), "F") + + # Test with output_type="input" + est = TestEstimator() + + # With cuml.global_output_type == None, this should return the input type + X_out = est.predict(X_in) + + assert determine_array_type(X_out) == input_type + + assert array_identical(X_in, X_out) + + # Test with output_type=base_output_type + est = TestEstimator(output_type=base_output_type) + + # With cuml.global_output_type == None, this should return the + # base_output_type + X_out = est.predict(X_in) + + assert determine_array_type(X_out) == base_output_type + + assert array_identical(X_in, X_out) + + # Test with global_output_type, should return global_output_type + with cuml.using_output_type(global_output_type): + X_out = est.predict(X_in) + + target_output_type = global_output_type + + if (target_output_type is None or target_output_type == "input"): + target_output_type = base_output_type + + if (target_output_type == "input"): + target_output_type = input_type + + assert determine_array_type(X_out) == target_output_type + + assert array_identical(X_in, X_out) + + +@pytest.mark.parametrize('input_arg', ["X", "y", "bad", ...]) +@pytest.mark.parametrize('target_arg', ["X", "y", "bad", ...]) +@pytest.mark.parametrize('get_output_type', [True, False]) +@pytest.mark.parametrize('get_output_dtype', [True, False]) +def test_return_array(input_arg: str, + target_arg: str, + get_output_type: bool, + get_output_dtype: bool): + """ + Test autowrapping on predict that will set target_type + """ + + input_type_X = "numpy" + input_dtype_X = np.float64 + + input_type_Y = "cupy" + input_dtype_Y = np.int32 + + inner_type = "numba" + inner_dtype = np.float16 + + X_in = create_input(input_type_X, input_dtype_X, (10, 10), "F") + Y_in = create_input(input_type_Y, input_dtype_Y, (10, 10), "F") + + def test_func(X, y): + + if (not get_output_type): + cuml.internals.set_api_output_type(inner_type) + + if (not get_output_dtype): + cuml.internals.set_api_output_dtype(inner_dtype) + + return X + + if (input_arg == "bad" or target_arg == "bad"): + pytest.xfail("Expected error with bad arg name") + + test_func = cuml.internals.api_return_array( + input_arg=input_arg, + target_arg=target_arg, + get_output_type=get_output_type, + get_output_dtype=get_output_dtype)(test_func) + + X_out = test_func(X=X_in, y=Y_in) + + target_type = None + target_dtype = None + + if (not get_output_type): + target_type = inner_type + else: + if (input_arg == "y"): + target_type = input_type_Y + else: + target_type = input_type_X + + if (not get_output_dtype): + target_dtype = inner_dtype + else: + if (target_arg == "X"): + target_dtype = input_dtype_X + else: + target_dtype = input_dtype_Y + + assert determine_array_type(X_out) == target_type + + assert determine_array_dtype(X_out) == target_dtype diff --git a/python/cuml/test/test_dbscan.py b/python/cuml/test/test_dbscan.py index bb46a4505f..e4eb6367e9 100644 --- a/python/cuml/test/test_dbscan.py +++ b/python/cuml/test/test_dbscan.py @@ -26,7 +26,6 @@ from sklearn.preprocessing import StandardScaler from sklearn.metrics import adjusted_rand_score - dataset_names = ['noisy_moons', 'varied', 'aniso', 'blobs', 'noisy_circles', 'no_structure'] diff --git a/python/cuml/test/test_fit_function.py b/python/cuml/test/test_fit_function.py index 6a0059dab5..8693e9484c 100644 --- a/python/cuml/test/test_fit_function.py +++ b/python/cuml/test/test_fit_function.py @@ -8,7 +8,11 @@ def func_positional_arg(func): - if hasattr(func, "__code__"): + + if hasattr(func, "__wrapped__"): + return func_positional_arg(func.__wrapped__) + + elif hasattr(func, "__code__"): all_args = func.__code__.co_argcount if func.__defaults__ is not None: kwargs = len(func.__defaults__) @@ -37,7 +41,8 @@ def test_fit_function(dataset, model_name): "TSNE", "TruncatedSVD", "AutoARIMA", - "MultinomialNB" + "MultinomialNB", + "LabelEncoder", ]: pytest.xfail("These models are not tested yet") diff --git a/python/cuml/test/test_input_utils.py b/python/cuml/test/test_input_utils.py index 606b3cb078..1a5446e7fb 100644 --- a/python/cuml/test/test_input_utils.py +++ b/python/cuml/test/test_input_utils.py @@ -22,7 +22,6 @@ from pandas import DataFrame as pdDF from cuml.common import input_to_cuml_array, CumlArray -from cuml.common import input_to_dev_array from cuml.common import input_to_host_array from cuml.common import has_cupy from cuml.common.input_utils import convert_dtype @@ -205,15 +204,15 @@ def test_dtype_check(dtype, check_dtype, input_type, order): pytest.skip('cupy not installed') if dtype == check_dtype: - _, _, _, _, got_dtype = \ - input_to_dev_array(input_data, check_dtype=check_dtype, - order=order) + _, _, _, got_dtype = \ + input_to_cuml_array(input_data, check_dtype=check_dtype, + order=order) assert got_dtype == check_dtype else: with pytest.raises(TypeError): - _, _, _, _, got_dtype = \ - input_to_dev_array(input_data, check_dtype=check_dtype, - order=order) + _, _, _, got_dtype = \ + input_to_cuml_array(input_data, check_dtype=check_dtype, + order=order) @pytest.mark.parametrize('num_rows', test_num_rows) diff --git a/python/cuml/test/test_kneighbors_classifier.py b/python/cuml/test/test_kneighbors_classifier.py index 8882514696..73f5198092 100644 --- a/python/cuml/test/test_kneighbors_classifier.py +++ b/python/cuml/test/test_kneighbors_classifier.py @@ -18,6 +18,7 @@ import cudf +import cuml from cuml.neighbors import KNeighborsClassifier as cuKNN from sklearn.neighbors import KNeighborsClassifier as skKNN @@ -224,10 +225,10 @@ def test_predict_non_gaussian(n_samples, n_features, n_neighbors, n_query): knn_cuml = cuKNN(n_neighbors=n_neighbors) knn_cuml.fit(X_device_train, y_device_train) - cuml_result = knn_cuml.predict(X_device_test) + with cuml.using_output_type("numpy"): + cuml_result = knn_cuml.predict(X_device_test) - assert np.array_equal( - np.asarray(cuml_result.to_gpu_array()), sk_result) + assert np.array_equal(cuml_result, sk_result) @pytest.mark.parametrize("n_classes", [2, 5]) diff --git a/python/cuml/test/test_module_config.py b/python/cuml/test/test_module_config.py index 1d2c959d0b..3bb7db4b23 100644 --- a/python/cuml/test/test_module_config.py +++ b/python/cuml/test/test_module_config.py @@ -44,6 +44,16 @@ } +@pytest.fixture(scope="function", params=global_input_configs) +def global_output_type(request): + + output_type = request.param + + yield output_type + + # Ensure we reset the type at the end of the test + cuml.set_global_output_type(None) + ############################################################################### # Tests # ############################################################################### @@ -65,30 +75,28 @@ def test_default_global_output_type(input_type): assert isinstance(res, test_output_types[input_type]) -@pytest.mark.parametrize('global_type', global_input_configs) @pytest.mark.parametrize('input_type', global_input_types) -def test_global_output_type(global_type, input_type): +def test_global_output_type(global_output_type, input_type): dataset = get_small_dataset(input_type) - cuml.set_global_output_type(global_type) + cuml.set_global_output_type(global_output_type) dbscan_float = cuml.DBSCAN(eps=1.0, min_samples=1) dbscan_float.fit(dataset) res = dbscan_float.labels_ - if global_type == 'numba': + if global_output_type == 'numba': assert is_cuda_array(res) else: - assert isinstance(res, test_output_types[global_type]) + assert isinstance(res, test_output_types[global_output_type]) -@pytest.mark.parametrize('global_type', global_input_configs) @pytest.mark.parametrize('context_type', global_input_configs) -def test_output_type_context_mgr(global_type, context_type): +def test_output_type_context_mgr(global_output_type, context_type): dataset = get_small_dataset('numba') - test_type = 'cupy' if global_type != 'cupy' else 'numpy' + test_type = 'cupy' if global_output_type != 'cupy' else 'numpy' cuml.set_global_output_type(test_type) # use cuml context manager @@ -111,9 +119,6 @@ def test_output_type_context_mgr(global_type, context_type): res = dbscan_float.labels_ assert isinstance(res, test_output_types[test_type]) - # reset cuml global output type to 'input' for further tests - cuml.set_global_output_type('input') - ############################################################################### # Utility Functions # diff --git a/python/cuml/test/test_nearest_neighbors.py b/python/cuml/test/test_nearest_neighbors.py index 1e1d07db18..d5bd9ea0b9 100644 --- a/python/cuml/test/test_nearest_neighbors.py +++ b/python/cuml/test/test_nearest_neighbors.py @@ -33,7 +33,6 @@ import sklearn import cuml from cuml.common import has_scipy -from cuml.common.array import CumlArray def predict(neigh_ind, _y, n_neighbors): @@ -147,9 +146,10 @@ def test_knn_separate_index_search(input_type, nrows, n_feats, k, metric): D_cuml_arr = D_cuml I_cuml_arr = I_cuml - # Assert the cuml model was properly reverted - np.testing.assert_allclose(knn_cu._X_m.to_output("numpy"), X_orig, - atol=1e-3, rtol=1e-3) + with cuml.using_output_type("numpy"): + # Assert the cuml model was properly reverted + np.testing.assert_allclose(knn_cu.X_m, X_orig, + atol=1e-3, rtol=1e-3) if metric == 'braycurtis': diff = D_cuml_arr - D_sk @@ -199,26 +199,6 @@ def test_knn_x_none(input_type, nrows, n_feats, k, metric): assert I_cuml.all() == I_sk.all() -@pytest.mark.parametrize('input_type', ['dataframe', 'ndarray']) -def test_knn_return_cumlarray(input_type): - n_samples = 50 - n_feats = 50 - k = 5 - - X, _ = make_blobs(n_samples=n_samples, - n_features=n_feats, random_state=0) - - if input_type == "dataframe": - X = cudf.DataFrame(X) - - knn_cu = cuKNN() - knn_cu.fit(X) - indices, distances = knn_cu._kneighbors(X, k, _output_cumlarray=True) - - assert isinstance(indices, CumlArray) - assert isinstance(distances, CumlArray) - - def test_knn_fit_twice(): """ Test that fitting a model twice does not fail. diff --git a/python/cuml/test/test_pickle.py b/python/cuml/test/test_pickle.py index 268b4a3247..af8caf400b 100644 --- a/python/cuml/test/test_pickle.py +++ b/python/cuml/test/test_pickle.py @@ -364,6 +364,7 @@ def test_unfit_clone(model_name): # Cloning runs into many of the same problems as pickling mod = all_models[model_name]() + clone(mod) # TODO: check parameters exactly? @@ -422,7 +423,7 @@ def assert_model(pickled_model, X_test): assert array_equal(result["neighbors"], D_after) state = pickled_model.__dict__ assert state["n_indices"] == 1 - assert "_X_m" in state + assert "X_m" in state pickle_save_load(tmpdir, create_mod, assert_model) @@ -449,13 +450,13 @@ def create_mod(): def assert_model(loaded_model, X): state = loaded_model.__dict__ assert state["n_indices"] == 0 - assert "_X_m" not in state + assert "X_m" not in state loaded_model.fit(X[0]) state = loaded_model.__dict__ assert state["n_indices"] == 1 - assert "_X_m" in state + assert "X_m" in state pickle_save_load(tmpdir, create_mod, assert_model) @@ -509,7 +510,7 @@ def assert_model(pickled_model, X): result["fit_model"] = pickled_model.fit(X) result["data"] = X result["trust"] = trustworthiness( - X, pickled_model._embedding_.to_output('numpy'), 10) + X, pickled_model.embedding_, 10) def create_mod_2(): model = result["fit_model"] @@ -517,7 +518,7 @@ def create_mod_2(): def assert_second_model(pickled_model, X): trust_after = trustworthiness( - X, pickled_model._embedding_.to_output('numpy'), 10) + X, pickled_model.embedding_, 10) assert result["trust"] == trust_after pickle_save_load(tmpdir, create_mod, assert_model) diff --git a/python/cuml/test/test_preproc_utils.py b/python/cuml/test/test_preproc_utils.py index a96ee71976..3d8bb5b962 100644 --- a/python/cuml/test/test_preproc_utils.py +++ b/python/cuml/test/test_preproc_utils.py @@ -16,7 +16,7 @@ import pytest from cuml.datasets import make_classification, make_blobs -from ..thirdparty_adapters import to_output_type +from cuml.thirdparty_adapters import to_output_type from numpy.testing import assert_allclose as np_assert_allclose import numpy as np diff --git a/python/cuml/test/test_preprocessing.py b/python/cuml/test/test_preprocessing.py index 6ab2eb901e..ff22ba27cc 100644 --- a/python/cuml/test/test_preprocessing.py +++ b/python/cuml/test/test_preprocessing.py @@ -15,16 +15,17 @@ import pytest -from ..experimental.preprocessing import StandardScaler as cuStandardScaler, \ - MinMaxScaler as cuMinMaxScaler, \ - MaxAbsScaler as cuMaxAbsScaler, \ - Normalizer as cuNormalizer, \ - Binarizer as cuBinarizer, \ - PolynomialFeatures as cuPolynomialFeatures, \ - SimpleImputer as cuSimpleImputer, \ - RobustScaler as cuRobustScaler, \ - KBinsDiscretizer as cuKBinsDiscretizer -from ..experimental.preprocessing import scale as cu_scale, \ +from cuml.experimental.preprocessing import \ + StandardScaler as cuStandardScaler, \ + MinMaxScaler as cuMinMaxScaler, \ + MaxAbsScaler as cuMaxAbsScaler, \ + Normalizer as cuNormalizer, \ + Binarizer as cuBinarizer, \ + PolynomialFeatures as cuPolynomialFeatures, \ + SimpleImputer as cuSimpleImputer, \ + RobustScaler as cuRobustScaler, \ + KBinsDiscretizer as cuKBinsDiscretizer +from cuml.experimental.preprocessing import scale as cu_scale, \ minmax_scale as cu_minmax_scale, \ normalize as cu_normalize, \ add_dummy_feature as cu_add_dummy_feature, \ @@ -46,18 +47,20 @@ from sklearn.impute import SimpleImputer as skSimpleImputer from sklearn.preprocessing import KBinsDiscretizer as skKBinsDiscretizer -from ..thirdparty_adapters.sparsefuncs_fast import csr_mean_variance_axis0, \ - csc_mean_variance_axis0, \ - _csc_mean_variance_axis0, \ - inplace_csr_row_normalize_l1, \ - inplace_csr_row_normalize_l2 - -from .test_preproc_utils import clf_dataset, int_dataset, blobs_dataset, \ - sparse_clf_dataset, \ - sparse_blobs_dataset, \ - sparse_int_dataset # noqa: F401 -from .test_preproc_utils import assert_allclose -from ..common.import_utils import check_cupy8 +from cuml.thirdparty_adapters.sparsefuncs_fast import \ + csr_mean_variance_axis0, \ + csc_mean_variance_axis0, \ + _csc_mean_variance_axis0, \ + inplace_csr_row_normalize_l1, \ + inplace_csr_row_normalize_l2 + +from cuml.test.test_preproc_utils import \ + clf_dataset, int_dataset, blobs_dataset, \ + sparse_clf_dataset, \ + sparse_blobs_dataset, \ + sparse_int_dataset # noqa: F401 +from cuml.test.test_preproc_utils import assert_allclose +from cuml.common.import_utils import check_cupy8 import numpy as np import cupy as cp diff --git a/python/cuml/test/test_prims.py b/python/cuml/test/test_prims.py index ddbdd7a0ca..3ac8aee0be 100644 --- a/python/cuml/test/test_prims.py +++ b/python/cuml/test/test_prims.py @@ -42,24 +42,24 @@ def test_monotonic_validate_invert_labels(arr_type, dtype, copy): cp.cuda.Stream.null.synchronize() - assert array_equal(monotonic.get(), np.array([0, 2, 1, 4, 3, 4])) + assert array_equal(monotonic, np.array([0, 2, 1, 4, 3, 4])) # We only care about in-place updating if data is on device if arr_type == "cp": if copy: - assert array_equal(arr_orig.get(), arr.get()) + assert array_equal(arr_orig, arr) else: - assert array_equal(arr.get(), monotonic.get()) + assert array_equal(arr, monotonic) wrong_classes = cp.asarray([0, 1, 2], dtype=dtype) - val_labels = check_labels(monotonic.get(), classes=wrong_classes) + val_labels = check_labels(monotonic, classes=wrong_classes) cp.cuda.Stream.null.synchronize() assert not val_labels correct_classes = cp.asarray([0, 1, 2, 3, 4], dtype=dtype) - val_labels = check_labels(monotonic.get(), classes=correct_classes) + val_labels = check_labels(monotonic, classes=correct_classes) cp.cuda.Stream.null.synchronize() @@ -76,8 +76,8 @@ def test_monotonic_validate_invert_labels(arr_type, dtype, copy): if arr_type == "cp": if copy: - assert array_equal(monotonic_copy.get(), monotonic.get()) + assert array_equal(monotonic_copy, monotonic) else: - assert array_equal(monotonic.get(), arr_orig.get()) + assert array_equal(monotonic, arr_orig) - assert array_equal(inverted.get(), original) + assert array_equal(inverted, original) diff --git a/python/cuml/test/test_random_forest.py b/python/cuml/test/test_random_forest.py index 686a557149..1272750e91 100644 --- a/python/cuml/test/test_random_forest.py +++ b/python/cuml/test/test_random_forest.py @@ -23,6 +23,7 @@ from numba import cuda +import cuml from cuml.ensemble import RandomForestClassifier as curfc from cuml.ensemble import RandomForestRegressor as curfr from cuml.metrics import r2_score @@ -525,11 +526,11 @@ def test_rf_classification_sparse(small_clf, datatype, fil_acc = accuracy_score(y_test, fil_preds) fil_model = cuml_model.convert_to_fil_model() - input_type = 'numpy' - fil_model_preds = fil_model.predict(X_test, - output_type=input_type) - fil_model_acc = accuracy_score(y_test, fil_model_preds) - assert fil_acc == fil_model_acc + + with cuml.using_output_type("numpy"): + fil_model_preds = fil_model.predict(X_test) + fil_model_acc = accuracy_score(y_test, fil_model_preds) + assert fil_acc == fil_model_acc tl_model = cuml_model.convert_to_treelite_model() assert num_treees == tl_model.num_trees @@ -588,13 +589,12 @@ def test_rf_regression_sparse(special_reg, datatype, fil_sparse_format, algo): fil_model = cuml_model.convert_to_fil_model() - input_type = 'numpy' - fil_model_preds = fil_model.predict(X_test, - output_type=input_type) - fil_model_preds = np.reshape(fil_model_preds, np.shape(y_test)) - fil_model_r2 = r2_score(y_test, fil_model_preds, - convert_dtype=datatype) - assert fil_r2 == fil_model_r2 + with cuml.using_output_type("numpy"): + fil_model_preds = fil_model.predict(X_test) + fil_model_preds = np.reshape(fil_model_preds, np.shape(y_test)) + fil_model_r2 = r2_score(y_test, fil_model_preds, + convert_dtype=datatype) + assert fil_r2 == fil_model_r2 tl_model = cuml_model.convert_to_treelite_model() assert num_treees == tl_model.num_trees diff --git a/python/cuml/thirdparty_adapters/adapters.py b/python/cuml/thirdparty_adapters/adapters.py index f9666eacd8..ba6aaf4527 100644 --- a/python/cuml/thirdparty_adapters/adapters.py +++ b/python/cuml/thirdparty_adapters/adapters.py @@ -16,6 +16,7 @@ import numpy as np import cupy as cp from cuml.common import input_to_cuml_array +from cuml.common.input_utils import input_to_cupy_array from cupy.sparse import csr_matrix as gpu_csr_matrix from cupy.sparse import csc_matrix as gpu_csc_matrix from cupy.sparse import csc_matrix as gpu_coo_matrix @@ -274,10 +275,9 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True, new_array = new_array.astype(correct_dtype) return new_array else: - X, n_rows, n_cols, dtype = input_to_cuml_array(array, + X, n_rows, n_cols, dtype = input_to_cupy_array(array, order=order, deepcopy=copy) - X = X.to_output('cupy') if correct_dtype != dtype: X = X.astype(correct_dtype) check_finite(X, force_all_finite) diff --git a/python/cuml/tsa/arima.pyx b/python/cuml/tsa/arima.pyx index 7404aa53ef..ddc5f7c7a2 100644 --- a/python/cuml/tsa/arima.pyx +++ b/python/cuml/tsa/arima.pyx @@ -25,7 +25,9 @@ from libcpp cimport bool from libcpp.vector cimport vector from typing import List, Tuple, Dict, Mapping, Optional, Union -from cuml.common.array import CumlArray as cumlArray +import cuml.internals +from cuml.common.array import CumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.base import Base from cuml.common.cuda import nvtx_range_wrap from cuml.raft.common.handle cimport handle_t @@ -109,16 +111,16 @@ cdef class ARIMAParamsWrapper: cdef ARIMAOrder order = model.order cdef uintptr_t d_mu_ptr = \ - model._mu_.ptr if order.k else NULL + model.mu_.ptr if order.k else NULL cdef uintptr_t d_ar_ptr = \ - model._ar_.ptr if order.p else NULL + model.ar_.ptr if order.p else NULL cdef uintptr_t d_ma_ptr = \ - model._ma_.ptr if order.q else NULL + model.ma_.ptr if order.q else NULL cdef uintptr_t d_sar_ptr = \ - model._sar_.ptr if order.P else NULL + model.sar_.ptr if order.P else NULL cdef uintptr_t d_sma_ptr = \ - model._sma_.ptr if order.Q else NULL - cdef uintptr_t d_sigma2_ptr = model._sigma2_.ptr + model.sma_.ptr if order.Q else NULL + cdef uintptr_t d_sigma2_ptr = model.sigma2_.ptr self.params.mu = d_mu_ptr self.params.ar = d_ar_ptr @@ -257,6 +259,17 @@ class ARIMA(Base): """ + d_y = CumlArrayDescriptor() + # TODO: (MDD) Should this be public? Its not listed in the attributes doc + _d_y_diff = CumlArrayDescriptor() + + mu_ = CumlArrayDescriptor() + ar_ = CumlArrayDescriptor() + ma_ = CumlArrayDescriptor() + sar_ = CumlArrayDescriptor() + sma_ = CumlArrayDescriptor() + sigma2_ = CumlArrayDescriptor() + def __init__(self, endog, order: Tuple[int, int, int] = (1, 1, 1), @@ -305,7 +318,7 @@ class ARIMA(Base): "Required: max(p+s*P, q+s*Q) <= 1024") # Get device array. Float64 only for now. - self._d_y, self.n_obs, self.batch_size, self.dtype \ + self.d_y, self.n_obs, self.batch_size, self.dtype \ = input_to_cuml_array(endog, check_dtype=np.float64) if self.n_obs < d + s * D + 1: @@ -314,21 +327,32 @@ class ARIMA(Base): self.simple_differencing = simple_differencing - # Compute the differenced series - self._d_y_diff = cumlArray.empty( + self._d_y_diff = CumlArray.empty( (self.n_obs - d - s * D, self.batch_size), self.dtype) - cdef uintptr_t d_y_ptr = self._d_y.ptr + + self.n_obs_diff = self.n_obs - d - D * s + + self._initial_calc() + + @cuml.internals.api_base_return_any_skipall + def _initial_calc(self): + """ + This separates the initial calculation from the initialization to make + the CumlArrayDescriptors work + """ + + # Compute the differenced series + cdef uintptr_t d_y_ptr = self.d_y.ptr cdef uintptr_t d_y_diff_ptr = self._d_y_diff.ptr cdef handle_t* handle_ = self.handle.getHandle() batched_diff(handle_[0], d_y_diff_ptr, d_y_ptr, self.batch_size, self.n_obs, self.order) # Create a version of the order for the differenced series - cdef ARIMAOrder cpp_order_diff = cpp_order + cdef ARIMAOrder cpp_order_diff = self.order cpp_order_diff.d = 0 cpp_order_diff.D = 0 self.order_diff = cpp_order_diff - self.n_obs_diff = self.n_obs - d - D * s def __str__(self): cdef ARIMAOrder order = self.order @@ -342,6 +366,7 @@ class ARIMA(Base): order.p, order.d, order.q, intercept_str, self.batch_size) @nvtx_range_wrap + @cuml.internals.api_base_return_any_skipall def _ic(self, ic_type: str): """Wrapper around C++ information_criterion """ @@ -351,10 +376,10 @@ class ARIMA(Base): self.order_diff if self.simple_differencing else self.order cdef ARIMAParams[double] cpp_params = ARIMAParamsWrapper(self).params - ic = cumlArray.empty(self.batch_size, self.dtype) + ic = CumlArray.empty(self.batch_size, self.dtype) cdef uintptr_t d_ic_ptr = ic.ptr cdef uintptr_t d_y_kf_ptr = \ - self._d_y_diff.ptr if self.simple_differencing else self._d_y.ptr + self._d_y_diff.ptr if self.simple_differencing else self.d_y.ptr n_obs_kf = (self.n_obs_diff if self.simple_differencing else self.n_obs) @@ -371,20 +396,20 @@ class ARIMA(Base): order_kf, cpp_params, d_ic_ptr, ic_type_id) - return ic.to_output(self.output_type) + return ic @property - def aic(self): + def aic(self) -> CumlArray: """Akaike Information Criterion""" return self._ic("aic") @property - def aicc(self): + def aicc(self) -> CumlArray: """Corrected Akaike Information Criterion""" return self._ic("aicc") @property - def bic(self): + def bic(self) -> CumlArray: """Bayesian Information Criterion""" return self._ic("bic") @@ -394,7 +419,8 @@ class ARIMA(Base): cdef ARIMAOrder order = self.order return order.p + order.P + order.q + order.Q + order.k + 1 - def get_fit_params(self) -> Dict[str, object]: + @cuml.internals.api_base_return_autoarray(input_arg=None) + def get_fit_params(self) -> Dict[str, CumlArray]: """Get all the fit parameters. Not to be confused with get_params Note: pack() can be used to get a compact vector of the parameters @@ -434,7 +460,7 @@ class ARIMA(Base): if param_name in params: array, *_ = input_to_cuml_array(params[param_name], check_dtype=np.float64) - setattr(self, "_{}_".format(param_name), array) + setattr(self, "{}_".format(param_name), array) def get_param_names(self): raise NotImplementedError @@ -467,7 +493,13 @@ class ARIMA(Base): "`get_params` and `set_params`.") @nvtx_range_wrap - def predict(self, start=0, end=None, level=None): + @cuml.internals.api_base_return_autoarray(input_arg=None) + def predict( + self, + start=0, + end=None, + level=None + ) -> Union[CumlArray, Tuple[CumlArray, CumlArray, CumlArray]]: """Compute in-sample and/or out-of-sample prediction for each series Parameters @@ -529,25 +561,24 @@ class ARIMA(Base): end = self.n_obs cdef handle_t* handle_ = self.handle.getHandle() - predict_size = end - start # allocate predictions and intervals device memory cdef uintptr_t d_y_p_ptr = NULL cdef uintptr_t d_lower_ptr = NULL cdef uintptr_t d_upper_ptr = NULL - d_y_p = cumlArray.empty((predict_size, self.batch_size), + d_y_p = CumlArray.empty((predict_size, self.batch_size), dtype=np.float64, order="F") d_y_p_ptr = d_y_p.ptr if level is not None: - d_lower = cumlArray.empty((predict_size, self.batch_size), + d_lower = CumlArray.empty((predict_size, self.batch_size), dtype=np.float64, order="F") - d_upper = cumlArray.empty((predict_size, self.batch_size), + d_upper = CumlArray.empty((predict_size, self.batch_size), dtype=np.float64, order="F") d_lower_ptr = d_lower.ptr d_upper_ptr = d_upper.ptr - cdef uintptr_t d_y_ptr = self._d_y.ptr + cdef uintptr_t d_y_ptr = self.d_y.ptr cpp_predict(handle_[0], d_y_ptr, self.batch_size, self.n_obs, start, end, order, @@ -557,14 +588,19 @@ class ARIMA(Base): d_lower_ptr, d_upper_ptr) if level is None: - return d_y_p.to_output(self.output_type) + return d_y_p else: - return (d_y_p.to_output(self.output_type), - d_lower.to_output(self.output_type), - d_upper.to_output(self.output_type)) + return (d_y_p, + d_lower, + d_upper) @nvtx_range_wrap - def forecast(self, nsteps: int, level=None): + @cuml.internals.api_base_return_generic_skipall + def forecast( + self, + nsteps: int, + level=None + ) -> Union[CumlArray, Tuple[CumlArray, CumlArray, CumlArray]]: """Forecast the given model `nsteps` into the future. Parameters @@ -599,28 +635,30 @@ class ARIMA(Base): return self.predict(self.n_obs, self.n_obs + nsteps, level) + @cuml.internals.api_base_return_any_skipall def _create_arrays(self): """Create the parameter arrays if non-existing""" cdef ARIMAOrder order = self.order - if order.k and not hasattr(self, "_mu_"): - self._mu_ = cumlArray.empty(self.batch_size, np.float64) - if order.p and not hasattr(self, "_ar_"): - self._ar_ = cumlArray.empty((order.p, self.batch_size), + if order.k and not hasattr(self, "mu_"): + self.mu_ = CumlArray.empty(self.batch_size, np.float64) + if order.p and not hasattr(self, "ar_"): + self.ar_ = CumlArray.empty((order.p, self.batch_size), + np.float64) + if order.q and not hasattr(self, "ma_"): + self.ma_ = CumlArray.empty((order.q, self.batch_size), + np.float64) + if order.P and not hasattr(self, "sar_"): + self.sar_ = CumlArray.empty((order.P, self.batch_size), np.float64) - if order.q and not hasattr(self, "_ma_"): - self._ma_ = cumlArray.empty((order.q, self.batch_size), + if order.Q and not hasattr(self, "sma_"): + self.sma_ = CumlArray.empty((order.Q, self.batch_size), np.float64) - if order.P and not hasattr(self, "_sar_"): - self._sar_ = cumlArray.empty((order.P, self.batch_size), - np.float64) - if order.Q and not hasattr(self, "_sma_"): - self._sma_ = cumlArray.empty((order.Q, self.batch_size), - np.float64) - if not hasattr(self, "_sigma2_"): - self._sigma2_ = cumlArray.empty(self.batch_size, np.float64) + if not hasattr(self, "sigma2_"): + self.sigma2_ = CumlArray.empty(self.batch_size, np.float64) @nvtx_range_wrap + @cuml.internals.api_base_return_any_skipall def _estimate_x0(self): """Internal method. Estimate initial parameters of the model. """ @@ -629,7 +667,7 @@ class ARIMA(Base): cdef ARIMAOrder order = self.order cdef ARIMAParams[double] cpp_params = ARIMAParamsWrapper(self).params - cdef uintptr_t d_y_ptr = self._d_y.ptr + cdef uintptr_t d_y_ptr = self.d_y.ptr cdef handle_t* handle_ = self.handle.getHandle() # Call C++ function @@ -637,13 +675,14 @@ class ARIMA(Base): self.batch_size, self.n_obs, order) @nvtx_range_wrap + @cuml.internals.api_base_return_any_skipall def fit(self, start_params: Optional[Mapping[str, object]] = None, opt_disp: int = -1, h: float = 1e-8, maxiter: int = 1000, method="ml", - truncate: int = 0): + truncate: int = 0) -> "ARIMA": r"""Fit the ARIMA model to each time series. Parameters @@ -680,7 +719,7 @@ class ARIMA(Base): observations """ def fit_helper(x_in, fit_method): - cdef uintptr_t d_y_ptr = self._d_y.ptr + cdef uintptr_t d_y_ptr = self.d_y.ptr def f(x: np.ndarray) -> np.ndarray: """The (batched) energy functional returning the negative @@ -732,6 +771,7 @@ class ARIMA(Base): return self @nvtx_range_wrap + @cuml.internals.api_base_return_any_skipall def _loglike(self, x, trans=True, method="ml", truncate=0): """Compute the batched log-likelihood for the given parameters. @@ -767,12 +807,12 @@ class ARIMA(Base): cdef uintptr_t d_x_ptr = d_x_array.ptr cdef uintptr_t d_y_kf_ptr = \ - self._d_y_diff.ptr if diff else self._d_y.ptr + self._d_y_diff.ptr if diff else self.d_y.ptr cdef handle_t* handle_ = self.handle.getHandle() n_obs_kf = (self.n_obs_diff if diff else self.n_obs) - d_vs = cumlArray.empty((n_obs_kf, self.batch_size), dtype=np.float64, + d_vs = CumlArray.empty((n_obs_kf, self.batch_size), dtype=np.float64, order="F") cdef uintptr_t d_vs_ptr = d_vs.ptr @@ -785,6 +825,7 @@ class ARIMA(Base): return np.array(vec_loglike, dtype=np.float64) @nvtx_range_wrap + @cuml.internals.api_base_return_any_skipall def _loglike_grad(self, x, h=1e-8, trans=True, method="ml", truncate=0): """Compute the gradient (via finite differencing) of the batched log-likelihood. @@ -818,7 +859,7 @@ class ARIMA(Base): cdef LoglikeMethod ll_method = CSS if method == "css" else MLE diff = ll_method != MLE or self.simple_differencing - grad = cumlArray.empty(N * self.batch_size, np.float64) + grad = CumlArray.empty(N * self.batch_size, np.float64) cdef uintptr_t d_grad = grad.ptr cdef ARIMAOrder order_kf = self.order_diff if diff else self.order @@ -828,7 +869,7 @@ class ARIMA(Base): cdef uintptr_t d_x_ptr = d_x_array.ptr cdef uintptr_t d_y_kf_ptr = \ - self._d_y_diff.ptr if diff else self._d_y.ptr + self._d_y_diff.ptr if diff else self.d_y.ptr cdef handle_t* handle_ = self.handle.getHandle() @@ -859,7 +900,7 @@ class ARIMA(Base): cdef ARIMAParams[double] cpp_params = ARIMAParamsWrapper(self).params cdef uintptr_t d_y_kf_ptr = \ - self._d_y_diff.ptr if self.simple_differencing else self._d_y.ptr + self._d_y_diff.ptr if self.simple_differencing else self.d_y.ptr n_obs_kf = (self.n_obs_diff if self.simple_differencing else self.n_obs) @@ -867,7 +908,7 @@ class ARIMA(Base): cdef LoglikeMethod ll_method = MLE diff = self.simple_differencing - d_vs = cumlArray.empty((n_obs_kf, self.batch_size), dtype=np.float64, + d_vs = CumlArray.empty((n_obs_kf, self.batch_size), dtype=np.float64, order="F") cdef uintptr_t d_vs_ptr = d_vs.ptr @@ -919,7 +960,7 @@ class ARIMA(Base): cdef ARIMAOrder order = self.order cdef ARIMAParams[double] cpp_params = ARIMAParamsWrapper(self).params - d_x_array = cumlArray.empty(self.complexity * self.batch_size, + d_x_array = CumlArray.empty(self.complexity * self.batch_size, np.float64) cdef uintptr_t d_x_ptr = d_x_array.ptr @@ -929,6 +970,7 @@ class ARIMA(Base): return d_x_array.to_output("numpy") @nvtx_range_wrap + @cuml.internals.api_base_return_any_skipall def _batched_transform(self, x, isInv=False): """Applies Jones transform or inverse transform to a parameter vector diff --git a/python/cuml/tsa/auto_arima.pyx b/python/cuml/tsa/auto_arima.pyx index 9f5612a0c5..883a832747 100644 --- a/python/cuml/tsa/auto_arima.pyx +++ b/python/cuml/tsa/auto_arima.pyx @@ -16,6 +16,8 @@ # distutils: language = c++ +import typing + import ctypes import itertools from libc.stdint cimport uintptr_t @@ -25,13 +27,15 @@ import numpy as np import cupy as cp -import cuml +import cuml.internals from cuml.common import logger -from cuml.common.array import CumlArray as cumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.common.array import CumlArray from cuml.common.base import Base from cuml.raft.common.handle cimport handle_t from cuml.raft.common.handle import Handle -from cuml.common.input_utils import input_to_cuml_array +from cuml.common import input_to_cuml_array +from cuml.common import using_output_type from cuml.tsa.arima import ARIMA from cuml.tsa.seasonality import seas_test from cuml.tsa.stationarity import kpss_test @@ -168,6 +172,8 @@ class AutoARIMA(Base): """ + d_y = CumlArrayDescriptor() + def __init__(self, endog, handle=None, @@ -182,11 +188,12 @@ class AutoARIMA(Base): self._set_base_attributes(output_type=endog) # Get device array. Float64 only for now. - self._d_y, self.n_obs, self.batch_size, self.dtype \ + self.d_y, self.n_obs, self.batch_size, self.dtype \ = input_to_cuml_array(endog, check_dtype=np.float64) self.simple_differencing = simple_differencing + @cuml.internals.api_return_any() def search(self, s=None, d=range(3), @@ -277,20 +284,23 @@ class AutoARIMA(Base): D_options = _parse_sequence("D", D, 0, 1) if not s: # Non-seasonal -> D=0 - data_D = {0: (self._d_y, d_index)} + data_D = {0: (self.d_y, d_index)} elif len(D_options) == 1: # D is specified by the user - data_D = {D_options[0]: (self._d_y, d_index)} + data_D = {D_options[0]: (self.d_y, d_index)} else: # D is chosen with a seasonal differencing test if seasonal_test not in tests_map: raise ValueError("Unknown seasonal diff test: {}" .format(seasonal_test)) - mask_cp = tests_map[seasonal_test](self._d_y.to_output("cupy"), s) + + with using_output_type("cupy"): + mask_cp = tests_map[seasonal_test](self.d_y, s) + mask = input_to_cuml_array(mask_cp)[0] del mask_cp data_D = {} - out0, index0, out1, index1 = _divide_by_mask(self._d_y, mask, + out0, index0, out1, index1 = _divide_by_mask(self.d_y, mask, d_index) if out0 is not None: data_D[0] = (out0, index0) @@ -397,6 +407,7 @@ class AutoARIMA(Base): self.id_to_model, self.id_to_pos = _build_division_map(id_tracker, self.batch_size) + @cuml.internals.api_base_return_any_skipall def fit(self, h: float = 1e-8, maxiter: int = 1000, @@ -423,7 +434,14 @@ class AutoARIMA(Base): logger.debug("Fitting {} ({})".format(model, method)) model.fit(h=h, maxiter=maxiter, method=method, truncate=truncate) - def predict(self, start=0, end=None, level=None): + @cuml.internals.api_base_return_generic_skipall + def predict( + self, + start=0, + end=None, + level=None + ) -> typing.Union[CumlArray, typing.Tuple[CumlArray, CumlArray, + CumlArray]]: """Compute in-sample and/or out-of-sample prediction for each series Parameters @@ -463,12 +481,12 @@ class AutoARIMA(Base): # Put all the predictions together y_p = _merge_series(pred_list, self.id_to_model, self.id_to_pos, - self.batch_size).to_output(self.output_type) + self.batch_size) if level is not None: lower = _merge_series(lower_list, self.id_to_model, self.id_to_pos, - self.batch_size).to_output(self.output_type) + self.batch_size) upper = _merge_series(upper_list, self.id_to_model, self.id_to_pos, - self.batch_size).to_output(self.output_type) + self.batch_size) # Return the results if level is None: @@ -476,7 +494,13 @@ class AutoARIMA(Base): else: return y_p, lower, upper - def forecast(self, nsteps: int, level=None): + @cuml.internals.api_base_return_generic_skipall + def forecast(self, + nsteps: int, + level=None) -> typing.Union[CumlArray, + typing.Tuple[CumlArray, + CumlArray, + CumlArray]]: """Forecast `nsteps` into the future. Parameters @@ -533,11 +557,11 @@ def _divide_by_mask(original, mask, batch_id, handle=None): Parameters ---------- - original : cumlArray (float32 or float64) + original : CumlArray (float32 or float64) Original batch - mask : cumlArray (bool) + mask : CumlArray (bool) Boolean mask: False for the 1st sub-batch and True for the second - batch_id : cumlArray (int) + batch_id : CumlArray (int) Integer array to track the id of each member in the initial batch handle : cuml.Handle Specifies the cuml.handle that holds internal CUDA state for @@ -549,14 +573,14 @@ def _divide_by_mask(original, mask, batch_id, handle=None): Returns ------- - out0 : cumlArray (float32 or float64) + out0 : CumlArray (float32 or float64) Sub-batch 0, or None if empty - batch0_id : cumlArray (int) + batch0_id : CumlArray (int) Indices of the members of the sub-batch 0 in the initial batch, or None if empty - out1 : cumlArray (float32 or float64) + out1 : CumlArray (float32 or float64) Sub-batch 1, or None if empty - batch1_id : cumlArray (int) + batch1_id : CumlArray (int) Indices of the members of the sub-batch 1 in the initial batch, or None if empty """ @@ -570,7 +594,7 @@ def _divide_by_mask(original, mask, batch_id, handle=None): handle = Handle() cdef handle_t* handle_ = handle.getHandle() - index = cumlArray.empty(batch_size, np.int32) + index = CumlArray.empty(batch_size, np.int32) cdef uintptr_t d_index = index.ptr cdef uintptr_t d_mask = mask.ptr @@ -580,8 +604,8 @@ def _divide_by_mask(original, mask, batch_id, handle=None): d_index, batch_size) - out0 = cumlArray.empty((n_obs, batch_size - nb_true), dtype) - out1 = cumlArray.empty((n_obs, nb_true), dtype) + out0 = CumlArray.empty((n_obs, batch_size - nb_true), dtype) + out1 = CumlArray.empty((n_obs, nb_true), dtype) # Type declarations (can't be in if-else statements) cdef uintptr_t d_out0 @@ -607,8 +631,8 @@ def _divide_by_mask(original, mask, batch_id, handle=None): # If both sub-batches have elements else: - out0 = cumlArray.empty((n_obs, batch_size - nb_true), dtype) - out1 = cumlArray.empty((n_obs, nb_true), dtype) + out0 = CumlArray.empty((n_obs, batch_size - nb_true), dtype) + out1 = CumlArray.empty((n_obs, nb_true), dtype) d_out0 = out0.ptr d_out1 = out1.ptr @@ -633,8 +657,8 @@ def _divide_by_mask(original, mask, batch_id, handle=None): n_obs) # Also keep track of the original id of the series in the batch - batch0_id = cumlArray.empty(batch_size - nb_true, np.int32) - batch1_id = cumlArray.empty(nb_true, np.int32) + batch0_id = CumlArray.empty(batch_size - nb_true, np.int32) + batch1_id = CumlArray.empty(nb_true, np.int32) d_batch0_id = batch0_id.ptr d_batch1_id = batch1_id.ptr d_batch_id = batch_id.ptr @@ -655,13 +679,13 @@ def _divide_by_min(original, metrics, batch_id, handle=None): """Divide a given batch into multiple sub-batches according to the values of the given metrics, by selecting the minimum value for each member - Parameters + Parameters: ---------- - original : cumlArray (float32 or float64) + original : CumlArray (float32 or float64) Original batch - metrics : cumlArray (float32 or float64) + metrics : CumlArray (float32 or float64) Matrix of shape (batch_size, n_sub) containing the metrics to minimize - batch_id : cumlArray (int) + batch_id : CumlArray (int) Integer array to track the id of each member in the initial batch handle : cuml.Handle Specifies the cuml.handle that holds internal CUDA state for @@ -673,9 +697,9 @@ def _divide_by_min(original, metrics, batch_id, handle=None): Returns ------- - sub_batches : List[cumlArray] (float32 or float64) + sub_batches : List[CumlArray] (float32 or float64) List of arrays containing each sub-batch, or None if empty - sub_id : List[cumlArray] (int) + sub_id : List[CumlArray] (int) List of arrays containing the indices of each member in the initial batch, or None if empty """ @@ -690,8 +714,8 @@ def _divide_by_min(original, metrics, batch_id, handle=None): handle = Handle() cdef handle_t* handle_ = handle.getHandle() - batch_buffer = cumlArray.empty(batch_size, np.int32) - index_buffer = cumlArray.empty(batch_size, np.int32) + batch_buffer = CumlArray.empty(batch_size, np.int32) + index_buffer = CumlArray.empty(batch_size, np.int32) cdef vector[int] size_buffer size_buffer.resize(n_sub) @@ -720,7 +744,7 @@ def _divide_by_min(original, metrics, batch_id, handle=None): # Build a list of cuML arrays for the sub-batches and a vector of pointers # to be passed to the next C++ step - sub_batches = [cumlArray.empty((n_obs, s), dtype) if s else None + sub_batches = [CumlArray.empty((n_obs, s), dtype) if s else None for s in size_buffer] cdef vector[uintptr_t] sub_ptr sub_ptr.resize(n_sub) @@ -753,7 +777,7 @@ def _divide_by_min(original, metrics, batch_id, handle=None): # Keep track of the id of the series if requested cdef vector[uintptr_t] id_ptr - sub_id = [cumlArray.empty(s, np.int32) if s else None + sub_id = [CumlArray.empty(s, np.int32) if s else None for s in size_buffer] id_ptr.resize(n_sub) for i in range(n_sub): @@ -781,16 +805,16 @@ def _build_division_map(id_tracker, batch_size, handle=None): Parameters ---------- - id_tracker : List[cumlArray] (int) + id_tracker : List[CumlArray] (int) List of the index arrays of each sub-batch batch_size : int Size of the initial batch Returns ------- - id_to_model : cumlArray (int) + id_to_model : CumlArray (int) Associates each batch member with a model - id_to_pos : cumlArray (int) + id_to_pos : CumlArray (int) Position of each member in the respective sub-batch """ if handle is None: @@ -799,8 +823,8 @@ def _build_division_map(id_tracker, batch_size, handle=None): n_sub = len(id_tracker) - id_to_pos = cumlArray.empty(batch_size, np.int32) - id_to_model = cumlArray.empty(batch_size, np.int32) + id_to_pos = CumlArray.empty(batch_size, np.int32) + id_to_model = CumlArray.empty(batch_size, np.int32) cdef vector[uintptr_t] id_ptr cdef vector[int] size_vec @@ -833,18 +857,18 @@ def _merge_series(data_in, id_to_sub, id_to_pos, batch_size, handle=None): Parameters ---------- - data_in : List[cumlArray] (float32 or float64) + data_in : List[CumlArray] (float32 or float64) List of sub-batches to merge - id_to_model : cumlArray (int) + id_to_model : CumlArray (int) Associates each member of the batch with a sub-batch - id_to_pos : cumlArray (int) + id_to_pos : CumlArray (int) Position of each member of the batch in its respective sub-batch batch_size : int Size of the initial batch Returns ------- - data_out : cumlArray (float32 or float64) + data_out : CumlArray (float32 or float64) Merged batch """ dtype = data_in[0].dtype @@ -860,7 +884,7 @@ def _merge_series(data_in, id_to_sub, id_to_pos, batch_size, handle=None): for i in range(n_sub): in_ptr[i] = data_in[i].ptr - data_out = cumlArray.empty((n_obs, batch_size), dtype) + data_out = CumlArray.empty((n_obs, batch_size), dtype) cdef uintptr_t hd_in = in_ptr.data() cdef uintptr_t d_id_to_pos = id_to_pos.ptr diff --git a/python/cuml/tsa/holtwinters.pyx b/python/cuml/tsa/holtwinters.pyx index 4972b92a30..9f0d15e969 100644 --- a/python/cuml/tsa/holtwinters.pyx +++ b/python/cuml/tsa/holtwinters.pyx @@ -18,12 +18,14 @@ import cudf import cupy as cp import numpy as np -from numba import cuda from libc.stdint cimport uintptr_t -from cuml.common import input_to_dev_array -from cuml.common import get_dev_array_ptr -from cuml.common import numba_utils + +import cuml.internals +from cuml.common.input_utils import input_to_cupy_array +from cuml.common import using_output_type from cuml.common.base import Base +from cuml.common.array import CumlArray +from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.raft.common.handle cimport handle_t cdef extern from "cuml/tsa/holtwinters_params.h" namespace "ML": @@ -173,6 +175,13 @@ class ExponentialSmoothing(Base): See :ref:`output-data-type-configuration` for more info. """ + + forecasted_points = CumlArrayDescriptor() + level = CumlArrayDescriptor() + trend = CumlArrayDescriptor() + season = CumlArrayDescriptor() + SSE = CumlArrayDescriptor() + def __init__(self, endog, seasonal="additive", seasonal_periods=2, start_periods=2, ts_num=1, eps=2.24e-3, handle=None, @@ -240,30 +249,28 @@ class ExponentialSmoothing(Base): self.fit_executed_flag = False self.h = 0 - def _check_dims(self, ts_input, is_cudf=False): + def _check_dims(self, ts_input, is_cudf=False) -> CumlArray: err_mess = ("ExponentialSmoothing initialized with " + str(self.ts_num) + " time series, but data has dimension ") - if len(ts_input.shape) == 1: - self.n = ts_input.shape[0] + + is_cudf = isinstance(ts_input, cudf.DataFrame) + + mod_ts_input = input_to_cupy_array(ts_input, order="C").array + + if len(mod_ts_input.shape) == 1: + self.n = mod_ts_input.shape[0] if self.ts_num != 1: raise ValueError(err_mess + "1.") - if(is_cudf): - mod_ts_input = ts_input.as_gpu_matrix() - elif(isinstance(ts_input, cudf.Series)): - mod_ts_input = ts_input.to_gpu_array() - else: - mod_ts_input = ts_input elif len(ts_input.shape) == 2: if(is_cudf): - d1 = ts_input.shape[0] - d2 = ts_input.shape[1] - mod_ts_input = ts_input.as_gpu_matrix()\ - .reshape((d1*d2,)) + d1 = mod_ts_input.shape[0] + d2 = mod_ts_input.shape[1] + mod_ts_input = mod_ts_input.reshape((d1*d2,)) else: - d1 = ts_input.shape[1] - d2 = ts_input.shape[0] - mod_ts_input = ts_input.ravel() + d1 = mod_ts_input.shape[1] + d2 = mod_ts_input.shape[0] + mod_ts_input = mod_ts_input.ravel() self.n = d1 if self.ts_num != d2: raise ValueError(err_mess + str(d2)) @@ -271,23 +278,17 @@ class ExponentialSmoothing(Base): raise ValueError("Data input must have 1 or 2 dimensions.") return mod_ts_input - def fit(self): + @cuml.internals.api_base_return_any_skipall + def fit(self) -> "ExponentialSmoothing": """ Perform fitting on the given `endog` dataset. Calculates the level, trend, season, and SSE components. """ - if isinstance(self.endog, cudf.Series): - arr = self._check_dims(self.endog) - elif isinstance(self.endog, cudf.DataFrame): - arr = self._check_dims(self.endog, True) - elif cuda.is_cuda_array(self.endog): - try: - import cupy as cp - arr = self._check_dims(self.endog) - except Exception: - arr = cuda.as_cuda_array(self.endog).copy_to_host() - if isinstance(self.endog, np.ndarray): - arr = self._check_dims(self.endog) + + X_m = self._check_dims(self.endog) + + self.dtype = X_m.dtype + if self.n < self.start_periods*self.seasonal_periods: raise ValueError("Length of time series (" + str(self.n) + ") must be at least freq*start_periods (" + @@ -302,8 +303,7 @@ class ExponentialSmoothing(Base): cdef int leveltrend_coef_offset, season_coef_offset cdef int error_len - X_m, input_ptr, _, _, self.dtype = \ - input_to_dev_array(arr, order='C') + input_ptr = X_m.ptr buffer_size( self.n, self.ts_num, self.seasonal_periods, @@ -317,14 +317,14 @@ class ExponentialSmoothing(Base): cdef handle_t* handle_ = self.handle.getHandle() cdef uintptr_t level_ptr, trend_ptr, season_ptr, SSE_ptr - self.level = numba_utils.zeros(components_len, dtype=self.dtype) - self.trend = numba_utils.zeros(components_len, dtype=self.dtype) - self.season = numba_utils.zeros(components_len, dtype=self.dtype) - self.SSE = numba_utils.zeros(self.ts_num, dtype=self.dtype) - level_ptr = get_dev_array_ptr(self.level) - trend_ptr = get_dev_array_ptr(self.trend) - season_ptr = get_dev_array_ptr(self.season) - SSE_ptr = get_dev_array_ptr(self.SSE) + self.level = CumlArray.zeros(components_len, dtype=self.dtype) + self.trend = CumlArray.zeros(components_len, dtype=self.dtype) + self.season = CumlArray.zeros(components_len, dtype=self.dtype) + self.SSE = CumlArray.zeros(self.ts_num, dtype=self.dtype) + level_ptr = self.level.ptr + trend_ptr = self.trend.ptr + season_ptr = self.season.ptr + SSE_ptr = self.SSE.ptr cdef float eps_f = np.float32(self.eps) cdef double eps_d = np.float64(self.eps) @@ -352,9 +352,12 @@ class ExponentialSmoothing(Base): " and float64 input, but input type " + str(self.dtype) + " passed.") num_rows = int(components_len/self.ts_num) - self.level = self.level.reshape((self.ts_num, num_rows), order='F') - self.trend = self.trend.reshape((self.ts_num, num_rows), order='F') - self.season = self.season.reshape((self.ts_num, num_rows), order='F') + + with using_output_type("cupy"): + self.level = self.level.reshape((self.ts_num, num_rows), order='F') + self.trend = self.trend.reshape((self.ts_num, num_rows), order='F') + self.season = self.season.reshape((self.ts_num, num_rows), + order='F') self.handle.sync() self.fit_executed_flag = True @@ -395,12 +398,13 @@ class ExponentialSmoothing(Base): if h > self.h: self.h = h - self.forecasted_points = numba_utils.zeros(self.ts_num*h, - dtype=self.dtype) - forecast_ptr = get_dev_array_ptr(self.forecasted_points) - level_ptr = get_dev_array_ptr(self.level) - trend_ptr = get_dev_array_ptr(self.trend) - season_ptr = get_dev_array_ptr(self.season) + self.forecasted_points = CumlArray.zeros(self.ts_num*h, + dtype=self.dtype) + with using_output_type("cuml"): + forecast_ptr = self.forecasted_points.ptr + level_ptr = self.level.ptr + trend_ptr = self.trend.ptr + season_ptr = self.season.ptr if self.dtype == np.float32: forecast(handle_[0], self.n, @@ -421,9 +425,11 @@ class ExponentialSmoothing(Base): trend_ptr, season_ptr, forecast_ptr) - self.forecasted_points =\ - self.forecasted_points.reshape((self.ts_num, h), - order='F') + + with using_output_type("cupy"): + self.forecasted_points =\ + self.forecasted_points.reshape((self.ts_num, h), + order='F') self.handle.sync() if index is None: diff --git a/python/cuml/tsa/seasonality.pyx b/python/cuml/tsa/seasonality.pyx index f6bc98e0d5..318375ffcc 100644 --- a/python/cuml/tsa/seasonality.pyx +++ b/python/cuml/tsa/seasonality.pyx @@ -20,9 +20,8 @@ import numpy as np from libc.stdint cimport uintptr_t from libcpp cimport bool -import cuml -from cuml.common.array import CumlArray as cumlArray -from cuml.common.base import _input_to_type +import cuml.internals +from cuml.common.array import CumlArray from cuml.raft.common.handle cimport handle_t from cuml.common.input_utils import input_to_host_array, input_to_cuml_array @@ -47,7 +46,8 @@ def python_seas_test(y, batch_size, n_obs, s, threshold=0.64): return results -def seas_test(y, s, output_type="input", handle=None): +@cuml.internals.api_return_array(input_arg="y", get_output_type=True) +def seas_test(y, s, handle=None) -> CumlArray: """ Perform Wang, Smith & Hyndman's test to decide whether seasonal differencing is needed @@ -78,9 +78,6 @@ def seas_test(y, s, output_type="input", handle=None): "ERROR: Invalid period for the seasonal differencing test: {}" .format(s)) - if output_type == "input": - output_type = _input_to_type(y) - # At the moment we use a host array h_y, _, n_obs, batch_size, dtype = \ input_to_host_array(y, check_dtype=[np.float32, np.float64]) @@ -88,4 +85,4 @@ def seas_test(y, s, output_type="input", handle=None): # Temporary: Python implementation python_res = python_seas_test(h_y, batch_size, n_obs, s) d_res, *_ = input_to_cuml_array(np.array(python_res), check_dtype=np.bool) - return d_res.to_output(output_type) + return d_res diff --git a/python/cuml/tsa/stationarity.pyx b/python/cuml/tsa/stationarity.pyx index e6f1146eab..42397fe70e 100644 --- a/python/cuml/tsa/stationarity.pyx +++ b/python/cuml/tsa/stationarity.pyx @@ -20,12 +20,10 @@ import numpy as np from libc.stdint cimport uintptr_t from libcpp cimport bool -import cuml -from cuml.common.array import CumlArray as cumlArray -from cuml.common.base import _input_to_type +import cuml.internals +from cuml.common.array import CumlArray from cuml.raft.common.handle cimport handle_t from cuml.raft.common.handle import Handle - from cuml.common.input_utils import input_to_cuml_array @@ -49,8 +47,9 @@ cdef extern from "cuml/tsa/stationarity.h" namespace "ML": double pval_threshold) -def kpss_test(y, d=0, D=0, s=0, pval_threshold=0.05, output_type="input", - handle=None): +@cuml.internals.api_return_array(input_arg="y", get_output_type=True) +def kpss_test(y, d=0, D=0, s=0, pval_threshold=0.05, + handle=None) -> CumlArray: """ Perform the KPSS stationarity test on the data differenced according to the given order @@ -86,14 +85,11 @@ def kpss_test(y, d=0, D=0, s=0, pval_threshold=0.05, output_type="input", input_to_cuml_array(y, check_dtype=[np.float32, np.float64]) cdef uintptr_t d_y_ptr = d_y.ptr - if output_type == "input": - output_type = _input_to_type(y) - if handle is None: handle = Handle() cdef handle_t* handle_ = handle.getHandle() - results = cumlArray.empty(batch_size, dtype=np.bool) + results = CumlArray.empty(batch_size, dtype=np.bool) cdef uintptr_t d_results = results.ptr # Call C++ function @@ -114,4 +110,4 @@ def kpss_test(y, d=0, D=0, s=0, pval_threshold=0.05, output_type="input", d, D, s, pval_threshold) - return results.to_output(output_type) + return results diff --git a/wiki/python/DEVELOPER_GUIDE.md b/wiki/python/DEVELOPER_GUIDE.md index 10c0e0882b..a6d81140ab 100644 --- a/wiki/python/DEVELOPER_GUIDE.md +++ b/wiki/python/DEVELOPER_GUIDE.md @@ -18,13 +18,15 @@ Refer to the section on thread safety in [C++ DEVELOPER_GUIDE.md](../cpp/DEVELOP 1. Make sure that this algo has been implemented in the C++ side. Refer to [C++ DEVELOPER_GUIDE.md](../cpp/DEVELOPER_GUIDE.md) for guidelines on developing in C++. 2. Refer to the [next section](DEVELOPER_GUIDE.md#creating-python-wrapper-class-for-an-existing-ml-algo) for the remaining steps. -## Creating python wrapper class for an existing ML algo +## Creating python estimator wrapper class 1. Create a corresponding algoName.pyx file inside `python/cuml` folder. 2. Ensure that the folder structure inside here reflects that of sklearn's. Example, `pca.pyx` should be kept inside the `decomposition` sub-folder of `python/cuml`. . Match the corresponding scikit-learn's interface as closely as possible. Refer to their [developer guide](https://scikit-learn.org/stable/developers/contributing.html#apis-of-scikit-learn-objects) on API design of sklearn objects for details. 3. Always make sure to have your class inherit from `cuml.Base` class as your parent/ancestor. 4. Ensure that the estimator's output fields follow the 'underscore on both sides' convention explained in the documentation of `cuml.Base`. This allows it to support configurable output types. +For an in-depth guide to creating estimators, see the [Estimator Guide](ESTIMATOR_GUIDE.md) + ## Error handling If you are trying to call into cuda runtime APIs inside `cuml.cuda`, in case of any errors, they'll raise a `cuml.cuda.CudaRuntimeError`. For example: ```python diff --git a/wiki/python/ESTIMATOR_GUIDE.md b/wiki/python/ESTIMATOR_GUIDE.md new file mode 100644 index 0000000000..7a81031e17 --- /dev/null +++ b/wiki/python/ESTIMATOR_GUIDE.md @@ -0,0 +1,832 @@ +# cuML Python Estimators Developer Guide + +This guide is meant to help developers follow the correct patterns when creating/modifying any cuML Estimator object and ensure a uniform cuML API. + +**Note:** This guide is long, because it includes internal details on how cuML manages input and output types for advanced use cases. But for the vast majority of estimators, the requirements are very simple and can follow the example patterns shown below in the [Quick Start Guide](#quick-start-guide). + +To start, it's recommended to read the following Scikit-learn documentation: + +1. [Scikit-learn's Estimator Docs](https://scikit-learn.org/stable/developers/develop.html) + 1. cuML Estimator design follows Scikit-learn very closely. We will only cover portions where our design differs from this document + 2. If short on time, pay close attention to these sections, as these topics have caused pain points in the past: + 1. [Instantiation](https://scikit-learn.org/stable/developers/develop.html#estimated-attributes) + 2. [Estimated Attributes](https://scikit-learn.org/stable/developers/develop.html#estimated-attributes) + 3. [`get_params` and `set_params`](https://scikit-learn.org/stable/developers/develop.html#estimated-attributes) + 4. [cloning](https://scikit-learn.org/stable/developers/develop.html#estimated-attributes) +2. [Scikit-learn's Docstring Guide](https://scikit-learn.org/stable/developers/contributing.html#guidelines-for-writing-documentation) + 1. We follow the same guidelines for specifying array-like objects, array shapes, dtypes, and default values + +## Quick Start Guide + +At a high level, all cuML Estimators must: +1. Inherit from `cuml.common.base.Base` + ```python + from cuml.common.base import Base + + class MyEstimator(Base): + ... + ``` +2. Follow the Scikit-learn estimator guidelines found [here](https://scikit-learn.org/stable/developers/develop.html) +3. Include the `Base.__init__()` arguments available in the new Estimator's `__init__()` + ```python + class MyEstimator(Base): + + def __init__(self, *, extra_arg=True, handle=None, verbose=False, output_type=None): + super().__init__(handle=handle, verbose=verbose, output_type=output_type) + ... + ``` +4. Declare each array-like attribute the new Estimator will compute as a class variable for automatic array type conversion + ```python + from cuml.common.array_descriptor import CumlArrayDescriptor + + class MyEstimator(Base): + + labels_ = CumlArrayDescriptor() + + def __init__(self): + ... + ``` +5. Add input and return type annotations to public API functions OR wrap those functions explicitly with conversion decorators (see [this example](#non-standard-predict) for a non-standard use case) + ```python + class MyEstimator(Base): + + def fit(self, X) -> "MyEstimator": + ... + + def predict(self, X) -> CumlArray: + ... + ``` +6. Implement `get_param_names()` including values returned by `super().get_param_names()` + ```python + def get_param_names(self): + return super().get_param_names() + [ + "eps", + "min_samples", + ] + ``` + +For the majority of estimators, the above steps will be sufficient to correctly work with the cuML library and ensure a consistent API. However, situations may arise where an estimator differs from the standard pattern and some of the functionality needs to be customized. The remainder of this guide takes a deep dive into the estimator functionality to assist developers when building estimators. + +## Background + +Some background is necessary to understand the design of estimators and how to work around any non-standard situations. + +### Input and Output Types in cuML + +In cuML we support both ingesting and generating a variety of different object types. Estimators should be able to accept and return any array type. The types that are supported as of release 0.17: + + - cuDF DataFrame or Series + - Pandas DataFrame or Series + - NumPy Arrays + - Numba Device Arrays + - CuPy arrays + - CumlArray type (Internal to the `cuml` API only.) + +When converting between types, it's important to minimize the CPU<->GPU type conversions as much as possible. Conversions such as NumPy -> CuPy or Numba -> Pandas DataFrame will incur a performance penalty as memory is copied from device to host or vice-versa. + +Converting between types of the same device, i.e. CPU<->CPU or GPU<->GPU, do not have as significant of a penalty, though they may still increase memory usage (this is particularly true for the array <-> dataframe conversion. i.e. when converting from CuPy to cuDF, memory usage may increase slightly). + +Finally, conversions between Numba<->CuPy<->CumlArray incur the least amount of overhead since only the device pointer is moved from one class to another. + +Internally, all arrays should be converted to `CumlArray` as much as possible since it is compatible with all output types and can be easily converted. + +### Specifying the Array Output Type + +Users can choose which array type should be returned by cuml by either: +1. Individually setting the output_type property on an estimator class (i.e `Base(output_type="numpy")`) +2. Globally setting the `cuml.global_output_type` +3. Temporarily setting the `cuml.global_output_type` via the `cuml.using_output_type` context manager + +**Note:** Setting `cuml.global_output_type` (either directly or via `cuml.set_output_type()` or `cuml.using_output_type()`) will take precedence over any value in `Base.output_type` + +Changing the array output type will alter the return value of estimator functions (i.e. `predict()`, `transform()`), and the return value for array-like estimator attributes (i.e. `my_estimator.classes_` or `my_estimator.coef_`) + +All output_types (including `cuml.global_output_type`) are specified using an all lowercase string. These strings can be passed in an estimators constructor or via `cuml.set_global_output_type` and `cuml.using_output_type`. Accepted values are: + + - `None`: (Default) No global value set. Will use individual values from estimators output_type + - `"input"`: Similar to `None`. Will mirror the same type as any array passed into the estimator + - `"numba"`: Returns Numba Device Arrays + - `"numpy"`: Returns Numpy Arrays + - `"cudf"`: Returns cuDF DataFrame if cols > 1, else cuDF Series + - `"cupy"`: Returns CuPy Device Arrays + +**Note:** There is an additional option `"mirror"` which can only be set by internal API calls and is not user accessible. This value is only used internally by the `CumlArrayDescriptor` to mirror any input value set. + +### Ingesting Arrays + +When the input array type isn't known, the correct and safest way to ingest arrays is using `cuml.common.input_to_cuml_array`. This method can handle all supported types, is capable of checking the array order, can enforce a specific dtype, and can raise errors on incorrect array sizes: + +```python +def fit(self, X): + cuml_array, dtype, cols, rows = input_to_cuml_array(X, order="K") + ... +``` + +### Returning Arrays + +The `CumlArray` class can convert to any supported array type using the `to_output(output_type: str)` method. However, doing this explicitly is almost never needed in practice and **should be avoided**. Directly converting arrays with `to_output()` will circumvent the automatic conversion system potentially causing extra or incorrect array conversions. + +## Estimator Design + +All estimators (any class that is a child of `cuml.common.base.Base`) have a similar structure. In addition to the guidelines specified in the [SkLearn Estimator Docs](https://scikit-learn.org/stable/developers/develop.html), cuML implements a few additional rules. + +### Initialization + +All estimators should match the arguments (including the default value) in `Base.__init__` and pass these values to `super().__init__()`. As of 0.17, all estimators should accept `handle`, `verbose` and `output_type`. + +In addition, is recommended to force keyword arguments to prevent breaking changes if arguments are added or removed in future versions. For example, all arguments below after `*` must be passed by keyword: + +```python +def __init__(self, *, eps=0.5, min_samples=5, max_mbytes_per_batch=None, + calc_core_sample_indices=True, handle=None, verbose=False, output_type=None): +``` + +Finally, do not alter any input arguments - if you do, it will prevent proper cloning of the estimator. See Scikit-learn's [section](https://scikit-learn.org/stable/developers/develop.html#instantiation) on instantiation for more info. + +For example, the following `__init__` shows what **NOT** to do: +```python +def __init__(self, my_option="option1"): + if (my_option == "option1"): + self.my_option = 1 + else: + self.my_option = 2 +``` + +This will break cloning since the value of `self.my_option` is not a valid input to `__init__`. Instead, `my_option` should be saved as an attribute as-is. + +### Implementing `get_param_names()` + +To support cloning, estimators need to implement the function `get_param_names()`. The returned value should be a list of strings of all estimator attributes that are necessary to duplicate the estimator. This method is used in `Base.get_params()` which will collect the collect the estimator param values from this list and pass this dictionary to a new estimator constructor. Therefore, all strings returned by `get_param_names()` should be arguments in `__init__()` otherwise an invalid argument exception will be raised. Most estimators implement `get_param_names()` similar to: + +```python +def get_param_names(self): + return super().get_param_names() + [ + "eps", + "min_samples", + ] +``` + +**Note:** Be sure to include `super().get_param_names()` in the returned list to properly set the `super()` attributes. + +### Estimator Array-Like Attributes + +Any array-like attribute stored in an estimator needs to be convertible to the user's desired output type. To make it easier to store array-like objects in a class that derives from `Base`, the `cuml.common.array_descriptor.CumlArrayDescriptor` was created. The `CumlArrayDescriptor` class is a Python descriptor object which allows cuML to implement customized attribute lookup, storage and deletion code that can be reused on all estimators. + +The `CumlArrayDescriptor` behaves different when accessed internally (from within one of `cuml`'s functions) vs. externally (for user code outside the cuml module). Internally, it behaves exactly like a normal attribute and will return the previous value set. Externally, the array will get converted to the user's desired output type lazily and repeated conversion will be cached. + +Performing the arrray conversion lazily (i.e. converting the input array to the desired output type, only when the attribute it read from for the first time) can greatly help reduce memory consumption, but can have unintended impacts the developers should be aware of. For example, benchmarking should take into account the lazy evaluation and ensure the array conversion is included in any profiling. + +#### Defining Array-Like Attributes + +To use the `CumlArrayDescriptor` in an estimator, any array-like attributes need to be specified by creating a `CumlArrayDescriptor` as a class variable. + +```python +from cuml.common.array_descriptor import CumlArrayDescriptor + +class TestEstimator(cuml.Base): + + # Class variables outside of any function + my_cuml_array_ = CumlArrayDescriptor() + + def __init__(self, ...): + ... +``` + +This gives the developer full control over which attributes are arrays and the name for the array-like attribute (something that was not true before `0.17`). + +#### Working with `CumlArrayDescriptor` + +Once an `CumlArrayDescriptor` attribute has been defined, developers can use the attribute as they normally would. Consider the following example estimator: + +```python +import cupy as cp +import cuml +from cuml.common.array_descriptor import CumlArrayDescriptor + +class SampleEstimator(cuml.Base): + + # Class variables outside of any function + my_cuml_array_ = CumlArrayDescriptor() + my_cupy_array_ = CumlArrayDescriptor() + my_other_array_ = CumlArrayDescriptor() + + def __init__(self, ...): + + # Initialize to None (not mandatory) + self.my_cuml_array_ = None + + # Init with a cupy array + self.my_cupy_array_ = cp.zeros((10, 10)) + + def fit(self, X): + # Stores the type of `X` and sets the output type if self.output_type == "input" + self._set_output_type(X) + + # Set my_cuml_array_ with a CumlArray + self.my_cuml_array_, *_ = input_to_cuml_array(X, order="K") + + # Access `my_cupy_array_` normally and set to another attribute + # The internal type of my_other_array_ will be a CuPy array + self.my_other_array_ = cp.ones((10, 10)) + self.my_cupy_array_ + + return self +``` + +Just like any normal attribute, `CumlArrayDescriptor` attributes will return the same value that was set into the attribute _unless accessed externally_ (more on that below). However, developers can convert the type of an array-like attribute by using the `cuml.global_output_type` functionality and reading from the attribute. For example, we could add a `score()` function to `TestEstimator`: + +```python +def score(self): + + # Set the global output type to numpy + with cuml.using_output_type("numpy"): + # Accessing my_cuml_array_ will return a numpy array and + # the result can be returned directly + return np.sum(self.my_cuml_array_, axis=0) +``` + +This has the same benefits of lazy conversion and caching as when descriptors are used externally. + +#### CumlArrayDescriptor External Functionality + +Externally, when users read from a `CumlArrayDescriptor` attribute, the array data will be converted to the correct output type _lazily_ when the attribute is read from. For example, building off the above `TestEstimator`: + +```python +my_est = SampleEstimator() + +# Print the default output_type and value for `my_cuml_array_` +# By default, `output_type` is set to `cuml.global_output_type` +# If `cuml.global_output_type == None`, `output_type` is set to "input" +print(my_est.output_type) # Output: "input" +print(my_est.my_cuml_array_) # Output: None +print(my_est.my_other_array_) # Output: AttributeError! my_other_array_ was never set + +# Call fit() with a numpy array as the input +np_arr = np.ones((10,)) +my_est.fit(np_arr) # This will load data into attributes + +# `my_cuml_array_` was set internally as a CumlArray. Externally, we can check the type +print(type(my_est.my_cuml_array_)) # Output: Numpy (saved from the input of `fit`) + +# Calling fit again with cupy arrays, will have a similar effect +my_est.fit(cp.ones((10,))) +print(type(my_est.my_cuml_array_)) # Output: CuPy + +# Setting the `output_type` will change all descriptor properties +# and ignore the input type +my_est.output_type = "cudf" + +# Reading any of the attributes will convert the type lazily +print(type(my_est.my_cuml_array_)) # Output: cuDF object + +# Setting the global_output_type, overrides the estimator output_type attribute +with cuml.using_output_type("cupy"): + print(type(my_est.my_cuml_array_)) # Output: cupy + +# Once the global_output_type is restored, we return to the estimator output_type +print(type(my_est.my_cuml_array_)) # Output: cuDF. Using a cached value! +``` + +For more information about `CumlArrayDescriptor` and it's implementation, see the [CumlArrayDescriptor Details]() section of the Appendix. + +### Estimator Methods + +To allow estimator methods to accept a wide variety of inputs and outputs, a set of decorators have been created to wrap estimator functions (and all `cuml` API functions as well) and perform the standard conversions automatically. cuML provides 2 options to for performing the standard array type conversions: + +1. For many common patterns used in functions like `fit()`, `predict()`, `transform()`, `cuml.Base` can automatically perform the data conversions as long as a method has the necessary type annotations. +2. Decorators can be manually added to methods to handle more advanced use cases + +#### Option 1: Automatic Array Conversion From Type Annotation + +To automatically convert array-like objects being returned by an Estimator method, a new metaclass has been added to `Base` that can scan the return type information of an Estimator method and infer which, if any, array conversion should be done. For example, if a method returns a type of `Base`, cuML can assume this method is likely similar to `fit()` and should call `Base._set_base_attributes()` before calling the method. If a method returns a type of `CumlArray`, cuML can assume this method is similar to `predict()` or `transform()`, and the return value is an array that may need to be converted using the output type calculated in `Base._get_output_type()`. + +The full set of return types rules that will be applied by the `Base` metaclass are: + +| Return Type | Converts Array Type? | Common Methods | Notes | +| :---------: | :-----------: | :----------- | :----------- | +| `Base` | No | `fit()` | Any type that inherits or `isinstance` of `Base` will work | +| `CumlArray` | Yes | `predict()`, `transform()` | Functions can return any array-like object (`np.ndarray`, `cp.ndarray`, etc. all accepted) | +| `SparseCumlArray` | Yes | `predict()`, `transform()` | Functions can return any sparse array-like object (`scipy`, `cupyx.scipy` sparse arrays accepted) | +| `dict`, `tuple`, `list` or `typing.Union` | Yes | | Functions must return a generic object that contains an array-like object. No sparse arrays are supported | + +Simply setting the return type of a method is all that is necessary to automatically convert the return type (with the added benefit of adding more information to the code). Below are some examples to show simple methods using automatic array conversion. + +##### `fit()` + +```python +def fit(self, X) -> "KMeans": + + # Convert the input to CumlArray + self.coef_ = input_to_cuml_array(X, order="K").array + + return self +``` + +**Notes:** + - Any type that derives from `Base` can be used as the return type for `fit()`. In python, to indicate returning `self` from a function, class type can be surrounded in quotes to prevent an import error. + + +##### `predict()` + +```python +def predict(self, X) -> CumlArray: + # Convert to CumlArray + X_m = input_to_cuml_array(X, order="K").array + + # Call a cuda function + X_m = cp.asarray(X_m) + cp.ones(X_m.shape) + + # Directly return a cupy array + return X_m +``` + +**Notes:** + - Its not necessary to convert to `CumlArray` and casting with `to_output` before returning. This function directly returned a `cp.ndarray` object. Any array-like object can be returned. + +#### Option 2: Manual Estimator Method Decoration + +While the automatic converions from type annotations works for many estimator functions, sometimes its necessary to explicitly decorate an estimator method. This allows developers greater flexibility over the input argument, output type and output dtype. + +Which decorator to use for an estimator function is determined by 2 factors: + +1. Function return type +2. Whether the function is on a class deriving from `Base` + +The full set of descriptors can be organized by these two factors: + +| Return Type-> | Array-Like | Sparse Array-Like | Generic | Any | +| -----------: | :-----------: | :-----------: | :-----------: | :-----------: | +| `Base` | `@api_base_return_array` | `@api_base_return_sparse_array` |`@api_base_return_generic` | `@api_base_return_any` | +| Non-`Base` | `@api_return_array` | `@api_return_sparse_array` | `@api_return_generic` | `@api_return_any` | + +Simply choosing the decorator based off the return type and if the function is on `Base` will work most of the time. The decorator default options were designed to work on most estimator functions without much customization. + +An in-depth discussion of how these decorators work, when each should be used, and their default options can be found in the Appendix. For now, we will show an example method that uses a non-standard input argument name, and also requires converting the array dtype: + +##### Non-Standard `predict()` + +```python +@cuml.internals.api_base_return_array(input_arg="X_in", get_output_dtype=True) +def predict(self, X): + # Convert to CumlArray + X_m = input_to_cuml_array(X, order="K").array + + # Call a cuda function + X_m = cp.asarray(X_m) + cp.ones(X_m.shape) + + # Return the cupy array directly + return X_m +``` + +**Notes:** + - The decorator argument `input_arg` can be used to specify which input should be considered the "input". + - In reality, this isn't necessary for this example. The decorator will look for an argument named `"X"` or default to the first, non `self`, argument. + - It's not necessary to convert to `CumlArray` and casting with `to_output` before returning. This function directly returned a `cp.ndarray` object. Any array-like object can be returned. + - Specifying `get_output_dtype=True` in the decorator argument instructs the decorator to also calculate the dtype in addition to the output type. + +## Do's And Don'ts + +### **Do:** Add Return Typing Information to Estimator Functions + +Adding the return type to estimator functions will allow the `Base` meta-class to automatically decorate functions based on their return type. + +**Do this:** +```python +def fit(self, X, y, convert_dtype=True) -> "KNeighborsRegressor": +def predict(self, X, convert_dtype=True) -> CumlArray: +def kneighbors_graph(self, X=None, n_neighbors=None, mode='connectivity') -> SparseCumlArray: +def predict(self, start=0, end=None, level=None) -> typing.Union[CumlArray, float]: +``` + +**Not this:** +```python +def fit(self, X, y, convert_dtype=True): +def predict(self, X, convert_dtype=True): +def kneighbors_graph(self, X=None, n_neighbors=None, mode='connectivity'): +def predict(self, start=0, end=None, level=None): +``` + +### **Do:** Return Array-Like Objects Directly + +There is no need to convert the array type before returning it. Simply return any array-like object and it will be automatically converted + +**Do this:** +```python +def predict(self) -> CumlArray: + cp_arr = cp.ones((10,)) + + return cp_arr +``` + +**Not this:** +```python +def predict(self, X, y) -> CumlArray: + cp_arr = cp.ones((10,)) + + # Don't be tempted to use `CumlArray(cp_arr)` here either + cuml_arr = input_to_cuml_array(cp_arr, order="K").array + + return cuml_arr.to_output(self._get_output_type(X)) +``` + +### **Don't:** Use `CumlArray.to_output()` directly + +Using `CumlArray.to_output()` is no longer necessary except in very rare circumstances. Converting array types is best handled with `input_to_cuml_array` or `cuml.using_output_type()` when retrieving `CumlArrayDescriptor` values. + +**Do this:** +```python +def _private_func(self) -> CumlArray: + return cp.ones((10,)) + +def predict(self, X, y) -> CumlArray: + + self.my_cupy_attribute_ = cp.zeros((10,)) + + with cuml.using_output_type("numpy"): + np_arr = self._private_func() + + return self.my_cupy_attribute_ + np_arr +``` + +**Not this:** +```python +def _private_func(self) -> CumlArray: + return cp.ones((10,)) + +def predict(self, X, y) -> CumlArray: + + self.my_cupy_attribute_ = cp.zeros((10,)) + + np_arr = CumlArray(self._private_func()).to_output("numpy") + + return CumlArray(self.my_cupy_attribute_).to_output("numpy") + np_arr +``` + +### **Don't:** Perform parameter modification in `__init__()` + +Input arguments to `__init__()` should be stored as they were passed in. Parameter modification, such as converting parameter strings to integers, should be done in `fit()` or a helper private function. + +While it's more verbose, altering the parameters in `__init__` will break the estimator's ability to be used in `clone()`. + +**Do this:** +```python +class TestEstimator(cuml.Base): + + def __init__(self, method_name: str, ...): + super().__init__(...) + + self.method_name = method_name + + def _method_int(self) -> int: + return 1 if self.method_name == "type1" else 0 + + def fit(self, X) -> "TestEstimator": + + # Call external code from Cython + my_external_func(X.ptr, self._method_int()) + + return self +``` + +**Not this:** +```python +class TestEstimator(cuml.Base): + + def __init__(self, method_name: str, ...): + super().__init__(...) + + self.method_name = 1 if method_name == "type1" else 0 + + def fit(self, X) -> "TestEstimator": + + # Call external code from Cython + my_external_func(X.ptr, self.method_name) + + return self +``` + +## Appendix + +This section contains more in-depth information about the decorators and descriptors to help developers understand whats going on behind the scenes + +### Estimator Array-Like Attributes + +#### Automatic Decoration Rules + +Adding decorators to every estimator function just to use the decorator default values would be very repetitive and unnecessary. Because most of estimator functions follow a similar pattern, a new meta-class has been created to automatically decorate estimator functions based off their return type. This meta class will decorate functions according to a few rules: + +1. If a functions has been manually decorated, it will not be automatically decorated +2. If an estimator function returns an instance of `Base`, then `@api_base_return_any()` will be applied. +3. If an estimator function returns a `CumlArray`, then `@api_base_return_array()` will be applied. +3. If an estimator function returns a `SparseCumlArray`, then `@api_base_return_sparse_array()` will be applied. +4. If an estimator function returns a `dict`, `tuple`, `list` or `typing.Union`, then `@api_base_return_generic()` will be applied. + +| Return Type | Decorator | Notes | +| :-----------: | :-----------: | :----------- | +| `Base` | `@api_base_return_any(set_output_type=True, set_n_features_in=True)` | Any type that `isinstance` of `Base` will work | +| `CumlArray` | `@api_base_return_array(get_output_type=True)` | Functions can return any array-like object | +| `SparseCumlArray` | `@api_base_return_sparse_array(get_output_type=True)` | Functions can return any sparse array-like object | +| `dict`, `tuple`, `list` or `typing.Union` | `@api_base_return_generic(get_output_type=True)` | Functions must return a generic object that contains an array-like object. No sparse arrays are supported | + +#### `CumlArrayDescriptor` Internals + +The internal representation of `CumlArrayDescriptor` is a `CumlArrayDescriptorMeta` object. To inspect the internal representation, the attribute value must be directly accessed from the estimator's `__dict__` (`getattr` and `__getattr__` will perform the conversion). For example: + +```python +my_est = TestEstimator() +my_est.fit(cp.ones((10,))) + +# Access the CumlArrayDescriptorMeta value directly. No array conversion will occur +print(my_est.__dict__["my_cuml_array_"]) +# Output: CumlArrayDescriptorMeta(input_type='cupy', values={'cuml': , 'numpy': array([ 0, 1, 1, 2, 2, -1, -1, ... + +# Values from CumlArrayDescriptorMeta can be specifically read +print(my_est.__dict__["my_cuml_array_"].input_type) +# Output: "cupy" + +# The input value can be accessed +print(my_est.__dict__["my_cuml_array_"].get_input_value()) +# Output: CumlArray ... +``` + +### Estimator Methods + +#### Common Functionality + +All of these decorators perform the same basic steps with a few small differences. The common steps performed by each decorator is: + +1. Set `cuml.global_output_type = "mirror"` + 1. When `"mirror"` is used as the global output type, that indicates we are in an internal cuML API call. The `CumlArrayDescriptor` keys off this value to change between internal and external functionality +2. Set CuPy allocator to use RMM + 1. This replaces the existing decorator `@with_cupy_rmm` + 2. Unlike before, the CuPy allocator is only set once per API call +3. Set the estimator input attributes. Can be broken down into 3 steps: + 1. Set `_input_type` attribute + 2. Set `target_dtype` attribute + 3. Set `n_features` attribute +4. Call the desired function +5. Get the estimator output type. Can be broken down into 2 steps: + 1. Get `output_type` + 2. Get `output_dtype` +6. Convert the return value + 1. This will ultimately call `CumlArray.to_output(output_type=output_type, output_dtype=output_dtype) + +While the above list of steps may seem excessive for every call, most functions follow this general form, but may skip a few steps depending on a couple of factors. For example, Step #3 is necessary on functions that modify the estimator's estimated attributes, such as `fit()`, but is not necessary for functions like `predict()` or `transform()`. And Step #5/6 are only necessary when returning array-like objects and are omitted when returning any other type. + +Functionally, you can think of these decorators equivalent to the following pseudocode: +```python +def my_func(self, X): + with cuml.using_ouput_type("mirror"): + with cupy.cuda.cupy_using_allocator(rmm.rmm_cupy_allocator): + # Set the input properties + self._set_base_attributes(output_type=X, n_features=X) + + # Do actual calculation returning an array-like object + ret_val = self._my_func(X) + + # Get the output type + output_type = self._get_output_type(X) + + # Convert array-like to CumlArray + ret_val = input_to_cuml_array(ret_val, order="K").array + + # Convert CumlArray to desired output_type + return ret_val.to_output(output_type) +``` + +Keep the above pseudocode in mind when working with these decorators since their goal is to replace many of these repetitive functions. + +### Decorator Defaults + +Every function in `cuml` is slightly different and some `fit()` functions may need to set the `target_dtype` or some `predict()` functions may need to skip getting the output type. To handle these situations, all of the decorators take arguments to configure their functionality. + +Since the decorator's functionality is very similar, so are their arguments. All of the decorators take similar arguments that will be outlined below. + +| Argument | Type | Default | Meaning | +| :-----------: | :-----------: | :-----------: | :----------- | +| `input_arg` | `str` | `'X'` or 1st non-self argument | Determines which input argument to use for `_set_output_type()` and `_set_n_features_in()` | +| `target_arg` | `str` | `'y'` or 2nd non-self argument | Determines which input argument to use for `_set_target_dtype()` | +| `set_output_type` | `bool` | Varies | Whether to call `_set_output_type(input_arg)` | +| `set_output_dtype` | `bool` | `False` | Whether to call `_set_target_dtype(target_arg)` | +| `set_n_features_in` | `bool` | Varies | Whether to call `_set_n_features_in(input_arg)` | +| `get_output_type` | `bool` | Varies | Whether to call `_get_output_type(input_arg)` | +| `get_output_dtype` | `bool` | `False` | Whether to call `_get_target_dtype()` | + +An example of how these arguments can be used is below: + +**Before:** +```python +@with_cupy_rmm +def predict(self, X, y): + # Determine the output type and dtype + out_type = self._get_output_type(y) + out_dtype = self._get_target_dtype() + + # Convert to CumlArray + X_m = input_to_cuml_array(X, order="K").array + + # Call a cuda function + someCudaFunction(X_m.ptr) + + # Convert the CudaArray to the desired output + return X_m.to_output(output_type=out_type, output_dtype=out_dtype) +``` + +**After:** +```python +@cuml.internals.api_base_return_array(input_arg="y", get_output_dtype=True) +def predict(self, X): + # Convert to CumlArray + X_m = input_to_cuml_array(X, order="K").array + + # Call a cuda function + someCudaFunction(X_m.ptr) + + # Convert the CudaArray to the desired output + return X_m +``` + +#### Before `0.17` and After Comparison + +For developers used to the `0.16` architecture it can be helpful to see examples of estimator methods from `0.16` compared to `0.17` and after. This section shows a few examples side-by-side to illustrate the changes. + +##### `fit()` + + + + + + + + + + + + + + +
BeforeAfter
+ +```python +@with_cupy_rmm +def fit(self, X): + # Set the base input attributes + self._set_base_attributes(output_type=X, n_features=X) + + self.coef_ = input_to_cuml_array(X, order="K").array + + return self +``` + + + +```python + +def fit(self, X) -> "KMeans": + + + + self.coef_ = input_to_cuml_array(X, order="K").array + + return self +``` + +
+ +**Notes:** + - `@with_cupy_rmm` is no longer needed. This is automatically applied for every public method of estimators + - `self._set_base_attributes()` no longer needs to be called. + +##### `predict()` + + + + + + + + + + + + + + +
BeforeAfter
+ +```python +@with_cupy_rmm +def predict(self, X, y): + # Determine the output type and dtype + out_type = self._get_output_type(y) + + # Convert to CumlArray + X_m = input_to_cuml_array(X, order="K").array + + # Do some calculation with cupy + X_m = cp.asarray(X_m) + cp.ones(X_m.shape) + + # Convert back to CumlArray + X_m = CumlArray(X_m) + + # Convert the CudaArray to the desired output + return X_m.to_output(output_type=out_type) +``` + + +```python + +def predict(self, X) -> CumlArray: + + + + # Convert to CumlArray + X_m = input_to_cuml_array(X, order="K").array + + # Call a cuda function + X_m = cp.asarray(X_m) + cp.ones(X_m.shape) + + + + + # Directly return a cupy array + return X_m +``` + +
+ +**Notes:** + - `@with_cupy_rmm` is no longer needed. This is automatically applied for every public method of estimators + - `self._get_output_type()` no longer needs to be called. The output type is determined automatically + - Its not necessary to convert to `CumlArray` and casting with `to_output` before returning. This function directly returned a `cp.ndarray` object. Any array-like object can be returned. + +##### `predict()` with `dtype` + + + + + + + + + + + + + + +
BeforeAfter
+ +```python + +@with_cupy_rmm +def predict(self, X_in): + # Determine the output_type + out_type = self._get_output_type(X_in) + out_dtype = self._get_target_dtype() + + # Convert to CumlArray + X_m = input_to_cuml_array(X_in, order="K").array + + # Call a cuda function + X_m = cp.asarray(X_m) + cp.ones(X_m.shape) + + # Convert back to CumlArray + X_m = CumlArray(X_m) + + # Convert the CudaArray to the desired output and dtype + return X_m.to_output(output_type=out_type, output_dtype=out_dtype) +``` + + + +```python + +@api_base_return_array(input_arg="X_in", get_output_dtype=True) +def predict(self, X): + + + + + # Convert to CumlArray + X_m = input_to_cuml_array(X, order="K").array + + # Call a cuda function + X_m = cp.asarray(X_m) + cp.ones(X_m.shape) + + + + + # Return the cupy array directly + return X_m + +``` + +
+ +**Notes:** + - `@with_cupy_rmm` is no longer needed. This is automatically applied with every decorator + - The decorator argument `input_arg` can be used to specify which input should be considered the "input". + - In reality, this isn't necessary for this example. The decorator will look for an argument named `"X"` or default to the first, non `self`, argument. + - `self._get_output_type()` and `self._get_target_dtype()` no longer needs to be called. Both the output type and dtype are determined automatically + - It's not necessary to convert to `CumlArray` and casting with `to_output` before returning. This function directly returned a `cp.ndarray` object. Any array-like object can be returned. + - Specifying `get_output_dtype=True` in the decorator argument instructs the decorator to also calculate the dtype in addition to the output type. \ No newline at end of file