From 62513830d9ee1010f43ef9bc23961533211e8658 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Wed, 7 Jun 2023 01:06:16 -0500 Subject: [PATCH] feat(rust,python): add "sql_expr" function (#9248) --- polars/polars-sql/src/lib.rs | 1 + polars/polars-sql/src/sql_expr.rs | 35 +++++++++++++++++++ polars/polars-sql/tests/simple_exprs.rs | 9 +++++ polars/src/sql.rs | 2 +- .../reference/expressions/functions.rst | 1 + py-polars/polars/__init__.py | 2 ++ py-polars/polars/functions/__init__.py | 2 ++ py-polars/polars/functions/lazy.py | 26 ++++++++++++++ py-polars/src/functions/lazy.rs | 7 ++++ py-polars/src/lib.rs | 4 +++ py-polars/tests/unit/test_sql.py | 7 ++++ 11 files changed, 95 insertions(+), 1 deletion(-) diff --git a/polars/polars-sql/src/lib.rs b/polars/polars-sql/src/lib.rs index 89bb66dd454c..35dd67d76370 100644 --- a/polars/polars-sql/src/lib.rs +++ b/polars/polars-sql/src/lib.rs @@ -8,3 +8,4 @@ mod sql_expr; mod table_functions; pub use context::SQLContext; +pub use sql_expr::sql_expr; diff --git a/polars/polars-sql/src/sql_expr.rs b/polars/polars-sql/src/sql_expr.rs index 3d8ab7921b4c..febdd8c73617 100644 --- a/polars/polars-sql/src/sql_expr.rs +++ b/polars/polars-sql/src/sql_expr.rs @@ -1,3 +1,4 @@ +use polars_arrow::error::to_compute_err; use polars_core::prelude::*; use polars_lazy::dsl::Expr; use polars_lazy::prelude::*; @@ -7,6 +8,8 @@ use sqlparser::ast::{ Expr as SqlExpr, Function as SQLFunction, JoinConstraint, OrderByExpr, TrimWhereField, UnaryOperator, Value as SqlValue, }; +use sqlparser::dialect::GenericDialect; +use sqlparser::parser::{Parser, ParserOptions}; use crate::functions::SqlFunctionVisitor; use crate::SQLContext; @@ -465,3 +468,35 @@ pub(super) fn process_join_constraint( } polars_bail!(InvalidOperation: "SQL join constraint {:?} is not yet supported", constraint); } + +/// parse a SQL expression to a polars expression +/// # Example +/// ```rust +/// # use polars_sql::{SQLContext, sql_expr}; +/// # use polars_core::prelude::*; +/// # use polars_lazy::prelude::*; +/// # fn main() { +/// +/// let mut ctx = SQLContext::new(); +/// let df = df! { +/// "a" => [1, 2, 3], +/// } +/// .unwrap(); +/// let expr = sql_expr("MAX(a)").unwrap(); +/// df.lazy().select(vec![expr]).collect().unwrap(); +/// # } +/// ``` +pub fn sql_expr>(s: S) -> PolarsResult { + let ctx = SQLContext::new(); + + let mut parser = Parser::new(&GenericDialect); + parser = parser.with_options(ParserOptions { + trailing_commas: true, + }); + + let mut ast = parser.try_with_sql(s.as_ref()).map_err(to_compute_err)?; + + let expr = ast.parse_expr().map_err(to_compute_err)?; + + parse_sql_expr(&expr, &ctx) +} diff --git a/polars/polars-sql/tests/simple_exprs.rs b/polars/polars-sql/tests/simple_exprs.rs index 309ec91c49ca..350a596b859a 100644 --- a/polars/polars-sql/tests/simple_exprs.rs +++ b/polars/polars-sql/tests/simple_exprs.rs @@ -524,3 +524,12 @@ fn test_case_expr() { let df_pl = df.lazy().select(&[case_expr]).collect().unwrap(); assert!(df_sql.frame_equal(&df_pl)); } + +#[test] +fn test_sql_expr() { + let df = create_sample_df().unwrap(); + let expr = sql_expr("MIN(a)").unwrap(); + let actual = df.clone().lazy().select(&[expr]).collect().unwrap(); + let expected = df.lazy().select(&[col("a").min()]).collect().unwrap(); + assert!(actual.frame_equal(&expected)); +} diff --git a/polars/src/sql.rs b/polars/src/sql.rs index f4e1fe05ede8..e0451dc7c505 100644 --- a/polars/src/sql.rs +++ b/polars/src/sql.rs @@ -1 +1 @@ -pub use polars_sql::{keywords, SQLContext}; +pub use polars_sql::{keywords, sql_expr, SQLContext}; diff --git a/py-polars/docs/source/reference/expressions/functions.rst b/py-polars/docs/source/reference/expressions/functions.rst index faa721eeb5b9..9d2d2718c087 100644 --- a/py-polars/docs/source/reference/expressions/functions.rst +++ b/py-polars/docs/source/reference/expressions/functions.rst @@ -56,6 +56,7 @@ These functions are available from the polars module root and can be used as exp std struct sum + sql_expr tail time var diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 08f4afe16acf..6e30b4cbcf21 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -121,6 +121,7 @@ rolling_corr, rolling_cov, select, + sql_expr, std, struct, sum, @@ -342,6 +343,7 @@ "threadpool_size", # selectors "selectors", + "sql_expr", ] os.environ["POLARS_ALLOW_EXTENSION"] = "true" diff --git a/py-polars/polars/functions/__init__.py b/py-polars/polars/functions/__init__.py index cf72e37cc3dc..8f93e87ece43 100644 --- a/py-polars/polars/functions/__init__.py +++ b/py-polars/polars/functions/__init__.py @@ -47,6 +47,7 @@ rolling_corr, rolling_cov, select, + sql_expr, std, sum, tail, @@ -118,4 +119,5 @@ "var", # polars.functions.whenthen "when", + "sql_expr", ] diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 1c0b5e7bb99b..aa6356fdfe54 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -2652,3 +2652,29 @@ def rolling_corr( return wrap_expr( plr.rolling_corr(a._pyexpr, b._pyexpr, window_size, min_periods, ddof) ) + + +def sql_expr(sql: str) -> Expr: + """ + Parse a SQL expression to a polars expression. + + Parameters + ---------- + sql + SQL expression + + Examples + -------- + >>> df = pl.DataFrame({"a": [2, 1]}) + >>> expr = pl.sql_expr("MAX(a)") + >>> df.select(expr) + shape: (1, 1) + ┌─────┐ + │ a │ + │ --- │ + │ i64 │ + ╞═════╡ + │ 2 │ + └─────┘ + """ + return wrap_expr(plr.sql_expr(sql)) diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index e421d7262dc5..277ee73860e7 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -452,3 +452,10 @@ pub fn time_range_lazy( let every = Duration::parse(every); dsl::functions::time_range(start, end, every, closed.0).into() } + +#[pyfunction] +#[cfg(feature = "sql")] +pub fn sql_expr(sql: &str) -> PyResult { + let expr = polars::sql::sql_expr(sql).map_err(PyPolarsErr::from)?; + Ok(expr.into()) +} diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index ee3d7b2e57d7..08ad1a42ac08 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -166,6 +166,10 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(functions::whenthen::when)) .unwrap(); + #[cfg(feature = "sql")] + m.add_wrapped(wrap_pyfunction!(functions::lazy::sql_expr)) + .unwrap(); + // Functions - I/O #[cfg(feature = "ipc")] m.add_wrapped(wrap_pyfunction!(functions::io::read_ipc_schema)) diff --git a/py-polars/tests/unit/test_sql.py b/py-polars/tests/unit/test_sql.py index 532d6bcf8993..5d325cd84618 100644 --- a/py-polars/tests/unit/test_sql.py +++ b/py-polars/tests/unit/test_sql.py @@ -234,3 +234,10 @@ def test_register_context() -> None: assert ctx.tables() == ["_lf1", "_lf2"] assert ctx.tables() == [] + + +def test_sql_expr() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, None, 6]}) + sql_expr = pl.sql_expr("MIN(a)") + expected = pl.DataFrame({"a": [1]}) + assert df.select(sql_expr).frame_equal(expected)