Skip to content

Commit

Permalink
planner enhancement: nice bindvar names for update (#11581)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>

Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay authored Oct 27, 2022
1 parent 6ac536c commit 781181b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 25 deletions.
31 changes: 23 additions & 8 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ func (nz *normalizer) WalkStatement(cursor *Cursor) bool {
nz.convertLiteral(node, cursor)
case *ComparisonExpr:
nz.convertComparison(node)
case *UpdateExpr:
nz.convertUpdateExpr(node)
case *ColName, TableName:
// Common node types that never contain Literal or ListArgs but create a lot of object
// allocations.
Expand Down Expand Up @@ -198,28 +200,34 @@ func (nz *normalizer) convertComparison(node *ComparisonExpr) {
}

func (nz *normalizer) rewriteOtherComparisons(node *ComparisonExpr) {
col, ok := node.Left.(*ColName)
newR := nz.parameterize(node.Left, node.Right)
if newR != nil {
node.Right = newR
}
}

func (nz *normalizer) parameterize(left, right Expr) Expr {
col, ok := left.(*ColName)
if !ok {
return
return nil
}
lit, ok := node.Right.(*Literal)
lit, ok := right.(*Literal)
if !ok {
return
return nil
}
err := validateLiteral(lit)
if err != nil {
nz.err = err
return
return nil
}

bval := SQLToBindvar(lit)
if bval == nil {
return
return nil
}
key := keyFor(bval, lit)
bvname := nz.decideBindVarName(key, lit, col, bval)

node.Right = Argument(bvname)
return Argument(bvname)
}

func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName, bval *querypb.BindVariable) string {
Expand Down Expand Up @@ -268,6 +276,13 @@ func (nz *normalizer) rewriteInComparisons(node *ComparisonExpr) {
node.Right = ListArg(bvname)
}

func (nz *normalizer) convertUpdateExpr(node *UpdateExpr) {
newR := nz.parameterize(node.Name, node.Expr)
if newR != nil {
node.Expr = newR
}
}

func SQLToBindvar(node SQLNode) *querypb.BindVariable {
if node, ok := node.(*Literal); ok {
var v sqltypes.Value
Expand Down
20 changes: 9 additions & 11 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,17 @@ func TestNormalize(t *testing.T) {
}, {
// val should be reused only in subqueries of DMLs
in: "update a set v1=(select 5 from t), v2=5, v3=(select 5 from t), v4=5",
outstmt: "update a set v1 = (select :bv1 from t), v2 = :bv2, v3 = (select :bv1 from t), v4 = :bv3",
outstmt: "update a set v1 = (select :bv1 from t), v2 = :bv1, v3 = (select :bv1 from t), v4 = :bv1",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.Int64BindVariable(5),
"bv2": sqltypes.Int64BindVariable(5),
"bv3": sqltypes.Int64BindVariable(5),
},
}, {
// list vars should work for DMLs also
in: "update a set v1=5 where v2 in (1, 4, 5)",
outstmt: "update a set v1 = :bv1 where v2 in ::bv2",
outstmt: "update a set v1 = :v1 where v2 in ::bv1",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.Int64BindVariable(5),
"bv2": sqltypes.TestBindVariable([]any{1, 4, 5}),
"v1": sqltypes.Int64BindVariable(5),
"bv1": sqltypes.TestBindVariable([]any{1, 4, 5}),
},
}, {
// Hex number values should work for selects
Expand All @@ -157,10 +155,10 @@ func TestNormalize(t *testing.T) {
},
}, {
// Hex number values should work for DMLs
in: "update a set v1 = 0x12",
outstmt: "update a set v1 = :bv1",
in: "update a set foo = 0x12",
outstmt: "update a set foo = :foo",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.HexNumBindVariable([]byte("0x12")),
"foo": sqltypes.HexNumBindVariable([]byte("0x12")),
},
}, {
// Bin values work fine
Expand All @@ -172,9 +170,9 @@ func TestNormalize(t *testing.T) {
}, {
// Bin value does not convert for DMLs
in: "update a set v1 = b'11'",
outstmt: "update a set v1 = :bv1",
outstmt: "update a set v1 = :v1",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.HexNumBindVariable([]byte("0x3")),
"v1": sqltypes.HexNumBindVariable([]byte("0x3")),
},
}, {
// ORDER BY column_position
Expand Down
12 changes: 6 additions & 6 deletions go/vt/vtgate/executor_dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,10 @@ func TestUpdateNormalize(t *testing.T) {
_, err := executorExec(executor, "/* leading */ update user set a=2 where id = 1 /* trailing */", nil)
require.NoError(t, err)
wantQueries := []*querypb.BoundQuery{{
Sql: "/* leading */ update `user` set a = :vtg1 where id = :id /* trailing */",
Sql: "/* leading */ update `user` set a = :a where id = :id /* trailing */",
BindVariables: map[string]*querypb.BindVariable{
"vtg1": sqltypes.TestBindVariable(int64(2)),
"id": sqltypes.TestBindVariable(int64(1)),
"a": sqltypes.TestBindVariable(int64(2)),
"id": sqltypes.TestBindVariable(int64(1)),
},
}}
assertQueries(t, sbc1, wantQueries)
Expand All @@ -353,10 +353,10 @@ func TestUpdateNormalize(t *testing.T) {
_, err = executorExec(executor, "/* leading */ update user set a=2 where id = 1 /* trailing */", nil)
require.NoError(t, err)
wantQueries = []*querypb.BoundQuery{{
Sql: "/* leading */ update `user` set a = :vtg1 where id = :id /* trailing */",
Sql: "/* leading */ update `user` set a = :a where id = :id /* trailing */",
BindVariables: map[string]*querypb.BindVariable{
"vtg1": sqltypes.TestBindVariable(int64(2)),
"id": sqltypes.TestBindVariable(int64(1)),
"a": sqltypes.TestBindVariable(int64(2)),
"id": sqltypes.TestBindVariable(int64(1)),
},
}}
assertQueries(t, sbc1, nil)
Expand Down

0 comments on commit 781181b

Please sign in to comment.