From 62fe977b985e5db4593e8aa90c58a41f79207630 Mon Sep 17 00:00:00 2001 From: Roman Tkachenko Date: Thu, 7 Apr 2022 11:33:39 -0700 Subject: [PATCH] Add auth'd tunnel mode to tsh proxy db command (#11720) --- lib/client/db/postgres/connstring.go | 5 +- lib/client/db/postgres/connstring_test.go | 2 +- lib/srv/db/access_test.go | 113 +++++++++++++++++++ lib/srv/db/local_proxy_test.go | 131 ++++++++++++++++++++++ lib/srv/db/mysql/proxy.go | 17 ++- lib/srv/db/postgres/proxy.go | 23 ++-- lib/srv/db/proxyserver.go | 30 ++++- lib/srv/db/sqlserver/proxy.go | 2 - tool/tsh/db.go | 12 ++ tool/tsh/db_test.go | 112 +++++++++++++++++- tool/tsh/dbcmd.go | 91 +++++++++------ tool/tsh/proxy.go | 93 ++++++++++++--- tool/tsh/tsh.go | 3 + 13 files changed, 561 insertions(+), 73 deletions(-) create mode 100644 lib/srv/db/local_proxy_test.go diff --git a/lib/client/db/postgres/connstring.go b/lib/client/db/postgres/connstring.go index 985d73f58a979..b3dcd91664eb1 100644 --- a/lib/client/db/postgres/connstring.go +++ b/lib/client/db/postgres/connstring.go @@ -27,7 +27,7 @@ import ( ) // GetConnString returns formatted Postgres connection string for the profile. -func GetConnString(c *profile.ConnectProfile) string { +func GetConnString(c *profile.ConnectProfile, noTLS bool) string { connStr := "postgres://" if c.User != "" { // Username may contain special characters in which case it should @@ -39,6 +39,9 @@ func GetConnString(c *profile.ConnectProfile) string { if c.Database != "" { connStr += "/" + c.Database } + if noTLS { + return connStr + } params := []string{ fmt.Sprintf("sslrootcert=%v", c.CACertPath), fmt.Sprintf("sslcert=%v", c.CertPath), diff --git a/lib/client/db/postgres/connstring_test.go b/lib/client/db/postgres/connstring_test.go index 089300d68c826..9b7e4031d5f5c 100644 --- a/lib/client/db/postgres/connstring_test.go +++ b/lib/client/db/postgres/connstring_test.go @@ -84,7 +84,7 @@ func TestConnString(t *testing.T) { CACertPath: caPath, CertPath: certPath, KeyPath: keyPath, - })) + }, false)) }) } } diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 5d06696388270..37d57b4ec0e31 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -1355,6 +1355,30 @@ func (c *testContext) postgresClientWithAddr(ctx context.Context, address, telep }) } +// postgresClientLocalProxy connects to test Postgres through local ALPN proxy. +func (c *testContext) postgresClientLocalProxy(ctx context.Context, teleportUser, dbService, dbUser, dbName string) (*pgconn.PgConn, *alpnproxy.LocalProxy, error) { + route := tlsca.RouteToDatabase{ + ServiceName: dbService, + Protocol: defaults.ProtocolPostgres, + Username: dbUser, + Database: dbName, + } + + // Start local proxy which client will connect to. + proxy, err := c.startLocalALPNProxy(ctx, c.webListener.Addr().String(), teleportUser, route) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + // Client connects to the local proxy without TLS. + conn, err := pgconn.Connect(ctx, fmt.Sprintf("postgres://%v@%v/%v", dbUser, proxy.GetAddr(), dbName)) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return conn, proxy, nil +} + // mysqlClient connects to test MySQL through database access as a specified // Teleport user and database account. func (c *testContext) mysqlClient(teleportUser, dbService, dbUser string) (*mysqlclient.Conn, error) { @@ -1377,6 +1401,29 @@ func (c *testContext) mysqlClientWithAddr(address, teleportUser, dbService, dbUs }) } +// mysqlClientLocalProxy connects to test MySQL through local ALPN proxy. +func (c *testContext) mysqlClientLocalProxy(ctx context.Context, teleportUser, dbService, dbUser string) (*mysqlclient.Conn, *alpnproxy.LocalProxy, error) { + route := tlsca.RouteToDatabase{ + ServiceName: dbService, + Protocol: defaults.ProtocolMySQL, + Username: dbUser, + } + + // Start local proxy which client will connect to. + proxy, err := c.startLocalALPNProxy(ctx, c.webListener.Addr().String(), teleportUser, route) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + // Client connects to the local proxy without TLS. + conn, err := mysqlclient.Connect(proxy.GetAddr(), dbUser, "", "") + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return conn, proxy, nil +} + // mongoClient connects to test MongoDB through database access as a // specified Teleport user and database account. func (c *testContext) mongoClient(ctx context.Context, teleportUser, dbService, dbUser string, opts ...*options.ClientOptions) (*mongo.Client, error) { @@ -1399,6 +1446,41 @@ func (c *testContext) mongoClientWithAddr(ctx context.Context, address, teleport }, opts...) } +// mongoClientLocalProxy connects to test MongoDB through local ALPN proxy. +func (c *testContext) mongoClientLocalProxy(ctx context.Context, teleportUser, dbService, dbUser string) (*mongo.Client, *alpnproxy.LocalProxy, error) { + route := tlsca.RouteToDatabase{ + ServiceName: dbService, + Protocol: defaults.ProtocolMongoDB, + Username: dbUser, + } + + // Start local proxy which client will connect to. + proxy, err := c.startLocalALPNProxy(ctx, c.webListener.Addr().String(), teleportUser, route) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + // Client connects to the local proxy without TLS. + client, err := mongo.Connect(ctx, options.Client(). + ApplyURI("mongodb://"+proxy.GetAddr()). + SetHeartbeatInterval(500*time.Millisecond). + SetServerSelectionTimeout(5*time.Second)) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + // Ping to make sure it connected successfully. + errPing := client.Ping(ctx, nil) + if errPing != nil { + if err := client.Disconnect(ctx); err != nil { + return nil, nil, trace.NewAggregate(errPing, err) + } + return nil, nil, trace.Wrap(errPing) + } + + return client, proxy, nil +} + // redisClient connects to test Redis through database access as a specified Teleport user and database account. func (c *testContext) redisClient(ctx context.Context, teleportUser, dbService, dbUser string, opts ...redis.ClientOptions) (*redis.Client, error) { return c.redisClientWithAddr(ctx, c.webListener.Addr().String(), teleportUser, dbService, dbUser, opts...) @@ -1420,6 +1502,37 @@ func (c *testContext) redisClientWithAddr(ctx context.Context, proxyAddress, tel }, opts...) } +// redisClientLocalProxy connects to test Redis through local ALPN proxy. +func (c *testContext) redisClientLocalProxy(ctx context.Context, teleportUser, dbService, dbUser string) (*redis.Client, *alpnproxy.LocalProxy, error) { + route := tlsca.RouteToDatabase{ + ServiceName: dbService, + Protocol: defaults.ProtocolRedis, + Username: dbUser, + } + + // Start local proxy which client will connect to. + proxy, err := c.startLocalALPNProxy(ctx, c.webListener.Addr().String(), teleportUser, route) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + // Client connects to the local proxy without TLS. + client := goredis.NewClient(&goredis.Options{ + Addr: proxy.GetAddr(), + }) + + // Ping to make sure connection is successful. + errPing := client.Ping(ctx).Err() + if errPing != nil { + if err := client.Close(); err != nil { + return nil, nil, trace.NewAggregate(errPing, err) + } + return nil, nil, trace.Wrap(errPing) + } + + return client, proxy, nil +} + // sqlServerClient connects to the specified SQL Server address. func (c *testContext) sqlServerClient(ctx context.Context, teleportUser, dbService, dbUser, dbName string) (*mssql.Conn, *alpnproxy.LocalProxy, error) { route := tlsca.RouteToDatabase{ diff --git a/lib/srv/db/local_proxy_test.go b/lib/srv/db/local_proxy_test.go new file mode 100644 index 0000000000000..bafa37f0163c2 --- /dev/null +++ b/lib/srv/db/local_proxy_test.go @@ -0,0 +1,131 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package db + +import ( + "context" + "testing" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/srv/db/postgres" + + "github.com/jackc/pgconn" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/bson" +) + +// TestLocalProxyPostgres verifies connecting to a Postgres database +// through the local authenticated ALPN proxy. +func TestLocalProxyPostgres(t *testing.T) { + ctx := context.Background() + testCtx := setupTestContext(ctx, t, withSelfHostedPostgres("postgres")) + go testCtx.startHandlingConnections() + + // Create test user/role. + testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{types.Wildcard}, []string{types.Wildcard}) + + // Try to connect to the database as this user. + conn, proxy, err := testCtx.postgresClientLocalProxy(ctx, "alice", "postgres", "postgres", "postgres") + require.NoError(t, err) + + // Close connection and local proxy after the test. + t.Cleanup(func() { + require.NoError(t, conn.Close(ctx)) + require.NoError(t, proxy.Close()) + }) + + // Execute a query. + result, err := conn.Exec(ctx, "select 1").ReadAll() + require.NoError(t, err) + require.Equal(t, []*pgconn.Result{postgres.TestQueryResponse}, result) +} + +// TestLocalProxyMySQL verifies connecting to a MySQL database +// through the local authenticated ALPN proxy. +func TestLocalProxyMySQL(t *testing.T) { + ctx := context.Background() + testCtx := setupTestContext(ctx, t, withSelfHostedMySQL("mysql")) + go testCtx.startHandlingConnections() + + // Create test user/role. + testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{types.Wildcard}, []string{types.Wildcard}) + + // Connect to the database as this user. + conn, proxy, err := testCtx.mysqlClientLocalProxy(ctx, "alice", "mysql", "alice") + require.NoError(t, err) + + // Close connection and local proxy after the test. + t.Cleanup(func() { + require.NoError(t, conn.Close()) + require.NoError(t, proxy.Close()) + }) + + // Execute a query. + _, err = conn.Execute("select 1") + require.NoError(t, err) +} + +// TestLocalProxyMongoDB verifies connecting to a MongoDB database +// through the local authenticated ALPN proxy. +func TestLocalProxyMongoDB(t *testing.T) { + ctx := context.Background() + testCtx := setupTestContext(ctx, t, withSelfHostedMongo("mongo")) + go testCtx.startHandlingConnections() + + // Create test user/role. + testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{types.Wildcard}, []string{types.Wildcard}) + + // Connect to the database as this user. + client, proxy, err := testCtx.mongoClientLocalProxy(ctx, "alice", "mongo", "admin") + require.NoError(t, err) + + // Close connection and local proxy after the test. + t.Cleanup(func() { + require.NoError(t, client.Disconnect(ctx)) + require.NoError(t, proxy.Close()) + }) + + // Execute a query. + _, err = client.Database("admin").Collection("test").Find(ctx, bson.M{}) + require.NoError(t, err) +} + +// TestLocalProxyRedis verifies connecting to a Redis database +// through the local authenticated ALPN proxy. +func TestLocalProxyRedis(t *testing.T) { + ctx := context.Background() + testCtx := setupTestContext(ctx, t, withSelfHostedRedis("redis")) + go testCtx.startHandlingConnections() + + // Create test user/role. + testCtx.createUserAndRole(ctx, t, "alice", "admin", []string{types.Wildcard}, []string{types.Wildcard}) + + // Connect to the database as this user. + client, proxy, err := testCtx.redisClientLocalProxy(ctx, "alice", "redis", "admin") + require.NoError(t, err) + + // Close connection and local proxy after the test. + t.Cleanup(func() { + require.NoError(t, client.Close()) + require.NoError(t, proxy.Close()) + }) + + // Execute a query. + result := client.Echo(ctx, "ping") + require.NoError(t, result.Err()) + require.Equal(t, "ping", result.Val()) +} diff --git a/lib/srv/db/mysql/proxy.go b/lib/srv/db/mysql/proxy.go index c4938a6663d43..15f16c264a4d5 100644 --- a/lib/srv/db/mysql/proxy.go +++ b/lib/srv/db/mysql/proxy.go @@ -141,6 +141,8 @@ func (p *Proxy) makeServer(clientConn net.Conn) *server.Conn { mysql.DEFAULT_COLLATION_ID, mysql.AUTH_NATIVE_PASSWORD, nil, + // TLS config can actually be nil if the client is connecting + // through local TLS proxy without TLS. p.TLSConfig), &credentialProvider{}, server.EmptyHandler{}) @@ -170,11 +172,18 @@ func (p *Proxy) performHandshake(conn *multiplexer.Conn, server *server.Conn) (* // First part of the handshake completed and the connection has been // upgraded to TLS so now we can look at the client certificate and // see which database service to route the connection to. - tlsConn, ok := server.Conn.Conn.(*tls.Conn) - if !ok { - return nil, trace.BadParameter("expected TLS connection") + switch c := server.Conn.Conn.(type) { + case *tls.Conn: + return c, nil + case *multiplexer.Conn: + tlsConn, ok := c.Conn.(*tls.Conn) + if !ok { + return nil, trace.BadParameter("expected TLS connection, got: %T", c.Conn) + } + return tlsConn, nil } - return tlsConn, nil + return nil, trace.BadParameter("expected *tls.Conn or *multiplexer.Conn, got: %T", + server.Conn.Conn) } // maybeReadProxyLine peeks into the connection to see if instead of regular diff --git a/lib/srv/db/postgres/proxy.go b/lib/srv/db/postgres/proxy.go index 9404a10f06cdc..5c784792c8567 100644 --- a/lib/srv/db/postgres/proxy.go +++ b/lib/srv/db/postgres/proxy.go @@ -125,14 +125,23 @@ func (p *Proxy) handleStartup(ctx context.Context, clientConn net.Conn) (*pgprot // https://www.postgresql.org/docs/13/protocol-flow.html#id-1.10.5.7.11 switch m := startupMessage.(type) { case *pgproto3.SSLRequest: - // Send 'S' back to indicate TLS support to the client. - _, err := clientConn.Write([]byte("S")) - if err != nil { - return nil, nil, nil, trace.Wrap(err) + if p.TLSConfig == nil { + // Send 'N' back to make the client connect without TLS. Happens + // when client connects through the local TLS proxy. + _, err := clientConn.Write([]byte("N")) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + } else { + // Send 'S' back to indicate TLS support to the client. + _, err := clientConn.Write([]byte("S")) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + // Upgrade the connection to TLS and wait for the next message + // which should be of the StartupMessage type. + clientConn = tls.Server(clientConn, p.TLSConfig) } - // Upgrade the connection to TLS and wait for the next message - // which should be of the StartupMessage type. - clientConn = tls.Server(clientConn, p.TLSConfig) return p.handleStartup(ctx, clientConn) case *pgproto3.StartupMessage: // TLS connection between the client and this proxy has been diff --git a/lib/srv/db/proxyserver.go b/lib/srv/db/proxyserver.go index 3c277360bff00..bb6946d89d399 100644 --- a/lib/srv/db/proxyserver.go +++ b/lib/srv/db/proxyserver.go @@ -322,6 +322,10 @@ func (s *ProxyServer) handleConnection(conn net.Conn) error { return trace.Wrap(err) } switch proxyCtx.Identity.RouteToDatabase.Protocol { + case defaults.ProtocolPostgres: + return s.PostgresProxyNoTLS().HandleConnection(s.closeCtx, tlsConn) + case defaults.ProtocolMySQL: + return s.MySQLProxyNoTLS().HandleConnection(s.closeCtx, tlsConn) case defaults.ProtocolSQLServer: return s.SQLServerProxy().HandleConnection(s.closeCtx, proxyCtx, tlsConn) } @@ -348,6 +352,16 @@ func (s *ProxyServer) PostgresProxy() *postgres.Proxy { } } +// PostgresProxyNoTLS returns a new instance of the non-TLS Postgres proxy. +func (s *ProxyServer) PostgresProxyNoTLS() *postgres.Proxy { + return &postgres.Proxy{ + Middleware: s.middleware, + Service: s, + Limiter: s.cfg.Limiter, + Log: s.log, + } +} + // MySQLProxy returns a new instance of the MySQL protocol aware proxy. func (s *ProxyServer) MySQLProxy() *mysql.Proxy { return &mysql.Proxy{ @@ -359,15 +373,19 @@ func (s *ProxyServer) MySQLProxy() *mysql.Proxy { } } +// MySQLProxyNoTLS returns a new instance of the non-TLS MySQL proxy. +func (s *ProxyServer) MySQLProxyNoTLS() *mysql.Proxy { + return &mysql.Proxy{ + Middleware: s.middleware, + Service: s, + Limiter: s.cfg.Limiter, + Log: s.log, + } +} + // SQLServerProxy returns a new instance of the SQL Server protocol aware proxy. func (s *ProxyServer) SQLServerProxy() *sqlserver.Proxy { - // SQL Server clients don't support client certificates, connections - // come over TLS routing tunnel. - tlsConf := s.cfg.TLSConfig.Clone() - tlsConf.ClientAuth = tls.NoClientCert - tlsConf.GetConfigForClient = nil return &sqlserver.Proxy{ - TLSConfig: tlsConf, Middleware: s.middleware, Service: s, Log: s.log, diff --git a/lib/srv/db/sqlserver/proxy.go b/lib/srv/db/sqlserver/proxy.go index 790914806c807..606480148509e 100644 --- a/lib/srv/db/sqlserver/proxy.go +++ b/lib/srv/db/sqlserver/proxy.go @@ -31,8 +31,6 @@ import ( // Proxy accepts connections from SQL Server clients, performs a Pre-Login // handshake and then forwards the connection to the database service agent. type Proxy struct { - // TLSConfig is the proxy TLS configuration. - TLSConfig *tls.Config // Middleware is the auth middleware. Middleware *auth.Middleware // Service is used to connect to a remote database service. diff --git a/tool/tsh/db.go b/tool/tsh/db.go index ea034ad7ffef5..43bb6686be39c 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -578,6 +578,7 @@ type connectionCommandOpts struct { localProxyPort int localProxyHost string caPath string + noTLS bool } type ConnectCommandFunc func(*connectionCommandOpts) @@ -590,6 +591,17 @@ func WithLocalProxy(host string, port int, caPath string) ConnectCommandFunc { } } +// WithNoTLS is the connect command option that makes the command connect +// without TLS. +// +// It is used when connecting through the local proxy that was started in +// mutual TLS mode (i.e. with a client certificate). +func WithNoTLS() ConnectCommandFunc { + return func(opts *connectionCommandOpts) { + opts.noTLS = true + } +} + func formatDatabaseListCommand(clusterFlag string) string { if clusterFlag == "" { return "tsh db ls" diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go index 382f8e58cea8a..10bfea69a10d1 100644 --- a/tool/tsh/db_test.go +++ b/tool/tsh/db_test.go @@ -373,6 +373,7 @@ func TestCliCommandBuilderGetConnectCommand(t *testing.T) { databaseName string execer *fakeExec cmd []string + noTLS bool wantErr bool }{ { @@ -385,6 +386,15 @@ func TestCliCommandBuilderGetConnectCommand(t *testing.T) { "sslkey=/tmp/keys/example.com/bob&sslmode=verify-full"}, wantErr: false, }, + { + name: "postgres no TLS", + dbProtocol: defaults.ProtocolPostgres, + databaseName: "mydb", + noTLS: true, + cmd: []string{"psql", + "postgres://myUser@localhost:12345/mydb"}, + wantErr: false, + }, { name: "cockroach", dbProtocol: defaults.ProtocolCockroachDB, @@ -400,6 +410,20 @@ func TestCliCommandBuilderGetConnectCommand(t *testing.T) { "sslkey=/tmp/keys/example.com/bob&sslmode=verify-full"}, wantErr: false, }, + { + name: "cockroach no TLS", + dbProtocol: defaults.ProtocolCockroachDB, + databaseName: "mydb", + noTLS: true, + execer: &fakeExec{ + execOutput: map[string][]byte{ + "cockroach": []byte(""), + }, + }, + cmd: []string{"cockroach", "sql", "--url", + "postgres://myUser@localhost:12345/mydb"}, + wantErr: false, + }, { name: "cockroach psql fallback", dbProtocol: defaults.ProtocolCockroachDB, @@ -432,6 +456,24 @@ func TestCliCommandBuilderGetConnectCommand(t *testing.T) { "--ssl-verify-server-cert"}, wantErr: false, }, + { + name: "mariadb no TLS", + dbProtocol: defaults.ProtocolMySQL, + databaseName: "mydb", + noTLS: true, + execer: &fakeExec{ + execOutput: map[string][]byte{ + "mariadb": []byte(""), + }, + }, + cmd: []string{"mariadb", + "--user", "myUser", + "--database", "mydb", + "--port", "12345", + "--host", "localhost", + "--protocol", "TCP"}, + wantErr: false, + }, { name: "mysql by mariadb", dbProtocol: defaults.ProtocolMySQL, @@ -471,6 +513,24 @@ func TestCliCommandBuilderGetConnectCommand(t *testing.T) { "--protocol", "TCP"}, wantErr: false, }, + { + name: "mysql no TLS", + dbProtocol: defaults.ProtocolMySQL, + databaseName: "mydb", + noTLS: true, + execer: &fakeExec{ + execOutput: map[string][]byte{ + "mysql": []byte("Ver 8.0.27-0ubuntu0.20.04.1 for Linux on x86_64 ((Ubuntu))"), + }, + }, + cmd: []string{"mysql", + "--user", "myUser", + "--database", "mydb", + "--port", "12345", + "--host", "localhost", + "--protocol", "TCP"}, + wantErr: false, + }, { name: "no mysql nor mariadb", dbProtocol: defaults.ProtocolMySQL, @@ -496,6 +556,20 @@ func TestCliCommandBuilderGetConnectCommand(t *testing.T) { "mydb"}, wantErr: false, }, + { + name: "mongodb no TLS", + dbProtocol: defaults.ProtocolMongoDB, + databaseName: "mydb", + noTLS: true, + execer: &fakeExec{ + execOutput: map[string][]byte{}, + }, + cmd: []string{"mongo", + "--host", "localhost", + "--port", "12345", + "mydb"}, + wantErr: false, + }, { name: "mongosh", dbProtocol: defaults.ProtocolMongoDB, @@ -510,6 +584,22 @@ func TestCliCommandBuilderGetConnectCommand(t *testing.T) { "--port", "12345", "--tls", "--tlsCertificateKeyFile", "/tmp/keys/example.com/bob-db/db.example.com/mysql-x509.pem", + "--tlsUseSystemCA", + "mydb"}, + }, + { + name: "mongosh no TLS", + dbProtocol: defaults.ProtocolMongoDB, + databaseName: "mydb", + noTLS: true, + execer: &fakeExec{ + execOutput: map[string][]byte{ + "mongosh": []byte("1.1.6"), + }, + }, + cmd: []string{"mongosh", + "--host", "localhost", + "--port", "12345", "mydb"}, }, { @@ -528,9 +618,9 @@ func TestCliCommandBuilderGetConnectCommand(t *testing.T) { name: "redis-cli", dbProtocol: defaults.ProtocolRedis, cmd: []string{"redis-cli", - "--tls", "-h", "localhost", "-p", "12345", + "--tls", "--key", "/tmp/keys/example.com/bob", "--cert", "/tmp/keys/example.com/bob-db/db.example.com/mysql-x509.pem"}, wantErr: false, @@ -540,14 +630,23 @@ func TestCliCommandBuilderGetConnectCommand(t *testing.T) { dbProtocol: defaults.ProtocolRedis, databaseName: "2", cmd: []string{"redis-cli", - "--tls", "-h", "localhost", "-p", "12345", + "--tls", "--key", "/tmp/keys/example.com/bob", "--cert", "/tmp/keys/example.com/bob-db/db.example.com/mysql-x509.pem", "-n", "2"}, wantErr: false, }, + { + name: "redis-cli no TLS", + dbProtocol: defaults.ProtocolRedis, + noTLS: true, + cmd: []string{"redis-cli", + "-h", "localhost", + "-p", "12345"}, + wantErr: false, + }, } for _, tt := range tests { @@ -562,7 +661,14 @@ func TestCliCommandBuilderGetConnectCommand(t *testing.T) { ServiceName: "mysql", } - c := newCmdBuilder(tc, profile, database, "root", WithLocalProxy("localhost", 12345, "")) + opts := []ConnectCommandFunc{ + WithLocalProxy("localhost", 12345, ""), + } + if tt.noTLS { + opts = append(opts, WithNoTLS()) + } + + c := newCmdBuilder(tc, profile, database, "root", opts...) c.uid = utils.NewFakeUID() c.exe = tt.execer got, err := c.getConnectCommand() diff --git a/tool/tsh/dbcmd.go b/tool/tsh/dbcmd.go index 78c7b8f92c4bc..be5a2788b6c26 100644 --- a/tool/tsh/dbcmd.go +++ b/tool/tsh/dbcmd.go @@ -145,7 +145,7 @@ func (c *cliCommandBuilder) getConnectCommand() (*exec.Cmd, error) { func (c *cliCommandBuilder) getPostgresCommand() *exec.Cmd { return exec.Command(postgresBin, - postgres.GetConnString(db.New(c.tc, *c.db, *c.profile, c.rootCluster, c.host, c.port))) + postgres.GetConnString(db.New(c.tc, *c.db, *c.profile, c.rootCluster, c.host, c.port), c.options.noTLS)) } func (c *cliCommandBuilder) getCockroachCommand() *exec.Cmd { @@ -154,10 +154,10 @@ func (c *cliCommandBuilder) getCockroachCommand() *exec.Cmd { log.Debugf("Couldn't find %q client in PATH, falling back to %q: %v.", cockroachBin, postgresBin, err) return exec.Command(postgresBin, - postgres.GetConnString(db.New(c.tc, *c.db, *c.profile, c.rootCluster, c.host, c.port))) + postgres.GetConnString(db.New(c.tc, *c.db, *c.profile, c.rootCluster, c.host, c.port), c.options.noTLS)) } return exec.Command(cockroachBin, "sql", "--url", - postgres.GetConnString(db.New(c.tc, *c.db, *c.profile, c.rootCluster, c.host, c.port))) + postgres.GetConnString(db.New(c.tc, *c.db, *c.profile, c.rootCluster, c.host, c.port), c.options.noTLS)) } // getMySQLCommonCmdOpts returns common command line arguments for mysql and mariadb. @@ -188,6 +188,11 @@ func (c *cliCommandBuilder) getMySQLCommonCmdOpts() []string { // between Oracle and MariaDB version are covered by getMySQLCommonCmdOpts(). func (c *cliCommandBuilder) getMariaDBArgs() []string { args := c.getMySQLCommonCmdOpts() + + if c.options.noTLS { + return args + } + sslCertPath := c.profile.DatabaseCertPathForCluster(c.tc.SiteName, c.db.ServiceName) args = append(args, []string{"--ssl-key", c.profile.KeyPath()}...) @@ -208,6 +213,10 @@ func (c *cliCommandBuilder) getMariaDBArgs() []string { func (c *cliCommandBuilder) getMySQLOracleCommand() *exec.Cmd { args := c.getMySQLCommonCmdOpts() + if c.options.noTLS { + return exec.Command(mysqlBin, args...) + } + // defaults-group-suffix must be first. groupSuffix := []string{fmt.Sprintf("--defaults-group-suffix=_%v-%v", c.tc.SiteName, c.db.ServiceName)} args = append(groupSuffix, args...) @@ -287,33 +296,45 @@ func (c *cliCommandBuilder) getMongoCommand() *exec.Cmd { // look for `mongosh` hasMongosh := c.isMongoshBinAvailable() - // Starting with Mongo 4.2 there is an updated set of flags. - // We are using them with `mongosh` as otherwise warnings will get displayed. - type tlsFlags struct { - tls string - tlsCertKeyFile string - tlsCAFile string - } - - var flags tlsFlags - - if hasMongosh { - flags = tlsFlags{tls: "--tls", tlsCertKeyFile: "--tlsCertificateKeyFile", tlsCAFile: "--tlsCAFile"} - } else { - flags = tlsFlags{tls: "--ssl", tlsCertKeyFile: "--sslPEMKeyFile", tlsCAFile: "--sslCAFile"} - } - args := []string{ "--host", c.host, "--port", strconv.Itoa(c.port), - flags.tls, - flags.tlsCertKeyFile, c.profile.DatabaseCertPathForCluster(c.tc.SiteName, c.db.ServiceName), } - if c.options.caPath != "" { - // caPath is set only if mongo connects to the Teleport Proxy via ALPN SNI Local Proxy - // and connection is terminated by proxy identity certificate. - args = append(args, []string{flags.tlsCAFile, c.options.caPath}...) + if !c.options.noTLS { + // Starting with Mongo 4.2 there is an updated set of flags. + // We are using them with `mongosh` as otherwise warnings will get displayed. + type tlsFlags struct { + tls string + tlsCertKeyFile string + tlsCAFile string + } + + var flags tlsFlags + + if hasMongosh { + flags = tlsFlags{tls: "--tls", tlsCertKeyFile: "--tlsCertificateKeyFile", tlsCAFile: "--tlsCAFile"} + } else { + flags = tlsFlags{tls: "--ssl", tlsCertKeyFile: "--sslPEMKeyFile", tlsCAFile: "--sslCAFile"} + } + + args = append(args, + flags.tls, + flags.tlsCertKeyFile, + c.profile.DatabaseCertPathForCluster(c.tc.SiteName, c.db.ServiceName)) + + // mongosh does not load system CAs by default which will cause issues if + // the proxy presents a certificate signed by a non-recognized authority + // which your system trusts (e.g. mkcert). + if hasMongosh { + args = append(args, "--tlsUseSystemCA") + } + + if c.options.caPath != "" { + // caPath is set only if mongo connects to the Teleport Proxy via ALPN SNI Local Proxy + // and connection is terminated by proxy identity certificate. + args = append(args, []string{flags.tlsCAFile, c.options.caPath}...) + } } if c.db.Database != "" { @@ -333,19 +354,23 @@ func (c *cliCommandBuilder) getMongoCommand() *exec.Cmd { func (c *cliCommandBuilder) getRedisCommand() *exec.Cmd { // TODO(jakub): Add "-3" when Teleport adds support for Redis RESP3 protocol. args := []string{ - "--tls", "-h", c.host, "-p", strconv.Itoa(c.port), - "--key", c.profile.KeyPath(), - "--cert", c.profile.DatabaseCertPathForCluster(c.tc.SiteName, c.db.ServiceName), } - if c.tc.InsecureSkipVerify { - args = append(args, "--insecure") - } + if !c.options.noTLS { + args = append(args, + "--tls", + "--key", c.profile.KeyPath(), + "--cert", c.profile.DatabaseCertPathForCluster(c.tc.SiteName, c.db.ServiceName)) - if c.options.caPath != "" { - args = append(args, []string{"--cacert", c.options.caPath}...) + if c.tc.InsecureSkipVerify { + args = append(args, "--insecure") + } + + if c.options.caPath != "" { + args = append(args, []string{"--cacert", c.options.caPath}...) + } } // append database number if provided diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index c109c1b414fbb..ea799e6e2618c 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/utils/keypaths" libclient "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/srv/alpnproxy" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/utils" @@ -151,6 +152,10 @@ func onProxyCommandDB(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } + profile, err := libclient.StatusCurrent(cf.HomePath, cf.Proxy) + if err != nil { + return trace.Wrap(err) + } addr := "localhost:0" if cf.LocalProxyPort != "" { @@ -165,13 +170,25 @@ func onProxyCommandDB(cf *CLIConf) error { log.WithError(err).Warnf("Failed to close listener.") } }() + + // If user requested no client auth, open an authenticated tunnel using + // client cert/key of the database. + certFile := cf.LocalProxyCertFile + if certFile == "" && cf.LocalProxyTunnel { + certFile = profile.DatabaseCertPathForCluster(cf.SiteName, database.ServiceName) + } + keyFile := cf.LocalProxyKeyFile + if keyFile == "" && cf.LocalProxyTunnel { + keyFile = profile.KeyPath() + } + lp, err := mkLocalProxy(cf.Context, localProxyOpts{ proxyAddr: client.WebProxyAddr, protocol: database.Protocol, listener: listener, insecure: cf.InsecureSkipVerify, - certFile: cf.LocalProxyCertFile, - keyFile: cf.LocalProxyKeyFile, + certFile: certFile, + keyFile: keyFile, }) if err != nil { return trace.Wrap(err) @@ -181,20 +198,38 @@ func onProxyCommandDB(cf *CLIConf) error { lp.Close() }() - profile, err := libclient.StatusCurrent(cf.HomePath, cf.Proxy) - if err != nil { - return trace.Wrap(err) - } - - err = dbProxyTpl.Execute(os.Stdout, map[string]string{ - "database": database.ServiceName, - "address": listener.Addr().String(), - "ca": profile.CACertPathForCluster(rootCluster), - "cert": profile.DatabaseCertPathForCluster(cf.SiteName, database.ServiceName), - "key": profile.KeyPath(), - }) - if err != nil { - return trace.Wrap(err) + if cf.LocalProxyTunnel { + addr, err := utils.ParseAddr(lp.GetAddr()) + if err != nil { + return trace.Wrap(err) + } + cmd, err := newCmdBuilder(client, profile, database, cf.SiteName, + WithLocalProxy("localhost", addr.Port(0), ""), + WithNoTLS()).getConnectCommand() + if err != nil { + return trace.Wrap(err) + } + err = dbProxyAuthTpl.Execute(os.Stdout, map[string]string{ + "database": database.ServiceName, + "type": dbProtocolToText(database.Protocol), + "cluster": profile.Cluster, + "command": cmd.String(), + "address": listener.Addr().String(), + }) + if err != nil { + return trace.Wrap(err) + } + } else { + err = dbProxyTpl.Execute(os.Stdout, map[string]string{ + "database": database.ServiceName, + "address": listener.Addr().String(), + "ca": profile.CACertPathForCluster(rootCluster), + "cert": profile.DatabaseCertPathForCluster(cf.SiteName, database.ServiceName), + "key": profile.KeyPath(), + }) + if err != nil { + return trace.Wrap(err) + } } defer lp.Close() @@ -263,3 +298,29 @@ Use following credentials to connect to the {{.database}} proxy: cert_file={{.cert}} key_file={{.key}} `)) + +func dbProtocolToText(protocol string) string { + switch protocol { + case defaults.ProtocolPostgres: + return "PostgreSQL" + case defaults.ProtocolCockroachDB: + return "CockroachDB" + case defaults.ProtocolMySQL: + return "MySQL" + case defaults.ProtocolMongoDB: + return "MongoDB" + case defaults.ProtocolRedis: + return "Redis" + case defaults.ProtocolSQLServer: + return "SQL Server" + } + return "" +} + +// dbProxyAuthTpl is the message that's printed for an authenticated db proxy. +var dbProxyAuthTpl = template.Must(template.New("").Parse( + `Started authenticated tunnel for the {{.type}} database "{{.database}}" in cluster "{{.cluster}}" on {{.address}}. + +Use the following command to connect to the database: + $ {{.command}} +`)) diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 6ea5c4c798b96..810c8b15eee36 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -275,6 +275,8 @@ type CLIConf struct { LocalProxyCertFile string // LocalProxyKeyFile is the client key used by local proxy. LocalProxyKeyFile string + // LocalProxyTunnel specifies whether local proxy will open auth'd tunnel. + LocalProxyTunnel bool // ConfigProxyTarget is the node which should be connected to in `tsh config-proxy`. ConfigProxyTarget string @@ -467,6 +469,7 @@ func Run(args []string, opts ...cliOption) error { proxyDB.Flag("port", " Specifies the source port used by proxy db listener").Short('p').StringVar(&cf.LocalProxyPort) proxyDB.Flag("cert-file", "Certificate file for proxy client TLS configuration").StringVar(&cf.LocalProxyCertFile) proxyDB.Flag("key-file", "Key file for proxy client TLS configuration").StringVar(&cf.LocalProxyKeyFile) + proxyDB.Flag("tunnel", "Open authenticated tunnel using database's client certificate so clients don't need to authenticate").BoolVar(&cf.LocalProxyTunnel) // Databases. db := app.Command("db", "View and control proxied databases.")