From 07747f1afa937c24f6fc8e6a45b15a28fd9d4610 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Sun, 2 Jul 2023 16:28:58 +0200 Subject: [PATCH] refactor(rust,python): refactor `arange` and add `int_range`/`int_ranges` (#9666) --- polars/polars-core/src/serde/mod.rs | 1 - .../polars-plan/src/dsl/function_expr/mod.rs | 28 +++ .../src/dsl/function_expr/range.rs | 162 ++++++++++++++ .../src/dsl/function_expr/schema.rs | 12 + .../polars-plan/src/dsl/functions/range.rs | 210 +++--------------- .../reference/expressions/functions.rst | 7 + py-polars/docs/source/reference/functions.rst | 13 +- py-polars/polars/__init__.py | 4 + py-polars/polars/functions/__init__.py | 4 +- py-polars/polars/functions/range.py | 179 ++++++++++++++- py-polars/polars/type_aliases.py | 3 +- py-polars/src/functions/lazy.rs | 5 - py-polars/src/functions/mod.rs | 1 + py-polars/src/functions/range.rs | 28 +++ py-polars/src/lib.rs | 10 +- py-polars/tests/unit/functions/test_range.py | 75 ++++++- 16 files changed, 541 insertions(+), 201 deletions(-) create mode 100644 polars/polars-lazy/polars-plan/src/dsl/function_expr/range.rs create mode 100644 py-polars/src/functions/range.rs diff --git a/polars/polars-core/src/serde/mod.rs b/polars/polars-core/src/serde/mod.rs index 575dce3a8f33..1551ed476c6a 100644 --- a/polars/polars-core/src/serde/mod.rs +++ b/polars/polars-core/src/serde/mod.rs @@ -4,7 +4,6 @@ pub mod series; #[cfg(test)] mod test { - use super::*; use crate::prelude::*; #[test] diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs index 169753bda36e..433d900cef26 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs @@ -25,6 +25,8 @@ mod list; mod log; mod nan; mod pow; +#[cfg(feature = "arange")] +mod range; #[cfg(all(feature = "rolling_window", feature = "moment"))] mod rolling; #[cfg(feature = "round_series")] @@ -67,6 +69,8 @@ pub use self::boolean::BooleanFunction; pub(crate) use self::cat::CategoricalFunction; #[cfg(feature = "temporal")] pub(super) use self::datetime::TemporalFunction; +#[cfg(feature = "arange")] +pub(super) use self::range::RangeFunction; #[cfg(feature = "strings")] pub(crate) use self::strings::StringFunction; #[cfg(feature = "dtype-struct")] @@ -93,6 +97,8 @@ pub enum FunctionExpr { BinaryExpr(BinaryFunction), #[cfg(feature = "temporal")] TemporalExpr(TemporalFunction), + #[cfg(feature = "arange")] + Range(RangeFunction), #[cfg(feature = "date_offset")] DateOffset(polars_time::Duration), #[cfg(feature = "trigonometry")] @@ -209,6 +215,8 @@ impl Display for FunctionExpr { BinaryExpr(b) => return write!(f, "{b}"), #[cfg(feature = "temporal")] TemporalExpr(fun) => return write!(f, "{fun}"), + #[cfg(feature = "arange")] + Range(func) => return write!(f, "{func}"), #[cfg(feature = "date_offset")] DateOffset(_) => "dt.offset_by", #[cfg(feature = "trigonometry")] @@ -391,6 +399,8 @@ impl From for SpecialEq> { BinaryExpr(s) => s.into(), #[cfg(feature = "temporal")] TemporalExpr(func) => func.into(), + #[cfg(feature = "arange")] + Range(func) => func.into(), #[cfg(feature = "date_offset")] DateOffset(offset) => { @@ -651,3 +661,21 @@ impl From for SpecialEq> { } } } + +#[cfg(feature = "arange")] +impl From for SpecialEq> { + fn from(func: RangeFunction) -> Self { + use RangeFunction::*; + match func { + ARange { step } => { + map_as_slice!(range::arange, step) + } + IntRange { step } => { + map_as_slice!(range::int_range, step) + } + IntRanges { step } => { + map_as_slice!(range::int_ranges, step) + } + } + } +} diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/range.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/range.rs new file mode 100644 index 000000000000..6f0f7b73ddba --- /dev/null +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/range.rs @@ -0,0 +1,162 @@ +use super::*; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)] +pub enum RangeFunction { + ARange { step: i64 }, + IntRange { step: i64 }, + IntRanges { step: i64 }, +} + +impl Display for RangeFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use RangeFunction::*; + match self { + ARange { .. } => write!(f, "arange"), + IntRange { .. } => write!(f, "int_range"), + IntRanges { .. } => write!(f, "int_ranges"), + } + } +} + +fn int_range_impl(start: T::Native, end: T::Native, step: i64) -> PolarsResult +where + T: PolarsNumericType, + ChunkedArray: IntoSeries, + std::ops::Range: Iterator, + std::ops::RangeInclusive: DoubleEndedIterator, +{ + let name = "int"; + + let mut ca = match step { + 0 => polars_bail!(InvalidOperation: "step must not be zero"), + 1 => ChunkedArray::::from_iter_values(name, start..end), + 2.. => ChunkedArray::::from_iter_values(name, (start..end).step_by(step as usize)), + _ => { + polars_ensure!(start > end, InvalidOperation: "range must be decreasing if 'step' is negative"); + ChunkedArray::::from_iter_values( + name, + (end..=start).rev().step_by(step.unsigned_abs() as usize), + ) + } + }; + + let is_sorted = if end < start { + IsSorted::Descending + } else { + IsSorted::Ascending + }; + ca.set_sorted_flag(is_sorted); + + Ok(ca.into_series()) +} + +/// Create list entries that are range arrays +/// - if `start` and `end` are a column, every element will expand into an array in a list column. +/// - if `start` and `end` are literals the output will be of `Int64`. +pub(super) fn arange(s: &[Series], step: i64) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + let mut result = if start.len() == 1 && end.len() == 1 { + int_range(s, step) + } else { + int_ranges(s, step) + }?; + + result.rename("arange"); + + Ok(result) +} + +pub(super) fn int_range(s: &[Series], step: i64) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + match start.dtype() { + dt if dt == &IDX_DTYPE => { + let start = start + .idx()? + .get(0) + .ok_or_else(|| polars_err!(NoData: "no data in `start` evaluation"))?; + let end = end.cast(&IDX_DTYPE)?; + let end = end + .idx()? + .get(0) + .ok_or_else(|| polars_err!(NoData: "no data in `end` evaluation"))?; + + int_range_impl::(start, end, step) + } + _ => { + let start = start.cast(&DataType::Int64)?; + let end = end.cast(&DataType::Int64)?; + let start = start + .i64()? + .get(0) + .ok_or_else(|| polars_err!(NoData: "no data in `start` evaluation"))?; + let end = end + .i64()? + .get(0) + .ok_or_else(|| polars_err!(NoData: "no data in `end` evaluation"))?; + int_range_impl::(start, end, step) + } + } +} + +pub(super) fn int_ranges(s: &[Series], step: i64) -> PolarsResult { + let start = &s[0]; + let end = &s[1]; + + let output_name = "int_range"; + + let mut start = start.cast(&DataType::Int64)?; + let mut end = end.cast(&DataType::Int64)?; + + if start.len() != end.len() { + if start.len() == 1 { + start = start.new_from_index(0, end.len()) + } else if end.len() == 1 { + end = end.new_from_index(0, start.len()) + } else { + polars_bail!( + ComputeError: + "lengths of `start`: {} and `end`: {} arguments `\ + cannot be matched in the `arange` expression", + start.len(), end.len() + ); + } + } + + let start = start.i64()?; + let end = end.i64()?; + let mut builder = ListPrimitiveChunkedBuilder::::new( + output_name, + start.len(), + start.len() * 3, + DataType::Int64, + ); + + for (opt_start, opt_end) in start.into_iter().zip(end.into_iter()) { + match (opt_start, opt_end) { + (Some(start_v), Some(end_v)) => match step { + 1 => { + builder.append_iter_values(start_v..end_v); + } + 2.. => { + builder.append_iter_values((start_v..end_v).step_by(step as usize)); + } + _ => { + polars_ensure!(start_v > end_v, InvalidOperation: "range must be decreasing if 'step' is negative"); + builder.append_iter_values( + (end_v..=start_v) + .rev() + .step_by(step.unsigned_abs() as usize), + ) + } + }, + _ => builder.append_null(), + } + } + + Ok(builder.finish().into_series()) +} diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs index 12067a947c16..a0f44620e1a2 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs @@ -85,6 +85,18 @@ impl FunctionExpr { mapper.with_dtype(dtype) } + #[cfg(feature = "arange")] + Range(fun) => { + use RangeFunction::*; + let field = match fun { + ARange { .. } => Field::new("arange", DataType::Int64), // This is not always correct + IntRange { .. } => Field::new("int", DataType::Int64), + IntRanges { .. } => { + Field::new("int_range", DataType::List(Box::new(DataType::Int64))) + } + }; + Ok(field) + } #[cfg(feature = "date_offset")] DateOffset(_) => mapper.with_same_dtype(), #[cfg(feature = "trigonometry")] diff --git a/polars/polars-lazy/polars-plan/src/dsl/functions/range.rs b/polars/polars-lazy/polars-plan/src/dsl/functions/range.rs index a73a8a2a68b2..c04b8d3ffcf9 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/functions/range.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/functions/range.rs @@ -1,191 +1,49 @@ use super::*; -#[cfg(feature = "arange")] -fn arange_impl(start: T::Native, end: T::Native, step: i64) -> PolarsResult> -where - T: PolarsNumericType, - ChunkedArray: IntoSeries, - std::ops::Range: Iterator, - std::ops::RangeInclusive: DoubleEndedIterator, -{ - let mut ca = match step { - 1 => ChunkedArray::::from_iter_values("arange", start..end), - 2.. => ChunkedArray::::from_iter_values("arange", (start..end).step_by(step as usize)), - _ => { - polars_ensure!(start > end, InvalidOperation: "range must be decreasing if 'step' is negative"); - ChunkedArray::::from_iter_values( - "arange", - (end..=start).rev().step_by(step.unsigned_abs() as usize), - ) - } - }; - let is_sorted = if end < start { - IsSorted::Descending - } else { - IsSorted::Ascending - }; - ca.set_sorted_flag(is_sorted); - Ok(Some(ca.into_series())) -} - -// TODO! rewrite this with the apply_private architecture /// Create list entries that are range arrays /// - if `start` and `end` are a column, every element will expand into an array in a list column. /// - if `start` and `end` are literals the output will be of `Int64`. #[cfg(feature = "arange")] pub fn arange(start: Expr, end: Expr, step: i64) -> Expr { - let output_name = "arange"; - - let has_col_without_agg = |e: &Expr| { - has_expr(e, |ae| matches!(ae, Expr::Column(_))) - && - // check if there is no aggregation - !has_expr(e, |ae| { - matches!( - ae, - Expr::Agg(_) - | Expr::Count - | Expr::AnonymousFunction { - options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, - .. - }, - .. - } - | Expr::Function { - options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, - .. - }, - .. - }, - ) - }) - }; - let has_lit = |e: &Expr| { - (matches!(e, Expr::Literal(_)) && !matches!(e, Expr::Literal(LiteralValue::Series(_)))) - }; - - let any_column_no_agg = has_col_without_agg(&start) || has_col_without_agg(&end); - let literal_start = has_lit(&start); - let literal_end = has_lit(&end); - - if (literal_start || literal_end) && !any_column_no_agg { - let f = move |sa: Series, sb: Series| { - polars_ensure!(step != 0, InvalidOperation: "step must not be zero"); + let input = vec![start, end]; - match sa.dtype() { - dt if dt == &IDX_DTYPE => { - let start = sa - .idx()? - .get(0) - .ok_or_else(|| polars_err!(NoData: "no data in `start` evaluation"))?; - let sb = sb.cast(&IDX_DTYPE)?; - let end = sb - .idx()? - .get(0) - .ok_or_else(|| polars_err!(NoData: "no data in `end` evaluation"))?; - #[cfg(feature = "bigidx")] - { - arange_impl::(start, end, step) - } - #[cfg(not(feature = "bigidx"))] - { - arange_impl::(start, end, step) - } - } - _ => { - let sa = sa.cast(&DataType::Int64)?; - let sb = sb.cast(&DataType::Int64)?; - let start = sa - .i64()? - .get(0) - .ok_or_else(|| polars_err!(NoData: "no data in `start` evaluation"))?; - let end = sb - .i64()? - .get(0) - .ok_or_else(|| polars_err!(NoData: "no data in `end` evaluation"))?; - arange_impl::(start, end, step) - } - } - }; - apply_binary( - start, - end, - f, - GetOutput::map_field(|input| { - let dtype = if input.data_type() == &IDX_DTYPE { - IDX_DTYPE - } else { - DataType::Int64 - }; - Field::new(output_name, dtype) - }), - ) - .alias(output_name) - } else { - let f = move |sa: Series, sb: Series| { - polars_ensure!(step != 0, InvalidOperation: "step must not be zero"); - let mut sa = sa.cast(&DataType::Int64)?; - let mut sb = sb.cast(&DataType::Int64)?; + Expr::Function { + input, + function: FunctionExpr::Range(RangeFunction::ARange { step }), + options: FunctionOptions { + allow_rename: true, + ..Default::default() + }, + } +} - if sa.len() != sb.len() { - if sa.len() == 1 { - sa = sa.new_from_index(0, sb.len()) - } else if sb.len() == 1 { - sb = sb.new_from_index(0, sa.len()) - } else { - polars_bail!( - ComputeError: - "lengths of `start`: {} and `end`: {} arguments `\ - cannot be matched in the `arange` expression", - sa.len(), sb.len() - ); - } - } +#[cfg(feature = "arange")] +/// Generate a range of integers. +pub fn int_range(start: Expr, end: Expr, step: i64) -> Expr { + let input = vec![start, end]; - let start = sa.i64()?; - let end = sb.i64()?; - let mut builder = ListPrimitiveChunkedBuilder::::new( - output_name, - start.len(), - start.len() * 3, - DataType::Int64, - ); + Expr::Function { + input, + function: FunctionExpr::Range(RangeFunction::IntRange { step }), + options: FunctionOptions { + allow_rename: true, + ..Default::default() + }, + } +} - for (opt_start, opt_end) in start.into_iter().zip(end.into_iter()) { - match (opt_start, opt_end) { - (Some(start_v), Some(end_v)) => match step { - 1 => { - builder.append_iter_values(start_v..end_v); - } - 2.. => { - builder.append_iter_values((start_v..end_v).step_by(step as usize)); - } - _ => { - polars_ensure!(start_v > end_v, InvalidOperation: "range must be decreasing if 'step' is negative"); - builder.append_iter_values( - (end_v..=start_v) - .rev() - .step_by(step.unsigned_abs() as usize), - ) - } - }, - _ => builder.append_null(), - } - } +#[cfg(feature = "arange")] +/// Generate a range of integers for each row of the input columns. +pub fn int_ranges(start: Expr, end: Expr, step: i64) -> Expr { + let input = vec![start, end]; - Ok(Some(builder.finish().into_series())) - }; - apply_binary( - start, - end, - f, - GetOutput::map_field(|_| { - Field::new(output_name, DataType::List(DataType::Int64.into())) - }), - ) - .alias(output_name) + Expr::Function { + input, + function: FunctionExpr::Range(RangeFunction::IntRanges { step }), + options: FunctionOptions { + allow_rename: true, + ..Default::default() + }, } } diff --git a/py-polars/docs/source/reference/expressions/functions.rst b/py-polars/docs/source/reference/expressions/functions.rst index 9d2d2718c087..5a785893f8d4 100644 --- a/py-polars/docs/source/reference/expressions/functions.rst +++ b/py-polars/docs/source/reference/expressions/functions.rst @@ -18,6 +18,7 @@ These functions are available from the polars module root and can be used as exp approx_unique arange arg_sort_by + arg_where avg coalesce concat_list @@ -30,6 +31,7 @@ These functions are available from the polars module root and can be used as exp cumsum date datetime + date_range duration element exclude @@ -40,6 +42,8 @@ These functions are available from the polars module root and can be used as exp groups head implode + int_range + int_ranges lit map max @@ -47,6 +51,7 @@ These functions are available from the polars module root and can be used as exp median min n_unique + ones quantile reduce repeat @@ -59,8 +64,10 @@ These functions are available from the polars module root and can be used as exp sql_expr tail time + time_range var when + zeros **Available in expression namespace:** diff --git a/py-polars/docs/source/reference/functions.rst b/py-polars/docs/source/reference/functions.rst index 62d4ad8147b5..44807c91ddd9 100644 --- a/py-polars/docs/source/reference/functions.rst +++ b/py-polars/docs/source/reference/functions.rst @@ -17,24 +17,13 @@ Conversion from_records from_repr -Eager/Lazy functions -~~~~~~~~~~~~~~~~~~~~ -.. autosummary:: - :toctree: api/ - - arg_where - concat - date_range - ones - time_range - zeros - Miscellaneous ~~~~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: api/ align_frames + concat Parallelization ~~~~~~~~~~~~~~~ diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 1291accb266b..130914381c0b 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -110,6 +110,8 @@ groups, head, implode, + int_range, + int_ranges, last, lit, map, @@ -308,6 +310,8 @@ "groups", "head", "implode", + "int_range", + "int_ranges", "last", "lit", "map", diff --git a/py-polars/polars/functions/__init__.py b/py-polars/polars/functions/__init__.py index 8f93e87ece43..f9fcec4fb853 100644 --- a/py-polars/polars/functions/__init__.py +++ b/py-polars/polars/functions/__init__.py @@ -53,7 +53,7 @@ tail, var, ) -from polars.functions.range import arange, date_range, time_range +from polars.functions.range import arange, date_range, int_range, int_ranges, time_range from polars.functions.repeat import ones, repeat, zeros from polars.functions.whenthen import when @@ -98,6 +98,8 @@ "groups", "head", "implode", + "int_range", + "int_ranges", "last", "lit", "map", diff --git a/py-polars/polars/functions/range.py b/py-polars/polars/functions/range.py index 3b944c02672b..52a6cc46df75 100644 --- a/py-polars/polars/functions/range.py +++ b/py-polars/polars/functions/range.py @@ -26,7 +26,13 @@ from datetime import date from polars import Expr, Series - from polars.type_aliases import ClosedInterval, PolarsDataType, TimeUnit + from polars.type_aliases import ( + ClosedInterval, + IntoExpr, + PolarsDataType, + PolarsIntegerType, + TimeUnit, + ) if sys.version_info >= (3, 8): from typing import Literal @@ -139,6 +145,177 @@ def arange( return result +@overload +def int_range( + start: int | IntoExpr, + end: int | IntoExpr, + step: int = ..., + *, + eager: Literal[False] = ..., +) -> Expr: + ... + + +@overload +def int_range( + start: int | IntoExpr, + end: int | IntoExpr, + step: int = ..., + *, + eager: Literal[True], +) -> Series: + ... + + +@overload +def int_range( + start: int | IntoExpr, + end: int | IntoExpr, + step: int = ..., + *, + eager: bool, +) -> Expr | Series: + ... + + +def int_range( + start: int | IntoExpr, + end: int | IntoExpr, + step: int = 1, + *, + eager: bool = False, +) -> Expr | Series: + """ + Generate a range of integers. + + Parameters + ---------- + start + Lower bound of the range (inclusive). + end + Upper bound of the range (exclusive). + step + Step size of the range. + eager + Evaluate immediately and return a ``Series``. If set to ``False`` (default), + return an expression instead. + + Returns + ------- + Column of data type ``Int64``. + + Examples + -------- + >>> pl.int_range(0, 3, eager=True) + shape: (3,) + Series: 'int' [i64] + [ + 0 + 1 + 2 + ] + + """ + start = parse_as_expression(start) + end = parse_as_expression(end) + result = wrap_expr(plr.int_range(start, end, step)) + + if eager: + return F.select(result).to_series() + + return result + + +@overload +def int_ranges( + start: IntoExpr, + end: IntoExpr, + step: int = ..., + *, + dtype: PolarsIntegerType = ..., + eager: Literal[False] = ..., +) -> Expr: + ... + + +@overload +def int_ranges( + start: IntoExpr, + end: IntoExpr, + step: int = ..., + *, + dtype: PolarsIntegerType = ..., + eager: Literal[True], +) -> Series: + ... + + +@overload +def int_ranges( + start: IntoExpr, + end: IntoExpr, + step: int = ..., + *, + dtype: PolarsIntegerType = ..., + eager: bool, +) -> Expr | Series: + ... + + +def int_ranges( + start: IntoExpr, + end: IntoExpr, + step: int = 1, + *, + dtype: PolarsIntegerType = Int64, + eager: bool = False, +) -> Expr | Series: + """ + Generate a range of integers for each row of the input columns. + + Parameters + ---------- + start + Lower bound of the range (inclusive). + end + Upper bound of the range (exclusive). + step + Step size of the range. + dtype + Integer data type of the ranges. Defaults to ``Int64``. + eager + Evaluate immediately and return a ``Series``. If set to ``False`` (default), + return an expression instead. + + Returns + ------- + Column of data type ``List(dtype)``. + + Examples + -------- + >>> df = pl.DataFrame({"start": [1, -1], "end": [3, 2]}) + >>> df.with_columns(pl.int_ranges("start", "end")) + shape: (2, 3) + ┌───────┬─────┬────────────┐ + │ start ┆ end ┆ int_range │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ list[i64] │ + ╞═══════╪═════╪════════════╡ + │ 1 ┆ 3 ┆ [1, 2] │ + │ -1 ┆ 2 ┆ [-1, 0, 1] │ + └───────┴─────┴────────────┘ + + """ + start = parse_as_expression(start) + end = parse_as_expression(end) + result = wrap_expr(plr.int_ranges(start, end, step, dtype)) + + if eager: + return F.select(result).to_series() + + return result + + @overload def date_range( start: date | datetime | Expr | str, diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index 384128e46282..bab0225ed86b 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from polars import DataFrame, Expr, LazyFrame, Series - from polars.datatypes import DataType, DataTypeClass, TemporalType + from polars.datatypes import DataType, DataTypeClass, IntegralType, TemporalType from polars.dependencies import numpy as np from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa @@ -39,6 +39,7 @@ # Data types PolarsDataType: TypeAlias = Union["DataTypeClass", "DataType"] PolarsTemporalType: TypeAlias = Union[Type["TemporalType"], "TemporalType"] +PolarsIntegerType: TypeAlias = Union[Type["IntegralType"], "IntegralType"] OneOrMoreDataTypes: TypeAlias = Union[PolarsDataType, Iterable[PolarsDataType]] PythonDataType: TypeAlias = Union[ Type[int], diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index 8c8cfb233137..64e07e92e107 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -20,11 +20,6 @@ macro_rules! set_unwrapped_or_0 { }; } -#[pyfunction] -pub fn arange(start: PyExpr, end: PyExpr, step: i64) -> PyExpr { - dsl::arange(start.inner, end.inner, step).into() -} - #[pyfunction] pub fn rolling_corr( x: PyExpr, diff --git a/py-polars/src/functions/mod.rs b/py-polars/src/functions/mod.rs index 02b8c04c4519..e26d7e2a28b2 100644 --- a/py-polars/src/functions/mod.rs +++ b/py-polars/src/functions/mod.rs @@ -3,4 +3,5 @@ pub mod io; pub mod lazy; pub mod meta; pub mod misc; +pub mod range; pub mod whenthen; diff --git a/py-polars/src/functions/range.rs b/py-polars/src/functions/range.rs new file mode 100644 index 000000000000..0ddc21266365 --- /dev/null +++ b/py-polars/src/functions/range.rs @@ -0,0 +1,28 @@ +use polars::lazy::dsl; +use pyo3::prelude::*; + +use crate::prelude::*; +use crate::PyExpr; + +#[pyfunction] +pub fn arange(start: PyExpr, end: PyExpr, step: i64) -> PyExpr { + dsl::arange(start.inner, end.inner, step).into() +} + +#[pyfunction] +pub fn int_range(start: PyExpr, end: PyExpr, step: i64) -> PyExpr { + dsl::int_range(start.inner, end.inner, step).into() +} + +#[pyfunction] +pub fn int_ranges(start: PyExpr, end: PyExpr, step: i64, dtype: Wrap) -> PyExpr { + let dtype = dtype.0; + + let mut result = dsl::int_ranges(start.inner, end.inner, step); + + if dtype != DataType::Int64 { + result = result.cast(DataType::List(Box::new(dtype))) + } + + result.into() +} diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 13635ec63526..f275bea134e4 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -92,9 +92,15 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(functions::eager::time_range_eager)) .unwrap(); - // Functions - lazy - m.add_wrapped(wrap_pyfunction!(functions::lazy::arange)) + // Functions - range + m.add_wrapped(wrap_pyfunction!(functions::range::arange)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::range::int_range)) + .unwrap(); + m.add_wrapped(wrap_pyfunction!(functions::range::int_ranges)) .unwrap(); + + // Functions - lazy m.add_wrapped(wrap_pyfunction!(functions::lazy::arg_sort_by)) .unwrap(); m.add_wrapped(wrap_pyfunction!(functions::lazy::arg_where)) diff --git a/py-polars/tests/unit/functions/test_range.py b/py-polars/tests/unit/functions/test_range.py index a0afd40c13df..0940478615f2 100644 --- a/py-polars/tests/unit/functions/test_range.py +++ b/py-polars/tests/unit/functions/test_range.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import date, datetime, time, timedelta -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pandas as pd import pytest @@ -9,7 +9,7 @@ import polars as pl from polars.datatypes import DTYPE_TEMPORAL_UNITS from polars.exceptions import ComputeError, TimeZoneAwareConstructorWarning -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: from zoneinfo import ZoneInfo @@ -69,6 +69,77 @@ def test_arange_name() -> None: assert result_lazy.name == expected_name +def test_arange_schema() -> None: + result = pl.LazyFrame().select(pl.arange(-3, 3)) + + expected_schema = {"arange": pl.Int64} + assert result.schema == expected_schema + assert result.collect().schema == expected_schema + + +def test_int_range() -> None: + result = pl.int_range(0, 3) + expected = pl.Series("int", [0, 1, 2]) + assert_series_equal(pl.select(result).to_series(), expected) + + +def test_int_range_eager() -> None: + result = pl.int_range(0, 3, eager=True) + expected = pl.Series("int", [0, 1, 2]) + assert_series_equal(result, expected) + + +def test_int_range_schema() -> None: + result = pl.LazyFrame().select(pl.int_range(-3, 3)) + + expected_schema = {"int": pl.Int64} + assert result.schema == expected_schema + assert result.collect().schema == expected_schema + + +@pytest.mark.parametrize( + ("start", "end", "expected"), + [ + ("a", "b", pl.Series("int_range", [[1, 2], [2, 3]])), + (-1, "a", pl.Series("int_range", [[-1, 0], [-1, 0, 1]])), + ("b", 4, pl.Series("int_range", [[3], []])), + ], +) +def test_int_ranges(start: Any, end: Any, expected: pl.Series) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + + result = df.select(pl.int_ranges(start, end)) + assert_series_equal(result.to_series(), expected) + + +def test_int_ranges_eager() -> None: + start = pl.Series([1, 2]) + result = pl.int_ranges(start, 4, eager=True) + + expected = pl.Series("int_range", [[1, 2, 3], [2, 3]]) + assert_series_equal(result, expected) + + +def test_int_ranges_schema_dtype_default() -> None: + lf = pl.LazyFrame({"start": [1, 2], "end": [3, 4]}) + + result = lf.select(pl.int_ranges("start", "end")) + + expected_schema = {"int_range": pl.List(pl.Int64)} + assert result.schema == expected_schema + assert result.collect().schema == expected_schema + + +def test_int_ranges_schema_dtype_arg() -> None: + lf = pl.LazyFrame({"start": [1, 2], "end": [3, 4]}) + + result = lf.select(pl.int_ranges("start", "end", dtype=pl.UInt16)) + + expected_schema = {"int_range": pl.List(pl.UInt16)} + assert result.schema == expected_schema + assert result.collect().schema == expected_schema + + def test_date_range() -> None: result = pl.date_range( date(1985, 1, 1), date(2015, 7, 1), timedelta(days=1, hours=12), eager=True