diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index 78ae01b441..5d40e7cbc7 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -21,6 +21,7 @@ package flightsql_test import ( "context" + "encoding/json" "errors" "fmt" "net" @@ -41,6 +42,7 @@ import ( "github.com/apache/arrow/go/v16/arrow/flight" "github.com/apache/arrow/go/v16/arrow/flight/flightsql" "github.com/apache/arrow/go/v16/arrow/flight/flightsql/schema_ref" + flightproto "github.com/apache/arrow/go/v16/arrow/flight/gen/flight" "github.com/apache/arrow/go/v16/arrow/memory" "github.com/golang/protobuf/ptypes/wrappers" "github.com/stretchr/testify/suite" @@ -134,6 +136,10 @@ func TestMultiTable(t *testing.T) { suite.Run(t, &MultiTableTests{}) } +func TestSessionOptions(t *testing.T) { + suite.Run(t, &SessionOptionTests{}) +} + // ---- AuthN Tests -------------------- type AuthnTestServer struct { @@ -1654,3 +1660,217 @@ func (suite *MultiTableTests) TestGetTableSchema() { expectedSchema := arrow.NewSchema([]arrow.Field{{Name: "b", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) suite.Equal(expectedSchema, actualSchema) } + +// ---- Session Option Tests -------------------- + +// TODO: configure cookies +type SessionOptionTestServer struct { + flightsql.BaseServer + options map[string]interface{} +} + +func (server *SessionOptionTestServer) GetSessionOptions(ctx context.Context, req *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error) { + options := make(map[string]*flight.SessionOptionValue) + for k, v := range server.options { + switch s := v.(type) { + case bool: + options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_BoolValue{BoolValue: s}} + case float64: + options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_DoubleValue{DoubleValue: s}} + case int64: + options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_Int64Value{Int64Value: s}} + case string: + options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_StringValue{StringValue: s}} + case []string: + options[k] = &flight.SessionOptionValue{OptionValue: &flightproto.SessionOptionValue_StringListValue_{StringListValue: &flightproto.SessionOptionValue_StringListValue{Values: s}}} + case nil: + options[k] = &flight.SessionOptionValue{} + default: + panic("not implemented") + } + } + return &flight.GetSessionOptionsResult{ + SessionOptions: options, + }, nil +} + +func (server *SessionOptionTestServer) SetSessionOptions(ctx context.Context, req *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error) { + errors := map[string]*flightproto.SetSessionOptionsResult_Error{} + for k, v := range req.SessionOptions { + switch k { + case "bad name": + errors[k] = &flightproto.SetSessionOptionsResult_Error{Value: flightproto.SetSessionOptionsResult_INVALID_NAME} + continue + case "bad value": + errors[k] = &flightproto.SetSessionOptionsResult_Error{Value: flightproto.SetSessionOptionsResult_INVALID_VALUE} + continue + case "error": + errors[k] = &flightproto.SetSessionOptionsResult_Error{Value: flightproto.SetSessionOptionsResult_ERROR} + continue + } + switch s := v.GetOptionValue().(type) { + case *flightproto.SessionOptionValue_BoolValue: + server.options[k] = s.BoolValue + case *flightproto.SessionOptionValue_DoubleValue: + server.options[k] = s.DoubleValue + case *flightproto.SessionOptionValue_Int64Value: + server.options[k] = s.Int64Value + case *flightproto.SessionOptionValue_StringValue: + server.options[k] = s.StringValue + case *flightproto.SessionOptionValue_StringListValue_: + server.options[k] = s.StringListValue.Values + case nil: + delete(server.options, k) + default: + return nil, status.Error(codes.InvalidArgument, "invalid option type") + } + } + return &flight.SetSessionOptionsResult{Errors: errors}, nil +} + +func (server *SessionOptionTestServer) CloseSession(ctx context.Context, req *flight.CloseSessionRequest) (*flight.CloseSessionResult, error) { + return &flight.CloseSessionResult{ + Status: flight.CloseSessionResultClosed, + }, nil +} + +type SessionOptionTests struct { + ServerBasedTests +} + +func (suite *SessionOptionTests) SetupSuite() { + suite.DoSetupSuite(&SessionOptionTestServer{ + options: map[string]interface{}{ + "string": "expected", + "bool": true, + "float64": float64(1.5), + "int64": int64(20), + "catalog": "main", + "schema": "session", + "stringlist": []string{"a", "b", "c"}, + "nilopt": nil, + }, + }, nil, map[string]string{ + driver.OptionCookieMiddleware: "true", + }) +} + +func (suite *SessionOptionTests) TestGetAllOptions() { + val, err := suite.cnxn.(adbc.GetSetOptions).GetOption(driver.OptionSessionOptions) + suite.NoError(err) + + options := make(map[string]interface{}) + suite.NoError(json.Unmarshal([]byte(val), &options)) + // XXX: because Go decodes ints to strings by default. Should we use + // an alternate representation? What happens to int64max? + suite.Equal(float64(20), options["int64"]) + suite.Equal("expected", options["string"]) + // Bit of a hack, but lets servers send "this option exists, but is + // not set" by returning a nil/unset value + suite.Nil(options["nilopt"]) +} + +func (suite *SessionOptionTests) TestGetAllOptionsByte() { + val, err := suite.cnxn.(adbc.GetSetOptions).GetOptionBytes(driver.OptionSessionOptions) + suite.NoError(err) + + options := make(map[string]interface{}) + // XXX: maybe we can return the underlying proto repr here? + suite.NoError(json.Unmarshal(val, &options)) + suite.Equal(float64(20), options["int64"]) + suite.Equal("expected", options["string"]) +} + +func (suite *SessionOptionTests) TestGetSetCatalog() { + val, err := suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog) + suite.NoError(err) + suite.Equal("main", val) + + suite.NoError(suite.cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog, "postgres")) + val, err = suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog) + suite.NoError(err) + suite.Equal("postgres", val) +} + +func (suite *SessionOptionTests) TestGetSetSchema() { + val, err := suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema) + suite.NoError(err) + suite.Equal("session", val) + + suite.NoError(suite.cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentDbSchema, "public")) + val, err = suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema) + suite.NoError(err) + suite.Equal("public", val) +} + +func (suite *SessionOptionTests) TestGetSetBool() { + o := suite.cnxn.(adbc.GetSetOptions) + val, err := o.GetOption(driver.OptionBoolSessionOptionPrefix + "bool") + suite.NoError(err) + suite.Equal("true", val) + + suite.NoError(o.SetOption(driver.OptionBoolSessionOptionPrefix+"bool", "false")) + val, err = o.GetOption(driver.OptionBoolSessionOptionPrefix + "bool") + suite.NoError(err) + suite.Equal("false", val) +} + +func (suite *SessionOptionTests) TestGetSetFloat64() { + o := suite.cnxn.(adbc.GetSetOptions) + val, err := o.GetOptionDouble(driver.OptionSessionOptionPrefix + "float64") + suite.NoError(err) + suite.Equal(1.5, val) + + suite.NoError(o.SetOptionDouble(driver.OptionSessionOptionPrefix+"float64", -42.0)) + val, err = o.GetOptionDouble(driver.OptionSessionOptionPrefix + "float64") + suite.NoError(err) + suite.Equal(-42.0, val) +} + +func (suite *SessionOptionTests) TestGetSetInt64() { + o := suite.cnxn.(adbc.GetSetOptions) + val, err := o.GetOptionInt(driver.OptionSessionOptionPrefix + "int64") + suite.NoError(err) + suite.Equal(int64(20), val) + + suite.NoError(o.SetOptionInt(driver.OptionSessionOptionPrefix+"int64", 128)) + val, err = o.GetOptionInt(driver.OptionSessionOptionPrefix + "int64") + suite.NoError(err) + suite.Equal(int64(128), val) +} + +func (suite *SessionOptionTests) TestGetSetString() { + o := suite.cnxn.(adbc.GetSetOptions) + _, err := o.GetOption(driver.OptionSessionOptionPrefix + "unknown") + suite.ErrorContains(err, "unknown session option 'unknown'") + + suite.NoError(o.SetOption(driver.OptionSessionOptionPrefix+"unknown", "42")) + val, err := o.GetOption(driver.OptionSessionOptionPrefix + "unknown") + suite.NoError(err) + suite.Equal("42", val) + + suite.NoError(o.SetOption(driver.OptionEraseSessionOptionPrefix+"unknown", "")) + _, err = o.GetOption(driver.OptionSessionOptionPrefix + "unknown") + suite.ErrorContains(err, "unknown session option 'unknown'") + + suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"bad name", ""), "Could not set option(s) 'bad name' (invalid name)") + suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"bad value", ""), "Could not set option(s) 'bad value' (invalid value)") + suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"error", ""), "Could not set option(s) 'error' (error setting option)") +} + +func (suite *SessionOptionTests) TestGetSetStringList() { + o := suite.cnxn.(adbc.GetSetOptions) + val, err := o.GetOption(driver.OptionStringListSessionOptionPrefix + "stringlist") + suite.NoError(err) + suite.Equal(`["a","b","c"]`, val) + + suite.NoError(o.SetOption(driver.OptionStringListSessionOptionPrefix+"stringlist", `["foo", "bar"]`)) + val, err = o.GetOption(driver.OptionStringListSessionOptionPrefix + "stringlist") + suite.NoError(err) + suite.Equal(`["foo","bar"]`, val) + + suite.NoError(o.SetOption(driver.OptionStringListSessionOptionPrefix+"stringlist", `[]`)) + val, err = o.GetOption(driver.OptionStringListSessionOptionPrefix + "stringlist") + suite.NoError(err) + suite.Equal(`[]`, val) +} diff --git a/go/adbc/driver/flightsql/flightsql_connection.go b/go/adbc/driver/flightsql/flightsql_connection.go index 2b4ce93e9a..815d1e52df 100644 --- a/go/adbc/driver/flightsql/flightsql_connection.go +++ b/go/adbc/driver/flightsql/flightsql_connection.go @@ -20,6 +20,7 @@ package flightsql import ( "bytes" "context" + "encoding/json" "fmt" "io" "math" @@ -32,6 +33,7 @@ import ( "github.com/apache/arrow/go/v16/arrow/flight" "github.com/apache/arrow/go/v16/arrow/flight/flightsql" "github.com/apache/arrow/go/v16/arrow/flight/flightsql/schema_ref" + flightproto "github.com/apache/arrow/go/v16/arrow/flight/gen/flight" "github.com/apache/arrow/go/v16/arrow/ipc" "github.com/bluele/gcache" "google.golang.org/grpc" @@ -85,6 +87,156 @@ func doGet(ctx context.Context, cl *flightsql.Client, endpoint *flight.FlightEnd return nil, err } +func (c *cnxn) getSessionOptions(ctx context.Context) (map[string]interface{}, error) { + ctx = metadata.NewOutgoingContext(ctx, c.hdrs) + var header, trailer metadata.MD + rawOptions, err := c.cl.GetSessionOptions(ctx, &flight.GetSessionOptionsRequest{}, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) + if err != nil { + return nil, adbcFromFlightStatusWithDetails(err, header, trailer, "GetSessionOptions") + } + options := make(map[string]interface{}, len(rawOptions.SessionOptions)) + for k, rawValue := range rawOptions.SessionOptions { + switch v := rawValue.OptionValue.(type) { + case *flightproto.SessionOptionValue_BoolValue: + options[k] = v.BoolValue + case *flightproto.SessionOptionValue_DoubleValue: + options[k] = v.DoubleValue + case *flightproto.SessionOptionValue_Int64Value: + options[k] = v.Int64Value + case *flightproto.SessionOptionValue_StringValue: + options[k] = v.StringValue + case *flightproto.SessionOptionValue_StringListValue_: + if v.StringListValue.Values == nil { + options[k] = make([]string, 0) + } else { + options[k] = v.StringListValue.Values + } + case nil: + options[k] = nil + default: + return nil, adbc.Error{ + Code: adbc.StatusNotImplemented, + Msg: fmt.Sprintf("[FlightSQL] Unknown session option type %#v", rawValue), + } + } + } + return options, nil +} + +type unsetSessionOption struct{} + +func (c *cnxn) setSessionOptions(ctx context.Context, key string, val interface{}) error { + req := flight.SetSessionOptionsRequest{} + + req.SessionOptions = make(map[string]*flight.SessionOptionValue) + switch v := val.(type) { + case bool: + req.SessionOptions[key] = &flight.SessionOptionValue{ + OptionValue: &flightproto.SessionOptionValue_BoolValue{BoolValue: v}, + } + case float64: + req.SessionOptions[key] = &flight.SessionOptionValue{ + OptionValue: &flightproto.SessionOptionValue_DoubleValue{DoubleValue: v}, + } + case int64: + req.SessionOptions[key] = &flight.SessionOptionValue{ + OptionValue: &flightproto.SessionOptionValue_Int64Value{Int64Value: v}, + } + case string: + req.SessionOptions[key] = &flight.SessionOptionValue{ + OptionValue: &flightproto.SessionOptionValue_StringValue{StringValue: v}, + } + case []string: + req.SessionOptions[key] = &flight.SessionOptionValue{ + OptionValue: &flightproto.SessionOptionValue_StringListValue_{StringListValue: &flightproto.SessionOptionValue_StringListValue{Values: v}}, + } + case unsetSessionOption: + req.SessionOptions[key] = &flight.SessionOptionValue{} + default: + panic("unimplemented case") + } + + var header, trailer metadata.MD + errors, err := c.cl.SetSessionOptions(ctx, &req, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) + if err != nil { + return adbcFromFlightStatusWithDetails(err, header, trailer, "GetSessionOptions") + } + if len(errors.Errors) > 0 { + msg := strings.Builder{} + _, err := msg.WriteString("[Flight SQL] Could not set option(s) ") + if err != nil { + return err + } + + first := true + for k, v := range errors.Errors { + if !first { + _, err = msg.WriteString(", ") + if err != nil { + return err + } + } + first = false + + _, err = msg.WriteString("'") + if err != nil { + return err + } + + _, err = msg.WriteString(k) + if err != nil { + return err + } + + _, err = msg.WriteString("' (") + if err != nil { + return err + } + errmsg := "unknown error" + switch v.Value { + case flightproto.SetSessionOptionsResult_INVALID_NAME: + errmsg = "invalid name" + case flightproto.SetSessionOptionsResult_INVALID_VALUE: + errmsg = "invalid value" + case flightproto.SetSessionOptionsResult_ERROR: + errmsg = "error setting option" + } + _, err = msg.WriteString(errmsg) + if err != nil { + return err + } + _, err = msg.WriteString(")") + if err != nil { + return err + } + } + + return adbc.Error{ + Msg: msg.String(), + Code: adbc.StatusInvalidArgument, + } + } + return nil +} + +func getSessionOption[T any](options map[string]interface{}, key string, defaultVal T, valueType string) (T, error) { + rawValue, ok := options[key] + if !ok { + return defaultVal, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] unknown session option '%s'", key), + Code: adbc.StatusNotFound, + } + } + value, ok := rawValue.(T) + if !ok { + return defaultVal, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] session option %s=%#v is not %s value", key, rawValue, valueType), + Code: adbc.StatusNotFound, + } + } + return value, nil +} + func (c *cnxn) GetOption(key string) (string, error) { if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) { name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix) @@ -114,16 +266,97 @@ func (c *cnxn) GetOption(key string) (string, error) { return adbc.OptionValueEnabled, nil } case adbc.OptionKeyCurrentCatalog: + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return "", err + } + if catalog, ok := options["catalog"]; ok { + if val, ok := catalog.(string); ok { + return val, nil + } + return "", adbc.Error{ + Msg: fmt.Sprintf("[FlightSQL] Server returned non-string catalog %#v", catalog), + Code: adbc.StatusInternal, + } + } return "", adbc.Error{ - Msg: "[Flight SQL] current catalog not supported", + Msg: "[FlightSQL] current catalog not supported", Code: adbc.StatusNotFound, } case adbc.OptionKeyCurrentDbSchema: + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return "", err + } + if schema, ok := options["schema"]; ok { + if val, ok := schema.(string); ok { + return val, nil + } + return "", adbc.Error{ + Msg: fmt.Sprintf("[FlightSQL] Server returned non-string schema %#v", schema), + Code: adbc.StatusInternal, + } + } return "", adbc.Error{ - Msg: "[Flight SQL] current schema not supported", + Msg: "[FlightSQL] current schema not supported", Code: adbc.StatusNotFound, } + case OptionSessionOptions: + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return "", err + } + encoded, err := json.Marshal(options) + if err != nil { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Could not encode option values: %s", err.Error()), + Code: adbc.StatusInternal, + } + } + return string(encoded), nil + } + if strings.HasPrefix(key, OptionSessionOptionPrefix) { + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return "", err + } + name := key[len(OptionSessionOptionPrefix):] + return getSessionOption(options, name, "", "a string") + } + if strings.HasPrefix(key, OptionBoolSessionOptionPrefix) { + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return "", err + } + name := key[len(OptionBoolSessionOptionPrefix):] + v, err := getSessionOption(options, name, false, "a boolean") + if err != nil { + return "", err + } + if v { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil + } + if strings.HasPrefix(key, OptionStringListSessionOptionPrefix) { + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return "", err + } + name := key[len(OptionStringListSessionOptionPrefix):] + v, err := getSessionOption[[]string](options, name, nil, "a string list") + if err != nil { + return "", err + } + encoded, err := json.Marshal(v) + if err != nil { + return "", adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Could not encode option value: %s", err.Error()), + Code: adbc.StatusInternal, + } + } + return string(encoded), nil } return "", adbc.Error{ @@ -133,6 +366,22 @@ func (c *cnxn) GetOption(key string) (string, error) { } func (c *cnxn) GetOptionBytes(key string) ([]byte, error) { + switch key { + case OptionSessionOptions: + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return nil, err + } + encoded, err := json.Marshal(options) + if err != nil { + return nil, adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] Could not encode option values: %s", err.Error()), + Code: adbc.StatusInternal, + } + } + return encoded, nil + } + return nil, adbc.Error{ Msg: "[Flight SQL] unknown connection option", Code: adbc.StatusNotFound, @@ -152,6 +401,14 @@ func (c *cnxn) GetOptionInt(key string) (int64, error) { } return int64(val), nil } + if strings.HasPrefix(key, OptionSessionOptionPrefix) { + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return 0, err + } + name := key[len(OptionSessionOptionPrefix):] + return getSessionOption(options, name, int64(0), "an integer") + } return 0, adbc.Error{ Msg: "[Flight SQL] unknown connection option", @@ -168,6 +425,14 @@ func (c *cnxn) GetOptionDouble(key string) (float64, error) { case OptionTimeoutUpdate: return c.timeouts.updateTimeout.Seconds(), nil } + if strings.HasPrefix(key, OptionSessionOptionPrefix) { + options, err := c.getSessionOptions(context.Background()) + if err != nil { + return 0, err + } + name := key[len(OptionSessionOptionPrefix):] + return getSessionOption(options, name, float64(0.0), "a floating-point") + } return 0.0, adbc.Error{ Msg: "[Flight SQL] unknown connection option", @@ -235,13 +500,50 @@ func (c *cnxn) SetOption(key, value string) error { } } return nil + case adbc.OptionKeyCurrentCatalog: + return c.setSessionOptions(context.Background(), "catalog", value) + case adbc.OptionKeyCurrentDbSchema: + return c.setSessionOptions(context.Background(), "schema", value) + } - default: - return adbc.Error{ - Msg: "[Flight SQL] unknown connection option", - Code: adbc.StatusNotImplemented, + if strings.HasPrefix(key, OptionSessionOptionPrefix) { + name := key[len(OptionSessionOptionPrefix):] + return c.setSessionOptions(context.Background(), name, value) + } + if strings.HasPrefix(key, OptionBoolSessionOptionPrefix) { + name := key[len(OptionBoolSessionOptionPrefix):] + switch value { + case adbc.OptionValueEnabled: + return c.setSessionOptions(context.Background(), name, true) + case adbc.OptionValueDisabled: + return c.setSessionOptions(context.Background(), name, false) + default: + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] invalid boolean session option value %s=%s", name, value), + Code: adbc.StatusNotImplemented, + } } } + if strings.HasPrefix(key, OptionStringListSessionOptionPrefix) { + name := key[len(OptionStringListSessionOptionPrefix):] + stringlist := make([]string, 0) + if err := json.Unmarshal([]byte(value), &stringlist); err != nil { + return adbc.Error{ + Msg: fmt.Sprintf("[Flight SQL] invalid string list session option value %s=%s: %s", name, value, err.Error()), + Code: adbc.StatusNotImplemented, + } + } + return c.setSessionOptions(context.Background(), name, stringlist) + } + if strings.HasPrefix(key, OptionEraseSessionOptionPrefix) { + name := key[len(OptionEraseSessionOptionPrefix):] + return c.setSessionOptions(context.Background(), name, unsetSessionOption{}) + } + + return adbc.Error{ + Msg: "[Flight SQL] unknown connection option", + Code: adbc.StatusNotImplemented, + } } func (c *cnxn) SetOptionBytes(key string, value []byte) error { @@ -256,6 +558,10 @@ func (c *cnxn) SetOptionInt(key string, value int64) error { case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate: return c.timeouts.setTimeout(key, float64(value)) } + if strings.HasPrefix(key, OptionSessionOptionPrefix) { + name := key[len(OptionSessionOptionPrefix):] + return c.setSessionOptions(context.Background(), name, value) + } return adbc.Error{ Msg: "[Flight SQL] unknown connection option", @@ -272,6 +578,10 @@ func (c *cnxn) SetOptionDouble(key string, value float64) error { case OptionTimeoutUpdate: return c.timeouts.setTimeout(key, value) } + if strings.HasPrefix(key, OptionSessionOptionPrefix) { + name := key[len(OptionSessionOptionPrefix):] + return c.setSessionOptions(context.Background(), name, value) + } return adbc.Error{ Msg: "[Flight SQL] unknown connection option", @@ -927,7 +1237,18 @@ func (c *cnxn) Close() error { } } - err := c.cl.Close() + ctx := metadata.NewOutgoingContext(context.Background(), c.hdrs) + var header, trailer metadata.MD + _, err := c.cl.CloseSession(ctx, &flight.CloseSessionRequest{}, grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts) + if err != nil { + grpcStatus := grpcstatus.Convert(err) + // Ignore unimplemented + if grpcStatus.Code() != grpccodes.Unimplemented { + return adbcFromFlightStatusWithDetails(err, header, trailer, "CloseSession") + } + } + + err = c.cl.Close() c.cl = nil return adbcFromFlightStatus(err, "Close") } diff --git a/go/adbc/driver/flightsql/flightsql_driver.go b/go/adbc/driver/flightsql/flightsql_driver.go index df1ae688b4..d437f0829b 100644 --- a/go/adbc/driver/flightsql/flightsql_driver.go +++ b/go/adbc/driver/flightsql/flightsql_driver.go @@ -45,23 +45,28 @@ import ( ) const ( - OptionAuthority = "adbc.flight.sql.client_option.authority" - OptionMTLSCertChain = "adbc.flight.sql.client_option.mtls_cert_chain" - OptionMTLSPrivateKey = "adbc.flight.sql.client_option.mtls_private_key" - OptionSSLOverrideHostname = "adbc.flight.sql.client_option.tls_override_hostname" - OptionSSLSkipVerify = "adbc.flight.sql.client_option.tls_skip_verify" - OptionSSLRootCerts = "adbc.flight.sql.client_option.tls_root_certs" - OptionWithBlock = "adbc.flight.sql.client_option.with_block" - OptionWithMaxMsgSize = "adbc.flight.sql.client_option.with_max_msg_size" - OptionAuthorizationHeader = "adbc.flight.sql.authorization_header" - OptionTimeoutConnect = "adbc.flight.sql.rpc.timeout_seconds.connect" - OptionTimeoutFetch = "adbc.flight.sql.rpc.timeout_seconds.fetch" - OptionTimeoutQuery = "adbc.flight.sql.rpc.timeout_seconds.query" - OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update" - OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header." - OptionCookieMiddleware = "adbc.flight.sql.rpc.with_cookie_middleware" - OptionLastFlightInfo = "adbc.flight.sql.statement.exec.last_flight_info" - infoDriverName = "ADBC Flight SQL Driver - Go" + OptionAuthority = "adbc.flight.sql.client_option.authority" + OptionMTLSCertChain = "adbc.flight.sql.client_option.mtls_cert_chain" + OptionMTLSPrivateKey = "adbc.flight.sql.client_option.mtls_private_key" + OptionSSLOverrideHostname = "adbc.flight.sql.client_option.tls_override_hostname" + OptionSSLSkipVerify = "adbc.flight.sql.client_option.tls_skip_verify" + OptionSSLRootCerts = "adbc.flight.sql.client_option.tls_root_certs" + OptionWithBlock = "adbc.flight.sql.client_option.with_block" + OptionWithMaxMsgSize = "adbc.flight.sql.client_option.with_max_msg_size" + OptionAuthorizationHeader = "adbc.flight.sql.authorization_header" + OptionTimeoutConnect = "adbc.flight.sql.rpc.timeout_seconds.connect" + OptionTimeoutFetch = "adbc.flight.sql.rpc.timeout_seconds.fetch" + OptionTimeoutQuery = "adbc.flight.sql.rpc.timeout_seconds.query" + OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update" + OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header." + OptionCookieMiddleware = "adbc.flight.sql.rpc.with_cookie_middleware" + OptionSessionOptions = "adbc.flight.sql.session.options" + OptionSessionOptionPrefix = "adbc.flight.sql.session.option." + OptionEraseSessionOptionPrefix = "adbc.flight.sql.session.optionerase." + OptionBoolSessionOptionPrefix = "adbc.flight.sql.session.optionbool." + OptionStringListSessionOptionPrefix = "adbc.flight.sql.session.optionstringlist." + OptionLastFlightInfo = "adbc.flight.sql.statement.exec.last_flight_info" + infoDriverName = "ADBC Flight SQL Driver - Go" ) var ( diff --git a/go/adbc/go.mod b/go/adbc/go.mod index 7e7b605ea6..93c96af4b8 100644 --- a/go/adbc/go.mod +++ b/go/adbc/go.mod @@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc go 1.19 require ( - github.com/apache/arrow/go/v16 v16.0.0-20240129203910-c2ca9bcedeb0 + github.com/apache/arrow/go/v16 v16.0.0-20240307132415-1c9a3122c603 github.com/bluele/gcache v0.0.2 github.com/golang/protobuf v1.5.3 github.com/google/uuid v1.6.0 diff --git a/go/adbc/go.sum b/go/adbc/go.sum index 75377ffe63..0911d39f76 100644 --- a/go/adbc/go.sum +++ b/go/adbc/go.sum @@ -27,6 +27,8 @@ github.com/apache/arrow/go/v14 v14.0.2 h1:N8OkaJEOfI3mEZt07BIkvo4sC6XDbL+48MBPWO github.com/apache/arrow/go/v14 v14.0.2/go.mod h1:u3fgh3EdgN/YQ8cVQRguVW3R+seMybFg8QBQ5LU+eBY= github.com/apache/arrow/go/v16 v16.0.0-20240129203910-c2ca9bcedeb0 h1:ooLFCCZ/sq3KDyrcFBxWweB1wTr1oAIgjj1+Zl3WsRw= github.com/apache/arrow/go/v16 v16.0.0-20240129203910-c2ca9bcedeb0/go.mod h1:+HkSDKotr3KDBxj7gTVgj8Egy18Y1ECzQdnY5XsXwlQ= +github.com/apache/arrow/go/v16 v16.0.0-20240307132415-1c9a3122c603 h1:UOXjIpzPxFAsxrtqUa+e8yuVdhMklFi+Uyo6oB+sDK4= +github.com/apache/arrow/go/v16 v16.0.0-20240307132415-1c9a3122c603/go.mod h1:+HkSDKotr3KDBxj7gTVgj8Egy18Y1ECzQdnY5XsXwlQ= github.com/apache/thrift v0.17.0 h1:cMd2aj52n+8VoAtvSvLn4kDC3aZ6IAkBuqWQ2IDu7wo= github.com/apache/thrift v0.17.0/go.mod h1:OLxhMRJxomX+1I/KUw03qoV3mMz16BwaKI+d4fPBx7Q= github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU=