Skip to content

Commit

Permalink
feat(go/adbc/driver/snowflake): support parameter binding
Browse files Browse the repository at this point in the history
Fixes #1144.
  • Loading branch information
lidavidm committed May 3, 2024
1 parent d6ddc01 commit e1553dd
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 8 deletions.
2 changes: 1 addition & 1 deletion c/driver/snowflake/snowflake_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
bool supports_metadata_current_catalog() const override { return false; }
bool supports_metadata_current_db_schema() const override { return false; }
bool supports_partitioned_data() const override { return false; }
bool supports_dynamic_parameter_binding() const override { return false; }
bool supports_dynamic_parameter_binding() const override { return true; }
bool supports_error_on_incompatible_schema() const override { return false; }
bool ddl_implicit_commit_txn() const override { return true; }
std::string db_schema() const override { return schema_; }
Expand Down
2 changes: 1 addition & 1 deletion go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"encoding/json"
"errors"
"fmt"
"maps"
"net"
"net/textproto"
"os"
Expand All @@ -48,7 +49,6 @@ import (
"github.com/apache/arrow/go/v17/arrow/memory"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/stretchr/testify/suite"
"golang.org/x/exp/maps"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
Expand Down
2 changes: 1 addition & 1 deletion go/adbc/driver/flightsql/flightsql_driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
package flightsql

import (
"maps"
"net/url"
"time"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/apache/arrow/go/v17/arrow/memory"
"golang.org/x/exp/maps"
"google.golang.org/grpc/metadata"
)

Expand Down
2 changes: 1 addition & 1 deletion go/adbc/driver/flightsql/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import (
"context"
"io"
"log/slog"
"maps"
"time"

"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
Expand Down
138 changes: 138 additions & 0 deletions go/adbc/driver/snowflake/binding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package snowflake

import (
"database/sql"
"database/sql/driver"
"fmt"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/arrow/go/v17/arrow/array"
)

func convertArrowToNamedValue(batch arrow.Record, index int) ([]driver.NamedValue, error) {
// see goTypeToSnowflake in gosnowflake
// technically, snowflake can bind an array of values at once, but
// only for INSERT, so we can't take advantage of that without
// analyzing the query ourselves
params := make([]driver.NamedValue, batch.NumCols())
for i, field := range batch.Schema().Fields() {
rawColumn := batch.Column(i)
params[i].Ordinal = i + 1
switch column := rawColumn.(type) {
case *array.Boolean:
params[i].Value = sql.NullBool{
Bool: column.Value(index),
Valid: column.IsValid(index),
}
case *array.Float32:
// Snowflake only recognizes float64
params[i].Value = sql.NullFloat64{
Float64: float64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Float64:
params[i].Value = sql.NullFloat64{
Float64: column.Value(index),
Valid: column.IsValid(index),
}
case *array.Int8:
// Snowflake only recognizes int64
params[i].Value = sql.NullInt64{
Int64: int64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Int16:
params[i].Value = sql.NullInt64{
Int64: int64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Int32:
params[i].Value = sql.NullInt64{
Int64: int64(column.Value(index)),
Valid: column.IsValid(index),
}
case *array.Int64:
params[i].Value = sql.NullInt64{
Int64: column.Value(index),
Valid: column.IsValid(index),
}
case *array.String:
params[i].Value = sql.NullString{
String: column.Value(index),
Valid: column.IsValid(index),
}
case *array.LargeString:
params[i].Value = sql.NullString{
String: column.Value(index),
Valid: column.IsValid(index),
}
default:
return nil, adbc.Error{
Code: adbc.StatusNotImplemented,
Msg: fmt.Sprintf("[Snowflake] Unsupported bind param '%s' type %s", field.Name, field.Type.String()),
}
}
}
return params, nil
}

type snowflakeBindReader struct {
doQuery func([]driver.NamedValue) (array.RecordReader, error)
currentBatch arrow.Record
nextIndex int64
// may be nil if we bound only a batch
stream array.RecordReader
}

func (r *snowflakeBindReader) Release() {
if r.currentBatch != nil {
r.currentBatch.Release()
}
if r.stream != nil {
r.stream.Release()
}
}

func (r *snowflakeBindReader) Next() (array.RecordReader, error) {
for r.currentBatch == nil || r.nextIndex >= r.currentBatch.NumRows() {
if r.stream != nil && r.stream.Next() {
if r.currentBatch != nil {
r.currentBatch.Release()
}
r.currentBatch = r.stream.Record()
r.nextIndex = 0
continue
} else if r.stream != nil && r.stream.Err() != nil {
return nil, r.stream.Err()
} else {
// end-of-stream
return nil, nil
}
}

params, err := convertArrowToNamedValue(r.currentBatch, int(r.nextIndex))
if err != nil {
return nil, err
}
r.nextIndex++

return r.doQuery(params)
}
102 changes: 102 additions & 0 deletions go/adbc/driver/snowflake/concat_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package snowflake

import (
"sync/atomic"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/arrow/go/v17/arrow/array"
)

type readerIter interface {
Release()

Next() (array.RecordReader, error)
}

type concatReader struct {
refCount atomic.Int64
readers readerIter
currentReader array.RecordReader
schema *arrow.Schema
err error
}

func (r *concatReader) nextReader() {
if r.currentReader != nil {
r.currentReader.Release()
r.currentReader = nil
}
reader, err := r.readers.Next()
if err != nil {
r.err = err
} else {
// May be nil
r.currentReader = reader
}
}
func (r *concatReader) Init(readers readerIter) error {
r.readers = readers
r.refCount.Store(1)
r.nextReader()
if r.err != nil {
return r.err
} else if r.currentReader == nil {
r.err = adbc.Error{
Code: adbc.StatusInternal,
Msg: "[Snowflake] No data in this stream",
}
return r.err
}
r.schema = r.currentReader.Schema()
return nil
}
func (r *concatReader) Retain() {
r.refCount.Add(1)
}
func (r *concatReader) Release() {
if r.refCount.Add(-1) == 0 {
r.readers.Release()
if r.currentReader != nil {
r.currentReader.Release()
}
}
}
func (r *concatReader) Schema() *arrow.Schema {
if r.schema == nil {
panic("did not call concatReader.Init")
}
return r.schema
}
func (r *concatReader) Next() bool {
for r.currentReader != nil && !r.currentReader.Next() {
r.nextReader()
}
if r.currentReader == nil || r.err != nil {
return false
}
return true
}
func (r *concatReader) Record() arrow.Record {
return r.currentReader.Record()
}
func (r *concatReader) Err() error {
return r.err
}
2 changes: 1 addition & 1 deletion go/adbc/driver/snowflake/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package snowflake

import (
"errors"
"maps"
"runtime/debug"
"strings"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/apache/arrow/go/v17/arrow/memory"
"github.com/snowflakedb/gosnowflake"
"golang.org/x/exp/maps"
)

const (
Expand Down
23 changes: 20 additions & 3 deletions go/adbc/driver/snowflake/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package snowflake

import (
"context"
"database/sql/driver"
"fmt"
"strconv"
"strings"
Expand Down Expand Up @@ -463,10 +464,26 @@ func (st *statement) ExecuteQuery(ctx context.Context) (array.RecordReader, int6
// concatenate RecordReaders which doesn't exist yet. let's put
// that off for now.
if st.streamBind != nil || st.bound != nil {
return nil, -1, adbc.Error{
Msg: "executing non-bulk ingest with bound params not yet implemented",
Code: adbc.StatusNotImplemented,
bind := snowflakeBindReader{
doQuery: func(params []driver.NamedValue) (array.RecordReader, error) {
loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query, params...)
if err != nil {
return nil, errToAdbcErr(adbc.StatusInternal, err)
}
return newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision)
},
currentBatch: st.bound,
stream: st.streamBind,
}
st.bound = nil
st.streamBind = nil

rdr := concatReader{}
err := rdr.Init(&bind)
if err != nil {
return nil, -1, err
}
return &rdr, -1, nil
}

loader, err := st.cnxn.cn.QueryArrowStream(ctx, st.query)
Expand Down

0 comments on commit e1553dd

Please sign in to comment.