Skip to content

Commit

Permalink
Pipe terminal stdin to session in kubernetes peer mode (#11288) (#11918)
Browse files Browse the repository at this point in the history
  • Loading branch information
xacrimon authored Apr 20, 2022
1 parent 936047d commit e70e6c8
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 66 deletions.
36 changes: 22 additions & 14 deletions integration/kube_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ import (
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/services"
Expand All @@ -47,7 +47,6 @@ import (
"github.com/gravitational/teleport/lib/utils"

apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/lib/kube/proxy/streamproto"
kubeutils "github.com/gravitational/teleport/lib/kube/utils"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -1481,29 +1480,25 @@ func kubeExec(kubeConfig *rest.Config, args kubeExecArgs) error {
return executor.Stream(opts)
}

func kubeJoin(kubeConfig kubeProxyConfig, sessionID string) (*streamproto.SessionStream, error) {
func kubeJoin(kubeConfig kubeProxyConfig, tc *client.TeleportClient, sessionID string) (*client.KubeSession, error) {
tlsConfig, err := kubeProxyTLSConfig(kubeConfig)
if err != nil {
return nil, trace.Wrap(err)
}

dialer := &websocket.Dialer{
TLSClientConfig: tlsConfig,
}

endpoint := "wss://" + kubeConfig.t.Config.Proxy.Kube.ListenAddr.Addr + "/api/v1/teleport/join/" + sessionID
ws, resp, err := dialer.Dial(endpoint, nil)
meta, err := types.NewSessionTracker(types.SessionTrackerSpecV1{
SessionID: sessionID,
})
if err != nil {
return nil, trace.Wrap(err)
}
defer resp.Body.Close()

stream, err := streamproto.NewSessionStream(ws, streamproto.ClientHandshake{Mode: types.SessionObserverMode})
sess, err := client.NewKubeSession(context.TODO(), tc, meta, kubeConfig.t.Config.Proxy.Kube.ListenAddr.Addr, "", types.SessionPeerMode, tlsConfig)
if err != nil {
return nil, trace.Wrap(err)
}

return stream, nil
return sess, nil
}

// testKubeJoin tests that that joining an interactive exec session works.
Expand Down Expand Up @@ -1579,12 +1574,21 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {
// created. Sadly though the k8s API doesn't give us much indication of when that is.
time.Sleep(time.Second * 5)

participantStdinR, participantStdinW, err := os.Pipe()
participantStdoutR, participantStdoutW, err := os.Pipe()

tc, err := teleport.NewClient(ClientConfig{})
require.NoError(t, err)

tc.Stdin = participantStdinR
tc.Stdout = participantStdoutW

stream, err := kubeJoin(kubeProxyConfig{
t: teleport,
username: participantUsername,
kubeUsers: kubeUsers,
kubeGroups: kubeGroups,
}, "")
}, tc, "")
require.NoError(t, err)
defer stream.Close()

Expand All @@ -1593,13 +1597,17 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {
// new IO streams of the second client.
time.Sleep(time.Second * 5)

// sent a test message from the participant
participantStdinW.WriteString("\aecho hi2\n\r")

// lets type "echo hi" followed by "enter" and then "exit" + "enter":
term.Type("\aecho hi\n\r")

// Terminate the session after a moment to allow for the IO to reach the second client.
time.AfterFunc(5*time.Second, func() { term.Type("\aexit\n\r\a") })

participantOutput, err := io.ReadAll(stream)
participantOutput, err := io.ReadAll(participantStdoutR)
require.NoError(t, err)
require.Contains(t, participantOutput, []byte("echo hi"))
require.Contains(t, out.String(), []byte("echo hi2"))
}
27 changes: 14 additions & 13 deletions lib/client/kubesession.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package client

import (
"context"
"crypto/tls"
"fmt"
"io"
"sync"
Expand Down Expand Up @@ -46,15 +47,9 @@ type KubeSession struct {
}

// NewKubeSession joins a live kubernetes session.
func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionTracker, key *Key, kubeAddr string, tlsServer string, mode types.SessionParticipantMode) (*KubeSession, error) {
func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionTracker, kubeAddr string, tlsServer string, mode types.SessionParticipantMode, tlsConfig *tls.Config) (*KubeSession, error) {
closeWait := &sync.WaitGroup{}
joinEndpoint := "wss://" + kubeAddr + "/api/v1/teleport/join/" + meta.GetSessionID()
kubeCluster := meta.GetKubeCluster()
ciphers := utils.DefaultCipherSuites()
tlsConfig, err := key.KubeClientTLSConfig(ciphers, kubeCluster)
if err != nil {
return nil, trace.Wrap(err)
}

if tlsServer != "" {
tlsConfig.ServerName = tlsServer
Expand Down Expand Up @@ -201,12 +196,18 @@ func (s *KubeSession) pipeInOut(stdout io.Writer, mode types.SessionParticipantM
go func() {
defer s.cancel()

handleNonPeerControls(mode, s.term, func() {
err := s.stream.ForceTerminate()
if err != nil {
fmt.Printf("\n\rError while sending force termination request: %v\n\r", err.Error())
}
})
switch mode {
case types.SessionPeerMode:
handlePeerControls(s.term, s.stream)
default:
handleNonPeerControls(mode, s.term, func() {
err := s.stream.ForceTerminate()
if err != nil {
log.Debugf("Error sending force termination request: %v", err)
fmt.Print("\n\rError while sending force termination request\n\r")
}
})
}
}()
}

Expand Down
63 changes: 30 additions & 33 deletions lib/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,35 @@ func handleNonPeerControls(mode types.SessionParticipantMode, term *terminal.Ter
}
}

// handlePeerControls streams the terminal input to the remote shell's standard input.
// Escape sequences for stopping the stream on the client side are supported via `escape.NewReader`.
func handlePeerControls(term *terminal.Terminal, remoteStdin io.Writer) {
stdin := term.Stdin()
if term.IsAttached() {
// escape.NewReader is used to enable manual disconnect sequences as those supported
// by tsh. These can be used to force a client disconnect since CTRL-C is merely passed
// to the other end and not interpreted as an exit request locally
stdin = escape.NewReader(stdin, term.Stderr(), func(err error) {
log.Debugf("escape.NewReader error: %v", err)

switch err {
case escape.ErrDisconnect:
fmt.Fprint(term.Stderr(), "\r\nDisconnected\r\n")
case escape.ErrTooMuchBufferedData:
fmt.Fprint(term.Stderr(), "\r\nRemote peer may be unreachable, check your connectivity\r\n")
default:
fmt.Fprintf(term.Stderr(), "\r\nunknown error: %v\r\n", err.Error())
}
})
}

_, err := io.Copy(remoteStdin, stdin)
if err != nil {
log.Debugf("Error copying data to remote peer: %v", err)
fmt.Fprint(term.Stderr(), "\r\nError copying data to remote peer\r\n")
}
}

// pipeInOut launches two goroutines: one to pipe the local input into the remote shell,
// and another to pipe the output of the remote shell into the local output
func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser, mode types.SessionParticipantMode, sess *ssh.Session) {
Expand Down Expand Up @@ -671,39 +700,7 @@ func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser, mode types.SessionPar
case types.SessionPeerMode:
// copy from the local input to the remote shell:
go func() {
defer ns.closer.Close()
buf := make([]byte, 1024)

stdin := ns.terminal.Stdin()
if ns.terminal.IsAttached() && ns.enableEscapeSequences {
stdin = escape.NewReader(stdin, ns.terminal.Stderr(), func(err error) {
switch err {
case escape.ErrDisconnect:
fmt.Fprintf(ns.terminal.Stderr(), "\r\n%v\r\n", err)
case escape.ErrTooMuchBufferedData:
fmt.Fprintf(ns.terminal.Stderr(), "\r\nerror: %v\r\nremote peer may be unreachable, check your connectivity\r\n", trace.Wrap(err))
default:
fmt.Fprintf(ns.terminal.Stderr(), "\r\nerror: %v\r\n", trace.Wrap(err))
}
ns.closer.Close()
})
}

for {
n, err := stdin.Read(buf)
if n > 0 {
_, err = shell.Write(buf[:n])
if err != nil {
ns.ExitMsg = err.Error()
return
}
}

if err != nil {
fmt.Fprintf(ns.terminal.Stderr(), "\r\n%v\r\n", trace.Wrap(err))
return
}
}
handlePeerControls(ns.terminal, shell)
}()
}
}
Expand Down
5 changes: 0 additions & 5 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -801,11 +801,6 @@ func (f *Forwarder) join(ctx *authContext, w http.ResponseWriter, req *http.Requ
}

<-party.closeC

if err != nil {
return nil, trace.Wrap(err)
}

return nil, nil
}

Expand Down
8 changes: 7 additions & 1 deletion tool/tsh/kube.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,13 @@ func (c *kubeJoinCommand) run(cf *CLIConf) error {
return trace.Wrap(err)
}

session, err := client.NewKubeSession(cf.Context, tc, meta, k, tc.KubeProxyAddr, kubeStatus.tlsServerName, types.SessionParticipantMode(c.mode))
ciphers := utils.DefaultCipherSuites()
tlsConfig, err := k.KubeClientTLSConfig(ciphers, kubeCluster)
if err != nil {
return trace.Wrap(err)
}

session, err := client.NewKubeSession(cf.Context, tc, meta, tc.KubeProxyAddr, kubeStatus.tlsServerName, types.SessionParticipantMode(c.mode), tlsConfig)
if err != nil {
return trace.Wrap(err)
}
Expand Down

0 comments on commit e70e6c8

Please sign in to comment.