From 08ddc4b2a320dc03a4d6c3eda210bd384fad6d0e Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 13 Dec 2024 11:58:07 +0400 Subject: [PATCH] feat(python): Streamline creation of empty frame from `Schema` (#20267) --- py-polars/polars/schema.py | 39 ++++++++++++++++++++++++++--- py-polars/tests/unit/test_schema.py | 38 ++++++++++++++++++++++------ 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index fb1b8268bf2f..35d72dd6d493 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -3,7 +3,7 @@ import sys from collections import OrderedDict from collections.abc import Mapping -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Literal, Union, overload from polars._typing import PythonDataType from polars.datatypes import DataType, DataTypeClass, is_polars_dtype @@ -12,12 +12,13 @@ if TYPE_CHECKING: from collections.abc import Iterable + from polars import DataFrame, LazyFrame + if sys.version_info >= (3, 10): from typing import TypeAlias else: from typing_extensions import TypeAlias - if sys.version_info >= (3, 10): def _required_init_args(tp: DataTypeClass) -> bool: @@ -35,7 +36,6 @@ def _required_init_args(tp: DataTypeClass) -> bool: BaseSchema = OrderedDict[str, DataType] SchemaInitDataType: TypeAlias = Union[DataType, DataTypeClass, PythonDataType] - __all__ = ["Schema"] @@ -152,6 +152,39 @@ def dtypes(self) -> list[DataType]: """ return list(self.values()) + @overload + def to_frame(self, *, eager: Literal[False] = ...) -> LazyFrame: ... + + @overload + def to_frame(self, *, eager: Literal[True]) -> DataFrame: ... + + def to_frame(self, *, eager: bool = True) -> DataFrame | LazyFrame: + """ + Create an empty DataFrame (or LazyFrame) from this Schema. + + Parameters + ---------- + eager + If True, create a DataFrame; otherwise, create a LazyFrame. + + Examples + -------- + >>> s = pl.Schema({"x": pl.Int32(), "y": pl.String()}) + >>> s.to_frame() + shape: (0, 2) + ┌─────┬─────┐ + │ x ┆ y │ + │ --- ┆ --- │ + │ i32 ┆ str │ + ╞═════╪═════╡ + └─────┴─────┘ + >>> s.to_frame(eager=False) # doctest: +IGNORE_RESULT + + """ + from polars import DataFrame, LazyFrame + + return DataFrame(schema=self) if eager else LazyFrame(schema=self) + def len(self) -> int: """ Get the number of schema entries. diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 7ebae8e22f01..bdfc4bd21195 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -24,6 +24,28 @@ def test_schema() -> None: pl.Schema({"foo": pl.String, "bar": pl.List}) +@pytest.mark.parametrize( + "schema", + [ + pl.Schema(), + pl.Schema({"foo": pl.Int8()}), + pl.Schema({"foo": pl.Datetime("us"), "bar": pl.String()}), + pl.Schema( + { + "foo": pl.UInt32(), + "bar": pl.Categorical("physical"), + "baz": pl.Struct({"x": pl.Int64(), "y": pl.Float64()}), + } + ), + ], +) +def test_schema_empty_frame(schema: pl.Schema) -> None: + assert_frame_equal( + schema.to_frame(), + pl.DataFrame(schema=schema), + ) + + def test_schema_equality() -> None: s1 = pl.Schema({"foo": pl.Int8(), "bar": pl.Float64()}) s2 = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) @@ -248,13 +270,15 @@ def test_lazy_agg_lit_explode() -> None: assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": [[1]]}, schema=schema)) # type: ignore[arg-type] -@pytest.mark.parametrize("expr_op", [ - "approx_n_unique", "arg_max", "arg_min", "bitwise_and", "bitwise_or", - "bitwise_xor", "count", "entropy", "first", "has_nulls", "implode", "kurtosis", - "last", "len", "lower_bound", "max", "mean", "median", "min", "n_unique", "nan_max", - "nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound", - "var" -]) # fmt: skip +@pytest.mark.parametrize( + "expr_op", [ + "approx_n_unique", "arg_max", "arg_min", "bitwise_and", "bitwise_or", + "bitwise_xor", "count", "entropy", "first", "has_nulls", "implode", "kurtosis", + "last", "len", "lower_bound", "max", "mean", "median", "min", "n_unique", "nan_max", + "nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound", + "var" + ] +) # fmt: skip @pytest.mark.parametrize("lhs", [pl.col("b"), pl.lit(1, dtype=pl.Int64).alias("b")]) def test_lazy_agg_to_scalar_schema_19752(lhs: pl.Expr, expr_op: str) -> None: op = getattr(pl.Expr, expr_op)