diff --git a/c/driver/snowflake/snowflake_test.cc b/c/driver/snowflake/snowflake_test.cc index 0fe07ecbd7..a4d742491a 100644 --- a/c/driver/snowflake/snowflake_test.cc +++ b/c/driver/snowflake/snowflake_test.cc @@ -146,7 +146,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { bool supports_metadata_current_catalog() const override { return false; } bool supports_metadata_current_db_schema() const override { return false; } bool supports_partitioned_data() const override { return false; } - bool supports_dynamic_parameter_binding() const override { return false; } + bool supports_dynamic_parameter_binding() const override { return true; } bool supports_error_on_incompatible_schema() const override { return false; } bool ddl_implicit_commit_txn() const override { return true; } std::string db_schema() const override { return schema_; } diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index 3bf695f0b2..7bfa08d90c 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -24,6 +24,7 @@ import ( "encoding/json" "errors" "fmt" + "maps" "net" "net/textproto" "os" @@ -48,7 +49,6 @@ import ( "github.com/apache/arrow/go/v17/arrow/memory" "github.com/golang/protobuf/ptypes/wrappers" "github.com/stretchr/testify/suite" - "golang.org/x/exp/maps" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go index db3e397724..175d685e4c 100644 --- a/go/adbc/driver/flightsql/flightsql_driver.go +++ b/go/adbc/driver/flightsql/flightsql_driver.go @@ -32,13 +32,13 @@ package flightsql import ( + "maps" "net/url" "time" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v17/arrow/memory" - "golang.org/x/exp/maps" "google.golang.org/grpc/metadata" ) diff --git a/go/adbc/driver/flightsql/logging.go b/go/adbc/driver/flightsql/logging.go index 4fb12c4112..187ac70783 100644 --- a/go/adbc/driver/flightsql/logging.go +++ b/go/adbc/driver/flightsql/logging.go @@ -21,9 +21,9 @@ import ( "context" "io" "log/slog" + "maps" "time" - "golang.org/x/exp/maps" "golang.org/x/exp/slices" "google.golang.org/grpc" "google.golang.org/grpc/metadata" diff --git a/go/adbc/driver/snowflake/binding.go b/go/adbc/driver/snowflake/binding.go new file mode 100644 index 0000000000..e79ecc8c42 --- /dev/null +++ b/go/adbc/driver/snowflake/binding.go @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package snowflake + +import ( + "database/sql" + "database/sql/driver" + "fmt" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +func convertArrowToNamedValue(batch arrow.Record, index int) ([]driver.NamedValue, error) { + // see goTypeToSnowflake in gosnowflake + // technically, snowflake can bind an array of values at once, but + // only for INSERT, so we can't take advantage of that without + // analyzing the query ourselves + params := make([]driver.NamedValue, batch.NumCols()) + for i, field := range batch.Schema().Fields() { + rawColumn := batch.Column(i) + params[i].Ordinal = i + 1 + switch column := rawColumn.(type) { + case *array.Boolean: + params[i].Value = sql.NullBool{ + Bool: column.Value(index), + Valid: column.IsValid(index), + } + case *array.Float32: + // Snowflake only recognizes float64 + params[i].Value = sql.NullFloat64{ + Float64: float64(column.Value(index)), + Valid: column.IsValid(index), + } + case *array.Float64: + params[i].Value = sql.NullFloat64{ + Float64: column.Value(index), + Valid: column.IsValid(index), + } + case *array.Int8: + // Snowflake only recognizes int64 + params[i].Value = sql.NullInt64{ + Int64: int64(column.Value(index)), + Valid: column.IsValid(index), + } + case *array.Int16: + params[i].Value = sql.NullInt64{ + Int64: int64(column.Value(index)), + Valid: column.IsValid(index), + } + case *array.Int32: + params[i].Value = sql.NullInt64{ + Int64: int64(column.Value(index)), + Valid: column.IsValid(index), + } + case *array.Int64: + params[i].Value = sql.NullInt64{ + Int64: column.Value(index), + Valid: column.IsValid(index), + } + case *array.String: + params[i].Value = sql.NullString{ + String: column.Value(index), + Valid: column.IsValid(index), + } + case *array.LargeString: + params[i].Value = sql.NullString{ + String: column.Value(index), + Valid: column.IsValid(index), + } + default: + return nil, adbc.Error{ + Code: adbc.StatusNotImplemented, + Msg: fmt.Sprintf("[Snowflake] Unsupported bind param '%s' type %s", field.Name, field.Type.String()), + } + } + } + return params, nil +} + +type snowflakeBindReader struct { + doQuery func([]driver.NamedValue) (array.RecordReader, error) + currentBatch arrow.Record + nextIndex int64 + // may be nil if we bound only a batch + stream array.RecordReader +} + +func (r *snowflakeBindReader) Release() { + if r.currentBatch != nil { + r.currentBatch.Release() + } + if r.stream != nil { + r.stream.Release() + } +} + +func (r *snowflakeBindReader) Next() (array.RecordReader, error) { + for r.currentBatch == nil || r.nextIndex >= r.currentBatch.NumRows() { + if r.stream != nil && r.stream.Next() { + if r.currentBatch != nil { + r.currentBatch.Release() + } + r.currentBatch = r.stream.Record() + r.nextIndex = 0 + continue + } else if r.stream != nil && r.stream.Err() != nil { + return nil, r.stream.Err() + } else { + // end-of-stream + return nil, nil + } + } + + params, err := convertArrowToNamedValue(r.currentBatch, int(r.nextIndex)) + if err != nil { + return nil, err + } + r.nextIndex++ + + return r.doQuery(params) +} diff --git a/go/adbc/driver/snowflake/concat_reader.go b/go/adbc/driver/snowflake/concat_reader.go new file mode 100644 index 0000000000..389bfb8862 --- /dev/null +++ b/go/adbc/driver/snowflake/concat_reader.go @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package snowflake + +import ( + "sync/atomic" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" +) + +type readerIter interface { + Release() + + Next() (array.RecordReader, error) +} + +type concatReader struct { + refCount atomic.Int64 + readers readerIter + currentReader array.RecordReader + schema *arrow.Schema + err error +} + +func (r *concatReader) nextReader() { + if r.currentReader != nil { + r.currentReader.Release() + r.currentReader = nil + } + reader, err := r.readers.Next() + if err != nil { + r.err = err + } else { + // May be nil + r.currentReader = reader + } +} +func (r *concatReader) Init(readers readerIter) error { + r.readers = readers + r.refCount.Store(1) + r.nextReader() + if r.err != nil { + return r.err + } else if r.currentReader == nil { + r.err = adbc.Error{ + Code: adbc.StatusInternal, + Msg: "[Snowflake] No data in this stream", + } + return r.err + } + r.schema = r.currentReader.Schema() + return nil +} +func (r *concatReader) Retain() { + r.refCount.Add(1) +} +func (r *concatReader) Release() { + if r.refCount.Add(-1) == 0 { + r.readers.Release() + if r.currentReader != nil { + r.currentReader.Release() + } + } +} +func (r *concatReader) Schema() *arrow.Schema { + if r.schema == nil { + panic("did not call concatReader.Init") + } + return r.schema +} +func (r *concatReader) Next() bool { + for r.currentReader != nil && !r.currentReader.Next() { + r.nextReader() + } + if r.currentReader == nil || r.err != nil { + return false + } + return true +} +func (r *concatReader) Record() arrow.Record { + return r.currentReader.Record() +} +func (r *concatReader) Err() error { + return r.err +} diff --git a/go/adbc/driver/snowflake/driver.go b/go/adbc/driver/snowflake/driver.go index da49a6097d..a49dd13b81 100644 --- a/go/adbc/driver/snowflake/driver.go +++ b/go/adbc/driver/snowflake/driver.go @@ -19,6 +19,7 @@ package snowflake import ( "errors" + "maps" "runtime/debug" "strings" @@ -26,7 +27,6 @@ import ( "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" "github.com/apache/arrow/go/v17/arrow/memory" "github.com/snowflakedb/gosnowflake" - "golang.org/x/exp/maps" ) const ( diff --git a/go/adbc/driver/snowflake/statement.go b/go/adbc/driver/snowflake/statement.go index f61db8f06a..283862ce8f 100644 --- a/go/adbc/driver/snowflake/statement.go +++ b/go/adbc/driver/snowflake/statement.go @@ -19,6 +19,7 @@ package snowflake import ( "context" + "database/sql/driver" "fmt" "strconv" "strings" @@ -463,10 +464,26 @@ func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int6 // concatenate RecordReaders which doesn't exist yet. let's put // that off for now. if st.streamBind != nil || st.bound != nil { - return nil, -1, adbc.Error{ - Msg: "executing non-bulk ingest with bound params not yet implemented", - Code: adbc.StatusNotImplemented, + bind := snowflakeBindReader{ + doQuery: func(params []driver.NamedValue) (array.RecordReader, error) { + loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query, params...) + if err != nil { + return nil, errToAdbcErr(adbc.StatusInternal, err) + } + return newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision) + }, + currentBatch: st.bound, + stream: st.streamBind, + } + st.bound = nil + st.streamBind = nil + + rdr := concatReader{} + err := rdr.Init(&bind) + if err != nil { + return nil, -1, err } + return &rdr, -1, nil } loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query)