diff --git a/executor/plan_replayer.go b/executor/plan_replayer.go index 61464aea02dae..2a589f1668786 100644 --- a/executor/plan_replayer.go +++ b/executor/plan_replayer.go @@ -377,7 +377,7 @@ func dumpVariables(ctx sessionctx.Context, zw *zip.Writer) error { if err != nil { return err } - sRows, err := resultSetToStringSlice(context.Background(), recordSets[0]) + sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], false) if err != nil { return err } @@ -404,7 +404,7 @@ func dumpSessionBindings(ctx sessionctx.Context, zw *zip.Writer) error { if err != nil { return err } - sRows, err := resultSetToStringSlice(context.Background(), recordSets[0]) + sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], true) if err != nil { return err } @@ -428,7 +428,7 @@ func dumpGlobalBindings(ctx sessionctx.Context, zw *zip.Writer) error { if err != nil { return err } - sRows, err := resultSetToStringSlice(context.Background(), recordSets[0]) + sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], false) if err != nil { return err } @@ -463,7 +463,7 @@ func dumpExplain(ctx sessionctx.Context, zw *zip.Writer, sql string, isAnalyze b return err } } - sRows, err := resultSetToStringSlice(context.Background(), recordSets[0]) + sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], false) if err != nil { return err } @@ -527,7 +527,7 @@ func getShowCreateTable(pair tableNamePair, zw *zip.Writer, ctx sessionctx.Conte if err != nil { return err } - sRows, err := resultSetToStringSlice(context.Background(), recordSets[0]) + sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], false) if err != nil { return err } @@ -559,7 +559,7 @@ func getShowCreateTable(pair tableNamePair, zw *zip.Writer, ctx sessionctx.Conte return nil } -func resultSetToStringSlice(ctx context.Context, rs sqlexec.RecordSet) ([][]string, error) { +func resultSetToStringSlice(ctx context.Context, rs sqlexec.RecordSet, emptyAsNil bool) ([][]string, error) { rows, err := getRows(ctx, rs) if err != nil { return nil, err @@ -580,6 +580,9 @@ func resultSetToStringSlice(ctx context.Context, rs sqlexec.RecordSet) ([][]stri if err != nil { return nil, err } + if len(iRow[j]) < 1 && emptyAsNil { + iRow[j] = "" + } } } sRows[i] = iRow @@ -685,6 +688,64 @@ func loadSetTiFlashReplica(ctx sessionctx.Context, z *zip.Reader) error { return nil } +func loadAllBindings(ctx sessionctx.Context, z *zip.Reader) error { + for _, f := range z.File { + if strings.Compare(f.Name, sessionBindingFile) == 0 { + err := loadBindings(ctx, f, true) + if err != nil { + return err + } + } else if strings.Compare(f.Name, globalBindingFile) == 0 { + err := loadBindings(ctx, f, false) + if err != nil { + return err + } + } + } + return nil +} + +func loadBindings(ctx sessionctx.Context, f *zip.File, isSession bool) error { + r, err := f.Open() + if err != nil { + return errors.AddStack(err) + } + //nolint: errcheck + defer r.Close() + buf := new(bytes.Buffer) + _, err = buf.ReadFrom(r) + if err != nil { + return errors.AddStack(err) + } + if len(buf.String()) < 1 { + return nil + } + bindings := strings.Split(buf.String(), "\n") + for _, binding := range bindings { + cols := strings.Split(binding, "\t") + if len(cols) < 3 { + continue + } + originSQL := cols[0] + bindingSQL := cols[1] + enabled := cols[3] + if strings.Compare(enabled, "enabled") == 0 { + sql := fmt.Sprintf("CREATE %s BINDING FOR %s USING %s", func() string { + if isSession { + return "SESSION" + } + return "GLOBAL" + }(), originSQL, bindingSQL) + c := context.Background() + _, err = ctx.(sqlexec.SQLExecutor).Execute(c, sql) + if err != nil { + return err + } + } + } + return nil +} + func loadVariables(ctx sessionctx.Context, z *zip.Reader) error { unLoadVars := make([]string, 0) for _, zipFile := range z.File { @@ -839,5 +900,11 @@ func (e *PlanReplayerLoadInfo) Update(data []byte) error { } } } + + err = loadAllBindings(e.Ctx, z) + if err != nil { + logutil.BgLogger().Warn("load bindings failed", zap.Error(err)) + e.Ctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("load bindings failed, err:%v", err)) + } return nil }