Skip to content

Commit

Permalink
*: Make 'IF NOT EXISTS' great again in 'CREATE TABLE IF NOT EXISTS LI…
Browse files Browse the repository at this point in the history
…KE' syntax (#6896) (#6928)
  • Loading branch information
zz-jason authored Jul 3, 2018
1 parent b056caa commit 9cefb33
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 22 deletions.
2 changes: 1 addition & 1 deletion ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ type DDL interface {
CreateSchema(ctx sessionctx.Context, name model.CIStr, charsetInfo *ast.CharsetOpt) error
DropSchema(ctx sessionctx.Context, schema model.CIStr) error
CreateTable(ctx sessionctx.Context, stmt *ast.CreateTableStmt) error
CreateTableWithLike(ctx sessionctx.Context, ident, referIdent ast.Ident) error
CreateTableWithLike(ctx sessionctx.Context, ident, referIdent ast.Ident, ifNotExists bool) error
DropTable(ctx sessionctx.Context, tableIdent ast.Ident) (err error)
CreateIndex(ctx sessionctx.Context, tableIdent ast.Ident, unique bool, indexName model.CIStr,
columnNames []*ast.IndexColName, indexOption *ast.IndexOption) error
Expand Down
9 changes: 7 additions & 2 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ func buildTableInfo(ctx sessionctx.Context, d *ddl, tableName model.CIStr, cols
return
}

func (d *ddl) CreateTableWithLike(ctx sessionctx.Context, ident, referIdent ast.Ident) error {
func (d *ddl) CreateTableWithLike(ctx sessionctx.Context, ident, referIdent ast.Ident, ifNotExists bool) error {
is := d.GetInformationSchema()
_, ok := is.SchemaByName(referIdent.Schema)
if !ok {
Expand All @@ -726,6 +726,10 @@ func (d *ddl) CreateTableWithLike(ctx sessionctx.Context, ident, referIdent ast.
return infoschema.ErrDatabaseNotExists.GenByArgs(ident.Schema)
}
if is.TableExists(ident.Schema, ident.Name) {
if ifNotExists {
ctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrTableExists.GenByArgs(ident))
return nil
}
return infoschema.ErrTableExists.GenByArgs(ident)
}

Expand Down Expand Up @@ -754,7 +758,7 @@ func (d *ddl) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStmt) (err e
ident := ast.Ident{Schema: s.Table.Schema, Name: s.Table.Name}
if s.ReferTable != nil {
referIdent := ast.Ident{Schema: s.ReferTable.Schema, Name: s.ReferTable.Name}
return d.CreateTableWithLike(ctx, ident, referIdent)
return d.CreateTableWithLike(ctx, ident, referIdent, s.IfNotExists)
}
colDefs := s.Cols
is := d.GetInformationSchema()
Expand All @@ -764,6 +768,7 @@ func (d *ddl) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStmt) (err e
}
if is.TableExists(ident.Schema, ident.Name) {
if s.IfNotExists {
ctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrTableExists.GenByArgs(ident))
return nil
}
return infoschema.ErrTableExists.GenByArgs(ident)
Expand Down
100 changes: 100 additions & 0 deletions ddl/integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package ddl_test

import (
"fmt"

"github.com/juju/errors"
. "github.com/pingcap/check"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
)

var _ = Suite(&testIntegrationSuite{})

type testIntegrationSuite struct {
store kv.Storage
dom *domain.Domain
ctx sessionctx.Context
}

func (s *testIntegrationSuite) TearDownTest(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
r := tk.MustQuery("show tables")
for _, tb := range r.Rows() {
tableName := tb[0]
tk.MustExec(fmt.Sprintf("drop table %v", tableName))
}
}

func (s *testIntegrationSuite) SetUpSuite(c *C) {
var err error
testleak.BeforeTest()
s.store, s.dom, err = newStoreWithBootstrap()
c.Assert(err, IsNil)
s.ctx = mock.NewContext()
}

func (s *testIntegrationSuite) TearDownSuite(c *C) {
s.dom.Close()
s.store.Close()
testleak.AfterTest(c)()
}

// for issue #6879
func (s *testIntegrationSuite) TestCreateTableIfNotExists(c *C) {
tk := testkit.NewTestKit(c, s.store)

tk.MustExec("USE test;")

tk.MustExec("create table t1(a bigint)")
tk.MustExec("create table t(a bigint)")

// Test duplicate create-table with `LIKE` clause
tk.MustExec("create table if not exists t like t1;")
warnings := tk.Se.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), GreaterEqual, 1)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(infoschema.ErrTableExists, lastWarn.Err), IsTrue)
c.Assert(lastWarn.Level, Equals, stmtctx.WarnLevelNote)

// Test duplicate create-table without `LIKE` clause
tk.MustExec("create table if not exists t(b bigint, c varchar(60));")
warnings = tk.Se.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), GreaterEqual, 1)
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(infoschema.ErrTableExists, lastWarn.Err), IsTrue)
}

func newStoreWithBootstrap() (kv.Storage, *domain.Domain, error) {
store, err := mockstore.NewMockTikvStore()
if err != nil {
return nil, nil, errors.Trace(err)
}
session.SetSchemaLease(0)
session.SetStatsLease(0)
dom, err := session.BootstrapSession(store)
return store, dom, errors.Trace(err)
}
8 changes: 4 additions & 4 deletions executor/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -696,14 +696,14 @@ func (e *ShowExec) fetchShowPlugins() error {

func (e *ShowExec) fetchShowWarnings() error {
warns := e.ctx.GetSessionVars().StmtCtx.GetWarnings()
for _, warn := range warns {
warn = errors.Cause(warn)
for _, w := range warns {
warn := errors.Cause(w.Err)
switch x := warn.(type) {
case *terror.Error:
sqlErr := x.ToSQLError()
e.appendRow([]interface{}{"Warning", int64(sqlErr.Code), sqlErr.Message})
e.appendRow([]interface{}{w.Level, int64(sqlErr.Code), sqlErr.Message})
default:
e.appendRow([]interface{}{"Warning", int64(mysql.ErrUnknown), warn.Error()})
e.appendRow([]interface{}{w.Level, int64(mysql.ErrUnknown), warn.Error()})
}
}
return nil
Expand Down
12 changes: 6 additions & 6 deletions expression/builtin_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (s *testEvaluatorSuite) TestCast(c *C) {

warnings := sc.GetWarnings()
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn), IsTrue)
c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn.Err), IsTrue)

f = BuildCastFunction(ctx, &Constant{Value: types.NewDatum("-1"), RetType: types.NewFieldType(mysql.TypeString)}, tp1)
res, err = f.Eval(nil)
Expand All @@ -103,7 +103,7 @@ func (s *testEvaluatorSuite) TestCast(c *C) {

warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrCastNegIntAsUnsigned, lastWarn), IsTrue)
c.Assert(terror.ErrorEqual(types.ErrCastNegIntAsUnsigned, lastWarn.Err), IsTrue)

f = BuildCastFunction(ctx, &Constant{Value: types.NewDatum("-18446744073709551616"), RetType: types.NewFieldType(mysql.TypeString)}, tp1)
res, err = f.Eval(nil)
Expand All @@ -114,7 +114,7 @@ func (s *testEvaluatorSuite) TestCast(c *C) {

warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn), IsTrue)
c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn.Err), IsTrue)

// cast('18446744073709551616' as signed);
mask := ^mysql.UnsignedFlag
Expand All @@ -126,7 +126,7 @@ func (s *testEvaluatorSuite) TestCast(c *C) {

warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn), IsTrue)
c.Assert(terror.ErrorEqual(types.ErrTruncatedWrongVal, lastWarn.Err), IsTrue)

// cast('18446744073709551614' as signed);
f = BuildCastFunction(ctx, &Constant{Value: types.NewDatum("18446744073709551614"), RetType: types.NewFieldType(mysql.TypeString)}, tp1)
Expand All @@ -136,7 +136,7 @@ func (s *testEvaluatorSuite) TestCast(c *C) {

warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrCastAsSignedOverflow, lastWarn), IsTrue)
c.Assert(terror.ErrorEqual(types.ErrCastAsSignedOverflow, lastWarn.Err), IsTrue)

// create table t1(s1 time);
// insert into t1 values('11:11:11');
Expand All @@ -158,7 +158,7 @@ func (s *testEvaluatorSuite) TestCast(c *C) {

warnings = sc.GetWarnings()
lastWarn = warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(types.ErrOverflow, lastWarn), IsTrue)
c.Assert(terror.ErrorEqual(types.ErrOverflow, lastWarn.Err), IsTrue)
sc = origSc

// cast(bad_string as decimal)
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_encryption_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ func (s *testEvaluatorSuite) TestPassword(c *C) {
c.Assert(len(warnings), Equals, warnCount+1)

lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errDeprecatedSyntaxNoReplacement, lastWarn), IsTrue)
c.Assert(terror.ErrorEqual(errDeprecatedSyntaxNoReplacement, lastWarn.Err), IsTrue)

warnCount = len(warnings)
} else {
Expand Down
34 changes: 28 additions & 6 deletions sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ import (
"github.com/pingcap/tidb/util/memory"
)

const (
// WarnLevelWarning represents level "Warning" for 'SHOW WARNINGS' syntax.
WarnLevelWarning = "Warning"
// WarnLevelNote represents level "Note" for 'SHOW WARNINGS' syntax.
WarnLevelNote = "Note"
)

// SQLWarn relates a sql warning and it's level.
type SQLWarn struct {
Level string
Err error
}

// StatementContext contains variables for a statement.
// It should be reset before executing a statement.
type StatementContext struct {
Expand All @@ -45,7 +58,7 @@ type StatementContext struct {
sync.Mutex
affectedRows uint64
foundRows uint64
warnings []error
warnings []SQLWarn
histogramsNotLoad bool
}

Expand Down Expand Up @@ -89,9 +102,9 @@ func (sc *StatementContext) AddFoundRows(rows uint64) {
}

// GetWarnings gets warnings.
func (sc *StatementContext) GetWarnings() []error {
func (sc *StatementContext) GetWarnings() []SQLWarn {
sc.mu.Lock()
warns := make([]error, len(sc.mu.warnings))
warns := make([]SQLWarn, len(sc.mu.warnings))
copy(warns, sc.mu.warnings)
sc.mu.Unlock()
return warns
Expand All @@ -109,17 +122,26 @@ func (sc *StatementContext) WarningCount() uint16 {
}

// SetWarnings sets warnings.
func (sc *StatementContext) SetWarnings(warns []error) {
func (sc *StatementContext) SetWarnings(warns []SQLWarn) {
sc.mu.Lock()
sc.mu.warnings = warns
sc.mu.Unlock()
}

// AppendWarning appends a warning.
// AppendWarning appends a warning with level 'Warning'.
func (sc *StatementContext) AppendWarning(warn error) {
sc.mu.Lock()
if len(sc.mu.warnings) < math.MaxUint16 {
sc.mu.warnings = append(sc.mu.warnings, warn)
sc.mu.warnings = append(sc.mu.warnings, SQLWarn{WarnLevelWarning, warn})
}
sc.mu.Unlock()
}

// AppendNote appends a warning with level 'Note'.
func (sc *StatementContext) AppendNote(warn error) {
sc.mu.Lock()
if len(sc.mu.warnings) < math.MaxUint16 {
sc.mu.warnings = append(sc.mu.warnings, SQLWarn{WarnLevelNote, warn})
}
sc.mu.Unlock()
}
Expand Down
4 changes: 2 additions & 2 deletions store/mockstore/mocktikv/cop_handler_dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ func (mock *mockCopStreamClient) readBlockFromExecutor() (tipb.Chunk, bool, *cop
return chunk, finish, &ran, mock.exec.Counts(), nil
}

func buildResp(chunks []tipb.Chunk, counts []int64, err error, warnings []error) *coprocessor.Response {
func buildResp(chunks []tipb.Chunk, counts []int64, err error, warnings []stmtctx.SQLWarn) *coprocessor.Response {
resp := &coprocessor.Response{}
selResp := &tipb.SelectResponse{
Error: toPBError(err),
Expand All @@ -542,7 +542,7 @@ func buildResp(chunks []tipb.Chunk, counts []int64, err error, warnings []error)
if len(warnings) > 0 {
selResp.Warnings = make([]*tipb.Error, 0, len(warnings))
for i := range warnings {
selResp.Warnings = append(selResp.Warnings, toPBError(warnings[i]))
selResp.Warnings = append(selResp.Warnings, toPBError(warnings[i].Err))
}
}
if err != nil {
Expand Down

0 comments on commit 9cefb33

Please sign in to comment.