Skip to content

Commit

Permalink
Add a tsh sessions ls command (#37740)
Browse files Browse the repository at this point in the history
This deprecates `tsh kube sessions` in favor of a command that
can be used to list all types of active sessions.

Closes #19152
  • Loading branch information
zmb3 authored Feb 8, 2024
1 parent b60e7a7 commit ab8a629
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 57 deletions.
59 changes: 3 additions & 56 deletions tool/tsh/common/kube.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ type kubeSessionsCommand struct {

func newKubeSessionsCommand(parent *kingpin.CmdClause) *kubeSessionsCommand {
c := &kubeSessionsCommand{
CmdClause: parent.Command("sessions", "Get a list of active Kubernetes sessions."),
CmdClause: parent.Command("sessions", "Get a list of active Kubernetes sessions. (DEPRECATED: use tsh sessions ls --kind=kube instead)"),
}
c.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&c.format, defaults.DefaultFormats...)
c.Flag("cluster", clusterHelp).Short('c').StringVar(&c.siteName)
Expand All @@ -539,61 +539,8 @@ func (c *kubeSessionsCommand) run(cf *CLIConf) error {
return trace.Wrap(err)
}

filteredSessions := make([]types.SessionTracker, 0)
for _, session := range sessions {
if session.GetSessionKind() == types.KubernetesSessionKind {
filteredSessions = append(filteredSessions, session)
}
}

sort.Slice(filteredSessions, func(i, j int) bool {
return filteredSessions[i].GetCreated().Before(filteredSessions[j].GetCreated())
})

format := strings.ToLower(c.format)
switch format {
case teleport.Text, "":
printSessions(cf.Stdout(), filteredSessions)
case teleport.JSON, teleport.YAML:
out, err := serializeKubeSessions(sessions, format)
if err != nil {
return trace.Wrap(err)
}
fmt.Fprintln(cf.Stdout(), out)
default:
return trace.BadParameter("unsupported format %q", c.format)
}
return nil
}

func serializeKubeSessions(sessions []types.SessionTracker, format string) (string, error) {
var out []byte
var err error
if format == teleport.JSON {
out, err = utils.FastMarshalIndent(sessions, "", " ")
} else {
out, err = yaml.Marshal(sessions)
}
return string(out), trace.Wrap(err)
}

func printSessions(output io.Writer, sessions []types.SessionTracker) {
table := asciitable.MakeTable([]string{"ID", "State", "Created", "Hostname", "Address", "Login", "Reason", "Command"})
for _, s := range sessions {
table.AddRow([]string{
s.GetSessionID(),
s.GetState().String(),
s.GetCreated().Format(time.RFC3339),
s.GetHostname(),
s.GetAddress(),
s.GetLogin(),
s.GetReason(),
strings.Join(s.GetCommand(), " "),
})
}

tableOutput := table.AsBuffer().String()
fmt.Fprintln(output, tableOutput)
filteredSessions := sortAndFilterSessions(sessions, []types.SessionKind{types.KubernetesSessionKind})
return trace.Wrap(serializeSessions(filteredSessions, strings.ToLower(c.format), cf.Stdout()))
}

type kubeCredentialsCommand struct {
Expand Down
98 changes: 98 additions & 0 deletions tool/tsh/common/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ type CLIConf struct {
// JoinMode is the participant mode someone is joining a session as.
JoinMode string

// SessionKinds is the kind of active sessions to list.
SessionKinds []string

// displayParticipantRequirements is set if verbose participant requirement information should be printed for moderated sessions.
displayParticipantRequirements bool

Expand Down Expand Up @@ -944,12 +947,19 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
ls.Flag("search", searchHelp).StringVar(&cf.SearchKeywords)
ls.Flag("query", queryHelp).StringVar(&cf.PredicateExpression)
ls.Flag("all", "List nodes from all clusters and proxies.").Short('R').BoolVar(&cf.ListAll)

// clusters
clusters := app.Command("clusters", "List available Teleport clusters.")
clusters.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...)
clusters.Flag("quiet", "Quiet mode").Short('q').BoolVar(&cf.Quiet)
clusters.Flag("verbose", "Verbose table output, shows full label output").Short('v').BoolVar(&cf.Verbose)

// sessions
sessions := app.Command("sessions", "Operate on active sessions.")
sessionsList := sessions.Command("ls", "List active sessions.")
sessionsList.Flag("format", defaults.FormatFlagDescription(defaults.DefaultFormats...)).Short('f').Default(teleport.Text).EnumVar(&cf.Format, defaults.DefaultFormats...)
sessionsList.Flag("kind", "Filter by session kind(s)").Default("ssh", "k8s", "db", "app", "desktop").EnumsVar(&cf.SessionKinds, "ssh", "k8s", "kube", "db", "app", "desktop")

// login logs in with remote proxy and obtains a "session certificate" which gets
// stored in ~/.tsh directory
login := app.Command("login", "Log in to a cluster and retrieve the session certificate.")
Expand Down Expand Up @@ -1355,6 +1365,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
err = onListNodes(&cf)
case clusters.FullCommand():
err = onListClusters(&cf)
case sessionsList.FullCommand():
err = onListSessions(&cf)
case login.FullCommand():
err = onLogin(&cf)
case logout.FullCommand():
Expand Down Expand Up @@ -3053,6 +3065,92 @@ func onListClusters(cf *CLIConf) error {
return nil
}

func onListSessions(cf *CLIConf) error {
tc, err := makeClient(cf)
if err != nil {
return trace.Wrap(err)
}

clt, err := tc.ConnectToCluster(cf.Context)
if err != nil {
return trace.Wrap(err)
}
defer clt.Close()

sessions, err := clt.AuthClient.GetActiveSessionTrackers(cf.Context)
if err != nil {
return trace.Wrap(err)
}

kinds := map[string]types.SessionKind{
"ssh": types.SSHSessionKind,
"db": types.DatabaseSessionKind,
"app": types.AppSessionKind,
"desktop": types.WindowsDesktopSessionKind,
"k8s": types.KubernetesSessionKind,
// tsh commands often use "kube" to mean kubernetes,
// so add an alias to make it more intuitive
"kube": types.KubernetesSessionKind,
}

var filter []types.SessionKind
for _, k := range cf.SessionKinds {
filter = append(filter, kinds[k])
}
sessions = sortAndFilterSessions(sessions, filter)
return trace.Wrap(serializeSessions(sessions, strings.ToLower(cf.Format), cf.Stdout()))
}

func sortAndFilterSessions(sessions []types.SessionTracker, kinds []types.SessionKind) []types.SessionTracker {
filtered := slices.DeleteFunc(sessions, func(st types.SessionTracker) bool {
return !slices.Contains(kinds, st.GetSessionKind())
})
sort.Slice(filtered, func(i, j int) bool {
return filtered[i].GetCreated().Before(filtered[j].GetCreated())
})
return filtered
}

func serializeSessions(sessions []types.SessionTracker, format string, w io.Writer) error {
switch format {
case teleport.Text, "":
printSessions(w, sessions)
case teleport.JSON:
out, err := utils.FastMarshalIndent(sessions, "", " ")
if err != nil {
return trace.Wrap(err)
}
fmt.Fprintln(w, string(out))
case teleport.YAML:
out, err := yaml.Marshal(sessions)
if err != nil {
return trace.Wrap(err)
}
fmt.Fprintln(w, string(out))
default:
return trace.BadParameter("unsupported format %q", format)
}
return nil
}

func printSessions(output io.Writer, sessions []types.SessionTracker) {
table := asciitable.MakeTable([]string{"ID", "Kind", "Created", "Hostname", "Address", "Login", "Command"})
for _, s := range sessions {
table.AddRow([]string{
s.GetSessionID(),
string(s.GetSessionKind()),
s.GetCreated().Format(time.RFC3339),
s.GetHostname(),
s.GetAddress(),
s.GetLogin(),
strings.Join(s.GetCommand(), " "),
})
}

tableOutput := table.AsBuffer().String()
fmt.Fprintln(output, tableOutput)
}

type clusterInfo struct {
ClusterName string `json:"cluster_name"`
Status string `json:"status"`
Expand Down
4 changes: 3 additions & 1 deletion tool/tsh/common/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4463,7 +4463,9 @@ func TestSerializeKubeSessions(t *testing.T) {
})
require.NoError(t, err)
testSerialization(t, expected, func(f string) (string, error) {
return serializeKubeSessions([]types.SessionTracker{tracker}, f)
var b bytes.Buffer
err := serializeSessions([]types.SessionTracker{tracker}, f, &b)
return b.String(), err
})
}

Expand Down

0 comments on commit ab8a629

Please sign in to comment.