diff --git a/polars/polars-sql/src/functions.rs b/polars/polars-sql/src/functions.rs index 09589c3a0b7d..ec6f73785b62 100644 --- a/polars/polars-sql/src/functions.rs +++ b/polars/polars-sql/src/functions.rs @@ -498,14 +498,15 @@ impl SqlFunctionVisitor<'_> { fn visit_unary_no_window(&self, f: impl Fn(Expr) -> Expr) -> PolarsResult { let function = self.func; + let args = extract_args(function); - if let FunctionArgExpr::Expr(sql_expr) = args[0] { - // parse the inner sql expr -- e.g. SUM(a) -> a - let expr = parse_sql_expr(sql_expr, self.ctx)?; - // apply the function on the inner expr -- e.g. SUM(a) -> SUM - Ok(f(expr)) - } else { - not_supported_error(function.name.0[0].value.as_str(), &args) + match args.as_slice() { + [FunctionArgExpr::Expr(sql_expr)] => { + let expr = parse_sql_expr(sql_expr, self.ctx)?; + // apply the function on the inner expr -- e.g. SUM(a) -> SUM + Ok(f(expr)) + } + _ => self.not_supported_error(), } } @@ -519,52 +520,37 @@ impl SqlFunctionVisitor<'_> { ) -> PolarsResult { let function = self.func; let args = extract_args(function); - if let FunctionArgExpr::Expr(sql_expr) = args[0] { - let expr = - self.apply_window_spec(parse_sql_expr(sql_expr, self.ctx)?, &function.over)?; - if let FunctionArgExpr::Expr(sql_expr) = args[1] { - let expr2 = Arg::from_sql_expr(sql_expr, self.ctx)?; + match args.as_slice() { + [FunctionArgExpr::Expr(sql_expr), FunctionArgExpr::Expr(sql_expr2)] => { + let expr = parse_sql_expr(sql_expr, self.ctx)?; + let expr2 = Arg::from_sql_expr(sql_expr2, self.ctx)?; f(expr, expr2) - } else { - not_supported_error(function.name.0[0].value.as_str(), &args) } - } else { - not_supported_error(function.name.0[0].value.as_str(), &args) + _ => self.not_supported_error(), } } fn visit_count(&self) -> PolarsResult { let args = extract_args(self.func); - Ok(match (args.len(), self.func.distinct) { + match (self.func.distinct, args.as_slice()) { // count() - (0, false) => count(), - // count(distinct) - (0, true) => return not_supported_error("count", &args), - (1, false) => match args[0] { - // count(col) - FunctionArgExpr::Expr(sql_expr) => { - let expr = self - .apply_window_spec(parse_sql_expr(sql_expr, self.ctx)?, &self.func.over)?; - expr.count() - } - // count(*) - FunctionArgExpr::Wildcard => count(), - // count(tbl.*) is not supported - _ => return not_supported_error("count", &args), - }, - (1, true) => { - // count(distinct col) - if let FunctionArgExpr::Expr(sql_expr) = args[0] { - let expr = self - .apply_window_spec(parse_sql_expr(sql_expr, self.ctx)?, &self.func.over)?; - expr.n_unique() - } else { - // count(distinct *) or count(distinct tbl.*) is not supported - return not_supported_error("count", &args); - } + (false, []) => Ok(count()), + // count(column_name) + (false, [FunctionArgExpr::Expr(sql_expr)]) => { + let expr = + self.apply_window_spec(parse_sql_expr(sql_expr, self.ctx)?, &self.func.over)?; + Ok(expr.count()) } - _ => return not_supported_error("count", &args), - }) + // count(*) + (false, [FunctionArgExpr::Wildcard]) => Ok(count()), + // count(distinct column_name) + (true, [FunctionArgExpr::Expr(sql_expr)]) => { + let expr = + self.apply_window_spec(parse_sql_expr(sql_expr, self.ctx)?, &self.func.over)?; + Ok(expr.n_unique()) + } + _ => self.not_supported_error(), + } } fn apply_window_spec( @@ -604,14 +590,14 @@ impl SqlFunctionVisitor<'_> { None => expr, }) } -} -fn not_supported_error(function_name: &str, args: &Vec<&FunctionArgExpr>) -> PolarsResult { - polars_bail!( - InvalidOperation: - "function `{}` with args {:?} is not supported in polars-sql", - function_name, args - ); + fn not_supported_error(&self) -> PolarsResult { + polars_bail!( + InvalidOperation: + "No function matches the given name and arguments: `{}`", + self.func.to_string() + ); + } } fn extract_args(sql_function: &SQLFunction) -> Vec<&FunctionArgExpr> { diff --git a/polars/polars-sql/tests/simple_exprs.rs b/polars/polars-sql/tests/simple_exprs.rs index 350a596b859a..822de48bd5b8 100644 --- a/polars/polars-sql/tests/simple_exprs.rs +++ b/polars/polars-sql/tests/simple_exprs.rs @@ -533,3 +533,20 @@ fn test_sql_expr() { let expected = df.lazy().select(&[col("a").min()]).collect().unwrap(); assert!(actual.frame_equal(&expected)); } + +#[test] +fn test_iss_9471() { + let sql = r#" + SELECT + ABS(a,a,a,a,1,2,3,XYZRandomLetters,"XYZRandomLetters") as "abs", + FROM df"#; + let df = df! { + "a" => [-4, -3, -2, -1, 0, 1, 2, 3, 4], + } + .unwrap() + .lazy(); + let mut context = SQLContext::new(); + context.register("df", df.clone()); + let res = context.execute(sql); + assert!(res.is_err()) +}