From e26875e88c941d8cd7efd10593b161ab130f1779 Mon Sep 17 00:00:00 2001 From: xia Date: Thu, 22 Oct 2015 14:43:09 +0800 Subject: [PATCH 01/15] parser: add parse replace statment --- parser/parser.y | 43 ++++++++++++++++++++++++++++++++++++++++--- parser/scanner.l | 2 ++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/parser/parser.y b/parser/parser.y index ef5ed414e8644..50aa5efd782a9 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -217,6 +217,7 @@ import ( references "REFERENCES" regexp "REGEXP" repeat "REPEAT" + replace "REPLACE" right "RIGHT" rlike "RLIKE" rollback "ROLLBACK" @@ -432,7 +433,7 @@ import ( IndexName "index name" IndexType "index type" InsertIntoStmt "INSERT INTO statement" - InsertRest "Rest part of INSERT INTO statement" + InsertRest "Rest part of INSERT/REPLACE INTO statement" IntoOpt "INTO or EmptyString" JoinTable "join table" JoinType "join type" @@ -474,6 +475,8 @@ import ( PrivType "Privilege type" ReferDef "Reference definition" RegexpSym "REGEXP or RLIKE" + ReplaceIntoStmt "REPLACE INTO statement" + ReplacePriority "replace statement priority" RollbackStmt "ROLLBACK statement" SelectLockOpt "FOR UPDATE or LOCK IN SHARE MODE," SelectStmt "SELECT statement" @@ -1823,7 +1826,40 @@ OnDuplicateKeyUpdate: $$ = $5 } -/***********************************Insert Statments END************************************/ +/************************************************************************************ + * + * Replace Statments + * + * TODO: support PARTITION + **********************************************************************************/ +ReplaceIntoStmt: + "REPLACE" ReplacePriority IntoOpt TableIdent InsertRest + { + + x := &stmts.ReplaceIntoStmt{ColNames: $5.(*stmts.InsertIntoStmt).ColNames, + Lists: $5.(*stmts.InsertIntoStmt).Lists, + Sel: $5.(*stmts.InsertIntoStmt).Sel, + SetList: $5.(*stmts.InsertIntoStmt).Setlist} + x.Priority = $2.(int) + x.TableIdent = $4.(table.Ident) + $$ = x + if yylex.(*lexer).root { + break + } + } + +ReplacePriority: + { + $$ = stmts.NoPriority + } +| "LOW_PRIORITY" + { + $$ = stmts.LowPriority + } +| "DELAYED" + { + $$ = stmts.DelayedPriority + } Literal: "false" @@ -2776,7 +2812,6 @@ PrimaryFactor: } | PrimaryExpression - Priority: { $$ = stmts.NoPriority @@ -3533,6 +3568,7 @@ Statement: | InsertIntoStmt | PreparedStmt | RollbackStmt +| ReplaceIntoStmt | SelectStmt | SetStmt | ShowStmt @@ -3552,6 +3588,7 @@ ExplainableStmt: | DeleteFromStmt | UpdateStmt | InsertIntoStmt +| ReplaceIntoStmt StatementList: Statement diff --git a/parser/scanner.l b/parser/scanner.l index 8ce33d9011dd3..bc4a35790f7ac 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -373,6 +373,7 @@ rand {r}{a}{n}{d} repeat {r}{e}{p}{e}{a}{t} references {r}{e}{f}{e}{r}{e}{n}{c}{e}{s} regexp {r}{e}{g}{e}{x}{p} +replace {r}{e}{p}{l}{a}{c}{e} right {r}{i}{g}{h}{t} rlike {r}{l}{i}{k}{e} rollback {r}{o}{l}{l}{b}{a}{c}{k} @@ -802,6 +803,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} return repeat {regexp} return regexp {references} return references +{replace} return replace {rlike} return rlike {sys_var} lval.item = string(l.val) From 5e9042b47afa8f79671e8e961bcc1b5207af710c Mon Sep 17 00:00:00 2001 From: xia Date: Thu, 22 Oct 2015 14:45:54 +0800 Subject: [PATCH 02/15] parser: add test of parsing replace statment --- parser/parser_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/parser/parser_test.go b/parser/parser_test.go index 65b890047c864..09663ec0ea8de 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -160,6 +160,17 @@ func (s *testParserSuite) TestDMLStmt(c *C) { {"INSERT INTO foo (a,b,) VALUES (42,314,)", false}, {"INSERT INTO foo () VALUES ()", true}, {"INSERT INTO foo VALUE ()", true}, + + {"REPLACE INTO foo VALUES (1 || 2)", true}, + {"REPLACE INTO foo VALUES (1 | 2)", true}, + {"REPLACE INTO foo VALUES (false || true)", true}, + {"REPLACE INTO foo VALUES (bar(5678))", false}, + {"REPLACE INTO foo VALUES ()", true}, + {"REPLACE INTO foo (a,b) VALUES (42,314)", true}, + {"REPLACE INTO foo (a,b,) VALUES (42,314)", false}, + {"REPLACE INTO foo (a,b,) VALUES (42,314,)", false}, + {"REPLACE INTO foo () VALUES ()", true}, + {"REPLACE INTO foo VALUE ()", true}, // 40 {`SELECT stuff.id FROM stuff From 3447ac2f554a1fb113eabc048497637b3e32346e Mon Sep 17 00:00:00 2001 From: xia Date: Thu, 22 Oct 2015 17:57:59 +0800 Subject: [PATCH 03/15] stmt: add replace statment --- stmt/stmts/replace.go | 99 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 stmt/stmts/replace.go diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go new file mode 100644 index 0000000000000..0494f9e72244c --- /dev/null +++ b/stmt/stmts/replace.go @@ -0,0 +1,99 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package stmts + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/plan" + "github.com/pingcap/tidb/rset" + "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/util/format" +) + +// ReplaceIntoStmt is a statement works exactly like insert, except that if +// an old row in the table has the same value as a new row for a PRIMARY KEY or a UNIQUE index, +// the old row is deleted before the new row is inserted. +// See: https://dev.mysql.com/doc/refman/5.7/en/replace.html +type ReplaceIntoStmt struct { + ColNames []string + Lists [][]expression.Expression + Sel plan.Planner + TableIdent table.Ident + SetList []*expression.Assignment + Priority int + + Text string +} + +// Explain implements the stmt.Statement Explain interface. +func (s *ReplaceIntoStmt) Explain(ctx context.Context, w format.Formatter) { + w.Format("%s\n", s.Text) +} + +// IsDDL implements the stmt.Statement IsDDL interface. +func (s *ReplaceIntoStmt) IsDDL() bool { + return false +} + +// OriginText implements the stmt.Statement OriginText interface. +func (s *ReplaceIntoStmt) OriginText() string { + return s.Text +} + +// SetText implements the stmt.Statement SetText interface. +func (s *ReplaceIntoStmt) SetText(text string) { + s.Text = text +} + +// Exec implements the stmt.Statement Exec interface. +func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { + stmt := &InsertIntoStmt{ColNames: s.ColNames, Lists: s.Lists, Priority: s.Priority, + Sel: s.Sel, Setlist: s.SetList, TableIdent: s.TableIdent, Text: s.Text} + t, err := getTable(ctx, stmt.TableIdent) + if err != nil { + return nil, errors.Trace(err) + } + cols, err := stmt.getColumns(t.Cols()) + if err != nil { + return nil, errors.Trace(err) + } + + // Process `replace ... (select ...)` + if stmt.Sel != nil { + return stmt.execSelect(t, cols, ctx) + } + // Process `replace ... set x=y ...` + if err = stmt.getSetList(); err != nil { + return nil, errors.Trace(err) + } + m, err := stmt.getDefaultValues(ctx, t.Cols()) + if err != nil { + return nil, errors.Trace(err) + } + replaceValueCount := len(stmt.Lists[0]) + + for i, list := range stmt.Lists { + if err = stmt.checkValueCount(replaceValueCount, len(list), i, cols); err != nil { + return nil, errors.Trace(err) + } + // TODO: remove that has the same value as a new row for a PRIMARY KEY or a UNIQUE index. + if _, _, err = stmt.addRecord(ctx, t, cols, list, m); err != nil { + return nil, errors.Trace(err) + } + } + + return nil, nil +} From c5cb36a7090466237ad0ab169dca4485f0208714 Mon Sep 17 00:00:00 2001 From: xia Date: Thu, 22 Oct 2015 17:58:38 +0800 Subject: [PATCH 04/15] stmt: add replace statment test --- stmt/stmts/replace_test.go | 88 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 stmt/stmts/replace_test.go diff --git a/stmt/stmts/replace_test.go b/stmt/stmts/replace_test.go new file mode 100644 index 0000000000000..f622b68d77a1c --- /dev/null +++ b/stmt/stmts/replace_test.go @@ -0,0 +1,88 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package stmts_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb" + "github.com/pingcap/tidb/stmt/stmts" +) + +func (s *testStmtSuite) TestReplace(c *C) { + testSQL := `drop table if exists replace_test; + create table replace_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int, c2 int, c3 int default 1); + replace replace_test (c1) values (1),(2),(NULL);` + mustExec(c, s.testDB, testSQL) + + errReplaceSelectSQL := `replace replace_test (c1) values ();` + tx := mustBegin(c, s.testDB) + _, err := tx.Exec(errReplaceSelectSQL) + c.Assert(err, NotNil) + tx.Rollback() + + errReplaceSelectSQL = `replace replace_test (c1, c2) values (1,2),(1);` + tx = mustBegin(c, s.testDB) + _, err = tx.Exec(errReplaceSelectSQL) + c.Assert(err, NotNil) + tx.Rollback() + + errReplaceSelectSQL = `replace replace_test (xxx) values (3);` + tx = mustBegin(c, s.testDB) + _, err = tx.Exec(errReplaceSelectSQL) + c.Assert(err, NotNil) + tx.Rollback() + + errReplaceSelectSQL = `replace replace_test_xxx (c1) values ();` + tx = mustBegin(c, s.testDB) + _, err = tx.Exec(errReplaceSelectSQL) + c.Assert(err, NotNil) + tx.Rollback() + + replaceSetSQL := `replace replace_test set c1 = 3;` + mustExec(c, s.testDB, replaceSetSQL) + + stmtList, err := tidb.Compile(replaceSetSQL) + c.Assert(err, IsNil) + c.Assert(stmtList, HasLen, 1) + + testStmt, ok := stmtList[0].(*stmts.ReplaceIntoStmt) + c.Assert(ok, IsTrue) + c.Assert(testStmt.IsDDL(), IsFalse) + c.Assert(len(testStmt.OriginText()), Greater, 0) + + errReplaceSelectSQL = `replace replace_test set c1 = 4, c1 = 5;` + tx = mustBegin(c, s.testDB) + _, err = tx.Exec(errReplaceSelectSQL) + c.Assert(err, NotNil) + tx.Rollback() + + errReplaceSelectSQL = `replace replace_test set xxx = 6;` + tx = mustBegin(c, s.testDB) + _, err = tx.Exec(errReplaceSelectSQL) + c.Assert(err, NotNil) + tx.Rollback() + + replaceSelectSQL := `create table replace_test_1 (id int, c1 int); replace replace_test_1 select id, c1 from replace_test;` + mustExec(c, s.testDB, replaceSelectSQL) + + replaceSelectSQL = `create table replace_test_2 (id int, c1 int); + replace replace_test_1 select id, c1 from replace_test union select id * 10, c1 * 10 from replace_test;` + mustExec(c, s.testDB, replaceSelectSQL) + + errReplaceSelectSQL = `replace replace_test_1 select c1 from replace_test;` + tx = mustBegin(c, s.testDB) + _, err = tx.Exec(errReplaceSelectSQL) + c.Assert(err, NotNil) + tx.Rollback() +} From b9f0bbe088296edfea90bcae8cd69daec22a4759 Mon Sep 17 00:00:00 2001 From: xia Date: Thu, 22 Oct 2015 17:59:12 +0800 Subject: [PATCH 05/15] stmt: adjust the code format --- stmt/stmts/insert.go | 229 ++++++++++++++++++++++++------------------- 1 file changed, 128 insertions(+), 101 deletions(-) diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index 61251f56f8d5f..af936e0b04b36 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -109,7 +109,7 @@ func (s *InsertIntoStmt) execSelect(t table.Table, cols []*column.Col, ctx conte marked[cols[i].Offset] = struct{}{} } - if err = s.initDefaultValues(ctx, t, t.Cols(), data0, marked); err != nil { + if err = s.initDefaultValues(ctx, t, data0, marked); err != nil { return nil, errors.Trace(err) } @@ -185,15 +185,44 @@ func (s *InsertIntoStmt) getColumns(tableCols []*column.Col) ([]*column.Col, err return cols, nil } +func (s *InsertIntoStmt) getDefaultValues(ctx context.Context, cols []*column.Col) (map[interface{}]interface{}, error) { + m := map[interface{}]interface{}{} + for _, v := range cols { + if value, ok, err := getDefaultValue(ctx, v); ok { + if err != nil { + return nil, errors.Trace(err) + } + + m[v.Name.L] = value + } + } + + return m, nil +} + +func (s *InsertIntoStmt) getSetList() error { + if len(s.Setlist) > 0 { + if len(s.Lists) > 0 { + return errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent) + } + + var l []expression.Expression + for _, v := range s.Setlist { + l = append(l, v.Expr) + } + s.Lists = append(s.Lists, l) + } + + return nil +} + // Exec implements the stmt.Statement Exec interface. func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { t, err := getTable(ctx, s.TableIdent) if err != nil { return nil, errors.Trace(err) } - - tableCols := t.Cols() - cols, err := s.getColumns(tableCols) + cols, err := s.getColumns(t.Cols()) if err != nil { return nil, errors.Trace(err) } @@ -204,123 +233,121 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) } // Process `insert ... set x=y...` - if len(s.Setlist) > 0 { - if len(s.Lists) > 0 { - return nil, errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent) - } - - var l []expression.Expression - for _, v := range s.Setlist { - l = append(l, v.Expr) - } - s.Lists = append(s.Lists, l) + if err = s.getSetList(); err != nil { + return nil, errors.Trace(err) } - m := map[interface{}]interface{}{} - for _, v := range tableCols { - var ( - value interface{} - ok bool - ) - value, ok, err = getDefaultValue(ctx, v) - if ok { - if err != nil { - return nil, errors.Trace(err) - } - - m[v.Name.L] = value - } + m, err := s.getDefaultValues(ctx, t.Cols()) + if err != nil { + return nil, errors.Trace(err) } - insertValueCount := len(s.Lists[0]) - toUpdateColumns, err0 := getOnDuplicateUpdateColumns(s.OnDuplicate, t) - if err0 != nil { - return nil, errors.Trace(err0) + toUpdateColumns, err := getOnDuplicateUpdateColumns(s.OnDuplicate, t) + if err != nil { + return nil, errors.Trace(err) } - toUpdateArgs := map[interface{}]interface{}{} for i, list := range s.Lists { - r := make([]interface{}, len(tableCols)) - valueCount := len(list) - - if insertValueCount != valueCount { - // "insert into t values (), ()" is valid. - // "insert into t values (), (1)" is not valid. - // "insert into t values (1), ()" is not valid. - // "insert into t values (1,2), (1)" is not valid. - // So the value count must be same for all insert list. - return nil, errors.Errorf("Column count doesn't match value count at row %d", i+1) - } - - if valueCount == 0 && len(s.ColNames) > 0 { - // "insert into t (c1) values ()" is not valid. - return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0) - } else if valueCount > 0 && valueCount != len(cols) { - return nil, errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount) - } - - // Clear last insert id. - variable.GetSessionVars(ctx).SetLastInsertID(0) - - marked := make(map[int]struct{}, len(list)) - for i, expr := range list { - // For "insert into t values (default)" Default Eval. - m[expression.ExprEvalDefaultName] = cols[i].Name.O - - val, evalErr := expr.Eval(ctx, m) - if evalErr != nil { - return nil, errors.Trace(evalErr) - } - r[cols[i].Offset] = val - marked[cols[i].Offset] = struct{}{} - } - - if err := s.initDefaultValues(ctx, t, tableCols, r, marked); err != nil { - return nil, errors.Trace(err) - } - - if err = column.CastValues(ctx, r, cols); err != nil { - return nil, errors.Trace(err) - } - if err = column.CheckNotNull(tableCols, r); err != nil { + if err = s.checkValueCount(insertValueCount, len(list), i, cols); err != nil { return nil, errors.Trace(err) } - - // Notes: incompatible with mysql - // MySQL will set last insert id to the first row, as follows: - // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` - // `insert t (c1) values(1),(2),(3);` - // Last insert id will be 1, not 3. - h, err := t.AddRecord(ctx, r) + row, h, err := s.addRecord(ctx, t, cols, list, m) if err == nil { continue } - if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) { + + if h == -1 || len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) { return nil, errors.Trace(err) } - // On duplicate key Update the duplicate row. - // Evaluate the updated value. - // TODO: report rows affected and last insert id. - data, err := t.Row(ctx, h) - if err != nil { + if err = execOnDuplicateUpdate(ctx, t, row, h, toUpdateColumns); err != nil { return nil, errors.Trace(err) } + } - toUpdateArgs[expression.ExprEvalValuesFunc] = func(name string) (interface{}, error) { - c, err1 := findColumnByName(t, name) - if err1 != nil { - return nil, errors.Trace(err1) - } - return r[c.Offset], nil + return nil, nil +} + +func (s *InsertIntoStmt) checkValueCount(insertValueCount, valueCount, num int, cols []*column.Col) error { + if insertValueCount != valueCount { + // "insert into t values (), ()" is valid. + // "insert into t values (), (1)" is not valid. + // "insert into t values (1), ()" is not valid. + // "insert into t values (1,2), (1)" is not valid. + // So the value count must be same for all insert list. + return errors.Errorf("Column count doesn't match value count at row %d", num+1) + } + if valueCount == 0 && len(s.ColNames) > 0 { + // "insert into t (c1) values ()" is not valid. + return errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(s.ColNames), 0) + } else if valueCount > 0 && valueCount != len(cols) { + return errors.Errorf("INSERT INTO %s: expected %d value(s), have %d", s.TableIdent, len(cols), valueCount) + } + + return nil +} + +func (s *InsertIntoStmt) addRecord(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, int64, error) { + r := make([]interface{}, len(t.Cols())) + marked := make(map[int]struct{}, len(list)) + for i, expr := range list { + // For "insert into t values (default)" Default Eval. + m[expression.ExprEvalDefaultName] = cols[i].Name.O + + val, evalErr := expr.Eval(ctx, m) + if evalErr != nil { + return nil, -1, errors.Trace(evalErr) } + r[cols[i].Offset] = val + marked[cols[i].Offset] = struct{}{} + } - err = updateRecord(ctx, h, data, t, toUpdateColumns, toUpdateArgs, 0, true) - if err != nil { - return nil, errors.Trace(err) + // Clear last insert id. + variable.GetSessionVars(ctx).SetLastInsertID(0) + + err := s.initDefaultValues(ctx, t, r, marked) + if err != nil { + return nil, -1, errors.Trace(err) + } + if err = column.CastValues(ctx, r, cols); err != nil { + return nil, -1, errors.Trace(err) + } + if err = column.CheckNotNull(t.Cols(), r); err != nil { + return nil, -1, errors.Trace(err) + } + + // Notes: incompatible with mysql + // MySQL will set last insert id to the first row, as follows: + // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` + // `insert t (c1) values(1),(2),(3);` + // Last insert id will be 1, not 3. + h, err := t.AddRecord(ctx, r) + + return r, h, err +} + +func execOnDuplicateUpdate(ctx context.Context, t table.Table, row []interface{}, h int64, cols map[int]expression.Assignment) error { + // On duplicate key update the duplicate row. + // Evaluate the updated value. + // TODO: report rows affected and last insert id. + data, err := t.Row(ctx, h) + if err != nil { + return errors.Trace(err) + } + + toUpdateArgs := map[interface{}]interface{}{} + toUpdateArgs[expression.ExprEvalValuesFunc] = func(name string) (interface{}, error) { + c, err1 := findColumnByName(t, name) + if err1 != nil { + return nil, errors.Trace(err1) } + return row[c.Offset], nil } - return nil, nil + if err = updateRecord(ctx, h, data, t, cols, toUpdateArgs, 0, true); err != nil { + return errors.Trace(err) + } + + return nil } func getOnDuplicateUpdateColumns(assignList []expression.Assignment, t table.Table) (map[int]expression.Assignment, error) { @@ -336,10 +363,10 @@ func getOnDuplicateUpdateColumns(assignList []expression.Assignment, t table.Tab return m, nil } -func (s *InsertIntoStmt) initDefaultValues(ctx context.Context, t table.Table, cols []*column.Col, row []interface{}, marked map[int]struct{}) error { +func (s *InsertIntoStmt) initDefaultValues(ctx context.Context, t table.Table, row []interface{}, marked map[int]struct{}) error { var err error var defaultValueCols []*column.Col - for i, c := range cols { + for i, c := range t.Cols() { if row[i] != nil { // Column value is not nil, continue. continue From 79ecef38cbeb49114cd4206f55ce57772f24d3aa Mon Sep 17 00:00:00 2001 From: xia Date: Fri, 23 Oct 2015 17:16:20 +0800 Subject: [PATCH 06/15] stmt: rename addRecord to getRow --- stmt/stmts/insert.go | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index af936e0b04b36..0a10bad63d0b4 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -251,12 +251,23 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) if err = s.checkValueCount(insertValueCount, len(list), i, cols); err != nil { return nil, errors.Trace(err) } - row, h, err := s.addRecord(ctx, t, cols, list, m) + + row, err := s.getRow(ctx, t, cols, list, m) + if err != nil { + return nil, errors.Trace(err) + } + + // Notes: incompatible with mysql + // MySQL will set last insert id to the first row, as follows: + // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` + // `insert t (c1) values(1),(2),(3);` + // Last insert id will be 1, not 3. + h, err := t.AddRecord(ctx, row) if err == nil { continue } - if h == -1 || len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) { + if len(s.OnDuplicate) == 0 || !errors2.ErrorEqual(err, kv.ErrKeyExists) { return nil, errors.Trace(err) } if err = execOnDuplicateUpdate(ctx, t, row, h, toUpdateColumns); err != nil { @@ -286,7 +297,7 @@ func (s *InsertIntoStmt) checkValueCount(insertValueCount, valueCount, num int, return nil } -func (s *InsertIntoStmt) addRecord(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, int64, error) { +func (s *InsertIntoStmt) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, error) { r := make([]interface{}, len(t.Cols())) marked := make(map[int]struct{}, len(list)) for i, expr := range list { @@ -295,7 +306,7 @@ func (s *InsertIntoStmt) addRecord(ctx context.Context, t table.Table, cols []*c val, evalErr := expr.Eval(ctx, m) if evalErr != nil { - return nil, -1, errors.Trace(evalErr) + return nil, errors.Trace(evalErr) } r[cols[i].Offset] = val marked[cols[i].Offset] = struct{}{} @@ -306,23 +317,16 @@ func (s *InsertIntoStmt) addRecord(ctx context.Context, t table.Table, cols []*c err := s.initDefaultValues(ctx, t, r, marked) if err != nil { - return nil, -1, errors.Trace(err) + return nil, errors.Trace(err) } if err = column.CastValues(ctx, r, cols); err != nil { - return nil, -1, errors.Trace(err) + return nil, errors.Trace(err) } if err = column.CheckNotNull(t.Cols(), r); err != nil { - return nil, -1, errors.Trace(err) + return nil, errors.Trace(err) } - // Notes: incompatible with mysql - // MySQL will set last insert id to the first row, as follows: - // `t(id int AUTO_INCREMENT, c1 int, PRIMARY KEY (id))` - // `insert t (c1) values(1),(2),(3);` - // Last insert id will be 1, not 3. - h, err := t.AddRecord(ctx, r) - - return r, h, err + return r, err } func execOnDuplicateUpdate(ctx context.Context, t table.Table, row []interface{}, h int64, cols map[int]expression.Assignment) error { From 84ae608915007184aea5f82762cef31ec9b91e5b Mon Sep 17 00:00:00 2001 From: xia Date: Fri, 23 Oct 2015 20:28:23 +0800 Subject: [PATCH 07/15] stmt: add removeExistRow --- stmt/stmts/replace.go | 106 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 101 insertions(+), 5 deletions(-) diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go index 0494f9e72244c..486f79b8c06d8 100644 --- a/stmt/stmts/replace.go +++ b/stmt/stmts/replace.go @@ -14,13 +14,19 @@ package stmts import ( + "strings" + "github.com/juju/errors" + "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/format" + "github.com/pingcap/tidb/util/types" ) // ReplaceIntoStmt is a statement works exactly like insert, except that if @@ -32,7 +38,7 @@ type ReplaceIntoStmt struct { Lists [][]expression.Expression Sel plan.Planner TableIdent table.Ident - SetList []*expression.Assignment + Setlist []*expression.Assignment Priority int Text string @@ -61,7 +67,7 @@ func (s *ReplaceIntoStmt) SetText(text string) { // Exec implements the stmt.Statement Exec interface. func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { stmt := &InsertIntoStmt{ColNames: s.ColNames, Lists: s.Lists, Priority: s.Priority, - Sel: s.Sel, Setlist: s.SetList, TableIdent: s.TableIdent, Text: s.Text} + Sel: s.Sel, Setlist: s.Setlist, TableIdent: s.TableIdent, Text: s.Text} t, err := getTable(ctx, stmt.TableIdent) if err != nil { return nil, errors.Trace(err) @@ -76,7 +82,7 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error return stmt.execSelect(t, cols, ctx) } // Process `replace ... set x=y ...` - if err = stmt.getSetList(); err != nil { + if err = stmt.getSetlist(); err != nil { return nil, errors.Trace(err) } m, err := stmt.getDefaultValues(ctx, t.Cols()) @@ -89,11 +95,101 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error if err = stmt.checkValueCount(replaceValueCount, len(list), i, cols); err != nil { return nil, errors.Trace(err) } - // TODO: remove that has the same value as a new row for a PRIMARY KEY or a UNIQUE index. - if _, _, err = stmt.addRecord(ctx, t, cols, list, m); err != nil { + row, err := stmt.getRow(ctx, t, cols, list, m) + if err != nil { + return nil, errors.Trace(err) + } + + if err = removeExistRow(ctx, t, row); err != nil { + return nil, errors.Trace(err) + } + if _, err = t.AddRecord(ctx, row); err != nil { return nil, errors.Trace(err) } } return nil, nil } + +func removeExistRow(ctx context.Context, t table.Table, replaceRow []interface{}) error { + indices := make([]*column.IndexedCol, 0, len(t.Indices())) + for _, idx := range t.Indices() { + // TODO: handles the idx is composite primary key the affected rows. + if idx.Unique || idx.Primary && len(idx.Columns) >= 1 { + indices = append(indices, idx) + } + } + if len(indices) == 0 { + return nil + } + + txn, err := ctx.GetTxn(false) + if err != nil { + return errors.Trace(err) + } + it, err := txn.Seek([]byte(t.FirstKey())) + if err != nil { + return errors.Trace(err) + } + defer it.Close() + + prefix := t.KeyPrefix() + for it.Valid() && strings.HasPrefix(it.Key(), prefix) { + handle, err0 := util.DecodeHandleFromRowKey(it.Key()) + if err0 != nil { + return errors.Trace(err0) + } + row, err0 := t.Row(ctx, handle) + if err0 != nil { + return errors.Trace(err0) + } + for _, idx := range indices { + v, err1 := compareIndex(t.Cols(), idx, row, replaceRow) + if err1 != nil { + return errors.Trace(err1) + } + if v != 0 { + continue + } + if err = removeRow(ctx, t, handle, row); err != nil { + return errors.Trace(err) + } + } + + rk := t.RecordKey(handle, nil) + if it, err0 = kv.NextUntil(it, util.RowKeyPrefixFilter(rk)); err0 != nil { + return errors.Trace(err0) + } + } + + return nil +} + +// compareIndex returns an integer comparing the idx old with new. +// old > new -> 1 +// old = new -> 0 +// old < new -> -1 +func compareIndex(cols []*column.Col, idx *column.IndexedCol, row, replaceRow []interface{}) (v int, err error) { + nulls := 0 + for _, idxCol := range idx.Columns { + col := column.FindCol(cols, idxCol.Name.L) + if col == nil { + return 0, errors.Errorf("No such column: %v", idx) + } + v, err = types.Compare(row[col.Offset], replaceRow[col.Offset]) + if err != nil { + return 0, errors.Trace(err) + } + if v != 0 { + break + } + if row[col.Offset] == nil { + nulls++ + } + } + if nulls == len(idx.Columns) { + v = -1 + } + + return v, nil +} From ccc0ffdfe939abc4fc85367d05a5077e10aeab45 Mon Sep 17 00:00:00 2001 From: xia Date: Fri, 23 Oct 2015 20:29:08 +0800 Subject: [PATCH 08/15] stmt: add replace primary key and unique index test --- stmt/stmts/replace_test.go | 65 +++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/stmt/stmts/replace_test.go b/stmt/stmts/replace_test.go index f622b68d77a1c..fed415c63685e 100644 --- a/stmt/stmts/replace_test.go +++ b/stmt/stmts/replace_test.go @@ -25,33 +25,32 @@ func (s *testStmtSuite) TestReplace(c *C) { replace replace_test (c1) values (1),(2),(NULL);` mustExec(c, s.testDB, testSQL) - errReplaceSelectSQL := `replace replace_test (c1) values ();` + errReplaceSQL := `replace replace_test (c1) values ();` tx := mustBegin(c, s.testDB) - _, err := tx.Exec(errReplaceSelectSQL) + _, err := tx.Exec(errReplaceSQL) c.Assert(err, NotNil) tx.Rollback() - errReplaceSelectSQL = `replace replace_test (c1, c2) values (1,2),(1);` + errReplaceSQL = `replace replace_test (c1, c2) values (1,2),(1);` tx = mustBegin(c, s.testDB) - _, err = tx.Exec(errReplaceSelectSQL) + _, err = tx.Exec(errReplaceSQL) c.Assert(err, NotNil) tx.Rollback() - errReplaceSelectSQL = `replace replace_test (xxx) values (3);` + errReplaceSQL = `replace replace_test (xxx) values (3);` tx = mustBegin(c, s.testDB) - _, err = tx.Exec(errReplaceSelectSQL) + _, err = tx.Exec(errReplaceSQL) c.Assert(err, NotNil) tx.Rollback() - errReplaceSelectSQL = `replace replace_test_xxx (c1) values ();` + errReplaceSQL = `replace replace_test_xxx (c1) values ();` tx = mustBegin(c, s.testDB) - _, err = tx.Exec(errReplaceSelectSQL) + _, err = tx.Exec(errReplaceSQL) c.Assert(err, NotNil) tx.Rollback() replaceSetSQL := `replace replace_test set c1 = 3;` mustExec(c, s.testDB, replaceSetSQL) - stmtList, err := tidb.Compile(replaceSetSQL) c.Assert(err, IsNil) c.Assert(stmtList, HasLen, 1) @@ -61,28 +60,64 @@ func (s *testStmtSuite) TestReplace(c *C) { c.Assert(testStmt.IsDDL(), IsFalse) c.Assert(len(testStmt.OriginText()), Greater, 0) - errReplaceSelectSQL = `replace replace_test set c1 = 4, c1 = 5;` + errReplaceSetSQL := `replace replace_test set c1 = 4, c1 = 5;` tx = mustBegin(c, s.testDB) - _, err = tx.Exec(errReplaceSelectSQL) + _, err = tx.Exec(errReplaceSetSQL) c.Assert(err, NotNil) tx.Rollback() - errReplaceSelectSQL = `replace replace_test set xxx = 6;` + errReplaceSetSQL = `replace replace_test set xxx = 6;` tx = mustBegin(c, s.testDB) - _, err = tx.Exec(errReplaceSelectSQL) + _, err = tx.Exec(errReplaceSetSQL) c.Assert(err, NotNil) tx.Rollback() - replaceSelectSQL := `create table replace_test_1 (id int, c1 int); replace replace_test_1 select id, c1 from replace_test;` + replaceSelectSQL := `create table replace_test_1 (id int, c1 int); + replace replace_test_1 select id, c1 from replace_test;` mustExec(c, s.testDB, replaceSelectSQL) replaceSelectSQL = `create table replace_test_2 (id int, c1 int); replace replace_test_1 select id, c1 from replace_test union select id * 10, c1 * 10 from replace_test;` mustExec(c, s.testDB, replaceSelectSQL) - errReplaceSelectSQL = `replace replace_test_1 select c1 from replace_test;` + errReplaceSelectSQL := `replace replace_test_1 select c1 from replace_test;` tx = mustBegin(c, s.testDB) _, err = tx.Exec(errReplaceSelectSQL) c.Assert(err, NotNil) tx.Rollback() + + replaceUniqueIndexSQL := `create table replace_test_3 (c1 int, c2 int, UNIQUE INDEX (c2)); + replace into replace_test_3 set c2=1;` + mustExec(c, s.testDB, replaceUniqueIndexSQL) + replaceUniqueIndexSQL = `replace into replace_test_3 set c2=1;` + ret := mustExec(c, s.testDB, replaceUniqueIndexSQL) + rows, err := ret.RowsAffected() + c.Assert(err, IsNil) + c.Assert(rows, Equals, int64(2)) + + replaceUniqueIndexSQL = `replace into replace_test_3 set c2=NULL;` + mustExec(c, s.testDB, replaceUniqueIndexSQL) + replaceUniqueIndexSQL = `replace into replace_test_3 set c2=NULL;` + ret = mustExec(c, s.testDB, replaceUniqueIndexSQL) + rows, err = ret.RowsAffected() + c.Assert(err, IsNil) + c.Assert(rows, Equals, int64(1)) + + replaceUniqueIndexSQL = `create table replace_test_4 (c1 int, c2 int, c3 int, UNIQUE INDEX (c1, c2)); + replace into replace_test_4 set c2=NULL;` + mustExec(c, s.testDB, replaceUniqueIndexSQL) + replaceUniqueIndexSQL = `replace into replace_test_4 set c2=NULL;` + ret = mustExec(c, s.testDB, replaceUniqueIndexSQL) + rows, err = ret.RowsAffected() + c.Assert(err, IsNil) + c.Assert(rows, Equals, int64(1)) + + replacePrimaryKeySQL := `create table replace_test_5 (c1 int, c2 int, c3 int, PRIMARY KEY (c1, c2)); + replace into replace_test_5 set c1=1, c2=2;` + mustExec(c, s.testDB, replacePrimaryKeySQL) + replacePrimaryKeySQL = `replace into replace_test_5 set c1=1, c2=2;` + ret = mustExec(c, s.testDB, replacePrimaryKeySQL) + rows, err = ret.RowsAffected() + c.Assert(err, IsNil) + c.Assert(rows, Equals, int64(2)) } From 36a60807cbb401c59521c548d0c04510ea349d9f Mon Sep 17 00:00:00 2001 From: xia Date: Fri, 23 Oct 2015 20:31:14 +0800 Subject: [PATCH 09/15] stmt: rename addRecord to getRow and rename SetList to Setlist --- parser/parser.y | 2 +- stmt/stmts/insert.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/parser/parser.y b/parser/parser.y index 50aa5efd782a9..4eef2ac966ef0 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -1839,7 +1839,7 @@ ReplaceIntoStmt: x := &stmts.ReplaceIntoStmt{ColNames: $5.(*stmts.InsertIntoStmt).ColNames, Lists: $5.(*stmts.InsertIntoStmt).Lists, Sel: $5.(*stmts.InsertIntoStmt).Sel, - SetList: $5.(*stmts.InsertIntoStmt).Setlist} + Setlist: $5.(*stmts.InsertIntoStmt).Setlist} x.Priority = $2.(int) x.TableIdent = $4.(table.Ident) $$ = x diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index 0a10bad63d0b4..b49b292739411 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -200,7 +200,7 @@ func (s *InsertIntoStmt) getDefaultValues(ctx context.Context, cols []*column.Co return m, nil } -func (s *InsertIntoStmt) getSetList() error { +func (s *InsertIntoStmt) getSetlist() error { if len(s.Setlist) > 0 { if len(s.Lists) > 0 { return errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent) @@ -233,7 +233,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) } // Process `insert ... set x=y...` - if err = s.getSetList(); err != nil { + if err = s.getSetlist(); err != nil { return nil, errors.Trace(err) } From bacb48996328667f863d20d38e98c33e466f59b4 Mon Sep 17 00:00:00 2001 From: xia Date: Fri, 23 Oct 2015 20:32:20 +0800 Subject: [PATCH 10/15] stmt: change remove method to function --- stmt/stmts/delete.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stmt/stmts/delete.go b/stmt/stmts/delete.go index d8a9a91b870d7..68c0d14938fde 100644 --- a/stmt/stmts/delete.go +++ b/stmt/stmts/delete.go @@ -99,7 +99,7 @@ func (s *DeleteStmt) plan(ctx context.Context) (plan.Plan, error) { return r, nil } -func (s *DeleteStmt) removeRow(ctx context.Context, t table.Table, h int64, data []interface{}) error { +func removeRow(ctx context.Context, t table.Table, h int64, data []interface{}) error { // remove row's all indexies if err := t.RemoveRowAllIndex(ctx, h, data); err != nil { return err @@ -186,7 +186,7 @@ func (s *DeleteStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { if err != nil { return nil, errors.Trace(err) } - err = s.removeRow(ctx, t, handle, data) + err = removeRow(ctx, t, handle, data) if err != nil { return nil, errors.Trace(err) } From b4562dc8f56f0a260faccf1b75f5157e1eace040 Mon Sep 17 00:00:00 2001 From: xia Date: Sat, 24 Oct 2015 16:19:10 +0800 Subject: [PATCH 11/15] stmt: update affected rows and test --- stmt/stmts/replace.go | 104 +++++++++---------------------------- stmt/stmts/replace_test.go | 7 ++- 2 files changed, 31 insertions(+), 80 deletions(-) diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go index 486f79b8c06d8..e5890c852387e 100644 --- a/stmt/stmts/replace.go +++ b/stmt/stmts/replace.go @@ -14,17 +14,15 @@ package stmts import ( - "strings" - "github.com/juju/errors" - "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/rset" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/errors2" "github.com/pingcap/tidb/util/format" "github.com/pingcap/tidb/util/types" ) @@ -99,97 +97,45 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error if err != nil { return nil, errors.Trace(err) } - - if err = removeExistRow(ctx, t, row); err != nil { + h, err := t.AddRecord(ctx, row) + if err == nil { + continue + } + if err != nil && !errors2.ErrorEqual(err, kv.ErrKeyExists) { return nil, errors.Trace(err) } - if _, err = t.AddRecord(ctx, row); err != nil { + + // While the insertion fails because a duplicate-key error occurs for a primary key or unique index, + // a storage engine may perform the REPLACE as an update rather than a delete plus insert. + // See: http://dev.mysql.com/doc/refman/5.7/en/replace.html. + if err = replaceRow(ctx, t, h, row); err != nil { return nil, errors.Trace(err) } + variable.GetSessionVars(ctx).AddAffectedRows(1) } return nil, nil } -func removeExistRow(ctx context.Context, t table.Table, replaceRow []interface{}) error { - indices := make([]*column.IndexedCol, 0, len(t.Indices())) - for _, idx := range t.Indices() { - // TODO: handles the idx is composite primary key the affected rows. - if idx.Unique || idx.Primary && len(idx.Columns) >= 1 { - indices = append(indices, idx) - } - } - if len(indices) == 0 { - return nil - } - - txn, err := ctx.GetTxn(false) +func replaceRow(ctx context.Context, t table.Table, handle int64, replaceRow []interface{}) error { + row, err := t.Row(ctx, handle) if err != nil { return errors.Trace(err) } - it, err := txn.Seek([]byte(t.FirstKey())) - if err != nil { - return errors.Trace(err) - } - defer it.Close() - - prefix := t.KeyPrefix() - for it.Valid() && strings.HasPrefix(it.Key(), prefix) { - handle, err0 := util.DecodeHandleFromRowKey(it.Key()) - if err0 != nil { - return errors.Trace(err0) - } - row, err0 := t.Row(ctx, handle) - if err0 != nil { - return errors.Trace(err0) - } - for _, idx := range indices { - v, err1 := compareIndex(t.Cols(), idx, row, replaceRow) - if err1 != nil { - return errors.Trace(err1) - } - if v != 0 { - continue - } - if err = removeRow(ctx, t, handle, row); err != nil { - return errors.Trace(err) - } - } - - rk := t.RecordKey(handle, nil) - if it, err0 = kv.NextUntil(it, util.RowKeyPrefixFilter(rk)); err0 != nil { - return errors.Trace(err0) - } - } - - return nil -} - -// compareIndex returns an integer comparing the idx old with new. -// old > new -> 1 -// old = new -> 0 -// old < new -> -1 -func compareIndex(cols []*column.Col, idx *column.IndexedCol, row, replaceRow []interface{}) (v int, err error) { - nulls := 0 - for _, idxCol := range idx.Columns { - col := column.FindCol(cols, idxCol.Name.L) - if col == nil { - return 0, errors.Errorf("No such column: %v", idx) - } - v, err = types.Compare(row[col.Offset], replaceRow[col.Offset]) + for i, val := range row { + v, err := types.Compare(val, replaceRow[i]) if err != nil { - return 0, errors.Trace(err) + return errors.Trace(err) } if v != 0 { - break - } - if row[col.Offset] == nil { - nulls++ + touched := make([]bool, len(row)) + for i := 0; i < len(touched); i++ { + touched[i] = true + } + variable.GetSessionVars(ctx).AddAffectedRows(1) + return t.UpdateRecord(ctx, handle, row, replaceRow, touched) } } - if nulls == len(idx.Columns) { - v = -1 - } - return v, nil + return nil } diff --git a/stmt/stmts/replace_test.go b/stmt/stmts/replace_test.go index fed415c63685e..fb223491611d3 100644 --- a/stmt/stmts/replace_test.go +++ b/stmt/stmts/replace_test.go @@ -93,6 +93,11 @@ func (s *testStmtSuite) TestReplace(c *C) { ret := mustExec(c, s.testDB, replaceUniqueIndexSQL) rows, err := ret.RowsAffected() c.Assert(err, IsNil) + c.Assert(rows, Equals, int64(1)) + replaceUniqueIndexSQL = `replace into replace_test_3 set c1=1, c2=1;` + ret = mustExec(c, s.testDB, replaceUniqueIndexSQL) + rows, err = ret.RowsAffected() + c.Assert(err, IsNil) c.Assert(rows, Equals, int64(2)) replaceUniqueIndexSQL = `replace into replace_test_3 set c2=NULL;` @@ -119,5 +124,5 @@ func (s *testStmtSuite) TestReplace(c *C) { ret = mustExec(c, s.testDB, replacePrimaryKeySQL) rows, err = ret.RowsAffected() c.Assert(err, IsNil) - c.Assert(rows, Equals, int64(2)) + c.Assert(rows, Equals, int64(1)) } From f5f4b223a58709790d4fd6fbd02028b185944049 Mon Sep 17 00:00:00 2001 From: xia Date: Mon, 26 Oct 2015 10:48:22 +0800 Subject: [PATCH 12/15] *: Update format. --- parser/parser.y | 6 +----- stmt/stmts/insert.go | 10 +++++----- stmt/stmts/replace.go | 10 ++++++++-- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/parser/parser.y b/parser/parser.y index 4eef2ac966ef0..cf2df231bb224 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -1827,15 +1827,14 @@ OnDuplicateKeyUpdate: } /************************************************************************************ - * * Replace Statments + * See: https://dev.mysql.com/doc/refman/5.7/en/replace.html * * TODO: support PARTITION **********************************************************************************/ ReplaceIntoStmt: "REPLACE" ReplacePriority IntoOpt TableIdent InsertRest { - x := &stmts.ReplaceIntoStmt{ColNames: $5.(*stmts.InsertIntoStmt).ColNames, Lists: $5.(*stmts.InsertIntoStmt).Lists, Sel: $5.(*stmts.InsertIntoStmt).Sel, @@ -1843,9 +1842,6 @@ ReplaceIntoStmt: x.Priority = $2.(int) x.TableIdent = $4.(table.Ident) $$ = x - if yylex.(*lexer).root { - break - } } ReplacePriority: diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index b49b292739411..be351f17dfc5d 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -81,7 +81,7 @@ func (s *InsertIntoStmt) SetText(text string) { } // execExecSelect implements `insert table select ... from ...`. -func (s *InsertIntoStmt) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (_ rset.Recordset, err error) { +func (s *InsertIntoStmt) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (rset.Recordset, error) { r, err := s.Sel.Plan(ctx) if err != nil { return nil, errors.Trace(err) @@ -304,9 +304,9 @@ func (s *InsertIntoStmt) getRow(ctx context.Context, t table.Table, cols []*colu // For "insert into t values (default)" Default Eval. m[expression.ExprEvalDefaultName] = cols[i].Name.O - val, evalErr := expr.Eval(ctx, m) - if evalErr != nil { - return nil, errors.Trace(evalErr) + val, err := expr.Eval(ctx, m) + if err != nil { + return nil, errors.Trace(err) } r[cols[i].Offset] = val marked[cols[i].Offset] = struct{}{} @@ -326,7 +326,7 @@ func (s *InsertIntoStmt) getRow(ctx context.Context, t table.Table, cols []*colu return nil, errors.Trace(err) } - return r, err + return r, nil } func execOnDuplicateUpdate(ctx context.Context, t table.Table, row []interface{}, h int64, cols map[int]expression.Assignment) error { diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go index e5890c852387e..5c5fbbd85f776 100644 --- a/stmt/stmts/replace.go +++ b/stmt/stmts/replace.go @@ -64,8 +64,14 @@ func (s *ReplaceIntoStmt) SetText(text string) { // Exec implements the stmt.Statement Exec interface. func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { - stmt := &InsertIntoStmt{ColNames: s.ColNames, Lists: s.Lists, Priority: s.Priority, - Sel: s.Sel, Setlist: s.Setlist, TableIdent: s.TableIdent, Text: s.Text} + stmt := &InsertIntoStmt{ColNames: s.ColNames, + Lists: s.Lists, + Priority: s.Priority, + Sel: s.Sel, + Setlist: s.Setlist, + TableIdent: s.TableIdent, + Text: s.Text} + t, err := getTable(ctx, stmt.TableIdent) if err != nil { return nil, errors.Trace(err) From 5a42db1ea4b3aa5a8b74fc4cf4150265b69fcd9b Mon Sep 17 00:00:00 2001 From: xia Date: Mon, 26 Oct 2015 18:14:51 +0800 Subject: [PATCH 13/15] *: rename getSetlist to fillValueList and add the InsertRest structure to make the code more clear --- parser/parser.y | 21 +++++++++------------ stmt/stmts/insert.go | 33 +++++++++++++++++++-------------- stmt/stmts/replace.go | 37 +++++++++++-------------------------- 3 files changed, 39 insertions(+), 52 deletions(-) diff --git a/parser/parser.y b/parser/parser.y index cf2df231bb224..84dfc0476f9c9 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -1719,7 +1719,7 @@ NotKeywordToken: InsertIntoStmt: "INSERT" Priority IgnoreOptional IntoOpt TableIdent InsertRest OnDuplicateKeyUpdate { - x := $6.(*stmts.InsertIntoStmt) + x := &stmts.InsertIntoStmt{InsertRest: $6.(stmts.InsertRest)} x.Priority = $2.(int) x.TableIdent = $5.(table.Ident) if $7 != nil { @@ -1741,33 +1741,33 @@ IntoOpt: InsertRest: '(' ColumnNameListOpt ')' ValueSym ExpressionListList { - $$ = &stmts.InsertIntoStmt{ + $$ = stmts.InsertRest{ ColNames: $2.([]string), Lists: $5.([][]expression.Expression)} } | '(' ColumnNameListOpt ')' SelectStmt { - $$ = &stmts.InsertIntoStmt{ColNames: $2.([]string), Sel: $4.(*stmts.SelectStmt)} + $$ = stmts.InsertRest{ColNames: $2.([]string), Sel: $4.(*stmts.SelectStmt)} } | '(' ColumnNameListOpt ')' UnionStmt { - $$ = &stmts.InsertIntoStmt{ColNames: $2.([]string), Sel: $4.(*stmts.UnionStmt)} + $$ = stmts.InsertRest{ColNames: $2.([]string), Sel: $4.(*stmts.UnionStmt)} } | ValueSym ExpressionListList %prec insertValues { - $$ = &stmts.InsertIntoStmt{Lists: $2.([][]expression.Expression)} + $$ = stmts.InsertRest{Lists: $2.([][]expression.Expression)} } | SelectStmt { - $$ = &stmts.InsertIntoStmt{Sel: $1.(*stmts.SelectStmt)} + $$ = stmts.InsertRest{Sel: $1.(*stmts.SelectStmt)} } | UnionStmt { - $$ = &stmts.InsertIntoStmt{Sel: $1.(*stmts.UnionStmt)} + $$ = stmts.InsertRest{Sel: $1.(*stmts.UnionStmt)} } | "SET" ColumnSetValueList { - $$ = &stmts.InsertIntoStmt{Setlist: $2.([]*expression.Assignment)} + $$ = stmts.InsertRest{Setlist: $2.([]*expression.Assignment)} } ValueSym: @@ -1835,10 +1835,7 @@ OnDuplicateKeyUpdate: ReplaceIntoStmt: "REPLACE" ReplacePriority IntoOpt TableIdent InsertRest { - x := &stmts.ReplaceIntoStmt{ColNames: $5.(*stmts.InsertIntoStmt).ColNames, - Lists: $5.(*stmts.InsertIntoStmt).Lists, - Sel: $5.(*stmts.InsertIntoStmt).Sel, - Setlist: $5.(*stmts.InsertIntoStmt).Setlist} + x := &stmts.ReplaceIntoStmt{InsertRest: $5.(stmts.InsertRest)} x.Priority = $2.(int) x.TableIdent = $4.(table.Ident) $$ = x diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index be351f17dfc5d..dfaeb90a74a40 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -46,15 +46,20 @@ const ( DelayedPriority ) +// InsertRest is the rest part of insert/replace into statement. +type InsertRest struct { + ColNames []string + Lists [][]expression.Expression + Sel plan.Planner + TableIdent table.Ident + Setlist []*expression.Assignment + Priority int +} + // InsertIntoStmt is a statement to insert new rows into an existing table. // See: https://dev.mysql.com/doc/refman/5.7/en/insert.html type InsertIntoStmt struct { - ColNames []string - Lists [][]expression.Expression - Sel plan.Planner - TableIdent table.Ident - Setlist []*expression.Assignment - Priority int + InsertRest OnDuplicate []expression.Assignment Text string @@ -81,7 +86,7 @@ func (s *InsertIntoStmt) SetText(text string) { } // execExecSelect implements `insert table select ... from ...`. -func (s *InsertIntoStmt) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (rset.Recordset, error) { +func (s *InsertRest) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (rset.Recordset, error) { r, err := s.Sel.Plan(ctx) if err != nil { return nil, errors.Trace(err) @@ -144,7 +149,7 @@ func (s *InsertIntoStmt) execSelect(t table.Table, cols []*column.Col, ctx conte // 2 insert ... set x=y... --> set type column // 3 insert ... (select ..) --> name type column // See: https://dev.mysql.com/doc/refman/5.7/en/insert.html -func (s *InsertIntoStmt) getColumns(tableCols []*column.Col) ([]*column.Col, error) { +func (s *InsertRest) getColumns(tableCols []*column.Col) ([]*column.Col, error) { var cols []*column.Col var err error @@ -185,7 +190,7 @@ func (s *InsertIntoStmt) getColumns(tableCols []*column.Col) ([]*column.Col, err return cols, nil } -func (s *InsertIntoStmt) getDefaultValues(ctx context.Context, cols []*column.Col) (map[interface{}]interface{}, error) { +func (s *InsertRest) getDefaultValues(ctx context.Context, cols []*column.Col) (map[interface{}]interface{}, error) { m := map[interface{}]interface{}{} for _, v := range cols { if value, ok, err := getDefaultValue(ctx, v); ok { @@ -200,7 +205,7 @@ func (s *InsertIntoStmt) getDefaultValues(ctx context.Context, cols []*column.Co return m, nil } -func (s *InsertIntoStmt) getSetlist() error { +func (s *InsertRest) fillValueList() error { if len(s.Setlist) > 0 { if len(s.Lists) > 0 { return errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent) @@ -233,7 +238,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) } // Process `insert ... set x=y...` - if err = s.getSetlist(); err != nil { + if err = s.fillValueList(); err != nil { return nil, errors.Trace(err) } @@ -278,7 +283,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) return nil, nil } -func (s *InsertIntoStmt) checkValueCount(insertValueCount, valueCount, num int, cols []*column.Col) error { +func (s *InsertRest) checkValueCount(insertValueCount, valueCount, num int, cols []*column.Col) error { if insertValueCount != valueCount { // "insert into t values (), ()" is valid. // "insert into t values (), (1)" is not valid. @@ -297,7 +302,7 @@ func (s *InsertIntoStmt) checkValueCount(insertValueCount, valueCount, num int, return nil } -func (s *InsertIntoStmt) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, error) { +func (s *InsertRest) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, error) { r := make([]interface{}, len(t.Cols())) marked := make(map[int]struct{}, len(list)) for i, expr := range list { @@ -367,7 +372,7 @@ func getOnDuplicateUpdateColumns(assignList []expression.Assignment, t table.Tab return m, nil } -func (s *InsertIntoStmt) initDefaultValues(ctx context.Context, t table.Table, row []interface{}, marked map[int]struct{}) error { +func (s *InsertRest) initDefaultValues(ctx context.Context, t table.Table, row []interface{}, marked map[int]struct{}) error { var err error var defaultValueCols []*column.Col for i, c := range t.Cols() { diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go index 5c5fbbd85f776..ece0a305a903f 100644 --- a/stmt/stmts/replace.go +++ b/stmt/stmts/replace.go @@ -16,9 +16,7 @@ package stmts import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" @@ -32,12 +30,7 @@ import ( // the old row is deleted before the new row is inserted. // See: https://dev.mysql.com/doc/refman/5.7/en/replace.html type ReplaceIntoStmt struct { - ColNames []string - Lists [][]expression.Expression - Sel plan.Planner - TableIdent table.Ident - Setlist []*expression.Assignment - Priority int + InsertRest Text string } @@ -64,42 +57,34 @@ func (s *ReplaceIntoStmt) SetText(text string) { // Exec implements the stmt.Statement Exec interface. func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { - stmt := &InsertIntoStmt{ColNames: s.ColNames, - Lists: s.Lists, - Priority: s.Priority, - Sel: s.Sel, - Setlist: s.Setlist, - TableIdent: s.TableIdent, - Text: s.Text} - - t, err := getTable(ctx, stmt.TableIdent) + t, err := getTable(ctx, s.TableIdent) if err != nil { return nil, errors.Trace(err) } - cols, err := stmt.getColumns(t.Cols()) + cols, err := s.getColumns(t.Cols()) if err != nil { return nil, errors.Trace(err) } // Process `replace ... (select ...)` - if stmt.Sel != nil { - return stmt.execSelect(t, cols, ctx) + if s.Sel != nil { + return s.execSelect(t, cols, ctx) } // Process `replace ... set x=y ...` - if err = stmt.getSetlist(); err != nil { + if err = s.fillValueList(); err != nil { return nil, errors.Trace(err) } - m, err := stmt.getDefaultValues(ctx, t.Cols()) + m, err := s.getDefaultValues(ctx, t.Cols()) if err != nil { return nil, errors.Trace(err) } - replaceValueCount := len(stmt.Lists[0]) + replaceValueCount := len(s.Lists[0]) - for i, list := range stmt.Lists { - if err = stmt.checkValueCount(replaceValueCount, len(list), i, cols); err != nil { + for i, list := range s.Lists { + if err = s.checkValueCount(replaceValueCount, len(list), i, cols); err != nil { return nil, errors.Trace(err) } - row, err := stmt.getRow(ctx, t, cols, list, m) + row, err := s.getRow(ctx, t, cols, list, m) if err != nil { return nil, errors.Trace(err) } From 376bdee59a2ad2b9f921ef032d9e14c424135799 Mon Sep 17 00:00:00 2001 From: xia Date: Tue, 27 Oct 2015 10:30:33 +0800 Subject: [PATCH 14/15] stmt: update only modify columns. --- stmt/stmts/replace.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go index ece0a305a903f..9d885a4a5ec2d 100644 --- a/stmt/stmts/replace.go +++ b/stmt/stmts/replace.go @@ -113,18 +113,25 @@ func replaceRow(ctx context.Context, t table.Table, handle int64, replaceRow []i if err != nil { return errors.Trace(err) } + + isReplace := false + touched := make([]bool, len(row)) for i, val := range row { v, err := types.Compare(val, replaceRow[i]) if err != nil { return errors.Trace(err) } if v != 0 { - touched := make([]bool, len(row)) - for i := 0; i < len(touched); i++ { - touched[i] = true + touched[i] = true + if !isReplace { + isReplace = true } - variable.GetSessionVars(ctx).AddAffectedRows(1) - return t.UpdateRecord(ctx, handle, row, replaceRow, touched) + } + } + if isReplace { + variable.GetSessionVars(ctx).AddAffectedRows(1) + if err = t.UpdateRecord(ctx, handle, row, replaceRow, touched); err != nil { + return errors.Trace(err) } } From 4f921732671be69c84a88334e2e92ec1a1ed74f4 Mon Sep 17 00:00:00 2001 From: xia Date: Tue, 27 Oct 2015 13:21:35 +0800 Subject: [PATCH 15/15] *: rename InsertRest to InsertValues and add TODOs --- parser/parser.y | 26 +++++++++++++------------- stmt/stmts/insert.go | 21 +++++++++++---------- stmt/stmts/replace.go | 13 ++++++------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/parser/parser.y b/parser/parser.y index 84dfc0476f9c9..391455b166f85 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -433,7 +433,7 @@ import ( IndexName "index name" IndexType "index type" InsertIntoStmt "INSERT INTO statement" - InsertRest "Rest part of INSERT/REPLACE INTO statement" + InsertValues "Rest part of INSERT/REPLACE INTO statement" IntoOpt "INTO or EmptyString" JoinTable "join table" JoinType "join type" @@ -1717,9 +1717,9 @@ NotKeywordToken: * TODO: support PARTITION **********************************************************************************/ InsertIntoStmt: - "INSERT" Priority IgnoreOptional IntoOpt TableIdent InsertRest OnDuplicateKeyUpdate + "INSERT" Priority IgnoreOptional IntoOpt TableIdent InsertValues OnDuplicateKeyUpdate { - x := &stmts.InsertIntoStmt{InsertRest: $6.(stmts.InsertRest)} + x := &stmts.InsertIntoStmt{InsertValues: $6.(stmts.InsertValues)} x.Priority = $2.(int) x.TableIdent = $5.(table.Ident) if $7 != nil { @@ -1738,36 +1738,36 @@ IntoOpt: { } -InsertRest: +InsertValues: '(' ColumnNameListOpt ')' ValueSym ExpressionListList { - $$ = stmts.InsertRest{ + $$ = stmts.InsertValues{ ColNames: $2.([]string), Lists: $5.([][]expression.Expression)} } | '(' ColumnNameListOpt ')' SelectStmt { - $$ = stmts.InsertRest{ColNames: $2.([]string), Sel: $4.(*stmts.SelectStmt)} + $$ = stmts.InsertValues{ColNames: $2.([]string), Sel: $4.(*stmts.SelectStmt)} } | '(' ColumnNameListOpt ')' UnionStmt { - $$ = stmts.InsertRest{ColNames: $2.([]string), Sel: $4.(*stmts.UnionStmt)} + $$ = stmts.InsertValues{ColNames: $2.([]string), Sel: $4.(*stmts.UnionStmt)} } | ValueSym ExpressionListList %prec insertValues { - $$ = stmts.InsertRest{Lists: $2.([][]expression.Expression)} + $$ = stmts.InsertValues{Lists: $2.([][]expression.Expression)} } | SelectStmt { - $$ = stmts.InsertRest{Sel: $1.(*stmts.SelectStmt)} + $$ = stmts.InsertValues{Sel: $1.(*stmts.SelectStmt)} } | UnionStmt { - $$ = stmts.InsertRest{Sel: $1.(*stmts.UnionStmt)} + $$ = stmts.InsertValues{Sel: $1.(*stmts.UnionStmt)} } | "SET" ColumnSetValueList { - $$ = stmts.InsertRest{Setlist: $2.([]*expression.Assignment)} + $$ = stmts.InsertValues{Setlist: $2.([]*expression.Assignment)} } ValueSym: @@ -1833,9 +1833,9 @@ OnDuplicateKeyUpdate: * TODO: support PARTITION **********************************************************************************/ ReplaceIntoStmt: - "REPLACE" ReplacePriority IntoOpt TableIdent InsertRest + "REPLACE" ReplacePriority IntoOpt TableIdent InsertValues { - x := &stmts.ReplaceIntoStmt{InsertRest: $5.(stmts.InsertRest)} + x := &stmts.ReplaceIntoStmt{InsertValues: $5.(stmts.InsertValues)} x.Priority = $2.(int) x.TableIdent = $4.(table.Ident) $$ = x diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index dfaeb90a74a40..d9096fa372afb 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -46,8 +46,8 @@ const ( DelayedPriority ) -// InsertRest is the rest part of insert/replace into statement. -type InsertRest struct { +// InsertValues is the rest part of insert/replace into statement. +type InsertValues struct { ColNames []string Lists [][]expression.Expression Sel plan.Planner @@ -59,7 +59,7 @@ type InsertRest struct { // InsertIntoStmt is a statement to insert new rows into an existing table. // See: https://dev.mysql.com/doc/refman/5.7/en/insert.html type InsertIntoStmt struct { - InsertRest + InsertValues OnDuplicate []expression.Assignment Text string @@ -86,7 +86,7 @@ func (s *InsertIntoStmt) SetText(text string) { } // execExecSelect implements `insert table select ... from ...`. -func (s *InsertRest) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (rset.Recordset, error) { +func (s *InsertValues) execSelect(t table.Table, cols []*column.Col, ctx context.Context) (rset.Recordset, error) { r, err := s.Sel.Plan(ctx) if err != nil { return nil, errors.Trace(err) @@ -149,7 +149,7 @@ func (s *InsertRest) execSelect(t table.Table, cols []*column.Col, ctx context.C // 2 insert ... set x=y... --> set type column // 3 insert ... (select ..) --> name type column // See: https://dev.mysql.com/doc/refman/5.7/en/insert.html -func (s *InsertRest) getColumns(tableCols []*column.Col) ([]*column.Col, error) { +func (s *InsertValues) getColumns(tableCols []*column.Col) ([]*column.Col, error) { var cols []*column.Col var err error @@ -190,7 +190,7 @@ func (s *InsertRest) getColumns(tableCols []*column.Col) ([]*column.Col, error) return cols, nil } -func (s *InsertRest) getDefaultValues(ctx context.Context, cols []*column.Col) (map[interface{}]interface{}, error) { +func (s *InsertValues) getDefaultValues(ctx context.Context, cols []*column.Col) (map[interface{}]interface{}, error) { m := map[interface{}]interface{}{} for _, v := range cols { if value, ok, err := getDefaultValue(ctx, v); ok { @@ -205,7 +205,7 @@ func (s *InsertRest) getDefaultValues(ctx context.Context, cols []*column.Col) ( return m, nil } -func (s *InsertRest) fillValueList() error { +func (s *InsertValues) fillValueList() error { if len(s.Setlist) > 0 { if len(s.Lists) > 0 { return errors.Errorf("INSERT INTO %s: set type should not use values", s.TableIdent) @@ -233,6 +233,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) } // Process `insert ... (select ..) ` + // TODO: handles the duplicate-key in a primary key or a unique index. if s.Sel != nil { return s.execSelect(t, cols, ctx) } @@ -283,7 +284,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) return nil, nil } -func (s *InsertRest) checkValueCount(insertValueCount, valueCount, num int, cols []*column.Col) error { +func (s *InsertValues) checkValueCount(insertValueCount, valueCount, num int, cols []*column.Col) error { if insertValueCount != valueCount { // "insert into t values (), ()" is valid. // "insert into t values (), (1)" is not valid. @@ -302,7 +303,7 @@ func (s *InsertRest) checkValueCount(insertValueCount, valueCount, num int, cols return nil } -func (s *InsertRest) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, error) { +func (s *InsertValues) getRow(ctx context.Context, t table.Table, cols []*column.Col, list []expression.Expression, m map[interface{}]interface{}) ([]interface{}, error) { r := make([]interface{}, len(t.Cols())) marked := make(map[int]struct{}, len(list)) for i, expr := range list { @@ -372,7 +373,7 @@ func getOnDuplicateUpdateColumns(assignList []expression.Assignment, t table.Tab return m, nil } -func (s *InsertRest) initDefaultValues(ctx context.Context, t table.Table, row []interface{}, marked map[int]struct{}) error { +func (s *InsertValues) initDefaultValues(ctx context.Context, t table.Table, row []interface{}, marked map[int]struct{}) error { var err error var defaultValueCols []*column.Col for i, c := range t.Cols() { diff --git a/stmt/stmts/replace.go b/stmt/stmts/replace.go index 9d885a4a5ec2d..86c5ea41ee722 100644 --- a/stmt/stmts/replace.go +++ b/stmt/stmts/replace.go @@ -30,7 +30,7 @@ import ( // the old row is deleted before the new row is inserted. // See: https://dev.mysql.com/doc/refman/5.7/en/replace.html type ReplaceIntoStmt struct { - InsertRest + InsertValues Text string } @@ -67,6 +67,7 @@ func (s *ReplaceIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error } // Process `replace ... (select ...)` + // TODO: handles the duplicate-key in a primary key or a unique index. if s.Sel != nil { return s.execSelect(t, cols, ctx) } @@ -117,15 +118,13 @@ func replaceRow(ctx context.Context, t table.Table, handle int64, replaceRow []i isReplace := false touched := make([]bool, len(row)) for i, val := range row { - v, err := types.Compare(val, replaceRow[i]) - if err != nil { - return errors.Trace(err) + v, err1 := types.Compare(val, replaceRow[i]) + if err1 != nil { + return errors.Trace(err1) } if v != 0 { touched[i] = true - if !isReplace { - isReplace = true - } + isReplace = true } } if isReplace {