Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(go/adbc/driver/snowflake): handling of integer values sent for NUMBER columns #1267

Merged
merged 9 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"encoding/base64"
"encoding/pem"
"fmt"
"math"
"os"
"runtime"
"strconv"
Expand All @@ -38,6 +39,7 @@ import (
"github.com/apache/arrow-adbc/go/adbc/validation"
"github.com/apache/arrow/go/v14/arrow"
"github.com/apache/arrow/go/v14/arrow/array"
"github.com/apache/arrow/go/v14/arrow/decimal128"
"github.com/apache/arrow/go/v14/arrow/memory"
"github.com/google/uuid"
"github.com/snowflakedb/gosnowflake"
Expand Down Expand Up @@ -679,6 +681,96 @@ func (suite *SnowflakeTests) TestUseHighPrecision() {
suite.Equal(9876543210.99, rec.Column(1).(*array.Float64).Value(1))
}

func (suite *SnowflakeTests) TestDecimalHighPrecision() {
for sign := 0; sign <= 1; sign++ {
for scale := 0; scale <= 2; scale++ {
for precision := 3; precision <= 38; precision++ {
numberString := strings.Repeat("9", precision-scale) + "." + strings.Repeat("9", scale)
if sign == 1 {
numberString = "-" + numberString
}
query := "SELECT CAST('" + numberString + fmt.Sprintf("' AS NUMBER(%d, %d)) AS RESULT", precision, scale)
number, err := decimal128.FromString(numberString, int32(precision), int32(scale))
suite.NoError(err)

suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueEnabled))
suite.Require().NoError(suite.stmt.SetSqlQuery(query))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()

suite.EqualValues(1, n)
suite.Truef(arrow.TypeEqual(&arrow.Decimal128Type{Precision: int32(precision), Scale: int32(scale)}, rdr.Schema().Field(0).Type), "expected decimal(%d, %d), got %s", precision, scale, rdr.Schema().Field(0).Type)
suite.True(rdr.Next())
rec := rdr.Record()

suite.Equal(number, rec.Column(0).(*array.Decimal128).Value(0))
}
}
}
}

func (suite *SnowflakeTests) TestNonIntDecimalLowPrecision() {
for sign := 0; sign <= 1; sign++ {
for precision := 3; precision <= 38; precision++ {
scale := 2
numberString := strings.Repeat("9", precision-scale) + ".99"
if sign == 1 {
numberString = "-" + numberString
}
query := "SELECT CAST('" + numberString + fmt.Sprintf("' AS NUMBER(%d, %d)) AS RESULT", precision, scale)
decimalNumber, err := decimal128.FromString(numberString, int32(precision), int32(scale))
suite.NoError(err)
number := decimalNumber.ToFloat64(int32(scale))

suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueDisabled))
suite.Require().NoError(suite.stmt.SetSqlQuery(query))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()

suite.EqualValues(1, n)
suite.Truef(arrow.TypeEqual(arrow.PrimitiveTypes.Float64, rdr.Schema().Field(0).Type), "expected float64, got %s", rdr.Schema().Field(0).Type)
suite.True(rdr.Next())
rec := rdr.Record()

value := rec.Column(0).(*array.Float64).Value(0)
difference := math.Abs(number - value)
suite.Truef(difference < 1e-13, "expected %f, got %f", number, value)
}
}
}

func (suite *SnowflakeTests) TestIntDecimalLowPrecision() {
for sign := 0; sign <= 1; sign++ {
for precision := 3; precision <= 38; precision++ {
scale := 0
numberString := strings.Repeat("9", precision-scale)
if sign == 1 {
numberString = "-" + numberString
}
query := "SELECT CAST('" + numberString + fmt.Sprintf("' AS NUMBER(%d, %d)) AS RESULT", precision, scale)
decimalNumber, err := decimal128.FromString(numberString, int32(precision), int32(scale))
suite.NoError(err)
number := int64(decimalNumber.LowBits())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is great behavior for the driver, but it is the current behavior.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment here as to why the resulting value is equal to the low bits?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind filing an issue to update this? It sounds like we must error here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #1277


suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueDisabled))
suite.Require().NoError(suite.stmt.SetSqlQuery(query))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()

suite.EqualValues(1, n)
suite.Truef(arrow.TypeEqual(arrow.PrimitiveTypes.Int64, rdr.Schema().Field(0).Type), "expected int64, got %s", rdr.Schema().Field(0).Type)
suite.True(rdr.Next())
rec := rdr.Record()

value := rec.Column(0).(*array.Int64).Value(0)
suite.Equal(number, value)
}
}
}

func (suite *SnowflakeTests) TestDescribeOnly() {
suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueEnabled))
suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT CAST('9999.99' AS NUMBER(6, 2)) AS RESULT"))
Expand Down
52 changes: 43 additions & 9 deletions go/adbc/driver/snowflake/record_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,36 @@ func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader, useHighP
}
f.Type = dt
transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) {
return compute.CastArray(ctx, a, compute.SafeCastOptions(dt))
return integerToDecimal128(ctx, a, dt)
}
} else {
if srcMeta.Scale != 0 {
f.Type = arrow.PrimitiveTypes.Float64
transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) {
result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true},
&compute.ArrayDatum{Value: a.Data()},
compute.NewDatum(math.Pow10(int(srcMeta.Scale))))
if err != nil {
return nil, err
// For precisions of 16, 17 and 18, a conversion from int64 to float64 fails with an error
// So for these precisions, we instead convert first to a decimal128 and then to a float64.
Comment on lines +109 to +110
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the error? Should those precisions instead work and we should push a fix upstream to the Arrow lib?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message is "invalid: integer value 99999999999999999 not in range: -9007199254740992 to 9007199254740992". I do think that it makes sense to allow a lossy conversion in Arrow from int64 to float64 and that would avoid the need for this special case. This may require some design work in Arrow -- for instance, having the Divide kernel take a CastOptions or adding AllowFloatTruncate to ArithmeticOptions.

if srcMeta.Precision > 15 && srcMeta.Precision < 19 {
transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) {
result, err := integerToDecimal128(ctx, a, &arrow.Decimal128Type{
Precision: int32(srcMeta.Precision),
Scale: int32(srcMeta.Scale),
})
if err != nil {
return nil, err
}
return compute.CastArray(ctx, result, compute.UnsafeCastOptions(f.Type))
}
} else {
// For precisions less than 16, we can simply scale the integer value appropriately
transformers[i] = func(ctx context.Context, a arrow.Array) (arrow.Array, error) {
result, err := compute.Divide(ctx, compute.ArithmeticOptions{NoCheckOverflow: true},
&compute.ArrayDatum{Value: a.Data()},
compute.NewDatum(math.Pow10(int(srcMeta.Scale))))
if err != nil {
return nil, err
}
defer result.Release()
return result.(*compute.ArrayDatum).MakeArray(), nil
}
defer result.Release()
return result.(*compute.ArrayDatum).MakeArray(), nil
}
} else {
f.Type = arrow.PrimitiveTypes.Int64
Expand Down Expand Up @@ -266,6 +282,24 @@ func getTransformer(sc *arrow.Schema, ld gosnowflake.ArrowStreamLoader, useHighP
return out, getRecTransformer(out, transformers)
}

func integerToDecimal128(ctx context.Context, a arrow.Array, dt *arrow.Decimal128Type) (arrow.Array, error) {
// We can't do a cast directly into the destination type because the numbers we get from Snowflake
// are scaled integers. So not only would the cast produce the wrong value, it also risks producing
// an error of precisions which e.g. can't hold every int64. To work around these problems, we instead
// cast into a decimal type of a precision and scale which we know will hold all values and won't
// require scaling, We then substitute the type on this array with the actual return type.
Comment on lines +286 to +290
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hate this so much, but it makes sense that we have to do this because of how snowflake works.


dt0 := &arrow.Decimal128Type{
Precision: int32(20),
Scale: int32(0),
}
result, err := compute.CastArray(ctx, a, compute.SafeCastOptions(dt0))
if err == nil {
result.Data().Reset(dt, result.Data().Len(), result.Data().Buffers(), result.Data().Children(), result.Data().NullN(), result.Data().Offset())
}
return result, err
CurtHagenlocher marked this conversation as resolved.
Show resolved Hide resolved
}

func rowTypesToArrowSchema(ctx context.Context, ld gosnowflake.ArrowStreamLoader, useHighPrecision bool) (*arrow.Schema, error) {
var loc *time.Location

Expand Down
Loading