Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exec: support comparisons between all numeric types #39778

Merged
merged 2 commits into from
Aug 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions pkg/col/coltypes/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,30 @@ var AllTypes []T
// type in a binary expression.
var CompatibleTypes map[T][]T

// NumberTypes is a slice containing all numeric types.
var NumberTypes = []T{Int8, Int16, Int32, Int64, Float32, Float64, Decimal}

// IntTypes is a slice containing all int types.
var IntTypes = []T{Int8, Int16, Int32, Int64}

// FloatTypes is a slice containing all float types.
var FloatTypes = []T{Float32, Float64}

func init() {
for i := Bool; i < Unhandled; i++ {
AllTypes = append(AllTypes, i)
}

intTypes := []T{Int8, Int16, Int32, Int64}
floatTypes := []T{Float32, Float64}

CompatibleTypes = make(map[T][]T)
CompatibleTypes[Bool] = append(CompatibleTypes[Bool], Bool)
CompatibleTypes[Bytes] = append(CompatibleTypes[Bytes], Bytes)
CompatibleTypes[Decimal] = append(CompatibleTypes[Decimal], Decimal)
CompatibleTypes[Int8] = append(CompatibleTypes[Int8], intTypes...)
CompatibleTypes[Int16] = append(CompatibleTypes[Int16], intTypes...)
CompatibleTypes[Int32] = append(CompatibleTypes[Int32], intTypes...)
CompatibleTypes[Int64] = append(CompatibleTypes[Int64], intTypes...)
CompatibleTypes[Float32] = append(CompatibleTypes[Float32], floatTypes...)
CompatibleTypes[Float64] = append(CompatibleTypes[Float64], floatTypes...)
CompatibleTypes[Decimal] = append(CompatibleTypes[Decimal], NumberTypes...)
CompatibleTypes[Int8] = append(CompatibleTypes[Int8], NumberTypes...)
CompatibleTypes[Int16] = append(CompatibleTypes[Int16], NumberTypes...)
CompatibleTypes[Int32] = append(CompatibleTypes[Int32], NumberTypes...)
CompatibleTypes[Int64] = append(CompatibleTypes[Int64], NumberTypes...)
CompatibleTypes[Float32] = append(CompatibleTypes[Float32], NumberTypes...)
CompatibleTypes[Float64] = append(CompatibleTypes[Float64], NumberTypes...)
}

// FromGoType returns the type for a Go value, if applicable. Shouldn't be used at
Expand Down
220 changes: 194 additions & 26 deletions pkg/sql/exec/execgen/cmd/execgen/overloads.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func init() {
sameTypeComparisonOpToOverloads = make(map[tree.ComparisonOperator][]*overload, len(comparisonOpName))
anyTypeComparisonOpToOverloads = make(map[tree.ComparisonOperator][]*overload, len(comparisonOpName))
for _, t := range inputTypes {
customizer := typeCustomizers[t]
customizer := typeCustomizers[coltypePair{t, t}]
for _, op := range binOps {
// Skip types that don't have associated binary ops.
switch t {
Expand Down Expand Up @@ -200,8 +200,8 @@ func init() {
hashOverloads = append(hashOverloads, ov)
}
for _, leftType := range inputTypes {
customizer := typeCustomizers[leftType]
for _, rightType := range coltypes.CompatibleTypes[leftType] {
customizer := typeCustomizers[coltypePair{leftType, rightType}]
for _, op := range cmpOps {
opStr := comparisonOpInfix[op]
ov := &overload{
Expand Down Expand Up @@ -258,14 +258,18 @@ func init() {
// (==, <, etc) or binary operator (+, -, etc) semantics.
type typeCustomizer interface{}

// TODO(rafi): make this map keyed by (leftType, rightType) so we can have
// customizers for mixed-type operations.
var typeCustomizers map[coltypes.T]typeCustomizer
// coltypePair is used to key a map that holds all typeCustomizers.
type coltypePair struct {
leftType coltypes.T
rightType coltypes.T
}

var typeCustomizers map[coltypePair]typeCustomizer

// registerTypeCustomizer registers a particular type customizer to a type, for
// usage by templates.
func registerTypeCustomizer(t coltypes.T, customizer typeCustomizer) {
typeCustomizers[t] = customizer
// registerTypeCustomizer registers a particular type customizer to a
// pair of types, for usage by templates.
func registerTypeCustomizer(pair coltypePair, customizer typeCustomizer) {
typeCustomizers[pair] = customizer
}

// binOpTypeCustomizer is a type customizer that changes how the templater
Expand Down Expand Up @@ -301,9 +305,33 @@ type decimalCustomizer struct{}
// floatCustomizers are used for hash functions.
type floatCustomizer struct{ width int }

// intCustomizers are used for hash functions.
// intCustomizers are used for hash functions and overflow handling.
type intCustomizer struct{ width int }

// decimalFloatCustomizer supports mixed type expressions with a decimal
// left-hand side and a float right-hand side.
type decimalFloatCustomizer struct{}

// decimalIntCustomizer supports mixed type expressions with a decimal left-hand
// side and an int right-hand side.
type decimalIntCustomizer struct{}

// floatDecimalCustomizer supports mixed type expressions with a float left-hand
// side and a decimal right-hand side.
type floatDecimalCustomizer struct{}

// intDecimalCustomizer supports mixed type expressions with an int left-hand
// side and a decimal right-hand side.
type intDecimalCustomizer struct{}

// floatIntCustomizer supports mixed type expressions with a float left-hand
// side and an int right-hand side.
type floatIntCustomizer struct{}

// intFloatCustomizer supports mixed type expressions with an int left-hand
// side and a float right-hand side.
type intFloatCustomizer struct{}

func (boolCustomizer) getCmpOpCompareFunc() compareFunc {
return func(target, l, r string) string {
args := map[string]string{"Target": target, "Left": l, "Right": r}
Expand Down Expand Up @@ -386,13 +414,20 @@ func (c floatCustomizer) getHashAssignFunc() assignFunc {
}
}

func (c floatCustomizer) getCmpOpCompareFunc() compareFunc {
func getFloatCmpOpCompareFunc(checkLeftNan, checkRightNan bool) compareFunc {
return func(target, l, r string) string {
args := map[string]string{"Target": target, "Left": l, "Right": r}
args := map[string]interface{}{
"Target": target,
"Left": l,
"Right": r,
"CheckLeftNan": checkLeftNan,
"CheckRightNan": checkRightNan}
buf := strings.Builder{}
// In SQL, NaN is treated as less than all other float values. In Go, any
// comparison with NaN returns false. To allow floats of different sizes to
// be compared, always upcast to float64.
// be compared, always upcast to float64. The CheckLeftNan and CheckRightNan
// flags skip NaN checks when the input is an int (which is necessary to
// pass linting.)
t := template.Must(template.New("").Parse(`
{
a, b := float64({{.Left}}), float64({{.Right}})
Expand All @@ -402,8 +437,8 @@ func (c floatCustomizer) getCmpOpCompareFunc() compareFunc {
{{.Target}} = 1
} else if a == b {
{{.Target}} = 0
} else if math.IsNaN(a) {
if math.IsNaN(b) {
} else if {{ if .CheckLeftNan }} math.IsNaN(a) {{ else }} false {{ end }} {
if {{ if .CheckRightNan }} math.IsNaN(b) {{ else }} false {{ end }} {
{{.Target}} = 0
} else {
{{.Target}} = -1
Expand All @@ -421,6 +456,10 @@ func (c floatCustomizer) getCmpOpCompareFunc() compareFunc {
}
}

func (c floatCustomizer) getCmpOpCompareFunc() compareFunc {
return getFloatCmpOpCompareFunc(true /* checkLeftNan */, true /* checkRightNan */)
}

func (c intCustomizer) getHashAssignFunc() assignFunc {
return func(op overload, target, v, _ string) string {
return fmt.Sprintf("%[1]s = memhash%[3]d(noescape(unsafe.Pointer(&%[2]s)), %[1]s)", target, v, c.width)
Expand Down Expand Up @@ -450,7 +489,6 @@ func (c intCustomizer) getCmpOpCompareFunc() compareFunc {
}
return buf.String()
}

}

func (c intCustomizer) getBinOpAssignFunc() assignFunc {
Expand Down Expand Up @@ -567,17 +605,147 @@ func (c intCustomizer) getBinOpAssignFunc() assignFunc {
}
}

func (c decimalFloatCustomizer) getCmpOpCompareFunc() compareFunc {
return func(target, l, r string) string {
args := map[string]string{"Target": target, "Left": l, "Right": r}
buf := strings.Builder{}
// todo(rafi): is there a way to avoid allocating on each comparison?
t := template.Must(template.New("").Parse(`
{
tmpDec := &apd.Decimal{}
if _, err := tmpDec.SetFloat64(float64({{.Right}})); err != nil {
execerror.NonVectorizedPanic(err)
}
{{.Target}} = tree.CompareDecimals(&{{.Left}}, tmpDec)
}
`))

if err := t.Execute(&buf, args); err != nil {
execerror.VectorizedInternalPanic(err)
}
return buf.String()
}
}

func (c decimalIntCustomizer) getCmpOpCompareFunc() compareFunc {
return func(target, l, r string) string {
args := map[string]string{"Target": target, "Left": l, "Right": r}
buf := strings.Builder{}
// todo(rafi): is there a way to avoid allocating on each comparison?
t := template.Must(template.New("").Parse(`
{
tmpDec := &apd.Decimal{}
tmpDec.SetFinite(int64({{.Right}}), 0)
{{.Target}} = tree.CompareDecimals(&{{.Left}}, tmpDec)
}
`))

if err := t.Execute(&buf, args); err != nil {
execerror.VectorizedInternalPanic(err)
}
return buf.String()
}
}

func (c floatDecimalCustomizer) getCmpOpCompareFunc() compareFunc {
return func(target, l, r string) string {
args := map[string]string{"Target": target, "Left": l, "Right": r}
buf := strings.Builder{}
// todo(rafi): is there a way to avoid allocating on each comparison?
t := template.Must(template.New("").Parse(`
{
tmpDec := &apd.Decimal{}
if _, err := tmpDec.SetFloat64(float64({{.Left}})); err != nil {
execerror.NonVectorizedPanic(err)
}
{{.Target}} = tree.CompareDecimals(tmpDec, &{{.Right}})
}
`))

if err := t.Execute(&buf, args); err != nil {
execerror.VectorizedInternalPanic(err)
}
return buf.String()
}
}

func (c intDecimalCustomizer) getCmpOpCompareFunc() compareFunc {
return func(target, l, r string) string {
args := map[string]string{"Target": target, "Left": l, "Right": r}
buf := strings.Builder{}
// todo(rafi): is there a way to avoid allocating on each comparison?
t := template.Must(template.New("").Parse(`
{
tmpDec := &apd.Decimal{}
tmpDec.SetFinite(int64({{.Left}}), 0)
{{.Target}} = tree.CompareDecimals(tmpDec, &{{.Right}})
}
`))

if err := t.Execute(&buf, args); err != nil {
execerror.VectorizedInternalPanic(err)
}
return buf.String()
}
}

func (c floatIntCustomizer) getCmpOpCompareFunc() compareFunc {
// floatCustomizer's comparison function can be reused since float-int
// comparison works by casting the int.
return getFloatCmpOpCompareFunc(true /* checkLeftNan */, false /* checkRightNan */)
}

func (c intFloatCustomizer) getCmpOpCompareFunc() compareFunc {
// floatCustomizer's comparison function can be reused since int-float
// comparison works by casting the int.
return getFloatCmpOpCompareFunc(false /* checkLeftNan */, true /* checkRightNan */)
}

func registerTypeCustomizers() {
typeCustomizers = make(map[coltypes.T]typeCustomizer)
registerTypeCustomizer(coltypes.Bool, boolCustomizer{})
registerTypeCustomizer(coltypes.Bytes, bytesCustomizer{})
registerTypeCustomizer(coltypes.Decimal, decimalCustomizer{})
registerTypeCustomizer(coltypes.Float32, floatCustomizer{width: 32})
registerTypeCustomizer(coltypes.Float64, floatCustomizer{width: 64})
registerTypeCustomizer(coltypes.Int8, intCustomizer{width: 8})
registerTypeCustomizer(coltypes.Int16, intCustomizer{width: 16})
registerTypeCustomizer(coltypes.Int32, intCustomizer{width: 32})
registerTypeCustomizer(coltypes.Int64, intCustomizer{width: 64})
typeCustomizers = make(map[coltypePair]typeCustomizer)
registerTypeCustomizer(coltypePair{coltypes.Bool, coltypes.Bool}, boolCustomizer{})
registerTypeCustomizer(coltypePair{coltypes.Bytes, coltypes.Bytes}, bytesCustomizer{})
registerTypeCustomizer(coltypePair{coltypes.Decimal, coltypes.Decimal}, decimalCustomizer{})
for _, leftFloatType := range coltypes.FloatTypes {
for _, rightFloatType := range coltypes.FloatTypes {
registerTypeCustomizer(coltypePair{leftFloatType, rightFloatType}, floatCustomizer{width: 64})
}
}
for _, leftIntType := range coltypes.IntTypes {
for _, rightIntType := range coltypes.IntTypes {
registerTypeCustomizer(coltypePair{leftIntType, rightIntType}, intCustomizer{width: 64})
}
}
// Use a customizer of appropriate width when widths are the same.
registerTypeCustomizer(coltypePair{coltypes.Float32, coltypes.Float32}, floatCustomizer{width: 32})
registerTypeCustomizer(coltypePair{coltypes.Float64, coltypes.Float64}, floatCustomizer{width: 64})
registerTypeCustomizer(coltypePair{coltypes.Int8, coltypes.Int8}, intCustomizer{width: 8})
registerTypeCustomizer(coltypePair{coltypes.Int16, coltypes.Int16}, intCustomizer{width: 16})
registerTypeCustomizer(coltypePair{coltypes.Int32, coltypes.Int32}, intCustomizer{width: 32})
registerTypeCustomizer(coltypePair{coltypes.Int64, coltypes.Int64}, intCustomizer{width: 64})

for _, rightFloatType := range coltypes.FloatTypes {
registerTypeCustomizer(coltypePair{coltypes.Decimal, rightFloatType}, decimalFloatCustomizer{})
}
for _, rightIntType := range coltypes.IntTypes {
registerTypeCustomizer(coltypePair{coltypes.Decimal, rightIntType}, decimalIntCustomizer{})
}
for _, leftFloatType := range coltypes.FloatTypes {
registerTypeCustomizer(coltypePair{leftFloatType, coltypes.Decimal}, floatDecimalCustomizer{})
}
for _, leftIntType := range coltypes.IntTypes {
registerTypeCustomizer(coltypePair{leftIntType, coltypes.Decimal}, intDecimalCustomizer{})
}
for _, leftFloatType := range coltypes.FloatTypes {
for _, rightIntType := range coltypes.IntTypes {
registerTypeCustomizer(coltypePair{leftFloatType, rightIntType}, floatIntCustomizer{})
}
}
for _, leftIntType := range coltypes.IntTypes {
for _, rightFloatType := range coltypes.FloatTypes {
registerTypeCustomizer(coltypePair{leftIntType, rightFloatType}, intFloatCustomizer{})
}
}
}

// Avoid unused warning for functions which are only used in templates.
Expand Down
29 changes: 12 additions & 17 deletions pkg/sql/exec/execgen/cmd/execgen/overloads_test_utils_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,31 @@ package main
import (
"io"
"text/template"

"github.com/cockroachdb/cockroach/pkg/col/coltypes"
)

const overloadsTestUtilsTemplate = `
package exec

import (
"math"
"bytes"
"math"

"github.com/cockroachdb/apd"
"github.com/cockroachdb/cockroach/pkg/sql/exec/execerror"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
)

{{define "opName"}}perform{{.Name}}{{.LTyp}}{{end}}
{{define "opName"}}perform{{.Name}}{{.LTyp}}{{.RTyp}}{{end}}

{{/* The outer range is a coltypes.T, and the inner is the overloads associated
with that type. */}}
{{range .}}
{{/* The range is over all overloads */}}
{{range .}}

func {{template "opName" .}}(a, b {{.LTyp.GoTypeName}}) {{.RetTyp.GoTypeName}} {
{{(.Assign "a" "a" "b")}}
return a
func {{template "opName" .}}(a {{.LTyp.GoTypeName}}, b {{.RTyp.GoTypeName}}) {{.RetTyp.GoTypeName}} {
var r {{.RetTyp.GoTypeName}}
{{(.Assign "r" "a" "b")}}
return r
}

{{end}}
{{end}}
`

Expand All @@ -53,12 +50,10 @@ func genOverloadsTestUtils(wr io.Writer) error {
return err
}

typToOverloads := make(map[coltypes.T][]*overload)
for _, overload := range binaryOpOverloads {
typ := overload.LTyp
typToOverloads[typ] = append(typToOverloads[typ], overload)
}
return tmpl.Execute(wr, typToOverloads)
allOverloads := make([]*overload, 0)
allOverloads = append(allOverloads, binaryOpOverloads...)
allOverloads = append(allOverloads, comparisonOpOverloads...)
return tmpl.Execute(wr, allOverloads)
}

func init() {
Expand Down
Loading