From e64bb9839049a72a2d98422da8604cab3e167ee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Sun, 17 Dec 2023 23:55:20 +0100 Subject: [PATCH] fix(ir): implicitly convert `None` literals with `dt.Null` type to the requested type during value coercion --- ibis/expr/operations/core.py | 10 ++++++++-- ibis/expr/operations/generic.py | 6 ++++++ ibis/expr/operations/tests/test_generic.py | 21 ++++++++++++++++++++ ibis/expr/tests/test_api.py | 23 ++++++++++++++++++++++ ibis/expr/types/generic.py | 2 +- 5 files changed, 59 insertions(+), 3 deletions(-) diff --git a/ibis/expr/operations/core.py b/ibis/expr/operations/core.py index 7c56816cdbae..c7b4e0a1dc75 100644 --- a/ibis/expr/operations/core.py +++ b/ibis/expr/operations/core.py @@ -68,13 +68,19 @@ def __coerce__( ) -> Self: # note that S=Shape is unused here since the pattern will check the # shape of the value expression after executing Value.__coerce__() - from ibis.expr.operations import Literal + from ibis.expr.operations.generic import NULL, Literal from ibis.expr.types import Expr if isinstance(value, Expr): value = value.op() + if isinstance(value, Value): - return value + if value == NULL: + # treat the NULL literal the same as None to implicitly cast to + # the requested datatype if any + value = None + else: + return value if T is dt.Integer: dtype = dt.infer(int(value)) diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index b34f269c6daa..8f2daaf29de9 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -195,6 +195,9 @@ def name(self): return repr(self.value) +NULL = Literal(None, dt.null) + + @public class ScalarParameter(Scalar, Named): _counter = itertools.count() @@ -313,3 +316,6 @@ def shape(self): def dtype(self): exprs = [*self.results, self.default] return rlz.highest_precedence_dtype(exprs) + + +public(NULL=NULL) diff --git a/ibis/expr/operations/tests/test_generic.py b/ibis/expr/operations/tests/test_generic.py index 4b8d53a1e073..8dd7ff12ebcf 100644 --- a/ibis/expr/operations/tests/test_generic.py +++ b/ibis/expr/operations/tests/test_generic.py @@ -122,3 +122,24 @@ def test_error_message_when_constructing_literal(call, error, snapshot): with pytest.raises(ValidationError) as exc: call() snapshot.assert_match(str(exc.value), f"{error}.txt") + + +def test_implicit_coercion_of_null_literal(): + # GH #7775 + NULL = ops.Literal(None, dt.null) + + value = ops.Value.__coerce__(None, dt.Int8) + expected = ops.Literal(None, dt.int8) + assert value == expected + + value = ops.Value.__coerce__(NULL, dt.Float64) + expected = ops.Literal(None, dt.float64) + assert value == expected + + +def test_NULL(): + assert isinstance(ops.NULL, ops.Literal) + assert ops.NULL.value is None + assert ops.NULL.dtype is dt.null + assert ops.NULL == ops.Literal(None, dt.null) + assert ops.NULL is not ops.Literal(None, dt.int8) diff --git a/ibis/expr/tests/test_api.py b/ibis/expr/tests/test_api.py index 8ae1aca5e8b0..d72a642b088b 100644 --- a/ibis/expr/tests/test_api.py +++ b/ibis/expr/tests/test_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator from datetime import datetime import pandas as pd @@ -9,6 +10,7 @@ import ibis import ibis.expr.datatypes as dt +import ibis.expr.operations as ops import ibis.expr.schema as sch from ibis import _ from ibis.common.exceptions import IbisInputError, IntegrityError @@ -124,3 +126,24 @@ def test_duplicate_columns_in_memtable_not_allowed(): with pytest.raises(IbisInputError, match="Duplicate column names"): ibis.memtable(df) + + +@pytest.mark.parametrize( + "op", + [ + operator.and_, + operator.or_, + operator.xor, + ], +) +def test_implicit_coercion_of_null_literal(op): + # GH #7775 + expr1 = op(ibis.literal(True), ibis.null()) + expr2 = op(ibis.literal(True), None) + + expected = expr1.op().__class__( + ops.Literal(True, dtype=dt.boolean), ops.Literal(None, dtype=dt.boolean) + ) + + assert expr1.op() == expected + assert expr2.op() == expected diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index f41ccc686216..1098f0cd5b22 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1998,7 +1998,7 @@ class NullColumn(Column, NullValue): @public def null(): """Create a NULL/NA scalar.""" - return literal(None) + return ops.NULL.to_expr() @public