Skip to content

Commit

Permalink
[AB#1669514] fix array binding
Browse files Browse the repository at this point in the history
  • Loading branch information
ChronosMasterOfAllTime committed Nov 22, 2024
1 parent ca44e68 commit 3b145a8
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 33 deletions.
101 changes: 79 additions & 22 deletions bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"database/sql"
"fmt"
"log"
"math"
"math/big"
"math/rand"
"reflect"
Expand Down Expand Up @@ -70,7 +71,7 @@ func TestBindingFloat64(t *testing.T) {
dbt.mustExec("INSERT INTO test VALUES (1, ?)", expected)
rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
if rows.Next() {
assertNilF(t, rows.Scan(&out))
Expand Down Expand Up @@ -203,14 +204,14 @@ func TestBindingTimestampTZ(t *testing.T) {
dbt.Fatal(err.Error())
}
defer func() {
assertNilF(t, stmt.Close())
assertNilF(t, stmt.Close())
}()
if _, err = stmt.Exec(DataTypeTimestampTz, expected); err != nil {
dbt.Fatal(err)
}
rows := dbt.mustQuery("SELECT tz FROM tztest WHERE id=?", 1)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
var v time.Time
if rows.Next() {
Expand Down Expand Up @@ -258,7 +259,7 @@ func TestBindingTimePtrInStruct(t *testing.T) {

rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
var v time.Time
if rows.Next() {
Expand Down Expand Up @@ -307,7 +308,7 @@ func TestBindingTimeInStruct(t *testing.T) {

rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
var v time.Time
if rows.Next() {
Expand All @@ -329,7 +330,7 @@ func TestBindingInterface(t *testing.T) {
rows := dbt.mustQueryContext(
WithHigherPrecision(context.Background()), selectVariousTypes)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
if !rows.Next() {
dbt.Error("failed to query")
Expand Down Expand Up @@ -357,7 +358,7 @@ func TestBindingInterfaceString(t *testing.T) {
runDBTest(t, func(dbt *DBTest) {
rows := dbt.mustQuery(selectVariousTypes)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
if !rows.Next() {
dbt.Error("failed to query")
Expand All @@ -382,6 +383,62 @@ func TestBindingInterfaceString(t *testing.T) {
})
}

func TestBulkArrayBindingUUID(t *testing.T) {
max := math.Pow10(5) // 100K because my power is maximum
uuids := make([]any, int(max))

createTable := "CREATE OR REPLACE TABLE TEST_PREP_STATEMENT (id INT autoincrement start 1 increment 1, uuid VARCHAR)"
insert := "INSERT INTO TEST_PREP_STATEMENT (uuid) VALUES (?)"

for i := range uuids {
uuids[i] = newTestUUID()
}

runDBTest(t, func(dbt *DBTest) {
var rows *RowsExtended
t.Cleanup(func() {
if rows != nil {
assertNilF(t, rows.Close())
}

dbt.exec(deleteTableSQL)
})

dbt.mustExec(createTable)

bound := Array(&uuids)
res := dbt.mustExec(insert, bound)

if affected, _ := res.RowsAffected(); affected != int64(max) {
t.Fatalf("failed to insert all rows. expected: %f.0, got: %v", max, affected)
}

rows = dbt.mustQuery("SELECT * FROM TEST_PREP_STATEMENT ORDER BY ID")

for i := 0; rows.Next(); i++ {
var (
id int
out testUUID
)
if err := rows.Scan(&id, &out); err != nil {
t.Fatal(err)
}

var found bool
for _, u := range uuids {
if u == out {
found = true
break
}
}
if !found {
t.Errorf("failed to find UUID. expected: %s, but it wasnt in the list", out)
}
}
})

}

func TestBulkArrayBindingInterfaceNil(t *testing.T) {
nilArray := make([]any, 1)

Expand All @@ -396,7 +453,7 @@ func TestBulkArrayBindingInterfaceNil(t *testing.T) {
Array(&nilArray, TimeType))
rows := dbt.mustQuery(selectAllSQL)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()

var v0 sql.NullInt32
Expand Down Expand Up @@ -481,7 +538,7 @@ func TestBulkArrayBindingInterface(t *testing.T) {
Array(&boolArray), Array(&strArray), Array(&byteArray), Array(&int64Array))
rows := dbt.mustQuery(selectAllSQLBulkArray)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()

var v0 sql.NullInt32
Expand Down Expand Up @@ -586,7 +643,7 @@ func TestBulkArrayBindingInterfaceDateTimeTimestamp(t *testing.T) {

rows := dbt.mustQuery(selectAllSQLBulkArrayDateTimeTimestamp)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()

var v0, v1, v2, v3, v4 sql.NullTime
Expand Down Expand Up @@ -695,7 +752,7 @@ func testBindingArray(t *testing.T, bulk bool) {
Array(&tmArray, TimeType))
rows := dbt.mustQuery(selectAllSQL)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()

var v0 int
Expand Down Expand Up @@ -777,7 +834,7 @@ func TestBulkArrayBinding(t *testing.T) {
dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?)", dbname), Array(&intArr), Array(&strArr), Array(&ltzArr, TimestampLTZType), Array(&tzArr, TimestampTZType), Array(&ntzArr, TimestampNTZType), Array(&dateArr, DateType), Array(&timeArr, TimeType), Array(&binArr))
rows := dbt.mustQuery("select * from " + dbname + " order by c1")
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
cnt := 0
var i int
Expand Down Expand Up @@ -825,7 +882,7 @@ func TestBulkArrayBindingTimeWithPrecision(t *testing.T) {
dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?)", dbname), Array(&secondsArr, TimeType), Array(&millisecondsArr, TimeType), Array(&microsecondsArr, TimeType), Array(&nanosecondsArr, TimeType))
rows := dbt.mustQuery("select * from " + dbname)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
cnt := 0
var s, ms, us, ns time.Time
Expand Down Expand Up @@ -866,7 +923,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) {
Array(&randomStrings))
rows := dbt.mustQuery("select count(*) from " + tempTableName)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
if rows.Next() {
var count int
Expand All @@ -878,7 +935,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) {

rows := dbt.mustQuery("select count(*) from " + tempTableName)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
if rows.Next() {
var count int
Expand Down Expand Up @@ -909,7 +966,7 @@ func TestBulkArrayMultiPartBindingInt(t *testing.T) {

rows := dbt.mustQuery("select * from binding_test order by c1")
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
cnt := startNum
var i int
Expand Down Expand Up @@ -959,7 +1016,7 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) {

rows := dbt.mustQuery("select * from binding_test order by c1,c2")
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
cnt := startNum
var i sql.NullInt32
Expand Down Expand Up @@ -1042,7 +1099,7 @@ func TestFunctionParameters(t *testing.T) {
t.Fatal(err)
}
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
if rows.Err() != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1144,7 +1201,7 @@ func TestVariousBindingModes(t *testing.T) {
t.Fatal(err)
}
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
if !rows.Next() {
t.Fatal("Expected to return a row")
Expand Down Expand Up @@ -1194,7 +1251,7 @@ func testLOBRetrieval(t *testing.T, useArrowFormat bool) {
rows, err := dbt.query(fmt.Sprintf("SELECT randstr(%v, 124)", testSize))
assertNilF(t, err)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
assertTrueF(t, rows.Next(), fmt.Sprintf("no rows returned for the LOB size %v", testSize))

Expand Down Expand Up @@ -1227,7 +1284,7 @@ func TestMaxLobSize(t *testing.T) {
rows, err := dbt.query("select randstr(20000000, random())")
assertNilF(t, err)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
})
})
Expand Down Expand Up @@ -1308,7 +1365,7 @@ func testInsertLOBData(t *testing.T, useArrowFormat bool, isLiteral bool) {
rows, err := dbt.query("SELECT * FROM lob_test_table")
assertNilF(t, err)
defer func() {
assertNilF(t, rows.Close())
assertNilF(t, rows.Close())
}()
assertTrueF(t, rows.Next(), fmt.Sprintf("%s: no rows returned", tc.testDesc))

Expand Down
33 changes: 22 additions & 11 deletions converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func arrayToString(v driver.Value, tsmode snowflakeType, params map[string]*stri
res = res[0:len(res)-1] + "]"
return bindingValue{&res, jsonFormatStr, &schemaForBytes}, nil
} else if isUUIDImplementer(v1) { // special case for UUIDs (snowflake type and other implementers)
stringer := v.(fmt.Stringer)
stringer := v.(fmt.Stringer) // we don't need to validate if it's a fmt.Stringer because we already checked if it's a UUID type with a stringer
value := stringer.String()
return bindingValue{&value, "", nil}, nil
} else if isSliceOfSlices(v) {
Expand Down Expand Up @@ -2696,52 +2696,45 @@ func interfaceSliceToString(interfaceSlice reflect.Value, stream bool, tzType ..
for i := 0; i < interfaceSlice.Len(); i++ {
val := interfaceSlice.Index(i)
if val.CanInterface() {
switch val.Interface().(type) {
v := val.Interface()

switch x := v.(type) {
case int:
t = fixedType
x := val.Interface().(int)
v := strconv.Itoa(x)
arr = append(arr, &v)
case int32:
t = fixedType
x := val.Interface().(int32)
v := strconv.Itoa(int(x))
arr = append(arr, &v)
case int64:
t = fixedType
x := val.Interface().(int64)
v := strconv.FormatInt(x, 10)
arr = append(arr, &v)
case float32:
t = realType
x := val.Interface().(float32)
v := fmt.Sprintf("%g", x)
arr = append(arr, &v)
case float64:
t = realType
x := val.Interface().(float64)
v := fmt.Sprintf("%g", x)
arr = append(arr, &v)
case bool:
t = booleanType
x := val.Interface().(bool)
v := strconv.FormatBool(x)
arr = append(arr, &v)
case string:
t = textType
x := val.Interface().(string)
arr = append(arr, &x)
case []byte:
t = binaryType
x := val.Interface().([]byte)
v := hex.EncodeToString(x)
arr = append(arr, &v)
case time.Time:
if len(tzType) < 1 {
return unSupportedType, nil
}

x := val.Interface().(time.Time)
switch tzType[0] {
case TimestampNTZType:
t = timestampNtzType
Expand Down Expand Up @@ -2781,8 +2774,26 @@ func interfaceSliceToString(interfaceSlice reflect.Value, stream bool, tzType ..
default:
return unSupportedType, nil
}
case driver.Valuer: // honor each driver's Valuer interface
if value, err := x.Value(); err == nil && value != nil {
// if the output value is a valid string, return that
if strVal, ok := value.(string); ok {
t = textType
arr = append(arr, &strVal)
}
} else if v != nil {
return unSupportedType, nil
} else {
arr = append(arr, nil)
}
default:
if val.Interface() != nil {
if isUUIDImplementer(val) {
t = textType
x := v.(fmt.Stringer).String()
arr = append(arr, &x)
continue
}
return unSupportedType, nil
}

Expand Down

0 comments on commit 3b145a8

Please sign in to comment.