diff --git a/pkg/sql/opt/exec/execbuilder/testdata/select b/pkg/sql/opt/exec/execbuilder/testdata/select index c705d02692c3..0e555cfffb32 100644 --- a/pkg/sql/opt/exec/execbuilder/testdata/select +++ b/pkg/sql/opt/exec/execbuilder/testdata/select @@ -521,6 +521,54 @@ scan · · (f float) · statement ok DROP TABLE flt +# ------------------------------------------------------------------------------ +# Verify we create the correct spans for negative numbers with extra +# operations. +# ------------------------------------------------------------------------------ + +statement ok +CREATE TABLE num ( + i int null, + unique index (i), + f float null, + unique index (f), + d decimal null, + unique index (d), + n interval null, + unique index (n) +) + +query TTTTT +EXPLAIN (TYPES) SELECT i FROM num WHERE i = -1:::INT +---- +scan · · (i int) · +· table num@num_i_key · · +· spans /-1-/0 · · + +query TTTTT +EXPLAIN (TYPES) SELECT f FROM num WHERE f = -1:::FLOAT +---- +scan · · (f float) · +· table num@num_f_key · · +· spans /-1-/-1/PrefixEnd · · + +query TTTTT +EXPLAIN (TYPES) SELECT d FROM num WHERE d = -1:::DECIMAL +---- +scan · · (d decimal) · +· table num@num_d_key · · +· spans /-1-/-1/PrefixEnd · · + +query TTTTT +EXPLAIN (TYPES) SELECT n FROM num WHERE n = -'1h':::INTERVAL +---- +scan · · (n interval) · +· table num@num_n_key · · +· spans /-1h-/1d-25h · · + +statement ok +DROP TABLE num + # ------------------------------------------------------------------------------ # ANY, ALL tests. # ------------------------------------------------------------------------------ diff --git a/pkg/sql/opt/idxconstraint/testdata/single-column b/pkg/sql/opt/idxconstraint/testdata/single-column index f98b682892c4..667effbfc76f 100644 --- a/pkg/sql/opt/idxconstraint/testdata/single-column +++ b/pkg/sql/opt/idxconstraint/testdata/single-column @@ -228,6 +228,16 @@ index-constraints vars=(int) index=(@1 desc) ---- [/5 - /5] +index-constraints vars=(int) index=(@1) +@1 = -1 +---- +[/-1 - /-1] + +index-constraints vars=(decimal) index=(@1) +@1 = -2.0 +---- +[/-2.0 - /-2.0] + index-constraints vars=(int) index=(@1 desc) @1 IS DISTINCT FROM 5 ---- diff --git a/pkg/sql/opt/memo/typing.go b/pkg/sql/opt/memo/typing.go index 60e3e86fe257..f52b38e64ae1 100644 --- a/pkg/sql/opt/memo/typing.go +++ b/pkg/sql/opt/memo/typing.go @@ -347,3 +347,24 @@ func findBinaryOverload(op opt.Operator, leftType, rightType types.T) (_ overloa } return overload{}, false } + +// EvalUnaryOp evaluates a unary expression on top of a constant value, returning +// a datum referring to the evaluated result. If an appropriate overload is not +// found, EvalUnaryOp returns an error. +func EvalUnaryOp(evalCtx *tree.EvalContext, op opt.Operator, ev ExprView) (tree.Datum, error) { + if !ev.IsConstValue() { + panic("expected const value") + } + + unaryOp := opt.UnaryOpReverseMap[op] + datum := ExtractConstDatum(ev) + + typ := datum.ResolvedType() + for _, unaryOverloads := range tree.UnaryOps[unaryOp] { + o := unaryOverloads.(tree.UnaryOp) + if o.Typ.Equivalent(typ) { + return o.Fn(evalCtx, datum) + } + } + return nil, fmt.Errorf("no overload found for %s applied to %s", op, typ) +} diff --git a/pkg/sql/opt/norm/custom_funcs.go b/pkg/sql/opt/norm/custom_funcs.go index 49cf6222e660..da5d71aceb0e 100644 --- a/pkg/sql/opt/norm/custom_funcs.go +++ b/pkg/sql/opt/norm/custom_funcs.go @@ -761,3 +761,22 @@ func (c *CustomFuncs) IsOne(input memo.GroupID) bool { } return false } + +// CanFoldUnaryMinus checks if a constant numeric value can be negated. +func (c *CustomFuncs) CanFoldUnaryMinus(input memo.GroupID) bool { + d := c.f.mem.LookupPrivate(c.f.mem.NormExpr(input).AsConst().Value()).(tree.Datum) + if t, ok := d.(*tree.DInt); ok { + return *t != math.MinInt64 + } + return true +} + +// NegateNumeric applies a unary minus to a numeric value. +func (c *CustomFuncs) NegateNumeric(input memo.GroupID) memo.GroupID { + ev := memo.MakeNormExprView(c.f.mem, input) + r, err := memo.EvalUnaryOp(c.f.evalCtx, opt.UnaryMinusOp, ev) + if err != nil { + panic(err) + } + return c.f.ConstructConst(c.f.InternDatum(r)) +} diff --git a/pkg/sql/opt/norm/rules/numeric.opt b/pkg/sql/opt/norm/rules/numeric.opt index 046273a9c30b..57c93370ec0c 100644 --- a/pkg/sql/opt/norm/rules/numeric.opt +++ b/pkg/sql/opt/norm/rules/numeric.opt @@ -65,3 +65,11 @@ $left # EliminateUnaryMinus discards a doubled UnaryMinus operator. [EliminateUnaryMinus, Normalize] (UnaryMinus (UnaryMinus $input:*)) => $input + +# FoldUnaryMinus negates a constant value within a UnaryMinus. +[FoldUnaryMinus, Normalize] +(UnaryMinus + $input:(Const) & (CanFoldUnaryMinus $input) +) +=> +(NegateNumeric $input) diff --git a/pkg/sql/opt/norm/testdata/rules/numeric b/pkg/sql/opt/norm/testdata/rules/numeric index 2b13e4c4477a..4767d3208172 100644 --- a/pkg/sql/opt/norm/testdata/rules/numeric +++ b/pkg/sql/opt/norm/testdata/rules/numeric @@ -130,3 +130,116 @@ project │ └── columns: i:2(int) └── projections [outer=(2)] └── variable: a.i [type=int, outer=(2)] + +# -------------------------------------------------- +# FoldUnaryMinus +# -------------------------------------------------- +opt +SELECT -(1:::int) +---- +project + ├── columns: "?column?":1(int!null) + ├── cardinality: [1 - 1] + ├── key: () + ├── fd: ()-->(1) + ├── values + │ ├── cardinality: [1 - 1] + │ ├── key: () + │ └── tuple [type=tuple{}] + └── projections + └── const: -1 [type=int] + +opt +SELECT -(1:::float) +---- +project + ├── columns: "?column?":1(float!null) + ├── cardinality: [1 - 1] + ├── key: () + ├── fd: ()-->(1) + ├── values + │ ├── cardinality: [1 - 1] + │ ├── key: () + │ └── tuple [type=tuple{}] + └── projections + └── const: -1.0 [type=float] + +# TODO(justin): it would be better if this produced an error in the optimizer +# rather than falling back to execution to error. +opt format=show-all +SELECT -((-9223372036854775808)::int) +---- +project + ├── columns: "?column?":1(int) + ├── cardinality: [1 - 1] + ├── stats: [rows=1] + ├── cost: 0.01 + ├── key: () + ├── fd: ()-->(1) + ├── prune: (1) + ├── values + │ ├── cardinality: [1 - 1] + │ ├── stats: [rows=1] + │ ├── cost: 0.01 + │ ├── key: () + │ └── tuple [type=tuple{}] + └── projections + └── unary-minus [type=int] + └── const: -9223372036854775808 [type=int] + +opt format=show-all +SELECT -(1:::decimal) +---- +project + ├── columns: "?column?":1(decimal!null) + ├── cardinality: [1 - 1] + ├── stats: [rows=1] + ├── cost: 0.01 + ├── key: () + ├── fd: ()-->(1) + ├── prune: (1) + ├── values + │ ├── cardinality: [1 - 1] + │ ├── stats: [rows=1] + │ ├── cost: 0.01 + │ ├── key: () + │ └── tuple [type=tuple{}] + └── projections + └── const: -1 [type=decimal] + +opt format=show-all +SELECT -('-1d'::interval); +---- +project + ├── columns: "?column?":1(interval!null) + ├── cardinality: [1 - 1] + ├── stats: [rows=1] + ├── cost: 0.01 + ├── key: () + ├── fd: ()-->(1) + ├── prune: (1) + ├── values + │ ├── cardinality: [1 - 1] + │ ├── stats: [rows=1] + │ ├── cost: 0.01 + │ ├── key: () + │ └── tuple [type=tuple{}] + └── projections + └── const: '1d' [type=interval] + +# TODO(justin): this seems incorrect but it's consistent with the existing +# planner. Revisit this: #26932. +opt +SELECT -('-9223372036854775808d'::interval); +---- +project + ├── columns: "?column?":1(interval!null) + ├── cardinality: [1 - 1] + ├── key: () + ├── fd: ()-->(1) + ├── values + │ ├── cardinality: [1 - 1] + │ ├── key: () + │ └── tuple [type=tuple{}] + └── projections + └── const: '-9223372036854775808d' [type=interval] diff --git a/pkg/sql/opt/optbuilder/testdata/scalar b/pkg/sql/opt/optbuilder/testdata/scalar index 62e7c787a4ce..053ce1a63068 100644 --- a/pkg/sql/opt/optbuilder/testdata/scalar +++ b/pkg/sql/opt/optbuilder/testdata/scalar @@ -736,3 +736,14 @@ build-scalar ARRAY['"foo"'::json] ---- error: arrays of jsonb not allowed + +opt +SELECT -((-9223372036854775808):::int) +---- +project + ├── columns: "?column?":1(int) + ├── values + │ └── tuple [type=tuple{}] + └── projections + └── unary-minus [type=int] + └── const: -9223372036854775808 [type=int] diff --git a/pkg/sql/opt/rule_name_string.go b/pkg/sql/opt/rule_name_string.go index bc58f62df309..03013fdda114 100644 --- a/pkg/sql/opt/rule_name_string.go +++ b/pkg/sql/opt/rule_name_string.go @@ -4,9 +4,9 @@ package opt import "strconv" -const _RuleName_name = "InvalidRuleNameNumManualRuleNamesEliminateEmptyAndEliminateEmptyOrEliminateSingletonAndOrSimplifyAndSimplifyOrSimplifyFiltersFoldNullAndOrNegateComparisonEliminateNotNegateAndNegateOrExtractRedundantClauseExtractRedundantSubclauseCommuteVarInequalityCommuteConstInequalityNormalizeCmpPlusConstNormalizeCmpMinusConstNormalizeCmpConstMinusNormalizeTupleEqualityFoldNullComparisonLeftFoldNullComparisonRightFoldIsNullFoldNonNullIsNullFoldIsNotNullFoldNonNullIsNotNullCommuteNullIsDecorrelateJoinTryDecorrelateSelectTryDecorrelateProjectTryDecorrelateProjectSelectTryDecorrelateScalarGroupByHoistSelectExistsHoistSelectNotExistsHoistSelectSubqueryHoistProjectSubqueryHoistJoinSubqueryHoistValuesSubqueryNormalizeAnyFilterNormalizeNotAnyFilterEliminateDistinctEliminateGroupByProjectPushSelectIntoInlinableProjectInlineProjectInProjectEnsureJoinFiltersAndEnsureJoinFiltersPushFilterIntoJoinLeftPushFilterIntoJoinRightSimplifyLeftJoinSimplifyRightJoinEliminateSemiJoinEliminateAntiJoinEliminateJoinNoColsLeftEliminateJoinNoColsRightEliminateLimitPushLimitIntoProjectPushOffsetIntoProjectEliminateMax1RowFoldPlusZeroFoldZeroPlusFoldMinusZeroFoldMultOneFoldOneMultFoldDivOneInvertMinusEliminateUnaryMinusEliminateProjectEliminateProjectProjectPruneProjectColsPruneScanColsPruneSelectColsPruneLimitColsPruneOffsetColsPruneJoinLeftColsPruneJoinRightColsPruneAggColsPruneGroupByColsPruneValuesColsPruneRowNumberColsCommuteVarCommuteConstEliminateCoalesceSimplifyCoalesceEliminateCastFoldNullCastFoldNullUnaryFoldNullBinaryLeftFoldNullBinaryRightFoldNullInNonEmptyFoldNullInEmptyFoldNullNotInEmptyNormalizeInConstFoldInNullEliminateExistsProjectEliminateExistsGroupByEliminateSelectEnsureSelectFiltersAndEnsureSelectFiltersMergeSelectsPushSelectIntoProjectPushSelectIntoJoinLeftPushSelectIntoJoinRightMergeSelectInnerJoinPushSelectIntoGroupByGenerateMergeJoinsPushLimitIntoScanPushLimitIntoLookupJoinGenerateIndexScansConstrainScanPushFilterIntoLookupJoinNoRemainderPushFilterIntoLookupJoinConstrainLookupJoinIndexScanNumRuleNames" +const _RuleName_name = "InvalidRuleNameNumManualRuleNamesEliminateEmptyAndEliminateEmptyOrEliminateSingletonAndOrSimplifyAndSimplifyOrSimplifyFiltersFoldNullAndOrNegateComparisonEliminateNotNegateAndNegateOrExtractRedundantClauseExtractRedundantSubclauseCommuteVarInequalityCommuteConstInequalityNormalizeCmpPlusConstNormalizeCmpMinusConstNormalizeCmpConstMinusNormalizeTupleEqualityFoldNullComparisonLeftFoldNullComparisonRightFoldIsNullFoldNonNullIsNullFoldIsNotNullFoldNonNullIsNotNullCommuteNullIsDecorrelateJoinTryDecorrelateSelectTryDecorrelateProjectTryDecorrelateProjectSelectTryDecorrelateScalarGroupByHoistSelectExistsHoistSelectNotExistsHoistSelectSubqueryHoistProjectSubqueryHoistJoinSubqueryHoistValuesSubqueryNormalizeAnyFilterNormalizeNotAnyFilterEliminateDistinctEliminateGroupByProjectPushSelectIntoInlinableProjectInlineProjectInProjectEnsureJoinFiltersAndEnsureJoinFiltersPushFilterIntoJoinLeftPushFilterIntoJoinRightSimplifyLeftJoinSimplifyRightJoinEliminateSemiJoinEliminateAntiJoinEliminateJoinNoColsLeftEliminateJoinNoColsRightEliminateLimitPushLimitIntoProjectPushOffsetIntoProjectEliminateMax1RowFoldPlusZeroFoldZeroPlusFoldMinusZeroFoldMultOneFoldOneMultFoldDivOneInvertMinusEliminateUnaryMinusFoldUnaryMinusEliminateProjectEliminateProjectProjectPruneProjectColsPruneScanColsPruneSelectColsPruneLimitColsPruneOffsetColsPruneJoinLeftColsPruneJoinRightColsPruneAggColsPruneGroupByColsPruneValuesColsPruneRowNumberColsCommuteVarCommuteConstEliminateCoalesceSimplifyCoalesceEliminateCastFoldNullCastFoldNullUnaryFoldNullBinaryLeftFoldNullBinaryRightFoldNullInNonEmptyFoldNullInEmptyFoldNullNotInEmptyNormalizeInConstFoldInNullEliminateExistsProjectEliminateExistsGroupByEliminateSelectEnsureSelectFiltersAndEnsureSelectFiltersMergeSelectsPushSelectIntoProjectPushSelectIntoJoinLeftPushSelectIntoJoinRightMergeSelectInnerJoinPushSelectIntoGroupByGenerateMergeJoinsPushLimitIntoScanPushLimitIntoLookupJoinGenerateIndexScansConstrainScanPushFilterIntoLookupJoinNoRemainderPushFilterIntoLookupJoinConstrainLookupJoinIndexScanNumRuleNames" -var _RuleName_index = [...]uint16{0, 15, 33, 50, 66, 89, 100, 110, 125, 138, 154, 166, 175, 183, 205, 230, 250, 272, 293, 315, 337, 359, 381, 404, 414, 431, 444, 464, 477, 492, 512, 533, 560, 587, 604, 624, 643, 663, 680, 699, 717, 738, 755, 778, 808, 830, 850, 867, 889, 912, 928, 945, 962, 979, 1002, 1026, 1040, 1060, 1081, 1097, 1109, 1121, 1134, 1145, 1156, 1166, 1177, 1196, 1212, 1235, 1251, 1264, 1279, 1293, 1308, 1325, 1343, 1355, 1371, 1386, 1404, 1414, 1426, 1443, 1459, 1472, 1484, 1497, 1515, 1534, 1552, 1567, 1585, 1601, 1611, 1633, 1655, 1670, 1692, 1711, 1723, 1744, 1766, 1789, 1809, 1830, 1848, 1865, 1888, 1906, 1919, 1954, 1978, 2006, 2018} +var _RuleName_index = [...]uint16{0, 15, 33, 50, 66, 89, 100, 110, 125, 138, 154, 166, 175, 183, 205, 230, 250, 272, 293, 315, 337, 359, 381, 404, 414, 431, 444, 464, 477, 492, 512, 533, 560, 587, 604, 624, 643, 663, 680, 699, 717, 738, 755, 778, 808, 830, 850, 867, 889, 912, 928, 945, 962, 979, 1002, 1026, 1040, 1060, 1081, 1097, 1109, 1121, 1134, 1145, 1156, 1166, 1177, 1196, 1210, 1226, 1249, 1265, 1278, 1293, 1307, 1322, 1339, 1357, 1369, 1385, 1400, 1418, 1428, 1440, 1457, 1473, 1486, 1498, 1511, 1529, 1548, 1566, 1581, 1599, 1615, 1625, 1647, 1669, 1684, 1706, 1725, 1737, 1758, 1780, 1803, 1823, 1844, 1862, 1879, 1902, 1920, 1933, 1968, 1992, 2020, 2032} func (i RuleName) String() string { if i >= RuleName(len(_RuleName_index)-1) { diff --git a/pkg/sql/sem/tree/eval.go b/pkg/sql/sem/tree/eval.go index ea382ae32ee1..e85c8d55fc7a 100644 --- a/pkg/sql/sem/tree/eval.go +++ b/pkg/sql/sem/tree/eval.go @@ -68,7 +68,7 @@ const SecondsInDay = 24 * 60 * 60 type UnaryOp struct { Typ types.T ReturnType types.T - fn func(*EvalContext, Datum) (Datum, error) + Fn func(*EvalContext, Datum) (Datum, error) types TypeList retType ReturnTyper @@ -106,21 +106,21 @@ var UnaryOps = map[UnaryOperator]unaryOpOverload{ UnaryOp{ Typ: types.Int, ReturnType: types.Int, - fn: func(_ *EvalContext, d Datum) (Datum, error) { + Fn: func(_ *EvalContext, d Datum) (Datum, error) { return d, nil }, }, UnaryOp{ Typ: types.Float, ReturnType: types.Float, - fn: func(_ *EvalContext, d Datum) (Datum, error) { + Fn: func(_ *EvalContext, d Datum) (Datum, error) { return d, nil }, }, UnaryOp{ Typ: types.Decimal, ReturnType: types.Decimal, - fn: func(_ *EvalContext, d Datum) (Datum, error) { + Fn: func(_ *EvalContext, d Datum) (Datum, error) { return d, nil }, }, @@ -130,7 +130,7 @@ var UnaryOps = map[UnaryOperator]unaryOpOverload{ UnaryOp{ Typ: types.Int, ReturnType: types.Int, - fn: func(_ *EvalContext, d Datum) (Datum, error) { + Fn: func(_ *EvalContext, d Datum) (Datum, error) { i := MustBeDInt(d) if i == math.MinInt64 { return nil, errIntOutOfRange @@ -141,14 +141,14 @@ var UnaryOps = map[UnaryOperator]unaryOpOverload{ UnaryOp{ Typ: types.Float, ReturnType: types.Float, - fn: func(_ *EvalContext, d Datum) (Datum, error) { + Fn: func(_ *EvalContext, d Datum) (Datum, error) { return NewDFloat(-*d.(*DFloat)), nil }, }, UnaryOp{ Typ: types.Decimal, ReturnType: types.Decimal, - fn: func(_ *EvalContext, d Datum) (Datum, error) { + Fn: func(_ *EvalContext, d Datum) (Datum, error) { dec := &d.(*DDecimal).Decimal dd := &DDecimal{} dd.Decimal.Neg(dec) @@ -158,7 +158,7 @@ var UnaryOps = map[UnaryOperator]unaryOpOverload{ UnaryOp{ Typ: types.Interval, ReturnType: types.Interval, - fn: func(_ *EvalContext, d Datum) (Datum, error) { + Fn: func(_ *EvalContext, d Datum) (Datum, error) { i := d.(*DInterval).Duration i.Nanos = -i.Nanos i.Days = -i.Days @@ -172,14 +172,14 @@ var UnaryOps = map[UnaryOperator]unaryOpOverload{ UnaryOp{ Typ: types.Int, ReturnType: types.Int, - fn: func(_ *EvalContext, d Datum) (Datum, error) { + Fn: func(_ *EvalContext, d Datum) (Datum, error) { return NewDInt(^MustBeDInt(d)), nil }, }, UnaryOp{ Typ: types.INet, ReturnType: types.INet, - fn: func(_ *EvalContext, d Datum) (Datum, error) { + Fn: func(_ *EvalContext, d Datum) (Datum, error) { ipAddr := MustBeDIPAddr(d).IPAddr return NewDIPAddr(DIPAddr{ipAddr.Complement()}), nil }, @@ -3728,7 +3728,7 @@ func (expr *UnaryExpr) Eval(ctx *EvalContext) (Datum, error) { if d == DNull { return DNull, nil } - res, err := expr.fn.fn(ctx, d) + res, err := expr.fn.Fn(ctx, d) if err != nil { return nil, err }