Skip to content

Commit

Permalink
fix(api): ensure consistent typing in round output type (#10351)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Oct 27, 2024
1 parent 18224f8 commit b2b0925
Show file tree
Hide file tree
Showing 13 changed files with 73 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
ROUND("t0"."double_col") AS "Round(double_col)"
CAST(ROUND("t0"."double_col", 0) AS Nullable(Int64)) AS "Round(double_col, 0)"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
ROUND("t0"."double_col", 0) AS "Round(double_col, 0)"
CAST(ROUND("t0"."double_col", 0) AS Nullable(Int64)) AS "Round(double_col, 0)"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
ROUND("t0"."double_col", 2) AS "Round(double_col, 2)"
CAST(ROUND("t0"."double_col", 2) AS Nullable(Float64)) AS "Round(double_col, 2)"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
CAST(ROUND(`t0`.`double_col`) AS BIGINT) AS `Round(double_col)`
CAST(ROUND(`t0`.`double_col`, 0) AS BIGINT) AS `Round(double_col, 0)`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
ROUND(`t0`.`double_col`, 0) AS `Round(double_col, 0)`
CAST(ROUND(`t0`.`double_col`, 0) AS BIGINT) AS `Round(double_col, 0)`
FROM `functional_alltypes` AS `t0`
4 changes: 1 addition & 3 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,9 +856,7 @@ def visit_Floor(self, op, *, arg):
return self.cast(self.f.floor(arg), op.dtype)

def visit_Round(self, op, *, arg, digits):
if digits is not None:
return sge.Round(this=arg, decimals=digits)
return sge.Round(this=arg)
return self.cast(self.f.round(arg, digits), op.dtype)

### Random Noise

Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,5 +500,10 @@ def visit_Strip(self, op, *, arg):
# remove.
return self.visit_RStrip(op, arg=self.visit_LStrip(op, arg=arg))

def visit_Round(self, op, *, arg, digits):
if op.dtype.is_integer():
return self.f.round(arg)
return self.cast(self.f.round(arg, digits), op.dtype)


compiler = OracleCompiler()
13 changes: 7 additions & 6 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,13 +535,14 @@ def visit_TypeOf(self, op, *, arg):
)

def visit_Round(self, op, *, arg, digits):
if digits is None:
return self.f.round(arg)
dtype = op.dtype

result = self.f.round(self.cast(arg, dt.decimal), digits)
if op.arg.dtype.is_decimal():
return result
return self.cast(result, dt.float64)
if dtype.is_integer():
result = self.f.round(arg)
else:
result = self.f.round(self.cast(arg, dt.decimal), digits)

return self.cast(result, dtype)

def visit_Modulus(self, op, *, left, right):
# postgres doesn't allow modulus of double precision values, so upcast and
Expand Down
22 changes: 22 additions & 0 deletions ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,3 +1523,25 @@ def test_bitwise_not_col(backend, alltypes, df):
result = expr.execute()
expected = ~df.int_col
backend.assert_series_equal(result, expected.rename("tmp"))


def test_column_round_is_integer(con):
t = ibis.memtable({"x": [1.2, 3.4]})
expr = t.x.round().cast(int)
result = con.execute(expr)

one, three = sorted(result.tolist())

assert one == 1
assert isinstance(one, int)

assert three == 3
assert isinstance(three, int)


def test_scalar_round_is_integer(con):
expr = ibis.literal(1.2).round().cast(int)
result = con.execute(expr)

assert result == 1
assert isinstance(result, int)
28 changes: 20 additions & 8 deletions ibis/expr/operations/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,31 @@ class Round(Value):
"""Round a value."""

arg: StrictNumeric
# TODO(kszucs): the default should be 0 instead of being None
digits: Optional[Integer] = None
digits: Integer

shape = rlz.shape_like("arg")

@property
def dtype(self):
if self.arg.dtype.is_decimal():
return self.arg.dtype
elif self.digits is None:
return dt.int64
else:
return dt.double
digits = self.digits
arg_dtype = self.arg.dtype

raw_digits = getattr(digits, "value", None)

# decimals with literal-typed digits return decimals
if arg_dtype.is_decimal() and raw_digits is not None:
return arg_dtype.copy(scale=raw_digits)

nullable = arg_dtype.nullable

# if digits are unspecified that means round to an integer
if raw_digits is not None and raw_digits == 0:
return dt.int64.copy(nullable=nullable)

# otherwise one of the following is true:
# 1. digits are specified as a more complex expression
# 2. self.arg is a double column
return dt.double.copy(nullable=nullable)


@public
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ r0 := UnboundTable: alltypes
Aggregate[r0]
groups:
key1: r0.g
key2: Round(r0.f)
key2: Round(r0.f, digits=0)
metrics:
c: Sum(r0.c)
d: Mean(r0.d)
22 changes: 11 additions & 11 deletions ibis/expr/types/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __neg__(self) -> NumericValue:
"""
return self.negate()

def round(self, digits: int | IntegerValue | None = None) -> NumericValue:
def round(self, digits: int | IntegerValue = 0) -> NumericValue:
"""Round values to an indicated number of decimal places.
Parameters
Expand Down Expand Up @@ -94,16 +94,16 @@ def round(self, digits: int | IntegerValue | None = None) -> NumericValue:
│ 2.54 │
└─────────┘
>>> t.values.round()
┏━━━━━━━━━━━━━━━┓
┃ Round(values) ┃
┡━━━━━━━━━━━━━━━┩
│ int64 │
├───────────────┤
│ 1 │
│ 2 │
│ 2 │
│ 3 │
└───────────────┘
┏━━━━━━━━━━━━━━━━━━
┃ Round(values, 0) ┃
┡━━━━━━━━━━━━━━━━━━
│ int64
├──────────────────
1 │
2 │
2 │
3 │
└──────────────────
>>> t.values.round(digits=1)
┏━━━━━━━━━━━━━━━━━━┓
┃ Round(values, 1) ┃
Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/expr/test_sql_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_sign(functional_alltypes, lineitem):
def test_round(functional_alltypes, lineitem):
result = functional_alltypes.double_col.round()
assert isinstance(result, ir.IntegerColumn)
assert result.op().args[1] is None
assert result.op().args[1] == ibis.literal(0).op()

result = functional_alltypes.double_col.round(2)
assert isinstance(result, ir.FloatingColumn)
Expand Down

0 comments on commit b2b0925

Please sign in to comment.