diff --git a/executor/update_test.go b/executor/update_test.go index 164b1f5eef45f..7169c5e790cf1 100644 --- a/executor/update_test.go +++ b/executor/update_test.go @@ -87,6 +87,41 @@ func (s *testUpdateSuite) TestUpdateGenColInTxn(c *C) { `1 2`)) } +func (s *testUpdateSuite) TestSafeUpdates(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int, c int, d int, key(a), key(b, c))") + + c.Assert(tk.ExecToErr("delete from t"), IsNil) + c.Assert(tk.ExecToErr("update t set a=2"), IsNil) + + safeUpdatesErr := func(err error) { + c.Assert(err, Not(IsNil)) + c.Assert(err.Error(), Equals, "[planner:1175]You are using safe update mode and you tried to update a table without a WHERE that uses a KEY column") + } + + tk.MustExec("set @@sql_safe_updates = 1") + safeUpdatesErr(tk.ExecToErr("delete from t")) + safeUpdatesErr(tk.ExecToErr("delete from t where 1=1")) + safeUpdatesErr(tk.ExecToErr("delete from t where c=1")) + safeUpdatesErr(tk.ExecToErr("delete from t where d=1")) + safeUpdatesErr(tk.ExecToErr("delete from t where a=1 or 1=1")) + safeUpdatesErr(tk.ExecToErr("update t set a=2")) + safeUpdatesErr(tk.ExecToErr("update t set a=2 where c=1")) + safeUpdatesErr(tk.ExecToErr("update t set a=2 where d=1")) + safeUpdatesErr(tk.ExecToErr("update t set a=2 where a=1 or 1=1")) + + tk.MustExec("delete from t where a=1") + tk.MustExec("delete from t where b=1") + tk.MustExec("delete from t where b>1") + tk.MustExec("delete from t where b=1 and c>0") + tk.MustExec("update t set a=2 where a=1") + tk.MustExec("update t set a=2 where b=1") + tk.MustExec("update t set a=2 where b>1") + tk.MustExec("update t set a=2 where b=1 and c>0") +} + func (s *testUpdateSuite) TestUpdateWithAutoidSchema(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test`) diff --git a/planner/core/errors.go b/planner/core/errors.go index 2794872c23f61..defe54f16c84c 100644 --- a/planner/core/errors.go +++ b/planner/core/errors.go @@ -80,9 +80,11 @@ var ( ErrCartesianProductUnsupported = dbterror.ClassOptimizer.NewStd(mysql.ErrCartesianProductUnsupported) ErrStmtNotFound = dbterror.ClassOptimizer.NewStd(mysql.ErrPreparedStmtNotFound) ErrAmbiguous = dbterror.ClassOptimizer.NewStd(mysql.ErrNonUniq) + ErrUpdateWithoutKeyInSafeMode = dbterror.ClassOptimizer.NewStd(mysql.ErrUpdateWithoutKeyInSafeMode) ErrUnresolvedHintName = dbterror.ClassOptimizer.NewStd(mysql.ErrUnresolvedHintName) ErrNotHintUpdatable = dbterror.ClassOptimizer.NewStd(mysql.ErrNotHintUpdatable) ErrWarnConflictingHint = dbterror.ClassOptimizer.NewStd(mysql.ErrWarnConflictingHint) + // Since we cannot know if user logged in with a password, use message of ErrAccessDeniedNoPassword instead ErrAccessDenied = dbterror.ClassOptimizer.NewStdErr(mysql.ErrAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDeniedNoPassword], "", "") ) diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index e23666ffa8523..4a28bac16cf33 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -655,6 +655,10 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter t = invalidTask candidates := ds.skylinePruning(prop) + svars := ds.ctx.GetSessionVars() + safeUpdateMode := (svars.StmtCtx.InDeleteStmt || svars.StmtCtx.InUpdateStmt) && svars.EnableSafeUpdates + numRangeScanPlans := 0 + cntPlan = 0 for _, candidate := range candidates { path := candidate.path @@ -717,6 +721,9 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter pointGetTask = ds.convertToBatchPointGet(prop, candidate) } if !pointGetTask.invalid() { + if safeUpdateMode && notFullScanPlan(pointGetTask.plan()) { + numRangeScanPlans++ + } cntPlan += 1 planCounter.Dec(1) } @@ -741,6 +748,9 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter return nil, 0, err } if !tblTask.invalid() { + if safeUpdateMode && notFullScanPlan(tblTask.plan()) { + numRangeScanPlans++ + } cntPlan += 1 planCounter.Dec(1) } @@ -761,6 +771,9 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter return nil, 0, err } if !idxTask.invalid() { + if safeUpdateMode && notFullScanPlan(idxTask.plan()) { + numRangeScanPlans++ + } cntPlan += 1 planCounter.Dec(1) } @@ -772,9 +785,29 @@ func (ds *DataSource) findBestTask(prop *property.PhysicalProperty, planCounter } } + if safeUpdateMode && numRangeScanPlans == 0 { + return nil, 0, ErrUpdateWithoutKeyInSafeMode + } return } +func notFullScanPlan(p PhysicalPlan) bool { + if p == nil { + return false + } + switch scan := p.(type) { + case *PhysicalTableScan: + return !scan.isFullScan() + case *PhysicalIndexScan: + return !scan.isFullScan() + case *PhysicalIndexLookUpReader: + return notFullScanPlan(scan.indexPlan) + case *PointGetPlan, *BatchPointGetPlan: + return true + } + return false +} + func (ds *DataSource) convertToIndexMergeScan(prop *property.PhysicalProperty, candidate *candidatePath) (task task, err error) { if prop.TaskTp != property.RootTaskType || !prop.IsEmpty() { return invalidTask, nil diff --git a/session/session.go b/session/session.go index 4ad2170ebbff1..43ec580e45bb3 100644 --- a/session/session.go +++ b/session/session.go @@ -2099,6 +2099,7 @@ var builtinGlobalVariable = []string{ variable.InnodbLockWaitTimeout, variable.WindowingUseHighPrecision, variable.SQLSelectLimit, + variable.SQLSafeUpdates, /* TiDB specific global variables: */ variable.TiDBSkipASCIICheck, diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 60e7185352727..eaf712acb526b 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -731,6 +731,9 @@ type SessionVars struct { // PartitionPruneMode indicates how and when to prune partitions. PartitionPruneMode atomic2.String + + // EnableSafeUpdates indicates if safe-update mode is enabled. + EnableSafeUpdates bool } // UseDynamicPartitionPrune indicates whether use new dynamic partition prune. @@ -1523,6 +1526,8 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { s.EnableChangeColumnType = TiDBOptOn(val) case TiDBEnableAmendPessimisticTxn: s.EnableAmendPessimisticTxn = TiDBOptOn(val) + case SQLSafeUpdates: + s.EnableSafeUpdates = TiDBOptOn(val) } s.systems[name] = val return nil