Skip to content

Commit

Permalink
get flight_sql integration up and running
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroshade committed Aug 9, 2022
1 parent 27f4895 commit ed9929b
Show file tree
Hide file tree
Showing 13 changed files with 616 additions and 28 deletions.
2 changes: 1 addition & 1 deletion dev/archery/archery/integration/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True,
Scenario(
"flight_sql",
description="Ensure Flight SQL protocol is working as expected.",
skip={"Rust", "Go"}
skip={"Rust"}
),
]

Expand Down
4 changes: 4 additions & 0 deletions go/arrow/array/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ func NewRecord(schema *arrow.Schema, cols []arrow.Array, nrows int64) *simpleRec
}

func (rec *simpleRecord) validate() error {
if rec.rows == 0 && len(rec.arrs) == 0 {
return nil
}

if len(rec.arrs) != len(rec.schema.Fields()) {
return fmt.Errorf("arrow/array: number of columns/fields mismatch")
}
Expand Down
4 changes: 2 additions & 2 deletions go/arrow/array/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ func TestRecord(t *testing.T) {
{
schema: schema,
cols: nil,
rows: -1,
err: fmt.Errorf("arrow/array: number of columns/fields mismatch"),
rows: 0,
// err: fmt.Errorf("arrow/array: number of columns/fields mismatch"),
},
{
schema: schema,
Expand Down
5 changes: 3 additions & 2 deletions go/arrow/flight/flightsql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package flightsql
import (
"context"
"errors"
"io"

"github.com/apache/arrow/go/v10/arrow"
"github.com/apache/arrow/go/v10/arrow/array"
Expand Down Expand Up @@ -306,7 +307,7 @@ func (p *PreparedStatement) Execute(ctx context.Context) (*flight.FlightInfo, er
pstream.CloseSend()

// wait for the server to ack the result
if _, err = pstream.Recv(); err != nil {
if _, err = pstream.Recv(); err != nil && err != io.EOF {
return nil, err
}
}
Expand Down Expand Up @@ -344,7 +345,7 @@ func (p *PreparedStatement) ExecuteUpdate(ctx context.Context) (nrecords int64,
}
} else {
schema := arrow.NewSchema([]arrow.Field{}, nil)
wr = flight.NewRecordWriter(pstream, ipc.WithSchema(p.paramBinding.Schema()))
wr = flight.NewRecordWriter(pstream, ipc.WithSchema(schema))
wr.SetFlightDescriptor(desc)
rec := array.NewRecord(schema, []arrow.Array{}, 0)
if err = wr.Write(rec); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion go/arrow/flight/flightsql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ func getAction(cmd proto.Message) *flight.Action {
return &flight.Action{Body: data}
}

func (s *FlightSqlClientSuite) SetupSuite() {
func (s *FlightSqlClientSuite) SetupTest() {
s.mockClient = FlightServiceClientMock{}
s.sqlClient.Client = &s.mockClient
}

Expand Down
4 changes: 2 additions & 2 deletions go/arrow/flight/flightsql/column_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ type ColumnMetadataBuilder struct {
keys, vals []string
}

func NewColumnMetadataBuilder() ColumnMetadataBuilder {
return ColumnMetadataBuilder{make([]string, 0), make([]string, 0)}
func NewColumnMetadataBuilder() *ColumnMetadataBuilder {
return &ColumnMetadataBuilder{make([]string, 0), make([]string, 0)}
}

func (c *ColumnMetadataBuilder) Build() ColumnMetadata {
Expand Down
6 changes: 4 additions & 2 deletions go/arrow/flight/flightsql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ func (BaseServer) DoPutPreparedStatementUpdate(context.Context, PreparedStatemen
}

type Server interface {
mustEmbedBaseServer()
GetFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.FlightInfo, error)
DoGetStatement(context.Context, StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error)
GetFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.FlightInfo, error)
Expand Down Expand Up @@ -307,6 +306,8 @@ type Server interface {
DoPutPreparedStatementUpdate(context.Context, PreparedStatementUpdate, flight.MessageReader) (int64, error)
CreatePreparedStatement(context.Context, ActionCreatePreparedStatementRequest) (ActionCreatePreparedStatementResult, error)
ClosePreparedStatement(context.Context, ActionClosePreparedStatementRequest) error

mustEmbedBaseServer()
}

func NewFlightServer(srv Server) flight.FlightServer {
Expand Down Expand Up @@ -438,10 +439,11 @@ func (p *putMetadataWriter) WriteMetadata(appMetadata []byte) error {
}

func (f *flightSqlServer) DoPut(stream flight.FlightService_DoPutServer) error {
rdr, err := flight.NewRecordReader(stream, ipc.WithAllocator(f.mem))
rdr, err := flight.NewRecordReader(stream, ipc.WithAllocator(f.mem), ipc.WithDelayReadSchema(true))
if err != nil {
return status.Errorf(codes.InvalidArgument, "failed to read input stream: %s", err.Error())
}
defer rdr.Release()

// flight descriptor should have come with the schema message
request := rdr.LatestFlightDescriptor()
Expand Down
10 changes: 10 additions & 0 deletions go/arrow/flight/flightsql/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package flightsql

import (
pb "github.com/apache/arrow/go/v10/arrow/flight/internal/flight"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
)

const (
Expand Down Expand Up @@ -64,6 +66,14 @@ func impkToTableRef(cmd *pb.CommandGetImportedKeys) TableRef {
}
}

func CreateStatementQueryTicket(handle []byte) ([]byte, error) {
query := &pb.TicketStatementQuery{StatementHandle: handle}
var ticket anypb.Any
ticket.MarshalFrom(query)

return proto.Marshal(&ticket)
}

type (
GetDBSchemasOpts = pb.CommandGetDbSchemas
GetTablesOpts = pb.CommandGetTables
Expand Down
26 changes: 24 additions & 2 deletions go/arrow/flight/record_batch_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type DataStreamReader interface {
type dataMessageReader struct {
rdr DataStreamReader

peeked *FlightData
refCount int64
msg *ipc.Message

Expand All @@ -46,7 +47,18 @@ type dataMessageReader struct {
}

func (d *dataMessageReader) Message() (*ipc.Message, error) {
fd, err := d.rdr.Recv()
var (
fd *FlightData
err error
)

if d.peeked != nil {
fd = d.peeked
d.peeked = nil
} else {
fd, err = d.rdr.Recv()
}

if err != nil {
if d.msg != nil {
// clear the previous message in the error case
Expand Down Expand Up @@ -135,8 +147,18 @@ func (r *Reader) Chunk() StreamChunk {
// as the source of the ipc messages, opts passed will be passed to the underlying
// ipc.Reader such as ipc.WithSchema and ipc.WithAllocator
func NewRecordReader(r DataStreamReader, opts ...ipc.Option) (*Reader, error) {
// peek the first message for a descriptor
data, err := r.Recv()
if err != nil {
return nil, err
}

rdr := &Reader{dmr: &dataMessageReader{rdr: r}}
var err error
rdr.dmr.descr = data.FlightDescriptor
if len(data.DataHeader) > 0 {
rdr.dmr.peeked = data
}

if rdr.Reader, err = ipc.NewReaderFromMessageReader(rdr.dmr, opts...); err != nil {
return nil, fmt.Errorf("arrow/flight: could not create flight reader: %w", err)
}
Expand Down
Loading

0 comments on commit ed9929b

Please sign in to comment.