Skip to content

Commit

Permalink
coltypes: don't shallow copy decimals
Browse files Browse the repository at this point in the history
This commit fixes a correctness bug caused by illegal shallow copy of
decimals. To copy a decimal, you must use the .Set method on it, and its
memory must already be distinct from the source decimal.

To facilitate this and generally simplify things, the "safe" GET method
is removed and replaced by COPYVAL, which copies a scalar value to
another scalar value safely. There were fewer than 5 users of GET, so
this wasn't particularly disruptive.

Release note: None
  • Loading branch information
jordanlewis committed Aug 30, 2019
1 parent f81ac2b commit 8fde2d8
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 54 deletions.
141 changes: 118 additions & 23 deletions pkg/col/coltypes/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ package coltypes

import (
"fmt"
"strings"
"text/template"

"github.com/cockroachdb/apd"
)
Expand Down Expand Up @@ -143,6 +145,7 @@ var (
_ = Bool.Swap
_ = Bool.Slice
_ = Bool.CopySlice
_ = Bool.CopyVal
_ = Bool.AppendSlice
_ = Bool.AppendVal
_ = Bool.Len
Expand All @@ -159,34 +162,63 @@ func (t T) GoTypeSliceName() string {
}

// Get is a function that should only be used in templates.
func (t T) Get(safe string, target, i string) string {
if safe != "safe" && safe != "unsafe" {
panic(fmt.Sprintf("unknown safe argument %s, use either safe or unsafe", safe))
}
if t == Bytes {
getString := fmt.Sprintf("%s.Get(%s)", target, i)
if safe == "safe" {
getString = "append([]byte(nil), " + getString + "...)"
}
return getString
func (t T) Get(target, i string) string {
switch t {
case Bytes:
return fmt.Sprintf("%s.Get(%s)", target, i)
}
return fmt.Sprintf("%s[%s]", target, i)
}

// CopyVal is a function that should only be used in templates.
func (t T) CopyVal(dest, src string) string {
switch t {
case Bytes:
return fmt.Sprintf("%[1]s = append(%[1]s[:0], %[2]s...)", dest, src)
case Decimal:
return fmt.Sprintf("%s.Set(&%s)", dest, src)
}
return fmt.Sprintf("%s = %s", dest, src)
}

// Set is a function that should only be used in templates.
func (t T) Set(target, i, new string) string {
if t == Bytes {
switch t {
case Bytes:
return fmt.Sprintf("%s.Set(%s, %s)", target, i, new)
case Decimal:
return fmt.Sprintf("%s[%s].Set(&%s)", target, i, new)
}
return fmt.Sprintf("%s[%s] = %s", target, i, new)
}

// Swap is a function that should only be used in templates.
func (t T) Swap(target, i, j string) string {
if t == Bytes {
return fmt.Sprintf("%s.Swap(%s, %s)", target, i, j)
var tmpl string
switch t {
case Bytes:
tmpl = `{{.Tgt}}.Swap({{.I}}, {{.J}})`
case Decimal:
tmpl = `
{
var __tmp apd.Decimal
__tmp.Set(&{{.Tgt}}[{{.I}}])
{{.Tgt}}[{{.I}}].Set(&{{.Tgt}}[{{.J}}])
{{.Tgt}}[{{.J}}].Set(&{{.Tgt}}[{{.I}}])
}`
default:
tmpl = `{{.Tgt}}[{{.I}}], {{.Tgt}}[{{.J}}] = {{.Tgt}}[{{.J}}], {{.Tgt}}[{{.I}}]`
}
args := map[string]string{
"Tgt": target,
"I": j,
"J": j,
}
return fmt.Sprintf("%[1]s[%[2]s], %[1]s[%[3]s] = %[1]s[%[3]s], %[1]s[%[2]s]", target, i, j)
var buf strings.Builder
if err := template.Must(template.New("").Parse(tmpl)).Execute(&buf, args); err != nil {
panic(err)
}
return buf.String()
}

// Slice is a function that should only be used in templates.
Expand All @@ -199,26 +231,84 @@ func (t T) Slice(target, start, end string) string {

// CopySlice is a function that should only be used in templates.
func (t T) CopySlice(target, src, destIdx, srcStartIdx, srcEndIdx string) string {
if t == Bytes {
return fmt.Sprintf("%s.CopySlice(%s, %s, %s, %s)", target, src, destIdx, srcStartIdx, srcEndIdx)
var tmpl string
switch t {
case Bytes:
tmpl = `{{.Tgt}}.CopySlice({{.Src}}, {{.TgtIdx}}, {{.SrcStart}}, {{.SrcEnd}})`
case Decimal:
tmpl = `{
__tgt_slice := {{.Tgt}}[{{.TgtIdx}}:]
__src_slice := {{.Src}}[{{.SrcStart}}:{{.SrcEnd}}]
for __i := range __src_slice {
__tgt_slice[__i].Set(&__src_slice[__i])
}
}`
default:
tmpl = `copy({{.Tgt}}[{{.TgtIdx}}:], {{.Src}}[{{.SrcStart}}:{{.SrcEnd}}])`
}
args := map[string]string{
"Tgt": target,
"Src": src,
"TgtIdx": destIdx,
"SrcStart": srcStartIdx,
"SrcEnd": srcEndIdx,
}
var buf strings.Builder
if err := template.Must(template.New("").Parse(tmpl)).Execute(&buf, args); err != nil {
panic(err)
}
return fmt.Sprintf("copy(%s[%s:], %s[%s:%s])", target, destIdx, src, srcStartIdx, srcEndIdx)
return buf.String()
}

// AppendSlice is a function that should only be used in templates.
func (t T) AppendSlice(target, src, destIdx, srcStartIdx, srcEndIdx string) string {
if t == Bytes {
return fmt.Sprintf("%s.AppendSlice(%s, %s, %s, %s)", target, src, destIdx, srcStartIdx, srcEndIdx)
var tmpl string
switch t {
case Bytes:
tmpl = `{{.Tgt}}.AppendSlice({{.Src}}, {{.TgtIdx}}, {{.SrcStart}}, {{.SrcEnd}})`
case Decimal:
tmpl = `{
__desiredCap := {{.TgtIdx}} + {{.SrcEnd}} - {{.SrcStart}}
if cap({{.Tgt}}) >= __desiredCap {
{{.Tgt}} = {{.Tgt}}[:__desiredCap]
} else {
__new_slice := make([]apd.Decimal, __desiredCap)
copy(__new_slice, {{.Tgt}})
{{.Tgt}} = __new_slice
}
__src_slice := {{.Src}}[{{.SrcStart}}:{{.SrcEnd}}]
__dst_slice := {{.Tgt}}[{{.TgtIdx}}:]
for __i := range __src_slice {
__dst_slice[__i].Set(&__src_slice[__i])
}
}`
default:
tmpl = `{{.Tgt}} = append({{.Tgt}}[:{{.TgtIdx}}], {{.Src}}[{{.SrcStart}}:{{.SrcEnd}}]...)`
}
return fmt.Sprintf("%[1]s = append(%[1]s[:%s], %s[%s:%s]...)", target, destIdx, src, srcStartIdx, srcEndIdx)
args := map[string]string{
"Tgt": target,
"Src": src,
"TgtIdx": destIdx,
"SrcStart": srcStartIdx,
"SrcEnd": srcEndIdx,
}
var buf strings.Builder
if err := template.Must(template.New("").Parse(tmpl)).Execute(&buf, args); err != nil {
panic(err)
}
return buf.String()
}

// AppendVal is a function that should only be used in templates.
func (t T) AppendVal(target, v string) string {
if t == Bytes {
switch t {
case Bytes:
return fmt.Sprintf("%s.AppendVal(%s)", target, v)
case Decimal:
return fmt.Sprintf(`%[1]s = append(%[1]s, apd.Decimal{})
%[1]s[len(%[1]s)-1].Set(&%[2]s)`, target, v)
}
return fmt.Sprintf("%[1]s = append(%[1]s, %s)", target, v)
return fmt.Sprintf("%[1]s = append(%[1]s, %[2]s)", target, v)
}

// Len is a function that should only be used in templates.
Expand All @@ -239,8 +329,13 @@ func (t T) Range(loopVariableIdent string, target string) string {

// Zero is a function that should only be used in templates.
func (t T) Zero(target string) string {
if t == Bytes {
switch t {
case Bytes:
return fmt.Sprintf("%s.Zero()", target)
case Decimal:
return fmt.Sprintf(`for n := 0; n < len(%[1]s); n++ {
%[1]s[n].SetInt64(0)
}`, target)
}
return fmt.Sprintf("for n := 0; n < len(%[1]s); n += copy(%[1]s[n:], zero%sColumn) {}", target, t.String())
}
2 changes: 1 addition & 1 deletion pkg/sql/exec/any_not_null_agg_tmpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func _FIND_ANY_NOT_NULL(a *anyNotNull_TYPEAgg, nulls *coldata.Nulls, i int, _HAS
// Explicit template language is used here because the type receiver differs
// from the rest of the template file.
// TODO(asubiotto): Figure out a way to alias this.
// v := {{ .Global.Get "unsafe" "col" "int(i)" }}
// v := {{ .Global.Get "col" "int(i)" }}
// {{ .Global.Set "a.vec" "a.curIdx" "v" }}
a.foundNonNullForCurrentGroup = true
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/sql/exec/distinct_tmpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,11 @@ func _CHECK_DISTINCT(
) { // */}}

// {{define "checkDistinct"}}
v := execgen.GET(col, int(checkIdx))
v := execgen.UNSAFEGET(col, int(checkIdx))
var unique bool
_ASSIGN_NE(unique, v, lastVal)
outputCol[outputIdx] = outputCol[outputIdx] || unique
lastVal = v
execgen.COPYVAL(lastVal, v)
// {{end}}

// {{/*
Expand All @@ -388,7 +388,7 @@ func _CHECK_DISTINCT_WITH_NULLS(

// {{define "checkDistinctWithNulls"}}
null := nulls.NullAt(uint16(checkIdx))
v := execgen.GET(col, int(checkIdx))
v := execgen.UNSAFEGET(col, int(checkIdx))
if null != lastValNull {
// Either the current value is null and the previous was not or vice-versa.
outputCol[outputIdx] = true
Expand All @@ -398,7 +398,7 @@ func _CHECK_DISTINCT_WITH_NULLS(
_ASSIGN_NE(unique, v, lastVal)
outputCol[outputIdx] = outputCol[outputIdx] || unique
}
lastVal = v
execgen.COPYVAL(lastVal, v)
lastValNull = null
// {{end}}

Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/exec/execgen/cmd/execgen/data_manipulation_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ var dataManipulationReplacementInfos = []dataManipulationReplacementInfo{
{
templatePlaceholder: "execgen.UNSAFEGET",
numArgs: 2,
replaceWith: "Get \"unsafe\"",
replaceWith: "Get",
},
{
templatePlaceholder: "execgen.GET",
templatePlaceholder: "execgen.COPYVAL",
numArgs: 2,
replaceWith: "Get \"safe\"",
replaceWith: "CopyVal",
},
{
templatePlaceholder: "execgen.SET",
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/exec/execgen/cmd/execgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ func (g *execgen) run(args ...string) bool {
return true
}

var emptyCommentRegex = regexp.MustCompile(`[ \t]*//[ \t]*\n`)
var emptyBlockCommentRegex = regexp.MustCompile(`[ \t]*/\*[ \t]*\*/[ \t]*\n`)
var emptyCommentRegex = regexp.MustCompile(`[ \t]*//[ \t]*$`)
var emptyBlockCommentRegex = regexp.MustCompile(`[ \t]*/\*[ \t]*\*/[ \t]*$`)

func (g *execgen) generate(genFunc generator, out string) error {
var buf bytes.Buffer
Expand Down
16 changes: 8 additions & 8 deletions pkg/sql/exec/execgen/cmd/execgen/projection_ops_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ func (p {{template "opRConstName" .}}) Next(ctx context.Context) coldata.Batch {
projCol := projVec.{{.RetTyp}}()
if sel := batch.Selection(); sel != nil {
for _, i := range sel {
arg := {{.LTyp.Get "unsafe" "col" "int(i)"}}
arg := {{.LTyp.Get "col" "int(i)"}}
{{(.Assign "projCol[i]" "arg" "p.constArg")}}
}
} else {
col = {{.LTyp.Slice "col" "0" "int(n)"}}
colLen := {{.LTyp.Len "col"}}
_ = projCol[colLen-1]
for {{.LTyp.Range "i" "col"}} {
arg := {{.LTyp.Get "unsafe" "col" "i"}}
arg := {{.LTyp.Get "col" "i"}}
{{(.Assign "projCol[i]" "arg" "p.constArg")}}
}
}
Expand Down Expand Up @@ -121,15 +121,15 @@ func (p {{template "opLConstName" .}}) Next(ctx context.Context) coldata.Batch {
projCol := projVec.{{.RetTyp}}()
if sel := batch.Selection(); sel != nil {
for _, i := range sel {
arg := {{.RTyp.Get "unsafe" "col" "int(i)"}}
arg := {{.RTyp.Get "col" "int(i)"}}
{{(.Assign "projCol[i]" "p.constArg" "arg")}}
}
} else {
col = {{.RTyp.Slice "col" "0" "int(n)"}}
colLen := {{.RTyp.Len "col"}}
_ = projCol[colLen-1]
for {{.RTyp.Range "i" "col"}} {
arg := {{.RTyp.Get "unsafe" "col" "i"}}
arg := {{.RTyp.Get "col" "i"}}
{{(.Assign "projCol[i]" "p.constArg" "arg")}}
}
}
Expand Down Expand Up @@ -176,8 +176,8 @@ func (p {{template "opName" .}}) Next(ctx context.Context) coldata.Batch {
col2 := vec2.{{.RTyp}}()
if sel := batch.Selection(); sel != nil {
for _, i := range sel {
arg1 := {{.LTyp.Get "unsafe" "col1" "int(i)"}}
arg2 := {{.RTyp.Get "unsafe" "col2" "int(i)"}}
arg1 := {{.LTyp.Get "col1" "int(i)"}}
arg2 := {{.RTyp.Get "col2" "int(i)"}}
{{(.Assign "projCol[i]" "arg1" "arg2")}}
}
} else {
Expand All @@ -186,8 +186,8 @@ func (p {{template "opName" .}}) Next(ctx context.Context) coldata.Batch {
_ = projCol[colLen-1]
_ = {{.LTyp.Slice "col2" "0" "colLen-1"}}
for {{.LTyp.Range "i" "col1"}} {
arg1 := {{.LTyp.Get "unsafe" "col1" "i"}}
arg2 := {{.LTyp.Get "unsafe" "col2" "i"}}
arg1 := {{.LTyp.Get "col1" "i"}}
arg2 := {{.LTyp.Get "col2" "i"}}
{{(.Assign "projCol[i]" "arg1" "arg2")}}
}
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/sql/exec/execgen/cmd/execgen/selection_ops_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ if sel := batch.Selection(); sel != nil {
sel = sel[:n]
for _, i := range sel {
var cmp bool
arg := {{.Global.LTyp.Get "unsafe" "col" "int(i)"}}
arg := {{.Global.LTyp.Get "col" "int(i)"}}
{{(.Global.Assign "cmp" "arg" "p.constArg")}}
if cmp {{if .HasNulls}}&& !nulls.NullAt(i) {{end}}{
sel[idx] = i
Expand All @@ -56,7 +56,7 @@ if sel := batch.Selection(); sel != nil {
col = {{.Global.LTyp.Slice "col" "0" "int(n)"}}
for {{.Global.LTyp.Range "i" "col"}} {
var cmp bool
arg := {{.Global.LTyp.Get "unsafe" "col" "i"}}
arg := {{.Global.LTyp.Get "col" "i"}}
{{(.Global.Assign "cmp" "arg" "p.constArg")}}
if cmp {{if .HasNulls}}&& !nulls.NullAt(uint16(i)) {{end}}{
sel[idx] = uint16(i)
Expand All @@ -71,8 +71,8 @@ if sel := batch.Selection(); sel != nil {
sel = sel[:n]
for _, i := range sel {
var cmp bool
arg1 := {{.Global.LTyp.Get "unsafe" "col1" "int(i)"}}
arg2 := {{.Global.RTyp.Get "unsafe" "col2" "int(i)"}}
arg1 := {{.Global.LTyp.Get "col1" "int(i)"}}
arg2 := {{.Global.RTyp.Get "col2" "int(i)"}}
{{(.Global.Assign "cmp" "arg1" "arg2")}}
if cmp {{if .HasNulls}}&& !nulls.NullAt(i) {{end}}{
sel[idx] = i
Expand All @@ -87,8 +87,8 @@ if sel := batch.Selection(); sel != nil {
col2 = {{.Global.RTyp.Slice "col2" "0" "col1Len"}}
for {{.Global.LTyp.Range "i" "col1"}} {
var cmp bool
arg1 := {{.Global.LTyp.Get "unsafe" "col1" "i"}}
arg2 := {{.Global.RTyp.Get "unsafe" "col2" "i"}}
arg1 := {{.Global.LTyp.Get "col1" "i"}}
arg2 := {{.Global.RTyp.Get "col2" "i"}}
{{(.Global.Assign "cmp" "arg1" "arg2")}}
if cmp {{if .HasNulls}}&& !nulls.NullAt(uint16(i)) {{end}}{
sel[idx] = uint16(i)
Expand Down
12 changes: 6 additions & 6 deletions pkg/sql/exec/execgen/placeholders.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const nonTemplatePanic = "do not call from non-template code"
// Remove unused warnings.
var (
_ = UNSAFEGET
_ = GET
_ = COPYVAL
_ = SET
_ = SWAP
_ = SLICE
Expand All @@ -36,12 +36,12 @@ func UNSAFEGET(target, i interface{}) interface{} {
return nil
}

// GET is a template function. The Bytes implementation of this function
// performs a copy. Use this if you need to keep values around past the
// lifecycle of a Batch.
func GET(target, i interface{}) interface{} {
// COPYVAL is a template function that can be used to set a scalar to the value
// of another scalar in such a way that the destination won't be modified if the
// source is. You must use this on the result of UNSAFEGET if you wish to store
// that result past the lifetime of the batch you UNSAFEGET'd from.
func COPYVAL(dest, src interface{}) {
execerror.VectorizedInternalPanic(nonTemplatePanic)
return nil
}

// SET is a template function.
Expand Down
Loading

0 comments on commit 8fde2d8

Please sign in to comment.