diff --git a/extension/.gitignore b/extension/.gitignore new file mode 100644 index 0000000000000..6f489871ae2a2 --- /dev/null +++ b/extension/.gitignore @@ -0,0 +1 @@ +_import/generated-*.go diff --git a/extension/_import/BUILD.bazel b/extension/_import/BUILD.bazel new file mode 100644 index 0000000000000..1dbb796e583ce --- /dev/null +++ b/extension/_import/BUILD.bazel @@ -0,0 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "_import", + srcs = ["import.go"], + importpath = "github.com/pingcap/tidb/extension/_import", + visibility = ["//visibility:public"], +) diff --git a/extension/_import/import.go b/extension/_import/import.go new file mode 100644 index 0000000000000..4b089e78e76e2 --- /dev/null +++ b/extension/_import/import.go @@ -0,0 +1,18 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensionimport + +// This package is used to import extensions outside the tidb repo +// The autogenerated files registering each extension will be placed in this folder diff --git a/server/conn.go b/server/conn.go index 9399626e050b4..59348900af4ea 100644 --- a/server/conn.go +++ b/server/conn.go @@ -820,7 +820,7 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e } cc.ctx.SetPort(port) if cc.dbname != "" { - err = cc.useDB(context.Background(), cc.dbname) + _, err = cc.useDB(context.Background(), cc.dbname) if err != nil { return err } @@ -1322,7 +1322,9 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { case mysql.ComQuit: return io.EOF case mysql.ComInitDB: - if err := cc.useDB(ctx, dataStr); err != nil { + node, err := cc.useDB(ctx, dataStr) + cc.onExtensionStmtEnd(node, false, err) + if err != nil { return err } return cc.writeOK(ctx) @@ -1405,19 +1407,19 @@ func (cc *clientConn) writeStats(ctx context.Context) error { return cc.flush(ctx) } -func (cc *clientConn) useDB(ctx context.Context, db string) (err error) { +func (cc *clientConn) useDB(ctx context.Context, db string) (node ast.StmtNode, err error) { // if input is "use `SELECT`", mysql client just send "SELECT" // so we add `` around db. stmts, err := cc.ctx.Parse(ctx, "use `"+db+"`") if err != nil { - return err + return nil, err } _, err = cc.ctx.ExecuteStmt(ctx, stmts[0]) if err != nil { - return err + return nil, err } cc.dbname = db - return + return stmts[0], err } func (cc *clientConn) flush(ctx context.Context) error { @@ -2492,7 +2494,7 @@ func (cc *clientConn) handleResetConnection(ctx context.Context) error { return errors.New("Could not reset connection") } if cc.dbname != "" { // Restore the current DB - err = cc.useDB(context.Background(), cc.dbname) + _, err = cc.useDB(context.Background(), cc.dbname) if err != nil { return err } diff --git a/tidb-server/BUILD.bazel b/tidb-server/BUILD.bazel index 9893f0dccb606..99239b523536e 100644 --- a/tidb-server/BUILD.bazel +++ b/tidb-server/BUILD.bazel @@ -13,6 +13,7 @@ go_library( "//domain/infosync", "//executor", "//extension", + "//extension/_import", "//keyspace", "//kv", "//metrics", diff --git a/tidb-server/main.go b/tidb-server/main.go index e17da7b572d41..c1f1109c35499 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -38,6 +38,7 @@ import ( "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/extension" + _ "github.com/pingcap/tidb/extension/_import" "github.com/pingcap/tidb/keyspace" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics"