Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve indexing in methods when applied to DataFrame and Series objects #4317

Merged
merged 13 commits into from
Nov 13, 2021
3 changes: 2 additions & 1 deletion python/cuml/cluster/dbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ class DBSCAN(Base,

cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()

self.labels_ = CumlArray.empty(n_rows, dtype=out_dtype)
self.labels_ = CumlArray.empty(n_rows, dtype=out_dtype,
index=X_m.index)
cdef uintptr_t labels_ptr = self.labels_.ptr

cdef uintptr_t core_sample_indices_ptr = <uintptr_t> NULL
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/cluster/hdbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):
self.n_connected_components_ = 1
self.n_leaves_ = n_rows

self.labels_ = CumlArray.empty(n_rows, dtype="int32")
self.labels_ = CumlArray.empty(n_rows, dtype="int32", index=X_m.index)
self.children_ = CumlArray.empty((2, n_rows), dtype="int32")
self.probabilities_ = CumlArray.empty(n_rows, dtype="float32")
self.sizes_ = CumlArray.empty(n_rows, dtype="int32")
Expand Down
3 changes: 2 additions & 1 deletion python/cuml/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,8 @@ class KMeans(Base,

cdef uintptr_t cluster_centers_ptr = self.cluster_centers_.ptr

self.labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32)
self.labels_ = CumlArray.zeros(shape=n_rows, dtype=np.int32,
index=X_m.index)
cdef uintptr_t labels_ptr = self.labels_.ptr

# Sum of squared distances of samples to their closest cluster center.
Expand Down
60 changes: 49 additions & 11 deletions python/cuml/common/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ class CumlArray(Buffer):

@nvtx.annotate(message="common.CumlArray.__init__", category="utils",
domain="cuml_python")
def __init__(self, data=None, owner=None, dtype=None, shape=None,
def __init__(self,
data=None,
index=None,
owner=None,
dtype=None,
shape=None,
order=None):

# Checks of parameters
Expand Down Expand Up @@ -148,6 +153,7 @@ def __init__(self, data=None, owner=None, dtype=None, shape=None,
else:
flattened_data = data

self._index = index
super().__init__(data=flattened_data,
owner=owner,
size=size)
Expand Down Expand Up @@ -179,6 +185,16 @@ def __init__(self, data=None, owner=None, dtype=None, shape=None,
self.strides = ary_interface['strides']
self.order = _strides_to_order(self.strides, self.dtype)

# We use the index as a property to allow for validation/processing
# in the future if needed
@property
def index(self):
return self._index

@index.setter
def index(self, index):
self._index = index

@with_cupy_rmm
def __getitem__(self, slice):
return CumlArray(data=cp.asarray(self).__getitem__(slice))
Expand Down Expand Up @@ -267,7 +283,8 @@ def to_output(self, output_type='cupy', output_dtype=None):
mat = cp.asarray(self, dtype=output_dtype)
if len(mat.shape) == 1:
mat = mat.reshape(mat.shape[0], 1)
return DataFrame(mat)
return DataFrame(mat,
index=self.index)
else:
raise ValueError('cuDF unsupported Array dtype')

Expand All @@ -277,7 +294,9 @@ def to_output(self, output_type='cupy', output_dtype=None):
if len(self.shape) == 1:
if self.dtype not in [np.uint8, np.uint16, np.uint32,
np.uint64, np.float16]:
return Series(self, dtype=output_dtype)
return Series(self,
dtype=output_dtype,
index=self.index)
else:
raise ValueError('cuDF unsupported Array dtype')
elif self.shape[1] > 1:
Expand Down Expand Up @@ -307,7 +326,11 @@ def serialize(self):
@classmethod
@nvtx.annotate(message="common.CumlArray.empty", category="utils",
domain="cuml_python")
def empty(cls, shape, dtype, order='F'):
def empty(cls,
shape,
dtype,
order='F',
index=None):
"""
Create an empty Array with an allocated but uninitialized DeviceBuffer

Expand All @@ -321,12 +344,17 @@ def empty(cls, shape, dtype, order='F'):
Whether to create a F-major or C-major array.
"""

return CumlArray(cp.empty(shape, dtype, order))
return CumlArray(cp.empty(shape, dtype, order), index=index)

@classmethod
@nvtx.annotate(message="common.CumlArray.full", category="utils",
domain="cuml_python")
def full(cls, shape, value, dtype, order='F'):
def full(cls,
shape,
value,
dtype,
order='F',
index=None):
"""
Create an Array with an allocated DeviceBuffer initialized to value.

Expand All @@ -340,12 +368,16 @@ def full(cls, shape, value, dtype, order='F'):
Whether to create a F-major or C-major array.
"""

return CumlArray(cp.full(shape, value, dtype, order))
return CumlArray(cp.full(shape, value, dtype, order), index=index)

@classmethod
@nvtx.annotate(message="common.CumlArray.zeros", category="utils",
domain="cuml_python")
def zeros(cls, shape, dtype='float32', order='F'):
def zeros(cls,
shape,
dtype='float32',
order='F',
index=None):
"""
Create an Array with an allocated DeviceBuffer initialized to zeros.

Expand All @@ -358,12 +390,17 @@ def zeros(cls, shape, dtype='float32', order='F'):
order: string, optional
Whether to create a F-major or C-major array.
"""
return CumlArray.full(value=0, shape=shape, dtype=dtype, order=order)
return CumlArray.full(value=0, shape=shape, dtype=dtype, order=order,
index=index)

@classmethod
@nvtx.annotate(message="common.CumlArray.ones", category="utils",
domain="cuml_python")
def ones(cls, shape, dtype='float32', order='F'):
def ones(cls,
shape,
dtype='float32',
order='F',
index=None):
"""
Create an Array with an allocated DeviceBuffer initialized to zeros.

Expand All @@ -376,7 +413,8 @@ def ones(cls, shape, dtype='float32', order='F'):
order: string, optional
Whether to create a F-major or C-major array.
"""
return CumlArray.full(value=1, shape=shape, dtype=dtype, order=order)
return CumlArray.full(value=1, shape=shape, dtype=dtype, order=order,
index=index)


def _check_low_level_type(data):
Expand Down
1 change: 1 addition & 0 deletions python/cuml/common/array_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(self, data=None,
self.shape = data.shape
self.dtype = self.data.dtype
self.nnz = data.nnz
self.index = None

@nvtx.annotate(message="common.SparseCumlArray.to_output",
category="utils", domain="cuml_python")
Expand Down
20 changes: 14 additions & 6 deletions python/cuml/common/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def check_order(arr_order):
safe_dtype=safe_dtype_conversion)
check_dtype = False

index = getattr(X, 'index', None)

# format conversion

if (isinstance(X, cudf.Series)):
Expand All @@ -322,9 +324,11 @@ def check_order(arr_order):

if isinstance(X, cudf.DataFrame):
if order == 'K':
X_m = CumlArray(data=X.as_gpu_matrix(order='F'))
X_m = CumlArray(data=X.as_gpu_matrix(order='F'),
index=index)
else:
X_m = CumlArray(data=X.as_gpu_matrix(order=order))
X_m = CumlArray(data=X.as_gpu_matrix(order=order),
index=index)

elif isinstance(X, CumlArray):
X_m = X
Expand All @@ -349,7 +353,6 @@ def check_order(arr_order):
if not _check_array_contiguity(X):
debug("Non contiguous array or view detected, a "
"contiguous copy of the data will be done.")
# X = cp.array(X, order=order, copy=True)
make_copy = True

# If we have a host array, we copy it first before changing order
Expand All @@ -359,7 +362,8 @@ def check_order(arr_order):

cp_arr = cp.array(X, copy=make_copy, order=order)

X_m = CumlArray(data=cp_arr)
X_m = CumlArray(data=cp_arr,
index=index)

if deepcopy:
X_m = copy.deepcopy(X_m)
Expand Down Expand Up @@ -404,9 +408,13 @@ def check_order(arr_order):

if (check_order(X_m.order)):
X_m = cp.array(X_m, copy=False, order=order)
X_m = CumlArray(data=X_m)
X_m = CumlArray(data=X_m,
index=index)

return cuml_array(array=X_m, n_rows=n_rows, n_cols=n_cols, dtype=X_m.dtype)
return cuml_array(array=X_m,
n_rows=n_rows,
n_cols=n_cols,
dtype=X_m.dtype)


@nvtx.annotate(message="common.input_utils.input_to_cupy_array",
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/decomposition/pca.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ class PCA(Base,

t_input_data = \
CumlArray.zeros((params.n_rows, params.n_components),
dtype=dtype.type)
dtype=dtype.type, index=X_m.index)

cdef uintptr_t _trans_input_ptr = t_input_data.ptr
cdef uintptr_t components_ptr = self.components_.ptr
Expand Down
14 changes: 7 additions & 7 deletions python/cuml/decomposition/tsvd.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ class TruncatedSVD(Base,
self.singular_values_.ptr

_trans_input_ = CumlArray.zeros((params.n_rows, params.n_components),
dtype=self.dtype)
dtype=self.dtype, index=X_m.index)
cdef uintptr_t t_input_ptr = _trans_input_.ptr

if self.n_components> self.n_cols:
Expand Down Expand Up @@ -389,7 +389,7 @@ class TruncatedSVD(Base,

"""

trans_input, n_rows, _, dtype = \
X_m, n_rows, _, dtype = \
input_to_cuml_array(X, check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype
else None))
Expand All @@ -400,9 +400,9 @@ class TruncatedSVD(Base,
params.n_cols = self.n_cols

input_data = CumlArray.zeros((params.n_rows, params.n_cols),
dtype=self.dtype)
dtype=self.dtype, index=X_m.index)

cdef uintptr_t trans_input_ptr = trans_input.ptr
cdef uintptr_t trans_input_ptr = X_m.ptr
cdef uintptr_t input_ptr = input_data.ptr
cdef uintptr_t components_ptr = self.components_.ptr

Expand Down Expand Up @@ -436,7 +436,7 @@ class TruncatedSVD(Base,
Perform dimensionality reduction on X.

"""
input, n_rows, _, dtype = \
X_m, n_rows, _, dtype = \
input_to_cuml_array(X, check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype
else None),
Expand All @@ -449,9 +449,9 @@ class TruncatedSVD(Base,

t_input_data = \
CumlArray.zeros((params.n_rows, params.n_components),
dtype=self.dtype)
dtype=self.dtype, index=X_m.index)

cdef uintptr_t input_ptr = input.ptr
cdef uintptr_t input_ptr = X_m.ptr
cdef uintptr_t trans_input_ptr = t_input_data.ptr
cdef uintptr_t components_ptr = self.components_.ptr

Expand Down
2 changes: 1 addition & 1 deletion python/cuml/experimental/linear_model/lars.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ class Lars(Base, RegressorMixin):
cdef uintptr_t active_idx_ptr = \
input_to_cuml_array(self.active_).array.ptr

preds = CumlArray.zeros(n_rows, dtype=self.dtype)
preds = CumlArray.zeros(n_rows, dtype=self.dtype, index=X_m.index)

if self.dtype == np.float32:
larsPredict(handle_[0], <float*> X_ptr, <int> n_rows,
Expand Down
12 changes: 7 additions & 5 deletions python/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,13 @@ cdef class ForestInference_impl():
shape += (2,)
else:
shape += (self.num_class,)
preds = CumlArray.empty(shape=shape, dtype=np.float32, order='C')
elif (not isinstance(preds, cudf.Series) and
not rmm.is_cuda_array(preds)):
raise ValueError("Invalid type for output preds,"
" need GPU array")
preds = CumlArray.empty(shape=shape, dtype=np.float32, order='C',
index=X_m.index)
else:
if not hasattr(preds, "__cuda_array_interface__"):
raise ValueError("Invalid type for output preds,"
" need GPU array")
preds.index = X_m.index

cdef uintptr_t preds_ptr
preds_ptr = preds.ptr
Expand Down
6 changes: 2 additions & 4 deletions python/cuml/linear_model/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,15 @@ class LinearPredictMixin:
Predicts `y` values for `X`.

"""
cdef uintptr_t X_ptr
X_m, n_rows, n_cols, dtype = \
input_to_cuml_array(X, check_dtype=self.dtype,
convert_to_dtype=(self.dtype if convert_dtype
else None),
check_cols=self.n_cols)
X_ptr = X_m.ptr

cdef uintptr_t X_ptr = X_m.ptr
cdef uintptr_t coef_ptr = self.coef_.ptr

preds = CumlArray.zeros(n_rows, dtype=dtype)
preds = CumlArray.zeros(n_rows, dtype=dtype, index=X_m.index)
cdef uintptr_t preds_ptr = preds.ptr

cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/manifold/t_sne.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ class TSNE(Base,
convert_format=False)
n, p = self.X_m.shape
self.sparse_fit = True

# Handle dense inputs
else:
self.X_m, n, p, _ = \
Expand Down Expand Up @@ -451,7 +450,8 @@ class TSNE(Base,
self.embedding_ = CumlArray.zeros(
(n, self.n_components),
order="F",
dtype=np.float32)
dtype=np.float32,
index=self.X_m.index)

cdef uintptr_t embed_ptr = self.embedding_.ptr

Expand Down
8 changes: 6 additions & 2 deletions python/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,8 @@ class UMAP(Base,

self.embedding_ = CumlArray.zeros((self.n_rows,
self.n_components),
order="C", dtype=np.float32)
order="C", dtype=np.float32,
index=self.X_m.index)

if self.hash_input:
with using_output_type("numpy"):
Expand Down Expand Up @@ -720,12 +721,14 @@ class UMAP(Base,
if is_sparse(X):
X_m = SparseCumlArray(X, convert_to_dtype=cupy.float32,
convert_format=False)
index = None
else:
X_m, n_rows, n_cols, dtype = \
input_to_cuml_array(X, order='C', check_dtype=np.float32,
convert_to_dtype=(np.float32
if convert_dtype
else None))
index = X_m.index
n_rows = X_m.shape[0]
n_cols = X_m.shape[1]

Expand All @@ -745,7 +748,8 @@ class UMAP(Base,

embedding = CumlArray.zeros((X_m.shape[0],
self.n_components),
order="C", dtype=np.float32)
order="C", dtype=np.float32,
index=index)
cdef uintptr_t xformed_ptr = embedding.ptr

(knn_indices_m, knn_indices_ctype), (knn_dists_m, knn_dists_ctype) =\
Expand Down
Loading