diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round/out.sql index 8e25278c7b9a..771a92a03516 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round/out.sql @@ -1,3 +1,3 @@ SELECT - ROUND("t0"."double_col") AS "Round(double_col)" + CAST(ROUND("t0"."double_col", 0) AS Nullable(Int64)) AS "Round(double_col, 0)" FROM "functional_alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round_0/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round_0/out.sql index 0540e27c45a9..771a92a03516 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round_0/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round_0/out.sql @@ -1,3 +1,3 @@ SELECT - ROUND("t0"."double_col", 0) AS "Round(double_col, 0)" + CAST(ROUND("t0"."double_col", 0) AS Nullable(Int64)) AS "Round(double_col, 0)" FROM "functional_alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round_2/out.sql b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round_2/out.sql index 86e64f8f88bf..fafd92797586 100644 --- a/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round_2/out.sql +++ b/ibis/backends/clickhouse/tests/snapshots/test_functions/test_translate_math_functions/round_2/out.sql @@ -1,3 +1,3 @@ SELECT - ROUND("t0"."double_col", 2) AS "Round(double_col, 2)" + CAST(ROUND("t0"."double_col", 2) AS Nullable(Float64)) AS "Round(double_col, 2)" FROM "functional_alltypes" AS "t0" \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_unary_builtins/test_numeric/round_no_args/out.sql b/ibis/backends/impala/tests/snapshots/test_unary_builtins/test_numeric/round_no_args/out.sql index 8acebb77ef31..dec5d37db9aa 100644 --- a/ibis/backends/impala/tests/snapshots/test_unary_builtins/test_numeric/round_no_args/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_unary_builtins/test_numeric/round_no_args/out.sql @@ -1,3 +1,3 @@ SELECT - CAST(ROUND(`t0`.`double_col`) AS BIGINT) AS `Round(double_col)` + CAST(ROUND(`t0`.`double_col`, 0) AS BIGINT) AS `Round(double_col, 0)` FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/impala/tests/snapshots/test_unary_builtins/test_numeric/round_zero/out.sql b/ibis/backends/impala/tests/snapshots/test_unary_builtins/test_numeric/round_zero/out.sql index 5454c320029e..dec5d37db9aa 100644 --- a/ibis/backends/impala/tests/snapshots/test_unary_builtins/test_numeric/round_zero/out.sql +++ b/ibis/backends/impala/tests/snapshots/test_unary_builtins/test_numeric/round_zero/out.sql @@ -1,3 +1,3 @@ SELECT - ROUND(`t0`.`double_col`, 0) AS `Round(double_col, 0)` + CAST(ROUND(`t0`.`double_col`, 0) AS BIGINT) AS `Round(double_col, 0)` FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index 786f67afa6db..bdfb1af0f54a 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -856,9 +856,7 @@ def visit_Floor(self, op, *, arg): return self.cast(self.f.floor(arg), op.dtype) def visit_Round(self, op, *, arg, digits): - if digits is not None: - return sge.Round(this=arg, decimals=digits) - return sge.Round(this=arg) + return self.cast(self.f.round(arg, digits), op.dtype) ### Random Noise diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index 7fef6d7715f0..791ce102c1f1 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -500,5 +500,10 @@ def visit_Strip(self, op, *, arg): # remove. return self.visit_RStrip(op, arg=self.visit_LStrip(op, arg=arg)) + def visit_Round(self, op, *, arg, digits): + if op.dtype.is_integer(): + return self.f.round(arg) + return self.cast(self.f.round(arg, digits), op.dtype) + compiler = OracleCompiler() diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index 3e38d87c5352..6fafa8619e3e 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -535,13 +535,14 @@ def visit_TypeOf(self, op, *, arg): ) def visit_Round(self, op, *, arg, digits): - if digits is None: - return self.f.round(arg) + dtype = op.dtype - result = self.f.round(self.cast(arg, dt.decimal), digits) - if op.arg.dtype.is_decimal(): - return result - return self.cast(result, dt.float64) + if dtype.is_integer(): + result = self.f.round(arg) + else: + result = self.f.round(self.cast(arg, dt.decimal), digits) + + return self.cast(result, dtype) def visit_Modulus(self, op, *, left, right): # postgres doesn't allow modulus of double precision values, so upcast and diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index 01a04e455416..9deeb4e9a8fc 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -1523,3 +1523,25 @@ def test_bitwise_not_col(backend, alltypes, df): result = expr.execute() expected = ~df.int_col backend.assert_series_equal(result, expected.rename("tmp")) + + +def test_column_round_is_integer(con): + t = ibis.memtable({"x": [1.2, 3.4]}) + expr = t.x.round().cast(int) + result = con.execute(expr) + + one, three = sorted(result.tolist()) + + assert one == 1 + assert isinstance(one, int) + + assert three == 3 + assert isinstance(three, int) + + +def test_scalar_round_is_integer(con): + expr = ibis.literal(1.2).round().cast(int) + result = con.execute(expr) + + assert result == 1 + assert isinstance(result, int) diff --git a/ibis/expr/operations/numeric.py b/ibis/expr/operations/numeric.py index 2f4c90a55605..a8bba2c0740b 100644 --- a/ibis/expr/operations/numeric.py +++ b/ibis/expr/operations/numeric.py @@ -148,19 +148,31 @@ class Round(Value): """Round a value.""" arg: StrictNumeric - # TODO(kszucs): the default should be 0 instead of being None - digits: Optional[Integer] = None + digits: Integer shape = rlz.shape_like("arg") @property def dtype(self): - if self.arg.dtype.is_decimal(): - return self.arg.dtype - elif self.digits is None: - return dt.int64 - else: - return dt.double + digits = self.digits + arg_dtype = self.arg.dtype + + raw_digits = getattr(digits, "value", None) + + # decimals with literal-typed digits return decimals + if arg_dtype.is_decimal() and raw_digits is not None: + return arg_dtype.copy(scale=raw_digits) + + nullable = arg_dtype.nullable + + # if digits are unspecified that means round to an integer + if raw_digits is not None and raw_digits == 0: + return dt.int64.copy(nullable=nullable) + + # otherwise one of the following is true: + # 1. digits are specified as a more complex expression + # 2. self.arg is a double column + return dt.double.copy(nullable=nullable) @public diff --git a/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt b/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt index 23cc70e5b6ac..3fa92abc9f91 100644 --- a/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt @@ -14,7 +14,7 @@ r0 := UnboundTable: alltypes Aggregate[r0] groups: key1: r0.g - key2: Round(r0.f) + key2: Round(r0.f, digits=0) metrics: c: Sum(r0.c) d: Mean(r0.d) \ No newline at end of file diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index d6836c68ad04..2191663aa004 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -56,7 +56,7 @@ def __neg__(self) -> NumericValue: """ return self.negate() - def round(self, digits: int | IntegerValue | None = None) -> NumericValue: + def round(self, digits: int | IntegerValue = 0) -> NumericValue: """Round values to an indicated number of decimal places. Parameters @@ -94,16 +94,16 @@ def round(self, digits: int | IntegerValue | None = None) -> NumericValue: │ 2.54 │ └─────────┘ >>> t.values.round() - ┏━━━━━━━━━━━━━━━┓ - ┃ Round(values) ┃ - ┡━━━━━━━━━━━━━━━┩ - │ int64 │ - ├───────────────┤ - │ 1 │ - │ 2 │ - │ 2 │ - │ 3 │ - └───────────────┘ + ┏━━━━━━━━━━━━━━━━━━┓ + ┃ Round(values, 0) ┃ + ┡━━━━━━━━━━━━━━━━━━┩ + │ int64 │ + ├──────────────────┤ + │ 1 │ + │ 2 │ + │ 2 │ + │ 3 │ + └──────────────────┘ >>> t.values.round(digits=1) ┏━━━━━━━━━━━━━━━━━━┓ ┃ Round(values, 1) ┃ diff --git a/ibis/tests/expr/test_sql_builtins.py b/ibis/tests/expr/test_sql_builtins.py index 578f5fc83f88..1e0f702a4f52 100644 --- a/ibis/tests/expr/test_sql_builtins.py +++ b/ibis/tests/expr/test_sql_builtins.py @@ -148,7 +148,7 @@ def test_sign(functional_alltypes, lineitem): def test_round(functional_alltypes, lineitem): result = functional_alltypes.double_col.round() assert isinstance(result, ir.IntegerColumn) - assert result.op().args[1] is None + assert result.op().args[1] == ibis.literal(0).op() result = functional_alltypes.double_col.round(2) assert isinstance(result, ir.FloatingColumn)