Skip to content

Commit

Permalink
if it accepts expression, it should accept column
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Sep 11, 2023
1 parent d404a73 commit a6ef75f
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 50 deletions.
8 changes: 0 additions & 8 deletions spec/API_specification/dataframe_api/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,9 @@
TypeVar,
Union,
Protocol,
TYPE_CHECKING,
TypeAlias
)
from enum import Enum

if TYPE_CHECKING:
from .expression_object import Expression
from .eagercolumn_object import EagerColumn

IntoExpression: TypeAlias = Expression | EagerColumn

# Type alias: Mypy needs Any, but for readability we need to make clear this
# is a Python scalar (i.e., an instance of `bool`, `int`, `float`, `str`, etc.)
Scalar = Any
Expand Down
10 changes: 5 additions & 5 deletions spec/API_specification/dataframe_api/dataframe_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .eagerframe_object import EagerFrame
from .eagercolumn_object import EagerColumn
from .groupby_object import GroupBy
from ._types import NullType, Scalar, IntoExpression
from ._types import NullType, Scalar


__all__ = ["DataFrame"]
Expand Down Expand Up @@ -92,7 +92,7 @@ def groupby(self, *keys: str) -> GroupBy:
"""
...

def select(self, *names: str | Expression) -> DataFrame:
def select(self, *names: str | Expression | EagerColumn[Any]) -> DataFrame:
"""
Select multiple columns, either by name or by expressions.
Expand Down Expand Up @@ -137,7 +137,7 @@ def slice_rows(
"""
...

def filter(self, mask: IntoExpression) -> DataFrame:
def filter(self, mask: Expression | EagerColumn[bool]) -> DataFrame:
"""
Select a subset of rows corresponding to a mask.
Expand Down Expand Up @@ -216,7 +216,7 @@ def update_columns(self, *columns: Expression | EagerColumn[Any]) -> DataFrame:
Parameters
----------
columns : Expression, EagerColumn, or sequence of either
columns : Expression | EagerColumn
Column(s) to update. If updating multiple columns, they must all have
different names.
Expand Down Expand Up @@ -273,7 +273,7 @@ def column_names(self) -> list[str]:

def sort(
self,
*keys: str | Expression,
*keys: str | Expression | EagerColumn[Any],
ascending: Sequence[bool] | bool = True,
nulls_position: Literal['first', 'last'] = 'last',
) -> DataFrame:
Expand Down
4 changes: 2 additions & 2 deletions spec/API_specification/dataframe_api/eagercolumn_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class EagerColumn(Generic[DType]):
"""
EagerColumn object
Instantiate via :meth:`EagerFrame.get_column_by_name`.
Instantiate via :meth:`EagerFrame.get_column`.
If you need to use this within the context of a :class`DataFrame` operation
(such as `:meth:`DataFrame.filter`) then you can convert it to an expression
Expand Down Expand Up @@ -106,7 +106,7 @@ def slice_rows(
...


def filter(self: EagerColumn[DType], mask: EagerColumn[Bool]) -> EagerColumn[DType]:
def filter(self: EagerColumn[DType], mask: Expression | EagerColumn[Bool]) -> EagerColumn[DType]:
"""
Select a subset of rows corresponding to a mask.
Expand Down
18 changes: 9 additions & 9 deletions spec/API_specification/dataframe_api/eagerframe_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .expression_object import Expression
from .dataframe_object import DataFrame
from .groupby_object import GroupBy
from ._types import NullType, Scalar, IntoExpression
from ._types import NullType, Scalar


__all__ = ["EagerFrame"]
Expand Down Expand Up @@ -89,7 +89,7 @@ def get_column(self, name: str, /) -> EagerColumn[Any]:
"""
...

def select(self, *columns: str | Expression) -> EagerFrame:
def select(self, *columns: str | Expression | EagerColumn[Any]) -> EagerFrame:
"""
Select multiple columns by name.
Expand All @@ -115,7 +115,7 @@ def select(self, *columns: str | Expression) -> EagerFrame:
"""
...

def get_rows(self, indices: Expression) -> EagerFrame:
def get_rows(self, indices: Expression | EagerColumn[Any]) -> EagerFrame:
"""
Select a subset of rows, similar to `ndarray.take`.
Expand Down Expand Up @@ -148,7 +148,7 @@ def slice_rows(
"""
...

def filter(self, mask: IntoExpression) -> EagerFrame:
def filter(self, mask: Expression | EagerColumn[bool]) -> EagerFrame:
"""
Select a subset of rows corresponding to a mask.
Expand All @@ -173,15 +173,15 @@ def insert_columns(self, *columns: Expression | EagerColumn[Any]) -> EagerFrame:
.. code-block:: python
new_column = df.get_column_by_name('a') + 1
new_column = df.get_column('a') + 1
df = df.insert_columns(new_column.rename('a_plus_1'))
If you need to insert the column at a different location, combine with
:meth:`select`, e.g.:
.. code-block:: python
new_column = df.get_column_by_name('a') + 1
new_column = df.get_column('a') + 1
new_columns_names = ['a_plus_1'] + df.get_column_names()
df = df.insert_columns(new_column.rename('a_plus_1'))
df = df.select(new_column_names)
Expand All @@ -203,12 +203,12 @@ def update_columns(self, *columns: Expression | EagerColumn[Any]) -> EagerFrame:
.. code-block:: python
new_column = df.get_column_by_name('a') + 1
new_column = df.get_column('a') + 1
df = df.update_column(new_column.rename('b').to_expression())
Parameters
----------
columns : IntoExpression | Sequence[IntoExpression]
columns : Expression | EagerColumn
Column(s) to update. If updating multiple columns, they must all have
different names.
Expand Down Expand Up @@ -265,7 +265,7 @@ def column_names(self) -> list[str]:

def sort(
self,
*keys: str | Expression,
*keys: str | Expression | EagerColumn[Any],
ascending: Sequence[bool] | bool = True,
nulls_position: Literal['first', 'last'] = 'last',
) -> EagerFrame:
Expand Down
47 changes: 23 additions & 24 deletions spec/API_specification/dataframe_api/expression_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@


if TYPE_CHECKING:
from ._types import DType
from . import Bool
from ._types import NullType, Scalar
from .eagercolumn_object import EagerColumn


__all__ = ['Expression']
Expand Down Expand Up @@ -92,7 +91,7 @@ def len(self) -> Expression:
Return the number of rows.
"""

def get_rows(self: Expression, indices: Expression) -> Expression:
def get_rows(self, indices: Expression | EagerColumn[Any]) -> Expression:
"""
Select a subset of rows, similar to `ndarray.take`.
Expand All @@ -104,7 +103,7 @@ def get_rows(self: Expression, indices: Expression) -> Expression:
...

def slice_rows(
self: Expression, start: int | None, stop: int | None, step: int | None
self, start: int | None, stop: int | None, step: int | None
) -> Expression:
"""
Select a subset of rows corresponding to a slice.
Expand All @@ -121,7 +120,7 @@ def slice_rows(
"""
...

def filter(self, mask: Expression) -> Expression:
def filter(self, mask: Expression | EagerColumn[bool]) -> Expression:
"""
Select a subset of rows corresponding to a mask.
Expand Down Expand Up @@ -225,7 +224,7 @@ def __eq__(self, other: Expression | Scalar) -> Expression: # type: ignore[over
Expression
"""

def __ne__(self: Expression, other: Expression | Scalar) -> Expression: # type: ignore[override]
def __ne__(self, other: Expression | Scalar) -> Expression: # type: ignore[override]
"""
Compare for non-equality.
Expand All @@ -243,7 +242,7 @@ def __ne__(self: Expression, other: Expression | Scalar) -> Expression: # type:
Expression
"""

def __ge__(self: Expression, other: Expression | Scalar) -> Expression:
def __ge__(self, other: Expression | Scalar) -> Expression:
"""
Compare for "greater than or equal to" `other`.
Expand All @@ -259,7 +258,7 @@ def __ge__(self: Expression, other: Expression | Scalar) -> Expression:
Expression
"""

def __gt__(self: Expression, other: Expression | Scalar) -> Expression:
def __gt__(self, other: Expression | Scalar) -> Expression:
"""
Compare for "greater than" `other`.
Expand All @@ -275,7 +274,7 @@ def __gt__(self: Expression, other: Expression | Scalar) -> Expression:
Expression
"""

def __le__(self: Expression, other: Expression | Scalar) -> Expression:
def __le__(self, other: Expression | Scalar) -> Expression:
"""
Compare for "less than or equal to" `other`.
Expand All @@ -291,7 +290,7 @@ def __le__(self: Expression, other: Expression | Scalar) -> Expression:
Expression
"""

def __lt__(self: Expression, other: Expression | Scalar) -> Expression:
def __lt__(self, other: Expression | Scalar) -> Expression:
"""
Compare for "less than" `other`.
Expand All @@ -307,7 +306,7 @@ def __lt__(self: Expression, other: Expression | Scalar) -> Expression:
Expression
"""

def __and__(self: Expression, other: Expression | bool) -> Expression:
def __and__(self, other: Expression | bool) -> Expression:
"""
Apply logical 'and' to `other` expression (or scalar) and this expression.
Expand All @@ -328,7 +327,7 @@ def __and__(self: Expression, other: Expression | bool) -> Expression:
If `self` or `other` is not boolean.
"""

def __or__(self: Expression, other: Expression | bool) -> Expression:
def __or__(self, other: Expression | bool) -> Expression:
"""
Apply logical 'or' to `other` expression (or scalar) and this expression.
Expand All @@ -349,7 +348,7 @@ def __or__(self: Expression, other: Expression | bool) -> Expression:
If `self` or `other` is not boolean.
"""

def __add__(self: Expression, other: Expression | Scalar) -> Expression:
def __add__(self, other: Expression | Scalar) -> Expression:
"""
Add `other` expression or scalar to this expression.
Expand All @@ -365,7 +364,7 @@ def __add__(self: Expression, other: Expression | Scalar) -> Expression:
Expression
"""

def __sub__(self: Expression, other: Expression | Scalar) -> Expression:
def __sub__(self, other: Expression | Scalar) -> Expression:
"""
Subtract `other` expression or scalar from this expression.
Expand Down Expand Up @@ -481,7 +480,7 @@ def __divmod__(self, other: Expression | Scalar) -> tuple[Expression, Expression
tuple[Expression, Expression]
"""

def __invert__(self: Expression) -> Expression:
def __invert__(self) -> Expression:
"""
Invert truthiness of (boolean) elements.
Expand All @@ -491,7 +490,7 @@ def __invert__(self: Expression) -> Expression:
If any of the expression's expressions is not boolean.
"""

def any(self: Expression, *, skip_nulls: bool = True) -> Expression:
def any(self, *, skip_nulls: bool = True) -> Expression:
"""
Reduction returns a bool.
Expand All @@ -501,7 +500,7 @@ def any(self: Expression, *, skip_nulls: bool = True) -> Expression:
If expression is not boolean.
"""

def all(self: Expression, *, skip_nulls: bool = True) -> Expression:
def all(self, *, skip_nulls: bool = True) -> Expression:
"""
Reduction returns a bool.
Expand Down Expand Up @@ -595,26 +594,26 @@ def var(self, *, correction: int | float = 1, skip_nulls: bool = True) -> Expres
Whether to skip null values.
"""

def cumulative_max(self: Expression) -> Expression:
def cumulative_max(self) -> Expression:
"""
Reduction returns a expression. Any data type that supports comparisons
must be supported. The returned value has the same dtype as the expression.
"""

def cumulative_min(self: Expression) -> Expression:
def cumulative_min(self) -> Expression:
"""
Reduction returns a expression. Any data type that supports comparisons
must be supported. The returned value has the same dtype as the expression.
"""

def cumulative_sum(self: Expression) -> Expression:
def cumulative_sum(self) -> Expression:
"""
Reduction returns a expression. Must be supported for numerical and
datetime data types. The returned value has the same dtype as the
expression.
"""

def cumulative_prod(self: Expression) -> Expression:
def cumulative_prod(self) -> Expression:
"""
Reduction returns a expression. Must be supported for numerical and
datetime data types. The returned value has the same dtype as the
Expand Down Expand Up @@ -659,7 +658,7 @@ def is_nan(self) -> Expression:
In particular, does not check for `np.timedelta64('NaT')`.
"""

def is_in(self: Expression, values: Expression) -> Expression:
def is_in(self, values: Expression | EagerColumn[Any]) -> Expression:
"""
Indicate whether the value at each row matches any value in `values`.
Expand Down Expand Up @@ -698,7 +697,7 @@ def unique_indices(self, *, skip_nulls: bool = True) -> Expression:
"""
...

def fill_nan(self: Expression, value: float | NullType, /) -> Expression:
def fill_nan(self, value: float | NullType, /) -> Expression:
"""
Fill floating point ``nan`` values with the given fill value.
Expand All @@ -712,7 +711,7 @@ def fill_nan(self: Expression, value: float | NullType, /) -> Expression:
"""
...

def fill_null(self: Expression, value: Scalar, /) -> Expression:
def fill_null(self, value: Scalar, /) -> Expression:
"""
Fill null values with the given fill value.
Expand Down
1 change: 0 additions & 1 deletion spec/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@
('py:class', 'optional'),
('py:class', 'NullType'),
('py:class', 'GroupBy'),
('py:class', 'IntoExpression'),
]
# NOTE: this alias handling isn't used yet - added in anticipation of future
# need based on dataframe API aliases.
Expand Down
2 changes: 1 addition & 1 deletion spec/design_topics/python_builtin_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class EagerColumn:
def mean(self, skip_nulls: bool = True) -> float | NullType:
...

larger = df2 > df1.get_column_by_name('foo').mean()
larger = df2 > df1.get_column('foo').mean()
```

For a GPU dataframe library, it is desirable for all data to reside on the GPU,
Expand Down

0 comments on commit a6ef75f

Please sign in to comment.