diff --git a/sqlx.go b/sqlx.go index f7b28768..99582e73 100644 --- a/sqlx.go +++ b/sqlx.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -237,6 +238,35 @@ func (r *Row) Err() error { return r.err } +// Queryable includes all methods shared by sqlx.DB and sqlx.Tx, allowing +// either type to be used interchangeably. +type Queryable interface { + Ext + ExecerContext + PreparerContext + QueryerContext + Preparer + + GetContext(context.Context, interface{}, string, ...interface{}) error + SelectContext(context.Context, interface{}, string, ...interface{}) error + Get(interface{}, string, ...interface{}) error + MustExecContext(context.Context, string, ...interface{}) sql.Result + PreparexContext(context.Context, string) (*Stmt, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row + Select(interface{}, string, ...interface{}) error + QueryRow(string, ...interface{}) *sql.Row + PrepareNamedContext(context.Context, string) (*NamedStmt, error) + PrepareNamed(string) (*NamedStmt, error) + Preparex(string) (*Stmt, error) + NamedExec(string, interface{}) (sql.Result, error) + NamedExecContext(context.Context, string, interface{}) (sql.Result, error) + MustExec(string, ...interface{}) sql.Result + NamedQuery(string, interface{}) (*Rows, error) +} + +var _ Queryable = (*DB)(nil) +var _ Queryable = (*Tx)(nil) + // DB is a wrapper around sql.DB which keeps track of the driverName upon Open, // used mostly to automatically bind named queries using the right bindvars. type DB struct { diff --git a/sqlx_test.go b/sqlx_test.go index 1d4aa20d..8cd605bc 100644 --- a/sqlx_test.go +++ b/sqlx_test.go @@ -1924,3 +1924,74 @@ func TestSelectReset(t *testing.T) { } }) } + +func TestQueryable(t *testing.T) { + sqlDBType := reflect.TypeOf(&sql.DB{}) + dbType := reflect.TypeOf(&DB{}) + sqlTxType := reflect.TypeOf(&sql.Tx{}) + txType := reflect.TypeOf(&Tx{}) + + dbMethods := exportableMethods(sqlDBType) + for k, v := range exportableMethods(dbType) { + dbMethods[k] = v + } + + txMethods := exportableMethods(sqlTxType) + for k, v := range exportableMethods(txType) { + txMethods[k] = v + } + + sharedMethods := make([]string, 0) + + for name, dbMethod := range dbMethods { + if txMethod, ok := txMethods[name]; ok { + if methodsEqual(dbMethod.Type, txMethod.Type) { + sharedMethods = append(sharedMethods, name) + } + } + } + + queryableType := reflect.TypeOf((*Queryable)(nil)).Elem() + queryableMethods := exportableMethods(queryableType) + + for _, sharedMethodName := range sharedMethods { + if _, ok := queryableMethods[sharedMethodName]; !ok { + t.Errorf("Queryable does not include shared DB/Tx method: %s", sharedMethodName) + } + } +} + +func exportableMethods(t reflect.Type) map[string]reflect.Method { + methods := make(map[string]reflect.Method) + + for i := 0; i < t.NumMethod(); i++ { + method := t.Method(i) + + if method.IsExported() { + methods[method.Name] = method + } + } + + return methods +} + +func methodsEqual(t reflect.Type, ot reflect.Type) bool { + if t.NumIn() != ot.NumIn() || t.NumOut() != ot.NumOut() || t.IsVariadic() != ot.IsVariadic() { + return false + } + + // Start at 1 to avoid comparing receiver argument + for i := 1; i < t.NumIn(); i++ { + if t.In(i) != ot.In(i) { + return false + } + } + + for i := 0; i < t.NumOut(); i++ { + if t.Out(i) != ot.Out(i) { + return false + } + } + + return true +}