diff --git a/integration/app_integration_test.go b/integration/app_integration_test.go index 4d71ab0eb2850..f518610f38fcc 100644 --- a/integration/app_integration_test.go +++ b/integration/app_integration_test.go @@ -24,7 +24,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net" "net/http" "net/http/httptest" @@ -634,6 +633,92 @@ func TestAppAuditEvents(t *testing.T) { }) } +func TestAppServersHA(t *testing.T) { + testCases := map[string]struct { + publicAddr func(pack *pack) string + makeRequest func(pack *pack, inCookie string) (status int, err error) + }{ + "HTTPApp": { + publicAddr: func(pack *pack) string { return pack.rootAppPublicAddr }, + makeRequest: func(pack *pack, inCookie string) (int, error) { + status, _, err := pack.makeRequest(inCookie, http.MethodGet, "/") + return status, err + }, + }, + "WebSocketApp": { + publicAddr: func(pack *pack) string { return pack.rootWSPublicAddr }, + makeRequest: func(pack *pack, inCookie string) (int, error) { + _, err := pack.makeWebsocketRequest(inCookie, "/") + return 0, err + }, + }, + } + + // asserts that the response has error. + responseWithError := func(t *testing.T, status int, err error) { + if status > 0 { + require.NoError(t, err) + require.Equal(t, http.StatusInternalServerError, status) + return + } + + require.Error(t, err) + } + // asserts that the response has no errors. + responseWithoutError := func(t *testing.T, status int, err error) { + if status > 0 { + require.NoError(t, err) + require.Equal(t, http.StatusOK, status) + return + } + + require.NoError(t, err) + } + + for name, test := range testCases { + t.Run(name, func(t *testing.T) { + pack := setupWithOptions(t, appTestOptions{rootAppServersCount: 3}) + inCookie := pack.createAppSession(t, test.publicAddr(pack), pack.rootAppClusterName) + + status, err := test.makeRequest(pack, inCookie) + responseWithoutError(t, status, err) + + // Stop all root app servers. + for i, appServer := range pack.rootAppServers { + appServer.Close() + + // issue a request right after a server is gone. + status, err = test.makeRequest(pack, inCookie) + if i == len(pack.rootAppServers)-1 { + // fails only when the last one is closed. + responseWithError(t, status, err) + } else { + // otherwise the request should be handled by another + // server. + responseWithoutError(t, status, err) + } + } + + servers := pack.startRootAppServers(t, 3, []service.App{}) + status, err = test.makeRequest(pack, inCookie) + responseWithoutError(t, status, err) + + // Start an additional app server and stop all current running + // ones. + pack.startRootAppServers(t, 1, []service.App{}) + for _, appServer := range servers { + appServer.Close() + + // Everytime a app server stops we issue a request to + // guarantee that the requests are going to be resolved by + // the remaining app servers. + status, err = test.makeRequest(pack, inCookie) + responseWithoutError(t, status, err) + } + }) + } +} + // pack contains identity as well as initialized Teleport clusters and instances. type pack struct { username string @@ -646,9 +731,9 @@ type pack struct { webCookie string webToken string - rootCluster *TeleInstance - rootAppServer *service.TeleportProcess - rootCertPool *x509.CertPool + rootCluster *TeleInstance + rootAppServers []*service.TeleportProcess + rootCertPool *x509.CertPool rootAppName string rootAppPublicAddr string @@ -702,12 +787,13 @@ type pack struct { } type appTestOptions struct { - extraRootApps []service.App - extraLeafApps []service.App - userLogins []string - userTraits map[string][]string - rootClusterPorts *InstancePorts - leafClusterPorts *InstancePorts + extraRootApps []service.App + extraLeafApps []service.App + userLogins []string + userTraits map[string][]string + rootClusterPorts *InstancePorts + leafClusterPorts *InstancePorts + rootAppServersCount int rootConfig func(config *service.Config) leafConfig func(config *service.Config) @@ -872,8 +958,7 @@ func setupWithOptions(t *testing.T, opts appTestOptions) *pack { rcConf := service.MakeDefaultConfig() rcConf.Console = nil rcConf.Log = log - rcConf.DataDir, err = ioutil.TempDir("", "cluster-"+p.rootCluster.Secrets.SiteName) - require.NoError(t, err) + rcConf.DataDir = t.TempDir() t.Cleanup(func() { os.RemoveAll(rcConf.DataDir) }) rcConf.Auth.Enabled = true rcConf.Auth.Preference.SetSecondFactor("off") @@ -889,8 +974,7 @@ func setupWithOptions(t *testing.T, opts appTestOptions) *pack { lcConf := service.MakeDefaultConfig() lcConf.Console = nil lcConf.Log = log - lcConf.DataDir, err = ioutil.TempDir("", "cluster-"+p.leafCluster.Secrets.SiteName) - require.NoError(t, err) + lcConf.DataDir = t.TempDir() t.Cleanup(func() { os.RemoveAll(lcConf.DataDir) }) lcConf.Auth.Enabled = true lcConf.Auth.Preference.SetSecondFactor("off") @@ -919,64 +1003,17 @@ func setupWithOptions(t *testing.T, opts appTestOptions) *pack { p.rootCluster.StopAll() }) - raConf := service.MakeDefaultConfig() - raConf.Console = nil - raConf.Log = log - raConf.DataDir, err = ioutil.TempDir("", "app-server-"+p.rootCluster.Secrets.SiteName) - require.NoError(t, err) - t.Cleanup(func() { os.RemoveAll(raConf.DataDir) }) - raConf.Token = "static-token-value" - raConf.AuthServers = []utils.NetAddr{ - { - AddrNetwork: "tcp", - Addr: net.JoinHostPort(Loopback, p.rootCluster.GetPortWeb()), - }, + // At least one rootAppServer should start during the setup + rootAppServersCount := 1 + if opts.rootAppServersCount > 0 { + rootAppServersCount = opts.rootAppServersCount } - raConf.Auth.Enabled = false - raConf.Proxy.Enabled = false - raConf.SSH.Enabled = false - raConf.Apps.Enabled = true - raConf.Apps.Apps = append([]service.App{ - { - Name: p.rootAppName, - URI: rootServer.URL, - PublicAddr: p.rootAppPublicAddr, - }, - { - Name: p.rootWSAppName, - URI: rootWSServer.URL, - PublicAddr: p.rootWSPublicAddr, - }, - { - Name: p.rootWSSAppName, - URI: rootWSSServer.URL, - PublicAddr: p.rootWSSPublicAddr, - }, - { - Name: p.jwtAppName, - URI: jwtServer.URL, - PublicAddr: p.jwtAppPublicAddr, - }, - { - Name: p.headerAppName, - URI: headerServer.URL, - PublicAddr: p.headerAppPublicAddr, - }, - { - Name: p.flushAppName, - URI: flushServer.URL, - PublicAddr: p.flushAppPublicAddr, - }, - }, opts.extraRootApps...) - p.rootAppServer, err = p.rootCluster.StartApp(raConf) - require.NoError(t, err) - t.Cleanup(func() { p.rootAppServer.Close() }) + p.rootAppServers = p.startRootAppServers(t, rootAppServersCount, opts.extraRootApps) laConf := service.MakeDefaultConfig() laConf.Console = nil laConf.Log = log - laConf.DataDir, err = ioutil.TempDir("", "app-server-"+p.leafCluster.Secrets.SiteName) - require.NoError(t, err) + laConf.DataDir = t.TempDir() t.Cleanup(func() { os.RemoveAll(laConf.DataDir) }) laConf.Token = "static-token-value" laConf.AuthServers = []utils.NetAddr{ @@ -1324,7 +1361,7 @@ func (p *pack) makeWebsocketRequest(sessionCookie, endpoint string) (string, err return "", trace.Wrap(err) } defer conn.Close() - data, err := ioutil.ReadAll(conn) + data, err := io.ReadAll(conn) if err != nil { return "", trace.Wrap(err) } @@ -1365,7 +1402,7 @@ func (p *pack) sendRequest(req *http.Request, tlsConfig *tls.Config) (int, strin defer resp.Body.Close() // Read in response body. - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { return 0, "", trace.Wrap(err) } @@ -1397,6 +1434,71 @@ func (p *pack) waitForLogout(appCookie string) (int, error) { } } +func (p *pack) startRootAppServers(t *testing.T, count int, extraApps []service.App) []*service.TeleportProcess { + log := utils.NewLoggerForTests() + + servers := make([]*service.TeleportProcess, count) + + for i := 0; i < count; i++ { + raConf := service.MakeDefaultConfig() + raConf.Console = nil + raConf.Log = log + raConf.DataDir = t.TempDir() + t.Cleanup(func() { os.RemoveAll(raConf.DataDir) }) + raConf.Token = "static-token-value" + raConf.AuthServers = []utils.NetAddr{ + { + AddrNetwork: "tcp", + Addr: net.JoinHostPort(Loopback, p.rootCluster.GetPortWeb()), + }, + } + raConf.Auth.Enabled = false + raConf.Proxy.Enabled = false + raConf.SSH.Enabled = false + raConf.Apps.Enabled = true + raConf.Apps.Apps = append([]service.App{ + { + Name: p.rootAppName, + URI: p.rootAppURI, + PublicAddr: p.rootAppPublicAddr, + }, + { + Name: p.rootWSAppName, + URI: p.rootWSAppURI, + PublicAddr: p.rootWSPublicAddr, + }, + { + Name: p.rootWSSAppName, + URI: p.rootWSSAppURI, + PublicAddr: p.rootWSSPublicAddr, + }, + { + Name: p.jwtAppName, + URI: p.jwtAppURI, + PublicAddr: p.jwtAppPublicAddr, + }, + { + Name: p.headerAppName, + URI: p.headerAppURI, + PublicAddr: p.headerAppPublicAddr, + }, + { + Name: p.flushAppName, + URI: p.flushAppURI, + PublicAddr: p.flushAppPublicAddr, + }, + }, extraApps...) + + appServer, err := p.rootCluster.StartApp(raConf) + require.NoError(t, err) + t.Cleanup(func() { appServer.Close() }) + + servers[i] = appServer + } + + return servers +} + var forwardedHeaderNames = []string{ teleport.AppJWTHeader, teleport.AppCFHeader, diff --git a/lib/web/app/handler.go b/lib/web/app/handler.go index c4ac69c92e2ab..a641d858dbff5 100644 --- a/lib/web/app/handler.go +++ b/lib/web/app/handler.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + oxyutils "github.com/gravitational/oxy/utils" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -140,29 +141,64 @@ func (h *Handler) handleForward(w http.ResponseWriter, r *http.Request, session return nil } +// handleForwardError when the forwarder has an error during the `ServeHTTP` it +// will call this function. This handler will then renew the session in order +// to get "fresh" app servers, and then will forwad the request to the newly +// created session. +func (h *Handler) handleForwardError(w http.ResponseWriter, req *http.Request, err error) { + // if it is not an agent connection problem, return without creating a new + // session. + if !trace.IsConnectionProblem(err) { + oxyutils.DefaultHandler.ServeHTTP(w, req, err) + return + } + + session, err := h.renewSession(req) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(http.StatusText(http.StatusInternalServerError))) + return + } + + session.fwd.ServeHTTP(w, req) +} + // authenticate will check if request carries a session cookie matching a // session in the backend. func (h *Handler) authenticate(ctx context.Context, r *http.Request) (*session, error) { - sessionID, err := h.extractSessionID(r) + ws, err := h.getAppSession(r) if err != nil { - h.log.Warnf("Failed to extract session id: %v.", err) + h.log.Warnf("Failed to fetch application session: %v.", err) return nil, trace.AccessDenied("invalid session") } - // Check that the session exists in the backend cache. This allows the user - // to logout and invalidate their application session immediately. This - // lookup should also be fast because it's in the local cache. - ws, err := h.c.AccessPoint.GetAppSession(ctx, types.GetAppSessionRequest{ - SessionID: sessionID, - }) + // Fetch a cached session or create one if this is the first request this + // process has seen. + session, err := h.getSession(ctx, ws) + if err != nil { + h.log.Warnf("Failed to get session: %v.", err) + return nil, trace.AccessDenied("invalid session") + } + + return session, nil +} + +// renewSession based on the request removes the session from cache (if present) +// and generates a new one using the `getSession` flow (same as in +// `authenticate`). +func (h *Handler) renewSession(r *http.Request) (*session, error) { + ws, err := h.getAppSession(r) if err != nil { h.log.Debugf("Failed to fetch application session: not found.") return nil, trace.AccessDenied("invalid session") } - // Fetch a cached session or create one if this is the first request this - // process has seen. - session, err := h.getSession(ctx, ws) + // Remove the session from the cache, this will force a new session to be + // generated and cached. + h.cache.remove(ws.GetName()) + + // Fetches a new session using the same flow as `authenticate`. + session, err := h.getSession(r.Context(), ws) if err != nil { h.log.Warnf("Failed to get session: %v.", err) return nil, trace.AccessDenied("invalid session") @@ -171,6 +207,23 @@ func (h *Handler) authenticate(ctx context.Context, r *http.Request) (*session, return session, nil } +// getAppSession retrieves the `types.WebSession` using the provided +// `http.Request`. +func (h *Handler) getAppSession(r *http.Request) (types.WebSession, error) { + sessionID, err := h.extractSessionID(r) + if err != nil { + h.log.Warnf("Failed to extract session id: %v.", err) + return nil, trace.AccessDenied("invalid session") + } + + // Check that the session exists in the backend cache. This allows the user + // to logout and invalidate their application session immediately. This + // lookup should also be fast because it's in the local cache. + return h.c.AccessPoint.GetAppSession(r.Context(), types.GetAppSessionRequest{ + SessionID: sessionID, + }) +} + // extractSessionID extracts application access session ID from either the // cookie or the client certificate of the provided request. func (h *Handler) extractSessionID(r *http.Request) (sessionID string, err error) { diff --git a/lib/web/app/match.go b/lib/web/app/match.go index 2517228546aa5..1be2446f34a90 100644 --- a/lib/web/app/match.go +++ b/lib/web/app/match.go @@ -26,6 +26,7 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/trace" ) @@ -39,17 +40,10 @@ type Getter interface { GetClusterName(opts ...services.MarshalOption) (types.ClusterName, error) } -// Match will match an application with the passed in matcher function. Matcher -// functions that can match on public address and name are available. -// -// Note that in the situation multiple applications match, a random selection -// is returned. This is done on purpose to support HA to allow multiple -// application proxy nodes to be run and if one is down, at least the -// application can be accessible on the other. -// -// In the future this function should be updated to keep state on application -// servers that are down and to not route requests to that server. -func Match(ctx context.Context, authClient Getter, fn Matcher) (types.AppServer, error) { +// Match will match a list of applications with the passed in matcher function. Matcher +// functions that can match on public address and name are available. The +// resulting list is shuffled before it is returned. +func Match(ctx context.Context, authClient Getter, fn Matcher) ([]types.AppServer, error) { servers, err := authClient.GetApplicationServers(ctx, defaults.Namespace) if err != nil { return nil, trace.Wrap(err) @@ -57,33 +51,55 @@ func Match(ctx context.Context, authClient Getter, fn Matcher) (types.AppServer, var as []types.AppServer for _, server := range servers { - if fn(server.GetApp()) { + if fn(server) { as = append(as, server) } } - if len(as) == 0 { - return nil, trace.NotFound("failed to match application") - } + rand.Shuffle(len(as), func(i, j int) { + as[i], as[j] = as[j], as[i] + }) - index := rand.Intn(len(as)) - return as[index], nil + return as, nil } // Matcher allows matching on different properties of an application. -type Matcher func(types.Application) bool +type Matcher func(types.AppServer) bool // MatchPublicAddr matches on the public address of an application. func MatchPublicAddr(publicAddr string) Matcher { - return func(app types.Application) bool { - return app.GetPublicAddr() == publicAddr + return func(appServer types.AppServer) bool { + return appServer.GetApp().GetPublicAddr() == publicAddr } } // MatchName matches on the name of an application. func MatchName(name string) Matcher { - return func(app types.Application) bool { - return app.GetName() == name + return func(appServer types.AppServer) bool { + return appServer.GetApp().GetName() == name + } +} + +// MatchHealthy tries to establish a connection with the server using the +// `dialAppServer` function. The app server is matched if the function call +// doesn't return any error. +func MatchHealthy(proxyClient reversetunnel.Tunnel, identity *tlsca.Identity) Matcher { + return func(appServer types.AppServer) bool { + _, err := dialAppServer(proxyClient, identity, appServer) + return err == nil + } +} + +// MatchAll matches if all the Matcher functions return true. +func MatchAll(matchers ...Matcher) Matcher { + return func(appServer types.AppServer) bool { + for _, fn := range matchers { + if !fn(appServer) { + return false + } + } + + return true } } @@ -97,13 +113,13 @@ func MatchName(name string) Matcher { // resolve an application. func ResolveFQDN(ctx context.Context, clt Getter, tunnel reversetunnel.Tunnel, proxyDNSNames []string, fqdn string) (types.AppServer, string, error) { // Try and match FQDN to public address of application within cluster. - server, err := Match(ctx, clt, MatchPublicAddr(fqdn)) - if err == nil { + servers, err := Match(ctx, clt, MatchPublicAddr(fqdn)) + if err == nil && len(servers) > 0 { clusterName, err := clt.GetClusterName() if err != nil { return nil, "", trace.Wrap(err) } - return server, clusterName.GetClusterName(), nil + return servers[0], clusterName.GetClusterName(), nil } // Extract the first subdomain from the FQDN and attempt to use this as the @@ -130,9 +146,9 @@ func ResolveFQDN(ctx context.Context, clt Getter, tunnel reversetunnel.Tunnel, p return nil, "", trace.Wrap(err) } - server, err = Match(ctx, authClient, MatchName(appName)) - if err == nil { - return server, clusterClient.GetName(), nil + servers, err = Match(ctx, authClient, MatchName(appName)) + if err == nil && len(servers) > 0 { + return servers[0], clusterClient.GetName(), nil } } diff --git a/lib/web/app/match_test.go b/lib/web/app/match_test.go new file mode 100644 index 0000000000000..fb2eda2fe416b --- /dev/null +++ b/lib/web/app/match_test.go @@ -0,0 +1,97 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package app + +import ( + "errors" + "net" + "testing" + + "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/stretchr/testify/require" +) + +func TestMatchAll(t *testing.T) { + falseMatcher := func(_ types.AppServer) bool { return false } + trueMatcher := func(_ types.AppServer) bool { return true } + + require.True(t, MatchAll(trueMatcher, trueMatcher, trueMatcher)(nil)) + require.False(t, MatchAll(trueMatcher, trueMatcher, falseMatcher)(nil)) + require.False(t, MatchAll(falseMatcher, falseMatcher, falseMatcher)(nil)) +} + +func TestMatchHealthy(t *testing.T) { + testCases := map[string]struct { + dialErr error + match bool + }{ + "WithHealthyApp": { + match: true, + }, + "WithUnhealthyApp": { + dialErr: errors.New("failed to connect"), + match: false, + }, + } + + for name, test := range testCases { + t.Run(name, func(t *testing.T) { + identity := &tlsca.Identity{RouteToApp: tlsca.RouteToApp{ClusterName: ""}} + match := MatchHealthy(&mockProxyClient{ + remoteSite: &mockRemoteSite{ + dialErr: test.dialErr, + }, + }, identity) + + app, err := types.NewAppV3( + types.Metadata{ + Name: "test-app", + Namespace: defaults.Namespace, + }, + types.AppSpecV3{ + URI: "https://app.localhost", + }, + ) + require.NoError(t, err) + + appServer, err := types.NewAppServerV3FromApp(app, "localhost", "123") + require.NoError(t, err) + require.Equal(t, test.match, match(appServer)) + }) + } +} + +type mockProxyClient struct { + reversetunnel.Tunnel + remoteSite *mockRemoteSite +} + +func (p *mockProxyClient) GetSite(_ string) (reversetunnel.RemoteSite, error) { + return p.remoteSite, nil +} + +type mockRemoteSite struct { + reversetunnel.RemoteSite + dialErr error +} + +func (r *mockRemoteSite) Dial(_ reversetunnel.DialParams) (net.Conn, error) { + return nil, r.dialErr +} diff --git a/lib/web/app/session.go b/lib/web/app/session.go index 72782f71ab0a7..736ce113dfd5f 100644 --- a/lib/web/app/session.go +++ b/lib/web/app/session.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/ttlmap" "github.com/gravitational/oxy/forward" + oxyutils "github.com/gravitational/oxy/utils" "github.com/sirupsen/logrus" ) @@ -62,30 +63,43 @@ func (h *Handler) newSession(ctx context.Context, ws types.WebSession) (*session if err != nil { return nil, trace.Wrap(err) } - server, err := Match(ctx, accessPoint, MatchPublicAddr(identity.RouteToApp.PublicAddr)) + + // Match healthy and PublicAddr servers. Having a list of only healthy + // servers helps the transport fail before the request is forwarded to a + // server (in cases where there are no healthy servers). This process might + // take an additional time to execute, but since it is cached, only a few + // requests need to perform it. + servers, err := Match(ctx, accessPoint, MatchAll(MatchHealthy(h.c.ProxyClient, identity), MatchPublicAddr(identity.RouteToApp.PublicAddr))) if err != nil { return nil, trace.Wrap(err) } + if len(servers) == 0 { + return nil, trace.NotFound("failed to match applications") + } + // Create a rewriting transport that will be used to forward requests. transport, err := newTransport(&transportConfig{ + log: h.log, proxyClient: h.c.ProxyClient, accessPoint: h.c.AccessPoint, cipherSuites: h.c.CipherSuites, identity: identity, - server: server, + servers: servers, ws: ws, clusterName: h.clusterName, }) if err != nil { return nil, trace.Wrap(err) } + fwd, err := forward.New( forward.FlushInterval(100*time.Millisecond), forward.RoundTripper(transport), forward.Logger(h.log), forward.PassHostHeader(true), - forward.WebsocketDial(transport.dialer), + forward.WebsocketDial(transport.DialWebsocket), + forward.ErrorHandler(oxyutils.ErrorHandlerFunc(h.handleForwardError)), ) if err != nil { return nil, trace.Wrap(err) @@ -152,6 +166,14 @@ func (s *sessionCache) set(key string, value *session, ttl time.Duration) error return nil } +// remove immediately removes a single session from the cache. +func (s *sessionCache) remove(key string) { + s.mu.Lock() + defer s.mu.Unlock() + + _, _ = s.cache.Remove(key) +} + // expireSessions ticks every second trying to close expired sessions. func (s *sessionCache) expireSessions() { ticker := time.NewTicker(time.Second) diff --git a/lib/web/app/transport.go b/lib/web/app/transport.go index b0e8d81dc0efb..b8a63e00d5d72 100644 --- a/lib/web/app/transport.go +++ b/lib/web/app/transport.go @@ -22,6 +22,7 @@ import ( "fmt" "net" "net/http" + "sync" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" @@ -32,6 +33,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + "github.com/sirupsen/logrus" "github.com/gravitational/oxy/forward" "github.com/gravitational/trace" @@ -43,13 +45,14 @@ type transportConfig struct { accessPoint auth.ReadProxyAccessPoint cipherSuites []uint16 identity *tlsca.Identity - server types.AppServer + servers []types.AppServer ws types.WebSession clusterName string + log *logrus.Entry } // Check validates configuration. -func (c transportConfig) Check() error { +func (c *transportConfig) Check() error { if c.proxyClient == nil { return trace.BadParameter("proxy client missing") } @@ -62,8 +65,8 @@ func (c transportConfig) Check() error { if c.identity == nil { return trace.BadParameter("identity missing") } - if c.server == nil { - return trace.BadParameter("server missing") + if len(c.servers) == 0 { + return trace.BadParameter("servers missing") } if c.ws == nil { return trace.BadParameter("web session missing") @@ -83,32 +86,43 @@ type transport struct { // tr is used for forwarding http connections. tr http.RoundTripper - // dialer is used for forwarding websocket connections. - dialer forward.Dialer + // clientTLSConfig is the TLS config used for mutual authentication. + clientTLSConfig *tls.Config + + // servers is the list of servers that the transport can connect to + // organized in a map where the key is the server ID, and the value is the + // `types.AppServer`. + servers *sync.Map } // newTransport creates a new transport. func newTransport(c *transportConfig) (*transport, error) { + var err error if err := c.Check(); err != nil { return nil, trace.Wrap(err) } - // Clone and configure the transport. - tr, err := defaults.Transport() + t := &transport{c: c, servers: &sync.Map{}} + + t.clientTLSConfig, err = configureTLS(c) if err != nil { return nil, trace.Wrap(err) } - tr.DialContext = dialFunc(c) - tr.TLSClientConfig, err = configureTLS(c) + + // Clone and configure the transport. + tr, err := defaults.Transport() if err != nil { return nil, trace.Wrap(err) } + tr.DialContext = t.DialContext + tr.TLSClientConfig = t.clientTLSConfig + + for _, server := range t.c.servers { + t.servers.Store(server.GetResourceID(), server) + } - return &transport{ - c: c, - tr: tr, - dialer: websocketsDialer(tr), - }, nil + t.tr = tr + return t, nil } // RoundTrip will rewrite the request, forward the request to the target @@ -159,39 +173,79 @@ func (t *transport) rewriteRequest(r *http.Request) error { return nil } -// websocketsDialer returns a function that dials a websocket connection -// over the transport's reverse tunnel. -func websocketsDialer(tr *http.Transport) forward.Dialer { - return func(network, address string) (net.Conn, error) { - conn, err := tr.DialContext(context.Background(), network, address) - if err != nil { - return nil, trace.Wrap(err) - } - // App access connections over reverse tunnel use mutual TLS. - return tls.Client(conn, tr.TLSClientConfig), nil - } -} +// DialContext dials and connect to the application service over the reverse +// tunnel subsystem. +func (t *transport) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { + var err error + var conn net.Conn -// dialFunc returns a function that can Dial and connect to the application -// service over the reverse tunnel subsystem. -func dialFunc(c *transportConfig) func(ctx context.Context, network string, addr string) (net.Conn, error) { - return func(ctx context.Context, network string, addr string) (net.Conn, error) { - clusterClient, err := c.proxyClient.GetSite(c.identity.RouteToApp.ClusterName) - if err != nil { - return nil, trace.Wrap(err) + t.servers.Range(func(serverID, appServerInterface interface{}) bool { + appServer, ok := appServerInterface.(types.AppServer) + if !ok { + t.c.log.Warnf("Failed to load AppServer, invalid type %T", appServerInterface) + return true } - conn, err := clusterClient.Dial(reversetunnel.DialParams{ - From: &utils.NetAddr{AddrNetwork: "tcp", Addr: "@web-proxy"}, - To: &utils.NetAddr{AddrNetwork: "tcp", Addr: reversetunnel.LocalNode}, - ServerID: fmt.Sprintf("%v.%v", c.server.GetHostID(), c.identity.RouteToApp.ClusterName), - ConnType: types.AppTunnel, - }) - if err != nil { - return nil, trace.Wrap(err) + var dialErr error + conn, dialErr = dialAppServer(t.c.proxyClient, t.c.identity, appServer) + if dialErr != nil { + // Connection problem with the server. + if trace.IsConnectionProblem(dialErr) { + t.c.log.Warnf("Failed to connect to application server %q: %v.", serverID, dialErr) + t.servers.Delete(serverID) + // Only goes for the next server if the error returned is a + // connection problem. Otherwise, stop iterating over the + // servers and return the error. + return true + } } + + // "save" dial error to return as the function error. + err = dialErr + return false + }) + + if err != nil { + return nil, trace.Wrap(err) + } + + if conn != nil { return conn, nil } + + return nil, trace.ConnectionProblem(nil, "no application servers remaining to connect") +} + +// DialWebsocket dials a websocket connection over the transport's reverse +// tunnel. +func (t *transport) DialWebsocket(network, address string) (net.Conn, error) { + conn, err := t.DialContext(context.Background(), network, address) + if err != nil { + return nil, trace.Wrap(err) + } + // App access connections over reverse tunnel use mutual TLS. + return tls.Client(conn, t.clientTLSConfig), nil +} + +// dialAppServer dial and connect to the application service over the reverse +// tunnel subsystem. +func dialAppServer(proxyClient reversetunnel.Tunnel, identity *tlsca.Identity, server types.AppServer) (net.Conn, error) { + clusterClient, err := proxyClient.GetSite(identity.RouteToApp.ClusterName) + if err != nil { + return nil, trace.Wrap(err) + } + + conn, err := clusterClient.Dial(reversetunnel.DialParams{ + From: &utils.NetAddr{AddrNetwork: "tcp", Addr: "@web-proxy"}, + To: &utils.NetAddr{AddrNetwork: "tcp", Addr: reversetunnel.LocalNode}, + ServerID: fmt.Sprintf("%v.%v", server.GetHostID(), identity.RouteToApp.ClusterName), + ConnType: types.AppTunnel, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + return conn, nil } // configureTLS creates and configures a *tls.Config that will be used for diff --git a/lib/web/apps.go b/lib/web/apps.go index 6231793be75af..378b8744d71d9 100644 --- a/lib/web/apps.go +++ b/lib/web/apps.go @@ -306,12 +306,16 @@ func (h *Handler) resolveDirect(ctx context.Context, proxy reversetunnel.Tunnel, return nil, "", trace.Wrap(err) } - server, err := app.Match(ctx, authClient, app.MatchPublicAddr(publicAddr)) + servers, err := app.Match(ctx, authClient, app.MatchPublicAddr(publicAddr)) if err != nil { return nil, "", trace.Wrap(err) } - return server, clusterName, nil + if len(servers) == 0 { + return nil, "", trace.NotFound("failed to match applications with public addr %s", publicAddr) + } + + return servers[0], clusterName, nil } // resolveFQDN makes a best effort attempt to resolve FQDN to an application