Skip to content

Commit

Permalink
fix(go/adbc/driver/snowflake): handling of integer values sent for NU…
Browse files Browse the repository at this point in the history
…MBER columns (apache#1267)

Fixes apache#1242.
  • Loading branch information
CurtHagenlocher authored and vleslief-ms committed Nov 9, 2023
1 parent 69d1060 commit b9fe679
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 9 deletions.
94 changes: 94 additions & 0 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"encoding/base64"
"encoding/pem"
"fmt"
"math"
"os"
"runtime"
"strconv"
Expand All @@ -38,6 +39,7 @@ import (
"github.com/apache/arrow-adbc/go/adbc/validation"
"github.com/apache/arrow/go/v14/arrow"
"github.com/apache/arrow/go/v14/arrow/array"
"github.com/apache/arrow/go/v14/arrow/decimal128"
"github.com/apache/arrow/go/v14/arrow/memory"
"github.com/google/uuid"
"github.com/snowflakedb/gosnowflake"
Expand Down Expand Up @@ -679,6 +681,98 @@ func (suite *SnowflakeTests) TestUseHighPrecision() {
suite.Equal(9876543210.99, rec.Column(1).(*array.Float64).Value(1))
}

func (suite *SnowflakeTests) TestDecimalHighPrecision() {
for sign := 0; sign <= 1; sign++ {
for scale := 0; scale <= 2; scale++ {
for precision := 3; precision <= 38; precision++ {
numberString := strings.Repeat("9", precision-scale) + "." + strings.Repeat("9", scale)
if sign == 1 {
numberString = "-" + numberString
}
query := "SELECT CAST('" + numberString + fmt.Sprintf("' AS NUMBER(%d, %d)) AS RESULT", precision, scale)
number, err := decimal128.FromString(numberString, int32(precision), int32(scale))
suite.NoError(err)

suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueEnabled))
suite.Require().NoError(suite.stmt.SetSqlQuery(query))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()

suite.EqualValues(1, n)
suite.Truef(arrow.TypeEqual(&arrow.Decimal128Type{Precision: int32(precision), Scale: int32(scale)}, rdr.Schema().Field(0).Type), "expected decimal(%d, %d), got %s", precision, scale, rdr.Schema().Field(0).Type)
suite.True(rdr.Next())
rec := rdr.Record()

suite.Equal(number, rec.Column(0).(*array.Decimal128).Value(0))
}
}
}
}

func (suite *SnowflakeTests) TestNonIntDecimalLowPrecision() {
for sign := 0; sign <= 1; sign++ {
for precision := 3; precision <= 38; precision++ {
scale := 2
numberString := strings.Repeat("9", precision-scale) + ".99"
if sign == 1 {
numberString = "-" + numberString
}
query := "SELECT CAST('" + numberString + fmt.Sprintf("' AS NUMBER(%d, %d)) AS RESULT", precision, scale)
decimalNumber, err := decimal128.FromString(numberString, int32(precision), int32(scale))
suite.NoError(err)
number := decimalNumber.ToFloat64(int32(scale))

suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueDisabled))
suite.Require().NoError(suite.stmt.SetSqlQuery(query))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()

suite.EqualValues(1, n)
suite.Truef(arrow.TypeEqual(arrow.PrimitiveTypes.Float64, rdr.Schema().Field(0).Type), "expected float64, got %s", rdr.Schema().Field(0).Type)
suite.True(rdr.Next())
rec := rdr.Record()

value := rec.Column(0).(*array.Float64).Value(0)
difference := math.Abs(number - value)
suite.Truef(difference < 1e-13, "expected %f, got %f", number, value)
}
}
}

func (suite *SnowflakeTests) TestIntDecimalLowPrecision() {
for sign := 0; sign <= 1; sign++ {
for precision := 3; precision <= 38; precision++ {
scale := 0
numberString := strings.Repeat("9", precision-scale)
if sign == 1 {
numberString = "-" + numberString
}
query := "SELECT CAST('" + numberString + fmt.Sprintf("' AS NUMBER(%d, %d)) AS RESULT", precision, scale)
decimalNumber, err := decimal128.FromString(numberString, int32(precision), int32(scale))
suite.NoError(err)
// The current behavior of the driver for decimal128 values too large to fit into 64 bits is to simply
// return the low 64 bits of the value.
number := int64(decimalNumber.LowBits())

suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueDisabled))
suite.Require().NoError(suite.stmt.SetSqlQuery(query))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()

suite.EqualValues(1, n)
suite.Truef(arrow.TypeEqual(arrow.PrimitiveTypes.Int64, rdr.Schema().Field(0).Type), "expected int64, got %s", rdr.Schema().Field(0).Type)
suite.True(rdr.Next())
rec := rdr.Record()

value := rec.Column(0).(*array.Int64).Value(0)
suite.Equal(number, value)
}
}
}

func (suite *SnowflakeTests) TestDescribeOnly() {
suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueEnabled))
suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT CAST('9999.99' AS NUMBER(6, 2)) AS RESULT"))
Expand Down
55 changes: 46 additions & 9 deletions go/adbc/driver/snowflake/record_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,36 @@ func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader, useHighP
}
f.Type = dt
transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) {
return compute.CastArray(ctx, a, compute.SafeCastOptions(dt))
return integerToDecimal128(ctx, a, dt)
}
} else {
if srcMeta.Scale != 0 {
f.Type = arrow.PrimitiveTypes.Float64
transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) {
result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true},
&compute.ArrayDatum{Value: a.Data()},
compute.NewDatum(math.Pow10(int(srcMeta.Scale))))
if err != nil {
return nil, err
// For precisions of 16, 17 and 18, a conversion from int64 to float64 fails with an error
// So for these precisions, we instead convert first to a decimal128 and then to a float64.
if srcMeta.Precision > 15 && srcMeta.Precision < 19 {
transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) {
result, err := integerToDecimal128(ctx, a, &arrow.Decimal128Type{
Precision: int32(srcMeta.Precision),
Scale: int32(srcMeta.Scale),
})
if err != nil {
return nil, err
}
return compute.CastArray(ctx, result, compute.UnsafeCastOptions(f.Type))
}
} else {
// For precisions less than 16, we can simply scale the integer value appropriately
transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) {
result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true},
&compute.ArrayDatum{Value: a.Data()},
compute.NewDatum(math.Pow10(int(srcMeta.Scale))))
if err != nil {
return nil, err
}
defer result.Release()
return result.(*compute.ArrayDatum).MakeArray(), nil
}
defer result.Release()
return result.(*compute.ArrayDatum).MakeArray(), nil
}
} else {
f.Type = arrow.PrimitiveTypes.Int64
Expand Down Expand Up @@ -266,6 +282,27 @@ func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader, useHighP
return out, getRecTransformer(out, transformers)
}

func integerToDecimal128(ctx context.Context, a arrow.Array, dt *arrow.Decimal128Type) (arrow.Array, error) {
// We can't do a cast directly into the destination type because the numbers we get from Snowflake
// are scaled integers. So not only would the cast produce the wrong value, it also risks producing
// an error of precisions which e.g. can't hold every int64. To work around these problems, we instead
// cast into a decimal type of a precision and scale which we know will hold all values and won't
// require scaling, We then substitute the type on this array with the actual return type.

dt0 := &arrow.Decimal128Type{
Precision: int32(20),
Scale: int32(0),
}
result, err := compute.CastArray(ctx, a, compute.SafeCastOptions(dt0))
if err != nil {
return nil, err
}

data := result.Data()
result.Data().Reset(dt, data.Len(), data.Buffers(), data.Children(), data.NullN(), data.Offset())
return result, err
}

func rowTypesToArrowSchema(ctx context.Context, ld gosnowflake.ArrowStreamLoader, useHighPrecision bool) (*arrow.Schema, error) {
var loc *time.Location

Expand Down

0 comments on commit b9fe679

Please sign in to comment.