From 065d524f62d3194f0e9b83ace377ef9e7ad7c722 Mon Sep 17 00:00:00 2001 From: Jim Bishopp Date: Tue, 1 Feb 2022 07:53:41 -0800 Subject: [PATCH] [v8] Client Certificate Authentication for GCP Cloud SQL (#10059) * Vendor go-mysql update * Client Certificate Authentication for GCP Cloud SQL (#9991) Allow users to secure GCP Cloud SQL instances by setting "Allow only SSL connections", which enforces client certificate authentication. This implementation does not require any configuration changes for Teleport users. Teleport will detect whether client certificate authentication is required and handle either case automatically. Client certificates are ephemeral. They are created for every connection by calling the GCP Cloud SQL API's GenerateEphemeralCert function. Certificates are only created when the destination Cloud SQL instance is configured to require client certificate authentication. The configuration is detected by requesting instance settings from the GCP Cloud SQL API on every connection attempt. A special case was implemented for MySQL. MySQL servers in GCP Cloud SQL do not trust the ephemeral certificate's CA but GCP Cloud Proxy does. To work around this issue, the implementation will connect to the MySQL Cloud Proxy port using a TLS dialer instead of the default MySQL port when client certificate authentication is required. The common.CloudClients interface and implementation now return an interface (GCPSQLAdminClient) from the GetGCPSQLAdminClient function instead of the GCP client's sqladmin.Service. Returning an interface simplified calling code and allowed for the client to be mocked for testing. Existing GCP Cloud SQL tests are configured to not require client certificate authentication by default. A new test named TestGCPRequireSSL was created to simulate client certificate authentication for both Postgres and MySQL. This required some minor changes to the test server code. A new ConnectWithDialer function was added to the github.com/gravitational/go-mysql fork. This function is available upstream in v1.4.0 but other changes upstream resulted in a number of errors and a panic processing network packets. So instead of upgrading, the dialer function was copied to the Teleport fork and a custom version was created instead: v1.1.1-teleport.1. --- go.mod | 2 +- go.sum | 4 +- lib/srv/db/access_test.go | 119 ++++++++++++++++ lib/srv/db/cloud/gcp.go | 68 ++++++++++ lib/srv/db/cloud/mocks.go | 25 ++++ lib/srv/db/common/auth.go | 9 +- lib/srv/db/common/cloud.go | 18 ++- lib/srv/db/common/gcp.go | 127 ++++++++++++++++++ lib/srv/db/common/test.go | 19 ++- lib/srv/db/mysql/engine.go | 69 +++++++++- lib/srv/db/mysql/test.go | 30 +++-- lib/srv/db/postgres/engine.go | 34 ++++- lib/srv/db/server.go | 24 ++-- .../siddontang/go-mysql/client/teleport.go | 49 +++++++ .../siddontang/go-mysql/server/conn.go | 12 +- .../go-mysql/server/handshake_resp.go | 2 +- .../go-mysql/server/initial_handshake.go | 2 +- .../siddontang/go-mysql/server/resp.go | 10 +- .../siddontang/go-mysql/server/teleport.go | 25 ++++ vendor/modules.txt | 2 +- 20 files changed, 580 insertions(+), 70 deletions(-) create mode 100644 lib/srv/db/cloud/gcp.go create mode 100644 lib/srv/db/common/gcp.go create mode 100644 vendor/github.com/siddontang/go-mysql/client/teleport.go create mode 100644 vendor/github.com/siddontang/go-mysql/server/teleport.go diff --git a/go.mod b/go.mod index 11a45fba59e71..bee4818246fb5 100644 --- a/go.mod +++ b/go.mod @@ -213,6 +213,6 @@ replace ( github.com/coreos/go-oidc => github.com/gravitational/go-oidc v0.0.5 github.com/gogo/protobuf => github.com/gravitational/protobuf v1.3.2-0.20201123192827-2b9fcfaffcbf github.com/gravitational/teleport/api => ./api - github.com/siddontang/go-mysql v1.1.0 => github.com/gravitational/go-mysql v1.1.1-0.20210212011549-886316308a77 + github.com/siddontang/go-mysql v1.1.0 => github.com/gravitational/go-mysql v1.1.1-teleport.1 github.com/sirupsen/logrus => github.com/gravitational/logrus v1.4.4-0.20210817004754-047e20245621 ) diff --git a/go.sum b/go.sum index 2cb6c96452aeb..97860e8a963cc 100644 --- a/go.sum +++ b/go.sum @@ -394,8 +394,8 @@ github.com/gravitational/configure v0.0.0-20180808141939-c3428bd84c23 h1:havbccu github.com/gravitational/configure v0.0.0-20180808141939-c3428bd84c23/go.mod h1:XL9nebvlfNVvRzRPWdDcWootcyA0l7THiH/A+W1233g= github.com/gravitational/form v0.0.0-20151109031454-c4048f792f70 h1:To76nCJtM3DI0mdq3nGLzXqTV1wNOJByxv01+u9/BxM= github.com/gravitational/form v0.0.0-20151109031454-c4048f792f70/go.mod h1:88hFR45MpUd23d2vNWE/dYtesU50jKsbz0I9kH7UaBY= -github.com/gravitational/go-mysql v1.1.1-0.20210212011549-886316308a77 h1:ivambM2XeST8qfxeSm+0Y8CP/DlNbS3o/9tSF2KtGFk= -github.com/gravitational/go-mysql v1.1.1-0.20210212011549-886316308a77/go.mod h1:re0JQZ1Cy5dVlIDGq0YksfDIla/GRZlxqOoC0XPSSGE= +github.com/gravitational/go-mysql v1.1.1-teleport.1 h1:062V8u0juCyUvpYMdkYch8JDDw7wf5rdhKaIfhnojDg= +github.com/gravitational/go-mysql v1.1.1-teleport.1/go.mod h1:re0JQZ1Cy5dVlIDGq0YksfDIla/GRZlxqOoC0XPSSGE= github.com/gravitational/go-oidc v0.0.5 h1:kxsCknoOZ+KqIAoYLLdHuQcvcc+SrQlnT7xxIM8oo6o= github.com/gravitational/go-oidc v0.0.5/go.mod h1:SevmOUNdOB0aD9BAIgjptZ6oHkKxMZZgA70nwPfgU/w= github.com/gravitational/kingpin v2.1.11-0.20190130013101-742f2714c145+incompatible h1:CfyZl3nyo9K5lLqOmqvl9/IElY1UCnOWKZiQxJ8HKdA= diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 6d185866b09b7..8e4286ec5e748 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -55,6 +55,7 @@ import ( "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" + sqladmin "google.golang.org/api/sqladmin/v1beta4" ) func TestMain(m *testing.M) { @@ -357,6 +358,70 @@ func TestAccessMySQLServerPacket(t *testing.T) { require.NoError(t, err) } +// TestGCPRequireSSL tests connecting to GCP Cloud SQL Postgres and MySQL +// databases with an ephemeral client certificate. +func TestGCPRequireSSL(t *testing.T) { + ctx := context.Background() + user := "alice" + testCtx := setupTestContext(ctx, t) + testCtx.createUserAndRole(ctx, t, user, "admin", []string{types.Wildcard}, []string{types.Wildcard}) + + // Generate ephemeral cert returned from mock GCP API. + ephemeralCert, err := common.MakeTestClientTLSCert(common.TestClientConfig{ + AuthClient: testCtx.authClient, + AuthServer: testCtx.authServer, + Cluster: testCtx.clusterName, + Username: user, + }) + require.NoError(t, err) + + // Setup database servers for Postgres and MySQL with a mock GCP API that + // will require SSL and return the ephemeral certificate created above. + testCtx.server = testCtx.setupDatabaseServer(ctx, t, agentParams{ + Databases: []types.Database{ + withCloudSQLPostgres("postgres", cloudSQLAuthToken)(t, ctx, testCtx), + withCloudSQLMySQLTLS("mysql", user, cloudSQLPassword)(t, ctx, testCtx), + }, + GCPSQL: &cloud.GCPSQLAdminClientMock{ + EphemeralCert: ephemeralCert, + DatabaseInstance: &sqladmin.DatabaseInstance{ + Settings: &sqladmin.Settings{ + IpConfiguration: &sqladmin.IpConfiguration{ + RequireSsl: true, + }, + }, + }, + }, + }) + go testCtx.startHandlingConnections() + + // Try to connect to postgres. + pgConn, err := testCtx.postgresClient(ctx, user, "postgres", "postgres", "postgres") + require.NoError(t, err) + + // Execute a query. + pgResult, err := pgConn.Exec(ctx, "select 1").ReadAll() + require.NoError(t, err) + require.Equal(t, []*pgconn.Result{postgres.TestQueryResponse}, pgResult) + + // Disconnect. + err = pgConn.Close(ctx) + require.NoError(t, err) + + // Try to connect to MySQL. + mysqlConn, err := testCtx.mysqlClient(user, "mysql", user) + require.NoError(t, err) + + // Execute a query. + mysqlResult, err := mysqlConn.Execute("select 1") + require.NoError(t, err) + require.Equal(t, mysql.TestQueryResponse, mysqlResult) + + // Disconnect. + err = mysqlConn.Close() + require.NoError(t, err) +} + // TestAccessMongoDB verifies access scenarios to a MongoDB database based // on the configured RBAC rules. func TestAccessMongoDB(t *testing.T) { @@ -985,12 +1050,25 @@ type agentParams struct { OnReconcile func(types.Databases) // NoStart indicates server should not be started. NoStart bool + // GCPSQL defines the GCP Cloud SQL mock to use for GCP API calls. + GCPSQL *cloud.GCPSQLAdminClientMock } func (p *agentParams) setDefaults(c *testContext) { if p.HostID == "" { p.HostID = c.hostID } + if p.GCPSQL == nil { + p.GCPSQL = &cloud.GCPSQLAdminClientMock{ + DatabaseInstance: &sqladmin.DatabaseInstance{ + Settings: &sqladmin.Settings{ + IpConfiguration: &sqladmin.IpConfiguration{ + RequireSsl: false, + }, + }, + }, + } + } } func (c *testContext) setupDatabaseServer(ctx context.Context, t *testing.T, p agentParams) *Server { @@ -1056,6 +1134,7 @@ func (c *testContext) setupDatabaseServer(ctx context.Context, t *testing.T, p a RDS: &cloud.RDSMock{}, Redshift: &cloud.RedshiftMock{}, IAM: &cloud.IAMMock{}, + GCPSQL: p.GCPSQL, }, }) require.NoError(t, err) @@ -1317,6 +1396,46 @@ func withCloudSQLMySQL(name, authUser, authToken string) withDatabaseOption { } } +// withCloudSQLMySQLTLS creates a test MySQL server that simulates GCP Cloud SQL +// and requires client authentication using an ephemeral client certificate. +func withCloudSQLMySQLTLS(name, authUser, authToken string) withDatabaseOption { + return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { + mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{ + Name: name, + AuthClient: testCtx.authClient, + AuthUser: authUser, + AuthToken: authToken, + // Cloud SQL presented certificate must have : + // in its CN. + CN: "project-1:instance-1", + // Enable TLS listener. + ListenTLS: true, + }) + require.NoError(t, err) + go mysqlServer.Serve() + t.Cleanup(func() { mysqlServer.Close() }) + database, err := types.NewDatabaseV3(types.Metadata{ + Name: name, + }, types.DatabaseSpecV3{ + Protocol: defaults.ProtocolMySQL, + URI: net.JoinHostPort("localhost", mysqlServer.Port()), + DynamicLabels: dynamicLabels, + GCP: types.GCPCloudSQL{ + ProjectID: "project-1", + InstanceID: "instance-1", + }, + // Set CA cert to pass cert validation. + CACert: string(testCtx.hostCA.GetActiveKeys().TLS[0].Cert), + }) + require.NoError(t, err) + testCtx.mysql[name] = testMySQL{ + db: mysqlServer, + resource: database, + } + return database + } +} + func withAzureMySQL(name, authUser, authToken string) withDatabaseOption { return func(t *testing.T, ctx context.Context, testCtx *testContext) types.Database { mysqlServer, err := mysql.NewTestServer(common.TestServerConfig{ diff --git a/lib/srv/db/cloud/gcp.go b/lib/srv/db/cloud/gcp.go new file mode 100644 index 0000000000000..d1e01211b6242 --- /dev/null +++ b/lib/srv/db/cloud/gcp.go @@ -0,0 +1,68 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cloud + +import ( + "context" + "crypto/tls" + + "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/trace" +) + +// GetGCPRequireSSL requests settings for the project/instance in session from GCP +// and returns true when the instance requires SSL. An access denied error is +// returned when an unauthorized error is returned from GCP. +func GetGCPRequireSSL(ctx context.Context, sessionCtx *common.Session, gcpClient common.GCPSQLAdminClient) (requireSSL bool, err error) { + dbi, err := gcpClient.GetDatabaseInstance(ctx, sessionCtx) + if err != nil { + err = common.ConvertError(err) + if trace.IsAccessDenied(err) { + return false, trace.Wrap(err, `Could not get GCP database instance settings: + + %v + +Make sure Teleport db service has "Cloud SQL Admin" GCP IAM role, +or "cloudsql.instances.get" IAM permission.`, err) + } + return false, trace.Wrap(err, "Failed to get Cloud SQL instance information for %q.", common.GCPServerName(sessionCtx)) + } else if dbi.Settings == nil || dbi.Settings.IpConfiguration == nil { + return false, trace.BadParameter("Failed to find Cloud SQL settings for %q. GCP returned %+v.", common.GCPServerName(sessionCtx), dbi) + } + return dbi.Settings.IpConfiguration.RequireSsl, nil +} + +// AppendGCPClientCert calls the GCP API to generate an ephemeral certificate +// and adds it to the TLS config. An access denied error is returned when the +// generate call fails. +func AppendGCPClientCert(ctx context.Context, sessionCtx *common.Session, gcpClient common.GCPSQLAdminClient, tlsConfig *tls.Config) error { + cert, err := gcpClient.GenerateEphemeralCert(ctx, sessionCtx) + if err != nil { + err = common.ConvertError(err) + if trace.IsAccessDenied(err) { + return trace.Wrap(err, `Cloud not generate GCP ephemeral client certificate: + + %v + +Make sure Teleport db service has "Cloud SQL Admin" GCP IAM role, +or "cloudsql.sslCerts.createEphemeral" IAM permission.`, err) + } + return trace.Wrap(err, "Failed to generate GCP ephemeral client certificate for %q.", common.GCPServerName(sessionCtx)) + } + tlsConfig.Certificates = []tls.Certificate{*cert} + return nil +} diff --git a/lib/srv/db/cloud/mocks.go b/lib/srv/db/cloud/mocks.go index b27521141af97..099493bc11ed7 100644 --- a/lib/srv/db/cloud/mocks.go +++ b/lib/srv/db/cloud/mocks.go @@ -17,6 +17,9 @@ limitations under the License. package cloud import ( + "context" + "crypto/tls" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/iam" @@ -27,7 +30,9 @@ import ( "github.com/aws/aws-sdk-go/service/redshift/redshiftiface" "github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/trace" + sqladmin "google.golang.org/api/sqladmin/v1beta4" ) // STSMock mocks AWS STS API. @@ -307,3 +312,23 @@ func (m *IAMMockUnauth) GetUserPolicyWithContext(ctx aws.Context, input *iam.Get func (m *IAMMockUnauth) PutUserPolicyWithContext(ctx aws.Context, input *iam.PutUserPolicyInput, options ...request.Option) (*iam.PutUserPolicyOutput, error) { return nil, trace.AccessDenied("unauthorized") } + +// GCPSQLAdminClientMock implements the common.GCPSQLAdminClient interface for tests. +type GCPSQLAdminClientMock struct { + // DatabaseInstance is returned from GetDatabaseInstance. + DatabaseInstance *sqladmin.DatabaseInstance + // EphemeralCert is returned from GenerateEphemeralCert. + EphemeralCert *tls.Certificate +} + +func (g *GCPSQLAdminClientMock) UpdateUser(ctx context.Context, sessionCtx *common.Session, user *sqladmin.User) error { + return nil +} + +func (g *GCPSQLAdminClientMock) GetDatabaseInstance(ctx context.Context, sessionCtx *common.Session) (*sqladmin.DatabaseInstance, error) { + return g.DatabaseInstance, nil +} + +func (g *GCPSQLAdminClientMock) GenerateEphemeralCert(ctx context.Context, sessionCtx *common.Session) (*tls.Certificate, error) { + return g.EphemeralCert, nil +} diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go index 2f3fa663f768d..9828762d5951c 100644 --- a/lib/srv/db/common/auth.go +++ b/lib/srv/db/common/auth.go @@ -251,11 +251,8 @@ func (a *dbAuth) GetCloudSQLPassword(ctx context.Context, sessionCtx *Session) ( } // updateCloudSQLUser makes a request to Cloud SQL API to update the provided user. -func (a *dbAuth) updateCloudSQLUser(ctx context.Context, sessionCtx *Session, gcpCloudSQL *sqladmin.Service, user *sqladmin.User) error { - _, err := gcpCloudSQL.Users.Update( - sessionCtx.Database.GetGCP().ProjectID, - sessionCtx.Database.GetGCP().InstanceID, - user).Name(sessionCtx.DatabaseUser).Host("%").Context(ctx).Do() +func (a *dbAuth) updateCloudSQLUser(ctx context.Context, sessionCtx *Session, gcpCloudSQL GCPSQLAdminClient, user *sqladmin.User) error { + err := gcpCloudSQL.UpdateUser(ctx, sessionCtx, user) if err != nil { return trace.AccessDenied(`Could not update Cloud SQL user %q password: @@ -354,7 +351,7 @@ func (a *dbAuth) getTLSConfigVerifyFull(ctx context.Context, sessionCtx *Session // Cloud SQL server presented certificates encode instance names as // ":" in CommonName. This is verified against // the ServerName in a custom connection verification step (see below). - tlsConfig.ServerName = fmt.Sprintf("%v:%v", sessionCtx.Database.GetGCP().ProjectID, sessionCtx.Database.GetGCP().InstanceID) + tlsConfig.ServerName = GCPServerName(sessionCtx) // This just disables default verification. tlsConfig.InsecureSkipVerify = true // This will verify CN and cert chain on each connection. diff --git a/lib/srv/db/common/cloud.go b/lib/srv/db/common/cloud.go index 4ec87e0e5b939..3eda03edea84a 100644 --- a/lib/srv/db/common/cloud.go +++ b/lib/srv/db/common/cloud.go @@ -39,7 +39,6 @@ import ( "github.com/gravitational/trace" "github.com/sirupsen/logrus" "google.golang.org/api/option" - sqladmin "google.golang.org/api/sqladmin/v1beta4" "google.golang.org/grpc" ) @@ -58,7 +57,7 @@ type CloudClients interface { // GetGCPIAMClient returns GCP IAM client. GetGCPIAMClient(context.Context) (*gcpcredentials.IamCredentialsClient, error) // GetGCPSQLAdminClient returns GCP Cloud SQL Admin client. - GetGCPSQLAdminClient(context.Context) (*sqladmin.Service, error) + GetGCPSQLAdminClient(context.Context) (GCPSQLAdminClient, error) // GetAzureCredential returns Azure default token credential chain. GetAzureCredential() (azcore.TokenCredential, error) // Closer closes all initialized clients. @@ -78,7 +77,7 @@ type cloudClients struct { // gcpIAM is the cached GCP IAM client. gcpIAM *gcpcredentials.IamCredentialsClient // gcpSQLAdmin is the cached GCP Cloud SQL Admin client. - gcpSQLAdmin *sqladmin.Service + gcpSQLAdmin GCPSQLAdminClient // azureCredential is the cached Azure credential. azureCredential azcore.TokenCredential // mtx is used for locking. @@ -144,7 +143,7 @@ func (c *cloudClients) GetGCPIAMClient(ctx context.Context) (*gcpcredentials.Iam } // GetGCPSQLAdminClient returns GCP Cloud SQL Admin client. -func (c *cloudClients) GetGCPSQLAdminClient(ctx context.Context) (*sqladmin.Service, error) { +func (c *cloudClients) GetGCPSQLAdminClient(ctx context.Context) (GCPSQLAdminClient, error) { c.mtx.RLock() if c.gcpSQLAdmin != nil { defer c.mtx.RUnlock() @@ -211,14 +210,14 @@ func (c *cloudClients) initGCPIAMClient(ctx context.Context) (*gcpcredentials.Ia return gcpIAM, nil } -func (c *cloudClients) initGCPSQLAdminClient(ctx context.Context) (*sqladmin.Service, error) { +func (c *cloudClients) initGCPSQLAdminClient(ctx context.Context) (GCPSQLAdminClient, error) { c.mtx.Lock() defer c.mtx.Unlock() if c.gcpSQLAdmin != nil { // If some other thread already got here first. return c.gcpSQLAdmin, nil } logrus.Debug("Initializing GCP Cloud SQL Admin client.") - gcpSQLAdmin, err := sqladmin.NewService(ctx) + gcpSQLAdmin, err := NewGCPSQLAdminClient(ctx) if err != nil { return nil, trace.Wrap(err) } @@ -248,6 +247,7 @@ type TestCloudClients struct { Redshift redshiftiface.RedshiftAPI IAM iamiface.IAMAPI STS stsiface.STSAPI + GCPSQL GCPSQLAdminClient } // GetAWSSession returns AWS session for the specified region. @@ -286,10 +286,8 @@ func (c *TestCloudClients) GetGCPIAMClient(ctx context.Context) (*gcpcredentials } // GetGCPSQLAdminClient returns GCP Cloud SQL Admin client. -func (c *TestCloudClients) GetGCPSQLAdminClient(ctx context.Context) (*sqladmin.Service, error) { - return sqladmin.NewService(ctx, - option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), // Insecure must be set for unauth client. - option.WithoutAuthentication()) +func (c *TestCloudClients) GetGCPSQLAdminClient(ctx context.Context) (GCPSQLAdminClient, error) { + return c.GCPSQL, nil } // GetAzureCredential returns default Azure token credential chain. diff --git a/lib/srv/db/common/gcp.go b/lib/srv/db/common/gcp.go new file mode 100644 index 0000000000000..37adec6ccc8b3 --- /dev/null +++ b/lib/srv/db/common/gcp.go @@ -0,0 +1,127 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package common + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "time" + + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/trace" + + sqladmin "google.golang.org/api/sqladmin/v1beta4" +) + +// GCPServerName returns the GCP database project and instance as ":". +func GCPServerName(sessionCtx *Session) string { + gcp := sessionCtx.Database.GetGCP() + return fmt.Sprintf("%s:%s", gcp.ProjectID, gcp.InstanceID) +} + +// GCPSQLAdminClient defines an interface providing access to the GCP Cloud SQL API. +type GCPSQLAdminClient interface { + // UpdateUser updates an existing user for the project/instance configured in a session. + UpdateUser(ctx context.Context, sessionCtx *Session, user *sqladmin.User) error + // GetDatabaseInstance returns database instance details for the project/instance + // configured in a session. + GetDatabaseInstance(ctx context.Context, sessionCtx *Session) (*sqladmin.DatabaseInstance, error) + // GenerateEphemeralCert returns a new client certificate with RSA key for the + // project/instance configured in a session. + GenerateEphemeralCert(ctx context.Context, sessionCtx *Session) (*tls.Certificate, error) +} + +// NewGCPSQLAdminClient returns a GCPSQLAdminClient interface wrapping sqladmin.Service. +func NewGCPSQLAdminClient(ctx context.Context) (GCPSQLAdminClient, error) { + service, err := sqladmin.NewService(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + return &gcpSQLAdminClient{service: service}, nil +} + +// gcpSQLAdminClient implements the GCPSQLAdminClient interface by wrapping +// sqladmin.Service. +type gcpSQLAdminClient struct { + service *sqladmin.Service +} + +// UpdateUser updates an existing user in a Cloud SQL for the project/instance +// configured in a session. +func (g *gcpSQLAdminClient) UpdateUser(ctx context.Context, sessionCtx *Session, user *sqladmin.User) error { + _, err := g.service.Users.Update( + sessionCtx.Database.GetGCP().ProjectID, + sessionCtx.Database.GetGCP().InstanceID, + user).Name(sessionCtx.DatabaseUser).Host("%").Context(ctx).Do() + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// GetDatabaseInstance returns database instance details from Cloud SQL for the +// project/instance configured in a session. +func (g *gcpSQLAdminClient) GetDatabaseInstance(ctx context.Context, sessionCtx *Session) (*sqladmin.DatabaseInstance, error) { + gcp := sessionCtx.Database.GetGCP() + dbi, err := g.service.Instances.Get(gcp.ProjectID, gcp.InstanceID).Context(ctx).Do() + if err != nil { + return nil, trace.Wrap(err) + } + return dbi, nil + +} + +// GenerateEphemeralCert returns a new client certificate with RSA key created +// using the GenerateEphemeralCertRequest Cloud SQL API. Client certificates are +// required when enabling SSL in Cloud SQL. +func (g *gcpSQLAdminClient) GenerateEphemeralCert(ctx context.Context, sessionCtx *Session) (*tls.Certificate, error) { + // TODO(jimbishopp): cache database certificates to avoid expensive generate + // operation on each connection. + + // Generate RSA private key, x509 encoded public key, and append to certificate request. + pkey, err := rsa.GenerateKey(rand.Reader, constants.RSAKeySize) + if err != nil { + return nil, trace.Wrap(err) + } + pkix, err := x509.MarshalPKIXPublicKey(pkey.Public()) + if err != nil { + return nil, trace.Wrap(err) + } + + // Make API call. + gcp := sessionCtx.Database.GetGCP() + req := g.service.Connect.GenerateEphemeralCert(gcp.ProjectID, gcp.InstanceID, &sqladmin.GenerateEphemeralCertRequest{ + PublicKey: string(pem.EncodeToMemory(&pem.Block{Bytes: pkix, Type: "RSA PUBLIC KEY"})), + ValidDuration: fmt.Sprintf("%ds", int(time.Until(sessionCtx.Identity.Expires).Seconds())), + }) + resp, err := req.Context(ctx).Do() + if err != nil { + return nil, trace.Wrap(err) + } + + // Create TLS certificate from returned ephemeral certificate and private key. + cert, err := tls.X509KeyPair([]byte(resp.EphemeralCert.Cert), tlsca.MarshalPrivateKeyPEM(pkey)) + if err != nil { + return nil, trace.Wrap(err) + } + return &cert, nil +} diff --git a/lib/srv/db/common/test.go b/lib/srv/db/common/test.go index 079962bc065de..a274d917dd5f6 100644 --- a/lib/srv/db/common/test.go +++ b/lib/srv/db/common/test.go @@ -51,6 +51,9 @@ type TestServerConfig struct { // Used when simulating test Cloud SQL database which should contains // : in its certificate. CN string + // ListenTLS creates a TLS listener when true instead of using a net listener. + // This is used to simulate MySQL connections through the GCP Cloud SQL Proxy. + ListenTLS bool } // MakeTestServerTLSConfig returns TLS config suitable for configuring test @@ -111,9 +114,9 @@ type TestClientConfig struct { RouteToDatabase tlsca.RouteToDatabase } -// MakeTestClientTLSConfig returns TLS config suitable for configuring test +// MakeTestClientCert returns TLS certificate suitable for configuring test // database Postgres/MySQL clients. -func MakeTestClientTLSConfig(config TestClientConfig) (*tls.Config, error) { +func MakeTestClientTLSCert(config TestClientConfig) (*tls.Certificate, error) { key, err := client.NewKey() if err != nil { return nil, trace.Wrap(err) @@ -132,6 +135,16 @@ func MakeTestClientTLSConfig(config TestClientConfig) (*tls.Config, error) { if err != nil { return nil, trace.Wrap(err) } + return &tlsCert, nil +} + +// MakeTestClientTLSConfig returns TLS config suitable for configuring test +// database Postgres/MySQL clients. +func MakeTestClientTLSConfig(config TestClientConfig) (*tls.Config, error) { + tlsCert, err := MakeTestClientTLSCert(config) + if err != nil { + return nil, trace.Wrap(err) + } ca, err := config.AuthClient.GetCertAuthority(types.CertAuthID{ Type: types.HostCA, DomainName: config.Cluster, @@ -145,7 +158,7 @@ func MakeTestClientTLSConfig(config TestClientConfig) (*tls.Config, error) { } return &tls.Config{ RootCAs: pool, - Certificates: []tls.Certificate{tlsCert}, + Certificates: []tls.Certificate{*tlsCert}, InsecureSkipVerify: true, }, nil } diff --git a/lib/srv/db/mysql/engine.go b/lib/srv/db/mysql/engine.go index 48a5ba4e6f867..0eedc38bf27a4 100644 --- a/lib/srv/db/mysql/engine.go +++ b/lib/srv/db/mysql/engine.go @@ -18,6 +18,7 @@ package mysql import ( "context" + "crypto/tls" "fmt" "net" "time" @@ -26,6 +27,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/db/cloud" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/common/role" "github.com/gravitational/teleport/lib/srv/db/mysql/protocol" @@ -56,6 +58,8 @@ type Engine struct { Context context.Context // Clock is the clock interface. Clock clockwork.Clock + // CloudClients provides access to cloud API clients. + CloudClients common.CloudClients // Log is used for logging. Log logrus.FieldLogger } @@ -153,6 +157,11 @@ func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (*clie return nil, trace.Wrap(err) } user := sessionCtx.DatabaseUser + connectOpt := func(conn *client.Conn) { + conn.SetTLSConfig(tlsConfig) + } + + var dialer client.Dialer var password string switch { case sessionCtx.Database.IsRDS(): @@ -184,6 +193,28 @@ func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (*clie if err != nil { return nil, trace.Wrap(err) } + // Get the client once for subsequent calls (it acquires a read lock). + gcpClient, err := e.CloudClients.GetGCPSQLAdminClient(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + // Detect whether the instance is set to require SSL. + // Fallback to not requiring SSL for access denied errors. + requireSSL, err := cloud.GetGCPRequireSSL(ctx, sessionCtx, gcpClient) + if err != nil && !trace.IsAccessDenied(err) { + return nil, trace.Wrap(err) + } + // Create ephemeral certificate and append to TLS config when + // the instance requires SSL. Also use a TLS dialer instead of + // the default net dialer when GCP requires SSL. + if requireSSL { + err = cloud.AppendGCPClientCert(ctx, sessionCtx, gcpClient, tlsConfig) + if err != nil { + return nil, trace.Wrap(err) + } + connectOpt = func(*client.Conn) {} + dialer = e.newGCPTLSDialer(tlsConfig) + } case sessionCtx.Database.IsAzure(): password, err = e.Auth.GetAzureAccessToken(ctx, sessionCtx) if err != nil { @@ -193,14 +224,20 @@ func (e *Engine) connect(ctx context.Context, sessionCtx *common.Session) (*clie // alice@mysql-server-name. user = fmt.Sprintf("%v@%v", user, sessionCtx.Database.GetAzure().Name) } + + // Use default net dialer unless it is already initialized. + if dialer == nil { + var nd net.Dialer + dialer = nd.DialContext + } + // TODO(r0mant): Set CLIENT_INTERACTIVE flag on the client? - conn, err := client.Connect(sessionCtx.Database.GetURI(), + conn, err := client.ConnectWithDialer(ctx, "tcp", sessionCtx.Database.GetURI(), user, password, sessionCtx.DatabaseName, - func(conn *client.Conn) { - conn.SetTLSConfig(tlsConfig) - }) + dialer, + connectOpt) if err != nil { if trace.IsAccessDenied(common.ConvertError(err)) && sessionCtx.Database.IsRDS() { return nil, trace.AccessDenied(`Could not connect to database: @@ -323,3 +360,27 @@ func (e *Engine) makeAcquireSemaphoreConfig(sessionCtx *common.Session) services }, } } + +// newGCPTLSDialer returns a TLS dialer configured to connect to the Cloud Proxy +// port rather than the default MySQL port. +func (e *Engine) newGCPTLSDialer(tlsConfig *tls.Config) client.Dialer { + return func(ctx context.Context, network, address string) (net.Conn, error) { + // Workaround issue generating ephemeral certificates for secure connections + // by creating a TLS connection to the Cloud Proxy port overridding the + // MySQL client's connection. MySQL on the default port does not trust + // the ephemeral certificate's CA but Cloud Proxy does. + host, port, err := net.SplitHostPort(address) + if err == nil && port == gcpSQLListenPort { + address = net.JoinHostPort(host, gcpSQLProxyListenPort) + } + tlsDialer := tls.Dialer{Config: tlsConfig} + return tlsDialer.DialContext(ctx, network, address) + } +} + +const ( + // gcpSQLListenPort is the port used by Cloud SQL MySQL instances. + gcpSQLListenPort = "3306" + // gcpSQLProxyListenPort is the port used by Cloud Proxy for MySQL instances. + gcpSQLProxyListenPort = "3307" +) diff --git a/lib/srv/db/mysql/test.go b/lib/srv/db/mysql/test.go index 55f6836570077..18b4ae22db6c7 100644 --- a/lib/srv/db/mysql/test.go +++ b/lib/srv/db/mysql/test.go @@ -70,15 +70,20 @@ func NewTestServer(config common.TestServerConfig) (*TestServer, error) { if config.Address != "" { address = config.Address } - listener, err := net.Listen("tcp", address) + tlsConfig, err := common.MakeTestServerTLSConfig(config) if err != nil { return nil, trace.Wrap(err) } - _, port, err := net.SplitHostPort(listener.Addr().String()) + var listener net.Listener + if config.ListenTLS { + listener, err = tls.Listen("tcp", address, tlsConfig) + } else { + listener, err = net.Listen("tcp", address) + } if err != nil { return nil, trace.Wrap(err) } - tlsConfig, err := common.MakeTestServerTLSConfig(config) + _, port, err := net.SplitHostPort(listener.Addr().String()) if err != nil { return nil, trace.Wrap(err) } @@ -86,14 +91,17 @@ func NewTestServer(config common.TestServerConfig) (*TestServer, error) { trace.Component: defaults.ProtocolMySQL, "name": config.Name, }) - return &TestServer{ - cfg: config, - listener: listener, - port: port, - tlsConfig: tlsConfig, - log: log, - handler: &testHandler{log: log}, - }, nil + server := &TestServer{ + cfg: config, + listener: listener, + port: port, + log: log, + handler: &testHandler{log: log}, + } + if !config.ListenTLS { + server.tlsConfig = tlsConfig + } + return server, nil } // Serve starts serving client connections. diff --git a/lib/srv/db/postgres/engine.go b/lib/srv/db/postgres/engine.go index 02239fd90f5ed..f9c29cf7965f4 100644 --- a/lib/srv/db/postgres/engine.go +++ b/lib/srv/db/postgres/engine.go @@ -24,6 +24,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/db/cloud" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/common/role" "github.com/gravitational/teleport/lib/utils" @@ -50,6 +51,8 @@ type Engine struct { Context context.Context // Clock is the clock interface. Clock clockwork.Clock + // CloudClients provides access to cloud API clients. + CloudClients common.CloudClients // Log is used for logging. Log logrus.FieldLogger } @@ -382,6 +385,12 @@ func (e *Engine) getConnectConfig(ctx context.Context, sessionCtx *common.Sessio if err != nil { return nil, trace.Wrap(err) } + // TLS config will use client certificate for an onprem database or + // will contain RDS root certificate for RDS/Aurora. + config.TLSConfig, err = e.Auth.GetTLSConfig(ctx, sessionCtx) + if err != nil { + return nil, trace.Wrap(err) + } config.User = sessionCtx.DatabaseUser config.Database = sessionCtx.DatabaseName // Pgconn adds fallbacks to retry connection without TLS if the TLS @@ -408,6 +417,25 @@ func (e *Engine) getConnectConfig(ctx context.Context, sessionCtx *common.Sessio if err != nil { return nil, trace.Wrap(err) } + // Get the client once for subsequent calls (it acquires a read lock). + gcpClient, err := e.CloudClients.GetGCPSQLAdminClient(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + // Detect whether the instance is set to require SSL. + // Fallback to not requiring SSL for access denied errors. + requireSSL, err := cloud.GetGCPRequireSSL(ctx, sessionCtx, gcpClient) + if err != nil && !trace.IsAccessDenied(err) { + return nil, trace.Wrap(err) + } + // Create ephemeral certificate and append to TLS config when + // the instance requires SSL. + if requireSSL { + err = cloud.AppendGCPClientCert(ctx, sessionCtx, gcpClient, config.TLSConfig) + if err != nil { + return nil, trace.Wrap(err) + } + } case types.DatabaseTypeAzure: config.Password, err = e.Auth.GetAzureAccessToken(ctx, sessionCtx) if err != nil { @@ -417,12 +445,6 @@ func (e *Engine) getConnectConfig(ctx context.Context, sessionCtx *common.Sessio // alice@postgres-server-name. config.User = fmt.Sprintf("%v@%v", config.User, sessionCtx.Database.GetAzure().Name) } - // TLS config will use client certificate for an onprem database or - // will contain RDS root certificate for RDS/Aurora. - config.TLSConfig, err = e.Auth.GetTLSConfig(ctx, sessionCtx) - if err != nil { - return nil, trace.Wrap(err) - } return config, nil } diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 0c7ace13c531e..41d4aa2791304 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -703,20 +703,22 @@ func (s *Server) dispatch(sessionCtx *common.Session, streamWriter events.Stream switch sessionCtx.Database.GetProtocol() { case defaults.ProtocolPostgres, defaults.ProtocolCockroachDB: return &postgres.Engine{ - Auth: s.cfg.Auth, - Audit: audit, - Context: s.closeContext, - Clock: s.cfg.Clock, - Log: sessionCtx.Log, + Auth: s.cfg.Auth, + Audit: audit, + Context: s.closeContext, + Clock: s.cfg.Clock, + CloudClients: s.cfg.CloudClients, + Log: sessionCtx.Log, }, nil case defaults.ProtocolMySQL: return &mysql.Engine{ - Auth: s.cfg.Auth, - Audit: audit, - AuthClient: s.cfg.AuthClient, - Context: s.closeContext, - Clock: s.cfg.Clock, - Log: sessionCtx.Log, + Auth: s.cfg.Auth, + Audit: audit, + AuthClient: s.cfg.AuthClient, + Context: s.closeContext, + Clock: s.cfg.Clock, + CloudClients: s.cfg.CloudClients, + Log: sessionCtx.Log, }, nil case defaults.ProtocolMongoDB: return &mongodb.Engine{ diff --git a/vendor/github.com/siddontang/go-mysql/client/teleport.go b/vendor/github.com/siddontang/go-mysql/client/teleport.go new file mode 100644 index 0000000000000..ceac69c871f39 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/client/teleport.go @@ -0,0 +1,49 @@ +package client + +import ( + "context" + "net" + + "github.com/pingcap/errors" + . "github.com/siddontang/go-mysql/mysql" + "github.com/siddontang/go-mysql/packet" +) + +// Dialer connects to the address on the named network using the provided context. +type Dialer func(ctx context.Context, network, address string) (net.Conn, error) + +// Connect to a MySQL server using the given Dialer. +func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) { + c := new(Conn) + + var err error + conn, err := dialer(ctx, network, addr) + if err != nil { + return nil, errors.Trace(err) + } + + if c.tlsConfig != nil { + c.Conn = packet.NewTLSConn(conn) + } else { + c.Conn = packet.NewConn(conn) + } + + c.user = user + c.password = password + c.db = dbName + c.proto = network + + // use default charset here, utf-8 + c.charset = DEFAULT_CHARSET + + // Apply configuration functions. + for i := range options { + options[i](c) + } + + if err = c.handshake(); err != nil { + return nil, errors.Trace(err) + } + + return c, nil +} diff --git a/vendor/github.com/siddontang/go-mysql/server/conn.go b/vendor/github.com/siddontang/go-mysql/server/conn.go index ec65e18753b34..b25e4f4a59599 100644 --- a/vendor/github.com/siddontang/go-mysql/server/conn.go +++ b/vendor/github.com/siddontang/go-mysql/server/conn.go @@ -109,19 +109,19 @@ func NewCustomizedConn(conn net.Conn, serverConf *Server, p CredentialProvider, } func (c *Conn) handshake() error { - if err := c.WriteInitialHandshake(); err != nil { + if err := c.writeInitialHandshake(); err != nil { return err } - if err := c.ReadHandshakeResponse(); err != nil { + if err := c.readHandshakeResponse(); err != nil { if err == ErrAccessDenied { err = NewDefaultError(ER_ACCESS_DENIED_ERROR, c.user, c.LocalAddr().String(), "Yes") } - c.WriteError(err) + c.writeError(err) return err } - if err := c.WriteOK(nil); err != nil { + if err := c.writeOK(nil); err != nil { return err } @@ -143,10 +143,6 @@ func (c *Conn) GetUser() string { return c.user } -func (c *Conn) GetDatabase() string { - return c.db -} - func (c *Conn) ConnectionID() uint32 { return c.connectionID } diff --git a/vendor/github.com/siddontang/go-mysql/server/handshake_resp.go b/vendor/github.com/siddontang/go-mysql/server/handshake_resp.go index 17ba9f786cbe9..6fba45ab7b8ad 100644 --- a/vendor/github.com/siddontang/go-mysql/server/handshake_resp.go +++ b/vendor/github.com/siddontang/go-mysql/server/handshake_resp.go @@ -9,7 +9,7 @@ import ( . "github.com/siddontang/go-mysql/mysql" ) -func (c *Conn) ReadHandshakeResponse() error { +func (c *Conn) readHandshakeResponse() error { data, pos, err := c.readFirstPart() if err != nil { return err diff --git a/vendor/github.com/siddontang/go-mysql/server/initial_handshake.go b/vendor/github.com/siddontang/go-mysql/server/initial_handshake.go index e782cdc5fc27d..312ac2b683b1e 100644 --- a/vendor/github.com/siddontang/go-mysql/server/initial_handshake.go +++ b/vendor/github.com/siddontang/go-mysql/server/initial_handshake.go @@ -1,7 +1,7 @@ package server // see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html -func (c *Conn) WriteInitialHandshake() error { +func (c *Conn) writeInitialHandshake() error { data := make([]byte, 4) //min version 10 diff --git a/vendor/github.com/siddontang/go-mysql/server/resp.go b/vendor/github.com/siddontang/go-mysql/server/resp.go index 685f9f968ef2f..db863239473ba 100644 --- a/vendor/github.com/siddontang/go-mysql/server/resp.go +++ b/vendor/github.com/siddontang/go-mysql/server/resp.go @@ -6,7 +6,7 @@ import ( . "github.com/siddontang/go-mysql/mysql" ) -func (c *Conn) WriteOK(r *Result) error { +func (c *Conn) writeOK(r *Result) error { if r == nil { r = &Result{} } @@ -28,7 +28,7 @@ func (c *Conn) WriteOK(r *Result) error { return c.WritePacket(data) } -func (c *Conn) WriteError(e error) error { +func (c *Conn) writeError(e error) error { var m *MyError var ok bool if m, ok = e.(*MyError); !ok { @@ -176,14 +176,14 @@ func (c *Conn) writeValue(value interface{}) error { case noResponse: return nil case error: - return c.WriteError(v) + return c.writeError(v) case nil: - return c.WriteOK(nil) + return c.writeOK(nil) case *Result: if v != nil && v.Resultset != nil { return c.writeResultset(v.Resultset) } else { - return c.WriteOK(v) + return c.writeOK(v) } case []*Field: return c.writeFieldList(v) diff --git a/vendor/github.com/siddontang/go-mysql/server/teleport.go b/vendor/github.com/siddontang/go-mysql/server/teleport.go new file mode 100644 index 0000000000000..54a6dcd980dd0 --- /dev/null +++ b/vendor/github.com/siddontang/go-mysql/server/teleport.go @@ -0,0 +1,25 @@ +package server + +import ( + . "github.com/siddontang/go-mysql/mysql" +) + +func (c *Conn) WriteInitialHandshake() error { + return c.writeInitialHandshake() +} + +func (c *Conn) ReadHandshakeResponse() error { + return c.readHandshakeResponse() +} + +func (c *Conn) GetDatabase() string { + return c.db +} + +func (c *Conn) WriteOK(r *Result) error { + return c.writeOK(r) +} + +func (c *Conn) WriteError(e error) error { + return c.writeError(e) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 2e3e7cc50916a..f14f8f5214fb5 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -666,7 +666,7 @@ github.com/siddontang/go/sync2 ## explicit github.com/siddontang/go-log/log github.com/siddontang/go-log/loggers -# github.com/siddontang/go-mysql v1.1.0 => github.com/gravitational/go-mysql v1.1.1-0.20210212011549-886316308a77 +# github.com/siddontang/go-mysql v1.1.0 => github.com/gravitational/go-mysql v1.1.1-teleport.1 ## explicit; go 1.15 github.com/siddontang/go-mysql/client github.com/siddontang/go-mysql/mysql