From a5b47cfbcc746730dfd7d37a572d664bae895d6a Mon Sep 17 00:00:00 2001 From: Jason Mo Date: Mon, 20 Dec 2021 16:11:45 +0800 Subject: [PATCH] add interface --- bindinfo/handle.go | 6 ++--- ddl/column.go | 4 +-- ddl/partition.go | 4 +-- ddl/reorg.go | 2 +- ddl/util/util.go | 2 +- domain/sysvar_cache.go | 2 +- executor/analyze.go | 2 +- executor/brie.go | 2 +- executor/ddl.go | 2 +- executor/infoschema_reader.go | 4 +-- executor/inspection_profile.go | 2 +- executor/inspection_result.go | 30 +++++++++++----------- executor/inspection_summary.go | 2 +- executor/metrics_reader.go | 4 +-- executor/opt_rule_blacklist.go | 2 +- executor/reload_expr_pushdown_blacklist.go | 2 +- executor/show.go | 8 +++--- executor/show_placement.go | 2 +- executor/simple.go | 8 +++--- executor/simple_test.go | 22 ++++++++++++++++ expression/util.go | 2 +- session/session.go | 20 +++++++++++---- session/session_test.go | 17 +++++------- session/tidb_test.go | 2 +- statistics/handle/handle.go | 8 +++--- telemetry/data_cluster_hardware.go | 2 +- telemetry/data_cluster_info.go | 2 +- telemetry/data_feature_usage.go | 2 +- util/admin/admin.go | 4 +-- util/gcutil/gcutil.go | 2 +- util/sqlexec/restricted_sql_executor.go | 2 ++ 31 files changed, 102 insertions(+), 73 deletions(-) diff --git a/bindinfo/handle.go b/bindinfo/handle.go index b1920047a20b6..d757c9fb578c9 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -131,7 +131,7 @@ func (h *BindHandle) Update(fullLoad bool) (err error) { } exec := h.sctx.Context.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source + stmt, err := exec.ParseWithParamsInternal(context.TODO(), `SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source FROM mysql.bind_info WHERE update_time > %? ORDER BY update_time, create_time`, updateTime) if err != nil { return err @@ -697,7 +697,7 @@ func (h *BindHandle) extractCaptureFilterFromStorage() (filter *captureFilter) { tables: make(map[stmtctx.TableEntry]struct{}), } exec := h.sctx.Context.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT filter_type, filter_value FROM mysql.capture_plan_baselines_blacklist order by filter_type`) + stmt, err := exec.ParseWithParamsInternal(context.TODO(), `SELECT filter_type, filter_value FROM mysql.capture_plan_baselines_blacklist order by filter_type`) if err != nil { logutil.BgLogger().Warn("[sql-bind] failed to parse query for mysql.capture_plan_baselines_blacklist load", zap.Error(err)) return @@ -923,7 +923,7 @@ func (h *BindHandle) SaveEvolveTasksToStore() { } func getEvolveParameters(ctx sessionctx.Context) (time.Duration, time.Time, time.Time, error) { - stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams( + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParamsInternal( context.TODO(), "SELECT variable_name, variable_value FROM mysql.global_variables WHERE variable_name IN (%?, %?, %?)", variable.TiDBEvolvePlanTaskMaxTime, diff --git a/ddl/column.go b/ddl/column.go index 1507e7437982c..bbcb8fc0051dc 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -996,7 +996,7 @@ func (w *worker) doModifyColumnTypeWithData( } defer w.sessPool.put(ctx) - stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), valStr) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParamsInternal(context.Background(), valStr) if err != nil { job.State = model.JobStateCancelled failpoint.Return(ver, err) @@ -1703,7 +1703,7 @@ func checkForNullValue(ctx context.Context, sctx sessionctx.Context, isDataTrunc } } buf.WriteString(" limit 1") - stmt, err := sctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(ctx, buf.String(), paramsList...) + stmt, err := sctx.(sqlexec.RestrictedSQLExecutor).ParseWithParamsInternal(ctx, buf.String(), paramsList...) if err != nil { return errors.Trace(err) } diff --git a/ddl/partition.go b/ddl/partition.go index 6f6d3b8b0ff59..87a536d5181de 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -1551,7 +1551,7 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde } defer w.sessPool.put(ctx) - stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(w.ddlJobCtx, sql, paramList...) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParamsInternal(w.ddlJobCtx, sql, paramList...) if err != nil { return errors.Trace(err) } @@ -1569,7 +1569,7 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde func buildCheckSQLForRangeExprPartition(pi *model.PartitionInfo, index int, schemaName, tableName model.CIStr) (string, []interface{}) { var buf strings.Builder paramList := make([]interface{}, 0, 4) - // Since the pi.Expr string may contain the identifier, which couldn't be escaped in our ParseWithParams(...) + // Since the pi.Expr string may contain the identifier, which couldn't be escaped in our ParseWithParamsInternal(...) // So we write it to the origin sql string here. if index == 0 { buf.WriteString("select 1 from %n.%n where ") diff --git a/ddl/reorg.go b/ddl/reorg.go index d1a8f5b0b4b59..54cda19a7974d 100644 --- a/ddl/reorg.go +++ b/ddl/reorg.go @@ -341,7 +341,7 @@ func getTableTotalCount(w *worker, tblInfo *model.TableInfo) int64 { return statistics.PseudoRowCount } sql := "select table_rows from information_schema.tables where tidb_table_id=%?;" - stmt, err := executor.ParseWithParams(w.ddlJobCtx, sql, tblInfo.ID) + stmt, err := executor.ParseWithParamsInternal(w.ddlJobCtx, sql, tblInfo.ID) if err != nil { return statistics.PseudoRowCount } diff --git a/ddl/util/util.go b/ddl/util/util.go index 62dba84cb4e62..993c0c226f6a0 100644 --- a/ddl/util/util.go +++ b/ddl/util/util.go @@ -176,7 +176,7 @@ func LoadGlobalVars(ctx context.Context, sctx sessionctx.Context, varNames []str paramNames = append(paramNames, name) } buf.WriteString(")") - stmt, err := e.ParseWithParams(ctx, buf.String(), paramNames...) + stmt, err := e.ParseWithParamsInternal(ctx, buf.String(), paramNames...) if err != nil { return errors.Trace(err) } diff --git a/domain/sysvar_cache.go b/domain/sysvar_cache.go index d89ba88a76ee0..bb0ff2d0c9ab0 100644 --- a/domain/sysvar_cache.go +++ b/domain/sysvar_cache.go @@ -94,7 +94,7 @@ func (do *Domain) fetchTableValues(ctx sessionctx.Context) (map[string]string, e tableContents := make(map[string]string) // Copy all variables from the table to tableContents exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.Background(), `SELECT variable_name, variable_value FROM mysql.global_variables`) + stmt, err := exec.ParseWithParamsInternal(context.Background(), `SELECT variable_name, variable_value FROM mysql.global_variables`) if err != nil { return tableContents, err } diff --git a/executor/analyze.go b/executor/analyze.go index 9414b2fbade1e..6388f0875a9fd 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -1570,7 +1570,7 @@ type AnalyzeFastExec struct { func (e *AnalyzeFastExec) calculateEstimateSampleStep() (err error) { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) var stmt ast.StmtNode - stmt, err = exec.ParseWithParams(context.TODO(), "select flag from mysql.stats_histograms where table_id = %?", e.tableID.GetStatisticsID()) + stmt, err = exec.ParseWithParamsInternal(context.TODO(), "select flag from mysql.stats_histograms where table_id = %?", e.tableID.GetStatisticsID()) if err != nil { return } diff --git a/executor/brie.go b/executor/brie.go index 9be7349c494a8..e72958d9d5f7d 100644 --- a/executor/brie.go +++ b/executor/brie.go @@ -462,7 +462,7 @@ func (gs *tidbGlueSession) CreateSession(store kv.Storage) (glue.Session, error) // These queries execute without privilege checking, since the calling statements // such as BACKUP and RESTORE have already been privilege checked. func (gs *tidbGlueSession) Execute(ctx context.Context, sql string) error { - stmt, err := gs.se.(sqlexec.RestrictedSQLExecutor).ParseWithParams(ctx, sql) + stmt, err := gs.se.(sqlexec.RestrictedSQLExecutor).ParseWithParamsInternal(ctx, sql) if err != nil { return err } diff --git a/executor/ddl.go b/executor/ddl.go index 4c2be5828b8a8..035877191a228 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -506,7 +506,7 @@ func (e *DDLExec) dropTableObject(objects []*ast.TableName, obt objectType, ifEx zap.String("table", fullti.Name.O), ) exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), "admin check table %n.%n", fullti.Schema.O, fullti.Name.O) + stmt, err := exec.ParseWithParamsInternal(context.TODO(), "admin check table %n.%n", fullti.Schema.O, fullti.Name.O) if err != nil { return err } diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index 1e4fcae3829ba..06d35d24fce1d 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -190,7 +190,7 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex func getRowCountAllTable(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, error) { exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "select table_id, count from mysql.stats_meta") + stmt, err := exec.ParseWithParamsInternal(ctx, "select table_id, count from mysql.stats_meta") if err != nil { return nil, err } @@ -215,7 +215,7 @@ type tableHistID struct { func getColLengthAllTables(ctx context.Context, sctx sessionctx.Context) (map[tableHistID]uint64, error) { exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") + stmt, err := exec.ParseWithParamsInternal(ctx, "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") if err != nil { return nil, err } diff --git a/executor/inspection_profile.go b/executor/inspection_profile.go index 90ea2bc224c5f..b8ab8b383df43 100644 --- a/executor/inspection_profile.go +++ b/executor/inspection_profile.go @@ -167,7 +167,7 @@ func (n *metricNode) getLabelValue(label string) *metricValue { func (n *metricNode) queryRowsByLabel(pb *profileBuilder, query string, handleRowFn func(label string, v float64)) error { exec := pb.sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), query) + stmt, err := exec.ParseWithParamsInternal(context.TODO(), query) if err != nil { return err } diff --git a/executor/inspection_result.go b/executor/inspection_result.go index bb26993824684..4062dab3c3176 100644 --- a/executor/inspection_result.go +++ b/executor/inspection_result.go @@ -140,7 +140,7 @@ func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionct e.statusToInstanceAddress = make(map[string]string) var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "select instance,status_address from information_schema.cluster_info;") + stmt, err := exec.ParseWithParamsInternal(ctx, "select instance,status_address from information_schema.cluster_info;") if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -251,7 +251,7 @@ func (configInspection) inspectDiffConfig(ctx context.Context, sctx sessionctx.C } var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in (%?) group by type, `key` having c > 1", ignoreConfigKey) + stmt, err := exec.ParseWithParamsInternal(ctx, "select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in (%?) group by type, `key` having c > 1", ignoreConfigKey) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -261,7 +261,7 @@ func (configInspection) inspectDiffConfig(ctx context.Context, sctx sessionctx.C generateDetail := func(tp, item string) string { var rows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, "select value, instance from information_schema.cluster_config where type=%? and `key`=%?;", tp, item) + stmt, err := exec.ParseWithParamsInternal(ctx, "select value, instance from information_schema.cluster_config where type=%? and `key`=%?;", tp, item) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -347,7 +347,7 @@ func (c configInspection) inspectCheckConfig(ctx context.Context, sctx sessionct } sql.Reset() fmt.Fprintf(sql, "select type,instance,value from information_schema.%s where %s", cas.table, cas.cond) - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -378,7 +378,7 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct } var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'") + stmt, err := exec.ParseWithParamsInternal(ctx, "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'") if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -405,7 +405,7 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct ipToCount[ip]++ } - stmt, err = exec.ParseWithParams(ctx, "select instance, value from metrics_schema.node_total_memory where time=now()") + stmt, err = exec.ParseWithParamsInternal(ctx, "select instance, value from metrics_schema.node_total_memory where time=now()") if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -473,7 +473,7 @@ func (versionInspection) inspect(ctx context.Context, sctx sessionctx.Context, f exec := sctx.(sqlexec.RestrictedSQLExecutor) var rows []chunk.Row // check the configuration consistent - stmt, err := exec.ParseWithParams(ctx, "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;") + stmt, err := exec.ParseWithParamsInternal(ctx, "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;") if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -643,7 +643,7 @@ func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx sql.Reset() fmt.Fprintf(sql, "select `%[1]s`,sum(value) as total from `%[2]s`.`%[3]s` %[4]s group by `%[1]s` having total>=1.0", strings.Join(def.Labels, "`,`"), util.MetricSchemaName.L, rule.tbl, condition) - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -698,7 +698,7 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se (select instance,job from metrics_schema.up %[1]s group by instance,job having max(value)-min(value)>0) as t1 join (select instance,min(time) as min_time from metrics_schema.up %[1]s and value=0 group by instance,job) as t2 on t1.instance=t2.instance order by job`, condition) var rows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -726,7 +726,7 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se // Check from log. sql.Reset() fmt.Fprintf(sql, "select type,instance,time from information_schema.cluster_log %s and level = 'info' and message like '%%Welcome to'", condition) - stmt, err = exec.ParseWithParams(ctx, sql.String()) + stmt, err = exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -863,7 +863,7 @@ func (thresholdCheckInspection) inspectThreshold1(ctx context.Context, sctx sess (select instance, max(value) as cpu from metrics_schema.tikv_thread_cpu %[3]s and name like '%[1]s' group by instance) as t1 where t1.cpu > %[2]f;`, rule.component, rule.threshold, condition) } - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -1036,7 +1036,7 @@ func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sess } else { fmt.Fprintf(sql, "select instance, max(value)/%.0f as max_value from metrics_schema.%s %s group by instance having max_value > %f;", rule.factor, rule.tbl, cond, rule.threshold) } - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -1222,7 +1222,7 @@ func checkRules(ctx context.Context, sctx sessionctx.Context, filter inspectionF continue } sql := rule.genSQL(filter.timeRange) - stmt, err := exec.ParseWithParams(ctx, sql) + stmt, err := exec.ParseWithParamsInternal(ctx, sql) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -1245,7 +1245,7 @@ func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx exec := sctx.(sqlexec.RestrictedSQLExecutor) var rows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -1259,7 +1259,7 @@ func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx sql.Reset() fmt.Fprintf(sql, `select time, value from metrics_schema.pd_scheduler_store_status %s and type='leader_count' and address = '%s' order by time`, condition, address) var subRows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, sql.String()) + stmt, err := exec.ParseWithParamsInternal(ctx, sql.String()) if err == nil { subRows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } diff --git a/executor/inspection_summary.go b/executor/inspection_summary.go index ffd235451cebb..e709373363e74 100644 --- a/executor/inspection_summary.go +++ b/executor/inspection_summary.go @@ -460,7 +460,7 @@ func (e *inspectionSummaryRetriever) retrieve(ctx context.Context, sctx sessionc util.MetricSchemaName.L, name, cond) } exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, sql) + stmt, err := exec.ParseWithParamsInternal(ctx, sql) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/metrics_reader.go b/executor/metrics_reader.go index d6df64dc7a377..2ef9f1196ad94 100644 --- a/executor/metrics_reader.go +++ b/executor/metrics_reader.go @@ -233,7 +233,7 @@ func (e *MetricsSummaryRetriever) retrieve(ctx context.Context, sctx sessionctx. } exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, sql) + stmt, err := exec.ParseWithParamsInternal(ctx, sql) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } @@ -318,7 +318,7 @@ func (e *MetricsSummaryByLabelRetriever) retrieve(ctx context.Context, sctx sess util.MetricSchemaName.L, name, cond) } exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, sql) + stmt, err := exec.ParseWithParamsInternal(ctx, sql) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/opt_rule_blacklist.go b/executor/opt_rule_blacklist.go index 0d915c9eb6966..b013087d925f2 100644 --- a/executor/opt_rule_blacklist.go +++ b/executor/opt_rule_blacklist.go @@ -37,7 +37,7 @@ func (e *ReloadOptRuleBlacklistExec) Next(ctx context.Context, _ *chunk.Chunk) e // LoadOptRuleBlacklist loads the latest data from table mysql.opt_rule_blacklist. func LoadOptRuleBlacklist(ctx sessionctx.Context) (err error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name from mysql.opt_rule_blacklist") + stmt, err := exec.ParseWithParamsInternal(context.TODO(), "select HIGH_PRIORITY name from mysql.opt_rule_blacklist") if err != nil { return err } diff --git a/executor/reload_expr_pushdown_blacklist.go b/executor/reload_expr_pushdown_blacklist.go index 81c8ea4f3cccb..0284edba08f38 100644 --- a/executor/reload_expr_pushdown_blacklist.go +++ b/executor/reload_expr_pushdown_blacklist.go @@ -39,7 +39,7 @@ func (e *ReloadExprPushdownBlacklistExec) Next(ctx context.Context, _ *chunk.Chu // LoadExprPushdownBlacklist loads the latest data from table mysql.expr_pushdown_blacklist. func LoadExprPushdownBlacklist(ctx sessionctx.Context) (err error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist") + stmt, err := exec.ParseWithParamsInternal(context.TODO(), "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist") if err != nil { return err } diff --git a/executor/show.go b/executor/show.go index 2a4eac148f74e..ea4bdf4f1aee8 100644 --- a/executor/show.go +++ b/executor/show.go @@ -342,7 +342,7 @@ func (e *ShowExec) fetchShowBind() error { func (e *ShowExec) fetchShowEngines(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, `SELECT * FROM information_schema.engines`) + stmt, err := exec.ParseWithParamsInternal(ctx, `SELECT * FROM information_schema.engines`) if err != nil { return errors.Trace(err) } @@ -473,7 +473,7 @@ func (e *ShowExec) fetchShowTableStatus(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, `SELECT + stmt, err := exec.ParseWithParamsInternal(ctx, `SELECT table_name, engine, version, row_format, table_rows, avg_row_length, data_length, max_data_length, index_length, data_free, auto_increment, create_time, update_time, check_time, @@ -1433,7 +1433,7 @@ func (e *ShowExec) fetchShowCreateUser(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, `SELECT plugin FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.UserTable, userName, strings.ToLower(hostName)) + stmt, err := exec.ParseWithParamsInternal(ctx, `SELECT plugin FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.UserTable, userName, strings.ToLower(hostName)) if err != nil { return errors.Trace(err) } @@ -1453,7 +1453,7 @@ func (e *ShowExec) fetchShowCreateUser(ctx context.Context) error { authplugin = rows[0].GetString(0) } - stmt, err = exec.ParseWithParams(ctx, `SELECT Priv FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) + stmt, err = exec.ParseWithParamsInternal(ctx, `SELECT Priv FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) if err != nil { return errors.Trace(err) } diff --git a/executor/show_placement.go b/executor/show_placement.go index d77c0a31000f1..a53a86ffba019 100644 --- a/executor/show_placement.go +++ b/executor/show_placement.go @@ -107,7 +107,7 @@ func (b *showPlacementLabelsResultBuilder) sortMapKeys(m map[string]interface{}) func (e *ShowExec) fetchShowPlacementLabels(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, "SELECT DISTINCT LABEL FROM %n.%n", "INFORMATION_SCHEMA", infoschema.TableTiKVStoreStatus) + stmt, err := exec.ParseWithParamsInternal(ctx, "SELECT DISTINCT LABEL FROM %n.%n", "INFORMATION_SCHEMA", infoschema.TableTiKVStoreStatus) if err != nil { return errors.Trace(err) } diff --git a/executor/simple.go b/executor/simple.go index 82e941eca143f..ee6932ad1dc91 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -966,7 +966,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) if !ok { return errors.Trace(ErrPasswordFormat) } - stmt, err := exec.ParseWithParams(ctx, + stmt, err := exec.ParseWithParamsInternal(ctx, `UPDATE %n.%n SET authentication_string=%?, plugin=%? WHERE Host=%? and User=%?;`, mysql.SystemDB, mysql.UserTable, pwd, spec.AuthOpt.AuthPlugin, strings.ToLower(spec.User.Hostname), spec.User.Username, ) @@ -980,7 +980,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) } if len(privData) > 0 { - stmt, err := exec.ParseWithParams(ctx, "INSERT INTO %n.%n (Host, User, Priv) VALUES (%?,%?,%?) ON DUPLICATE KEY UPDATE Priv = values(Priv)", mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, string(hack.String(privData))) + stmt, err := exec.ParseWithParamsInternal(ctx, "INSERT INTO %n.%n (Host, User, Priv) VALUES (%?,%?,%?) ON DUPLICATE KEY UPDATE Priv = values(Priv)", mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, string(hack.String(privData))) if err != nil { return err } @@ -1358,7 +1358,7 @@ func (e *SimpleExec) executeDropUser(ctx context.Context, s *ast.DropUserStmt) e func userExists(ctx context.Context, sctx sessionctx.Context, name string, host string) (bool, error) { exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, strings.ToLower(host)) + stmt, err := exec.ParseWithParamsInternal(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, strings.ToLower(host)) if err != nil { return false, err } @@ -1441,7 +1441,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error // update mysql.user exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, pwd, u, strings.ToLower(h)) + stmt, err := exec.ParseWithParamsInternal(ctx, `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, pwd, u, strings.ToLower(h)) if err != nil { return err } diff --git a/executor/simple_test.go b/executor/simple_test.go index b8dc034076ec7..6d83b91b08485 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -988,3 +988,25 @@ func (s *testSuite3) TestDropRoleAfterRevoke(c *C) { tk.MustExec("revoke r1, r3 from root;") tk.MustExec("drop role r1;") } + +func (s *testSuiteWithCliBaseCharset) TestUserWithSetNames(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("set names gbk;") + + gbkString := string([]byte{0xD2, 0xBB}) + + tk.MustExec("drop user if exists '一'@'localhost';") + tk.MustExec("create user '一'@'localhost' IDENTIFIED BY '" + gbkString + "';") + + result := tk.MustQuery(`SELECT authentication_string FROM mysql.User WHERE User="一" and Host="localhost";`) + result.Check(testkit.Rows(auth.EncodePassword("一"))) + + tk.MustExec(`ALTER USER '一'@'localhost' IDENTIFIED BY '` + gbkString + gbkString + `';`) + result = tk.MustQuery(`SELECT authentication_string FROM mysql.User WHERE User="一" and Host="localhost";`) + result.Check(testkit.Rows(auth.EncodePassword("一一"))) + + tk.MustExec(`RENAME USER '一'@'localhost' to '一'`) + + tk.MustExec("drop user '一';") +} diff --git a/expression/util.go b/expression/util.go index d7b92329d51f6..3a793cfddc640 100644 --- a/expression/util.go +++ b/expression/util.go @@ -1145,7 +1145,7 @@ func (r *SQLDigestTextRetriever) runFetchDigestQuery(ctx context.Context, sctx s stmt += " where digest in (" + strings.Repeat("%?,", len(inValues)-1) + "%?)" } - stmtNode, err := exec.ParseWithParams(ctx, stmt, inValues...) + stmtNode, err := exec.ParseWithParamsInternal(ctx, stmt, inValues...) if err != nil { return nil, err } diff --git a/session/session.go b/session/session.go index c57628d090f44..a7976a5001bee 100644 --- a/session/session.go +++ b/session/session.go @@ -124,9 +124,9 @@ type Session interface { Execute(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a sql statement. // ExecuteStmt executes a parsed statement. ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error) - // Parse is deprecated, use ParseWithParams() instead. + // Parse is deprecated, use ParseWithParams() or ParseWithParamsInternal() instead. Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) - // ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements. + // ExecuteInternal is a helper around ParseWithParamsInternal() and ExecuteStmt(). It is not allowed to execute multiple statements. ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error) String() string // String is used to debug. CommitTxn(context.Context) error @@ -1153,7 +1153,7 @@ func drainRecordSet(ctx context.Context, se *session, rs sqlexec.RecordSet, allo // getTableValue executes restricted sql and the result is one column. // It returns a string value. func (s *session) getTableValue(ctx context.Context, tblName string, varName string) (string, error) { - stmt, err := s.ParseWithParams(ctx, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) + stmt, err := s.ParseWithParamsInternal(ctx, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) if err != nil { return "", err } @@ -1175,7 +1175,7 @@ func (s *session) getTableValue(ctx context.Context, tblName string, varName str // replaceGlobalVariablesTableValue executes restricted sql updates the variable value // It will then notify the etcd channel that the value has changed. func (s *session) replaceGlobalVariablesTableValue(ctx context.Context, varName, val string) error { - stmt, err := s.ParseWithParams(ctx, `REPLACE INTO %n.%n (variable_name, variable_value) VALUES (%?, %?)`, mysql.SystemDB, mysql.GlobalVariablesTable, varName, val) + stmt, err := s.ParseWithParamsInternal(ctx, `REPLACE INTO %n.%n (variable_name, variable_value) VALUES (%?, %?)`, mysql.SystemDB, mysql.GlobalVariablesTable, varName, val) if err != nil { return err } @@ -1250,7 +1250,7 @@ func (s *session) SetGlobalSysVarOnly(name, value string) (err error) { // SetTiDBTableValue implements GlobalVarAccessor.SetTiDBTableValue interface. func (s *session) SetTiDBTableValue(name, value, comment string) error { - stmt, err := s.ParseWithParams(context.TODO(), `REPLACE INTO mysql.tidb (variable_name, variable_value, comment) VALUES (%?, %?, %?)`, name, value, comment) + stmt, err := s.ParseWithParamsInternal(context.TODO(), `REPLACE INTO mysql.tidb (variable_name, variable_value, comment) VALUES (%?, %?, %?)`, name, value, comment) if err != nil { return err } @@ -1520,6 +1520,16 @@ func (s *session) ParseWithParams(ctx context.Context, sql string, args ...inter return stmts[0], nil } +// ParseWithParamsInternal is same as ParseWithParams except set `s.sessionVars.InRestrictedSQL = true` +func (s *session) ParseWithParamsInternal(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) { + origin := s.sessionVars.InRestrictedSQL + s.sessionVars.InRestrictedSQL = true + defer func() { + s.sessionVars.InRestrictedSQL = origin + }() + return s.ParseWithParams(ctx, sql, args...) +} + // ExecRestrictedStmt implements RestrictedSQLExecutor interface. func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( []chunk.Row, []*ast.ResultField, error) { diff --git a/session/session_test.go b/session/session_test.go index 7b5febe0d18e0..4602758eeaa28 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -4377,22 +4377,17 @@ func (s *testSessionSerialSuite) TestProcessInfoIssue22068(c *C) { wg.Wait() } -func (s *testSessionSerialSuite) TestParseWithParams(c *C) { +func (s *testSessionSerialSuite) TestParseWithParamsInternal(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) se := tk.Se exec := se.(sqlexec.RestrictedSQLExecutor) // test compatibility with ExcuteInternal - origin := se.GetSessionVars().InRestrictedSQL - se.GetSessionVars().InRestrictedSQL = true - defer func() { - se.GetSessionVars().InRestrictedSQL = origin - }() - _, err := exec.ParseWithParams(context.TODO(), "SELECT 4") + _, err := exec.ParseWithParamsInternal(context.TODO(), "SELECT 4") c.Assert(err, IsNil) // test charset attack - stmt, err := exec.ParseWithParams(context.TODO(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") + stmt, err := exec.ParseWithParamsInternal(context.TODO(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") c.Assert(err, IsNil) var sb strings.Builder @@ -4402,15 +4397,15 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) { c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\"\xbf' OR 1=1 /*\" LIMIT 1") // test invalid sql - _, err = exec.ParseWithParams(context.TODO(), "SELECT") + _, err = exec.ParseWithParamsInternal(context.TODO(), "SELECT") c.Assert(err, ErrorMatches, ".*You have an error in your SQL syntax.*") // test invalid arguments to escape - _, err = exec.ParseWithParams(context.TODO(), "SELECT %?, %?", 3) + _, err = exec.ParseWithParamsInternal(context.TODO(), "SELECT %?, %?", 3) c.Assert(err, ErrorMatches, "missing arguments.*") // test noescape - stmt, err = exec.ParseWithParams(context.TODO(), "SELECT 3") + stmt, err = exec.ParseWithParamsInternal(context.TODO(), "SELECT 3") c.Assert(err, IsNil) sb.Reset() diff --git a/session/tidb_test.go b/session/tidb_test.go index 661a8f19d5f47..759eaa02702f4 100644 --- a/session/tidb_test.go +++ b/session/tidb_test.go @@ -37,7 +37,7 @@ func TestSysSessionPoolGoroutineLeak(t *testing.T) { count := 200 stmts := make([]ast.StmtNode, count) for i := 0; i < count; i++ { - stmt, err := se.ParseWithParams(context.Background(), "select * from mysql.user limit 1") + stmt, err := se.ParseWithParamsInternal(context.Background(), "select * from mysql.user limit 1") require.NoError(t, err) stmts[i] = stmt } diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 0889d00e431e5..50876ccdf8661 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -128,7 +128,7 @@ func (h *Handle) withRestrictedSQLExecutor(ctx context.Context, fn func(context. func (h *Handle) execRestrictedSQL(ctx context.Context, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { - stmt, err := exec.ParseWithParams(ctx, sql, params...) + stmt, err := exec.ParseWithParamsInternal(ctx, sql, params...) if err != nil { return nil, nil, errors.Trace(err) } @@ -138,7 +138,7 @@ func (h *Handle) execRestrictedSQL(ctx context.Context, sql string, params ...in func (h *Handle) execRestrictedSQLWithStatsVer(ctx context.Context, statsVer int, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { - stmt, err := exec.ParseWithParams(ctx, sql, params...) + stmt, err := exec.ParseWithParamsInternal(ctx, sql, params...) // TODO: An ugly way to set @@tidb_partition_prune_mode. Need to be improved. if _, ok := stmt.(*ast.AnalyzeTableStmt); ok { pruneMode := h.CurrentPruneMode() @@ -155,7 +155,7 @@ func (h *Handle) execRestrictedSQLWithStatsVer(ctx context.Context, statsVer int func (h *Handle) execRestrictedSQLWithSnapshot(ctx context.Context, sql string, snapshot uint64, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { - stmt, err := exec.ParseWithParams(ctx, sql, params...) + stmt, err := exec.ParseWithParamsInternal(ctx, sql, params...) if err != nil { return nil, nil, errors.Trace(err) } @@ -1385,7 +1385,7 @@ type statsReader struct { func (sr *statsReader) read(sql string, args ...interface{}) (rows []chunk.Row, fields []*ast.ResultField, err error) { ctx := context.TODO() - stmt, err := sr.ctx.ParseWithParams(ctx, sql, args...) + stmt, err := sr.ctx.ParseWithParamsInternal(ctx, sql, args...) if err != nil { return nil, nil, errors.Trace(err) } diff --git a/telemetry/data_cluster_hardware.go b/telemetry/data_cluster_hardware.go index 611ae2e005384..3c2dca4928d78 100644 --- a/telemetry/data_cluster_hardware.go +++ b/telemetry/data_cluster_hardware.go @@ -69,7 +69,7 @@ func normalizeFieldName(name string) string { func getClusterHardware(ctx sessionctx.Context) ([]*clusterHardwareItem, error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) + stmt, err := exec.ParseWithParamsInternal(context.TODO(), `SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) if err != nil { return nil, errors.Trace(err) } diff --git a/telemetry/data_cluster_info.go b/telemetry/data_cluster_info.go index a1569c3e67634..d1d645a3803d1 100644 --- a/telemetry/data_cluster_info.go +++ b/telemetry/data_cluster_info.go @@ -37,7 +37,7 @@ type clusterInfoItem struct { func getClusterInfo(ctx sessionctx.Context) ([]*clusterInfoItem, error) { // Explicitly list all field names instead of using `*` to avoid potential leaking sensitive info when adding new fields in future. exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) + stmt, err := exec.ParseWithParamsInternal(context.TODO(), `SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) if err != nil { return nil, errors.Trace(err) } diff --git a/telemetry/data_feature_usage.go b/telemetry/data_feature_usage.go index cf32510097423..6255eeb2ec9df 100644 --- a/telemetry/data_feature_usage.go +++ b/telemetry/data_feature_usage.go @@ -77,7 +77,7 @@ func getClusterIndexUsageInfo(ctx sessionctx.Context) (cu *ClusterIndexUsage, er exec := ctx.(sqlexec.RestrictedSQLExecutor) // query INFORMATION_SCHEMA.tables to get the latest table information about ClusterIndex - stmt, err := exec.ParseWithParams(context.TODO(), ` + stmt, err := exec.ParseWithParamsInternal(context.TODO(), ` SELECT left(sha2(TABLE_NAME, 256), 6) table_name_hash, TIDB_PK_TYPE, TABLE_SCHEMA, TABLE_NAME FROM information_schema.tables WHERE table_schema not in ('INFORMATION_SCHEMA', 'METRICS_SCHEMA', 'PERFORMANCE_SCHEMA', 'mysql') diff --git a/util/admin/admin.go b/util/admin/admin.go index 3f68393f52833..9b6bb8c5168ce 100644 --- a/util/admin/admin.go +++ b/util/admin/admin.go @@ -328,7 +328,7 @@ func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices }() // Add `` for some names like `table name`. exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX()", dbName, tableName) + stmt, err := exec.ParseWithParamsInternal(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX()", dbName, tableName) if err != nil { return 0, 0, errors.Trace(err) } @@ -350,7 +350,7 @@ func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices return 0, 0, errors.Trace(err) } for i, idx := range indices { - stmt, err := exec.ParseWithParams(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX(%n)", dbName, tableName, idx) + stmt, err := exec.ParseWithParamsInternal(context.Background(), "SELECT COUNT(*) FROM %n.%n USE INDEX(%n)", dbName, tableName, idx) if err != nil { return 0, i, errors.Trace(err) } diff --git a/util/gcutil/gcutil.go b/util/gcutil/gcutil.go index 9216d37b1ad9a..c11a6ca66d996 100644 --- a/util/gcutil/gcutil.go +++ b/util/gcutil/gcutil.go @@ -72,7 +72,7 @@ func ValidateSnapshotWithGCSafePoint(snapshotTS, safePointTS uint64) error { // GetGCSafePoint loads GC safe point time from mysql.tidb. func GetGCSafePoint(ctx sessionctx.Context) (uint64, error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.Background(), selectVariableValueSQL, "tikv_gc_safe_point") + stmt, err := exec.ParseWithParamsInternal(context.Background(), selectVariableValueSQL, "tikv_gc_safe_point") if err != nil { return 0, errors.Trace(err) } diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 4be2cae5ce12a..165eae7e8c452 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -46,6 +46,8 @@ type RestrictedSQLExecutor interface { // One argument should be a standalone entity. It should not "concat" with other placeholders and characters. // This function only saves you from processing potentially unsafe parameters. ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) + // ParseWithParamsInternal is same as ParseWithParams except set `s.sessionVars.InRestrictedSQL = true` + ParseWithParamsInternal(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) // ExecRestrictedStmt run sql statement in ctx with some restriction. ExecRestrictedStmt(ctx context.Context, stmt ast.StmtNode, opts ...OptionFuncAlias) ([]chunk.Row, []*ast.ResultField, error) }