From a5370d58fc3b7eab9939c412376522c2a82387b1 Mon Sep 17 00:00:00 2001 From: Anton Miniailo Date: Tue, 7 Mar 2023 13:55:02 +0000 Subject: [PATCH] Allow node to handle old and new way of client IP propagation on the same listener (#22572) * Allow node to handle old and new way of client IP propagation on same listener With addition of signed PROXY headers, node was listening on multiplexer, but because of that it couldn't processing incoming connection from older proxies when ProxyHelloSignature was used, because both ends were waiting for the other side to send data first. Here we integrate ability to handle PROXY headers into connection itself, so we can start ssh server without waiting for multiplexer to detect connection * Remove unneeded code * Improve comment's wording Co-authored-by: Michael Wilson * Improve comment's wording Co-authored-by: Michael Wilson * Fix imports * Unexport function * Add godoc * Remove unneeded comment * Move ProxyHelloSignature to constants * Check that proxyline is verified * Use Warn() instead of Warnf() * Move ProxyHelloSignature to api/constants * Add timeout for getting host CA during proxyline verification. * Rearrange conditions Co-authored-by: Edoardo Spadolini * Clarify comment. --------- Co-authored-by: Michael Wilson Co-authored-by: Edoardo Spadolini --- api/constants/constants.go | 8 ++ api/observability/tracing/ssh/ssh.go | 3 +- api/utils/sshutils/ssh.go | 9 -- integration/integration_test.go | 5 +- lib/multiplexer/multiplexer.go | 20 ++-- lib/multiplexer/multiplexer_test.go | 132 +++++++++++++-------------- lib/multiplexer/proxyline.go | 22 +++-- lib/multiplexer/proxyline_test.go | 18 ++-- lib/reversetunnel/emit_conn.go | 10 +- lib/service/service.go | 37 ++++---- lib/srv/regular/proxy.go | 9 +- lib/srv/regular/sshserver.go | 17 ++++ lib/sshutils/server.go | 106 ++++++++++++++++++--- lib/sshutils/server_test.go | 124 +++++++++++++++++++++++++ 14 files changed, 374 insertions(+), 146 deletions(-) diff --git a/api/constants/constants.go b/api/constants/constants.go index 7d2b7e9414be1..adfd5e4932ef5 100644 --- a/api/constants/constants.go +++ b/api/constants/constants.go @@ -380,6 +380,14 @@ const ( // allowed GCP service accounts. TraitGCPServiceAccounts = "gcp_service_accounts" ) +const ( + // ProxyHelloSignature is a string which Teleport proxy will send + // right after the initial SSH "handshake/version" message if it detects + // talking to a Teleport server. + // + // This is also leveraged by tsh to propagate its tracing span ID. + ProxyHelloSignature = "Teleport-Proxy" +) const ( // TimeoutGetClusterAlerts is the timeout for grabbing cluster alerts from tctl and tsh diff --git a/api/observability/tracing/ssh/ssh.go b/api/observability/tracing/ssh/ssh.go index cab328923e948..8352366a3578b 100644 --- a/api/observability/tracing/ssh/ssh.go +++ b/api/observability/tracing/ssh/ssh.go @@ -29,6 +29,7 @@ import ( oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/crypto/ssh" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/utils/sshutils" ) @@ -136,7 +137,7 @@ func NewClientConn(ctx context.Context, conn net.Conn, addr string, config *ssh. if len(hp.TracingContext) > 0 { payloadJSON, err := json.Marshal(hp) if err == nil { - payload := fmt.Sprintf("%s%s\x00", sshutils.ProxyHelloSignature, payloadJSON) + payload := fmt.Sprintf("%s%s\x00", constants.ProxyHelloSignature, payloadJSON) if _, err := conn.Write([]byte(payload)); err != nil { log.WithError(err).Warnf("Failed to pass along tracing context to proxy %v", addr) } diff --git a/api/utils/sshutils/ssh.go b/api/utils/sshutils/ssh.go index 2cf117831325b..d268d8cd6e538 100644 --- a/api/utils/sshutils/ssh.go +++ b/api/utils/sshutils/ssh.go @@ -32,15 +32,6 @@ import ( "github.com/gravitational/teleport/api/defaults" ) -const ( - // ProxyHelloSignature is a string which Teleport proxy will send - // right after the initial SSH "handshake/version" message if it detects - // talking to a Teleport server. - // - // This is also leveraged by tsh to propagate its tracing span ID. - ProxyHelloSignature = "Teleport-Proxy" -) - // HandshakePayload structure is sent as a JSON blob by the teleport // proxy to every SSH server who identifies itself as Teleport server // diff --git a/integration/integration_test.go b/integration/integration_test.go index 2e24d8e47986b..22867a42b5df0 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -4201,8 +4201,11 @@ func testProxyHostKeyCheck(t *testing.T, suite *integrationTestSuite) { instance := suite.NewTeleportWithConfig(makeConfig()) defer instance.StopAll() + caGetter := func(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) { + return instance.Process.GetAuthServer().Cache.GetCertAuthority(ctx, id, loadKeys) + } proxyEnabledListener, err := helpers.CreatePROXYEnabledListener(context.Background(), t, net.JoinHostPort(Host, strconv.Itoa(nodePort)), - instance.Process.GetAuthServer().Cache, instance.Secrets.SiteName) + caGetter, instance.Secrets.SiteName) require.NoError(t, err) sshNode, err := helpers.NewDiscardServer(hostSigner, proxyEnabledListener) diff --git a/lib/multiplexer/multiplexer.go b/lib/multiplexer/multiplexer.go index 59267e6c514d0..c56f8d50818ce 100644 --- a/lib/multiplexer/multiplexer.go +++ b/lib/multiplexer/multiplexer.go @@ -39,11 +39,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/jwt" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -54,9 +53,8 @@ var ( ) // CertAuthorityGetter allows to get cluster's host CA for verification of signed PROXY headers. -type CertAuthorityGetter interface { - GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) -} +// We define our own version to not create dependency on the 'services' package, which causes circular references +type CertAuthorityGetter = func(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) // Config is a multiplexer config type Config struct { @@ -422,7 +420,7 @@ func (m *Mux) detect(conn net.Conn) (*Conn, error) { continue } - if m.CertAuthorityGetter != nil && newProxyLine.isSigned() && !newProxyLine.IsVerified { + if m.CertAuthorityGetter != nil && newProxyLine.IsSigned() && !newProxyLine.IsVerified { return nil, trace.BadParameter("could not verify proxy line signature") } @@ -501,10 +499,10 @@ func (p Protocol) String() string { var ( proxyPrefix = []byte{'P', 'R', 'O', 'X', 'Y'} - proxyV2Prefix = []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} + ProxyV2Prefix = []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} sshPrefix = []byte{'S', 'S', 'H'} tlsPrefix = []byte{0x16} - proxyHelloPrefix = []byte(sshutils.ProxyHelloSignature) + proxyHelloPrefix = []byte(constants.ProxyHelloSignature) ) // This section defines Postgres wire protocol messages detected by Teleport: @@ -564,15 +562,15 @@ func detectProto(r *bufio.Reader) (Protocol, error) { switch { case bytes.HasPrefix(in, proxyPrefix): return ProtoProxy, nil - case bytes.HasPrefix(in, proxyV2Prefix[:8]): + case bytes.HasPrefix(in, ProxyV2Prefix[:8]): // if the first 8 bytes matches the first 8 bytes of the proxy // protocol v2 magic bytes, read more of the connection so we can // ensure all magic bytes match - in, err = r.Peek(len(proxyV2Prefix)) + in, err = r.Peek(len(ProxyV2Prefix)) if err != nil { return ProtoUnknown, trace.Wrap(err, "failed to peek connection") } - if bytes.HasPrefix(in, proxyV2Prefix) { + if bytes.HasPrefix(in, ProxyV2Prefix) { return ProtoProxyV2, nil } case bytes.HasPrefix(in, proxyHelloPrefix[:8]): diff --git a/lib/multiplexer/multiplexer_test.go b/lib/multiplexer/multiplexer_test.go index cb1d914ef55de..c66142b2694c6 100644 --- a/lib/multiplexer/multiplexer_test.go +++ b/lib/multiplexer/multiplexer_test.go @@ -23,6 +23,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/pem" "errors" "fmt" "io" @@ -36,6 +37,7 @@ import ( "github.com/jackc/pgproto3/v2" "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "google.golang.org/grpc" @@ -49,8 +51,6 @@ import ( "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/jwt" "github.com/gravitational/teleport/lib/multiplexer/test" - "github.com/gravitational/teleport/lib/services" - "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/cert" @@ -93,36 +93,20 @@ func TestMux(t *testing.T) { backend1.StartTLS() defer backend1.Close() - called := false - sshHandler := sshutils.NewChanHandlerFunc(func(_ context.Context, _ *sshutils.ConnectionContext, nch ssh.NewChannel) { - called = true - err := nch.Reject(ssh.Prohibited, "nothing to see here") - require.NoError(t, err) - }) + go startSSHServer(t, mux.SSH()) - srv, err := sshutils.NewServer( - "test", - utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, - sshHandler, - []ssh.Signer{signer}, - sshutils.AuthMethods{Password: pass("abc123")}, - ) - require.NoError(t, err) - go srv.Serve(mux.SSH()) - defer srv.Close() clt, err := ssh.Dial("tcp", listener.Addr().String(), &ssh.ClientConfig{ - Auth: []ssh.AuthMethod{ssh.Password("abc123")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: time.Second, - HostKeyCallback: ssh.FixedHostKey(signer.PublicKey()), }) require.NoError(t, err) defer clt.Close() - // call new session to initiate opening new channel - _, err = clt.NewSession() - require.NotNil(t, err) - // make sure the channel handler was called OK - require.True(t, called) + // Make sure the SSH connection works correctly + ok, response, err := clt.SendRequest("echo", true, []byte("beep")) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "beep", string(response)) client := testClient(backend1) re, err := client.Get(backend1.URL) @@ -424,36 +408,20 @@ func TestMux(t *testing.T) { backend1.StartTLS() defer backend1.Close() - called := false - sshHandler := sshutils.NewChanHandlerFunc(func(_ context.Context, _ *sshutils.ConnectionContext, nch ssh.NewChannel) { - called = true - err := nch.Reject(ssh.Prohibited, "nothing to see here") - require.NoError(t, err) - }) + go startSSHServer(t, mux.SSH()) - srv, err := sshutils.NewServer( - "test", - utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, - sshHandler, - []ssh.Signer{signer}, - sshutils.AuthMethods{Password: pass("abc123")}, - ) - require.NoError(t, err) - go srv.Serve(mux.SSH()) - defer srv.Close() clt, err := ssh.Dial("tcp", listener.Addr().String(), &ssh.ClientConfig{ - Auth: []ssh.AuthMethod{ssh.Password("abc123")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: time.Second, - HostKeyCallback: ssh.FixedHostKey(signer.PublicKey()), }) require.NoError(t, err) defer clt.Close() - // call new session to initiate opening new channel - _, err = clt.NewSession() - require.NotNil(t, err) - // make sure the channel handler was called OK - require.Equal(t, called, true) + // Make sure the SSH connection works correctly + ok, response, err := clt.SendRequest("echo", true, []byte("beep")) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "beep", string(response)) client := testClient(backend1) re, err := client.Get(backend1.URL) @@ -973,14 +941,6 @@ func TestMux(t *testing.T) { }) } -type mockCAsGetter struct { - HostCA types.CertAuthority -} - -func (m *mockCAsGetter) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) { - return m.HostCA, nil -} - func TestProtocolString(t *testing.T) { for i := -1; i < len(protocolStrings)+1; i++ { got := Protocol(i).String() @@ -1027,15 +987,6 @@ func testClient(srv *httptest.Server) *http.Client { } } -func pass(need string) sshutils.PasswordFunc { - return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { - if string(password) == need { - return nil, nil - } - return nil, fmt.Errorf("passwords don't match") - } -} - type noopListener struct { addr net.Addr } @@ -1096,8 +1047,10 @@ func getTestCertCAsGetterAndSigner(t testing.TB, clusterName string) ([]byte, Ce }, }) require.NoError(t, err) - mockCAsGetter := &mockCAsGetter{HostCA: ca} + mockCAGetter := func(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) { + return ca, nil + } proxyPriv, err := rsa.GenerateKey(rand.Reader, constants.RSAKeySize) require.NoError(t, err) @@ -1120,22 +1073,61 @@ func getTestCertCAsGetterAndSigner(t testing.TB, clusterName string) ([]byte, Ce tlsProxyCertPEM, err := tlsCA.GenerateCertificate(certReq) require.NoError(t, err) clock := clockwork.NewFakeClockAt(time.Now()) - jwtSigner, err := services.GetJWTSigner(proxyPriv, clusterName, clock) + jwtSigner, err := jwt.New(&jwt.Config{ + Clock: clock, + Algorithm: defaults.ApplicationTokenAlgorithm, + ClusterName: clusterName, + PrivateKey: proxyPriv, + }) require.NoError(t, err) tlsProxyCertDER, err := tlsca.ParseCertificatePEM(tlsProxyCertPEM) require.NoError(t, err) - return tlsProxyCertDER.Raw, mockCAsGetter, jwtSigner + return tlsProxyCertDER.Raw, mockCAGetter, jwtSigner +} + +func startSSHServer(t *testing.T, listener net.Listener) { + nConn, err := listener.Accept() + assert.NoError(t, err) + + t.Cleanup(func() { nConn.Close() }) + + block, _ := pem.Decode(fixtures.LocalhostKey) + pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + assert.NoError(t, err) + + signer, err := ssh.NewSignerFromKey(pkey) + assert.NoError(t, err) + + config := &ssh.ServerConfig{NoClientAuth: true} + config.AddHostKey(signer) + + conn, _, reqs, err := ssh.NewServerConn(nConn, config) + assert.NoError(t, err) + if err != nil { + return + } + t.Cleanup(func() { conn.Close() }) + + go func() { + for newReq := range reqs { + if newReq.Type == "echo" { + newReq.Reply(true, newReq.Payload) + } + err := newReq.Reply(false, nil) + assert.NoError(t, err) + } + }() } func BenchmarkMux_ProxyV2Signature(b *testing.B) { const clusterName = "test-teleport" clock := clockwork.NewFakeClockAt(time.Now()) - tlsProxyCert, casGetter, jwtSigner := getTestCertCAsGetterAndSigner(b, clusterName) + tlsProxyCert, caGetter, jwtSigner := getTestCertCAsGetterAndSigner(b, clusterName) - ca, err := casGetter.GetCertAuthority(context.Background(), types.CertAuthID{ + ca, err := caGetter(context.Background(), types.CertAuthID{ Type: types.HostCA, DomainName: clusterName, }, false) diff --git a/lib/multiplexer/proxyline.go b/lib/multiplexer/proxyline.go index ba2372291f7eb..2c7959724d181 100644 --- a/lib/multiplexer/proxyline.go +++ b/lib/multiplexer/proxyline.go @@ -40,7 +40,6 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/jwt" - "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/tlsca" ) @@ -116,7 +115,7 @@ func (p *ProxyLine) String() string { func (p *ProxyLine) Bytes() ([]byte, error) { b := &bytes.Buffer{} header := proxyV2Header{VersionCommand: (Version2 << 4) | ProxyCommand} - copy(header.Signature[:], proxyV2Prefix) + copy(header.Signature[:], ProxyV2Prefix) var addr interface{} if p.Source.Port < 0 || p.Destination.Port < 0 || p.Source.Port > math.MaxUint16 || p.Destination.Port > math.MaxUint16 { @@ -278,7 +277,7 @@ func ReadProxyLineV2(reader *bufio.Reader) (*ProxyLine, error) { if err := binary.Read(reader, binary.BigEndian, &header); err != nil { return nil, trace.Wrap(err) } - if !bytes.Equal(header.Signature[:], proxyV2Prefix) { + if !bytes.Equal(header.Signature[:], ProxyV2Prefix) { return nil, trace.BadParameter("unrecognized signature %s", hex.EncodeToString(header.Signature[:])) } cmd, ver := header.VersionCommand&0xF, header.VersionCommand>>4 @@ -427,9 +426,9 @@ func (p *ProxyLine) AddSignature(signature, signingCert []byte) error { return nil } -// isSigned returns true if proxy line's TLV contains signature. +// IsSigned returns true if proxy line's TLV contains signature. // Does not take into account if signature is valid or not, just the presence of it. -func (p *ProxyLine) isSigned() bool { +func (p *ProxyLine) IsSigned() bool { token, proxyCert, _ := p.getSignatureAndSigningCert() return len(token) > 0 || proxyCert != nil } @@ -489,14 +488,14 @@ func (p *ProxyLine) VerifySignature(ctx context.Context, caGetter CertAuthorityG identity.TeleportCluster, localClusterName) } - hostCA, err := caGetter.GetCertAuthority(ctx, types.CertAuthID{ + hostCA, err := caGetter(ctx, types.CertAuthID{ Type: types.HostCA, DomainName: localClusterName, }, false) if err != nil { return trace.Wrap(ErrNoHostCA, "CA cluster name: %s", localClusterName) } - hostCACerts := services.GetTLSCerts(hostCA) + hostCACerts := getTLSCerts(hostCA) roots := x509.NewCertPool() for _, cert := range hostCACerts { @@ -541,6 +540,15 @@ func (p *ProxyLine) VerifySignature(ctx context.Context, caGetter CertAuthorityG return nil } +func getTLSCerts(ca types.CertAuthority) [][]byte { + pairs := ca.GetTrustedTLSKeyPairs() + out := make([][]byte, len(pairs)) + for i, pair := range pairs { + out[i] = append([]byte{}, pair.Cert...) + } + return out +} + func checkForSystemRole(identity *tlsca.Identity, roleToFind types.SystemRole) bool { findRole := func(roles []string) bool { for _, role := range roles { diff --git a/lib/multiplexer/proxyline_test.go b/lib/multiplexer/proxyline_test.go index e9d8201c25142..a316e6a6dce1b 100644 --- a/lib/multiplexer/proxyline_test.go +++ b/lib/multiplexer/proxyline_test.go @@ -46,13 +46,13 @@ var ( // source=127.0.0.1:12345 destination=127.0.0.2:42 sampleIPv4Addresses = []byte{0x7F, 0x00, 0x00, 0x01, 0x7F, 0x00, 0x00, 0x02, 0x30, 0x39, 0x00, 0x2A} // {0x21, 0x11, 0x00, 0x0C} - 4 bits version, 4 bits command, 4 bits address family, 4 bits protocol, 16 bits length - sampleProxyV2Line = bytes.Join([][]byte{proxyV2Prefix, {0x21, 0x11, 0x00, 0x0C}, sampleIPv4Addresses}, nil) + sampleProxyV2Line = bytes.Join([][]byte{ProxyV2Prefix, {0x21, 0x11, 0x00, 0x0C}, sampleIPv4Addresses}, nil) // Proxy line with LOCAL command - sampleProxyV2LineLocal = bytes.Join([][]byte{proxyV2Prefix, {0x20, 0x11, 0x00, 0x00}}, nil) + sampleProxyV2LineLocal = bytes.Join([][]byte{ProxyV2Prefix, {0x20, 0x11, 0x00, 0x00}}, nil) sampleTLV = []byte{byte(PP2TypeTeleport), 0x00, 0x03, 0x01, 0x02, 0x03} sampleEmptyTLV = []byte{byte(PP2TypeTeleport), 0x00, 0x00} - sampleProxyV2LineTLV = bytes.Join([][]byte{proxyV2Prefix, {0x21, 0x11, 0x00, 0x12}, sampleIPv4Addresses, sampleTLV}, nil) - sampleProxyV2LineEmptyTLV = bytes.Join([][]byte{proxyV2Prefix, {0x21, 0x11, 0x00, 0x0F}, sampleIPv4Addresses, sampleEmptyTLV}, nil) + sampleProxyV2LineTLV = bytes.Join([][]byte{ProxyV2Prefix, {0x21, 0x11, 0x00, 0x12}, sampleIPv4Addresses, sampleTLV}, nil) + sampleProxyV2LineEmptyTLV = bytes.Join([][]byte{ProxyV2Prefix, {0x21, 0x11, 0x00, 0x0F}, sampleIPv4Addresses, sampleEmptyTLV}, nil) ) func TestReadProxyLine(t *testing.T) { @@ -94,7 +94,7 @@ func TestReadProxyLineV2(t *testing.T) { require.ErrorContains(t, err, "unrecognized signature") }) t.Run("malformed PROXY v2 header", func(t *testing.T) { - _, err := ReadProxyLineV2(bufio.NewReader(bytes.NewReader(append(proxyV2Prefix, []byte("JIBBERISH")...)))) + _, err := ReadProxyLineV2(bufio.NewReader(bytes.NewReader(append(ProxyV2Prefix, []byte("JIBBERISH")...)))) require.ErrorContains(t, err, "unsupported version") }) t.Run("successfully read proxy v2 line without TLV", func(t *testing.T) { @@ -444,7 +444,7 @@ func TestProxyLine_VerifySignature(t *testing.T) { }) require.NoError(t, err) - ca, err := casGetter.GetCertAuthority(context.Background(), types.CertAuthID{ + ca, err := casGetter(context.Background(), types.CertAuthID{ Type: types.HostCA, DomainName: clusterName, }, false) @@ -570,10 +570,12 @@ func TestProxyLine_VerifySignature(t *testing.T) { }, }, }) - mockCAGetter := &mockCAsGetter{HostCA: ca} - require.NoError(t, err) + mockCAGetter := func(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) { + return ca, nil + } + err = pl.VerifySignature(context.Background(), mockCAGetter, tt.localClusterName, clock) if tt.wantErr != "" { require.ErrorContains(t, err, tt.wantErr) diff --git a/lib/reversetunnel/emit_conn.go b/lib/reversetunnel/emit_conn.go index 4f7c407aaafc6..25a8b1044f17e 100644 --- a/lib/reversetunnel/emit_conn.go +++ b/lib/reversetunnel/emit_conn.go @@ -22,8 +22,8 @@ import ( "net" "sync" + "github.com/gravitational/teleport/api/constants" apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/events" ) @@ -59,7 +59,7 @@ func (conn *emitConn) Read(p []byte) (int, error) { n, err := conn.Conn.Read(p) // Skip buffering if already could have emitted or will never emit. - if err != nil || conn.buffer.Len() == len(sshutils.ProxyHelloSignature) || conn.serverID == "" { + if err != nil || conn.buffer.Len() == len(constants.ProxyHelloSignature) || conn.serverID == "" { conn.mu.RUnlock() return n, err } @@ -68,15 +68,15 @@ func (conn *emitConn) Read(p []byte) (int, error) { conn.mu.Lock() defer conn.mu.Unlock() - remaining := len(sshutils.ProxyHelloSignature) - conn.buffer.Len() + remaining := len(constants.ProxyHelloSignature) - conn.buffer.Len() _, err = conn.buffer.Write(p[:min(n, remaining)]) if err != nil { return n, err } // Only emit when we don't see the proxy hello signature in the first few bytes. - if conn.buffer.Len() == len(sshutils.ProxyHelloSignature) && - !bytes.HasPrefix(conn.buffer.Bytes(), []byte(sshutils.ProxyHelloSignature)) { + if conn.buffer.Len() == len(constants.ProxyHelloSignature) && + !bytes.HasPrefix(conn.buffer.Bytes(), []byte(constants.ProxyHelloSignature)) { event := &apievents.SessionConnect{ Metadata: apievents.Metadata{ Type: events.SessionConnectEvent, diff --git a/lib/service/service.go b/lib/service/service.go index 70aa72278e2b5..17261e8342f18 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -1673,12 +1673,15 @@ func (process *TeleportProcess) initAuthService() error { log.Infof("Starting Auth service with external PROXY protocol support.") } + muxCAGetter := func(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) { + return authServer.GetCertAuthority(ctx, id, loadKeys) + } // use multiplexer to leverage support for proxy protocol. mux, err := multiplexer.New(multiplexer.Config{ EnableExternalProxyProtocol: cfg.Auth.EnableProxyProtocol, Listener: listener, ID: teleport.Component(process.id), - CertAuthorityGetter: authServer, + CertAuthorityGetter: muxCAGetter, LocalClusterName: clusterName, }) if err != nil { @@ -2311,6 +2314,10 @@ func (process *TeleportProcess) initSSH() error { return trace.Wrap(err) } + caGetter := func(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) { + return authClient.GetCertAuthority(ctx, id, loadKeys) + } + s, err := regular.New( process.ExitContext(), cfg.SSH.Addr, @@ -2347,6 +2354,7 @@ func (process *TeleportProcess) initSSH() error { regular.SetInventoryControlHandle(process.inventoryHandle), regular.SetTracerProvider(process.TracingProvider), regular.SetSessionController(sessionController), + regular.SetCAGetter(caGetter), ) if err != nil { return trace.Wrap(err) @@ -2366,20 +2374,7 @@ func (process *TeleportProcess) initSSH() error { utils.Consolef(cfg.Console, log, teleport.ComponentNode, "Service %s:%s is starting on %v.", teleport.Version, teleport.Gitref, cfg.SSH.Addr.Addr) - sshListener, err := multiplexer.NewPROXYEnabledListener(multiplexer.Config{ - Listener: listener, - Context: process.ExitContext(), - EnableExternalProxyProtocol: false, - ID: teleport.Component(teleport.ComponentNode, fmt.Sprintf("%d", time.Now().Unix())), - CertAuthorityGetter: conn.Client, - LocalClusterName: clusterName.GetClusterName(), - }) - if err != nil { - return trace.Wrap(err) - } - // Start the SSH server. This kicks off updating labels, starting the - // heartbeat, and accepting connections. - go s.Serve(limiter.WrapListener(sshListener)) + go s.Serve(limiter.WrapListener(listener)) } else { // Start the SSH server. This kicks off updating labels and starting the // heartbeat. @@ -3076,6 +3071,10 @@ func (process *TeleportProcess) setupProxyListeners(networkingConfig types.Clust var err error var listeners proxyListeners + muxCAGetter := func(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) { + return accessPoint.GetCertAuthority(ctx, id, loadKeys) + } + if !cfg.Proxy.SSHAddr.IsEmpty() { l, err := process.importOrCreateListener(ListenerProxySSH, cfg.Proxy.SSHAddr.Addr) if err != nil { @@ -3086,7 +3085,7 @@ func (process *TeleportProcess) setupProxyListeners(networkingConfig types.Clust Listener: l, EnableExternalProxyProtocol: cfg.Proxy.EnableProxyProtocol, ID: teleport.Component(teleport.ComponentProxy, "ssh"), - CertAuthorityGetter: accessPoint, + CertAuthorityGetter: muxCAGetter, LocalClusterName: clusterName, }) if err != nil { @@ -3180,7 +3179,7 @@ func (process *TeleportProcess) setupProxyListeners(networkingConfig types.Clust EnableExternalProxyProtocol: cfg.Proxy.EnableProxyProtocol, Listener: listener, ID: teleport.Component(teleport.ComponentProxy, "tunnel", "web", process.id), - CertAuthorityGetter: accessPoint, + CertAuthorityGetter: muxCAGetter, LocalClusterName: clusterName, }) if err != nil { @@ -3210,7 +3209,7 @@ func (process *TeleportProcess) setupProxyListeners(networkingConfig types.Clust EnableExternalProxyProtocol: cfg.Proxy.EnableProxyProtocol, Listener: listener, ID: teleport.Component(teleport.ComponentProxy, "web", process.id), - CertAuthorityGetter: accessPoint, + CertAuthorityGetter: muxCAGetter, LocalClusterName: clusterName, }) if err != nil { @@ -3264,7 +3263,7 @@ func (process *TeleportProcess) setupProxyListeners(networkingConfig types.Clust EnableExternalProxyProtocol: cfg.Proxy.EnableProxyProtocol, Listener: listener, ID: teleport.Component(teleport.ComponentProxy, "web", process.id), - CertAuthorityGetter: accessPoint, + CertAuthorityGetter: muxCAGetter, LocalClusterName: clusterName, }) if err != nil { diff --git a/lib/srv/regular/proxy.go b/lib/srv/regular/proxy.go index 179c3e8c83d9a..174eaa92bcb2f 100644 --- a/lib/srv/regular/proxy.go +++ b/lib/srv/regular/proxy.go @@ -30,8 +30,10 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/observability/tracing" + "github.com/gravitational/teleport/api/types" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/srv" @@ -44,6 +46,11 @@ type PROXYHeaderSigner interface { SignPROXYHeader(source, destination net.Addr) ([]byte, error) } +// CertAuthorityGetter allows to get cluster's host CA for verification of signed PROXY headers. +// We define our own version to avoid circular dependencies in multiplexer package (it can't depend on 'services'), +// where this function is used. +type CertAuthorityGetter = func(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) + // proxySubsys implements an SSH subsystem for proxying listening sockets from // remote hosts to a proxy client (AKA port mapping) type proxySubsys struct { @@ -302,7 +309,7 @@ func (t *proxySubsys) doHandshake(ctx context.Context, clientAddr net.Addr, clie t.log.Error(err) } else { // send a JSON payload sandwiched between 'teleport proxy signature' and 0x00: - payload := fmt.Sprintf("%s%s\x00", apisshutils.ProxyHelloSignature, payloadJSON) + payload := fmt.Sprintf("%s%s\x00", constants.ProxyHelloSignature, payloadJSON) _, err = serverConn.Write([]byte(payload)) if err != nil { t.log.Error(err) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index c72103952e9fa..22a15e89f6dfc 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -229,6 +229,8 @@ type Server struct { // proxySigner is used to generate signed PROXYv2 header so we can securely propagate client IP proxySigner PROXYHeaderSigner + // caGetter is used to get host CA of the cluster to verify signed PROXY headers + caGetter CertAuthorityGetter } // TargetMetadata returns metadata about the server. @@ -691,6 +693,14 @@ func SetPROXYSigner(proxySigner PROXYHeaderSigner) ServerOption { } } +// SetCAGetter sets the cert authority getter +func SetCAGetter(caGetter CertAuthorityGetter) ServerOption { + return func(s *Server) error { + s.caGetter = caGetter + return nil + } +} + // New returns an unstarted server func New( ctx context.Context, @@ -813,6 +823,11 @@ func New( SessionRegistry: s.reg, } + clusterName, err := s.GetAccessPoint().GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + server, err := sshutils.NewServer( component, addr, s, signers, @@ -826,6 +841,8 @@ func New( sshutils.SetFIPS(s.fips), sshutils.SetClock(s.clock), sshutils.SetIngressReporter(s.ingressService, s.ingressReporter), + sshutils.SetCAGetter(s.caGetter), + sshutils.SetClusterName(clusterName.GetClusterName()), ) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/sshutils/server.go b/lib/sshutils/server.go index e4bab6b530f8a..7c96aac2fc385 100644 --- a/lib/sshutils/server.go +++ b/lib/sshutils/server.go @@ -19,6 +19,7 @@ limitations under the License. package sshutils import ( + "bufio" "bytes" "context" "encoding/json" @@ -26,6 +27,7 @@ import ( "fmt" "io" "net" + "strings" "sync" "sync/atomic" "time" @@ -39,11 +41,14 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/limiter" + "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/srv/ingress" "github.com/gravitational/teleport/lib/utils" @@ -102,6 +107,10 @@ type Server struct { // clock is used to control time. clock clockwork.Clock + caGetter CertAuthorityGetter + + clusterName string + // ingressReporter reports new and active connections. ingressReporter *ingress.Reporter // ingressService the service name passed to the ingress reporter. @@ -122,6 +131,10 @@ const ( // TrueClientAddrVar environment variable is used by the web UI to pass // the remote IP (user's IP) from the browser/HTTP session into an SSH session TrueClientAddrVar = "TELEPORT_CLIENT_ADDR" + + // caGetterTimeout is the timeout on getting host cert authority, that is used in + // signed PROXY headers verification. + caGetterTimeout = 5 * time.Second ) // ServerOption is a functional argument for server @@ -184,6 +197,21 @@ func SetClock(clock clockwork.Clock) ServerOption { } } +// SetCAGetter sets the cert authority getter +func SetCAGetter(caGetter CertAuthorityGetter) ServerOption { + return func(s *Server) error { + s.caGetter = caGetter + return nil + } +} + +func SetClusterName(clusterName string) ServerOption { + return func(s *Server) error { + s.clusterName = clusterName + return nil + } +} + func NewServer( component string, a utils.NetAddr, @@ -461,7 +489,7 @@ func (s *Server) HandleConnection(conn net.Conn) { // create a new SSH server which handles the handshake (and pass the custom // payload structure which will be populated only when/if this connection // comes from another Teleport proxy): - wrappedConn := WrapConnection(wconn, s.log) + wrappedConn := wrapConnection(wconn, s.caGetter, s.clusterName, s.clock, s.log) sconn, chans, reqs, err := ssh.NewServerConn(wrappedConn, &s.cfg) if err != nil { // Ignore EOF as these are triggered by loadbalancer health checks @@ -745,13 +773,13 @@ type ClusterDetails struct { FIPSEnabled bool } -// ConnectionWrapper allows the SSH server to perform custom handshake which +// connectionWrapper allows the SSH server to perform custom handshake which // lets teleport proxy servers to relay a true remote client IP address // to the SSH server. // // (otherwise connection.RemoteAddr (client IP) will always point to a proxy IP // instead of a true client IP) -type ConnectionWrapper struct { +type connectionWrapper struct { net.Conn logger logrus.FieldLogger @@ -766,16 +794,22 @@ type ConnectionWrapper struct { // traceContext is the tracing context that was passed across the // connection, used to correlate spans. traceContext tracing.PropagationContext + + caGetter CertAuthorityGetter + clusterName string + clock clockwork.Clock } // RemoteAddr returns the behind-the-proxy client address -func (c *ConnectionWrapper) RemoteAddr() net.Addr { +func (c *connectionWrapper) RemoteAddr() net.Addr { return c.clientAddr } // Read implements io.Read() part of net.Connection which allows us // peek at the beginning of SSH handshake (that's why we're wrapping the connection) -func (c *ConnectionWrapper) Read(b []byte) (int, error) { +// DELETE IN 14.0: we need to keep it for compatibility purposes, but in 14.0 we can remove this +// connection wrapper and instead use multiplexer listener for SSH node. +func (c *connectionWrapper) Read(b []byte) (int, error) { // handshake already took place, forward upstream: if c.upstreamReader != nil { return c.upstreamReader.Read(b) @@ -796,12 +830,12 @@ func (c *ConnectionWrapper) Read(b []byte) (int, error) { skip := 0 // are we reading from a Teleport proxy? - if bytes.HasPrefix(buff, []byte(sshutils.ProxyHelloSignature)) { + if bytes.HasPrefix(buff, []byte(constants.ProxyHelloSignature)) { // the JSON payload ends with a binary zero: payloadBoundary := bytes.IndexByte(buff, 0x00) if payloadBoundary > 0 { var hp sshutils.HandshakePayload - payload := buff[len(sshutils.ProxyHelloSignature):payloadBoundary] + payload := buff[len(constants.ProxyHelloSignature):payloadBoundary] if err = json.Unmarshal(payload, &hp); err != nil { c.logger.Error(err) } else { @@ -810,16 +844,60 @@ func (c *ConnectionWrapper) Read(b []byte) (int, error) { skip = payloadBoundary + 1 } } - c.upstreamReader = io.MultiReader(bytes.NewBuffer(buff[skip:]), c.Conn) + if bytes.HasPrefix(buff, multiplexer.ProxyV2Prefix) { + reader := bufio.NewReader(io.MultiReader(bytes.NewBuffer(buff), c.Conn)) + + proxyLine, err := multiplexer.ReadProxyLineV2(reader) + if err != nil { + return 0, trace.Wrap(err) + } + if c.caGetter != nil && proxyLine != nil && proxyLine.IsSigned() { + ctx, cancel := context.WithTimeout(context.Background(), caGetterTimeout) + defer cancel() + + err = proxyLine.VerifySignature(ctx, c.caGetter, c.clusterName, c.clock) + // NOTE(anton): Temporarily using string comparison here to not create circular references. + // Will be refactored after #21835 is resolved. + if err != nil { + if strings.Contains(err.Error(), "could not get specified host CA to verify signed PROXY header") { + c.logger.WithFields(logrus.Fields{ + "src_addr": c.Conn.RemoteAddr(), + "dst_addr": c.Conn.LocalAddr(), + }).Warn("Could not verify PROXY signature for connection - could not get host CA") + } else if strings.Contains(err.Error(), "signing certificate is not signed by local cluster CA") { + c.logger.WithFields(logrus.Fields{ + "src_addr": c.Conn.RemoteAddr(), + "dst_addr": c.Conn.LocalAddr(), + }).Warn("Could not verify PROXY signature for connection - signed by non local cluster") + } else { + return 0, trace.Wrap(err) + } + } + if proxyLine.IsVerified { + c.clientAddr = &proxyLine.Source + } + } + + c.upstreamReader = reader + } + + if c.upstreamReader == nil { + c.upstreamReader = io.MultiReader(bytes.NewBuffer(buff[skip:]), c.Conn) + } return c.upstreamReader.Read(b) } -// WrapConnection takes a network connection, wraps it into ConnectionWrapper +type CertAuthorityGetter = func(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) + +// wrapConnection takes a network connection, wraps it into connectionWrapper // object (which overrides Read method) and returns the wrapper. -func WrapConnection(conn net.Conn, logger logrus.FieldLogger) *ConnectionWrapper { - return &ConnectionWrapper{ - Conn: conn, - clientAddr: conn.RemoteAddr(), - logger: logger, +func wrapConnection(conn net.Conn, caGetter CertAuthorityGetter, localClusterName string, clock clockwork.Clock, logger logrus.FieldLogger) *connectionWrapper { + return &connectionWrapper{ + Conn: conn, + clientAddr: conn.RemoteAddr(), + caGetter: caGetter, + clusterName: localClusterName, + clock: clock, + logger: logger, } } diff --git a/lib/sshutils/server_test.go b/lib/sshutils/server_test.go index 7a2546885de7f..ca56b6c6955b3 100644 --- a/lib/sshutils/server_test.go +++ b/lib/sshutils/server_test.go @@ -17,17 +17,28 @@ limitations under the License. package sshutils import ( + "bytes" "context" + "crypto/x509" + "encoding/json" + "encoding/pem" "fmt" + "net" "os" "testing" "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/observability/tracing" + apisshutils "github.com/gravitational/teleport/api/utils/sshutils" + "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/cert" ) @@ -252,6 +263,119 @@ func TestHostSignerFIPS(t *testing.T) { } } +// TestConnectionWrapper_Read makes sure connectionWrapper can correctly process ProxyHelloSignature and PROXY protocol +// on the wire. +func TestConnectionWrapper_Read(t *testing.T) { + testCases := []struct { + desc string + sendData []byte + }{ + { + desc: "Plain connection without any special headers", + sendData: nil, + }, + { + desc: "Sending ProxyHelloSignature", + sendData: getProxyHelloSignaturePayload(t), + }, + { + desc: "Sending PROXY header", + sendData: getPROXYProtocolPayload(), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { listener.Close() }) + + go startSSHServer(t, listener) + + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + + _, err = conn.Write(tc.sendData) + require.NoError(t, err) + + sconn, nc, r, err := ssh.NewClientConn(conn, "", &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: time.Second, + }) + require.NoError(t, err) + require.Equal(t, "SSH-2.0-Go", string(sconn.ServerVersion())) + + client := ssh.NewClient(sconn, nc, r) + require.NoError(t, err) + + // Make sure SSH connection works correctly + ok, response, err := client.SendRequest("echo", true, []byte("beep")) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "beep", string(response)) + }) + } +} + +func getPROXYProtocolPayload() []byte { + proxyV2Prefix := []byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} + // source=127.0.0.1:12345 destination=127.0.0.2:42 + sampleIPv4Addresses := []byte{0x7F, 0x00, 0x00, 0x01, 0x7F, 0x00, 0x00, 0x02, 0x30, 0x39, 0x00, 0x2A} + // {0x21, 0x11, 0x00, 0x0C} - 4 bits version, 4 bits command, 4 bits address family, 4 bits protocol, 16 bits length + sampleProxyV2Line := bytes.Join([][]byte{proxyV2Prefix, {0x21, 0x11, 0x00, 0x0C}, sampleIPv4Addresses}, nil) + + return sampleProxyV2Line +} + +func getProxyHelloSignaturePayload(t *testing.T) []byte { + t.Helper() + + hp := &apisshutils.HandshakePayload{ + ClientAddr: "127.0.0.1:12345", + TracingContext: tracing.PropagationContextFromContext(context.Background()), + } + payloadJSON, err := json.Marshal(hp) + require.NoError(t, err) + + return []byte(fmt.Sprintf("%s%s\x00", constants.ProxyHelloSignature, payloadJSON)) +} + +func startSSHServer(t *testing.T, listener net.Listener) { + nConn, err := listener.Accept() + assert.NoError(t, err) + + t.Cleanup(func() { nConn.Close() }) + + wConn := wrapConnection(nConn, nil, "", clockwork.NewRealClock(), logrus.New()) + + block, _ := pem.Decode(fixtures.LocalhostKey) + pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + assert.NoError(t, err) + + signer, err := ssh.NewSignerFromKey(pkey) + assert.NoError(t, err) + + config := &ssh.ServerConfig{NoClientAuth: true} + config.AddHostKey(signer) + + conn, _, reqs, err := ssh.NewServerConn(wConn, config) + assert.NoError(t, err) + if err != nil { + return + } + t.Cleanup(func() { conn.Close() }) + + go func() { + for newReq := range reqs { + if newReq.Type == "echo" { + newReq.Reply(true, newReq.Payload) + } + err := newReq.Reply(false, nil) + assert.NoError(t, err) + } + }() +} + func pass(need string) PasswordFunc { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if string(password) == need {