diff --git a/pkg/datasource/sql/conn.go b/pkg/datasource/sql/conn.go index 9e441ccfc..96f754461 100644 --- a/pkg/datasource/sql/conn.go +++ b/pkg/datasource/sql/conn.go @@ -35,6 +35,7 @@ type Conn struct { txCtx *types.TransactionContext targetConn driver.Conn autoCommit bool + dbName string } // ResetSession is called prior to executing a query on the connection @@ -93,7 +94,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e }, nil } -// Exec +// Exec warning: if you want to use global transaction, please use ExecContext function func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { conn, ok := c.targetConn.(driver.Execer) if !ok { @@ -113,7 +114,7 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { Conn: c.targetConn, } - return executor.ExecWithValue(context.Background(), execCtx, + return executor.ExecWithValue(context.Background(), execCtx, // todo context传的不对 func(ctx context.Context, query string, args []driver.Value) (types.ExecResult, error) { ret, err := conn.Exec(query, args) if err != nil { @@ -132,7 +133,7 @@ func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { // ExecContext func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { targetConn, ok := c.targetConn.(driver.ExecerContext) - if ok { + if !ok { values := make([]driver.Value, 0, len(args)) for i := range args { @@ -153,6 +154,7 @@ func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.Name Query: query, NamedValues: args, Conn: c.targetConn, + DBName: c.dbName, } ret, err := executor.ExecWithNamedValue(ctx, execCtx, diff --git a/pkg/datasource/sql/conn_at.go b/pkg/datasource/sql/conn_at.go index 3caafad41..b309276ad 100644 --- a/pkg/datasource/sql/conn_at.go +++ b/pkg/datasource/sql/conn_at.go @@ -71,7 +71,7 @@ func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, c.txCtx.DBType = c.res.dbType c.txCtx.TxOpt = opts - if IsGlobalTx(ctx) { + if tm.IsGlobalTx(ctx) { c.txCtx.XaID = tm.GetXID(ctx) c.txCtx.TransType = types.ATMode } @@ -85,13 +85,15 @@ func (c *ATConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, } func (c *ATConn) createOnceTxContext(ctx context.Context) bool { - onceTx := IsGlobalTx(ctx) && c.autoCommit + onceTx := tm.IsGlobalTx(ctx) && c.autoCommit if onceTx { c.txCtx = types.NewTxCtx() c.txCtx.DBType = c.res.dbType c.txCtx.XaID = tm.GetXID(ctx) + c.txCtx.XID = tm.GetXID(ctx) c.txCtx.TransType = types.ATMode + c.txCtx.GlobalLockRequire = true } return onceTx diff --git a/pkg/datasource/sql/conn_at_test.go b/pkg/datasource/sql/conn_at_test.go index c962ffb28..600158560 100644 --- a/pkg/datasource/sql/conn_at_test.go +++ b/pkg/datasource/sql/conn_at_test.go @@ -39,7 +39,7 @@ func initAtConnTestResource(t *testing.T) (*gomock.Controller, *sql.DB, *mockSQL mockMgr := initMockResourceManager(t, ctrl) _ = mockMgr - db, err := sql.Open("seata-at-mysql", "root:seata_go@tcp(127.0.0.1:3306)/seata_go_test?multiStatements=true") + db, err := sql.Open(SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_client?multiStatements=true") if err != nil { t.Fatal(err) } diff --git a/pkg/datasource/sql/conn_xa.go b/pkg/datasource/sql/conn_xa.go index 1924bd654..017c36b32 100644 --- a/pkg/datasource/sql/conn_xa.go +++ b/pkg/datasource/sql/conn_xa.go @@ -72,7 +72,7 @@ func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, c.txCtx.DBType = c.res.dbType c.txCtx.TxOpt = opts - if IsGlobalTx(ctx) { + if tm.IsGlobalTx(ctx) { c.txCtx.TransType = types.XAMode c.txCtx.XaID = tm.GetXID(ctx) } @@ -86,7 +86,7 @@ func (c *XAConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, } func (c *XAConn) createOnceTxContext(ctx context.Context) bool { - onceTx := IsGlobalTx(ctx) && c.autoCommit + onceTx := tm.IsGlobalTx(ctx) && c.autoCommit if onceTx { c.txCtx = types.NewTxCtx() diff --git a/pkg/datasource/sql/connector.go b/pkg/datasource/sql/connector.go index fa1588a47..7d3ea22e6 100644 --- a/pkg/datasource/sql/connector.go +++ b/pkg/datasource/sql/connector.go @@ -22,6 +22,8 @@ import ( "database/sql/driver" "sync" + "github.com/go-sql-driver/mysql" + "github.com/seata/seata-go/pkg/datasource/sql/types" ) @@ -92,6 +94,7 @@ type seataConnector struct { once sync.Once driver driver.Driver target driver.Connector + cfg *mysql.Config } // Connect returns a connection to the database. @@ -118,6 +121,7 @@ func (c *seataConnector) Connect(ctx context.Context) (driver.Conn, error) { res: c.res, txCtx: types.NewTxCtx(), autoCommit: true, + dbName: c.cfg.DBName, }, nil } diff --git a/pkg/datasource/sql/datasource/base/meta_cache.go b/pkg/datasource/sql/datasource/base/meta_cache.go index bc0eeb1b5..d778a3a3b 100644 --- a/pkg/datasource/sql/datasource/base/meta_cache.go +++ b/pkg/datasource/sql/datasource/base/meta_cache.go @@ -19,7 +19,7 @@ package base import ( "context" - "database/sql" + "database/sql/driver" "errors" "sync" "time" @@ -30,7 +30,7 @@ import ( type ( // trigger trigger interface { - LoadOne(ctx context.Context, dbName string, table string, conn *sql.Conn) (*types.TableMeta, error) + LoadOne(ctx context.Context, dbName string, table string, conn driver.Conn) (*types.TableMeta, error) LoadAll() ([]types.TableMeta, error) } @@ -47,14 +47,13 @@ type BaseTableMetaCache struct { expireDuration time.Duration capity int32 size int32 - dbName string cache map[string]*entry cancel context.CancelFunc trigger trigger } // NewBaseCache -func NewBaseCache(capity int32, dbName string, expireDuration time.Duration, trigger trigger) *BaseTableMetaCache { +func NewBaseCache(capity int32, expireDuration time.Duration, trigger trigger) *BaseTableMetaCache { ctx, cancel := context.WithCancel(context.Background()) c := &BaseTableMetaCache{ @@ -63,7 +62,6 @@ func NewBaseCache(capity int32, dbName string, expireDuration time.Duration, tri size: 0, expireDuration: expireDuration, cache: map[string]*entry{}, - dbName: dbName, cancel: cancel, trigger: trigger, } @@ -136,13 +134,13 @@ func (c *BaseTableMetaCache) scanExpire(ctx context.Context) { } // GetTableMeta -func (c *BaseTableMetaCache) GetTableMeta(ctx context.Context, tableName string, conn *sql.Conn) (types.TableMeta, error) { +func (c *BaseTableMetaCache) GetTableMeta(ctx context.Context, dbName, tableName string, conn driver.Conn) (types.TableMeta, error) { c.lock.Lock() defer c.lock.Unlock() v, ok := c.cache[tableName] if !ok { - meta, err := c.trigger.LoadOne(ctx, c.dbName, tableName, conn) + meta, err := c.trigger.LoadOne(ctx, dbName, tableName, conn) if err != nil { return types.TableMeta{}, err } diff --git a/pkg/datasource/sql/datasource/datasource_manager.go b/pkg/datasource/sql/datasource/datasource_manager.go index d4e03a47f..59f66f1dd 100644 --- a/pkg/datasource/sql/datasource/datasource_manager.go +++ b/pkg/datasource/sql/datasource/datasource_manager.go @@ -20,6 +20,7 @@ package datasource import ( "context" "database/sql" + "database/sql/driver" "errors" "sync" @@ -61,6 +62,7 @@ func GetDataSourceManager(b branch.BranchType) DataSourceManager { return nil } +// todo implements ResourceManagerOutbound interface // DataSourceManager type DataSourceManager interface { // Register a Resource to be managed by Resource Manager @@ -176,7 +178,7 @@ type TableMetaCache interface { // Init Init(ctx context.Context, conn *sql.DB) error // GetTableMeta - GetTableMeta(ctx context.Context, table string, conn *sql.Conn) (*types.TableMeta, error) + GetTableMeta(ctx context.Context, dbName, table string, conn driver.Conn) (*types.TableMeta, error) // Destroy Destroy() error } diff --git a/pkg/datasource/sql/datasource/mysql/default.go b/pkg/datasource/sql/datasource/mysql/default.go index 37d84f1ec..28b573f8f 100644 --- a/pkg/datasource/sql/datasource/mysql/default.go +++ b/pkg/datasource/sql/datasource/mysql/default.go @@ -25,6 +25,6 @@ import ( // todo func init() { datasource.RegisterTableCache(types.DBTypeMySQL, func() datasource.TableMetaCache { - return &tableMetaCache{} + return &TableMetaCache{} }) } diff --git a/pkg/datasource/sql/datasource/mysql/meta_cache.go b/pkg/datasource/sql/datasource/mysql/meta_cache.go index e3e517ec5..1ccae84b0 100644 --- a/pkg/datasource/sql/datasource/mysql/meta_cache.go +++ b/pkg/datasource/sql/datasource/mysql/meta_cache.go @@ -20,6 +20,7 @@ package mysql import ( "context" "database/sql" + "database/sql/driver" "sync" "time" @@ -32,20 +33,19 @@ import ( var ( capacity int32 = 1024 EexpireTime = 15 * time.Minute - tableMetaInstance *tableMetaCache + tableMetaInstance *TableMetaCache tableMetaOnce sync.Once - DBName = "seata" ) -type tableMetaCache struct { +type TableMetaCache struct { tableMetaCache *base.BaseTableMetaCache } -func GetTableMetaInstance() *tableMetaCache { +func GetTableMetaInstance() *TableMetaCache { // Todo constant.DBName get from config tableMetaOnce.Do(func() { - tableMetaInstance = &tableMetaCache{ - tableMetaCache: base.NewBaseCache(capacity, DBName, EexpireTime, NewMysqlTrigger()), + tableMetaInstance = &TableMetaCache{ + tableMetaCache: base.NewBaseCache(capacity, EexpireTime, NewMysqlTrigger()), } }) @@ -53,17 +53,17 @@ func GetTableMetaInstance() *tableMetaCache { } // Init -func (c *tableMetaCache) Init(ctx context.Context, conn *sql.DB) error { +func (c *TableMetaCache) Init(ctx context.Context, conn *sql.DB) error { return nil } // GetTableMeta get table info from cache or information schema -func (c *tableMetaCache) GetTableMeta(ctx context.Context, tableName string, conn *sql.Conn) (*types.TableMeta, error) { +func (c *TableMetaCache) GetTableMeta(ctx context.Context, dbName, tableName string, conn driver.Conn) (*types.TableMeta, error) { if tableName == "" { return nil, errors.New("TableMeta cannot be fetched without tableName") } - tableMeta, err := c.tableMetaCache.GetTableMeta(ctx, tableName, conn) + tableMeta, err := c.tableMetaCache.GetTableMeta(ctx, dbName, tableName, conn) if err != nil { return nil, err } @@ -72,6 +72,6 @@ func (c *tableMetaCache) GetTableMeta(ctx context.Context, tableName string, con } // Destroy -func (c *tableMetaCache) Destroy() error { +func (c *TableMetaCache) Destroy() error { return nil } diff --git a/pkg/datasource/sql/datasource/mysql/meta_cache_test.go b/pkg/datasource/sql/datasource/mysql/meta_cache_test.go index 598d1b7d7..63e89b3a0 100644 --- a/pkg/datasource/sql/datasource/mysql/meta_cache_test.go +++ b/pkg/datasource/sql/datasource/mysql/meta_cache_test.go @@ -42,9 +42,8 @@ func TestGetTableMeta(t *testing.T) { defer db.Close() ctx := context.Background() - conn, _ := db.Conn(ctx) - tableMeta, err := metaInstance.GetTableMeta(ctx, "undo_log", conn) + tableMeta, err := metaInstance.GetTableMeta(ctx, "seata_client", "undo_log", nil) assert.NilError(t, err) t.Logf("%+v", tableMeta) diff --git a/pkg/datasource/sql/datasource/mysql/trigger.go b/pkg/datasource/sql/datasource/mysql/trigger.go index 5f428f088..9d2946070 100644 --- a/pkg/datasource/sql/datasource/mysql/trigger.go +++ b/pkg/datasource/sql/datasource/mysql/trigger.go @@ -19,7 +19,8 @@ package mysql import ( "context" - "database/sql" + "database/sql/driver" + "io" "strings" "github.com/pkg/errors" @@ -27,6 +28,11 @@ import ( "github.com/seata/seata-go/pkg/datasource/sql/undo/executor" ) +const ( + columnMetaSql = "SELECT `TABLE_NAME`, `TABLE_SCHEMA`, `COLUMN_NAME`, `DATA_TYPE`, `COLUMN_TYPE`, `COLUMN_KEY`, `IS_NULLABLE`, `EXTRA` FROM INFORMATION_SCHEMA.COLUMNS WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + indexMetaSql = "SELECT `INDEX_NAME`, `COLUMN_NAME`, `NON_UNIQUE`, `INDEX_TYPE`, `COLLATION`, `CARDINALITY` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" +) + type mysqlTrigger struct { } @@ -35,7 +41,7 @@ func NewMysqlTrigger() *mysqlTrigger { } // LoadOne get table meta column and index -func (m *mysqlTrigger) LoadOne(ctx context.Context, dbName string, tableName string, conn *sql.Conn) (*types.TableMeta, error) { +func (m *mysqlTrigger) LoadOne(ctx context.Context, dbName string, tableName string, conn driver.Conn) (*types.TableMeta, error) { tableMeta := types.TableMeta{ Name: tableName, Columns: make(map[string]types.ColumnMeta), @@ -63,6 +69,7 @@ func (m *mysqlTrigger) LoadOne(ctx context.Context, dbName string, tableName str idx, ok := tableMeta.Indexs[index.Name] if ok { idx.Values = append(idx.Values, col) + tableMeta.Indexs[index.Name] = idx } else { index.Values = append(index.Values, col) tableMeta.Indexs[index.Name] = index @@ -81,52 +88,42 @@ func (m *mysqlTrigger) LoadAll() ([]types.TableMeta, error) { } // getColumns get tableMeta column -func (m *mysqlTrigger) getColumns(ctx context.Context, dbName string, table string, conn *sql.Conn) ([]types.ColumnMeta, error) { +func (m *mysqlTrigger) getColumns(ctx context.Context, dbName string, table string, conn driver.Conn) ([]types.ColumnMeta, error) { table = executor.DelEscape(table, types.DBTypeMySQL) + var columnMetas []types.ColumnMeta - var result []types.ColumnMeta - - columnSchemaSql := "select TABLE_CATALOG, TABLE_NAME, TABLE_SCHEMA, COLUMN_NAME, DATA_TYPE, COLUMN_TYPE, COLUMN_KEY, " + - " IS_NULLABLE, EXTRA from INFORMATION_SCHEMA.COLUMNS where `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - - stmt, err := conn.PrepareContext(ctx, columnSchemaSql) + stmt, err := conn.Prepare(columnMetaSql) if err != nil { return nil, err } - rows, err := stmt.QueryContext(ctx, dbName, table) + rowsi, err := stmt.Query([]driver.Value{dbName, table}) if err != nil { return nil, err } - for rows.Next() { + for { + vals := make([]driver.Value, 8) + err = rowsi.Next(vals) + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + var ( - tableCatalog string - tableName string - tableSchema string - columnName string - dataType string - columnType string - columnKey string - isNullable string - extra string + tableName = string(vals[0].([]uint8)) + tableSchema = string(vals[1].([]uint8)) + columnName = string(vals[2].([]uint8)) + dataType = string(vals[3].([]uint8)) + columnType = string(vals[4].([]uint8)) + columnKey = string(vals[5].([]uint8)) + isNullable = string(vals[6].([]uint8)) + extra = string(vals[7].([]uint8)) ) col := types.ColumnMeta{} - - if err = rows.Scan( - &tableCatalog, - &tableName, - &tableSchema, - &columnName, - &dataType, - &columnType, - &columnKey, - &isNullable, - &extra); err != nil { - return nil, err - } - col.Schema = tableSchema col.Table = tableName col.ColumnName = strings.Trim(columnName, "` ") @@ -141,61 +138,52 @@ func (m *mysqlTrigger) getColumns(ctx context.Context, dbName string, table stri col.Extra = extra col.Autoincrement = strings.Contains(strings.ToLower(extra), "auto_increment") - result = append(result, col) + columnMetas = append(columnMetas, col) } - if err = rows.Err(); err != nil { - return nil, err - } - - if err = rows.Close(); err != nil { - return nil, err - } - - if len(result) == 0 { + if len(columnMetas) == 0 { return nil, errors.New("can't find column") } - return result, nil + return columnMetas, nil } // getIndex get tableMetaIndex -func (m *mysqlTrigger) getIndexes(ctx context.Context, dbName string, tableName string, conn *sql.Conn) ([]types.IndexMeta, error) { +func (m *mysqlTrigger) getIndexes(ctx context.Context, dbName string, tableName string, conn driver.Conn) ([]types.IndexMeta, error) { tableName = executor.DelEscape(tableName, types.DBTypeMySQL) - result := make([]types.IndexMeta, 0) - indexSchemaSql := "SELECT `INDEX_NAME`, `COLUMN_NAME`, `NON_UNIQUE`, `INDEX_TYPE`, `COLLATION`, `CARDINALITY` " + - "FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - - stmt, err := conn.PrepareContext(ctx, indexSchemaSql) + stmt, err := conn.Prepare(indexMetaSql) if err != nil { return nil, err } - rows, err := stmt.QueryContext(ctx, dbName, tableName) + rowsi, err := stmt.Query([]driver.Value{dbName, tableName}) if err != nil { return nil, err } - defer rows.Close() - - for rows.Next() { - var ( - indexName, columnName, nonUnique, indexType, collation string - cardinality int - ) + defer rowsi.Close() - if err = rows.Scan( - &indexName, - &columnName, - &nonUnique, - &indexType, - &collation, - &cardinality); err != nil { + for { + vals := make([]driver.Value, 6) + err = rowsi.Next(vals) + if err == io.EOF { + break + } + if err != nil { return nil, err } + var ( + indexName = string(vals[0].([]uint8)) + columnName = string(vals[1].([]uint8)) + nonUnique = vals[2].(int64) + //indexType = string(vals[3].([]uint8)) + //collation = string(vals[4].([]uint8)) + //cardinality = int(vals[6].([]uint8)) + ) + index := types.IndexMeta{ Schema: dbName, Table: tableName, @@ -204,12 +192,12 @@ func (m *mysqlTrigger) getIndexes(ctx context.Context, dbName string, tableName Values: make([]types.ColumnMeta, 0), } - if nonUnique == "1" || "yes" == strings.ToLower(nonUnique) { + if nonUnique == 1 { index.NonUnique = true } if "primary" == strings.ToLower(indexName) { - index.IType = types.IndexPrimary + index.IType = types.IndexTypePrimaryKey } else if !index.NonUnique { index.IType = types.IndexUnique } else { @@ -219,9 +207,5 @@ func (m *mysqlTrigger) getIndexes(ctx context.Context, dbName string, tableName result = append(result, index) } - if err = rows.Err(); err != nil { - return nil, err - } - return result, nil } diff --git a/pkg/datasource/sql/driver.go b/pkg/datasource/sql/driver.go index fdecca077..d2df87ada 100644 --- a/pkg/datasource/sql/driver.go +++ b/pkg/datasource/sql/driver.go @@ -65,6 +65,8 @@ func (d *seataATDriver) OpenConnector(name string) (c driver.Connector, err erro _connector, _ := connector.(*seataConnector) _connector.transType = types.ATMode + cfg, _ := mysql.ParseDSN(name) + _connector.cfg = cfg return &seataATConnector{ seataConnector: _connector, @@ -83,6 +85,8 @@ func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err erro _connector, _ := connector.(*seataConnector) _connector.transType = types.XAMode + cfg, _ := mysql.ParseDSN(name) + _connector.cfg = cfg return &seataXAConnector{ seataConnector: _connector, @@ -119,7 +123,7 @@ func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName()) } - proxy, err := registerResource(c, d.transType, dbType, sql.OpenDB(c), name) + proxy, err := getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name) if err != nil { log.Errorf("register resource: %w", err) return nil, err @@ -145,7 +149,7 @@ func (t *dsnConnector) Driver() driver.Driver { return t.driver } -func registerResource(connector driver.Connector, txType types.TransactionType, dbType types.DBType, db *sql.DB, +func getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, db *sql.DB, dataSourceName string, opts ...seataOption) (driver.Connector, error) { conf := loadConfig() for i := range opts { @@ -175,11 +179,13 @@ func registerResource(connector driver.Connector, txType types.TransactionType, log.Errorf("regisiter resource: %w", err) return nil, err } + cfg, _ := mysql.ParseDSN(dataSourceName) return &seataConnector{ res: res, target: connector, conf: conf, + cfg: cfg, }, nil } diff --git a/pkg/datasource/sql/exec/executor.go b/pkg/datasource/sql/exec/executor.go index 3b0fadac2..6d1f28314 100644 --- a/pkg/datasource/sql/exec/executor.go +++ b/pkg/datasource/sql/exec/executor.go @@ -224,7 +224,7 @@ func (e *BaseExecutor) ExecWithValue(ctx context.Context, execCtx *types.ExecCon } func (h *BaseExecutor) beforeImage(ctx context.Context, execCtx *types.ExecContext) ([]*types.RecordImage, error) { - if !tm.IsTransactionOpened(ctx) { + if !tm.IsGlobalTx(ctx) { return nil, nil } @@ -235,6 +235,7 @@ func (h *BaseExecutor) beforeImage(ctx context.Context, execCtx *types.ExecConte if !pc.HasValidStmt() { return nil, nil } + execCtx.ParseContext = pc builder := undo.GetUndologBuilder(pc.ExecutorType) if builder == nil { @@ -245,7 +246,7 @@ func (h *BaseExecutor) beforeImage(ctx context.Context, execCtx *types.ExecConte // After func (h *BaseExecutor) afterImage(ctx context.Context, execCtx *types.ExecContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { - if !tm.IsTransactionOpened(ctx) { + if !tm.IsGlobalTx(ctx) { return nil, nil } pc, err := parser.DoParser(execCtx.Query) @@ -255,6 +256,7 @@ func (h *BaseExecutor) afterImage(ctx context.Context, execCtx *types.ExecContex if !pc.HasValidStmt() { return nil, nil } + execCtx.ParseContext = pc builder := undo.GetUndologBuilder(pc.ExecutorType) if builder == nil { return nil, nil diff --git a/pkg/datasource/sql/exec/select_for_update_executor.go b/pkg/datasource/sql/exec/select_for_update_executor.go index c70109eb4..541c88bdf 100644 --- a/pkg/datasource/sql/exec/select_for_update_executor.go +++ b/pkg/datasource/sql/exec/select_for_update_executor.go @@ -325,7 +325,7 @@ func (s SelectForUpdateExecutor) buildLockKey(rows driver.Rows, meta types.Table lockKeys.WriteString(meta.Schema) lockKeys.WriteString(":") - ss := s.GetScanSlice(meta.GetPrimaryKeyOnlyName(), meta) + ss := s.GetScanSlice(meta.GetPrimaryKeyOnlyName(), &meta) for { err := rows.Next(ss) if err == io.EOF { diff --git a/pkg/datasource/sql/hook/undo_log_hook.go b/pkg/datasource/sql/hook/undo_log_hook.go index 327bf7b93..a8ff79b55 100644 --- a/pkg/datasource/sql/hook/undo_log_hook.go +++ b/pkg/datasource/sql/hook/undo_log_hook.go @@ -40,7 +40,7 @@ func (h *undoLogSQLHook) Type() types.SQLType { // Before func (h *undoLogSQLHook) Before(ctx context.Context, execCtx *types.ExecContext) error { - if !tm.IsTransactionOpened(ctx) { + if !tm.IsGlobalTx(ctx) { return nil } @@ -66,7 +66,7 @@ func (h *undoLogSQLHook) Before(ctx context.Context, execCtx *types.ExecContext) // After func (h *undoLogSQLHook) After(ctx context.Context, execCtx *types.ExecContext) error { - if !tm.IsTransactionOpened(ctx) { + if !tm.IsGlobalTx(ctx) { return nil } return nil diff --git a/pkg/datasource/sql/stmt.go b/pkg/datasource/sql/stmt.go index f1ebe072a..e62d76cdb 100644 --- a/pkg/datasource/sql/stmt.go +++ b/pkg/datasource/sql/stmt.go @@ -193,6 +193,9 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive return types.NewResult(types.WithResult(ret)), nil }) + if err != nil { + return nil, err + } return ret.GetResult(), err } diff --git a/pkg/datasource/sql/types/types.go b/pkg/datasource/sql/types/types.go index 8767b4779..7e6292509 100644 --- a/pkg/datasource/sql/types/types.go +++ b/pkg/datasource/sql/types/types.go @@ -38,6 +38,8 @@ type ( const ( IndexTypeNull IndexType = 0 IndexTypePrimaryKey IndexType = 1 + IndexUnique IndexType = 2 + IndexNormal IndexType = 3 ) const ( @@ -55,9 +57,9 @@ const ( // IndexPrimary primary index type. IndexPrimary = 0 // IndexNormal normal index type. - IndexNormal = 1 + //IndexNormal = 1 // IndexUnique unique index type. - IndexUnique = 2 + //IndexUnique = 2 // IndexFullText full text index type. IndexFullText = 3 ) @@ -98,7 +100,7 @@ type TransactionContext struct { // BranchID transaction branch unique id BranchID uint64 // XaID XA id - XaID string + XaID string // todo delete // XID global transaction id XID string // GlobalLockRequire @@ -114,8 +116,10 @@ type ExecContext struct { ParseContext *ParseContext NamedValues []driver.NamedValue Values []driver.Value - MetaDataMap map[string]TableMeta - Conn driver.Conn + // todo 待删除 + MetaDataMap map[string]TableMeta + Conn driver.Conn + DBName string // todo set values for these 4 param IsAutoCommit bool IsSupportsSavepoints bool diff --git a/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go b/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go index 226f59ff4..e0dde1a6a 100644 --- a/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/basic_undo_log_builder.go @@ -33,7 +33,7 @@ import ( type BasicUndoLogBuilder struct{} // getScanSlice get the column type for scann -func (*BasicUndoLogBuilder) GetScanSlice(columnNames []string, tableMeta types.TableMeta) []driver.Value { +func (*BasicUndoLogBuilder) GetScanSlice(columnNames []string, tableMeta *types.TableMeta) []driver.Value { scanSlice := make([]driver.Value, 0, len(columnNames)) for _, columnNmae := range columnNames { var ( @@ -152,7 +152,7 @@ func (b *BasicUndoLogBuilder) traversalArgs(node ast.Node, argsIndex *[]int32) { } } -func (b *BasicUndoLogBuilder) buildRecordImages(rowsi driver.Rows, tableMetaData types.TableMeta) (*types.RecordImage, error) { +func (b *BasicUndoLogBuilder) buildRecordImages(rowsi driver.Rows, tableMetaData *types.TableMeta) (*types.RecordImage, error) { // select column names columnNames := rowsi.Columns() rowImages := make([]types.RowImage, 0) @@ -163,6 +163,9 @@ func (b *BasicUndoLogBuilder) buildRecordImages(rowsi driver.Rows, tableMetaData if err == io.EOF { break } + if err != nil { + return nil, err + } columns := make([]types.ColumnImage, 0) // build record image diff --git a/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder.go index 4a566cf48..abd4299e7 100644 --- a/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_delete_undo_log_builder.go @@ -72,7 +72,7 @@ func (u *MySQLDeleteUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *ty tableName := execCtx.ParseContext.DeleteStmt.TableRefs.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O metaData := execCtx.MetaDataMap[tableName] - image, err := u.buildRecordImages(rows, metaData) + image, err := u.buildRecordImages(rows, &metaData) if err != nil { return nil, err } diff --git a/pkg/datasource/sql/undo/builder/mysql_multi_delete_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_multi_delete_undo_log_builder.go index 946a780d4..8af43bbea 100644 --- a/pkg/datasource/sql/undo/builder/mysql_multi_delete_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_multi_delete_undo_log_builder.go @@ -32,10 +32,6 @@ import ( "github.com/seata/seata-go/pkg/util/log" ) -func init() { - undo.RegisterUndoLogBuilder(types.MultiDeleteExecutor, GetMySQLMultiDeleteUndoLogBuilder) -} - type multiDelete struct { sql string clear bool @@ -91,7 +87,7 @@ func (u *MySQLMultiDeleteUndoLogBuilder) BeforeImage(ctx context.Context, execCt return nil, err } - record, err = u.buildRecordImages(rows, meDataMap) + record, err = u.buildRecordImages(rows, &meDataMap) if err != nil { log.Errorf("record images : %+v", err) return nil, err diff --git a/pkg/datasource/sql/undo/builder/mysql_multi_delete_undo_log_builder_test.go b/pkg/datasource/sql/undo/builder/mysql_multi_delete_undo_log_builder_test.go index 9ea8d426e..633635ee6 100644 --- a/pkg/datasource/sql/undo/builder/mysql_multi_delete_undo_log_builder_test.go +++ b/pkg/datasource/sql/undo/builder/mysql_multi_delete_undo_log_builder_test.go @@ -21,6 +21,7 @@ import ( "testing" "database/sql/driver" + "github.com/stretchr/testify/assert" ) diff --git a/pkg/datasource/sql/undo/builder/mysql_multi_update_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_multi_update_undo_log_builder.go index 30e639a14..cee165cae 100644 --- a/pkg/datasource/sql/undo/builder/mysql_multi_update_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_multi_update_undo_log_builder.go @@ -34,9 +34,9 @@ import ( "github.com/seata/seata-go/pkg/util/log" ) -func init() { - undo.RegisterUndoLogBuilder(types.UpdateExecutor, GetMySQLMultiUpdateUndoLogBuilder) -} +//func init() { +// undo.RegisterUndoLogBuilder(types.MultiExecutor, GetMySQLMultiUpdateUndoLogBuilder) +//} type updateVisitor struct { stmt *ast.UpdateStmt @@ -95,7 +95,7 @@ func (u *MySQLMultiUpdateUndoLogBuilder) BeforeImage(ctx context.Context, execCt tableName := execCtx.ParseContext.UpdateStmt.TableRefs.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O metaData := execCtx.MetaDataMap[tableName] - image, err := u.buildRecordImages(rows, metaData) + image, err := u.buildRecordImages(rows, &metaData) if err != nil { return nil, err } @@ -125,7 +125,7 @@ func (u *MySQLMultiUpdateUndoLogBuilder) AfterImage(ctx context.Context, execCtx return nil, err } - image, err := u.buildRecordImages(rows, metaData) + image, err := u.buildRecordImages(rows, &metaData) if err != nil { return nil, err } diff --git a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go index 406bc85c3..87e86f153 100644 --- a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go +++ b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder.go @@ -23,6 +23,9 @@ import ( "fmt" "strings" + "github.com/arana-db/parser/model" + "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" + "github.com/arana-db/parser/ast" "github.com/arana-db/parser/format" @@ -57,12 +60,19 @@ func (u *MySQLUpdateUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *ty vals := execCtx.Values if vals == nil { - for n, param := range execCtx.NamedValues { - vals[n] = param.Value + vals = make([]driver.Value, 0) + for _, param := range execCtx.NamedValues { + vals = append(vals, param.Value) } } // use - selectSQL, selectArgs, err := u.buildBeforeImageSQL(execCtx.ParseContext.UpdateStmt, vals) + selectSQL, selectArgs, err := u.buildBeforeImageSQL(ctx, execCtx, vals) + if err != nil { + return nil, err + } + + tableName, _ := execCtx.ParseContext.GteTableName() + metaData, err := mysql.GetTableMetaInstance().GetTableMeta(ctx, execCtx.DBName, tableName, execCtx.Conn) if err != nil { return nil, err } @@ -79,9 +89,6 @@ func (u *MySQLUpdateUndoLogBuilder) BeforeImage(ctx context.Context, execCtx *ty return nil, err } - tableName := execCtx.ParseContext.UpdateStmt.TableRefs.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O - metaData := execCtx.MetaDataMap[tableName] - image, err := u.buildRecordImages(rows, metaData) if err != nil { return nil, err @@ -100,8 +107,11 @@ func (u *MySQLUpdateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *typ beforeImage = beforeImages[0] } - tableName := execCtx.ParseContext.UpdateStmt.TableRefs.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).Name.O - metaData := execCtx.MetaDataMap[tableName] + tableName, _ := execCtx.ParseContext.GteTableName() + metaData, err := mysql.GetTableMetaInstance().GetTableMeta(ctx, execCtx.DBName, tableName, execCtx.Conn) + if err != nil { + return nil, err + } selectSQL, selectArgs := u.buildAfterImageSQL(beforeImage, metaData) stmt, err := execCtx.Conn.Prepare(selectSQL) @@ -124,17 +134,18 @@ func (u *MySQLUpdateUndoLogBuilder) AfterImage(ctx context.Context, execCtx *typ return []*types.RecordImage{image}, nil } -func (u *MySQLUpdateUndoLogBuilder) buildAfterImageSQL(beforeImage *types.RecordImage, meta types.TableMeta) (string, []driver.Value) { +func (u *MySQLUpdateUndoLogBuilder) buildAfterImageSQL(beforeImage *types.RecordImage, meta *types.TableMeta) (string, []driver.Value) { sb := strings.Builder{} // todo use ONLY_CARE_UPDATE_COLUMNS to judge select all columns or not - sb.WriteString("SELECT * FROM " + meta.Name + " ") + sb.WriteString("SELECT * FROM " + meta.Name + " WHERE ") whereSQL := u.buildWhereConditionByPKs(meta.GetPrimaryKeyOnlyName(), len(beforeImage.Rows), "mysql", maxInSize) sb.WriteString(" " + whereSQL + " ") return sb.String(), u.buildPKParams(beforeImage.Rows, meta.GetPrimaryKeyOnlyName()) } // buildSelectSQLByUpdate build select sql from update sql -func (u *MySQLUpdateUndoLogBuilder) buildBeforeImageSQL(updateStmt *ast.UpdateStmt, args []driver.Value) (string, []driver.Value, error) { +func (u *MySQLUpdateUndoLogBuilder) buildBeforeImageSQL(ctx context.Context, execCtx *types.ExecContext, args []driver.Value) (string, []driver.Value, error) { + updateStmt := execCtx.ParseContext.UpdateStmt if updateStmt == nil { log.Errorf("invalid update stmt") return "", nil, fmt.Errorf("invalid update stmt") @@ -151,6 +162,25 @@ func (u *MySQLUpdateUndoLogBuilder) buildBeforeImageSQL(updateStmt *ast.UpdateSt }) } + // select indexes columns + tableName, _ := execCtx.ParseContext.GteTableName() + metaData, err := mysql.GetTableMetaInstance().GetTableMeta(ctx, execCtx.DBName, tableName, execCtx.Conn) + if err != nil { + return "", nil, err + } + for _, columnName := range metaData.GetPrimaryKeyOnlyName() { + fields = append(fields, &ast.SelectField{ + Expr: &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Name: model.CIStr{ + O: columnName, + L: columnName, + }, + }, + }, + }) + } + selStmt := ast.SelectStmt{ SelectStmtOpts: &ast.SelectStmtOpts{}, From: updateStmt.TableRefs, diff --git a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder_test.go b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder_test.go index f1618aa82..ebd476073 100644 --- a/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder_test.go +++ b/pkg/datasource/sql/undo/builder/mysql_update_undo_log_builder_test.go @@ -18,7 +18,12 @@ package builder import ( + "context" "database/sql/driver" + "github.com/agiledragon/gomonkey" + "github.com/seata/seata-go/pkg/datasource/sql/datasource/mysql" + "github.com/seata/seata-go/pkg/datasource/sql/types" + "reflect" "testing" "github.com/seata/seata-go/pkg/datasource/sql/parser" @@ -32,6 +37,22 @@ func TestBuildSelectSQLByUpdate(t *testing.T) { var ( builder = MySQLUpdateUndoLogBuilder{} ) + //stub := gomonkey.ApplyMethod(reflect.TypeOf(mysql.GetTableMetaInstance()), "GetTableMeta", func(_ *datasource.TableMetaCache, ctx context.Context, dbName, tableName string, conn driver.Conn) (*types.TableMeta, error) { + // return &types.TableMeta{ + // + // }, nil + //}) + stub := gomonkey.ApplyMethod(reflect.TypeOf(mysql.GetTableMetaInstance()), "GetTableMeta", func(_ *mysql.TableMetaCache, ctx context.Context, dbName, tableName string, conn driver.Conn) (*types.TableMeta, error) { + return &types.TableMeta{ + Indexs: map[string]types.IndexMeta{ + "id": types.IndexMeta{ + ColumnName: "id", + IType: types.IndexTypePrimaryKey, + }, + }, + }, nil + }) + defer stub.Reset() tests := []struct { name string @@ -43,25 +64,25 @@ func TestBuildSelectSQLByUpdate(t *testing.T) { { sourceQuery: "update t_user set name = ?, age = ? where id = ?", sourceQueryArgs: []driver.Value{"Jack", 1, 100}, - expectQuery: "SELECT SQL_NO_CACHE name,age FROM t_user WHERE id=? FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? FOR UPDATE", expectQueryArgs: []driver.Value{100}, }, { sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age between ? and ?", sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, - expectQuery: "SELECT SQL_NO_CACHE name,age FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age BETWEEN ? AND ? FOR UPDATE", expectQueryArgs: []driver.Value{100, 18, 28}, }, { sourceQuery: "update t_user set name = ?, age = ? where id = ? and name = 'Jack' and age in (?,?)", sourceQueryArgs: []driver.Value{"Jack", 1, 100, 18, 28}, - expectQuery: "SELECT SQL_NO_CACHE name,age FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE id=? AND name=_UTF8MB4Jack AND age IN (?,?) FOR UPDATE", expectQueryArgs: []driver.Value{100, 18, 28}, }, { sourceQuery: "update t_user set name = ?, age = ? where kk between ? and ? and id = ? and addr in(?,?) and age > ? order by name desc limit ?", sourceQueryArgs: []driver.Value{"Jack", 1, 10, 20, 17, "Beijing", "Guangzhou", 18, 2}, - expectQuery: "SELECT SQL_NO_CACHE name,age FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE", + expectQuery: "SELECT SQL_NO_CACHE name,age,id FROM t_user WHERE kk BETWEEN ? AND ? AND id=? AND addr IN (?,?) AND age>? ORDER BY name DESC LIMIT ? FOR UPDATE", expectQueryArgs: []driver.Value{10, 20, 17, "Beijing", "Guangzhou", 18, 2}, }, } @@ -69,7 +90,7 @@ func TestBuildSelectSQLByUpdate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c, err := parser.DoParser(tt.sourceQuery) assert.Nil(t, err) - query, args, err := builder.buildBeforeImageSQL(c.UpdateStmt, tt.sourceQueryArgs) + query, args, err := builder.buildBeforeImageSQL(context.Background(), &types.ExecContext{ParseContext: c}, tt.sourceQueryArgs) assert.Nil(t, err) assert.Equal(t, tt.expectQuery, query) assert.Equal(t, tt.expectQueryArgs, args) diff --git a/pkg/datasource/sql/undo/executor/sql.go b/pkg/datasource/sql/undo/executor/sql.go index d9b0a48d6..a79965cba 100644 --- a/pkg/datasource/sql/undo/executor/sql.go +++ b/pkg/datasource/sql/undo/executor/sql.go @@ -35,7 +35,6 @@ func DelEscape(colName string, dbType types.DBType) string { if dbType == types.DBTypeMySQL { newColName = delEscape(newColName, EscapeMysql) } - return newColName } diff --git a/pkg/remoting/getty/getty_remoting.go b/pkg/remoting/getty/getty_remoting.go index 1eaf81577..d8c7a1dc7 100644 --- a/pkg/remoting/getty/getty_remoting.go +++ b/pkg/remoting/getty/getty_remoting.go @@ -18,6 +18,7 @@ package getty import ( + "fmt" "sync" "time" @@ -80,7 +81,7 @@ func (g *GettyRemoting) sendAsync(session getty.Session, msg message.RpcMessage, var err error if session == nil || session.IsClosed() { log.Warn("sendAsyncRequestWithResponse nothing, caused by null channel.") - return nil, err + return nil, fmt.Errorf("session is closed") } resp := message.NewMessageFuture(msg) g.futures.Store(msg.ID, resp) diff --git a/pkg/rm/tcc/tcc_service.go b/pkg/rm/tcc/tcc_service.go index 7774f761a..bdc7918b8 100644 --- a/pkg/rm/tcc/tcc_service.go +++ b/pkg/rm/tcc/tcc_service.go @@ -77,7 +77,7 @@ func (t *TCCServiceProxy) Reference() string { } func (t *TCCServiceProxy) Prepare(ctx context.Context, params interface{}) (interface{}, error) { - if tm.IsTransactionOpened(ctx) { + if tm.IsGlobalTx(ctx) { err := t.registeBranch(ctx, params) if err != nil { return nil, err @@ -91,7 +91,7 @@ func (t *TCCServiceProxy) Prepare(ctx context.Context, params interface{}) (inte // registeBranch send register branch transaction request func (t *TCCServiceProxy) registeBranch(ctx context.Context, params interface{}) error { - if !tm.IsTransactionOpened(ctx) { + if !tm.IsGlobalTx(ctx) { err := errors.New("BranchRegister error, transaction should be opened") log.Errorf(err.Error()) return err diff --git a/pkg/tm/context.go b/pkg/tm/context.go index 9c14eb6cc..ab1ee18ac 100644 --- a/pkg/tm/context.go +++ b/pkg/tm/context.go @@ -117,7 +117,7 @@ func SetTransactionRole(ctx context.Context, role GlobalTransactionRole) { } } -func IsTransactionOpened(ctx context.Context) bool { +func IsGlobalTx(ctx context.Context) bool { variable := ctx.Value(seataContextVariable) if variable == nil { return false diff --git a/pkg/tm/context_test.go b/pkg/tm/context_test.go index adc6776db..3066cde7b 100644 --- a/pkg/tm/context_test.go +++ b/pkg/tm/context_test.go @@ -112,10 +112,10 @@ func TestGetXID(t *testing.T) { func TestIsTransactionOpened(t *testing.T) { ctx := InitSeataContext(context.Background()) - assert.False(t, IsTransactionOpened(ctx)) + assert.False(t, IsGlobalTx(ctx)) xid := "12345" SetXID(ctx, xid) - assert.True(t, IsTransactionOpened(ctx)) + assert.True(t, IsGlobalTx(ctx)) } func TestSetXIDCopy(t *testing.T) { diff --git a/pkg/tm/transaction_executor.go b/pkg/tm/transaction_executor.go index d94c83f00..da383f1fe 100644 --- a/pkg/tm/transaction_executor.go +++ b/pkg/tm/transaction_executor.go @@ -74,7 +74,7 @@ func begin(ctx context.Context, name string) (rc context.Context, re error) { } var tx *GlobalTransaction - if IsTransactionOpened(ctx) { + if IsGlobalTx(ctx) { tx = &GlobalTransaction{ Xid: GetXID(ctx), Status: message.GlobalStatusBegin, diff --git a/sample/at/basic/main.go b/sample/at/basic/main.go new file mode 100644 index 000000000..e5765b682 --- /dev/null +++ b/sample/at/basic/main.go @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "context" + "fmt" + "time" + + "github.com/seata/seata-go/pkg/client" + "github.com/seata/seata-go/pkg/tm" +) + +type Foo struct { + id int + name string +} + +func main() { + client.Init() + initService() + tm.WithGlobalTx(context.Background(), &tm.TransactionInfo{ + Name: "ATSampleLocalGlobalTx", + }, updateData) + <-make(chan struct{}) +} + +func selectData() { + var foo Foo + row := db.QueryRow("select * from foo where id = ? ", 1) + err := row.Scan(&foo.id, &foo.name) + if err != nil { + panic(err) + } + fmt.Println(foo) +} + +func updateData(ctx context.Context) error { + sql := "update foo set name=? where id=?" + ret, err := db.ExecContext(ctx, sql, fmt.Sprintf("Zhangsan-%d", time.Now().UnixMilli()), 1) + if err != nil { + fmt.Printf("update failed, err:%v\n", err) + return nil + } + rows, err := ret.RowsAffected() + if err != nil { + fmt.Printf("update failed, err:%v\n", err) + return nil + } + fmt.Printf("update success: %d.\n", rows) + return nil +} + +func insertData(ctx context.Context) error { + sqlStr := "insert into foo(name) values (?)" + ret, err := db.ExecContext(ctx, sqlStr, fmt.Sprintf("Zhangsan-%d", time.Now().UnixMilli())) + if err != nil { + fmt.Printf("insert failed, err:%v\n", err) + return err + } + theID, err := ret.LastInsertId() + if err != nil { + fmt.Printf("get lastinsert ID failed, err:%v\n", err) + return err + } + fmt.Printf("insert success, the id is %d.\n", theID) + return nil +} diff --git a/pkg/datasource/sql/context.go b/sample/at/basic/service.go similarity index 71% rename from pkg/datasource/sql/context.go rename to sample/at/basic/service.go index a7a5c57fd..6e5ca6542 100644 --- a/pkg/datasource/sql/context.go +++ b/sample/at/basic/service.go @@ -15,15 +15,22 @@ * limitations under the License. */ -package sql +package main import ( - "context" + "database/sql" - "github.com/seata/seata-go/pkg/tm" + sql2 "github.com/seata/seata-go/pkg/datasource/sql" ) -// IsGlobalTx check is open global transactions -func IsGlobalTx(ctx context.Context) bool { - return tm.IsTransactionOpened(ctx) +var ( + db *sql.DB +) + +func initService() { + var err error + db, err = sql.Open(sql2.SeataATMySQLDriver, "root:12345678@tcp(127.0.0.1:3306)/seata_client?multiStatements=true&interpolateParams=true") + if err != nil { + panic("init service error") + } }