Skip to content

Commit

Permalink
Add more complete type annotations in polars interpreter (rapidsai#15942
Browse files Browse the repository at this point in the history
)

We can check this with:

    pyright --verifytypes cudf_polars --ignoreexternal

Which reports a "type completeness" score of around 94%. This will
improve once pylibcudf gets type stubs.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - James Lamb (https://github.com/jameslamb)
  - Matthew Roeschke (https://github.com/mroeschke)

URL: rapidsai#15942
  • Loading branch information
wence- authored Jun 6, 2024
1 parent 61da924 commit 3468fa1
Show file tree
Hide file tree
Showing 11 changed files with 287 additions and 123 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ repos:
- id: rapids-dependency-file-generator
args: ["--clean"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.3
rev: v0.4.8
hooks:
- id: ruff
files: python/.*$
Expand Down
5 changes: 4 additions & 1 deletion python/cudf_polars/cudf_polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@

from __future__ import annotations

__all__: list[str] = []
from cudf_polars.callback import execute_with_cudf
from cudf_polars.dsl.translate import translate_ir

__all__: list[str] = ["execute_with_cudf", "translate_ir"]
3 changes: 2 additions & 1 deletion python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import polars as pl

from cudf_polars.dsl.ir import IR
from cudf_polars.typing import NodeTraverser

__all__: list[str] = ["execute_with_cudf"]

Expand All @@ -33,7 +34,7 @@ def _callback(
return ir.evaluate(cache={}).to_polars()


def execute_with_cudf(nt, *, raise_on_fail: bool = False) -> None:
def execute_with_cudf(nt: NodeTraverser, *, raise_on_fail: bool = False) -> None:
"""
A post optimization callback that attempts to execute the plan with cudf.
Expand Down
13 changes: 7 additions & 6 deletions python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import polars as pl

Expand All @@ -17,6 +17,7 @@
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence, Set

import pyarrow as pa
from typing_extensions import Self

import cudf
Expand Down Expand Up @@ -44,13 +45,13 @@ def copy(self) -> Self:

def to_polars(self) -> pl.DataFrame:
"""Convert to a polars DataFrame."""
return pl.from_arrow(
plc.interop.to_arrow(
self.table,
[plc.interop.ColumnMetadata(name=c.name) for c in self.columns],
)
table: pa.Table = plc.interop.to_arrow(
self.table,
[plc.interop.ColumnMetadata(name=c.name) for c in self.columns],
)

return cast(pl.DataFrame, pl.from_arrow(table))

@cached_property
def column_names_set(self) -> frozenset[str]:
"""Return the column names as a set."""
Expand Down
55 changes: 39 additions & 16 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,14 @@ def is_equal(self, other: Any) -> bool:
other.children
)

def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Equality of expressions."""
if type(self) != type(other) or hash(self) != hash(other):
return False
else:
return self.is_equal(other)

def __ne__(self, other) -> bool:
def __ne__(self, other: Any) -> bool:
"""Inequality of expressions."""
return not self.__eq__(other)

Expand Down Expand Up @@ -285,6 +285,8 @@ class NamedExpr:
# when evaluating expressions themselves, only when constructing
# named return values in dataframe (IR) nodes.
__slots__ = ("name", "value")
value: Expr
name: str

def __init__(self, name: str, value: Expr) -> None:
self.name = name
Expand All @@ -298,15 +300,15 @@ def __repr__(self) -> str:
"""Repr of the expression."""
return f"NamedExpr({self.name}, {self.value}"

def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Equality of two expressions."""
return (
type(self) is type(other)
and self.name == other.name
and self.value == other.value
)

def __ne__(self, other) -> bool:
def __ne__(self, other: Any) -> bool:
"""Inequality of expressions."""
return not self.__eq__(other)

Expand Down Expand Up @@ -344,9 +346,10 @@ def collect_agg(self, *, depth: int) -> AggInfo:
class Literal(Expr):
__slots__ = ("value",)
_non_child = ("dtype", "value")
value: pa.Scalar
value: pa.Scalar[Any]
children: tuple[()]

def __init__(self, dtype: plc.DataType, value: pa.Scalar) -> None:
def __init__(self, dtype: plc.DataType, value: pa.Scalar[Any]) -> None:
super().__init__(dtype)
assert value.type == plc.interop.to_arrow(dtype)
self.value = value
Expand All @@ -367,6 +370,7 @@ class Col(Expr):
__slots__ = ("name",)
_non_child = ("dtype", "name")
name: str
children: tuple[()]

def __init__(self, dtype: plc.DataType, name: str) -> None:
self.dtype = dtype
Expand All @@ -388,6 +392,8 @@ def collect_agg(self, *, depth: int) -> AggInfo:


class Len(Expr):
children: tuple[()]

def do_evaluate(
self,
df: DataFrame,
Expand All @@ -410,8 +416,15 @@ def collect_agg(self, *, depth: int) -> AggInfo:
class BooleanFunction(Expr):
__slots__ = ("name", "options", "children")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(self, dtype: plc.DataType, name: str, options: tuple, *children: Expr):
def __init__(
self,
dtype: plc.DataType,
name: pl_expr.BooleanFunction,
options: tuple[Any, ...],
*children: Expr,
) -> None:
super().__init__(dtype)
self.options = options
self.name = name
Expand Down Expand Up @@ -610,14 +623,15 @@ def do_evaluate(
class StringFunction(Expr):
__slots__ = ("name", "options", "children")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(
self,
dtype: plc.DataType,
name: pl_expr.StringFunction,
options: tuple,
options: tuple[Any, ...],
*children: Expr,
):
) -> None:
super().__init__(dtype)
self.options = options
self.name = name
Expand Down Expand Up @@ -661,10 +675,11 @@ def do_evaluate(
class Sort(Expr):
__slots__ = ("options", "children")
_non_child = ("dtype", "options")
children: tuple[Expr]

def __init__(
self, dtype: plc.DataType, options: tuple[bool, bool, bool], column: Expr
):
) -> None:
super().__init__(dtype)
self.options = options
self.children = (column,)
Expand Down Expand Up @@ -696,14 +711,15 @@ def do_evaluate(
class SortBy(Expr):
__slots__ = ("options", "children")
_non_child = ("dtype", "options")
children: tuple[Expr, ...]

def __init__(
self,
dtype: plc.DataType,
options: tuple[bool, tuple[bool], tuple[bool]],
column: Expr,
*by: Expr,
):
) -> None:
super().__init__(dtype)
self.options = options
self.children = (column, *by)
Expand Down Expand Up @@ -734,8 +750,9 @@ def do_evaluate(
class Gather(Expr):
__slots__ = ("children",)
_non_child = ("dtype",)
children: tuple[Expr, Expr]

def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr):
def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr) -> None:
super().__init__(dtype)
self.children = (values, indices)

Expand Down Expand Up @@ -775,6 +792,7 @@ def do_evaluate(
class Filter(Expr):
__slots__ = ("children",)
_non_child = ("dtype",)
children: tuple[Expr, Expr]

def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr):
super().__init__(dtype)
Expand All @@ -801,8 +819,9 @@ def do_evaluate(
class RollingWindow(Expr):
__slots__ = ("options", "children")
_non_child = ("dtype", "options")
children: tuple[Expr]

def __init__(self, dtype: plc.DataType, options: Any, agg: Expr):
def __init__(self, dtype: plc.DataType, options: Any, agg: Expr) -> None:
super().__init__(dtype)
self.options = options
self.children = (agg,)
Expand All @@ -811,8 +830,9 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr):
class GroupedRollingWindow(Expr):
__slots__ = ("options", "children")
_non_child = ("dtype", "options")
children: tuple[Expr, ...]

def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr):
def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr) -> None:
super().__init__(dtype)
self.options = options
self.children = (agg, *by)
Expand All @@ -821,8 +841,9 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr):
class Cast(Expr):
__slots__ = ("children",)
_non_child = ("dtype",)
children: tuple[Expr]

def __init__(self, dtype: plc.DataType, value: Expr):
def __init__(self, dtype: plc.DataType, value: Expr) -> None:
super().__init__(dtype)
self.children = (value,)

Expand All @@ -848,6 +869,7 @@ def collect_agg(self, *, depth: int) -> AggInfo:
class Agg(Expr):
__slots__ = ("name", "options", "op", "request", "children")
_non_child = ("dtype", "name", "options")
children: tuple[Expr]

def __init__(
self, dtype: plc.DataType, name: str, options: Any, value: Expr
Expand Down Expand Up @@ -1007,7 +1029,7 @@ def _last(self, column: Column) -> Column:

def do_evaluate(
self,
df,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
Expand All @@ -1022,6 +1044,7 @@ def do_evaluate(
class BinOp(Expr):
__slots__ = ("op", "children")
_non_child = ("dtype", "op")
children: tuple[Expr, Expr]

def __init__(
self,
Expand Down
Loading

0 comments on commit 3468fa1

Please sign in to comment.