diff --git a/connection.go b/connection.go index d495939..a792817 100644 --- a/connection.go +++ b/connection.go @@ -98,7 +98,11 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name ctx = driverctx.NewContextWithConnId(ctx, c.id) if len(args) > 0 { - return nil, errors.New(ErrParametersNotSupported) + q, err := SubstituteArgs(query, args) + if err != nil { + return nil, err + } + query = q } exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args) @@ -140,7 +144,11 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam ctx = driverctx.NewContextWithConnId(ctx, c.id) if len(args) > 0 { - return nil, errors.New(ErrParametersNotSupported) + q, err := SubstituteArgs(query, args) + if err != nil { + return nil, err + } + query = q } // first we try to get the results synchronously. // at any point in time that the context is done we must cancel and return diff --git a/util.go b/util.go new file mode 100644 index 0000000..80a6ee2 --- /dev/null +++ b/util.go @@ -0,0 +1,87 @@ +package dbsql + +import ( + "bytes" + "database/sql/driver" + "fmt" + "strings" + "time" +) + +func EscapeArgs(args []driver.NamedValue) (_ []string, err error) { + escaped := make([]string, len(args)) + + for i, arg := range args { + escaped[i], err = escapearg(arg.Value) + if err != nil { + return nil, err + } + } + return escaped, nil +} + +func escapearg(val interface{}) (string, error) { + if vb, isBytes := val.([]byte); isBytes { + val = string(vb) + } + switch v := val.(type) { + case string: + return "'" + strings.ReplaceAll(strings.ReplaceAll(v, "'", "''"), "\\", "\\\\") + "'", nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + return fmt.Sprintf("%v", v), nil + case bool: + return fmt.Sprintf("%v", v), nil + case nil: + return "NULL", nil + case time.Time: + return "'" + v.Format("2006-01-02 15:04:05") + "'", nil + case []interface{}: + var err error + nested := make([]string, len(v)) + for j, vv := range v { + nested[j], err = escapearg(vv) + if err != nil { + return "", err + } + } + if err != nil { + return "", err + } + return fmt.Sprintf("(%s)", strings.Join(nested, ", ")), nil + default: + return "", fmt.Errorf("unsupported type %T", v) + } +} + +func SubstituteArgs(sql string, args []driver.NamedValue) (string, error) { + escaped, err := EscapeArgs(args) + if err != nil { + return "", err + } + + buf := &bytes.Buffer{} + i := 0 + for { + p := strings.Index(sql, "?") + if p == -1 { + break + } + + if len(sql[p:]) > 1 && sql[p:p+2] == "??" { // escape ?? => ? + buf.WriteString(sql[:p]) + buf.WriteString("?") + if len(sql[p:]) == 1 { + break + } + sql = sql[p+2:] + } else { + buf.WriteString(sql[:p]) + fmt.Fprint(buf, escaped[i]) + i++ + sql = sql[p+1:] + } + } + + buf.WriteString(sql) + return buf.String(), nil +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..399277a --- /dev/null +++ b/util_test.go @@ -0,0 +1,101 @@ +package dbsql + +import ( + "database/sql/driver" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestEscapeArgs(t *testing.T) { + thyme := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + tests := map[string]struct { + args []driver.NamedValue + expected []string + expectError bool + }{ + "strings": { + args: []driver.NamedValue{{Ordinal: 1, Value: "foo"}, {Ordinal: 2, Value: "bar"}}, + expected: []string{ + "'foo'", + "'bar'", + }, + }, + "strings with quotes": { + args: []driver.NamedValue{{Ordinal: 1, Value: "f'oo"}, {Ordinal: 2, Value: "bar"}}, + expected: []string{ + "'f''oo'", + "'bar'", + }, + }, + "lists": { + args: []driver.NamedValue{{Ordinal: 1, Value: []interface{}{"foo", "bar"}}}, + expected: []string{ + "('foo', 'bar')", + }, + }, + "time": { + args: []driver.NamedValue{{Ordinal: 1, Value: thyme}}, + expected: []string{ + "'2020-01-01 00:00:00'", + }, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + actual, err := EscapeArgs(test.args) + if test.expectError { + assert.Error(t, err) + } else { + assert.Equal(t, test.expected, actual) + } + }) + } +} + +func TestSubstituteArgs(t *testing.T) { + tests := map[string]struct { + sql string + args []driver.NamedValue + expected string + expectError bool + }{ + "no args": { + sql: "SELECT * FROM foo", + expected: "SELECT * FROM foo", + }, + "one arg": { + sql: "SELECT * FROM foo WHERE bar = ?", + args: []driver.NamedValue{{Ordinal: 1, Value: "baz"}}, + expected: "SELECT * FROM foo WHERE bar = 'baz'", + }, + "two args": { + sql: "SELECT * FROM foo WHERE bar = ? AND baz = ?", + args: []driver.NamedValue{{Ordinal: 1, Value: "baz"}, {Ordinal: 2, Value: "qux"}}, + expected: "SELECT * FROM foo WHERE bar = 'baz' AND baz = 'qux'", + }, + "two args with list": { + sql: "SELECT * FROM foo WHERE bar = ? AND baz IN ?", + args: []driver.NamedValue{{Ordinal: 1, Value: "baz"}, {Ordinal: 2, Value: []interface{}{"qux", "quux"}}}, + expected: "SELECT * FROM foo WHERE bar = 'baz' AND baz IN ('qux', 'quux')", + }, + "three args with list and time": { + sql: "SELECT * FROM foo WHERE bar = ? AND baz IN ? AND qux = ?", + args: []driver.NamedValue{{Ordinal: 1, Value: "baz"}, {Ordinal: 2, Value: []interface{}{"qux", "quux"}}, {Ordinal: 3, Value: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)}}, + expected: "SELECT * FROM foo WHERE bar = 'baz' AND baz IN ('qux', 'quux') AND qux = '2020-01-01 00:00:00'", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + actual, err := SubstituteArgs(test.sql, test.args) + if test.expectError { + assert.Error(t, err) + } else { + assert.Equal(t, test.expected, actual) + } + }) + } +}