Skip to content

Commit

Permalink
Merge pull request #10605 from eisenstatdavid/basic-collated-string
Browse files Browse the repository at this point in the history
sql: collation support, phase one
  • Loading branch information
David Eisenstat authored Nov 14, 2016
2 parents 180ee1d + 82d8588 commit 172ae88
Show file tree
Hide file tree
Showing 11 changed files with 706 additions and 261 deletions.
2 changes: 2 additions & 0 deletions pkg/sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,8 @@ func checkResultType(typ parser.Type) error {
// Compare all types that cannot rely on == equality.
istype := typ.FamilyEqual
switch {
case istype(parser.TypeCollatedString):
return nil
case istype(parser.TypePlaceholder):
return errors.Errorf("could not determine data type of %s", typ)
case istype(parser.TypeTuple):
Expand Down
107 changes: 107 additions & 0 deletions pkg/sql/parser/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ import (
"time"
"unsafe"

"golang.org/x/text/collate"
"golang.org/x/text/language"

"gopkg.in/inf.v0"

"github.com/cockroachdb/cockroach/pkg/roachpb"
Expand Down Expand Up @@ -548,6 +551,110 @@ func (d *DString) Size() uintptr {
return unsafe.Sizeof(*d) + uintptr(len(*d))
}

// DCollatedString is the Datum for strings with a locale. The struct members
// are intended to be immutable.
type DCollatedString struct {
Contents string
Locale string
// key is the collation key.
key []byte
}

// CollationEnvironment stores the state needed by NewDCollatedString to
// construct collation keys efficiently.
type CollationEnvironment struct {
cache map[string]collationEnvironmentCacheEntry
buffer collate.Buffer
}

type collationEnvironmentCacheEntry struct {
// locale is interned.
locale string
// collator is an expensive factory.
collator *collate.Collator
}

func (env *CollationEnvironment) getCacheEntry(locale string) collationEnvironmentCacheEntry {
entry, ok := env.cache[locale]
if !ok {
if env.cache == nil {
env.cache = make(map[string]collationEnvironmentCacheEntry)
}
entry = collationEnvironmentCacheEntry{locale, collate.New(language.MustParse(locale))}
env.cache[locale] = entry
}
return entry
}

// NewDCollatedString is a helper routine to create a *DCollatedString. Panics
// if locale is invalid. Not safe for concurrent use.
func NewDCollatedString(contents string, locale string, env *CollationEnvironment) *DCollatedString {
entry := env.getCacheEntry(locale)
key := entry.collator.KeyFromString(&env.buffer, contents)
d := DCollatedString{contents, entry.locale, make([]byte, len(key))}
copy(d.key, key)
env.buffer.Reset()
return &d
}

// Format implements the NodeFormatter interface.
func (d *DCollatedString) Format(buf *bytes.Buffer, f FmtFlags) {
encodeSQLString(buf, d.Contents)
}

// ResolvedType implements the TypedExpr interface.
func (d *DCollatedString) ResolvedType() Type {
return TCollatedString{d.Locale}
}

// Compare implements the Datum interface.
func (d *DCollatedString) Compare(other Datum) int {
if other == DNull {
// NULL is less than any non-NULL value.
return 1
}
v, ok := other.(*DCollatedString)
if !ok || d.Locale != v.Locale {
panic(makeUnsupportedComparisonMessage(d, other))
}
return bytes.Compare(d.key, v.key)
}

// HasPrev implements the Datum interface.
func (*DCollatedString) HasPrev() bool {
return false
}

// Prev implements the Datum interface.
func (d *DCollatedString) Prev() Datum {
panic(makeUnsupportedMethodMessage(d, "Prev"))
}

// HasNext implements the Datum interface.
func (*DCollatedString) HasNext() bool {
return false
}

// Next implements the Datum interface.
func (d *DCollatedString) Next() Datum {
panic(makeUnsupportedMethodMessage(d, "Next"))
}

// IsMax implements the Datum interface.
func (*DCollatedString) IsMax() bool {
return false
}

// IsMin implements the Datum interface.
func (d *DCollatedString) IsMin() bool {
return d.Contents == ""
}

// Size implements the Datum interface.
func (d *DCollatedString) Size() uintptr {
return unsafe.Sizeof(*d) + uintptr(len(d.Contents)) + uintptr(len(d.Locale)) + uintptr(len(d.key))
}

// DBytes is the bytes Datum. The underlying type is a string because we want
// the immutability, but this may contain arbitrary bytes.
type DBytes string
Expand Down
66 changes: 66 additions & 0 deletions pkg/sql/parser/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package parser

import (
"bytes"
"fmt"
"math"
"math/big"
Expand Down Expand Up @@ -883,6 +884,13 @@ var CmpOps = map[ComparisonOperator]cmpOpOverload{
return DBool(*left.(*DString) == *right.(*DString)), nil
},
},
CmpOp{
LeftType: TypeCollatedString,
RightType: TypeCollatedString,
fn: func(_ *EvalContext, left Datum, right Datum) (DBool, error) {
return DBool(bytes.Equal(left.(*DCollatedString).key, right.(*DCollatedString).key)), nil
},
},
CmpOp{
LeftType: TypeBytes,
RightType: TypeBytes,
Expand Down Expand Up @@ -1031,6 +1039,13 @@ var CmpOps = map[ComparisonOperator]cmpOpOverload{
return DBool(*left.(*DString) < *right.(*DString)), nil
},
},
CmpOp{
LeftType: TypeCollatedString,
RightType: TypeCollatedString,
fn: func(_ *EvalContext, left Datum, right Datum) (DBool, error) {
return DBool(bytes.Compare(left.(*DCollatedString).key, right.(*DCollatedString).key) < 0), nil
},
},
CmpOp{
LeftType: TypeBytes,
RightType: TypeBytes,
Expand Down Expand Up @@ -1178,6 +1193,13 @@ var CmpOps = map[ComparisonOperator]cmpOpOverload{
return DBool(*left.(*DString) <= *right.(*DString)), nil
},
},
CmpOp{
LeftType: TypeCollatedString,
RightType: TypeCollatedString,
fn: func(_ *EvalContext, left Datum, right Datum) (DBool, error) {
return DBool(bytes.Compare(left.(*DCollatedString).key, right.(*DCollatedString).key) <= 0), nil
},
},
CmpOp{
LeftType: TypeBytes,
RightType: TypeBytes,
Expand Down Expand Up @@ -1323,6 +1345,7 @@ var CmpOps = map[ComparisonOperator]cmpOpOverload{
makeEvalTupleIn(TypeFloat),
makeEvalTupleIn(TypeDecimal),
makeEvalTupleIn(TypeString),
makeEvalTupleIn(TypeCollatedString),
makeEvalTupleIn(TypeBytes),
makeEvalTupleIn(TypeDate),
makeEvalTupleIn(TypeTimestamp),
Expand Down Expand Up @@ -1476,6 +1499,8 @@ type EvalContext struct {
// (false) or not (true). It is set to true conditionally by
// EXPLAIN(TYPES[, NORMALIZE]).
SkipNormalize bool

collationEnv CollationEnvironment
}

// GetStmtTimestamp retrieves the current statement timestamp as per
Expand Down Expand Up @@ -1683,6 +1708,8 @@ func (expr *CastExpr) Eval(ctx *EvalContext) (Datum, error) {
return MakeDBool(v.Sign() != 0), nil
case *DString:
return ParseDBool(string(*v))
case *DCollatedString:
return ParseDBool(v.Contents)
}

case *IntColType:
Expand Down Expand Up @@ -1710,6 +1737,8 @@ func (expr *CastExpr) Eval(ctx *EvalContext) (Datum, error) {
return NewDInt(DInt(i)), nil
case *DString:
return ParseDInt(string(*v))
case *DCollatedString:
return ParseDInt(v.Contents)
case *DTimestamp:
return NewDInt(DInt(v.Unix())), nil
case *DTimestampTZ:
Expand Down Expand Up @@ -1739,6 +1768,8 @@ func (expr *CastExpr) Eval(ctx *EvalContext) (Datum, error) {
return NewDFloat(DFloat(f)), nil
case *DString:
return ParseDFloat(string(*v))
case *DCollatedString:
return ParseDFloat(v.Contents)
case *DTimestamp:
micros := float64(v.Nanosecond() / int(time.Microsecond))
return NewDFloat(DFloat(float64(v.Unix()) + micros*1e-6)), nil
Expand Down Expand Up @@ -1776,6 +1807,8 @@ func (expr *CastExpr) Eval(ctx *EvalContext) (Datum, error) {
return d, nil
case *DString:
return ParseDDecimal(string(*v))
case *DCollatedString:
return ParseDDecimal(v.Contents)
case *DTimestamp:
var res DDecimal
val := res.UnscaledBig()
Expand Down Expand Up @@ -1809,6 +1842,8 @@ func (expr *CastExpr) Eval(ctx *EvalContext) (Datum, error) {
s = DString(d.String())
case *DString:
s = *t
case *DCollatedString:
s = DString(t.Contents)
case *DBytes:
if !utf8.ValidString(string(*t)) {
return nil, fmt.Errorf("invalid utf8: %q", string(*t))
Expand All @@ -1828,6 +1863,8 @@ func (expr *CastExpr) Eval(ctx *EvalContext) (Datum, error) {
switch t := d.(type) {
case *DString:
return NewDBytes(DBytes(*t)), nil
case *DCollatedString:
return NewDBytes(DBytes(t.Contents)), nil
case *DBytes:
return d, nil
}
Expand All @@ -1836,6 +1873,8 @@ func (expr *CastExpr) Eval(ctx *EvalContext) (Datum, error) {
switch d := d.(type) {
case *DString:
return ParseDDate(string(*d), ctx.GetLocation())
case *DCollatedString:
return ParseDDate(d.Contents, ctx.GetLocation())
case *DDate:
return d, nil
case *DInt:
Expand All @@ -1851,6 +1890,8 @@ func (expr *CastExpr) Eval(ctx *EvalContext) (Datum, error) {
switch d := d.(type) {
case *DString:
return ParseDTimestamp(string(*d), time.Microsecond)
case *DCollatedString:
return ParseDTimestamp(d.Contents, time.Microsecond)
case *DDate:
year, month, day := time.Unix(int64(*d)*secondsInDay, 0).UTC().Date()
return MakeDTimestamp(time.Date(year, month, day, 0, 0, 0, 0, time.UTC), time.Microsecond), nil
Expand All @@ -1867,6 +1908,8 @@ func (expr *CastExpr) Eval(ctx *EvalContext) (Datum, error) {
switch d := d.(type) {
case *DString:
return ParseDTimestampTZ(string(*d), ctx.GetLocation(), time.Microsecond)
case *DCollatedString:
return ParseDTimestampTZ(d.Contents, ctx.GetLocation(), time.Microsecond)
case *DDate:
year, month, day := time.Unix(int64(*d)*secondsInDay, 0).UTC().Date()
return MakeDTimestampTZ(time.Date(year, month, day, 0, 0, 0, 0, ctx.GetLocation()), time.Microsecond), nil
Expand All @@ -1883,6 +1926,8 @@ func (expr *CastExpr) Eval(ctx *EvalContext) (Datum, error) {
switch v := d.(type) {
case *DString:
return ParseDInterval(string(*v))
case *DCollatedString:
return ParseDInterval(v.Contents)
case *DInt:
// An integer duration represents a duration in microseconds.
return &DInterval{Duration: duration.Duration{Nanos: int64(*v) * 1000}}, nil
Expand All @@ -1899,6 +1944,22 @@ func (expr *AnnotateTypeExpr) Eval(ctx *EvalContext) (Datum, error) {
return expr.Expr.(TypedExpr).Eval(ctx)
}

// Eval implements the TypedExpr interface.
func (expr *CollateExpr) Eval(ctx *EvalContext) (Datum, error) {
d, err := expr.Expr.(TypedExpr).Eval(ctx)
if err != nil {
return DNull, err
}
switch d := d.(type) {
case *DString:
return NewDCollatedString(string(*d), expr.Locale, &ctx.collationEnv), nil
case *DCollatedString:
return NewDCollatedString(d.Contents, expr.Locale, &ctx.collationEnv), nil
default:
panic(fmt.Sprintf("invalid argument to COLLATE: %s", d))
}
}

// Eval implements the TypedExpr interface.
func (expr *CoalesceExpr) Eval(ctx *EvalContext) (Datum, error) {
for _, e := range expr.Exprs {
Expand Down Expand Up @@ -2281,6 +2342,11 @@ func (t *DString) Eval(_ *EvalContext) (Datum, error) {
return t, nil
}

// Eval implements the TypedExpr interface.
func (t *DCollatedString) Eval(_ *EvalContext) (Datum, error) {
return t, nil
}

// Eval implements the TypedExpr interface.
func (t *DTimestamp) Eval(_ *EvalContext) (Datum, error) {
return t, nil
Expand Down
17 changes: 17 additions & 0 deletions pkg/sql/parser/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,21 @@ func (node *AnnotateTypeExpr) annotationType() Type {
return typ
}

// CollateExpr represents an (expr COLLATE locale) expression.
type CollateExpr struct {
Expr Expr
Locale string

typeAnnotation
}

// Format implements the NodeFormatter interface.
func (node *CollateExpr) Format(buf *bytes.Buffer, f FmtFlags) {
FormatNode(buf, f, node.Expr)
buf.WriteString(" COLLATE ")
buf.WriteString(node.Locale)
}

func (node *AliasedTableExpr) String() string { return AsString(node) }
func (node *ParenTableExpr) String() string { return AsString(node) }
func (node *JoinTableExpr) String() string { return AsString(node) }
Expand All @@ -1053,6 +1068,7 @@ func (node *BinaryExpr) String() string { return AsString(node) }
func (node *CaseExpr) String() string { return AsString(node) }
func (node *CastExpr) String() string { return AsString(node) }
func (node *CoalesceExpr) String() string { return AsString(node) }
func (node *CollateExpr) String() string { return AsString(node) }
func (node *ComparisonExpr) String() string { return AsString(node) }
func (node *DBool) String() string { return AsString(node) }
func (node *DBytes) String() string { return AsString(node) }
Expand All @@ -1062,6 +1078,7 @@ func (node *DFloat) String() string { return AsString(node) }
func (node *DInt) String() string { return AsString(node) }
func (node *DInterval) String() string { return AsString(node) }
func (node *DString) String() string { return AsString(node) }
func (node *DCollatedString) String() string { return AsString(node) }
func (node *DTimestamp) String() string { return AsString(node) }
func (node *DTimestampTZ) String() string { return AsString(node) }
func (node *DTuple) String() string { return AsString(node) }
Expand Down
Loading

0 comments on commit 172ae88

Please sign in to comment.