From fe097706c79aabc485290bb3d7e7422546b0949c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=B6=85?= Date: Mon, 24 Oct 2022 12:25:55 +0800 Subject: [PATCH] extension: support bootstrap for extension (#38589) close pingcap/tidb#38588 --- expression/BUILD.bazel | 1 + extension/BUILD.bazel | 3 ++ extension/bootstrap_test.go | 48 +++++++++++++++++ extension/extensionimpl/BUILD.bazel | 16 ++++++ extension/extensionimpl/bootstrap.go | 78 ++++++++++++++++++++++++++++ extension/extensions.go | 16 ++++++ extension/manifest.go | 30 +++++++++++ extension/registry.go | 36 +++++++++---- session/BUILD.bazel | 1 + session/session.go | 5 ++ 10 files changed, 223 insertions(+), 11 deletions(-) create mode 100644 extension/bootstrap_test.go create mode 100644 extension/extensionimpl/BUILD.bazel create mode 100644 extension/extensionimpl/bootstrap.go diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel index 224531f88361e..b5cd5ff84300f 100644 --- a/expression/BUILD.bazel +++ b/expression/BUILD.bazel @@ -99,6 +99,7 @@ go_library( "//util/parser", "//util/plancodec", "//util/printer", + "//util/sem", "//util/set", "//util/size", "//util/sqlexec", diff --git a/extension/BUILD.bazel b/extension/BUILD.bazel index 34e2cddbb2012..abd5b7121b6b7 100644 --- a/extension/BUILD.bazel +++ b/extension/BUILD.bazel @@ -22,6 +22,7 @@ go_library( go_test( name = "extension_test", srcs = [ + "bootstrap_test.go", "function_test.go", "main_test.go", "registry_test.go", @@ -29,12 +30,14 @@ go_test( embed = [":extension"], deps = [ "//expression", + "//parser/auth", "//privilege/privileges", "//sessionctx/variable", "//testkit", "//testkit/testsetup", "//types", "//util/chunk", + "//util/sem", "@com_github_pingcap_errors//:errors", "@com_github_stretchr_testify//require", "@org_uber_go_goleak//:goleak", diff --git a/extension/bootstrap_test.go b/extension/bootstrap_test.go new file mode 100644 index 0000000000000..ae4a7ff03f091 --- /dev/null +++ b/extension/bootstrap_test.go @@ -0,0 +1,48 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extension_test + +import ( + "testing" + + "github.com/pingcap/tidb/extension" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" +) + +func TestBootstrap(t *testing.T) { + defer func() { + extension.Reset() + }() + + extension.Reset() + require.NoError(t, extension.Register("test1", extension.WithBootstrapSQL("create table test.t1 (a int)"))) + require.NoError(t, extension.Register("test2", extension.WithBootstrap(func(ctx extension.BootstrapContext) error { + _, err := ctx.ExecuteSQL(ctx, "insert into test.t1 values(1)") + require.NoError(t, err) + + rows, err := ctx.ExecuteSQL(ctx, "select * from test.t1 where a=1") + require.NoError(t, err) + + require.Equal(t, 1, len(rows)) + require.Equal(t, int64(1), rows[0].GetInt64(0)) + return nil + }))) + require.NoError(t, extension.Setup()) + + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustQuery("select * from test.t1").Check(testkit.Rows("1")) +} diff --git a/extension/extensionimpl/BUILD.bazel b/extension/extensionimpl/BUILD.bazel new file mode 100644 index 0000000000000..4719d7777b43c --- /dev/null +++ b/extension/extensionimpl/BUILD.bazel @@ -0,0 +1,16 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "extensionimpl", + srcs = ["bootstrap.go"], + importpath = "github.com/pingcap/tidb/extension/extensionimpl", + visibility = ["//visibility:public"], + deps = [ + "//domain", + "//extension", + "//kv", + "//util/chunk", + "//util/sqlexec", + "@com_github_pingcap_errors//:errors", + ], +) diff --git a/extension/extensionimpl/bootstrap.go b/extension/extensionimpl/bootstrap.go new file mode 100644 index 0000000000000..0154971e7d9de --- /dev/null +++ b/extension/extensionimpl/bootstrap.go @@ -0,0 +1,78 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensionimpl + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/extension" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/sqlexec" +) + +type bootstrapContext struct { + context.Context + sqlExecutor sqlexec.SQLExecutor +} + +func (c *bootstrapContext) ExecuteSQL(ctx context.Context, sql string) (rows []chunk.Row, err error) { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBootstrap) + rs, err := c.sqlExecutor.ExecuteInternal(ctx, sql) + if err != nil { + return nil, err + } + + if rs == nil { + return nil, nil + } + + defer func() { + closeErr := rs.Close() + if err == nil { + err = closeErr + } + }() + + return sqlexec.DrainRecordSet(ctx, rs, 8) +} + +// Bootstrap bootstrap all extensions +func Bootstrap(ctx context.Context, do *domain.Domain) error { + extensions, err := extension.GetExtensions() + if err != nil { + return err + } + + if extensions == nil { + return nil + } + + pool := do.SysSessionPool() + sctx, err := pool.Get() + if err != nil { + return err + } + defer pool.Put(sctx) + + executor, ok := sctx.(sqlexec.SQLExecutor) + if !ok { + return errors.Errorf("type '%T' cannot be casted to 'sqlexec.SQLExecutor'", sctx) + } + + return extensions.Bootstrap(&bootstrapContext{ctx, executor}) +} diff --git a/extension/extensions.go b/extension/extensions.go index eaa6722eaf951..b7a3000d6aea3 100644 --- a/extension/extensions.go +++ b/extension/extensions.go @@ -28,3 +28,19 @@ func (es *Extensions) Manifests() []*Manifest { copy(manifests, es.manifests) return manifests } + +// Bootstrap bootstrap all extensions +func (es *Extensions) Bootstrap(ctx BootstrapContext) error { + if es == nil { + return nil + } + + for _, m := range es.manifests { + if m.bootstrap != nil { + if err := m.bootstrap(ctx); err != nil { + return err + } + } + } + return nil +} diff --git a/extension/manifest.go b/extension/manifest.go index 0835da1fb4495..c01e95aef24b8 100644 --- a/extension/manifest.go +++ b/extension/manifest.go @@ -15,8 +15,11 @@ package extension import ( + "context" + "github.com/pingcap/errors" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/chunk" ) // Option represents an option to initialize an extension @@ -51,11 +54,38 @@ func WithClose(fn func()) Option { } } +// BootstrapContext is the context used by extension in bootstrap +type BootstrapContext interface { + context.Context + // ExecuteSQL is used to execute a sql + ExecuteSQL(ctx context.Context, sql string) ([]chunk.Row, error) +} + +// WithBootstrap specifies the bootstrap func of an extension +func WithBootstrap(fn func(BootstrapContext) error) Option { + return func(m *Manifest) { + m.bootstrap = fn + } +} + +// WithBootstrapSQL the bootstrap SQL list +func WithBootstrapSQL(sqlList ...string) Option { + return WithBootstrap(func(ctx BootstrapContext) error { + for _, sql := range sqlList { + if _, err := ctx.ExecuteSQL(ctx, sql); err != nil { + return err + } + } + return nil + }) +} + // Manifest is an extension's manifest type Manifest struct { name string sysVariables []*variable.SysVar dynPrivs []string + bootstrap func(BootstrapContext) error funcs []*FunctionDef close func() } diff --git a/extension/registry.go b/extension/registry.go index b60bcf5ee3344..4e838181f9b56 100644 --- a/extension/registry.go +++ b/extension/registry.go @@ -30,14 +30,30 @@ type registry struct { close func() } +// Setup sets up the extensions +func (r *registry) Setup() error { + r.Lock() + defer r.Unlock() + + if _, err := r.doSetup(); err != nil { + return err + } + return nil +} + // Extensions returns the extensions after setup func (r *registry) Extensions() (*Extensions, error) { r.RLock() - defer r.RUnlock() - if !r.setup { - return nil, errors.New("The extensions has not been setup") + if r.setup { + extensions := r.extensions + r.RUnlock() + return extensions, nil } - return r.extensions, nil + r.RUnlock() + + r.Lock() + defer r.Unlock() + return r.doSetup() } // RegisterFactory registers a new extension with a factory @@ -68,17 +84,15 @@ func (r *registry) RegisterFactory(name string, factory func() ([]Option, error) } // Setup setups all extensions -func (r *registry) Setup() (err error) { - r.Lock() - defer r.Unlock() +func (r *registry) doSetup() (_ *Extensions, err error) { if r.setup { - return nil + return r.extensions, nil } if len(r.factories) == 0 { r.extensions = nil r.setup = true - return nil + return nil, nil } clearBuilder := &clearFuncBuilder{} @@ -102,13 +116,13 @@ func (r *registry) Setup() (err error) { }) if err != nil { - return err + return nil, err } } r.extensions = &Extensions{manifests: manifests} r.setup = true r.close = clearBuilder.Build() - return nil + return r.extensions, nil } // Reset resets the registry. It is only used by test diff --git a/session/BUILD.bazel b/session/BUILD.bazel index 73d83cf663cc8..e32172cd4b6ba 100644 --- a/session/BUILD.bazel +++ b/session/BUILD.bazel @@ -25,6 +25,7 @@ go_library( "//errno", "//executor", "//expression", + "//extension/extensionimpl", "//infoschema", "//kv", "//meta", diff --git a/session/session.go b/session/session.go index 684c0fb1b6967..441dd70b7e811 100644 --- a/session/session.go +++ b/session/session.go @@ -49,6 +49,7 @@ import ( "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/extension/extensionimpl" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" @@ -2888,6 +2889,10 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { return nil, err } + if err = extensionimpl.Bootstrap(context.Background(), dom); err != nil { + return nil, err + } + if len(cfg.Instance.PluginLoad) > 0 { err := plugin.Init(context.Background(), plugin.Config{EtcdClient: dom.GetEtcdClient()}) if err != nil {