diff --git a/integration/kube_integration_test.go b/integration/kube_integration_test.go index dfcf34d2e7acd..5e0797d955a0f 100644 --- a/integration/kube_integration_test.go +++ b/integration/kube_integration_test.go @@ -33,12 +33,12 @@ import ( "testing" "time" - "github.com/gorilla/websocket" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/services" @@ -47,7 +47,6 @@ import ( "github.com/gravitational/teleport/lib/utils" apidefaults "github.com/gravitational/teleport/api/defaults" - "github.com/gravitational/teleport/lib/kube/proxy/streamproto" kubeutils "github.com/gravitational/teleport/lib/kube/utils" "github.com/gravitational/trace" @@ -1481,29 +1480,25 @@ func kubeExec(kubeConfig *rest.Config, args kubeExecArgs) error { return executor.Stream(opts) } -func kubeJoin(kubeConfig kubeProxyConfig, sessionID string) (*streamproto.SessionStream, error) { +func kubeJoin(kubeConfig kubeProxyConfig, tc *client.TeleportClient, sessionID string) (*client.KubeSession, error) { tlsConfig, err := kubeProxyTLSConfig(kubeConfig) if err != nil { return nil, trace.Wrap(err) } - dialer := &websocket.Dialer{ - TLSClientConfig: tlsConfig, - } - - endpoint := "wss://" + kubeConfig.t.Config.Proxy.Kube.ListenAddr.Addr + "/api/v1/teleport/join/" + sessionID - ws, resp, err := dialer.Dial(endpoint, nil) + meta, err := types.NewSessionTracker(types.SessionTrackerSpecV1{ + SessionID: sessionID, + }) if err != nil { return nil, trace.Wrap(err) } - defer resp.Body.Close() - stream, err := streamproto.NewSessionStream(ws, streamproto.ClientHandshake{Mode: types.SessionObserverMode}) + sess, err := client.NewKubeSession(context.TODO(), tc, meta, kubeConfig.t.Config.Proxy.Kube.ListenAddr.Addr, "", types.SessionPeerMode, tlsConfig) if err != nil { return nil, trace.Wrap(err) } - return stream, nil + return sess, nil } // testKubeJoin tests that that joining an interactive exec session works. @@ -1579,12 +1574,21 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) { // created. Sadly though the k8s API doesn't give us much indication of when that is. time.Sleep(time.Second * 5) + participantStdinR, participantStdinW, err := os.Pipe() + participantStdoutR, participantStdoutW, err := os.Pipe() + + tc, err := teleport.NewClient(ClientConfig{}) + require.NoError(t, err) + + tc.Stdin = participantStdinR + tc.Stdout = participantStdoutW + stream, err := kubeJoin(kubeProxyConfig{ t: teleport, username: participantUsername, kubeUsers: kubeUsers, kubeGroups: kubeGroups, - }, "") + }, tc, "") require.NoError(t, err) defer stream.Close() @@ -1593,13 +1597,17 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) { // new IO streams of the second client. time.Sleep(time.Second * 5) + // sent a test message from the participant + participantStdinW.WriteString("\aecho hi2\n\r") + // lets type "echo hi" followed by "enter" and then "exit" + "enter": term.Type("\aecho hi\n\r") // Terminate the session after a moment to allow for the IO to reach the second client. time.AfterFunc(5*time.Second, func() { term.Type("\aexit\n\r\a") }) - participantOutput, err := io.ReadAll(stream) + participantOutput, err := io.ReadAll(participantStdoutR) require.NoError(t, err) require.Contains(t, participantOutput, []byte("echo hi")) + require.Contains(t, out.String(), []byte("echo hi2")) } diff --git a/lib/client/kubesession.go b/lib/client/kubesession.go index b29d9c542ac60..1b07c0a41d29d 100644 --- a/lib/client/kubesession.go +++ b/lib/client/kubesession.go @@ -18,6 +18,7 @@ package client import ( "context" + "crypto/tls" "fmt" "io" "sync" @@ -46,15 +47,9 @@ type KubeSession struct { } // NewKubeSession joins a live kubernetes session. -func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionTracker, key *Key, kubeAddr string, tlsServer string, mode types.SessionParticipantMode) (*KubeSession, error) { +func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionTracker, kubeAddr string, tlsServer string, mode types.SessionParticipantMode, tlsConfig *tls.Config) (*KubeSession, error) { closeWait := &sync.WaitGroup{} joinEndpoint := "wss://" + kubeAddr + "/api/v1/teleport/join/" + meta.GetSessionID() - kubeCluster := meta.GetKubeCluster() - ciphers := utils.DefaultCipherSuites() - tlsConfig, err := key.KubeClientTLSConfig(ciphers, kubeCluster) - if err != nil { - return nil, trace.Wrap(err) - } if tlsServer != "" { tlsConfig.ServerName = tlsServer @@ -201,12 +196,18 @@ func (s *KubeSession) pipeInOut(stdout io.Writer, mode types.SessionParticipantM go func() { defer s.cancel() - handleNonPeerControls(mode, s.term, func() { - err := s.stream.ForceTerminate() - if err != nil { - fmt.Printf("\n\rError while sending force termination request: %v\n\r", err.Error()) - } - }) + switch mode { + case types.SessionPeerMode: + handlePeerControls(s.term, s.stream) + default: + handleNonPeerControls(mode, s.term, func() { + err := s.stream.ForceTerminate() + if err != nil { + log.Debugf("Error sending force termination request: %v", err) + fmt.Print("\n\rError while sending force termination request\n\r") + } + }) + } }() } diff --git a/lib/client/session.go b/lib/client/session.go index e690f7ccdfa59..000140077b030 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -644,6 +644,35 @@ func handleNonPeerControls(mode types.SessionParticipantMode, term *terminal.Ter } } +// handlePeerControls streams the terminal input to the remote shell's standard input. +// Escape sequences for stopping the stream on the client side are supported via `escape.NewReader`. +func handlePeerControls(term *terminal.Terminal, remoteStdin io.Writer) { + stdin := term.Stdin() + if term.IsAttached() { + // escape.NewReader is used to enable manual disconnect sequences as those supported + // by tsh. These can be used to force a client disconnect since CTRL-C is merely passed + // to the other end and not interpreted as an exit request locally + stdin = escape.NewReader(stdin, term.Stderr(), func(err error) { + log.Debugf("escape.NewReader error: %v", err) + + switch err { + case escape.ErrDisconnect: + fmt.Fprint(term.Stderr(), "\r\nDisconnected\r\n") + case escape.ErrTooMuchBufferedData: + fmt.Fprint(term.Stderr(), "\r\nRemote peer may be unreachable, check your connectivity\r\n") + default: + fmt.Fprintf(term.Stderr(), "\r\nunknown error: %v\r\n", err.Error()) + } + }) + } + + _, err := io.Copy(remoteStdin, stdin) + if err != nil { + log.Debugf("Error copying data to remote peer: %v", err) + fmt.Fprint(term.Stderr(), "\r\nError copying data to remote peer\r\n") + } +} + // pipeInOut launches two goroutines: one to pipe the local input into the remote shell, // and another to pipe the output of the remote shell into the local output func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser, mode types.SessionParticipantMode, sess *ssh.Session) { @@ -671,39 +700,7 @@ func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser, mode types.SessionPar case types.SessionPeerMode: // copy from the local input to the remote shell: go func() { - defer ns.closer.Close() - buf := make([]byte, 1024) - - stdin := ns.terminal.Stdin() - if ns.terminal.IsAttached() && ns.enableEscapeSequences { - stdin = escape.NewReader(stdin, ns.terminal.Stderr(), func(err error) { - switch err { - case escape.ErrDisconnect: - fmt.Fprintf(ns.terminal.Stderr(), "\r\n%v\r\n", err) - case escape.ErrTooMuchBufferedData: - fmt.Fprintf(ns.terminal.Stderr(), "\r\nerror: %v\r\nremote peer may be unreachable, check your connectivity\r\n", trace.Wrap(err)) - default: - fmt.Fprintf(ns.terminal.Stderr(), "\r\nerror: %v\r\n", trace.Wrap(err)) - } - ns.closer.Close() - }) - } - - for { - n, err := stdin.Read(buf) - if n > 0 { - _, err = shell.Write(buf[:n]) - if err != nil { - ns.ExitMsg = err.Error() - return - } - } - - if err != nil { - fmt.Fprintf(ns.terminal.Stderr(), "\r\n%v\r\n", trace.Wrap(err)) - return - } - } + handlePeerControls(ns.terminal, shell) }() } } diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index 3601f6ebed0d4..f61f72bb71a7d 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -801,11 +801,6 @@ func (f *Forwarder) join(ctx *authContext, w http.ResponseWriter, req *http.Requ } <-party.closeC - - if err != nil { - return nil, trace.Wrap(err) - } - return nil, nil } diff --git a/tool/tsh/kube.go b/tool/tsh/kube.go index ed2968010ef8f..25f249c523edc 100644 --- a/tool/tsh/kube.go +++ b/tool/tsh/kube.go @@ -192,7 +192,13 @@ func (c *kubeJoinCommand) run(cf *CLIConf) error { return trace.Wrap(err) } - session, err := client.NewKubeSession(cf.Context, tc, meta, k, tc.KubeProxyAddr, kubeStatus.tlsServerName, types.SessionParticipantMode(c.mode)) + ciphers := utils.DefaultCipherSuites() + tlsConfig, err := k.KubeClientTLSConfig(ciphers, kubeCluster) + if err != nil { + return trace.Wrap(err) + } + + session, err := client.NewKubeSession(cf.Context, tc, meta, tc.KubeProxyAddr, kubeStatus.tlsServerName, types.SessionParticipantMode(c.mode), tlsConfig) if err != nil { return trace.Wrap(err) }