Skip to content

Commit

Permalink
feat:multi delete sql
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiaoxiong committed Oct 30, 2022
1 parent a98e390 commit 1a44802
Show file tree
Hide file tree
Showing 9 changed files with 277 additions and 21 deletions.
10 changes: 5 additions & 5 deletions pkg/datasource/sql/exec/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ import (
)

func init() {
undo.RegistrUndoLogBuilder(types.UpdateExecutor, builder.GetMySQLUpdateUndoLogBuilder)
undo.RegistrUndoLogBuilder(types.MultiExecutor, builder.GetMySQLMultiUndoLogBuilder)
undo.RegisterUndoLogBuilder(types.UpdateExecutor, builder.GetMySQLUpdateUndoLogBuilder)
undo.RegisterUndoLogBuilder(types.MultiExecutor, builder.GetMySQLMultiUndoLogBuilder)
}

// executorSolts
Expand Down Expand Up @@ -131,7 +131,7 @@ func (e *BaseExecutor) Interceptors(interceptors []SQLHook) {
// ExecWithNamedValue
func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.ExecContext, f CallbackWithNamedValue) (types.ExecResult, error) {
for i := range e.is {
e.is[i].Before(ctx, execCtx)
_ = e.is[i].Before(ctx, execCtx)
}

var (
Expand All @@ -151,7 +151,7 @@ func (e *BaseExecutor) ExecWithNamedValue(ctx context.Context, execCtx *types.Ex

defer func() {
for i := range e.is {
e.is[i].After(ctx, execCtx)
_ = e.is[i].After(ctx, execCtx)
}
}()

Expand Down Expand Up @@ -199,7 +199,7 @@ func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecCon

defer func() {
for i := range e.is {
e.is[i].After(ctx, execCtx)
_ = e.is[i].After(ctx, execCtx)
}
}()

Expand Down
1 change: 1 addition & 0 deletions pkg/datasource/sql/types/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ const (
DeleteExecutor
ReplaceIntoExecutor
MultiExecutor
MultiDeleteExecutor
InsertOnDuplicateExecutor
)

Expand Down
13 changes: 4 additions & 9 deletions pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
)

func init() {
undo.RegistrUndoLogBuilder(types.DeleteExecutor, GetMySQLDeleteUndoLogBuilder)
undo.RegisterUndoLogBuilder(types.DeleteExecutor, GetMySQLDeleteUndoLogBuilder)
}

type MySQLDeleteUndoLogBuilder struct {
Expand Down Expand Up @@ -96,16 +96,11 @@ func (u *MySQLDeleteUndoLogBuilder) buildBeforeImageSQL(query string, args []dri
return "", nil, fmt.Errorf("invalid delete stmt")
}

fields := []*ast.SelectField{}
fields = append(fields, &ast.SelectField{
WildCard: &ast.WildCardField{},
})

selStmt := ast.SelectStmt{
SelectStmtOpts: &ast.SelectStmtOpts{},
From: p.DeleteStmt.TableRefs,
Where: p.DeleteStmt.Where,
Fields: &ast.FieldList{Fields: fields},
Fields: &ast.FieldList{Fields: []*ast.SelectField{{WildCard: &ast.WildCardField{}}}},
OrderBy: p.DeleteStmt.Order,
Limit: p.DeleteStmt.Limit,
TableHints: p.DeleteStmt.TableHints,
Expand All @@ -115,9 +110,9 @@ func (u *MySQLDeleteUndoLogBuilder) buildBeforeImageSQL(query string, args []dri
}

b := bytes.NewByteBuffer([]byte{})
selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
_ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
sql := string(b.Bytes())
log.Infof("build select sql by delete sourceQuery, sql {}", sql)
log.Infof("build select sql by delete sourceQuery, sql {%s}", sql)

return sql, u.buildSelectArgs(&selStmt, args), nil
}
Expand Down
194 changes: 194 additions & 0 deletions pkg/datasource/sql/undo/builder/mysql_multi_delete_undo_log_builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package builder

import (
"bytes"
"context"
"database/sql/driver"
"strings"

"github.com/arana-db/parser/ast"
"github.com/arana-db/parser/format"
"github.com/seata/seata-go/pkg/datasource/sql/parser"
"github.com/seata/seata-go/pkg/datasource/sql/types"
"github.com/seata/seata-go/pkg/datasource/sql/undo"
"github.com/seata/seata-go/pkg/util/log"
)

func init() {
undo.RegisterUndoLogBuilder(types.MultiDeleteExecutor, GetMySQLMultiDeleteUndoLogBuilder)
}

type multiDelete struct {
sql string
clear bool
}

type MySQLMultiDeleteUndoLogBuilder struct {
BasicUndoLogBuilder
}

func GetMySQLMultiDeleteUndoLogBuilder() undo.UndoLogBuilder {
return &MySQLMultiDeleteUndoLogBuilder{BasicUndoLogBuilder: BasicUndoLogBuilder{}}
}

func (u *MySQLMultiDeleteUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) {
deletes := strings.Split(execCtx.Query, ";")
if len(deletes) == 1 {
return GetMySQLDeleteUndoLogBuilder().BeforeImage(ctx, execCtx)
}

values := make([]driver.Value, 0, len(execCtx.NamedValues)*2)
if execCtx.Values == nil {
for n, param := range execCtx.NamedValues {
values[n] = param.Value
}
}

multiQuery, args, err := u.buildBeforeImageSQL(deletes, values)
if err != nil {
return nil, err
}

var (
stmt driver.Stmt
rows driver.Rows

record *types.RecordImage
records []*types.RecordImage

meDataMap = execCtx.MetaDataMap[execCtx.ParseContext.DeleteStmt.
TableRefs.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O]
)

for _, sql := range multiQuery {
stmt, err = execCtx.Conn.Prepare(sql)
if err != nil {
log.Errorf("build prepare stmt: %+v", err)
return nil, err
}

rows, err = stmt.Query(args)
if err != nil {
log.Errorf("stmt query: %+v", err)
return nil, err
}

record, err = u.buildRecordImages(rows, meDataMap)
if err != nil {
log.Errorf("record images : %+v", err)
return nil, err
}
records = append(records, record)
}

return records, nil
}

func (u *MySQLMultiDeleteUndoLogBuilder) AfterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) {
return nil, nil
}

// buildBeforeImageSQL build delete sql from delete sql
func (u *MySQLMultiDeleteUndoLogBuilder) buildBeforeImageSQL(multiQuery []string, args []driver.Value) ([]string, []driver.Value, error) {
var (
err error
buf, param bytes.Buffer
p *types.ParseContext
tableName string
tables = make(map[string]multiDelete, len(multiQuery))
)

for _, query := range multiQuery {
p, err = parser.DoParser(query)
if err != nil {
return nil, nil, err
}

tableName = p.DeleteStmt.TableRefs.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O

v, ok := tables[tableName]
if ok && v.clear {
continue
}

buf.WriteString("delete from ")
buf.WriteString(tableName)

if p.DeleteStmt.Where == nil {
tables[tableName] = multiDelete{sql: buf.String(), clear: true}
buf.Reset()
continue
} else {
buf.WriteString(" where ")
}

_ = p.DeleteStmt.Where.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, &param))

v, ok = tables[tableName]
if ok {
buf.Reset()
buf.WriteString(v.sql)
buf.WriteString(" or ")
}

buf.Write(param.Bytes())
tables[tableName] = multiDelete{sql: buf.String()}

buf.Reset()
param.Reset()
}

var (
items = make([]string, 0, len(tables))
values = make([]driver.Value, 0, len(tables))
selStmt = ast.SelectStmt{
SelectStmtOpts: &ast.SelectStmtOpts{},
From: p.DeleteStmt.TableRefs,
Where: p.DeleteStmt.Where,
Fields: &ast.FieldList{Fields: []*ast.SelectField{{WildCard: &ast.WildCardField{}}}},
OrderBy: p.DeleteStmt.Order,
Limit: p.DeleteStmt.Limit,
TableHints: p.DeleteStmt.TableHints,
LockInfo: &ast.SelectLockInfo{LockType: ast.SelectLockForUpdate},
}
)

for _, table := range tables {
p, _ = parser.DoParser(table.sql)

selStmt.From = p.DeleteStmt.TableRefs
selStmt.Where = p.DeleteStmt.Where

_ = selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, &buf))
items = append(items, buf.String())
buf.Reset()
if table.clear {
values = append(values, u.buildSelectArgs(&selStmt, nil)...)
} else {
values = append(values, u.buildSelectArgs(&selStmt, args)...)
}
}

return items, values, nil
}

func (u *MySQLMultiDeleteUndoLogBuilder) GetExecutorType() types.ExecutorType {
return types.MultiDeleteExecutor
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package builder

import (
"testing"

"database/sql/driver"
"github.com/stretchr/testify/assert"
)

func TestBuildSelectSQLByMultiDelete(t *testing.T) {
tests := []struct {
name string
sourceQuery []string
sourceQueryArgs []driver.Value
expectQuery string
expectQueryArgs []driver.Value
}{
{
sourceQuery: []string{"delete from table_update_executor_test where id = ?", "delete from table_update_executor_test"},
sourceQueryArgs: []driver.Value{3},
expectQuery: "SELECT SQL_NO_CACHE * FROM table_update_executor_test FOR UPDATE",
expectQueryArgs: []driver.Value{},
},
{
sourceQuery: []string{"delete from table_update_executor_test2 where id = ?", "delete from table_update_executor_test2 where id = ?"},
sourceQueryArgs: []driver.Value{3, 2},
expectQuery: "SELECT SQL_NO_CACHE * FROM table_update_executor_test2 where id =? or id=? FOR UPDATE",
expectQueryArgs: []driver.Value{3, 2},
},
{
sourceQuery: []string{"delete from table_update_executor_test2 where id = ?", "delete from table_update_executor_test2 where name = ? and age = ?"},
sourceQueryArgs: []driver.Value{3, "seata-go", 4},
expectQuery: "SELECT SQL_NO_CACHE * FROM table_update_executor_test2 where id =? or id=? and age=? FOR UPDATE",
expectQueryArgs: []driver.Value{3, "seata-go", 4},
},
}

var builder = MySQLMultiDeleteUndoLogBuilder{}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
items, args, err := builder.buildBeforeImageSQL(tt.sourceQuery, tt.sourceQueryArgs)
assert.Nil(t, err)
assert.Equal(t, 1, len(items))
assert.Equal(t, tt.expectQueryArgs, args)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
)

func init() {
undo.RegistrUndoLogBuilder(types.MultiExecutor, GetMySQLMultiUndoLogBuilder)
undo.RegisterUndoLogBuilder(types.MultiExecutor, GetMySQLMultiUndoLogBuilder)
}

type MySQLMultiUndoLogBuilder struct {
Expand Down Expand Up @@ -68,6 +68,7 @@ func (u *MySQLMultiUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *typ
break
case types.DeleteExecutor:
// todo use MultiDeleteExecutor
tmpImages, err = GetMySQLMultiDeleteUndoLogBuilder().BeforeImage(ctx, execCtx)
break
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
)

func init() {
undo.RegistrUndoLogBuilder(types.UpdateExecutor, GetMySQLMultiUpdateUndoLogBuilder)
undo.RegisterUndoLogBuilder(types.UpdateExecutor, GetMySQLMultiUpdateUndoLogBuilder)
}

type updateVisitor struct {
Expand Down
Loading

0 comments on commit 1a44802

Please sign in to comment.