From 9c5a420018b739f4508c7b07f97f2b121e035f4d Mon Sep 17 00:00:00 2001 From: Blake Williams Date: Sun, 8 May 2022 20:26:05 -0400 Subject: [PATCH] Add Queryable interface When writing applications it's useful to pass a sqlx.DB and a sqlx.Tx interchangeably so that you can compose functions that are runnable in isolation or as part of a transaction. This introduces the Queryable interface, which includes the common exportable methods shared between sqlx.DB and sqlx.Tx so users of the package don't have to implement it themselves. This also adds tests that validate any new shared methods are added to the interface. --- sqlx.go | 30 ++++++++++++++++++++++ sqlx_test.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/sqlx.go b/sqlx.go index f7b28768..99582e73 100644 --- a/sqlx.go +++ b/sqlx.go @@ -1,6 +1,7 @@ package sqlx import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -237,6 +238,35 @@ func (r *Row) Err() error { return r.err } +// Queryable includes all methods shared by sqlx.DB and sqlx.Tx, allowing +// either type to be used interchangeably. +type Queryable interface { + Ext + ExecerContext + PreparerContext + QueryerContext + Preparer + + GetContext(context.Context, interface{}, string, ...interface{}) error + SelectContext(context.Context, interface{}, string, ...interface{}) error + Get(interface{}, string, ...interface{}) error + MustExecContext(context.Context, string, ...interface{}) sql.Result + PreparexContext(context.Context, string) (*Stmt, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row + Select(interface{}, string, ...interface{}) error + QueryRow(string, ...interface{}) *sql.Row + PrepareNamedContext(context.Context, string) (*NamedStmt, error) + PrepareNamed(string) (*NamedStmt, error) + Preparex(string) (*Stmt, error) + NamedExec(string, interface{}) (sql.Result, error) + NamedExecContext(context.Context, string, interface{}) (sql.Result, error) + MustExec(string, ...interface{}) sql.Result + NamedQuery(string, interface{}) (*Rows, error) +} + +var _ Queryable = (*DB)(nil) +var _ Queryable = (*Tx)(nil) + // DB is a wrapper around sql.DB which keeps track of the driverName upon Open, // used mostly to automatically bind named queries using the right bindvars. type DB struct { diff --git a/sqlx_test.go b/sqlx_test.go index 1d4aa20d..8cd605bc 100644 --- a/sqlx_test.go +++ b/sqlx_test.go @@ -1924,3 +1924,74 @@ func TestSelectReset(t *testing.T) { } }) } + +func TestQueryable(t *testing.T) { + sqlDBType := reflect.TypeOf(&sql.DB{}) + dbType := reflect.TypeOf(&DB{}) + sqlTxType := reflect.TypeOf(&sql.Tx{}) + txType := reflect.TypeOf(&Tx{}) + + dbMethods := exportableMethods(sqlDBType) + for k, v := range exportableMethods(dbType) { + dbMethods[k] = v + } + + txMethods := exportableMethods(sqlTxType) + for k, v := range exportableMethods(txType) { + txMethods[k] = v + } + + sharedMethods := make([]string, 0) + + for name, dbMethod := range dbMethods { + if txMethod, ok := txMethods[name]; ok { + if methodsEqual(dbMethod.Type, txMethod.Type) { + sharedMethods = append(sharedMethods, name) + } + } + } + + queryableType := reflect.TypeOf((*Queryable)(nil)).Elem() + queryableMethods := exportableMethods(queryableType) + + for _, sharedMethodName := range sharedMethods { + if _, ok := queryableMethods[sharedMethodName]; !ok { + t.Errorf("Queryable does not include shared DB/Tx method: %s", sharedMethodName) + } + } +} + +func exportableMethods(t reflect.Type) map[string]reflect.Method { + methods := make(map[string]reflect.Method) + + for i := 0; i < t.NumMethod(); i++ { + method := t.Method(i) + + if method.IsExported() { + methods[method.Name] = method + } + } + + return methods +} + +func methodsEqual(t reflect.Type, ot reflect.Type) bool { + if t.NumIn() != ot.NumIn() || t.NumOut() != ot.NumOut() || t.IsVariadic() != ot.IsVariadic() { + return false + } + + // Start at 1 to avoid comparing receiver argument + for i := 1; i < t.NumIn(); i++ { + if t.In(i) != ot.In(i) { + return false + } + } + + for i := 0; i < t.NumOut(); i++ { + if t.Out(i) != ot.Out(i) { + return false + } + } + + return true +}