Skip to content

Commit

Permalink
opt: fold unary minus operations
Browse files Browse the repository at this point in the history
This commit introduces a new rule, FoldUnaryMinus, which transforms
constant expressions of the form (UnaryMinus x) to -x.

There were some cases around this which would result in us not correctly
recognizing expressions like -1:::FLOAT as constant, causing us to
perform a full table scan instead of a point lookup.

Release note: None
  • Loading branch information
Justin Jaffray committed Jun 26, 2018
1 parent 95ed31b commit 5d0be37
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 13 deletions.
48 changes: 48 additions & 0 deletions pkg/sql/opt/exec/execbuilder/testdata/select
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,54 @@ scan · · (f float) ·
statement ok
DROP TABLE flt

# ------------------------------------------------------------------------------
# Verify we create the correct spans for negative numbers with extra
# operations.
# ------------------------------------------------------------------------------

statement ok
CREATE TABLE num (
i int null,
unique index (i),
f float null,
unique index (f),
d decimal null,
unique index (d),
n interval null,
unique index (n)
)

query TTTTT
EXPLAIN (TYPES) SELECT i FROM num WHERE i = -1:::INT
----
scan · · (i int) ·
· table num@num_i_key · ·
· spans /-1-/0 · ·

query TTTTT
EXPLAIN (TYPES) SELECT f FROM num WHERE f = -1:::FLOAT
----
scan · · (f float) ·
· table num@num_f_key · ·
· spans /-1-/-1/PrefixEnd · ·

query TTTTT
EXPLAIN (TYPES) SELECT d FROM num WHERE d = -1:::DECIMAL
----
scan · · (d decimal) ·
· table num@num_d_key · ·
· spans /-1-/-1/PrefixEnd · ·

query TTTTT
EXPLAIN (TYPES) SELECT n FROM num WHERE n = -'1h':::INTERVAL
----
scan · · (n interval) ·
· table num@num_n_key · ·
· spans /-1h-/1d-25h · ·

statement ok
DROP TABLE num

# ------------------------------------------------------------------------------
# ANY, ALL tests.
# ------------------------------------------------------------------------------
Expand Down
10 changes: 10 additions & 0 deletions pkg/sql/opt/idxconstraint/testdata/single-column
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ index-constraints vars=(int) index=(@1 desc)
----
[/5 - /5]

index-constraints vars=(int) index=(@1)
@1 = -1
----
[/-1 - /-1]

index-constraints vars=(decimal) index=(@1)
@1 = -2.0
----
[/-2.0 - /-2.0]

index-constraints vars=(int) index=(@1 desc)
@1 IS DISTINCT FROM 5
----
Expand Down
21 changes: 21 additions & 0 deletions pkg/sql/opt/memo/typing.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,24 @@ func findBinaryOverload(op opt.Operator, leftType, rightType types.T) (_ overloa
}
return overload{}, false
}

// EvalUnaryOp evaluates a unary expression on top of a constant value, returning
// a datum referring to the evaluated result. If an appropriate overload is not
// found, EvalUnaryOp returns an error.
func EvalUnaryOp(evalCtx *tree.EvalContext, op opt.Operator, ev ExprView) (tree.Datum, error) {
if !ev.IsConstValue() {
panic("expected const value")
}

unaryOp := opt.UnaryOpReverseMap[op]
datum := ExtractConstDatum(ev)

typ := datum.ResolvedType()
for _, unaryOverloads := range tree.UnaryOps[unaryOp] {
o := unaryOverloads.(tree.UnaryOp)
if o.Typ.Equivalent(typ) {
return o.Fn(evalCtx, datum)
}
}
return nil, fmt.Errorf("no overload found for %s applied to %s", op, typ)
}
19 changes: 19 additions & 0 deletions pkg/sql/opt/norm/custom_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,3 +761,22 @@ func (c *CustomFuncs) IsOne(input memo.GroupID) bool {
}
return false
}

// CanFoldUnaryMinus checks if a constant numeric value can be negated.
func (c *CustomFuncs) CanFoldUnaryMinus(input memo.GroupID) bool {
d := c.f.mem.LookupPrivate(c.f.mem.NormExpr(input).AsConst().Value()).(tree.Datum)
if t, ok := d.(*tree.DInt); ok {
return *t != math.MinInt64
}
return true
}

// NegateNumeric applies a unary minus to a numeric value.
func (c *CustomFuncs) NegateNumeric(input memo.GroupID) memo.GroupID {
ev := memo.MakeNormExprView(c.f.mem, input)
r, err := memo.EvalUnaryOp(c.f.evalCtx, opt.UnaryMinusOp, ev)
if err != nil {
panic(err)
}
return c.f.ConstructConst(c.f.InternDatum(r))
}
8 changes: 8 additions & 0 deletions pkg/sql/opt/norm/rules/numeric.opt
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,11 @@ $left
# EliminateUnaryMinus discards a doubled UnaryMinus operator.
[EliminateUnaryMinus, Normalize]
(UnaryMinus (UnaryMinus $input:*)) => $input

# FoldUnaryMinus negates a constant value within a UnaryMinus.
[FoldUnaryMinus, Normalize]
(UnaryMinus
$input:(Const) & (CanFoldUnaryMinus $input)
)
=>
(NegateNumeric $input)
113 changes: 113 additions & 0 deletions pkg/sql/opt/norm/testdata/rules/numeric
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,116 @@ project
│ └── columns: i:2(int)
└── projections [outer=(2)]
└── variable: a.i [type=int, outer=(2)]

# --------------------------------------------------
# FoldUnaryMinus
# --------------------------------------------------
opt
SELECT -(1:::int)
----
project
├── columns: "?column?":1(int!null)
├── cardinality: [1 - 1]
├── key: ()
├── fd: ()-->(1)
├── values
│ ├── cardinality: [1 - 1]
│ ├── key: ()
│ └── tuple [type=tuple{}]
└── projections
└── const: -1 [type=int]

opt
SELECT -(1:::float)
----
project
├── columns: "?column?":1(float!null)
├── cardinality: [1 - 1]
├── key: ()
├── fd: ()-->(1)
├── values
│ ├── cardinality: [1 - 1]
│ ├── key: ()
│ └── tuple [type=tuple{}]
└── projections
└── const: -1.0 [type=float]

# TODO(justin): it would be better if this produced an error in the optimizer
# rather than falling back to execution to error.
opt format=show-all
SELECT -((-9223372036854775808)::int)
----
project
├── columns: "?column?":1(int)
├── cardinality: [1 - 1]
├── stats: [rows=1]
├── cost: 0.01
├── key: ()
├── fd: ()-->(1)
├── prune: (1)
├── values
│ ├── cardinality: [1 - 1]
│ ├── stats: [rows=1]
│ ├── cost: 0.01
│ ├── key: ()
│ └── tuple [type=tuple{}]
└── projections
└── unary-minus [type=int]
└── const: -9223372036854775808 [type=int]

opt format=show-all
SELECT -(1:::decimal)
----
project
├── columns: "?column?":1(decimal!null)
├── cardinality: [1 - 1]
├── stats: [rows=1]
├── cost: 0.01
├── key: ()
├── fd: ()-->(1)
├── prune: (1)
├── values
│ ├── cardinality: [1 - 1]
│ ├── stats: [rows=1]
│ ├── cost: 0.01
│ ├── key: ()
│ └── tuple [type=tuple{}]
└── projections
└── const: -1 [type=decimal]

opt format=show-all
SELECT -('-1d'::interval);
----
project
├── columns: "?column?":1(interval!null)
├── cardinality: [1 - 1]
├── stats: [rows=1]
├── cost: 0.01
├── key: ()
├── fd: ()-->(1)
├── prune: (1)
├── values
│ ├── cardinality: [1 - 1]
│ ├── stats: [rows=1]
│ ├── cost: 0.01
│ ├── key: ()
│ └── tuple [type=tuple{}]
└── projections
└── const: '1d' [type=interval]

# TODO(justin): this seems incorrect but it's consistent with the existing
# planner. Revisit this: #26932.
opt
SELECT -('-9223372036854775808d'::interval);
----
project
├── columns: "?column?":1(interval!null)
├── cardinality: [1 - 1]
├── key: ()
├── fd: ()-->(1)
├── values
│ ├── cardinality: [1 - 1]
│ ├── key: ()
│ └── tuple [type=tuple{}]
└── projections
└── const: '-9223372036854775808d' [type=interval]
11 changes: 11 additions & 0 deletions pkg/sql/opt/optbuilder/testdata/scalar
Original file line number Diff line number Diff line change
Expand Up @@ -736,3 +736,14 @@ build-scalar
ARRAY['"foo"'::json]
----
error: arrays of jsonb not allowed

opt
SELECT -((-9223372036854775808):::int)
----
project
├── columns: "?column?":1(int)
├── values
│ └── tuple [type=tuple{}]
└── projections
└── unary-minus [type=int]
└── const: -9223372036854775808 [type=int]
4 changes: 2 additions & 2 deletions pkg/sql/opt/rule_name_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 11 additions & 11 deletions pkg/sql/sem/tree/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ const SecondsInDay = 24 * 60 * 60
type UnaryOp struct {
Typ types.T
ReturnType types.T
fn func(*EvalContext, Datum) (Datum, error)
Fn func(*EvalContext, Datum) (Datum, error)

types TypeList
retType ReturnTyper
Expand Down Expand Up @@ -106,21 +106,21 @@ var UnaryOps = map[UnaryOperator]unaryOpOverload{
UnaryOp{
Typ: types.Int,
ReturnType: types.Int,
fn: func(_ *EvalContext, d Datum) (Datum, error) {
Fn: func(_ *EvalContext, d Datum) (Datum, error) {
return d, nil
},
},
UnaryOp{
Typ: types.Float,
ReturnType: types.Float,
fn: func(_ *EvalContext, d Datum) (Datum, error) {
Fn: func(_ *EvalContext, d Datum) (Datum, error) {
return d, nil
},
},
UnaryOp{
Typ: types.Decimal,
ReturnType: types.Decimal,
fn: func(_ *EvalContext, d Datum) (Datum, error) {
Fn: func(_ *EvalContext, d Datum) (Datum, error) {
return d, nil
},
},
Expand All @@ -130,7 +130,7 @@ var UnaryOps = map[UnaryOperator]unaryOpOverload{
UnaryOp{
Typ: types.Int,
ReturnType: types.Int,
fn: func(_ *EvalContext, d Datum) (Datum, error) {
Fn: func(_ *EvalContext, d Datum) (Datum, error) {
i := MustBeDInt(d)
if i == math.MinInt64 {
return nil, errIntOutOfRange
Expand All @@ -141,14 +141,14 @@ var UnaryOps = map[UnaryOperator]unaryOpOverload{
UnaryOp{
Typ: types.Float,
ReturnType: types.Float,
fn: func(_ *EvalContext, d Datum) (Datum, error) {
Fn: func(_ *EvalContext, d Datum) (Datum, error) {
return NewDFloat(-*d.(*DFloat)), nil
},
},
UnaryOp{
Typ: types.Decimal,
ReturnType: types.Decimal,
fn: func(_ *EvalContext, d Datum) (Datum, error) {
Fn: func(_ *EvalContext, d Datum) (Datum, error) {
dec := &d.(*DDecimal).Decimal
dd := &DDecimal{}
dd.Decimal.Neg(dec)
Expand All @@ -158,7 +158,7 @@ var UnaryOps = map[UnaryOperator]unaryOpOverload{
UnaryOp{
Typ: types.Interval,
ReturnType: types.Interval,
fn: func(_ *EvalContext, d Datum) (Datum, error) {
Fn: func(_ *EvalContext, d Datum) (Datum, error) {
i := d.(*DInterval).Duration
i.Nanos = -i.Nanos
i.Days = -i.Days
Expand All @@ -172,14 +172,14 @@ var UnaryOps = map[UnaryOperator]unaryOpOverload{
UnaryOp{
Typ: types.Int,
ReturnType: types.Int,
fn: func(_ *EvalContext, d Datum) (Datum, error) {
Fn: func(_ *EvalContext, d Datum) (Datum, error) {
return NewDInt(^MustBeDInt(d)), nil
},
},
UnaryOp{
Typ: types.INet,
ReturnType: types.INet,
fn: func(_ *EvalContext, d Datum) (Datum, error) {
Fn: func(_ *EvalContext, d Datum) (Datum, error) {
ipAddr := MustBeDIPAddr(d).IPAddr
return NewDIPAddr(DIPAddr{ipAddr.Complement()}), nil
},
Expand Down Expand Up @@ -3728,7 +3728,7 @@ func (expr *UnaryExpr) Eval(ctx *EvalContext) (Datum, error) {
if d == DNull {
return DNull, nil
}
res, err := expr.fn.fn(ctx, d)
res, err := expr.fn.Fn(ctx, d)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 5d0be37

Please sign in to comment.