From 27f48952dded7c92fc04b3f21cc898523f467975 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Mon, 8 Aug 2022 20:09:19 -0400 Subject: [PATCH] simple server tests --- go/arrow/flight/flightsql/server.go | 2 +- go/arrow/flight/flightsql/server_test.go | 138 ++++++++++++++++++++++- 2 files changed, 138 insertions(+), 2 deletions(-) diff --git a/go/arrow/flight/flightsql/server.go b/go/arrow/flight/flightsql/server.go index dfea046dd1018..ebaad4fa433ab 100644 --- a/go/arrow/flight/flightsql/server.go +++ b/go/arrow/flight/flightsql/server.go @@ -210,7 +210,7 @@ func (BaseServer) DoGetDBSchemas(context.Context, GetDBSchemas) (*arrow.Schema, } func (BaseServer) GetFlightInfoTables(context.Context, GetTables, *flight.FlightDescriptor) (*flight.FlightInfo, error) { - return nil, status.Errorf(codes.Unimplemented, "GetTables not implemented") + return nil, status.Errorf(codes.Unimplemented, "GetFlightInfoTables not implemented") } func (BaseServer) DoGetTables(context.Context, GetTables) (*arrow.Schema, <-chan flight.StreamChunk, error) { diff --git a/go/arrow/flight/flightsql/server_test.go b/go/arrow/flight/flightsql/server_test.go index 6831594436887..ece7754bbbdb2 100644 --- a/go/arrow/flight/flightsql/server_test.go +++ b/go/arrow/flight/flightsql/server_test.go @@ -18,15 +18,20 @@ package flightsql_test import ( "context" + "strings" "testing" "github.com/apache/arrow/go/v10/arrow/flight" "github.com/apache/arrow/go/v10/arrow/flight/flightsql" + pb "github.com/apache/arrow/go/v10/arrow/flight/internal/flight" + "github.com/apache/arrow/go/v10/arrow/memory" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" ) var dialOpts = []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())} @@ -62,15 +67,146 @@ func (s *UnimplementedFlightSqlServerSuite) TearDownSuite() { s.s.Shutdown() } +// the following test functions verify that the default base server will +// correctly route requests to the appropriate interface methods based on +// the descriptor types for DoPut/DoGet/DoAction + func (s *UnimplementedFlightSqlServerSuite) TestExecute() { info, err := s.cl.Execute(context.TODO(), "SELECT * FROM IRRELEVANT") st, ok := status.FromError(err) s.True(ok) s.Equal(codes.Unimplemented, st.Code()) - s.Equal("GetFlightInfoStatement not implemented", st.Message()) + s.Equal(st.Message(), "GetFlightInfoStatement not implemented") + s.Nil(info) +} + +func (s *UnimplementedFlightSqlServerSuite) TestGetTables() { + info, err := s.cl.GetTables(context.TODO(), &flightsql.GetTablesOpts{}) + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal(st.Message(), "GetFlightInfoTables not implemented") + s.Nil(info) +} + +func (s *UnimplementedFlightSqlServerSuite) TestGetTableTypes() { + info, err := s.cl.GetTableTypes(context.TODO()) + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal(st.Message(), "GetFlightInfoTableTypes not implemented") + s.Nil(info) +} + +func (s *UnimplementedFlightSqlServerSuite) TestGetPrimaryKeys() { + info, err := s.cl.GetPrimaryKeys(context.TODO(), flightsql.TableRef{}) + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal(st.Message(), "GetFlightInfoPrimaryKeys not implemented") + s.Nil(info) +} + +func (s *UnimplementedFlightSqlServerSuite) TestGetExportedKeys() { + info, err := s.cl.GetExportedKeys(context.TODO(), flightsql.TableRef{}) + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal(st.Message(), "GetFlightInfoExportedKeys not implemented") + s.Nil(info) +} + +func (s *UnimplementedFlightSqlServerSuite) TestGetImportedKeys() { + info, err := s.cl.GetImportedKeys(context.TODO(), flightsql.TableRef{}) + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal(st.Message(), "GetFlightInfoImportedKeys not implemented") + s.Nil(info) +} + +func (s *UnimplementedFlightSqlServerSuite) TestGetCrossReference() { + info, err := s.cl.GetCrossReference(context.TODO(), flightsql.TableRef{}, flightsql.TableRef{}) + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal(st.Message(), "GetFlightInfoCrossReference not implemented") s.Nil(info) } +func (s *UnimplementedFlightSqlServerSuite) TestGetCatalogs() { + info, err := s.cl.GetCatalogs(context.TODO()) + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal(st.Message(), "GetFlightInfoCatalogs not implemented") + s.Nil(info) +} + +func (s *UnimplementedFlightSqlServerSuite) TestGetDBSchemas() { + info, err := s.cl.GetDBSchemas(context.TODO(), &flightsql.GetDBSchemasOpts{}) + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal(st.Message(), "GetFlightInfoSchemas not implemented") + s.Nil(info) +} + +func (s *UnimplementedFlightSqlServerSuite) TestGetTypeInfo() { + info, err := s.cl.GetXdbcTypeInfo(context.TODO(), nil) + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal(st.Message(), "GetFlightInfoXdbcTypeInfo not implemented") + s.Nil(info) +} + +func getTicket(cmd proto.Message) *flight.Ticket { + var anycmd anypb.Any + anycmd.MarshalFrom(cmd) + + data, _ := proto.Marshal(&anycmd) + return &flight.Ticket{ + Ticket: data, + } +} + +func (s *UnimplementedFlightSqlServerSuite) TestDoGet() { + tests := []struct { + name string + ticket proto.Message + }{ + {"DoGetStatement", &pb.TicketStatementQuery{}}, + {"DoGetPreparedStatement", &pb.CommandPreparedStatementQuery{}}, + {"DoGetCatalogs", &pb.CommandGetCatalogs{}}, + {"DoGetDBSchemas", &pb.CommandGetDbSchemas{}}, + {"DoGetTables", &pb.CommandGetTables{}}, + {"DoGetTableTypes", &pb.CommandGetTableTypes{}}, + {"DoGetXdbcTypeInfo", &pb.CommandGetXdbcTypeInfo{}}, + {"DoGetPrimaryKeys", &pb.CommandGetPrimaryKeys{}}, + {"DoGetExportedKeys", &pb.CommandGetExportedKeys{}}, + {"DoGetImportedKeys", &pb.CommandGetImportedKeys{}}, + {"DoGetCrossReference", &pb.CommandGetCrossReference{}}, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + rdr, err := s.cl.DoGet(context.TODO(), getTicket(tt.ticket)) + s.Nil(rdr) + s.True(strings.HasSuffix(err.Error(), tt.name+" not implemented"), err.Error()) + }) + } +} + +func (s *UnimplementedFlightSqlServerSuite) TestDoAction() { + prep, err := s.cl.Prepare(context.TODO(), memory.DefaultAllocator, "IRRELEVANT") + s.Nil(prep) + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal(st.Message(), "CreatePreparedStatement not implemented") +} + func TestBaseServer(t *testing.T) { suite.Run(t, new(UnimplementedFlightSqlServerSuite)) }