From 3def8f82f4608bbcefed7ab4c3223442cefacb9c Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Sat, 13 Oct 2018 00:43:44 +0000 Subject: [PATCH] Validate single row count for Row.Scan Validate that a single row only is returned by queries used for Row.Scan. This avoids unexpected results when the query has an issue such as a missing join criteria or limit in conjunction with functions which expect only on row returned e.g. Get(...). Also: * Fixed missing \n's for test output of ConnectAll. --- sqlx.go | 25 +++++++++++++++++++------ sqlx_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/sqlx.go b/sqlx.go index 4385c3fa..cfb4db0f 100644 --- a/sqlx.go +++ b/sqlx.go @@ -15,6 +15,12 @@ import ( "github.com/jmoiron/sqlx/reflectx" ) +// ErrMultiRows is returned by functions which are expected to work with result sets +// that only contain a single row but multiple rows where returned. +// This typically indicates an issue with the query such as a missing join criteria or +// limit condition or the use of Get(...) when Select(...) was intended. +var ErrMultiRows = errors.New("sql: multiple rows returned") + // Although the NameMapper is convenient, in practice it should not // be relied on except for application code. If you are writing a library // that uses sqlx, you should be aware that the name mappings you expect @@ -177,6 +183,7 @@ type Row struct { // Scan is a fixed implementation of sql.Row.Scan, which does not discard the // underlying error from the internal rows object if it exists. +// Returns ErrMultiRows if the result set contains more than one row. func (r *Row) Scan(dest ...interface{}) error { if r.err != nil { return r.err @@ -208,10 +215,16 @@ func (r *Row) Scan(dest ...interface{}) error { } return sql.ErrNoRows } - err := r.rows.Scan(dest...) - if err != nil { + if err := r.rows.Scan(dest...); err != nil { + return err + } + + if r.rows.Next() { + return ErrMultiRows + } else if err := r.rows.Err(); err != nil { return err } + // Make sure the query can be processed to completion with no errors. if err := r.rows.Close(); err != nil { return err @@ -323,7 +336,7 @@ func (db *DB) Select(dest interface{}, query string, args ...interface{}) error // Get using this DB. // Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. +// An error is returned if the result set is empty or contains more than one row. func (db *DB) Get(dest interface{}, query string, args ...interface{}) error { return Get(db, dest, query, args...) } @@ -446,7 +459,7 @@ func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row { // Get within a transaction. // Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. +// An error is returned if the result set is empty or contains more than one row. func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error { return Get(tx, dest, query, args...) } @@ -516,7 +529,7 @@ func (s *Stmt) Select(dest interface{}, args ...interface{}) error { // Get using the prepared statement. // Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. +// An error is returned if the result set is empty or contains more than one row. func (s *Stmt) Get(dest interface{}, args ...interface{}) error { return Get(&qStmt{s}, dest, "", args...) } @@ -682,7 +695,7 @@ func Select(q Queryer, dest interface{}, query string, args ...interface{}) erro // to dest. If dest is scannable, the result must only have one column. Otherwise, // StructScan is used. Get will return sql.ErrNoRows like row.Scan would. // Any placeholder parameters are replaced with supplied args. -// An error is returned if the result set is empty. +// An error is returned if the result set is empty or contains more than one row. func Get(q Queryer, dest interface{}, query string, args ...interface{}) error { r := q.QueryRowx(query, args...) return r.scanAny(dest, false) diff --git a/sqlx_test.go b/sqlx_test.go index e26c9807..b3977065 100644 --- a/sqlx_test.go +++ b/sqlx_test.go @@ -75,7 +75,7 @@ func ConnectAll() { if TestMysql { mysqldb, err = Connect("mysql", mydsn) if err != nil { - fmt.Printf("Disabling MySQL tests:\n %v", err) + fmt.Printf("Disabling MySQL tests:\n %v\n", err) TestMysql = false } } else { @@ -85,7 +85,7 @@ func ConnectAll() { if TestSqlite { sldb, err = Connect("sqlite3", sqdsn) if err != nil { - fmt.Printf("Disabling SQLite:\n %v", err) + fmt.Printf("Disabling SQLite:\n %v\n", err) TestSqlite = false } } else { @@ -1708,6 +1708,46 @@ func TestEmbeddedLiterals(t *testing.T) { }) } +// TestGet tests to ensure that Get behaves correctly for +// single row and multi row results. +func TestGet(t *testing.T) { + var schema = Schema{ + create: `CREATE TABLE tst (v integer);`, + drop: `drop table tst;`, + } + + RunWithSchema(schema, t, func(db *DB, t *testing.T) { + for _, v := range []int{1, 2} { + _, err := db.Exec(db.Rebind("INSERT INTO tst (v) VALUES (?)"), v) + if err != nil { + t.Error(err) + } + } + + tests := []struct { + name string + val int + err bool + }{ + {"multi-rows", 1, true}, + {"single-row", 2, false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var v int + err := db.Get(&v, db.Rebind("SELECT v FROM tst WHERE v >= ?"), tc.val) + if tc.err { + if err == nil { + t.Error("expected error but got nil") + } + } else if err != nil { + t.Error("unexpected error:", err) + } + }) + } + }) +} + func BenchmarkBindStruct(b *testing.B) { b.StopTimer() q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)`