diff --git a/src/executor/engine/join_engine.go b/src/executor/engine/join_engine.go index 9c70b68e..e42d20c4 100644 --- a/src/executor/engine/join_engine.go +++ b/src/executor/engine/join_engine.go @@ -101,7 +101,7 @@ func (j *JoinEngine) Execute(ctx *xcontext.ResultContext) error { return operator.ExecSubPlan(j.log, j.node, ctx) } -// execBindVars used to execute querys with bindvas. +// execBindVars used to execute querys with bindvars. func (j *JoinEngine) execBindVars(ctx *xcontext.ResultContext, bindVars map[string]*querypb.BindVariable, wantfields bool) error { var err error lctx := xcontext.NewResultContext() @@ -115,16 +115,15 @@ func (j *JoinEngine) execBindVars(ctx *xcontext.ResultContext, bindVars map[stri } for _, lrow := range lctx.Results.Rows { - blend := true + leftMatch := true matchCnt := 0 for _, idx := range j.node.LeftTmpCols { - vn := lrow[idx].ToNative() - if vn.(int64) == 0 { - blend = false + if !sqltypes.CastToBool(lrow[idx]) { + leftMatch = false break } } - if blend { + if leftMatch { for k, col := range j.node.Vars { joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) } diff --git a/src/executor/engine/merge_join.go b/src/executor/engine/merge_join.go index 8da8ee8b..888dce02 100644 --- a/src/executor/engine/merge_join.go +++ b/src/executor/engine/merge_join.go @@ -143,17 +143,16 @@ func concatLeftAndRight(lrows, rrows [][]sqltypes.Value, node *planner.JoinNode, return } - blend := true + leftMatch := true matchCnt := 0 for _, idx := range node.LeftTmpCols { - vn := lrow[idx].ToNative() - if vn == nil || vn.(int64) == 0 { - blend = false + if !sqltypes.CastToBool(lrow[idx]) { + leftMatch = false break } } - if blend { + if leftMatch { for _, rrow := range rrows { match := true for _, filter := range node.CmpFilter { diff --git a/src/vendor/github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes/arithmetic.go b/src/vendor/github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes/arithmetic.go index 8312f482..18942af0 100644 --- a/src/vendor/github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes/arithmetic.go +++ b/src/vendor/github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes/arithmetic.go @@ -538,3 +538,31 @@ func Cast(v Value, typ querypb.Type) (Value, error) { // go through full validation. return NewValue(typ, v.val) } + +func isNumZero(v numeric) bool { + switch v.typ { + case Uint64: + return v.uval == 0 + case Int64: + return v.ival == 0 + case Float64: + return v.fval == 0 + case Decimal: + return v.dval.IsZero() + } + panic("unreachable") +} + +// CastToBool used to cast the Value to a boolean. +func CastToBool(v Value) bool { + if v.IsNull() { + return false + } + + lv, err := newNumeric(v) + if err != nil { + return false + } + + return !isNumZero(lv) +} diff --git a/src/vendor/github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes/arithmetic_test.go b/src/vendor/github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes/arithmetic_test.go index cdfb422b..6024e4e1 100644 --- a/src/vendor/github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes/arithmetic_test.go +++ b/src/vendor/github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes/arithmetic_test.go @@ -1250,3 +1250,57 @@ func TestMax(t *testing.T) { } } } + +func TestCastToBool(t *testing.T) { + tcases := []struct { + v Value + out bool + }{{ + v: NewInt64(12), + out: true, + }, { + v: NewInt64(0), + out: false, + }, { + v: NewUint64(1), + out: true, + }, { + v: NewUint64(0), + out: false, + }, { + v: NewFloat64(1), + out: true, + }, { + v: NewFloat64(0), + out: false, + }, { + v: testVal(Decimal, "0"), + out: false, + }, { + v: testVal(Decimal, "1.2"), + out: true, + }, { + v: testVal(VarChar, "1"), + out: true, + }, { + v: testVal(VarChar, "0"), + out: false, + }, { + v: testVal(VarChar, "1.2"), + out: true, + }, { + v: testVal(VarChar, "abcd"), + out: false, + }, { + v: NULL, + out: false, + }, { + v: testVal(Uint64, "1.2"), + out: false, + }} + + for _, tcase := range tcases { + got := CastToBool(tcase.v) + assert.Equal(t, tcase.out, got) + } +}