Skip to content

Commit

Permalink
feat(python): Streamline creation of empty frame from Schema (#20267)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Dec 13, 2024
1 parent f599e88 commit 08ddc4b
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 10 deletions.
39 changes: 36 additions & 3 deletions py-polars/polars/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
from collections import OrderedDict
from collections.abc import Mapping
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Literal, Union, overload

from polars._typing import PythonDataType
from polars.datatypes import DataType, DataTypeClass, is_polars_dtype
Expand All @@ -12,12 +12,13 @@
if TYPE_CHECKING:
from collections.abc import Iterable

from polars import DataFrame, LazyFrame

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias


if sys.version_info >= (3, 10):

def _required_init_args(tp: DataTypeClass) -> bool:
Expand All @@ -35,7 +36,6 @@ def _required_init_args(tp: DataTypeClass) -> bool:
BaseSchema = OrderedDict[str, DataType]
SchemaInitDataType: TypeAlias = Union[DataType, DataTypeClass, PythonDataType]


__all__ = ["Schema"]


Expand Down Expand Up @@ -152,6 +152,39 @@ def dtypes(self) -> list[DataType]:
"""
return list(self.values())

@overload
def to_frame(self, *, eager: Literal[False] = ...) -> LazyFrame: ...

@overload
def to_frame(self, *, eager: Literal[True]) -> DataFrame: ...

def to_frame(self, *, eager: bool = True) -> DataFrame | LazyFrame:
"""
Create an empty DataFrame (or LazyFrame) from this Schema.
Parameters
----------
eager
If True, create a DataFrame; otherwise, create a LazyFrame.
Examples
--------
>>> s = pl.Schema({"x": pl.Int32(), "y": pl.String()})
>>> s.to_frame()
shape: (0, 2)
┌─────┬─────┐
│ x ┆ y │
│ --- ┆ --- │
│ i32 ┆ str │
╞═════╪═════╡
└─────┴─────┘
>>> s.to_frame(eager=False) # doctest: +IGNORE_RESULT
<LazyFrame at 0x11BC0AD80>
"""
from polars import DataFrame, LazyFrame

return DataFrame(schema=self) if eager else LazyFrame(schema=self)

def len(self) -> int:
"""
Get the number of schema entries.
Expand Down
38 changes: 31 additions & 7 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ def test_schema() -> None:
pl.Schema({"foo": pl.String, "bar": pl.List})


@pytest.mark.parametrize(
"schema",
[
pl.Schema(),
pl.Schema({"foo": pl.Int8()}),
pl.Schema({"foo": pl.Datetime("us"), "bar": pl.String()}),
pl.Schema(
{
"foo": pl.UInt32(),
"bar": pl.Categorical("physical"),
"baz": pl.Struct({"x": pl.Int64(), "y": pl.Float64()}),
}
),
],
)
def test_schema_empty_frame(schema: pl.Schema) -> None:
assert_frame_equal(
schema.to_frame(),
pl.DataFrame(schema=schema),
)


def test_schema_equality() -> None:
s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()})
s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()})
Expand Down Expand Up @@ -248,13 +270,15 @@ def test_lazy_agg_lit_explode() -> None:
assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": [[1]]}, schema=schema)) # type: ignore[arg-type]


@pytest.mark.parametrize("expr_op", [
"approx_n_unique", "arg_max", "arg_min", "bitwise_and", "bitwise_or",
"bitwise_xor", "count", "entropy", "first", "has_nulls", "implode", "kurtosis",
"last", "len", "lower_bound", "max", "mean", "median", "min", "n_unique", "nan_max",
"nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound",
"var"
]) # fmt: skip
@pytest.mark.parametrize(
"expr_op", [
"approx_n_unique", "arg_max", "arg_min", "bitwise_and", "bitwise_or",
"bitwise_xor", "count", "entropy", "first", "has_nulls", "implode", "kurtosis",
"last", "len", "lower_bound", "max", "mean", "median", "min", "n_unique", "nan_max",
"nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound",
"var"
]
) # fmt: skip
@pytest.mark.parametrize("lhs", [pl.col("b"), pl.lit(1, dtype=pl.Int64).alias("b")])
def test_lazy_agg_to_scalar_schema_19752(lhs: pl.Expr, expr_op: str) -> None:
op = getattr(pl.Expr, expr_op)
Expand Down

0 comments on commit 08ddc4b

Please sign in to comment.