Skip to content

Commit

Permalink
feat(rust,python): add "sql_expr" function (pola-rs#9248)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored and c-peters committed Jul 14, 2023
1 parent 74def3c commit 6251383
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 1 deletion.
1 change: 1 addition & 0 deletions polars/polars-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ mod sql_expr;
mod table_functions;

pub use context::SQLContext;
pub use sql_expr::sql_expr;
35 changes: 35 additions & 0 deletions polars/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use polars_arrow::error::to_compute_err;
use polars_core::prelude::*;
use polars_lazy::dsl::Expr;
use polars_lazy::prelude::*;
Expand All @@ -7,6 +8,8 @@ use sqlparser::ast::{
Expr as SqlExpr, Function as SQLFunction, JoinConstraint, OrderByExpr, TrimWhereField,
UnaryOperator, Value as SqlValue,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserOptions};

use crate::functions::SqlFunctionVisitor;
use crate::SQLContext;
Expand Down Expand Up @@ -465,3 +468,35 @@ pub(super) fn process_join_constraint(
}
polars_bail!(InvalidOperation: "SQL join constraint {:?} is not yet supported", constraint);
}

/// parse a SQL expression to a polars expression
/// # Example
/// ```rust
/// # use polars_sql::{SQLContext, sql_expr};
/// # use polars_core::prelude::*;
/// # use polars_lazy::prelude::*;
/// # fn main() {
///
/// let mut ctx = SQLContext::new();
/// let df = df! {
/// "a" => [1, 2, 3],
/// }
/// .unwrap();
/// let expr = sql_expr("MAX(a)").unwrap();
/// df.lazy().select(vec![expr]).collect().unwrap();
/// # }
/// ```
pub fn sql_expr<S: AsRef<str>>(s: S) -> PolarsResult<Expr> {
let ctx = SQLContext::new();

let mut parser = Parser::new(&GenericDialect);
parser = parser.with_options(ParserOptions {
trailing_commas: true,
});

let mut ast = parser.try_with_sql(s.as_ref()).map_err(to_compute_err)?;

let expr = ast.parse_expr().map_err(to_compute_err)?;

parse_sql_expr(&expr, &ctx)
}
9 changes: 9 additions & 0 deletions polars/polars-sql/tests/simple_exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,12 @@ fn test_case_expr() {
let df_pl = df.lazy().select(&[case_expr]).collect().unwrap();
assert!(df_sql.frame_equal(&df_pl));
}

#[test]
fn test_sql_expr() {
let df = create_sample_df().unwrap();
let expr = sql_expr("MIN(a)").unwrap();
let actual = df.clone().lazy().select(&[expr]).collect().unwrap();
let expected = df.lazy().select(&[col("a").min()]).collect().unwrap();
assert!(actual.frame_equal(&expected));
}
2 changes: 1 addition & 1 deletion polars/src/sql.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pub use polars_sql::{keywords, SQLContext};
pub use polars_sql::{keywords, sql_expr, SQLContext};
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ These functions are available from the polars module root and can be used as exp
std
struct
sum
sql_expr
tail
time
var
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
rolling_corr,
rolling_cov,
select,
sql_expr,
std,
struct,
sum,
Expand Down Expand Up @@ -342,6 +343,7 @@
"threadpool_size",
# selectors
"selectors",
"sql_expr",
]

os.environ["POLARS_ALLOW_EXTENSION"] = "true"
2 changes: 2 additions & 0 deletions py-polars/polars/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
rolling_corr,
rolling_cov,
select,
sql_expr,
std,
sum,
tail,
Expand Down Expand Up @@ -118,4 +119,5 @@
"var",
# polars.functions.whenthen
"when",
"sql_expr",
]
26 changes: 26 additions & 0 deletions py-polars/polars/functions/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2652,3 +2652,29 @@ def rolling_corr(
return wrap_expr(
plr.rolling_corr(a._pyexpr, b._pyexpr, window_size, min_periods, ddof)
)


def sql_expr(sql: str) -> Expr:
"""
Parse a SQL expression to a polars expression.
Parameters
----------
sql
SQL expression
Examples
--------
>>> df = pl.DataFrame({"a": [2, 1]})
>>> expr = pl.sql_expr("MAX(a)")
>>> df.select(expr)
shape: (1, 1)
┌─────┐
│ a │
│ --- │
│ i64 │
╞═════╡
│ 2 │
└─────┘
"""
return wrap_expr(plr.sql_expr(sql))
7 changes: 7 additions & 0 deletions py-polars/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,10 @@ pub fn time_range_lazy(
let every = Duration::parse(every);
dsl::functions::time_range(start, end, every, closed.0).into()
}

#[pyfunction]
#[cfg(feature = "sql")]
pub fn sql_expr(sql: &str) -> PyResult<PyExpr> {
let expr = polars::sql::sql_expr(sql).map_err(PyPolarsErr::from)?;
Ok(expr.into())
}
4 changes: 4 additions & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(functions::whenthen::when))
.unwrap();

#[cfg(feature = "sql")]
m.add_wrapped(wrap_pyfunction!(functions::lazy::sql_expr))
.unwrap();

// Functions - I/O
#[cfg(feature = "ipc")]
m.add_wrapped(wrap_pyfunction!(functions::io::read_ipc_schema))
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,10 @@ def test_register_context() -> None:
assert ctx.tables() == ["_lf1", "_lf2"]

assert ctx.tables() == []


def test_sql_expr() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, None, 6]})
sql_expr = pl.sql_expr("MIN(a)")
expected = pl.DataFrame({"a": [1]})
assert df.select(sql_expr).frame_equal(expected)

0 comments on commit 6251383

Please sign in to comment.