From 781181b7e02c6dbc3260ca1115e3f77adc0b57f0 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 27 Oct 2022 14:18:51 +0200 Subject: [PATCH] planner enhancement: nice bindvar names for update (#11581) Signed-off-by: Andres Taylor Signed-off-by: Andres Taylor --- go/vt/sqlparser/normalizer.go | 31 ++++++++++++++++++++++-------- go/vt/sqlparser/normalizer_test.go | 20 +++++++++---------- go/vt/vtgate/executor_dml_test.go | 12 ++++++------ 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index d92202416fb..5267042f682 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -72,6 +72,8 @@ func (nz *normalizer) WalkStatement(cursor *Cursor) bool { nz.convertLiteral(node, cursor) case *ComparisonExpr: nz.convertComparison(node) + case *UpdateExpr: + nz.convertUpdateExpr(node) case *ColName, TableName: // Common node types that never contain Literal or ListArgs but create a lot of object // allocations. @@ -198,28 +200,34 @@ func (nz *normalizer) convertComparison(node *ComparisonExpr) { } func (nz *normalizer) rewriteOtherComparisons(node *ComparisonExpr) { - col, ok := node.Left.(*ColName) + newR := nz.parameterize(node.Left, node.Right) + if newR != nil { + node.Right = newR + } +} + +func (nz *normalizer) parameterize(left, right Expr) Expr { + col, ok := left.(*ColName) if !ok { - return + return nil } - lit, ok := node.Right.(*Literal) + lit, ok := right.(*Literal) if !ok { - return + return nil } err := validateLiteral(lit) if err != nil { nz.err = err - return + return nil } bval := SQLToBindvar(lit) if bval == nil { - return + return nil } key := keyFor(bval, lit) bvname := nz.decideBindVarName(key, lit, col, bval) - - node.Right = Argument(bvname) + return Argument(bvname) } func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName, bval *querypb.BindVariable) string { @@ -268,6 +276,13 @@ func (nz *normalizer) rewriteInComparisons(node *ComparisonExpr) { node.Right = ListArg(bvname) } +func (nz *normalizer) convertUpdateExpr(node *UpdateExpr) { + newR := nz.parameterize(node.Name, node.Expr) + if newR != nil { + node.Expr = newR + } +} + func SQLToBindvar(node SQLNode) *querypb.BindVariable { if node, ok := node.(*Literal); ok { var v sqltypes.Value diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 9eddd3d7308..14952ac1226 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -120,19 +120,17 @@ func TestNormalize(t *testing.T) { }, { // val should be reused only in subqueries of DMLs in: "update a set v1=(select 5 from t), v2=5, v3=(select 5 from t), v4=5", - outstmt: "update a set v1 = (select :bv1 from t), v2 = :bv2, v3 = (select :bv1 from t), v4 = :bv3", + outstmt: "update a set v1 = (select :bv1 from t), v2 = :bv1, v3 = (select :bv1 from t), v4 = :bv1", outbv: map[string]*querypb.BindVariable{ "bv1": sqltypes.Int64BindVariable(5), - "bv2": sqltypes.Int64BindVariable(5), - "bv3": sqltypes.Int64BindVariable(5), }, }, { // list vars should work for DMLs also in: "update a set v1=5 where v2 in (1, 4, 5)", - outstmt: "update a set v1 = :bv1 where v2 in ::bv2", + outstmt: "update a set v1 = :v1 where v2 in ::bv1", outbv: map[string]*querypb.BindVariable{ - "bv1": sqltypes.Int64BindVariable(5), - "bv2": sqltypes.TestBindVariable([]any{1, 4, 5}), + "v1": sqltypes.Int64BindVariable(5), + "bv1": sqltypes.TestBindVariable([]any{1, 4, 5}), }, }, { // Hex number values should work for selects @@ -157,10 +155,10 @@ func TestNormalize(t *testing.T) { }, }, { // Hex number values should work for DMLs - in: "update a set v1 = 0x12", - outstmt: "update a set v1 = :bv1", + in: "update a set foo = 0x12", + outstmt: "update a set foo = :foo", outbv: map[string]*querypb.BindVariable{ - "bv1": sqltypes.HexNumBindVariable([]byte("0x12")), + "foo": sqltypes.HexNumBindVariable([]byte("0x12")), }, }, { // Bin values work fine @@ -172,9 +170,9 @@ func TestNormalize(t *testing.T) { }, { // Bin value does not convert for DMLs in: "update a set v1 = b'11'", - outstmt: "update a set v1 = :bv1", + outstmt: "update a set v1 = :v1", outbv: map[string]*querypb.BindVariable{ - "bv1": sqltypes.HexNumBindVariable([]byte("0x3")), + "v1": sqltypes.HexNumBindVariable([]byte("0x3")), }, }, { // ORDER BY column_position diff --git a/go/vt/vtgate/executor_dml_test.go b/go/vt/vtgate/executor_dml_test.go index d0016315f16..dbf5a38b70c 100644 --- a/go/vt/vtgate/executor_dml_test.go +++ b/go/vt/vtgate/executor_dml_test.go @@ -338,10 +338,10 @@ func TestUpdateNormalize(t *testing.T) { _, err := executorExec(executor, "/* leading */ update user set a=2 where id = 1 /* trailing */", nil) require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "/* leading */ update `user` set a = :vtg1 where id = :id /* trailing */", + Sql: "/* leading */ update `user` set a = :a where id = :id /* trailing */", BindVariables: map[string]*querypb.BindVariable{ - "vtg1": sqltypes.TestBindVariable(int64(2)), - "id": sqltypes.TestBindVariable(int64(1)), + "a": sqltypes.TestBindVariable(int64(2)), + "id": sqltypes.TestBindVariable(int64(1)), }, }} assertQueries(t, sbc1, wantQueries) @@ -353,10 +353,10 @@ func TestUpdateNormalize(t *testing.T) { _, err = executorExec(executor, "/* leading */ update user set a=2 where id = 1 /* trailing */", nil) require.NoError(t, err) wantQueries = []*querypb.BoundQuery{{ - Sql: "/* leading */ update `user` set a = :vtg1 where id = :id /* trailing */", + Sql: "/* leading */ update `user` set a = :a where id = :id /* trailing */", BindVariables: map[string]*querypb.BindVariable{ - "vtg1": sqltypes.TestBindVariable(int64(2)), - "id": sqltypes.TestBindVariable(int64(1)), + "a": sqltypes.TestBindVariable(int64(2)), + "id": sqltypes.TestBindVariable(int64(1)), }, }} assertQueries(t, sbc1, nil)