From 631c68a3068ab36487469d9aaebe4ab9bc4cbff7 Mon Sep 17 00:00:00 2001 From: Sven Rebhan <36194019+srebhan@users.noreply.github.com> Date: Wed, 12 Apr 2023 20:21:56 +0200 Subject: [PATCH] GH-35042: [Go][FlightSQL driver] Add TLS configuration (#35051) This PR adds the ability to enable and customize TLS parameters for Golang's `database/sql` driver. * Closes: #35042 Lead-authored-by: Sven Rebhan Co-authored-by: Sven Rebhan <36194019+srebhan@users.noreply.github.com> Co-authored-by: Kemal <223029+disq@users.noreply.github.com> Signed-off-by: Matt Topol --- go/arrow/flight/flightsql/driver/README.md | 81 +++- go/arrow/flight/flightsql/driver/config.go | 90 +++- .../flight/flightsql/driver/config_test.go | 427 ++++++++++++++++++ go/arrow/flight/flightsql/driver/driver.go | 7 - go/arrow/flight/flightsql/driver/errors.go | 26 ++ 5 files changed, 618 insertions(+), 13 deletions(-) create mode 100644 go/arrow/flight/flightsql/driver/config_test.go create mode 100644 go/arrow/flight/flightsql/driver/errors.go diff --git a/go/arrow/flight/flightsql/driver/README.md b/go/arrow/flight/flightsql/driver/README.md index cfb33ba2c6a5d..f81cb9250e1c9 100644 --- a/go/arrow/flight/flightsql/driver/README.md +++ b/go/arrow/flight/flightsql/driver/README.md @@ -33,9 +33,9 @@ connection pooling, transactions combined with ease of use (see (#usage)). --------------------------------------- -## Prerequisits +## Prerequisites -* Go 1.19+ +* Go 1.17+ * Installation via `go get -u github.com/apache/arrow/go/v12/arrow/flight/flightsql` * Backend speaking FlightSQL @@ -111,6 +111,23 @@ to limit the maximum time an operation can take. This prevents calls that wait forever, e.g. if the backend is down or a query is taking very long. When not set, the driver will use an _infinite_ timeout. +#### `tls` + +The `tls` parameter allows to enable and customize Transport-Layer-Security +settings. There are some special values for the parameters: + +* `disabled` or `false` will disable TLS for this server connection. In this + case all other settings are ignored. +* `enabled` or `true` will force TLS for this server connection. In this case + the system settings for trusted CAs etc will be used. +* `skip-verify` will enable TLS for this server connection but will not verify + the server certificate. **This is a security risk and should not be used!** + +Any other value will be interpreted as the name of a custom configuration. Those +configurations must be registered either by +[creating the DSN from configuration](#driver-config-usage) or by calling +`RegisterTLSConfig()` (see [TLS setup](#tls-setup) for details). + ## Driver config usage Alternatively to specifying the DSN directly you can fill the `DriverConfig` @@ -148,4 +165,62 @@ func main() { ## TLS setup -Currently TLS is not yet supported and will be added later. +By specifying the [`tls` parameter](#tls) you can enable +Transport-Layer-Security. Using `tls=enabled` the system settings are used for +verifying the server's certificate. Custom TLS configurations, e.g. when using +self-signed certificates, are referenced by a user-selected name. The underlying +TLS configuration needs to be registered (using the same name) in two ways. + +### TLS setup using `DriverConfig` + +The first way is to create a `DriverConfig` with the `TLSConfig` field set to +the custom config and `TLSConfigName` set to the chosen name. For example + +```golang + ... + + config := flightsql.DriverConfig{ + Address: "localhost:12345", + TLSEnabled: true, + TLSConfigName: "myconfig", + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + } + dsn := config.DSN() + + ... +``` + +will enable TLS forcing the minimum TLS version to 1.2. This custom config will +be registered with the name `myconfig` and the resulting DSN reads + +```text +flightsql://localhost:12345?tls=myconfig` +``` + +If the `TLSConfigName` is omitted a random unique name (UUID) is generated and +referenced in the DSN. This prevents errors from using an already registered +name leading to errors. + +### TLS setup using manual registration + +The second alternative is the manual registration of the custom TLS +configuration. In this case you need to call `RegisterTLSConfig()` in your code + +```golang + myconfig := &tls.Config{MinVersion: tls.VersionTLS12} + if err := flightsql.RegisterTLSConfig("myconfig", myconfig); err != nil { + ... + } + dsn := "flightsql://localhost:12345?tls=myconfig" + + ... +``` + +This will register the custom configuration, constraining the minimim TLS +version, as `myconfig` and then references the registered configuration by +name in the DSN. You can reuse the same TLS configuration by registering once +and then reference in multiple DSNs. Registering multiple configurations with +the same name will throw an error to prevent unintended side-effects due to the +driver-global registry. diff --git a/go/arrow/flight/flightsql/driver/config.go b/go/arrow/flight/flightsql/driver/config.go index d4a785dc6b760..9f1d56a31d582 100644 --- a/go/arrow/flight/flightsql/driver/config.go +++ b/go/arrow/flight/flightsql/driver/config.go @@ -19,9 +19,53 @@ import ( "crypto/tls" "fmt" "net/url" + "sync" "time" + + "github.com/google/uuid" +) + +// TLS configuration registry +var ( + tlsConfigRegistry = map[string]*tls.Config{ + "skip-verify": {InsecureSkipVerify: true}, + } + tlsRegistryMutex sync.Mutex ) +func RegisterTLSConfig(name string, cfg *tls.Config) error { + tlsRegistryMutex.Lock() + defer tlsRegistryMutex.Unlock() + + // Prevent name collisions + if _, found := tlsConfigRegistry[name]; found { + return ErrRegistryEntryExists + } + tlsConfigRegistry[name] = cfg + + return nil +} + +func UnregisterTLSConfig(name string) error { + tlsRegistryMutex.Lock() + defer tlsRegistryMutex.Unlock() + + if _, found := tlsConfigRegistry[name]; !found { + return ErrRegistryNoEntry + } + + delete(tlsConfigRegistry, name) + return nil +} + +func GetTLSConfig(name string) (*tls.Config, bool) { + tlsRegistryMutex.Lock() + defer tlsRegistryMutex.Unlock() + + cfg, found := tlsConfigRegistry[name] + return cfg, found +} + type DriverConfig struct { Address string Username string @@ -30,14 +74,15 @@ type DriverConfig struct { Timeout time.Duration Params map[string]string - TLSEnabled bool - TLSConfig *tls.Config + TLSEnabled bool + TLSConfigName string + TLSConfig *tls.Config } func NewDriverConfigFromDSN(dsn string) (*DriverConfig, error) { u, err := url.Parse(dsn) if err != nil { - return nil, err + return nil, fmt.Errorf("invalid URL: %w", err) } // Sanity checks on the given connection string @@ -83,6 +128,21 @@ func NewDriverConfigFromDSN(dsn string) (*DriverConfig, error) { if err != nil { return nil, err } + case "tls": + switch v { + case "true", "enabled": + config.TLSEnabled = true + case "false", "disabled": + config.TLSEnabled = false + default: + config.TLSEnabled = true + config.TLSConfigName = v + cfg, found := GetTLSConfig(config.TLSConfigName) + if !found { + return nil, fmt.Errorf("%q TLS %w", config.TLSConfigName, ErrRegistryNoEntry) + } + config.TLSConfig = cfg + } default: config.Params[key] = v } @@ -112,6 +172,30 @@ func (config *DriverConfig) DSN() string { if config.Timeout > 0 { values.Add("timeout", config.Timeout.String()) } + if config.TLSEnabled { + switch config.TLSConfigName { + case "skip-verify": + values.Add("tls", "skip-verify") + case "": + // Use system defaults if no config is given + if config.TLSConfig == nil { + values.Add("tls", "enabled") + break + } + // We got a custom TLS configuration but no name, create a unique one + config.TLSConfigName = uuid.NewString() + fallthrough + default: + values.Add("tls", config.TLSConfigName) + if config.TLSConfig != nil { + // Ignore the returned error as we do not care if the config + // was registered before. If this fails and the config is not + // yet registered, the driver will error out when parsing the + // DSN. + _ = RegisterTLSConfig(config.TLSConfigName, config.TLSConfig) + } + } + } for k, v := range config.Params { values.Add(k, v) } diff --git a/go/arrow/flight/flightsql/driver/config_test.go b/go/arrow/flight/flightsql/driver/config_test.go new file mode 100644 index 0000000000000..802cf82b82d46 --- /dev/null +++ b/go/arrow/flight/flightsql/driver/config_test.go @@ -0,0 +1,427 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package driver_test + +import ( + "crypto/tls" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/apache/arrow/go/v12/arrow/flight/flightsql/driver" +) + +func TestConfigTLSRegistry(t *testing.T) { + const cfgname = "bananarama" + + // Check if the 'skip-verify' entry exists + expected := &tls.Config{InsecureSkipVerify: true} + actual, found := driver.GetTLSConfig("skip-verify") + require.True(t, found) + require.EqualValues(t, expected, actual) + + // Make sure the testing entry does not exist + _, found = driver.GetTLSConfig(cfgname) + require.False(t, found) + + // Register a new expected config and check it contains the right config + expected = &tls.Config{ + ServerName: "myserver.company.org", + MinVersion: tls.VersionTLS12, + } + require.NoError(t, driver.RegisterTLSConfig(cfgname, expected)) + actual, found = driver.GetTLSConfig(cfgname) + require.True(t, found) + require.EqualValues(t, expected, actual) + + // Registering the config again will fail + require.ErrorIs(t, driver.RegisterTLSConfig(cfgname, expected), driver.ErrRegistryEntryExists) + + // Unregister the config + require.NoError(t, driver.UnregisterTLSConfig(cfgname)) + _, found = driver.GetTLSConfig(cfgname) + require.False(t, found) + + // Unregistering a non-existing config fails + require.ErrorIs(t, driver.UnregisterTLSConfig(cfgname), driver.ErrRegistryNoEntry) +} + +func TestConfigFromDSNInvalid(t *testing.T) { + testcases := []struct { + name string + dsn string + expected string + }{ + { + name: "empty config", + expected: "invalid scheme", + }, + { + name: "invalid url", + dsn: "flightsql://my host", + expected: "invalid URL", + }, + { + name: "invalid path", + dsn: "flightsql://127.0.0.1/someplace", + expected: "unexpected path", + }, + { + name: "invalid timeout", + dsn: "flightsql://127.0.0.1?timeout=2", + expected: "missing unit in duration", + }, + { + name: "multiple parameters (timeout)", + dsn: "flightsql://127.0.0.1:12345?timeout=123s&timeout=4s", + expected: "too many values", + }, + { + name: "multiple parameters (other)", + dsn: "flightsql://127.0.0.1:12345?foo=1&bar=true&foo=yes", + expected: "too many values", + }, + { + name: "TLS unregistered config", + dsn: "flightsql://127.0.0.1:12345?tls=mycfg", + expected: "TLS entry not registered", + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + actual, err := driver.NewDriverConfigFromDSN(tt.dsn) + require.ErrorContains(t, err, tt.expected) + require.Nil(t, actual) + }) + } +} + +func TestConfigFromDSN(t *testing.T) { + // Register a custom TLS config for testing + tlscfg := &tls.Config{ + ServerName: "myserver.company.org", + MinVersion: tls.VersionTLS12, + } + require.NoError(t, driver.RegisterTLSConfig("mycfg", tlscfg)) + + // Define the test-cases + testcases := []struct { + name string + dsn string + expected *driver.DriverConfig + }{ + { + name: "no authentication", + dsn: "flightsql://127.0.0.1:12345", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Params: make(map[string]string), + }, + }, + { + name: "username only authentication", + dsn: "flightsql://peter@127.0.0.1:12345", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Username: "peter", + Params: make(map[string]string), + }, + }, + { + name: "username and password authentication", + dsn: "flightsql://peter:parker@127.0.0.1:12345", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Username: "peter", + Password: "parker", + Params: make(map[string]string), + }, + }, + { + name: "token authentication", + dsn: "flightsql://127.0.0.1:12345?token=012345abcde6789fgh", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Token: "012345abcde6789fgh", + Params: make(map[string]string), + }, + }, + { + name: "timeout", + dsn: "flightsql://127.0.0.1:12345?timeout=123s", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Timeout: 123 * time.Second, + Params: make(map[string]string), + }, + }, + { + name: "custom parameters", + dsn: "flightsql://127.0.0.1:12345?timeout=200ms&database=mydb&pi=3.14", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Timeout: 200 * time.Millisecond, + Params: map[string]string{ + "database": "mydb", + "pi": "3.14", + }, + }, + }, + { + name: "TLS explicitly disabled", + dsn: "flightsql://127.0.0.1:12345?tls=disabled", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Params: make(map[string]string), + }, + }, + { + name: "TLS explicitly disabled (false)", + dsn: "flightsql://127.0.0.1:12345?tls=false", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Params: make(map[string]string), + }, + }, + { + name: "TLS system settings", + dsn: "flightsql://127.0.0.1:12345?tls=enabled", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + TLSEnabled: true, + Params: make(map[string]string), + }, + }, + { + name: "TLS system settings (true)", + dsn: "flightsql://127.0.0.1:12345?tls=true", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + TLSEnabled: true, + Params: make(map[string]string), + }, + }, + { + name: "TLS insecure skip-verify", + dsn: "flightsql://127.0.0.1:12345?tls=skip-verify", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + TLSEnabled: true, + TLSConfigName: "skip-verify", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + Params: make(map[string]string), + }, + }, + { + name: "TLS custom config", + dsn: "flightsql://127.0.0.1:12345?tls=mycfg", + expected: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + TLSEnabled: true, + TLSConfigName: "mycfg", + TLSConfig: tlscfg, + Params: make(map[string]string), + }, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + actual, err := driver.NewDriverConfigFromDSN(tt.dsn) + require.NoError(t, err) + require.EqualValues(t, tt.expected, actual) + }) + } +} + +func TestDSNFromConfig(t *testing.T) { + // Define the test-cases + testcases := []struct { + name string + expected string + drvcfg *driver.DriverConfig + }{ + { + name: "no authentication", + expected: "flightsql://127.0.0.1:12345", + drvcfg: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Params: make(map[string]string), + }, + }, + { + name: "username only authentication", + expected: "flightsql://peter@127.0.0.1:12345", + drvcfg: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Username: "peter", + Params: make(map[string]string), + }, + }, + { + name: "username and password authentication", + expected: "flightsql://peter:parker@127.0.0.1:12345", + drvcfg: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Username: "peter", + Password: "parker", + Params: make(map[string]string), + }, + }, + { + name: "token authentication", + expected: "flightsql://127.0.0.1:12345?token=012345abcde6789fgh", + drvcfg: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Token: "012345abcde6789fgh", + Params: make(map[string]string), + }, + }, + { + name: "timeout", + expected: "flightsql://127.0.0.1:12345?timeout=3s", + drvcfg: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Timeout: 3 * time.Second, + Params: make(map[string]string), + }, + }, + { + name: "custom parameters", + expected: "flightsql://127.0.0.1:12345?database=mydb&pi=3.14&timeout=20ms", + drvcfg: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Timeout: 20 * time.Millisecond, + Params: map[string]string{ + "database": "mydb", + "pi": "3.14", + }, + }, + }, + { + name: "TLS disabled", + expected: "flightsql://127.0.0.1:12345", + drvcfg: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + Params: make(map[string]string), + }, + }, + { + name: "TLS system settings", + expected: "flightsql://127.0.0.1:12345?tls=enabled", + drvcfg: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + TLSEnabled: true, + Params: make(map[string]string), + }, + }, + { + name: "TLS insecure skip-verify", + expected: "flightsql://127.0.0.1:12345?tls=skip-verify", + drvcfg: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + TLSEnabled: true, + TLSConfigName: "skip-verify", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + Params: make(map[string]string), + }, + }, + { + name: "TLS disabled", + expected: "flightsql://127.0.0.1:12345", + drvcfg: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + TLSEnabled: false, + TLSConfigName: "a random cfg", + TLSConfig: &tls.Config{ + ServerName: "myserver.company.org", + MinVersion: tls.VersionTLS12, + }, + Params: make(map[string]string), + }, + }, + { + name: "TLS custom config", + expected: "flightsql://127.0.0.1:12345?tls=mycfg", + drvcfg: &driver.DriverConfig{ + Address: "127.0.0.1:12345", + TLSEnabled: true, + TLSConfigName: "mycfg", + TLSConfig: &tls.Config{ + ServerName: "myserver.company.org", + MinVersion: tls.VersionTLS12, + }, + Params: make(map[string]string), + }, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + actual := tt.drvcfg.DSN() + require.Equal(t, tt.expected, actual) + }) + } +} + +func TestDSNFromConfigCustomTLS(t *testing.T) { + expected := "flightsql://127.0.0.1:12345?tls=mycustomcfg" + + tlscfg := &tls.Config{ + ServerName: "myserver.company.org", + MinVersion: tls.VersionTLS12, + } + + drvcfg := &driver.DriverConfig{ + Address: "127.0.0.1:12345", + TLSEnabled: true, + TLSConfigName: "mycustomcfg", + TLSConfig: tlscfg, + Params: make(map[string]string), + } + + require.Equal(t, expected, drvcfg.DSN()) + cfg, found := driver.GetTLSConfig("mycustomcfg") + require.True(t, found) + require.EqualValues(t, tlscfg, cfg) +} + +func TestDSNFromConfigUnnamedCustomTLS(t *testing.T) { + expected := "flightsql://127.0.0.1:12345?tls=" + + tlscfg := &tls.Config{ + ServerName: "myserver.company.org", + MinVersion: tls.VersionTLS12, + } + + drvcfg := &driver.DriverConfig{ + Address: "127.0.0.1:12345", + TLSEnabled: true, + TLSConfig: tlscfg, + Params: make(map[string]string), + } + + actual := drvcfg.DSN() + require.NotEmpty(t, drvcfg.TLSConfigName) + // Get the generated UUID and add it to the expected DSN + expected += drvcfg.TLSConfigName + require.Equal(t, expected, actual) + cfg, found := driver.GetTLSConfig(drvcfg.TLSConfigName) + require.True(t, found) + require.EqualValues(t, tlscfg, cfg) +} diff --git a/go/arrow/flight/flightsql/driver/driver.go b/go/arrow/flight/flightsql/driver/driver.go index 1bdbfc029ab93..72d6ea344e481 100644 --- a/go/arrow/flight/flightsql/driver/driver.go +++ b/go/arrow/flight/flightsql/driver/driver.go @@ -19,7 +19,6 @@ import ( "context" "database/sql" "database/sql/driver" - "errors" "fmt" "io" "sort" @@ -35,12 +34,6 @@ import ( "google.golang.org/grpc/credentials/insecure" ) -var ( - ErrNotSupported = errors.New("not supported") - ErrOutOfRange = errors.New("index out of range") - ErrTransactionInProgress = errors.New("transaction still in progress") -) - type Rows struct { schema *arrow.Schema records []arrow.Record diff --git a/go/arrow/flight/flightsql/driver/errors.go b/go/arrow/flight/flightsql/driver/errors.go new file mode 100644 index 0000000000000..908dde4c3edc4 --- /dev/null +++ b/go/arrow/flight/flightsql/driver/errors.go @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package driver + +import "errors" + +var ( + ErrNotSupported = errors.New("not supported") + ErrOutOfRange = errors.New("index out of range") + ErrTransactionInProgress = errors.New("transaction still in progress") + ErrRegistryEntryExists = errors.New("entry already exists") + ErrRegistryNoEntry = errors.New("entry not registered") +)