From 3b145a89ccdaed869bbaf886f8a07612621200a8 Mon Sep 17 00:00:00 2001 From: Efstathios Chouliaris Date: Thu, 21 Nov 2024 18:03:37 -0600 Subject: [PATCH] [AB#1669514] fix array binding --- bindings_test.go | 101 ++++++++++++++++++++++++++++++++++++----------- converter.go | 33 ++++++++++------ 2 files changed, 101 insertions(+), 33 deletions(-) diff --git a/bindings_test.go b/bindings_test.go index 91530dc5e..ef5998800 100644 --- a/bindings_test.go +++ b/bindings_test.go @@ -8,6 +8,7 @@ import ( "database/sql" "fmt" "log" + "math" "math/big" "math/rand" "reflect" @@ -70,7 +71,7 @@ func TestBindingFloat64(t *testing.T) { dbt.mustExec("INSERT INTO test VALUES (1, ?)", expected) rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if rows.Next() { assertNilF(t, rows.Scan(&out)) @@ -203,14 +204,14 @@ func TestBindingTimestampTZ(t *testing.T) { dbt.Fatal(err.Error()) } defer func() { - assertNilF(t, stmt.Close()) + assertNilF(t, stmt.Close()) }() if _, err = stmt.Exec(DataTypeTimestampTz, expected); err != nil { dbt.Fatal(err) } rows := dbt.mustQuery("SELECT tz FROM tztest WHERE id=?", 1) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v time.Time if rows.Next() { @@ -258,7 +259,7 @@ func TestBindingTimePtrInStruct(t *testing.T) { rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v time.Time if rows.Next() { @@ -307,7 +308,7 @@ func TestBindingTimeInStruct(t *testing.T) { rows := dbt.mustQuery("SELECT tz FROM timeStructTest WHERE id=?", &expectedID) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v time.Time if rows.Next() { @@ -329,7 +330,7 @@ func TestBindingInterface(t *testing.T) { rows := dbt.mustQueryContext( WithHigherPrecision(context.Background()), selectVariousTypes) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if !rows.Next() { dbt.Error("failed to query") @@ -357,7 +358,7 @@ func TestBindingInterfaceString(t *testing.T) { runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQuery(selectVariousTypes) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if !rows.Next() { dbt.Error("failed to query") @@ -382,6 +383,62 @@ func TestBindingInterfaceString(t *testing.T) { }) } +func TestBulkArrayBindingUUID(t *testing.T) { + max := math.Pow10(5) // 100K because my power is maximum + uuids := make([]any, int(max)) + + createTable := "CREATE OR REPLACE TABLE TEST_PREP_STATEMENT (id INT autoincrement start 1 increment 1, uuid VARCHAR)" + insert := "INSERT INTO TEST_PREP_STATEMENT (uuid) VALUES (?)" + + for i := range uuids { + uuids[i] = newTestUUID() + } + + runDBTest(t, func(dbt *DBTest) { + var rows *RowsExtended + t.Cleanup(func() { + if rows != nil { + assertNilF(t, rows.Close()) + } + + dbt.exec(deleteTableSQL) + }) + + dbt.mustExec(createTable) + + bound := Array(&uuids) + res := dbt.mustExec(insert, bound) + + if affected, _ := res.RowsAffected(); affected != int64(max) { + t.Fatalf("failed to insert all rows. expected: %f.0, got: %v", max, affected) + } + + rows = dbt.mustQuery("SELECT * FROM TEST_PREP_STATEMENT ORDER BY ID") + + for i := 0; rows.Next(); i++ { + var ( + id int + out testUUID + ) + if err := rows.Scan(&id, &out); err != nil { + t.Fatal(err) + } + + var found bool + for _, u := range uuids { + if u == out { + found = true + break + } + } + if !found { + t.Errorf("failed to find UUID. expected: %s, but it wasnt in the list", out) + } + } + }) + +} + func TestBulkArrayBindingInterfaceNil(t *testing.T) { nilArray := make([]any, 1) @@ -396,7 +453,7 @@ func TestBulkArrayBindingInterfaceNil(t *testing.T) { Array(&nilArray, TimeType)) rows := dbt.mustQuery(selectAllSQL) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v0 sql.NullInt32 @@ -481,7 +538,7 @@ func TestBulkArrayBindingInterface(t *testing.T) { Array(&boolArray), Array(&strArray), Array(&byteArray), Array(&int64Array)) rows := dbt.mustQuery(selectAllSQLBulkArray) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v0 sql.NullInt32 @@ -586,7 +643,7 @@ func TestBulkArrayBindingInterfaceDateTimeTimestamp(t *testing.T) { rows := dbt.mustQuery(selectAllSQLBulkArrayDateTimeTimestamp) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v0, v1, v2, v3, v4 sql.NullTime @@ -695,7 +752,7 @@ func testBindingArray(t *testing.T, bulk bool) { Array(&tmArray, TimeType)) rows := dbt.mustQuery(selectAllSQL) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() var v0 int @@ -777,7 +834,7 @@ func TestBulkArrayBinding(t *testing.T) { dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?)", dbname), Array(&intArr), Array(&strArr), Array(<zArr, TimestampLTZType), Array(&tzArr, TimestampTZType), Array(&ntzArr, TimestampNTZType), Array(&dateArr, DateType), Array(&timeArr, TimeType), Array(&binArr)) rows := dbt.mustQuery("select * from " + dbname + " order by c1") defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() cnt := 0 var i int @@ -825,7 +882,7 @@ func TestBulkArrayBindingTimeWithPrecision(t *testing.T) { dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?)", dbname), Array(&secondsArr, TimeType), Array(&millisecondsArr, TimeType), Array(µsecondsArr, TimeType), Array(&nanosecondsArr, TimeType)) rows := dbt.mustQuery("select * from " + dbname) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() cnt := 0 var s, ms, us, ns time.Time @@ -866,7 +923,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { Array(&randomStrings)) rows := dbt.mustQuery("select count(*) from " + tempTableName) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if rows.Next() { var count int @@ -878,7 +935,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { rows := dbt.mustQuery("select count(*) from " + tempTableName) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if rows.Next() { var count int @@ -909,7 +966,7 @@ func TestBulkArrayMultiPartBindingInt(t *testing.T) { rows := dbt.mustQuery("select * from binding_test order by c1") defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() cnt := startNum var i int @@ -959,7 +1016,7 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { rows := dbt.mustQuery("select * from binding_test order by c1,c2") defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() cnt := startNum var i sql.NullInt32 @@ -1042,7 +1099,7 @@ func TestFunctionParameters(t *testing.T) { t.Fatal(err) } defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if rows.Err() != nil { t.Fatal(err) @@ -1144,7 +1201,7 @@ func TestVariousBindingModes(t *testing.T) { t.Fatal(err) } defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() if !rows.Next() { t.Fatal("Expected to return a row") @@ -1194,7 +1251,7 @@ func testLOBRetrieval(t *testing.T, useArrowFormat bool) { rows, err := dbt.query(fmt.Sprintf("SELECT randstr(%v, 124)", testSize)) assertNilF(t, err) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() assertTrueF(t, rows.Next(), fmt.Sprintf("no rows returned for the LOB size %v", testSize)) @@ -1227,7 +1284,7 @@ func TestMaxLobSize(t *testing.T) { rows, err := dbt.query("select randstr(20000000, random())") assertNilF(t, err) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() }) }) @@ -1308,7 +1365,7 @@ func testInsertLOBData(t *testing.T, useArrowFormat bool, isLiteral bool) { rows, err := dbt.query("SELECT * FROM lob_test_table") assertNilF(t, err) defer func() { - assertNilF(t, rows.Close()) + assertNilF(t, rows.Close()) }() assertTrueF(t, rows.Next(), fmt.Sprintf("%s: no rows returned", tc.testDesc)) diff --git a/converter.go b/converter.go index f9f0fefdf..d6d40c1a7 100644 --- a/converter.go +++ b/converter.go @@ -423,7 +423,7 @@ func arrayToString(v driver.Value, tsmode snowflakeType, params map[string]*stri res = res[0:len(res)-1] + "]" return bindingValue{&res, jsonFormatStr, &schemaForBytes}, nil } else if isUUIDImplementer(v1) { // special case for UUIDs (snowflake type and other implementers) - stringer := v.(fmt.Stringer) + stringer := v.(fmt.Stringer) // we don't need to validate if it's a fmt.Stringer because we already checked if it's a UUID type with a stringer value := stringer.String() return bindingValue{&value, "", nil}, nil } else if isSliceOfSlices(v) { @@ -2696,44 +2696,38 @@ func interfaceSliceToString(interfaceSlice reflect.Value, stream bool, tzType .. for i := 0; i < interfaceSlice.Len(); i++ { val := interfaceSlice.Index(i) if val.CanInterface() { - switch val.Interface().(type) { + v := val.Interface() + + switch x := v.(type) { case int: t = fixedType - x := val.Interface().(int) v := strconv.Itoa(x) arr = append(arr, &v) case int32: t = fixedType - x := val.Interface().(int32) v := strconv.Itoa(int(x)) arr = append(arr, &v) case int64: t = fixedType - x := val.Interface().(int64) v := strconv.FormatInt(x, 10) arr = append(arr, &v) case float32: t = realType - x := val.Interface().(float32) v := fmt.Sprintf("%g", x) arr = append(arr, &v) case float64: t = realType - x := val.Interface().(float64) v := fmt.Sprintf("%g", x) arr = append(arr, &v) case bool: t = booleanType - x := val.Interface().(bool) v := strconv.FormatBool(x) arr = append(arr, &v) case string: t = textType - x := val.Interface().(string) arr = append(arr, &x) case []byte: t = binaryType - x := val.Interface().([]byte) v := hex.EncodeToString(x) arr = append(arr, &v) case time.Time: @@ -2741,7 +2735,6 @@ func interfaceSliceToString(interfaceSlice reflect.Value, stream bool, tzType .. return unSupportedType, nil } - x := val.Interface().(time.Time) switch tzType[0] { case TimestampNTZType: t = timestampNtzType @@ -2781,8 +2774,26 @@ func interfaceSliceToString(interfaceSlice reflect.Value, stream bool, tzType .. default: return unSupportedType, nil } + case driver.Valuer: // honor each driver's Valuer interface + if value, err := x.Value(); err == nil && value != nil { + // if the output value is a valid string, return that + if strVal, ok := value.(string); ok { + t = textType + arr = append(arr, &strVal) + } + } else if v != nil { + return unSupportedType, nil + } else { + arr = append(arr, nil) + } default: if val.Interface() != nil { + if isUUIDImplementer(val) { + t = textType + x := v.(fmt.Stringer).String() + arr = append(arr, &x) + continue + } return unSupportedType, nil }