From 721c71f7bdf3bf40cfa3cfd7cfac610878df1cc0 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 19 Oct 2021 13:16:59 +0900 Subject: [PATCH 1/2] improve context interfaces for Exec and Query --- xraysql/conn.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xraysql/conn.go b/xraysql/conn.go index 7f22a06..fe1d619 100644 --- a/xraysql/conn.go +++ b/xraysql/conn.go @@ -126,7 +126,8 @@ func (conn *driverConn) Exec(query string, args []driver.Value) (driver.Result, func (conn *driverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { execer, ok := conn.Conn.(driver.Execer) - if !ok { + execerCtx, okCtx := conn.Conn.(driver.ExecerContext) + if !ok && !okCtx { return nil, driver.ErrSkip } @@ -135,7 +136,7 @@ func (conn *driverConn) ExecContext(ctx context.Context, query string, args []dr var err error var result driver.Result - if execerCtx, ok := conn.Conn.(driver.ExecerContext); ok { + if okCtx { result, err = execerCtx.ExecContext(ctx, query, args) } else { select { @@ -168,7 +169,8 @@ func (conn *driverConn) Query(query string, args []driver.Value) (driver.Rows, e func (conn *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { queryer, ok := conn.Conn.(driver.Queryer) - if !ok { + queryerCtx, okCtx := conn.Conn.(driver.QueryerContext) + if !ok && okCtx { return nil, driver.ErrSkip } @@ -177,7 +179,7 @@ func (conn *driverConn) QueryContext(ctx context.Context, query string, args []d var err error var rows driver.Rows - if queryerCtx, ok := conn.Conn.(driver.QueryerContext); ok { + if okCtx { rows, err = queryerCtx.QueryContext(ctx, query, args) } else { select { From 65c571c16f93be0257af6618bb3eba9e9e1c212b Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 19 Oct 2021 13:43:36 +0900 Subject: [PATCH 2/2] fix the condition of using QueryerContext --- xraysql/conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xraysql/conn.go b/xraysql/conn.go index fe1d619..49bd0bd 100644 --- a/xraysql/conn.go +++ b/xraysql/conn.go @@ -170,7 +170,7 @@ func (conn *driverConn) Query(query string, args []driver.Value) (driver.Rows, e func (conn *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { queryer, ok := conn.Conn.(driver.Queryer) queryerCtx, okCtx := conn.Conn.(driver.QueryerContext) - if !ok && okCtx { + if !ok && !okCtx { return nil, driver.ErrSkip }