Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate single row count for Row.Scan #460

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions sqlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ import (
"github.com/jmoiron/sqlx/reflectx"
)

// ErrMultiRows is returned by functions which are expected to work with result sets
// that only contain a single row but multiple rows were returned.
// This typically indicates an issue with the query such as a missing join criteria or
// limit condition or the use of Get(...) when Select(...) was intended.
var ErrMultiRows = errors.New("sql: multiple rows returned")

// Although the NameMapper is convenient, in practice it should not
// be relied on except for application code. If you are writing a library
// that uses sqlx, you should be aware that the name mappings you expect
Expand Down Expand Up @@ -177,6 +183,7 @@ type Row struct {

// Scan is a fixed implementation of sql.Row.Scan, which does not discard the
// underlying error from the internal rows object if it exists.
// Returns ErrMultiRows if the result set contains more than one row.
func (r *Row) Scan(dest ...interface{}) error {
if r.err != nil {
return r.err
Expand Down Expand Up @@ -208,10 +215,16 @@ func (r *Row) Scan(dest ...interface{}) error {
}
return sql.ErrNoRows
}
err := r.rows.Scan(dest...)
if err != nil {
if err := r.rows.Scan(dest...); err != nil {
return err
}

if r.rows.Next() {
return ErrMultiRows
} else if err := r.rows.Err(); err != nil {
return err
}

// Make sure the query can be processed to completion with no errors.
if err := r.rows.Close(); err != nil {
return err
Expand Down Expand Up @@ -323,7 +336,7 @@ func (db *DB) Select(dest interface{}, query string, args ...interface{}) error

// Get using this DB.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
// An error is returned if the result set is empty or contains more than one row.
func (db *DB) Get(dest interface{}, query string, args ...interface{}) error {
return Get(db, dest, query, args...)
}
Expand Down Expand Up @@ -446,7 +459,7 @@ func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row {

// Get within a transaction.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
// An error is returned if the result set is empty or contains more than one row.
func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error {
return Get(tx, dest, query, args...)
}
Expand Down Expand Up @@ -516,7 +529,7 @@ func (s *Stmt) Select(dest interface{}, args ...interface{}) error {

// Get using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
// An error is returned if the result set is empty or contains more than one row.
func (s *Stmt) Get(dest interface{}, args ...interface{}) error {
return Get(&qStmt{s}, dest, "", args...)
}
Expand Down Expand Up @@ -682,7 +695,7 @@ func Select(q Queryer, dest interface{}, query string, args ...interface{}) erro
// to dest. If dest is scannable, the result must only have one column. Otherwise,
// StructScan is used. Get will return sql.ErrNoRows like row.Scan would.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
// An error is returned if the result set is empty or contains more than one row.
func Get(q Queryer, dest interface{}, query string, args ...interface{}) error {
r := q.QueryRowx(query, args...)
return r.scanAny(dest, false)
Expand Down
44 changes: 42 additions & 2 deletions sqlx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func ConnectAll() {
if TestMysql {
mysqldb, err = Connect("mysql", mydsn)
if err != nil {
fmt.Printf("Disabling MySQL tests:\n %v", err)
fmt.Printf("Disabling MySQL tests:\n %v\n", err)
TestMysql = false
}
} else {
Expand All @@ -85,7 +85,7 @@ func ConnectAll() {
if TestSqlite {
sldb, err = Connect("sqlite3", sqdsn)
if err != nil {
fmt.Printf("Disabling SQLite:\n %v", err)
fmt.Printf("Disabling SQLite:\n %v\n", err)
TestSqlite = false
}
} else {
Expand Down Expand Up @@ -1708,6 +1708,46 @@ func TestEmbeddedLiterals(t *testing.T) {
})
}

// TestGet tests to ensure that Get behaves correctly for
// single row and multi row results.
func TestGet(t *testing.T) {
var schema = Schema{
create: `CREATE TABLE tst (v integer);`,
drop: `drop table tst;`,
}

RunWithSchema(schema, t, func(db *DB, t *testing.T) {
for _, v := range []int{1, 2} {
_, err := db.Exec(db.Rebind("INSERT INTO tst (v) VALUES (?)"), v)
if err != nil {
t.Error(err)
}
}

tests := []struct {
name string
val int
err bool
}{
{"multi-rows", 1, true},
{"single-row", 2, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var v int
err := db.Get(&v, db.Rebind("SELECT v FROM tst WHERE v >= ?"), tc.val)
if tc.err {
if err == nil {
t.Error("expected error but got nil")
}
} else if err != nil {
t.Error("unexpected error:", err)
}
})
}
})
}

func BenchmarkBindStruct(b *testing.B) {
b.StopTimer()
q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)`
Expand Down