Skip to content

Commit

Permalink
apacheGH-37720: [Go][FlightSQL] Add prepared statement handle to DoPu…
Browse files Browse the repository at this point in the history
…t result (apache#40311)

### Rationale for this change
See discussion on apache#37720 and mailing list: https://lists.apache.org/thread/3kb82ypx99q96g84qv555l6x8r0bppyq

### What changes are included in this PR?

Changes the Go FlightSQL client and server implementations to support returning an updated prepared statement handle to the client as part of the `DoPut(PreparedStatement)` RPC call.

### Are these changes tested?

### Are there any user-facing changes?

See parent issue and docs PR apache#40243  for details of user facing changes.

**This PR includes breaking changes to public APIs.**

* GitHub Issue: apache#37720

Lead-authored-by: Adam Curtis <[email protected]>
Co-authored-by: David Li <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
2 people authored and tolleybot committed May 2, 2024
1 parent 4a576e3 commit 7f5797b
Show file tree
Hide file tree
Showing 8 changed files with 683 additions and 561 deletions.
36 changes: 26 additions & 10 deletions go/arrow/flight/flightsql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,6 @@ func (p *PreparedStatement) Execute(ctx context.Context, opts ...grpc.CallOption
if err != nil {
return nil, err
}

wr, err := p.writeBindParameters(pstream, desc)
if err != nil {
return nil, err
Expand All @@ -1133,9 +1132,7 @@ func (p *PreparedStatement) Execute(ctx context.Context, opts ...grpc.CallOption
return nil, err
}
pstream.CloseSend()

// wait for the server to ack the result
if _, err = pstream.Recv(); err != nil && err != io.EOF {
if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -1173,9 +1170,7 @@ func (p *PreparedStatement) ExecutePut(ctx context.Context, opts ...grpc.CallOpt
return err
}
pstream.CloseSend()

// wait for the server to ack the result
if _, err = pstream.Recv(); err != nil && err != io.EOF {
if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil {
return err
}
}
Expand Down Expand Up @@ -1219,9 +1214,7 @@ func (p *PreparedStatement) ExecutePoll(ctx context.Context, retryDescriptor *fl
return nil, err
}
pstream.CloseSend()

// wait for the server to ack the result
if _, err = pstream.Recv(); err != nil && err != io.EOF {
if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -1313,6 +1306,29 @@ func (p *PreparedStatement) writeBindParameters(pstream pb.FlightService_DoPutCl
}
}

func (p *PreparedStatement) captureDoPutPreparedStatementHandle(pstream pb.FlightService_DoPutClient) error {
var (
result *pb.PutResult
preparedStatementResult pb.DoPutPreparedStatementResult
err error
)
if result, err = pstream.Recv(); err != nil && err != io.EOF {
return err
}
// skip if server does not provide a response (legacy server)
if result == nil {
return nil
}
if err = proto.Unmarshal(result.GetAppMetadata(), &preparedStatementResult); err != nil {
return err
}
handle := preparedStatementResult.GetPreparedStatementHandle()
if handle != nil {
p.handle = handle
}
return nil
}

// DatasetSchema may be nil if the server did not return it when creating the
// Prepared Statement.
func (p *PreparedStatement) DatasetSchema() *arrow.Schema { return p.datasetSchema }
Expand Down
35 changes: 24 additions & 11 deletions go/arrow/flight/flightsql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,24 +408,26 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecute() {

func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() {
const query = "query"
const handle = "handle"
const updatedHandle = "updated handle"

// create and close actions
cmd := &pb.ActionCreatePreparedStatementRequest{Query: query}
action := getAction(cmd)
action.Type = flightsql.CreatePreparedStatementActionType
closeAct := getAction(&pb.ActionClosePreparedStatementRequest{PreparedStatementHandle: []byte(query)})
closeAct := getAction(&pb.ActionClosePreparedStatementRequest{PreparedStatementHandle: []byte(updatedHandle)})
closeAct.Type = flightsql.ClosePreparedStatementActionType

// results from createprepared statement
result := &pb.ActionCreatePreparedStatementResult{
PreparedStatementHandle: []byte(query),
actionResult := &pb.ActionCreatePreparedStatementResult{
PreparedStatementHandle: []byte(handle),
}
schema := arrow.NewSchema([]arrow.Field{{Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil)
result.ParameterSchema = flight.SerializeSchema(schema, memory.DefaultAllocator)
actionResult.ParameterSchema = flight.SerializeSchema(schema, memory.DefaultAllocator)

// mocked client stream
var out anypb.Any
out.MarshalFrom(result)
out.MarshalFrom(actionResult)
data, _ := proto.Marshal(&out)

createRsp := &mockDoActionClient{}
Expand All @@ -443,7 +445,12 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() {
s.mockClient.On("DoAction", flightsql.CreatePreparedStatementActionType, action.Body, s.callOpts).Return(createRsp, nil)
s.mockClient.On("DoAction", flightsql.ClosePreparedStatementActionType, closeAct.Body, s.callOpts).Return(closeRsp, nil)

expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)})
expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(handle)})

// mocked DoPut result
doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(updatedHandle)}
resdata, _ := proto.Marshal(doPutPreparedStatementResult)
putResult := &pb.PutResult{ AppMetadata: resdata }

// mocked client stream for DoPut
mockedPut := &mockDoPutClient{}
Expand All @@ -452,29 +459,30 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() {
return proto.Equal(expectedDesc, fd.FlightDescriptor)
})).Return(nil).Twice() // first sends schema message, second sends data
mockedPut.On("CloseSend").Return(nil)
mockedPut.On("Recv").Return((*pb.PutResult)(nil), nil)
mockedPut.On("Recv").Return(putResult, nil)

infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)}
infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(handle)}
desc := getDesc(infoCmd)
s.mockClient.On("GetFlightInfo", desc.Type, desc.Cmd, s.callOpts).Return(&emptyFlightInfo, nil)

prepared, err := s.sqlClient.Prepare(context.TODO(), query, s.callOpts...)
s.NoError(err)
defer prepared.Close(context.TODO(), s.callOpts...)

s.Equal(string(prepared.Handle()), "query")
s.Equal(string(prepared.Handle()), handle)

paramSchema := prepared.ParameterSchema()
rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, paramSchema, strings.NewReader(`[{"id": 1}]`))
s.NoError(err)
defer rec.Release()

s.Equal(string(prepared.Handle()), "query")
s.Equal(string(prepared.Handle()), handle)

prepared.SetParameters(rec)
info, err := prepared.Execute(context.TODO(), s.callOpts...)
s.NoError(err)
s.Equal(&emptyFlightInfo, info)
s.Equal(string(prepared.Handle()), updatedHandle)
}

func (s *FlightSqlClientSuite) TestPreparedStatementExecuteReaderBinding() {
Expand Down Expand Up @@ -516,6 +524,11 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteReaderBinding() {

expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)})

// mocked DoPut result
doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(query)}
resdata, _ := proto.Marshal(doPutPreparedStatementResult)
putResult := &pb.PutResult{ AppMetadata: resdata }

// mocked client stream for DoPut
mockedPut := &mockDoPutClient{}
s.mockClient.On("DoPut", s.callOpts).Return(mockedPut, nil)
Expand All @@ -528,7 +541,7 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteReaderBinding() {
return fd.FlightDescriptor == nil
})).Return(nil).Times(3)
mockedPut.On("CloseSend").Return(nil)
mockedPut.On("Recv").Return((*pb.PutResult)(nil), nil)
mockedPut.On("Recv").Return(putResult, nil)

infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)}
desc := getDesc(infoCmd)
Expand Down
10 changes: 5 additions & 5 deletions go/arrow/flight/flightsql/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1768,16 +1768,16 @@ func (s *MockServer) CreatePreparedStatement(ctx context.Context, req flightsql.
}, nil
}

func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flightsql.PreparedStatementQuery, r flight.MessageReader, w flight.MetadataWriter) error {
func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flightsql.PreparedStatementQuery, r flight.MessageReader, w flight.MetadataWriter) ([]byte, error) {
if s.ExpectedPreparedStatementSchema != nil {
if !s.ExpectedPreparedStatementSchema.Equal(r.Schema()) {
return errors.New("parameter schema: unexpected")
return nil, errors.New("parameter schema: unexpected")
}
return nil
return qry.GetPreparedStatementHandle(), nil
}

if s.PreparedStatementParameterSchema != nil && !s.PreparedStatementParameterSchema.Equal(r.Schema()) {
return fmt.Errorf("parameter schema: %w", arrow.ErrInvalid)
return nil, fmt.Errorf("parameter schema: %w", arrow.ErrInvalid)
}

// GH-35328: it's rare, but this function can complete execution and return
Expand All @@ -1791,7 +1791,7 @@ func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flight
for r.Next() {
}

return nil
return qry.GetPreparedStatementHandle(), nil
}

func (s *MockServer) DoGetStatement(ctx context.Context, ticket flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
Expand Down
8 changes: 4 additions & 4 deletions go/arrow/flight/flightsql/example/sqlite_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,21 +618,21 @@ func getParamsForStatement(rdr flight.MessageReader) (params [][]interface{}, er
return params, rdr.Err()
}

func (s *SQLiteFlightSQLServer) DoPutPreparedStatementQuery(_ context.Context, cmd flightsql.PreparedStatementQuery, rdr flight.MessageReader, _ flight.MetadataWriter) error {
func (s *SQLiteFlightSQLServer) DoPutPreparedStatementQuery(_ context.Context, cmd flightsql.PreparedStatementQuery, rdr flight.MessageReader, _ flight.MetadataWriter) ([]byte, error) {
val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle()))
if !ok {
return status.Error(codes.InvalidArgument, "prepared statement not found")
return nil, status.Error(codes.InvalidArgument, "prepared statement not found")
}

stmt := val.(Statement)
args, err := getParamsForStatement(rdr)
if err != nil {
return status.Errorf(codes.Internal, "error gathering parameters for prepared statement query: %s", err.Error())
return nil, status.Errorf(codes.Internal, "error gathering parameters for prepared statement query: %s", err.Error())
}

stmt.params = args
s.prepared.Store(string(cmd.GetPreparedStatementHandle()), stmt)
return nil
return cmd.GetPreparedStatementHandle(), nil
}

func (s *SQLiteFlightSQLServer) DoPutPreparedStatementUpdate(ctx context.Context, cmd flightsql.PreparedStatementUpdate, rdr flight.MessageReader) (int64, error) {
Expand Down
17 changes: 13 additions & 4 deletions go/arrow/flight/flightsql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ func (BaseServer) DoPutCommandSubstraitPlan(context.Context, StatementSubstraitP
return 0, status.Error(codes.Unimplemented, "DoPutCommandSubstraitPlan not implemented")
}

func (BaseServer) DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) error {
return status.Error(codes.Unimplemented, "DoPutPreparedStatementQuery not implemented")
func (BaseServer) DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) ([]byte, error) {
return nil, status.Error(codes.Unimplemented, "DoPutPreparedStatementQuery not implemented")
}

func (BaseServer) DoPutPreparedStatementUpdate(context.Context, PreparedStatementUpdate, flight.MessageReader) (int64, error) {
Expand Down Expand Up @@ -677,7 +677,7 @@ type Server interface {
// Currently anything written to the writer will be ignored. It is in the
// interface for potential future enhancements to avoid having to change
// the interface in the future.
DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) error
DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) ([]byte, error)
// DoPutPreparedStatementUpdate executes an update SQL Prepared statement
// for the specified statement handle. The reader allows providing a sequence
// of uploaded record batches to bind the parameters to. Returns the number
Expand Down Expand Up @@ -990,7 +990,16 @@ func (f *flightSqlServer) DoPut(stream flight.FlightService_DoPutServer) error {
}
return stream.Send(out)
case *pb.CommandPreparedStatementQuery:
return f.srv.DoPutPreparedStatementQuery(stream.Context(), cmd, rdr, &putMetadataWriter{stream})
handle, err := f.srv.DoPutPreparedStatementQuery(stream.Context(), cmd, rdr, &putMetadataWriter{stream})
if err != nil {
return err
}
result := pb.DoPutPreparedStatementResult{PreparedStatementHandle: handle}
out := &flight.PutResult{}
if out.AppMetadata, err = proto.Marshal(&result); err != nil {
return status.Errorf(codes.Internal, "failed to marshal PutResult: %s", err.Error())
}
return stream.Send(out)
case *pb.CommandPreparedStatementUpdate:
recordCount, err := f.srv.DoPutPreparedStatementUpdate(stream.Context(), cmd, rdr)
if err != nil {
Expand Down
Loading

0 comments on commit 7f5797b

Please sign in to comment.