Skip to content

Commit

Permalink
session: load variables before parsing SQL (#51466)
Browse files Browse the repository at this point in the history
close #51387
  • Loading branch information
YangKeao authored Mar 5, 2024
1 parent 61b66aa commit 5e6cb16
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 1 deletion.
32 changes: 32 additions & 0 deletions pkg/server/internal/testserverclient/server_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package testserverclient

import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -2629,4 +2630,35 @@ func (cli *TestServerClient) RunTestInfoschemaClientErrors(t *testing.T) {
})
}

func (cli *TestServerClient) RunTestSQLModeIsLoadedBeforeQuery(t *testing.T) {
cli.RunTests(t, nil, func(dbt *testkit.DBTestKit) {
ctx := context.Background()

conn, err := dbt.GetDB().Conn(ctx)
require.NoError(t, err)
_, err = conn.ExecContext(ctx, "set global sql_mode='NO_BACKSLASH_ESCAPES';")
require.NoError(t, err)
_, err = conn.ExecContext(ctx, `
CREATE TABLE t1 (
id bigint(20) NOT NULL,
t text DEFAULT NULL,
PRIMARY KEY (id)
);`)
require.NoError(t, err)

// use another new connection
conn1, err := dbt.GetDB().Conn(ctx)
require.NoError(t, err)
_, err = conn1.ExecContext(ctx, "insert into t1 values (1, 'ab\\\\c');")
require.NoError(t, err)
result, err := conn1.QueryContext(ctx, "select t from t1 where id = 1;")
require.NoError(t, err)
require.True(t, result.Next())
var tStr string
require.NoError(t, result.Scan(&tStr))

require.Equal(t, "ab\\\\c", tStr)
})
}

//revive:enable:exported
2 changes: 1 addition & 1 deletion pkg/server/tests/commontest/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ go_test(
"tidb_test.go",
],
flaky = True,
shard_count = 47,
shard_count = 48,
deps = [
"//pkg/config",
"//pkg/ddl/util",
Expand Down
5 changes: 5 additions & 0 deletions pkg/server/tests/commontest/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3067,3 +3067,8 @@ func TestPrepareCount(t *testing.T) {
require.Equal(t, prepareCnt, atomic.LoadInt64(&variable.PreparedStmtCount))
require.NoError(t, qctx.Close())
}

func TestSQLModeIsLoadedBeforeQuery(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)
ts.RunTestSQLModeIsLoadedBeforeQuery(t)
}
7 changes: 7 additions & 0 deletions pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1652,6 +1652,13 @@ func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec
func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) {
logutil.Logger(ctx).Debug("parse", zap.String("sql", sql))
parseStartTime := time.Now()

// Load the session variables to the context.
// This is necessary for the parser to get the current sql_mode.
if err := s.loadCommonGlobalVariablesIfNeeded(); err != nil {
return nil, err
}

stmts, warns, err := s.ParseSQL(ctx, sql, s.sessionVars.GetParseParams()...)
if err != nil {
s.rollbackOnError(ctx)
Expand Down

0 comments on commit 5e6cb16

Please sign in to comment.