From 707b0a4e3813a458b17027396357f59c625bf68f Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Tue, 27 Feb 2024 17:05:02 +0800 Subject: [PATCH] parser: support (Row(..),Row(..))=>(..) in the binding mode (#51319) close pingcap/tidb#51222 --- pkg/bindinfo/tests/BUILD.bazel | 2 +- pkg/bindinfo/tests/bind_test.go | 21 ++++++++++ pkg/parser/digester.go | 74 ++++++++++++++++++++++++++++++++- 3 files changed, 94 insertions(+), 3 deletions(-) diff --git a/pkg/bindinfo/tests/BUILD.bazel b/pkg/bindinfo/tests/BUILD.bazel index aedeac1666806..033d3b6ca0440 100644 --- a/pkg/bindinfo/tests/BUILD.bazel +++ b/pkg/bindinfo/tests/BUILD.bazel @@ -9,7 +9,7 @@ go_test( ], flaky = True, race = "on", - shard_count = 15, + shard_count = 16, deps = [ "//pkg/bindinfo", "//pkg/bindinfo/internal", diff --git a/pkg/bindinfo/tests/bind_test.go b/pkg/bindinfo/tests/bind_test.go index 995c241993460..17329931f041b 100644 --- a/pkg/bindinfo/tests/bind_test.go +++ b/pkg/bindinfo/tests/bind_test.go @@ -859,3 +859,24 @@ func TestJoinOrderHintWithBinding(t *testing.T) { tk.MustExec("drop global binding for select * from t1 join t2 on t1.a=t2.a join t3 on t2.b=t3.b") } + +func TestNormalizeStmtForBinding(t *testing.T) { + tests := []struct { + sql string + normalized string + digest string + }{ + {"select 1 from b where (x,y) in ((1, 3), ('3', 1))", "select ? from `b` where row ( `x` , `y` ) in ( ... )", "ab6c607d118c24030807f8d1c7c846ec23e3b752fd88ed763bb8e26fbfa56a83"}, + {"select 1 from b where (x,y) in ((1, 3), ('3', 1), (2, 3))", "select ? from `b` where row ( `x` , `y` ) in ( ... )", "ab6c607d118c24030807f8d1c7c846ec23e3b752fd88ed763bb8e26fbfa56a83"}, + {"select 1 from b where (x,y) in ((1, 3), ('3', 1), (2, 3),('x', 'y'))", "select ? from `b` where row ( `x` , `y` ) in ( ... )", "ab6c607d118c24030807f8d1c7c846ec23e3b752fd88ed763bb8e26fbfa56a83"}, + {"select 1 from b where (x,y) in ((1, 3), ('3', 1), (2, 3),('x', 'y'),('x', 'y'))", "select ? from `b` where row ( `x` , `y` ) in ( ... )", "ab6c607d118c24030807f8d1c7c846ec23e3b752fd88ed763bb8e26fbfa56a83"}, + {"select 1 from b where (x) in ((1), ('3'), (2),('x'),('x'))", "select ? from `b` where ( `x` ) in ( ( ... ) )", "03e6e1eb3d76b69363922ff269284b359ca73351001ba0e82d3221c740a6a14c"}, + {"select 1 from b where (x) in ((1), ('3'), (2),('x'))", "select ? from `b` where ( `x` ) in ( ( ... ) )", "03e6e1eb3d76b69363922ff269284b359ca73351001ba0e82d3221c740a6a14c"}, + } + for _, test := range tests { + stmt, _, _ := internal.UtilNormalizeWithDefaultDB(t, test.sql) + n, digest := norm.NormalizeStmtForBinding(stmt, norm.WithFuzz(true)) + require.Equal(t, test.normalized, n) + require.Equal(t, test.digest, digest) + } +} diff --git a/pkg/parser/digester.go b/pkg/parser/digester.go index e120cc40f8bd5..d3ac73ab6bdf1 100644 --- a/pkg/parser/digester.go +++ b/pkg/parser/digester.go @@ -229,13 +229,15 @@ func (d *sqlDigester) normalize(sql string, keepHint bool, forBinding bool, forP continue } - d.reduceLit(&currTok) + d.reduceLit(&currTok, forBinding) if forPlanReplayerReload { // Apply for plan replayer to match specific rules, changing IN (...) to IN (?). This can avoid plan replayer load failures caused by parse errors. d.replaceSingleLiteralWithInList(&currTok) } else if forBinding { // Apply binding matching specific rules, IN (?) => IN ( ... ) #44298 d.reduceInListWithSingleLiteral(&currTok) + // In (Row(...)) => In (...) #51222 + d.reduceInRowListWithSingleLiteral(&currTok) } if currTok.tok == identifier { @@ -311,7 +313,7 @@ func (d *sqlDigester) reduceOptimizerHint(tok *token) (reduced bool) { return } -func (d *sqlDigester) reduceLit(currTok *token) { +func (d *sqlDigester) reduceLit(currTok *token, forBinding bool) { if !d.isLit(*currTok) { return } @@ -346,6 +348,17 @@ func (d *sqlDigester) reduceLit(currTok *token) { currTok.lit = "..." return } + // reduce "In (row(...), row(...))" to "In (row(...))" + // final, it will be reduced to "In (...)". Issue: #51222 + if forBinding { + last9 := d.tokens.back(9) + if d.isGenericRowListsWithIn(last9) { + d.tokens.popBack(5) + currTok.tok = genericSymbolList + currTok.lit = "..." + return + } + } // order by n => order by n if currTok.tok == intLit { @@ -378,6 +391,41 @@ func (d *sqlDigester) isGenericLists(last4 []token) bool { return true } +// In (Row(...), Row(...)) => In (Row(...)) +func (d *sqlDigester) isGenericRowListsWithIn(last9 []token) bool { + if len(last9) < 7 { + return false + } + if !d.isInKeyword(last9[0]) { + return false + } + if last9[1].lit != "(" { + return false + } + if !d.isRowKeyword(last9[2]) { + return false + } + if last9[3].lit != "(" { + return false + } + if !(last9[4].tok == genericSymbol || last9[4].tok == genericSymbolList) { + return false + } + if last9[5].lit != ")" { + return false + } + if !d.isComma(last9[6]) { + return false + } + if !d.isRowKeyword(last9[7]) { + return false + } + if last9[8].lit != "(" { + return false + } + return true +} + // IN (...) => IN (?) Issue: #43192 func (d *sqlDigester) replaceSingleLiteralWithInList(currTok *token) { last5 := d.tokens.back(5) @@ -408,6 +456,23 @@ func (d *sqlDigester) reduceInListWithSingleLiteral(currTok *token) { } } +// In (Row(...)) => In (...) #51222 +func (d *sqlDigester) reduceInRowListWithSingleLiteral(currTok *token) { + last5 := d.tokens.back(6) + if len(last5) == 6 && + d.isInKeyword(last5[0]) && + d.isLeftParen(last5[1]) && + d.isRowKeyword(last5[2]) && + d.isLeftParen(last5[3]) && + (last5[4].tok == genericSymbolList || last5[4].tok == genericSymbol) && + d.isRightParen(last5[5]) && + d.isRightParen(*currTok) { + d.tokens.popBack(4) + d.tokens.pushBack(token{genericSymbolList, "..."}) + return + } +} + func (d *sqlDigester) isPrefixByUnary(currTok int) (isUnary bool) { if !d.isNumLit(currTok) { return @@ -531,6 +596,11 @@ func (*sqlDigester) isInKeyword(tok token) (isInKeyword bool) { return } +func (*sqlDigester) isRowKeyword(tok token) (isRowKeyword bool) { + isRowKeyword = tok.lit == "row" + return +} + type token struct { tok int lit string