Skip to content

Commit

Permalink
fix: do register resource where execute OpenConnector function (apach…
Browse files Browse the repository at this point in the history
…e#237)

* fix: register resource where execute OpenConnector function

* remove chinese comment
  • Loading branch information
luky116 authored Aug 24, 2022
1 parent fb985f5 commit 739821f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 23 deletions.
42 changes: 22 additions & 20 deletions pkg/datasource/sql/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -52,7 +51,7 @@ type SeataDriver struct {
func (d *SeataDriver) Open(name string) (driver.Conn, error) {
conn, err := d.target.Open(name)
if err != nil {
log.Errorf("open connection: %w", err)
log.Errorf("open target connection: %w", err)
return nil, err
}

Expand All @@ -62,35 +61,38 @@ func (d *SeataDriver) Open(name string) (driver.Conn, error) {
}

field := v.FieldByName("connector")

connector, _ := GetUnexportedField(field).(driver.Connector)

dbType := types.ParseDBType(d.getTargetDriverName())
if dbType == types.DBTypeUnknown {
return nil, errors.New("unsupport conn type")
}

c, err := d.OpenConnector(name)
proxy, err := d.OpenConnector(name)
if err != nil {
log.Errorf("open connector: %w", err)
return nil, fmt.Errorf("open connector error: %v", err.Error())
}

proxy, err := registerResource(connector, dbType, sql.OpenDB(c), name)
if err != nil {
log.Errorf("register resource: %w", err)
return nil, err
}

SetUnexportedField(field, proxy)
return conn, nil
}

func (d *SeataDriver) OpenConnector(dataSourceName string) (driver.Connector, error) {
func (d *SeataDriver) OpenConnector(name string) (c driver.Connector, err error) {
c = &dsnConnector{dsn: name, driver: d.target}
if driverCtx, ok := d.target.(driver.DriverContext); ok {
return driverCtx.OpenConnector(dataSourceName)
c, err = driverCtx.OpenConnector(name)
if err != nil {
log.Errorf("open connector: %w", err)
return nil, err
}
}
return &dsnConnector{dsn: dataSourceName, driver: d.target}, nil

dbType := types.ParseDBType(d.getTargetDriverName())
if dbType == types.DBTypeUnknown {
return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName())
}

proxy, err := registerResource(c, dbType, sql.OpenDB(c), name)
if err != nil {
log.Errorf("register resource: %w", err)
return nil, err
}

return proxy, nil
}

func (d *SeataDriver) getTargetDriverName() string {
Expand Down
6 changes: 5 additions & 1 deletion pkg/datasource/sql/exec/hook/logger_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ func (h *loggerSQLHook) Type() types.SQLType {

// Before
func (h *loggerSQLHook) Before(ctx context.Context, execCtx *exec.ExecContext) {
var txID string
if execCtx.TxCtx != nil {
txID = execCtx.TxCtx.LocalTransID
}
fields := []zap.Field{
zap.String("tx-id", execCtx.TxCtx.LocalTransID),
zap.String("tx-id", txID),
zap.String("sql", execCtx.Query),
}

Expand Down
36 changes: 34 additions & 2 deletions pkg/datasource/sql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@ import (
"testing"

_ "github.com/go-sql-driver/mysql"
"github.com/seata/seata-go/pkg/client"
"github.com/seata/seata-go/pkg/common/log"
)

var db *sql.DB

func Test_SQLOpen(t *testing.T) {
client.Init()
t.SkipNow()

db, err := sql.Open(SeataMySQLDriver, "root:polaris@tcp(127.0.0.1:3306)/polaris_server?multiStatements=true")
log.Info("begin test")
var err error
db, err = sql.Open(SeataMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/polaris_server?multiStatements=true")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -84,4 +90,30 @@ func Test_SQLOpen(t *testing.T) {
})

wait.Wait()
queryMultiRow()
}

func queryMultiRow() {
sqlStr := "select id, name from foo where id > ?"
rows, err := db.Query(sqlStr, 0)
if err != nil {
fmt.Printf("query failed, err:%v\n", err)
return
}
defer rows.Close()

for rows.Next() {
var u user
err := rows.Scan(&u.id, &u.name)
if err != nil {
fmt.Printf("scan failed, err:%v\n", err)
return
}
fmt.Printf("id:%d username:%s password:%s\n", u.id, u.name, u.name)
}
}

type user struct {
id int
name string
}

0 comments on commit 739821f

Please sign in to comment.