diff --git a/python/cudf/cudf/_lib/copying.pyx b/python/cudf/cudf/_lib/copying.pyx index 85faa08621b..c892c100bf6 100644 --- a/python/cudf/cudf/_lib/copying.pyx +++ b/python/cudf/cudf/_lib/copying.pyx @@ -247,47 +247,6 @@ def scatter(object source, Column scatter_map, Column target_column, return next(iter(data.values())) -def _reverse_column(Column source_column): - cdef column_view reverse_column_view = source_column.view() - - cdef unique_ptr[column] c_result - - with nogil: - c_result = move(cpp_copying.reverse( - reverse_column_view - )) - - return Column.from_unique_ptr( - move(c_result) - ) - - -def _reverse_table(source_table): - cdef table_view reverse_table_view = table_view_from_columns(source_table) - - cdef unique_ptr[table] c_result - with nogil: - c_result = move(cpp_copying.reverse( - reverse_table_view - )) - - return data_from_unique_ptr( - move(c_result), - column_names=source_table._column_names, - index_names=source_table._index_names - ) - - -def reverse(object source): - """ - Reversing a column or a table - """ - if isinstance(source, Column): - return _reverse_column(source) - else: - return _reverse_table(source) - - def column_empty_like(Column input_column): cdef column_view input_column_view = input_column.view() diff --git a/python/cudf/cudf/_lib/cpp/copying.pxd b/python/cudf/cudf/_lib/cpp/copying.pxd index a318dc68ac9..be1b6d8069c 100644 --- a/python/cudf/cudf/_lib/cpp/copying.pxd +++ b/python/cudf/cudf/_lib/cpp/copying.pxd @@ -34,14 +34,6 @@ cdef extern from "cudf/copying.hpp" namespace "cudf" nogil: out_of_bounds_policy policy ) except + - cdef unique_ptr[table] reverse ( - const table_view& source_table - ) except + - - cdef unique_ptr[column] reverse ( - const column_view& source_column - ) except + - cdef unique_ptr[column] shift( const column_view& input, size_type offset, diff --git a/python/cudf/cudf/_lib/transpose.pyx b/python/cudf/cudf/_lib/transpose.pyx index bea1164b655..b33a3cefba7 100644 --- a/python/cudf/cudf/_lib/transpose.pyx +++ b/python/cudf/cudf/_lib/transpose.pyx @@ -61,7 +61,7 @@ def transpose(source): if cats is not None: data= [ (name, cudf.core.column.column.build_categorical_column( - codes=cudf.core.column.column.as_column( + codes=cudf.core.column.column.build_column( col.base_data, dtype=col.dtype), mask=col.base_mask, size=col.size, diff --git a/python/cudf/cudf/core/_internals/where.py b/python/cudf/cudf/core/_internals/where.py index 216be8cd511..6c94a84fd37 100644 --- a/python/cudf/cudf/core/_internals/where.py +++ b/python/cudf/cudf/core/_internals/where.py @@ -309,7 +309,7 @@ def where( ): result = cudf.core.column.build_categorical_column( categories=frame._data[column_name].categories, - codes=cudf.core.column.as_column( + codes=cudf.core.column.build_column( result.base_data, dtype=result.dtype ), mask=result.base_mask, @@ -368,7 +368,7 @@ def where( cudf.core.column.CategoricalColumn, frame._data[frame.name], ).categories, - codes=cudf.core.column.as_column( + codes=cudf.core.column.build_column( result.base_data, dtype=result.dtype ), mask=result.base_mask, diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index d2da594fa3b..a8e868ed521 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -713,7 +713,9 @@ def deserialize(cls, header: dict, frames: list) -> CategoricalColumn: data=None, dtype=dtype, mask=mask, - children=(column.as_column(data.base_data, dtype=data.dtype),), + children=( + column.build_column(data.base_data, dtype=data.dtype), + ), ), ) @@ -859,7 +861,7 @@ def slice( codes = self.codes.slice(start, stop, stride) return cudf.core.column.build_categorical_column( categories=self.categories, - codes=cudf.core.column.as_column( + codes=cudf.core.column.build_column( codes.base_data, dtype=codes.dtype ), mask=codes.base_mask, @@ -910,7 +912,7 @@ def sort_by_values( codes, inds = self.as_numerical.sort_by_values(ascending, na_position) col = column.build_categorical_column( categories=self.dtype.categories._values, - codes=column.as_column(codes.base_data, dtype=codes.dtype), + codes=column.build_column(codes.base_data, dtype=codes.dtype), mask=codes.base_mask, size=codes.size, ordered=self.dtype.ordered, @@ -1001,7 +1003,7 @@ def unique(self) -> CategoricalColumn: codes = self.as_numerical.unique() return column.build_categorical_column( categories=self.categories, - codes=column.as_column(codes.base_data, dtype=codes.dtype), + codes=column.build_column(codes.base_data, dtype=codes.dtype), mask=codes.base_mask, offset=codes.offset, size=codes.size, @@ -1044,7 +1046,7 @@ def find_and_replace( df = cudf.DataFrame({"old": to_replace_col, "new": replacement_col}) df = df.drop_duplicates(subset=["old"], keep="last", ignore_index=True) if df._data["old"].null_count == 1: - fill_value = df._data["new"][df._data["old"].isna()][0] + fill_value = df._data["new"][df._data["old"].isnull()][0] if fill_value in self.categories: replaced = self.fillna(fill_value) else: @@ -1060,7 +1062,7 @@ def find_and_replace( else: replaced = self if df._data["new"].null_count > 0: - drop_values = df._data["old"][df._data["new"].isna()] + drop_values = df._data["old"][df._data["new"].isnull()] cur_categories = replaced.categories new_categories = cur_categories[ ~cudf.Series(cur_categories.isin(drop_values)) @@ -1096,7 +1098,7 @@ def find_and_replace( # those categories don't exist anymore # Resetting the index creates a column 'index' that associates # the original integers to the new labels - bmask = new_cats._data["cats"].notna() + bmask = new_cats._data["cats"].notnull() new_cats = cudf.DataFrame( {"cats": new_cats._data["cats"].apply_boolean_mask(bmask)} ).reset_index() @@ -1123,7 +1125,7 @@ def find_and_replace( return column.build_categorical_column( categories=new_cats["cats"], - codes=column.as_column(output.base_data, dtype=output.dtype), + codes=column.build_column(output.base_data, dtype=output.dtype), mask=output.base_mask, offset=output.offset, size=output.size, @@ -1205,7 +1207,7 @@ def fillna( result = column.build_categorical_column( categories=self.dtype.categories._values, - codes=column.as_column(result.base_data, dtype=result.dtype), + codes=column.build_column(result.base_data, dtype=result.dtype), offset=result.offset, size=result.size, mask=result.base_mask, @@ -1301,7 +1303,7 @@ def copy(self, deep: bool = True) -> CategoricalColumn: return column.build_categorical_column( categories=copied_cat, - codes=column.as_column( + codes=column.build_column( copied_col.base_data, dtype=copied_col.dtype ), offset=copied_col.offset, @@ -1312,7 +1314,7 @@ def copy(self, deep: bool = True) -> CategoricalColumn: else: return column.build_categorical_column( categories=self.dtype.categories._values, - codes=column.as_column( + codes=column.build_column( self.codes.base_data, dtype=self.codes.dtype ), mask=self.base_mask, @@ -1374,7 +1376,9 @@ def _concat(objs: MutableSequence[CategoricalColumn]) -> CategoricalColumn: return column.build_categorical_column( categories=column.as_column(cats), - codes=column.as_column(codes_col.base_data, dtype=codes_col.dtype), + codes=column.build_column( + codes_col.base_data, dtype=codes_col.dtype + ), mask=codes_col.base_mask, size=codes_col.size, offset=codes_col.offset, @@ -1386,7 +1390,7 @@ def _with_type_metadata( if isinstance(dtype, CategoricalDtype): return column.build_categorical_column( categories=dtype.categories._values, - codes=column.as_column( + codes=column.build_column( self.codes.base_data, dtype=self.codes.dtype ), mask=self.codes.base_mask, @@ -1522,7 +1526,9 @@ def _set_categories( # codes can't have masks, so take mask out before moving in return column.build_categorical_column( categories=new_cats, - codes=column.as_column(new_codes.base_data, dtype=new_codes.dtype), + codes=column.build_column( + new_codes.base_data, dtype=new_codes.dtype + ), mask=new_codes.base_mask, size=new_codes.size, offset=new_codes.offset, @@ -1609,7 +1615,7 @@ def pandas_categorical_as_column( return column.build_categorical_column( categories=categorical.categories, - codes=column.as_column(codes.base_data, dtype=codes.dtype), + codes=column.build_column(codes.base_data, codes.dtype), size=codes.size, mask=mask, ordered=categorical.ordered, diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index ac287bc6ded..23b6c01ca83 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -62,7 +62,7 @@ ListDtype, StructDtype, ) -from cudf.utils import ioutils, utils +from cudf.utils import utils from cudf.utils.dtypes import ( cudf_dtype_from_pa_type, get_time_unit, @@ -169,14 +169,9 @@ def equals(self, other: ColumnBase, check_dtypes: bool = False) -> bool: return True if other is None or len(self) != len(other): return False - if check_dtypes: - if self.dtype != other.dtype: - return False - null_equals = self._null_equals(other) - return null_equals.all() - - def _null_equals(self, other: ColumnBase) -> ColumnBase: - return self.binary_operator("NULL_EQUALS", other) + if check_dtypes and (self.dtype != other.dtype): + return False + return self.binary_operator("NULL_EQUALS", other).all() def all(self, skipna: bool = True) -> bool: # If all entries are null the result is True, including when the column @@ -188,8 +183,8 @@ def all(self, skipna: bool = True) -> bool: if isinstance(result_col, ColumnBase): return libcudf.reduce.reduce("all", result_col, dtype=np.bool_) - else: - return result_col + + return result_col def any(self, skipna: bool = True) -> bool: # Early exit for fast cases. @@ -201,8 +196,8 @@ def any(self, skipna: bool = True) -> bool: if isinstance(result_col, ColumnBase): return libcudf.reduce.reduce("any", result_col, dtype=np.bool_) - else: - return result_col + + return result_col def __sizeof__(self) -> int: n = 0 @@ -217,10 +212,7 @@ def dropna(self, drop_nan: bool = False) -> ColumnBase: col = self.nans_to_nulls() else: col = self - dropped_col = ( - col.as_frame()._drop_na_rows(drop_nan=drop_nan)._as_column() - ) - return dropped_col + return col.as_frame()._drop_na_rows(drop_nan=drop_nan)._as_column() def to_arrow(self) -> pa.Array: """Convert to PyArrow Array @@ -314,10 +306,7 @@ def from_arrow(cls, array: pa.Array) -> ColumnBase: result = libcudf.interop.from_arrow(data, data.column_names)[0]["None"] - result = result._with_type_metadata( - cudf_dtype_from_pa_type(array.type) - ) - return result + return result._with_type_metadata(cudf_dtype_from_pa_type(array.type)) def _get_mask_as_column(self) -> ColumnBase: return libcudf.transform.mask_to_bools( @@ -372,9 +361,6 @@ def to_array(self, fillna=None) -> np.ndarray: return self.to_gpu_array(fillna=fillna).copy_to_host() - def _reverse(self): - return libcudf.copying.reverse(self) - def _fill( self, fill_value: ScalarLike, @@ -416,10 +402,9 @@ def valid_count(self) -> int: def nullmask(self) -> Buffer: """The gpu buffer for the null-mask """ - if self.nullable: - return self.mask_array_view - else: + if not self.nullable: raise ValueError("Column has no null mask") + return self.mask_array_view def copy(self: T, deep: bool = True) -> T: """Columns are immutable, so a deep copy produces a copy of the @@ -484,6 +469,7 @@ def view(self, dtype: Dtype) -> ColumnBase: + f" total bytes into {dtype} with size {dtype.itemsize}" ) + # This assertion prevents mypy errors below. assert self.base_data is not None new_buf_ptr = ( self.base_data.ptr + self.offset * self.dtype.itemsize @@ -655,11 +641,6 @@ def isnull(self) -> ColumnBase: return result - def isna(self) -> ColumnBase: - """Identify missing values in a Column. Alias for isnull. - """ - return self.isnull() - def notnull(self) -> ColumnBase: """Identify non-missing values in a Column. """ @@ -672,11 +653,6 @@ def notnull(self) -> ColumnBase: return result - def notna(self) -> ColumnBase: - """Identify non-missing values in a Column. Alias for notnull. - """ - return self.notnull() - def find_first_value( self, value: ScalarLike, closest: bool = False ) -> int: @@ -765,9 +741,7 @@ def isin(self, values: Sequence) -> ColumnBase: # typecasting fails return full(len(self), False, dtype="bool") - res = lhs._obtain_isin_result(rhs) - - return res + return lhs._obtain_isin_result(rhs) def _process_values_for_isin( self, values: Sequence @@ -790,7 +764,7 @@ def _isin_earlystop(self, rhs: ColumnBase) -> Union[ColumnBase, None]: """ if self.dtype != rhs.dtype: if self.null_count and rhs.null_count: - return self.isna() + return self.isnull() else: return cudf.core.column.full(len(self), False, dtype="bool") elif self.null_count == 0 and (rhs.null_count == len(rhs)): @@ -809,8 +783,7 @@ def _obtain_isin_result(self, rhs: ColumnBase) -> ColumnBase: ) res = ldf.merge(rdf, on="x", how="left").sort_values(by="orig_order") res = res.drop_duplicates(subset="orig_order", ignore_index=True) - res = res._data["bool"].fillna(False) - return res + return res._data["bool"].fillna(False) def as_mask(self) -> Buffer: """Convert booleans to bitmask @@ -825,20 +798,10 @@ def as_mask(self) -> Buffer: return bools_to_mask(self) - @ioutils.doc_to_dlpack() - def to_dlpack(self): - """{docstring}""" - - return cudf.io.dlpack.to_dlpack(self) - @property def is_unique(self) -> bool: return self.distinct_count() == len(self) - @property - def is_monotonic(self) -> bool: - return self.is_monotonic_increasing - @property def is_monotonic_increasing(self) -> bool: return not self.has_nulls and self.as_frame()._is_sorted( @@ -1040,19 +1003,17 @@ def as_decimal32_column( def apply_boolean_mask(self, mask) -> ColumnBase: mask = as_column(mask, dtype="bool") - result = ( + return ( self.as_frame()._apply_boolean_mask(boolean_mask=mask)._as_column() ) - return result def argsort( self, ascending: bool = True, na_position: builtins.str = "last" ) -> ColumnBase: - sorted_indices = self.as_frame()._get_sorted_inds( + return self.as_frame()._get_sorted_inds( ascending=ascending, na_position=na_position ) - return sorted_indices def __arrow_array__(self, type=None): raise TypeError( @@ -1200,15 +1161,13 @@ def min(self, skipna: bool = None, dtype: Dtype = None): result_col = self._process_for_reduction(skipna=skipna) if isinstance(result_col, ColumnBase): return libcudf.reduce.reduce("min", result_col, dtype=dtype) - else: - return result_col + return result_col def max(self, skipna: bool = None, dtype: Dtype = None): result_col = self._process_for_reduction(skipna=skipna) if isinstance(result_col, ColumnBase): return libcudf.reduce.reduce("max", result_col, dtype=dtype) - else: - return result_col + return result_col def sum( self, skipna: bool = None, dtype: Dtype = None, min_count: int = 0 @@ -1247,11 +1206,11 @@ def corr(self, other: ColumnBase): ) def nans_to_nulls(self: T) -> T: - if self.dtype.kind == "f": - newmask = libcudf.transform.nans_to_nulls(self) - return self.set_mask(newmask) - else: + # Only floats can contain nan. + if self.dtype.kind != "f": return self + newmask = libcudf.transform.nans_to_nulls(self) + return self.set_mask(newmask) def _process_for_reduction( self, skipna: bool = None, min_count: int = 0 @@ -1753,12 +1712,6 @@ def as_column( if dtype is not None: data = data.astype(dtype) - elif type(arbitrary) is Buffer: - if dtype is None: - raise TypeError("dtype cannot be None if 'arbitrary' is a Buffer") - - data = build_column(arbitrary, dtype=dtype) - elif hasattr(arbitrary, "__cuda_array_interface__"): desc = arbitrary.__cuda_array_interface__ current_dtype = np.dtype(desc["typestr"]) @@ -1929,9 +1882,7 @@ def as_column( buffer = Buffer(arbitrary.view("|u1")) mask = None if nan_as_null is None or nan_as_null is True: - data = as_column( - buffer, dtype=arbitrary.dtype, nan_as_null=nan_as_null - ) + data = build_column(buffer, dtype=arbitrary.dtype) data = data._make_copy_with_na_as_null() mask = data.mask @@ -1949,9 +1900,7 @@ def as_column( buffer = Buffer(arbitrary.view("|u1")) mask = None if nan_as_null is None or nan_as_null is True: - data = as_column( - buffer, dtype=arbitrary.dtype, nan_as_null=nan_as_null - ) + data = build_column(buffer, dtype=arbitrary.dtype) data = data._make_copy_with_na_as_null() mask = data.mask @@ -2039,7 +1988,7 @@ def as_column( cudf_dtype = arbitrary._data.dtype data = Buffer(arbitrary._data.view("|u1")) - data = as_column(data, dtype=cudf_dtype) + data = build_column(data, dtype=cudf_dtype) mask = arbitrary._mask mask = bools_to_mask(as_column(mask).unary_operator("not")) diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index eba6764e83d..b1d69316863 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -494,7 +494,7 @@ def _make_copy_with_na_as_null(self): na_value = np.datetime64("nat", self.time_unit) out_col = cudf._lib.replace.replace( self, - as_column( + column.build_column( Buffer(np.array([na_value], dtype=self.dtype).view("|u1")), dtype=self.dtype, ), diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 27ff5da5505..ca5026c2293 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -3,7 +3,7 @@ from __future__ import annotations from types import SimpleNamespace -from typing import Any, Mapping, Sequence, Tuple, Union, cast +from typing import Any, Callable, Mapping, Sequence, Tuple, Union, cast import cupy import numpy as np @@ -61,8 +61,7 @@ def __init__( if data.size % dtype.itemsize: raise ValueError("Buffer size must be divisible by element size") if size is None: - size = data.size // dtype.itemsize - size = size - offset + size = (data.size // dtype.itemsize) - offset super().__init__( data, @@ -118,8 +117,12 @@ def __cuda_array_interface__(self) -> Mapping[str, Any]: return output - def unary_operator(self, unaryop: str) -> ColumnBase: - return _numeric_column_unaryop(self, op=unaryop) + def unary_operator(self, unaryop: Union[str, Callable]) -> ColumnBase: + if callable(unaryop): + return libcudf.transform.transform(self, unaryop) + + unaryop = libcudf.unary.UnaryOp[unaryop.upper()] + return libcudf.unary.unary_operation(self, unaryop) def binary_operator( self, binop: str, rhs: BinaryOperand, reflect: bool = False, @@ -203,8 +206,7 @@ def normalize_binop_value( if self.dtype.kind == "b": other_dtype = min_signed_type(other) if np.isscalar(other): - other = cudf.dtype(other_dtype).type(other) - return other + return cudf.dtype(other_dtype).type(other) else: ary = utils.scalar_broadcast_to( other, size=len(self), dtype=other_dtype @@ -352,7 +354,7 @@ def find_and_replace( df = df.drop_duplicates(subset=["old"], keep="last", ignore_index=True) if df._data["old"].null_count == 1: replaced = replaced.fillna( - df._data["new"][df._data["old"].isna()][0] + df._data["new"][df._data["old"].isnull()][0] ) df = df.dropna(subset=["old"]) @@ -370,10 +372,7 @@ def fillna( """ Fill null values with *fill_value* """ - if fill_nan: - col = self.nans_to_nulls() - else: - col = self + col = self.nans_to_nulls() if fill_nan else self if col.null_count == 0: return col @@ -401,73 +400,72 @@ def fillna( fill_value = cudf.Scalar(fill_value_casted) else: fill_value = column.as_column(fill_value, nan_as_null=False) - # cast safely to the same dtype as self if is_integer_dtype(col.dtype): - fill_value = _safe_cast_to_int(fill_value, col.dtype) + # cast safely to the same dtype as self + if fill_value.dtype != col.dtype: + new_fill_value = fill_value.astype(col.dtype) + if not (new_fill_value == fill_value).all(): + raise TypeError( + f"Cannot safely cast non-equivalent " + f"{col.dtype.type.__name__} to " + f"{cudf.dtype(dtype).type.__name__}" + ) + fill_value = new_fill_value else: fill_value = fill_value.astype(col.dtype) return super(NumericalColumn, col).fillna(fill_value, method) - def find_first_value( - self, value: ScalarLike, closest: bool = False + def _find_value( + self, value: ScalarLike, closest: bool, find: Callable, compare: str ) -> int: - """ - Returns offset of first value that matches. For monotonic - columns, returns the offset of the first larger value - if closest=True. - """ value = to_cudf_compatible_scalar(value) if not is_number(value): raise ValueError("Expected a numeric value") found = 0 if len(self): - found = cudautils.find_first( - self.data_array_view, value, mask=self.mask - ) - if found == -1 and self.is_monotonic and closest: - if value < self.min(): - found = 0 - elif value > self.max(): - found = len(self) - else: - found = cudautils.find_first( - self.data_array_view, value, mask=self.mask, compare="gt", + found = find(self.data_array_view, value, mask=self.mask,) + if found == -1: + if self.is_monotonic_increasing and closest: + found = find( + self.data_array_view, + value, + mask=self.mask, + compare=compare, ) if found == -1: raise ValueError("value not found") - elif found == -1: - raise ValueError("value not found") + else: + raise ValueError("value not found") return found + def find_first_value( + self, value: ScalarLike, closest: bool = False + ) -> int: + """ + Returns offset of first value that matches. For monotonic + columns, returns the offset of the first larger value + if closest=True. + """ + if self.is_monotonic_increasing and closest: + if value < self.min(): + return 0 + elif value > self.max(): + return len(self) + return self._find_value(value, closest, cudautils.find_first, "gt") + def find_last_value(self, value: ScalarLike, closest: bool = False) -> int: """ Returns offset of last value that matches. For monotonic columns, returns the offset of the last smaller value if closest=True. """ - value = to_cudf_compatible_scalar(value) - if not is_number(value): - raise ValueError("Expected a numeric value") - found = 0 - if len(self): - found = cudautils.find_last( - self.data_array_view, value, mask=self.mask, - ) - if found == -1 and self.is_monotonic and closest: + if self.is_monotonic_increasing and closest: if value < self.min(): - found = -1 + return -1 elif value > self.max(): - found = len(self) - 1 - else: - found = cudautils.find_last( - self.data_array_view, value, mask=self.mask, compare="lt", - ) - if found == -1: - raise ValueError("value not found") - elif found == -1: - raise ValueError("value not found") - return found + return len(self) - 1 + return self._find_value(value, closest, cudautils.find_last, "lt") def can_cast_safely(self, to_dtype: DtypeObj) -> bool: """ @@ -505,34 +503,23 @@ def can_cast_safely(self, to_dtype: DtypeObj) -> bool: # Column contains only infs return True - max_ = col.max() - if (min_ >= lower_) and (max_ < upper_): - return True - else: - return False + return (min_ >= lower_) and (col.max() < upper_) # want to cast int to uint elif self.dtype.kind == "i" and to_dtype.kind == "u": i_max_ = np.iinfo(self.dtype).max u_max_ = np.iinfo(to_dtype).max - if self.min() >= 0: - if i_max_ <= u_max_: - return True - if self.max() < u_max_: - return True - return False + return (self.min() >= 0) and ( + (i_max_ <= u_max_) or (self.max() < u_max_) + ) # want to cast uint to int elif self.dtype.kind == "u" and to_dtype.kind == "i": u_max_ = np.iinfo(self.dtype).max i_max_ = np.iinfo(to_dtype).max - if u_max_ <= i_max_: - return True - if self.max() < i_max_: - return True - return False + return (u_max_ <= i_max_) or (self.max() < i_max_) # want to cast int to float elif self.dtype.kind in {"i", "u"} and to_dtype.kind == "f": @@ -545,13 +532,10 @@ def can_cast_safely(self, to_dtype: DtypeObj) -> bool: else: filled = self.fillna(0) - if ( + return ( cudf.Series(filled).astype(to_dtype).astype(filled.dtype) == cudf.Series(filled) - ).all(): - return True - else: - return False + ).all() # want to cast float to int: elif self.dtype.kind == "f" and to_dtype.kind in {"i", "u"}: @@ -561,10 +545,7 @@ def can_cast_safely(self, to_dtype: DtypeObj) -> bool: # best we can do is hope to catch it here and avoid compare if (self.min() >= min_) and (self.max() <= max_): filled = self.fillna(0, fill_nan=False) - if (cudf.Series(filled) % 1 == 0).all(): - return True - else: - return False + return (cudf.Series(filled) % 1 == 0).all() else: return False @@ -574,7 +555,7 @@ def _with_type_metadata(self: ColumnBase, dtype: Dtype) -> ColumnBase: if isinstance(dtype, CategoricalDtype): return column.build_categorical_column( categories=dtype.categories._values, - codes=as_column(self.base_data, dtype=self.dtype), + codes=build_column(self.base_data, dtype=self.dtype), mask=self.base_mask, ordered=dtype.ordered, size=self.size, @@ -602,33 +583,6 @@ def to_pandas( return pd_series -def _numeric_column_unaryop(operand: ColumnBase, op: str) -> ColumnBase: - if callable(op): - return libcudf.transform.transform(operand, op) - - op = libcudf.unary.UnaryOp[op.upper()] - return libcudf.unary.unary_operation(operand, op) - - -def _safe_cast_to_int(col: ColumnBase, dtype: DtypeObj) -> ColumnBase: - """ - Cast given NumericalColumn to given integer dtype safely. - """ - assert is_integer_dtype(dtype) - - if col.dtype == dtype: - return col - - new_col = col.astype(dtype) - if (new_col == col).all(): - return new_col - else: - raise TypeError( - f"Cannot safely cast non-equivalent " - f"{col.dtype.type.__name__} to {cudf.dtype(dtype).type.__name__}" - ) - - def _normalize_find_and_replace_input( input_column_dtype: DtypeObj, col_to_normalize: Union[ColumnBase, list] ) -> ColumnBase: diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 81d4c9adfa1..dba96b9069d 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -540,7 +540,7 @@ def _split_by_character(self): offset_col = self._column.children[0] - res = cudf.core.column.ListColumn( + return cudf.core.column.ListColumn( size=len(self._column), dtype=cudf.ListDtype(self._column.dtype), mask=self._column.mask, @@ -548,7 +548,6 @@ def _split_by_character(self): null_count=self._column.null_count, children=(offset_col, result_col), ) - return res def extract( self, pat: str, flags: int = 0, expand: bool = True @@ -5118,15 +5117,7 @@ def set_base_data(self, value): "StringColumns do not use data attribute of Column, use " "`set_base_children` instead" ) - else: - super().set_base_data(value) - - def set_base_mask(self, value: Optional[Buffer]): - super().set_base_mask(value) - - def set_base_children(self, value: Tuple["column.ColumnBase", ...]): - # TODO: Implement dtype validation of the children here somehow - super().set_base_children(value) + super().set_base_data(value) def __contains__(self, item: ScalarLike) -> bool: if is_scalar(item): @@ -5199,7 +5190,7 @@ def as_datetime_column( ) else: format = datetime.infer_format( - self.apply_boolean_mask(self.notna()).element_indexing(0) + self.apply_boolean_mask(self.notnull()).element_indexing(0) ) return self._as_datetime_or_timedelta_column(out_dtype, format) @@ -5374,7 +5365,7 @@ def find_and_replace( df = cudf.DataFrame({"old": to_replace_col, "new": replacement_col}) df = df.drop_duplicates(subset=["old"], keep="last", ignore_index=True) if df._data["old"].null_count == 1: - res = self.fillna(df._data["new"][df._data["old"].isna()][0]) + res = self.fillna(df._data["new"][df._data["old"].isnull()][0]) df = df.dropna(subset=["old"]) else: res = self diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 0baa4012570..d2dbac4d155 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -45,6 +45,7 @@ from cudf.core.column import ( as_column, build_categorical_column, + build_column, column_empty, concat_columns, ) @@ -6949,7 +6950,9 @@ def _reassign_categories(categories, cols, col_idxs): if idx in categories: cols[name] = build_categorical_column( categories=categories[idx], - codes=as_column(cols[name].base_data, dtype=cols[name].dtype), + codes=build_column( + cols[name].base_data, dtype=cols[name].dtype + ), mask=cols[name].base_mask, offset=cols[name].offset, size=cols[name].size, diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 934682a1996..71f910d7bb9 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -1750,9 +1750,6 @@ def _repeat(self, count): result._copy_type_metadata(self) return result - def _reverse(self): - return self.__class__._from_data(*libcudf.copying.reverse(self)) - def _fill(self, fill_values, begin, end, inplace): col_and_fill = zip(self._columns, fill_values) diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 7743fecad49..c8d8837cbaa 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -5521,8 +5521,8 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): result_col = column.as_column(result) if a_col.null_count and b_col.null_count: - a_nulls = a_col.isna() - b_nulls = b_col.isna() + a_nulls = a_col.isnull() + b_nulls = b_col.isnull() null_values = a_nulls | b_nulls if equal_nan is True: @@ -5530,9 +5530,9 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): del a_nulls, b_nulls elif a_col.null_count: - null_values = a_col.isna() + null_values = a_col.isnull() elif b_col.null_count: - null_values = b_col.isna() + null_values = b_col.isnull() else: return Series(result_col, index=index) diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index 50fd27f2752..96281c139f2 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -2875,7 +2875,7 @@ def set_null_cases(column_l, column_r, case): "lcol,rcol,ans,case", generate_test_null_equals_columnops_data() ) def test_null_equals_columnops(lcol, rcol, ans, case): - assert lcol._null_equals(rcol).all() == ans + assert lcol.equals(rcol).all() == ans def test_add_series_to_dataframe(): diff --git a/python/cudf/cudf/tests/test_dlpack.py b/python/cudf/cudf/tests/test_dlpack.py index 4b26e2c13bc..a77845bf196 100644 --- a/python/cudf/cudf/tests/test_dlpack.py +++ b/python/cudf/cudf/tests/test_dlpack.py @@ -100,34 +100,12 @@ def test_to_dlpack_index(data_1d): assert str(type(dlt)) == "" -def test_to_dlpack_column(data_1d): - expectation = data_size_expectation_builder(data_1d) - - with expectation: - gs = cudf.Series(data_1d, nan_as_null=False) - dlt = gs._column.to_dlpack() - - # PyCapsules are a C-API thing so couldn't come up with a better way - assert str(type(dlt)) == "" - - -def test_to_dlpack_column_null(data_1d): - expectation = data_size_expectation_builder(data_1d, nan_null_param=True) - - with expectation: - gs = cudf.Series(data_1d, nan_as_null=True) - dlt = gs._column.to_dlpack() - - # PyCapsules are a C-API thing so couldn't come up with a better way - assert str(type(dlt)) == "" - - def test_to_dlpack_cupy_1d(data_1d): expectation = data_size_expectation_builder(data_1d, False) with expectation: gs = cudf.Series(data_1d, nan_as_null=False) cudf_host_array = gs.to_numpy(na_value=np.nan) - dlt = gs._column.to_dlpack() + dlt = gs.to_dlpack() cupy_array = cupy.fromDlpack(dlt) cupy_host_array = cupy_array.get() @@ -191,7 +169,7 @@ def test_to_dlpack_cupy_1d_null(data_1d): with expectation: gs = cudf.Series(data_1d) cudf_host_array = gs.to_numpy(na_value=np.nan) - dlt = gs._column.to_dlpack() + dlt = gs.to_dlpack() cupy_array = cupy.fromDlpack(dlt) cupy_host_array = cupy_array.get() diff --git a/python/dask_cudf/dask_cudf/sorting.py b/python/dask_cudf/dask_cudf/sorting.py index a1ce90ef09a..21bf3aee7d1 100644 --- a/python/dask_cudf/dask_cudf/sorting.py +++ b/python/dask_cudf/dask_cudf/sorting.py @@ -32,7 +32,7 @@ def _set_partitions_pre(s, divisions, ascending=True, na_position="last"): partitions[(partitions < 0) | (partitions >= len(divisions) - 1)] = ( 0 if ascending else (len(divisions) - 2) ) - partitions[s._columns[0].isna().values] = ( + partitions[s._columns[0].isnull().values] = ( len(divisions) - 2 if na_position == "last" else 0 ) return partitions