From 3939e5c2c470c0810a06a6d128a87fe5c9e248c2 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Thu, 9 Dec 2021 11:24:39 -0500 Subject: [PATCH] Add `--cluster` flag to all `tsh db` subcommands, Add "--diag_addr" flag to `teleport db/app start` (#9220) * add diag to teleport db/app start * db --cluster flag supports * add some ut and fix issue ~/.tsh get removed during test * working mongodb * fix logout * fix ut * code review comment * fix mysql --- lib/client/api.go | 78 +++++++++++++++------ lib/client/db/profile.go | 2 +- lib/client/db/profile_test.go | 2 +- lib/client/keyagent.go | 30 ++++++-- lib/client/keyagent_test.go | 36 +++++++++- lib/client/keystore.go | 7 +- lib/client/keystore_test.go | 21 ++++++ tool/teleport/common/teleport.go | 2 + tool/tsh/db.go | 115 +++++++++++++++++++++---------- tool/tsh/db_test.go | 43 +++++++++--- tool/tsh/proxy.go | 2 +- tool/tsh/proxy_test.go | 9 +-- tool/tsh/tsh.go | 33 +++++---- tool/tsh/tsh_test.go | 83 +++++++++++----------- 14 files changed, 321 insertions(+), 142 deletions(-) diff --git a/lib/client/api.go b/lib/client/api.go index cfb92a98e5136..ad6ec235c841d 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -428,12 +428,18 @@ func (p *ProfileStatus) KeyPath() string { return keypaths.UserKeyPath(p.Dir, p.Name, p.Username) } -// DatabaseCertPath returns path to the specified database access certificate -// for this profile. +// DatabaseCertPathForCluster returns path to the specified database access +// certificate for this profile, for the specified cluster. // // It's kept in /keys//-db//-x509.pem -func (p *ProfileStatus) DatabaseCertPath(name string) string { - return keypaths.DatabaseCertPath(p.Dir, p.Name, p.Username, p.Cluster, name) +// +// If the input cluster name is an empty string, the selected cluster in the +// profile will be used. +func (p *ProfileStatus) DatabaseCertPathForCluster(clusterName string, databaseName string) string { + if clusterName == "" { + clusterName = p.Cluster + } + return keypaths.DatabaseCertPath(p.Dir, p.Name, p.Username, clusterName, databaseName) } // AppCertPath returns path to the specified app access certificate @@ -460,6 +466,30 @@ func (p *ProfileStatus) DatabaseServices() (result []string) { return result } +// DatabasesForCluster returns a list of databases for this profile, for the +// specified cluster name. +func (p *ProfileStatus) DatabasesForCluster(clusterName string) ([]tlsca.RouteToDatabase, error) { + if clusterName == "" || clusterName == p.Cluster { + return p.Databases, nil + } + + idx := KeyIndex{ + ProxyHost: p.Name, + Username: p.Username, + ClusterName: clusterName, + } + + store, err := NewFSLocalKeyStore(p.Dir) + if err != nil { + return nil, trace.Wrap(err) + } + key, err := store.GetKey(idx, WithDBCerts{}) + if err != nil { + return nil, trace.Wrap(err) + } + return findActiveDatabases(key) +} + // AppNames returns a list of app names this profile is logged into. func (p *ProfileStatus) AppNames() (result []string) { for _, app := range p.Apps { @@ -596,25 +626,10 @@ func readProfile(profileDir string, profileName string) (*ProfileStatus, error) return nil, trace.Wrap(err) } - dbCerts, err := key.DBTLSCertificates() + databases, err := findActiveDatabases(key) if err != nil { return nil, trace.Wrap(err) } - var databases []tlsca.RouteToDatabase - for _, cert := range dbCerts { - tlsID, err := tlsca.FromSubject(cert.Subject, time.Time{}) - if err != nil { - return nil, trace.Wrap(err) - } - // If the cert expiration time is less than 5s consider cert as expired and don't add - // it to the user profile as an active database. - if time.Until(cert.NotAfter) < 5*time.Second { - continue - } - if tlsID.RouteToDatabase.ServiceName != "" { - databases = append(databases, tlsID.RouteToDatabase) - } - } appCerts, err := key.AppTLSCertificates() if err != nil { @@ -3323,3 +3338,26 @@ func playSession(sessionEvents []events.EventFields, stream []byte) error { return trace.Wrap(err) } } + +func findActiveDatabases(key *Key) ([]tlsca.RouteToDatabase, error) { + dbCerts, err := key.DBTLSCertificates() + if err != nil { + return nil, trace.Wrap(err) + } + var databases []tlsca.RouteToDatabase + for _, cert := range dbCerts { + tlsID, err := tlsca.FromSubject(cert.Subject, time.Time{}) + if err != nil { + return nil, trace.Wrap(err) + } + // If the cert expiration time is less than 5s consider cert as expired and don't add + // it to the user profile as an active database. + if time.Until(cert.NotAfter) < 5*time.Second { + continue + } + if tlsID.RouteToDatabase.ServiceName != "" { + databases = append(databases, tlsID.RouteToDatabase) + } + } + return databases, nil +} diff --git a/lib/client/db/profile.go b/lib/client/db/profile.go index 7fd29f90dbc87..978a826c2971e 100644 --- a/lib/client/db/profile.go +++ b/lib/client/db/profile.go @@ -86,7 +86,7 @@ func New(tc *client.TeleportClient, db tlsca.RouteToDatabase, clientProfile clie Database: db.Database, Insecure: tc.InsecureSkipVerify, CACertPath: clientProfile.CACertPath(), - CertPath: clientProfile.DatabaseCertPath(db.ServiceName), + CertPath: clientProfile.DatabaseCertPathForCluster(tc.SiteName, db.ServiceName), KeyPath: clientProfile.KeyPath(), } } diff --git a/lib/client/db/profile_test.go b/lib/client/db/profile_test.go index 773663b1b4219..16b073b2839ba 100644 --- a/lib/client/db/profile_test.go +++ b/lib/client/db/profile_test.go @@ -112,7 +112,7 @@ func TestAddProfile(t *testing.T) { Host: test.profileHostOut, Port: test.profilePortOut, CACertPath: ps.CACertPath(), - CertPath: ps.DatabaseCertPath(db.ServiceName), + CertPath: ps.DatabaseCertPathForCluster(tc.SiteName, db.ServiceName), KeyPath: ps.KeyPath(), }, actual) }) diff --git a/lib/client/keyagent.go b/lib/client/keyagent.go index e1dbc0fdf757d..c69e7c7b42263 100644 --- a/lib/client/keyagent.go +++ b/lib/client/keyagent.go @@ -437,8 +437,27 @@ func (a *LocalKeyAgent) defaultHostPromptFunc(host string, key ssh.PublicKey, wr // AddKey activates a new signed session key by adding it into the keystore and also // by loading it into the SSH agent. func (a *LocalKeyAgent) AddKey(key *Key) (*agent.AddedKey, error) { + if err := a.addKey(key); err != nil { + return nil, trace.Wrap(err) + } + + // Load key into the teleport agent and system agent. + return a.LoadKey(*key) +} + +// AddDatabaseKey activates a new signed database key by adding it into the keystore. +// key must contain at least one db cert. ssh cert is not required. +func (a *LocalKeyAgent) AddDatabaseKey(key *Key) error { + if len(key.DBTLSCerts) == 0 { + return trace.BadParameter("key must contains at least one database access certificate") + } + return a.addKey(key) +} + +// addKey activates a new signed session key by adding it into the keystore. +func (a *LocalKeyAgent) addKey(key *Key) error { if key == nil { - return nil, trace.BadParameter("key is nil") + return trace.BadParameter("key is nil") } if key.ProxyHost == "" { key.ProxyHost = a.proxyHost @@ -453,23 +472,22 @@ func (a *LocalKeyAgent) AddKey(key *Key) (*agent.AddedKey, error) { storedKey, err := a.keyStore.GetKey(key.KeyIndex) if err != nil { if !trace.IsNotFound(err) { - return nil, trace.Wrap(err) + return trace.Wrap(err) } } else { if subtle.ConstantTimeCompare(storedKey.Priv, key.Priv) == 0 { a.log.Debugf("Deleting obsolete stored key with index %+v.", storedKey.KeyIndex) if err := a.keyStore.DeleteKey(storedKey.KeyIndex); err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } } } // Save the new key to the keystore (usually into ~/.tsh). if err := a.keyStore.AddKey(key); err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } - // Load key into the teleport agent and system agent. - return a.LoadKey(*key) + return nil } // DeleteKey removes the key with all its certs from the key store diff --git a/lib/client/keyagent_test.go b/lib/client/keyagent_test.go index 94ad023443fc8..e5a0a20555a4a 100644 --- a/lib/client/keyagent_test.go +++ b/lib/client/keyagent_test.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keypaths" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" @@ -54,6 +55,7 @@ type KeyAgentTestSuite struct { hostname string clusterName string tlsca *tlsca.CertAuthority + tlscaCert auth.TrustedCerts } func makeSuite(t *testing.T) *KeyAgentTestSuite { @@ -72,7 +74,7 @@ func makeSuite(t *testing.T) *KeyAgentTestSuite { pemBytes, ok := fixtures.PEMBytes["rsa"] require.True(t, ok) - s.tlsca, _, err = newSelfSignedCA(pemBytes) + s.tlsca, s.tlscaCert, err = newSelfSignedCA(pemBytes) require.NoError(t, err) s.key, err = s.makeKey(s.username, []string{s.username}, 1*time.Minute) @@ -431,6 +433,38 @@ func TestDefaultHostPromptFunc(t *testing.T) { } } +func TestLocalKeyAgent_AddDatabaseKey(t *testing.T) { + s := makeSuite(t) + + // make a new local agent + keystore, err := NewFSLocalKeyStore(s.keyDir) + require.NoError(t, err) + lka, err := NewLocalAgent( + LocalAgentConfig{ + Keystore: keystore, + ProxyHost: s.hostname, + Username: s.username, + KeysOption: AddKeysToAgentAuto, + }) + require.NoError(t, err) + + t.Run("no database cert", func(t *testing.T) { + require.Error(t, lka.AddDatabaseKey(s.key)) + }) + + t.Run("success", func(t *testing.T) { + // modify key to have db cert + addKey := *s.key + addKey.DBTLSCerts = map[string][]byte{"some-db": addKey.TLSCert} + require.NoError(t, lka.SaveTrustedCerts([]auth.TrustedCerts{s.tlscaCert})) + require.NoError(t, lka.AddDatabaseKey(&addKey)) + + getKey, err := lka.GetKey(addKey.ClusterName, WithDBCerts{}) + require.NoError(t, err) + require.Contains(t, getKey.DBTLSCerts, "some-db") + }) +} + func (s *KeyAgentTestSuite) makeKey(username string, allowedLogins []string, ttl time.Duration) (*Key, error) { keygen := testauthority.New() diff --git a/lib/client/keystore.go b/lib/client/keystore.go index 7d499e83d22bc..9ae99650a63a0 100644 --- a/lib/client/keystore.go +++ b/lib/client/keystore.go @@ -138,9 +138,12 @@ func (fs *FSLocalKeyStore) AddKey(key *Key) error { } // Store per-cluster key data. - if err := fs.writeBytes(key.Cert, fs.sshCertPath(key.KeyIndex)); err != nil { - return trace.Wrap(err) + if len(key.Cert) > 0 { + if err := fs.writeBytes(key.Cert, fs.sshCertPath(key.KeyIndex)); err != nil { + return trace.Wrap(err) + } } + // TODO(awly): unit test this. for kubeCluster, cert := range key.KubeTLSCerts { // Prevent directory traversal via a crafted kubernetes cluster name. diff --git a/lib/client/keystore_test.go b/lib/client/keystore_test.go index 4b2861e767a4e..b896183a6e3a9 100644 --- a/lib/client/keystore_test.go +++ b/lib/client/keystore_test.go @@ -353,6 +353,27 @@ hRdXE63PXwAfzj0P/H4qWsFfwdeCo/fuIQIDAQAB require.NoError(t, err) } +func TestAddKey_withoutSSHCert(t *testing.T) { + s, cleanup := newTest(t) + defer cleanup() + + // without ssh cert, db certs only + idx := KeyIndex{"host.a", "bob", "root"} + key := s.makeSignedKey(t, idx, false) + key.Cert = nil + require.NoError(t, s.addKey(key)) + + // ssh cert path should NOT exist + sshCertPath := s.store.sshCertPath(key.KeyIndex) + _, err := os.Stat(sshCertPath) + require.ErrorIs(t, err, os.ErrNotExist) + + // check db certs + keyCopy, err := s.store.GetKey(idx, WithDBCerts{}) + require.NoError(t, err) + require.Len(t, keyCopy.DBTLSCerts, 1) +} + type keyStoreTest struct { storeDir string store *FSLocalKeyStore diff --git a/tool/teleport/common/teleport.go b/tool/teleport/common/teleport.go index 810f2fc7113f6..255d72d14b517 100644 --- a/tool/teleport/common/teleport.go +++ b/tool/teleport/common/teleport.go @@ -175,6 +175,7 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con appStartCmd.Flag("name", "Name of the application to start.").StringVar(&ccf.AppName) appStartCmd.Flag("uri", "Internal address of the application to proxy.").StringVar(&ccf.AppURI) appStartCmd.Flag("public-addr", "Public address of the application to proxy.").StringVar(&ccf.AppPublicAddr) + appStartCmd.Flag("diag-addr", "Start diagnostic prometheus and healthz endpoint.").Hidden().StringVar(&ccf.DiagnosticAddr) appStartCmd.Alias(appUsageExamples) // We're using "alias" section to display usage examples. // "teleport db" command and its subcommands @@ -200,6 +201,7 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con dbStartCmd.Flag("aws-rds-cluster-id", "(Only for Aurora) Aurora cluster identifier.").StringVar(&ccf.DatabaseAWSRDSClusterID) dbStartCmd.Flag("gcp-project-id", "(Only for Cloud SQL) GCP Cloud SQL project identifier.").StringVar(&ccf.DatabaseGCPProjectID) dbStartCmd.Flag("gcp-instance-id", "(Only for Cloud SQL) GCP Cloud SQL instance identifier.").StringVar(&ccf.DatabaseGCPInstanceID) + dbStartCmd.Flag("diag-addr", "Start diagnostic prometheus and healthz endpoint.").Hidden().StringVar(&ccf.DiagnosticAddr) dbStartCmd.Alias(dbUsageExamples) // We're using "alias" section to display usage examples. // define a hidden 'scp' command (it implements server-side implementation of handling diff --git a/tool/tsh/db.go b/tool/tsh/db.go index a5e1ff996e8b4..da7c23632a559 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -24,7 +24,6 @@ import ( "sort" "strconv" "strings" - "text/template" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" @@ -61,7 +60,12 @@ func onListDatabases(cf *CLIConf) error { sort.Slice(databases, func(i, j int) bool { return databases[i].GetName() < databases[j].GetName() }) - showDatabases(tc.SiteName, databases, profile.Databases, cf.Verbose) + + activeDatabases, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + showDatabases(cf.SiteName, databases, activeDatabases, cf.Verbose) return nil } @@ -116,7 +120,7 @@ func databaseLogin(cf *CLIConf, tc *client.TeleportClient, db tlsca.RouteToDatab }); err != nil { return trace.Wrap(err) } - if _, err = tc.LocalAgent().AddKey(key); err != nil { + if err = tc.LocalAgent().AddDatabaseKey(key); err != nil { return trace.Wrap(err) } @@ -132,7 +136,8 @@ func databaseLogin(cf *CLIConf, tc *client.TeleportClient, db tlsca.RouteToDatab } // Print after-connect message. if !quiet { - return connectMessage.Execute(os.Stdout, db) + fmt.Println(formatDatabaseConnnectMessage(cf.SiteName, db)) + return nil } return nil } @@ -147,12 +152,17 @@ func onDatabaseLogout(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } + activeDatabases, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return trace.Wrap(err) + } + var logout []tlsca.RouteToDatabase // If database name wasn't given on the command line, log out of all. if cf.DatabaseService == "" { - logout = profile.Databases + logout = activeDatabases } else { - for _, db := range profile.Databases { + for _, db := range activeDatabases { if db.ServiceName == cf.DatabaseService { logout = append(logout, db) } @@ -256,7 +266,7 @@ Key: %v `, database.ServiceName, host, port, database.Username, database.Database, profile.CACertPath(), - profile.DatabaseCertPath(database.ServiceName), profile.KeyPath()) + profile.DatabaseCertPathForCluster(tc.SiteName, database.ServiceName), profile.KeyPath()) } return nil } @@ -374,14 +384,19 @@ func getDatabase(cf *CLIConf, tc *client.TeleportClient, dbName string) (types.D } if len(databases) == 0 { return nil, trace.NotFound( - "database %q not found, use 'tsh db ls' to see registered databases", dbName) + "database %q not found, use '%v' to see registered databases", dbName, formatDatabaseListCommand(cf.SiteName)) } return databases[0], nil } func needRelogin(cf *CLIConf, tc *client.TeleportClient, database *tlsca.RouteToDatabase, profile *client.ProfileStatus) (bool, error) { found := false - for _, v := range profile.Databases { + activeDatabases, err := profile.DatabasesForCluster(tc.SiteName) + if err != nil { + return false, trace.Wrap(err) + } + + for _, v := range activeDatabases { if v.ServiceName == database.ServiceName { found = true } @@ -437,19 +452,28 @@ func pickActiveDatabase(cf *CLIConf) (*tlsca.RouteToDatabase, error) { if err != nil { return nil, trace.Wrap(err) } - if len(profile.Databases) == 0 { + activeDatabases, err := profile.DatabasesForCluster(cf.SiteName) + if err != nil { + return nil, trace.Wrap(err) + } + + if len(activeDatabases) == 0 { return nil, trace.NotFound("Please login using 'tsh db login' first") } + name := cf.DatabaseService if name == "" { - services := profile.DatabaseServices() - if len(services) > 1 { + if len(activeDatabases) > 1 { + var services []string + for _, database := range activeDatabases { + services = append(services, database.ServiceName) + } return nil, trace.BadParameter("Multiple databases are available (%v), please specify one using CLI argument", strings.Join(services, ", ")) } - name = services[0] + name = activeDatabases[0].ServiceName } - for _, db := range profile.Databases { + for _, db := range activeDatabases { if db.ServiceName == name { // If database user or name were provided on the CLI, // override the default ones. @@ -502,10 +526,10 @@ func getConnectCommand(cf *CLIConf, tc *client.TeleportClient, profile *client.P return getCockroachCommand(tc, profile, db, host, port, options), nil case defaults.ProtocolMySQL: - return getMySQLCommand(profile, db, options), nil + return getMySQLCommand(tc, profile, db, options), nil case defaults.ProtocolMongoDB: - return getMongoCommand(profile, db, host, port, options), nil + return getMongoCommand(tc, profile, db, host, port, options), nil } return nil, trace.BadParameter("unsupported database protocol: %v", db) @@ -528,8 +552,8 @@ func getCockroachCommand(tc *client.TeleportClient, profile *client.ProfileStatu postgres.GetConnString(dbprofile.New(tc, *db, *profile, host, port))) } -func getMySQLCommand(profile *client.ProfileStatus, db *tlsca.RouteToDatabase, options connectionCommandOpts) *exec.Cmd { - args := []string{fmt.Sprintf("--defaults-group-suffix=_%v-%v", profile.Cluster, db.ServiceName)} +func getMySQLCommand(tc *client.TeleportClient, profile *client.ProfileStatus, db *tlsca.RouteToDatabase, options connectionCommandOpts) *exec.Cmd { + args := []string{fmt.Sprintf("--defaults-group-suffix=_%v-%v", tc.SiteName, db.ServiceName)} if db.Username != "" { args = append(args, "--user", db.Username) } @@ -550,12 +574,12 @@ func getMySQLCommand(profile *client.ProfileStatus, db *tlsca.RouteToDatabase, o return exec.Command(mysqlBin, args...) } -func getMongoCommand(profile *client.ProfileStatus, db *tlsca.RouteToDatabase, host string, port int, options connectionCommandOpts) *exec.Cmd { +func getMongoCommand(tc *client.TeleportClient, profile *client.ProfileStatus, db *tlsca.RouteToDatabase, host string, port int, options connectionCommandOpts) *exec.Cmd { args := []string{ "--host", host, "--port", strconv.Itoa(port), "--ssl", - "--sslPEMKeyFile", profile.DatabaseCertPath(db.ServiceName), + "--sslPEMKeyFile", profile.DatabaseCertPathForCluster(tc.SiteName, db.ServiceName), } if options.caPath != "" { @@ -569,6 +593,41 @@ func getMongoCommand(profile *client.ProfileStatus, db *tlsca.RouteToDatabase, h return exec.Command(mongoBin, args...) } +func formatDatabaseListCommand(clusterFlag string) string { + if clusterFlag == "" { + return "tsh db ls" + } + return fmt.Sprintf("tsh db ls --cluster=%v", clusterFlag) +} + +func formatDatabaseConfigCommand(clusterFlag string, db tlsca.RouteToDatabase) string { + if clusterFlag == "" { + return fmt.Sprintf("tsh db config --format=cmd %v", db.ServiceName) + } + return fmt.Sprintf("tsh db config --cluster=%v --format=cmd %v", clusterFlag, db.ServiceName) +} + +func formatDatabaseConnnectMessage(clusterFlag string, db tlsca.RouteToDatabase) string { + connectCommand := formatConnectCommand(clusterFlag, db) + configCommand := formatDatabaseConfigCommand(clusterFlag, db) + + return fmt.Sprintf(` +Connection information for database "%v" has been saved. + +You can now connect to it using the following command: + + %v + +Or view the connect command for the native database CLI client: + + %v + +`, + db.ServiceName, + utils.Color(utils.Yellow, connectCommand), + utils.Color(utils.Yellow, configCommand)) +} + const ( // dbFormatText prints database configuration in text format. dbFormatText = "text" @@ -586,19 +645,3 @@ const ( // mongoBin is the Mongo client binary name. mongoBin = "mongo" ) - -// connectMessage is printed after successful login to a database. -var connectMessage = template.Must(template.New("").Parse(fmt.Sprintf(` -Connection information for database "{{.ServiceName}}" has been saved. - -You can now connect to it using the following command: - - %v - -Or view the connect command for the native database CLI client: - - %v - -`, - utils.Color(utils.Yellow, "tsh db connect {{.ServiceName}}"), - utils.Color(utils.Yellow, "tsh db config --format=cmd {{.ServiceName}}")))) diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go index 171ab6f9d5f22..4bd097b8cba10 100644 --- a/tool/tsh/db_test.go +++ b/tool/tsh/db_test.go @@ -24,12 +24,12 @@ import ( "time" apidefaults "github.com/gravitational/teleport/api/defaults" - "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/service" + "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" @@ -38,10 +38,7 @@ import ( // TestDatabaseLogin verifies "tsh db login" command. func TestDatabaseLogin(t *testing.T) { - os.RemoveAll(profile.FullProfilePath("")) - t.Cleanup(func() { - os.RemoveAll(profile.FullProfilePath("")) - }) + tmpHomePath := t.TempDir() connector := mockConnector(t) @@ -69,24 +66,24 @@ func TestDatabaseLogin(t *testing.T) { // Log into Teleport cluster. err = Run([]string{ "login", "--insecure", "--debug", "--auth", connector.GetName(), "--proxy", proxyAddr.String(), - }, cliOption(func(cf *CLIConf) error { + }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) return nil })) require.NoError(t, err) // Fetch the active profile. - profile, err := client.StatusFor("", proxyAddr.Host(), alice.GetName()) + profile, err := client.StatusFor(tmpHomePath, proxyAddr.Host(), alice.GetName()) require.NoError(t, err) // Log into test Postgres database. err = Run([]string{ "db", "login", "--debug", "postgres", - }) + }, setHomePath(tmpHomePath)) require.NoError(t, err) // Verify Postgres identity file contains certificate. - certs, keys, err := decodePEM(profile.DatabaseCertPath("postgres")) + certs, keys, err := decodePEM(profile.DatabaseCertPathForCluster("", "postgres")) require.NoError(t, err) require.Len(t, certs, 1) require.Len(t, keys, 0) @@ -94,16 +91,40 @@ func TestDatabaseLogin(t *testing.T) { // Log into test Mongo database. err = Run([]string{ "db", "login", "--debug", "--db-user", "admin", "mongo", - }) + }, setHomePath(tmpHomePath)) require.NoError(t, err) // Verify Mongo identity file contains both certificate and key. - certs, keys, err = decodePEM(profile.DatabaseCertPath("mongo")) + certs, keys, err = decodePEM(profile.DatabaseCertPathForCluster("", "mongo")) require.NoError(t, err) require.Len(t, certs, 1) require.Len(t, keys, 1) } +func TestFormatDatabaseListCommand(t *testing.T) { + t.Run("default", func(t *testing.T) { + require.Equal(t, "tsh db ls", formatDatabaseListCommand("")) + }) + + t.Run("with cluster flag", func(t *testing.T) { + require.Equal(t, "tsh db ls --cluster=leaf", formatDatabaseListCommand("leaf")) + }) +} + +func TestFormatConfigCommand(t *testing.T) { + db := tlsca.RouteToDatabase{ + ServiceName: "example-db", + } + + t.Run("default", func(t *testing.T) { + require.Equal(t, "tsh db config --format=cmd example-db", formatDatabaseConfigCommand("", db)) + }) + + t.Run("with cluster flag", func(t *testing.T) { + require.Equal(t, "tsh db config --cluster=leaf --format=cmd example-db", formatDatabaseConfigCommand("leaf", db)) + }) +} + func makeTestDatabaseServer(t *testing.T, auth *service.TeleportProcess, proxy *service.TeleportProcess, dbs ...service.Database) (db *service.TeleportProcess) { // Proxy uses self-signed certificates in tests. lib.SetInsecureDevMode(true) diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index d89647b1ed561..1ef8c6edc3b66 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -179,7 +179,7 @@ func onProxyCommandDB(cf *CLIConf) error { "database": database.ServiceName, "address": listener.Addr().String(), "ca": profile.CACertPath(), - "cert": profile.DatabaseCertPath(database.ServiceName), + "cert": profile.DatabaseCertPathForCluster(cf.SiteName, database.ServiceName), "key": profile.KeyPath(), }) if err != nil { diff --git a/tool/tsh/proxy_test.go b/tool/tsh/proxy_test.go index 6baa941636409..ab8af3bee64e7 100644 --- a/tool/tsh/proxy_test.go +++ b/tool/tsh/proxy_test.go @@ -49,10 +49,7 @@ import ( func TestProxySSHDial(t *testing.T) { createAgent(t) - os.RemoveAll(profile.FullProfilePath("")) - t.Cleanup(func() { - os.RemoveAll(profile.FullProfilePath("")) - }) + tmpHomePath := t.TempDir() connector := mockConnector(t) sshLoginRole, err := types.NewRole("ssh-login", types.RoleSpecV4{ @@ -85,7 +82,7 @@ func TestProxySSHDial(t *testing.T) { "--debug", "--auth", connector.GetName(), "--proxy", proxyAddr.String(), - }, func(cf *CLIConf) error { + }, setHomePath(tmpHomePath), func(cf *CLIConf) error { cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) return nil }) @@ -100,7 +97,7 @@ func TestProxySSHDial(t *testing.T) { // as communication channels but in unit test there is no easy way to mock this behavior. err = Run([]string{ "proxy", "ssh", unreachableSubsystem, - }) + }, setHomePath(tmpHomePath)) require.Contains(t, err.Error(), "subsystem request failed") } diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 1fa5f61083d97..8d9f0886d19a7 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -412,9 +412,9 @@ func Run(args []string, opts ...cliOption) error { // Databases. db := app.Command("db", "View and control proxied databases.") + db.Flag("cluster", clusterHelp).StringVar(&cf.SiteName) dbList := db.Command("ls", "List all available databases.") dbList.Flag("verbose", "Show extra database fields.").Short('v').BoolVar(&cf.Verbose) - dbList.Flag("cluster", clusterHelp).StringVar(&cf.SiteName) dbLogin := db.Command("login", "Retrieve credentials for a database.") dbLogin.Arg("db", "Database to retrieve credentials for. Can be obtained from 'tsh db ls' output.").Required().StringVar(&cf.DatabaseService) dbLogin.Flag("db-user", "Optional database user to configure as default.").StringVar(&cf.DatabaseUser) @@ -1426,7 +1426,7 @@ func showApps(apps []types.Application, active []tlsca.RouteToApp, verbose bool) } } -func showDatabases(cluster string, databases []types.Database, active []tlsca.RouteToDatabase, verbose bool) { +func showDatabases(clusterFlag string, databases []types.Database, active []tlsca.RouteToDatabase, verbose bool) { if verbose { t := asciitable.MakeTable([]string{"Name", "Description", "Protocol", "Type", "URI", "Labels", "Connect", "Expires"}) for _, database := range databases { @@ -1435,7 +1435,7 @@ func showDatabases(cluster string, databases []types.Database, active []tlsca.Ro for _, a := range active { if a.ServiceName == name { name = formatActiveDB(a) - connect = formatConnectCommand(cluster, a) + connect = formatConnectCommand(clusterFlag, a) } } t.AddRow([]string{ @@ -1458,7 +1458,7 @@ func showDatabases(cluster string, databases []types.Database, active []tlsca.Ro for _, a := range active { if a.ServiceName == name { name = formatActiveDB(a) - connect = formatConnectCommand(cluster, a) + connect = formatConnectCommand(clusterFlag, a) } } t.AddRow([]string{ @@ -1481,16 +1481,21 @@ func formatDatabaseLabels(database types.Database) string { // formatConnectCommand formats an appropriate database connection command // for a user based on the provided database parameters. -func formatConnectCommand(cluster string, active tlsca.RouteToDatabase) string { - switch { - case active.Username != "" && active.Database != "": - return fmt.Sprintf("tsh db connect %v", active.ServiceName) - case active.Username != "": - return fmt.Sprintf("tsh db connect --db-name= %v", active.ServiceName) - case active.Database != "": - return fmt.Sprintf("tsh db connect --db-user= %v", active.ServiceName) +func formatConnectCommand(clusterFlag string, active tlsca.RouteToDatabase) string { + cmdTokens := []string{"tsh", "db", "connect"} + + if clusterFlag != "" { + cmdTokens = append(cmdTokens, fmt.Sprintf("--cluster=%s", clusterFlag)) } - return fmt.Sprintf("tsh db connect --db-user= --db-name= %v", active.ServiceName) + if active.Username == "" { + cmdTokens = append(cmdTokens, "--db-user=") + } + if active.Database == "" { + cmdTokens = append(cmdTokens, "--db-name=") + } + + cmdTokens = append(cmdTokens, active.ServiceName) + return strings.Join(cmdTokens, " ") } func formatActiveDB(active tlsca.RouteToDatabase) string { @@ -2295,7 +2300,7 @@ func reissueWithRequests(cf *CLIConf, tc *client.TeleportClient, reqIDs ...strin if err := tc.ReissueUserCerts(cf.Context, client.CertCacheDrop, params); err != nil { return trace.Wrap(err) } - if err := tc.SaveProfile("", true); err != nil { + if err := tc.SaveProfile(cf.HomePath, true); err != nil { return trace.Wrap(err) } if err := updateKubeConfig(cf, tc, ""); err != nil { diff --git a/tool/tsh/tsh_test.go b/tool/tsh/tsh_test.go index a05d0d447dd4d..98316be0b6b84 100644 --- a/tool/tsh/tsh_test.go +++ b/tool/tsh/tsh_test.go @@ -34,7 +34,6 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" - "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth" @@ -115,10 +114,7 @@ func (p *cliModules) IsBoringBinary() bool { } func TestFailedLogin(t *testing.T) { - os.RemoveAll(profile.FullProfilePath("")) - t.Cleanup(func() { - os.RemoveAll(profile.FullProfilePath("")) - }) + tmpHomePath := t.TempDir() connector := mockConnector(t) @@ -139,7 +135,7 @@ func TestFailedLogin(t *testing.T) { "--debug", "--auth", connector.GetName(), "--proxy", proxyAddr.String(), - }, cliOption(func(cf *CLIConf) error { + }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { cf.mockSSOLogin = client.SSOLoginFunc(ssoLogin) return nil })) @@ -147,10 +143,7 @@ func TestFailedLogin(t *testing.T) { } func TestOIDCLogin(t *testing.T) { - os.RemoveAll(profile.FullProfilePath("")) - t.Cleanup(func() { - os.RemoveAll(profile.FullProfilePath("")) - }) + tmpHomePath := t.TempDir() modules.SetModules(&cliModules{}) @@ -226,7 +219,7 @@ func TestOIDCLogin(t *testing.T) { "--auth", connector.GetName(), "--proxy", proxyAddr.String(), "--user", "alice", // explicitly use wrong name - }, cliOption(func(cf *CLIConf) error { + }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) cf.SiteName = "localhost" return nil @@ -245,10 +238,7 @@ func TestOIDCLogin(t *testing.T) { // TestLoginIdentityOut makes sure that "tsh login --out " command // writes identity credentials to the specified path. func TestLoginIdentityOut(t *testing.T) { - os.RemoveAll(profile.FullProfilePath("")) - t.Cleanup(func() { - os.RemoveAll(profile.FullProfilePath("")) - }) + tmpHomePath := t.TempDir() connector := mockConnector(t) @@ -273,7 +263,7 @@ func TestLoginIdentityOut(t *testing.T) { "--auth", connector.GetName(), "--proxy", proxyAddr.String(), "--out", identPath, - }, cliOption(func(cf *CLIConf) error { + }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) return nil })) @@ -284,10 +274,7 @@ func TestLoginIdentityOut(t *testing.T) { } func TestRelogin(t *testing.T) { - os.RemoveAll(profile.FullProfilePath("")) - t.Cleanup(func() { - os.RemoveAll(profile.FullProfilePath("")) - }) + tmpHomePath := t.TempDir() connector := mockConnector(t) @@ -309,7 +296,7 @@ func TestRelogin(t *testing.T) { "--debug", "--auth", connector.GetName(), "--proxy", proxyAddr.String(), - }, cliOption(func(cf *CLIConf) error { + }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) return nil })) @@ -321,10 +308,10 @@ func TestRelogin(t *testing.T) { "--debug", "--proxy", proxyAddr.String(), "localhost", - }) + }, setHomePath(tmpHomePath)) require.NoError(t, err) - err = Run([]string{"logout"}) + err = Run([]string{"logout"}, setHomePath(tmpHomePath)) require.NoError(t, err) err = Run([]string{ @@ -334,7 +321,7 @@ func TestRelogin(t *testing.T) { "--auth", connector.GetName(), "--proxy", proxyAddr.String(), "localhost", - }, cliOption(func(cf *CLIConf) error { + }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) return nil })) @@ -342,12 +329,8 @@ func TestRelogin(t *testing.T) { } func TestMakeClient(t *testing.T) { - os.RemoveAll(profile.FullProfilePath("")) - t.Cleanup(func() { - os.RemoveAll(profile.FullProfilePath("")) - }) - var conf CLIConf + conf.HomePath = t.TempDir() // empty config won't work: tc, err := makeClient(&conf, true) @@ -452,10 +435,7 @@ func TestMakeClient(t *testing.T) { } func TestAccessRequestOnLeaf(t *testing.T) { - os.RemoveAll(profile.FullProfilePath("")) - t.Cleanup(func() { - os.RemoveAll(profile.FullProfilePath("")) - }) + tmpHomePath := t.TempDir() isInsecure := lib.IsInsecureDevMode() lib.SetInsecureDevMode(true) @@ -513,7 +493,7 @@ func TestAccessRequestOnLeaf(t *testing.T) { "--debug", "--auth", connector.GetName(), "--proxy", rootProxyAddr.String(), - }, cliOption(func(cf *CLIConf) error { + }, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error { cf.mockSSOLogin = mockSSOLogin(t, rootAuthServer, alice) return nil })) @@ -525,7 +505,7 @@ func TestAccessRequestOnLeaf(t *testing.T) { "--debug", "--proxy", rootProxyAddr.String(), "leafcluster", - }) + }, setHomePath(tmpHomePath)) require.NoError(t, err) err = Run([]string{ @@ -534,7 +514,7 @@ func TestAccessRequestOnLeaf(t *testing.T) { "--debug", "--proxy", rootProxyAddr.String(), "localhost", - }) + }, setHomePath(tmpHomePath)) require.NoError(t, err) err = Run([]string{ @@ -543,7 +523,7 @@ func TestAccessRequestOnLeaf(t *testing.T) { "--debug", "--proxy", rootProxyAddr.String(), "leafcluster", - }) + }, setHomePath(tmpHomePath)) require.NoError(t, err) errChan := make(chan error) @@ -555,7 +535,7 @@ func TestAccessRequestOnLeaf(t *testing.T) { "--debug", "--proxy", rootProxyAddr.String(), "--roles=access", - }) + }, setHomePath(tmpHomePath)) }() var request types.AccessRequest @@ -784,11 +764,11 @@ func TestOptions(t *testing.T) { } func TestFormatConnectCommand(t *testing.T) { - cluster := "root" tests := []struct { - comment string - db tlsca.RouteToDatabase - command string + clusterFlag string + comment string + db tlsca.RouteToDatabase + command string }{ { comment: "no default user/database are specified", @@ -826,10 +806,20 @@ func TestFormatConnectCommand(t *testing.T) { }, command: `tsh db connect test`, }, + { + comment: "extra cluster flag", + clusterFlag: "leaf", + db: tlsca.RouteToDatabase{ + ServiceName: "test", + Protocol: defaults.ProtocolPostgres, + Database: "postgres", + }, + command: `tsh db connect --cluster=leaf --db-user= test`, + }, } for _, test := range tests { t.Run(test.comment, func(t *testing.T) { - require.Equal(t, test.command, formatConnectCommand(cluster, test.db)) + require.Equal(t, test.command, formatConnectCommand(test.clusterFlag, test.db)) }) } } @@ -1259,3 +1249,10 @@ func mockSSOLogin(t *testing.T, authServer *auth.Server, user types.User) client }, nil } } + +func setHomePath(path string) cliOption { + return func(cf *CLIConf) error { + cf.HomePath = path + return nil + } +}