Skip to content

Commit

Permalink
reject, not replace
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Apr 12, 2024
1 parent 4c949e3 commit ae56239
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 116 deletions.
71 changes: 42 additions & 29 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router
}

// Serve launches a GIN server with the Headscale API.
// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
Expand Down Expand Up @@ -532,7 +532,7 @@ func (h *Headscale) Serve() error {

region, err := h.DERPServer.GenerateRegion()
if err != nil {
return err
return fmt.Errorf("generating DERP region for embedded server: %w", err)
}

if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
Expand Down Expand Up @@ -607,14 +607,14 @@ func (h *Headscale) Serve() error {
}...,
)
if err != nil {
return err
return fmt.Errorf("setting up gRPC gateway via socket: %w", err)
}

// Connect to the gRPC server over localhost to skip
// the authentication.
err = v1.RegisterHeadscaleServiceHandler(ctx, grpcGatewayMux, grpcGatewayConn)
if err != nil {
return err
return fmt.Errorf("registering Headscale API service to gRPC: %w", err)
}

// Start the local gRPC server without TLS and without authentication
Expand All @@ -635,9 +635,7 @@ func (h *Headscale) Serve() error {

tlsConfig, err := h.getTLSSettings()
if err != nil {
log.Error().Err(err).Msg("Failed to set up TLS configuration")

return err
return fmt.Errorf("configuring TLS settings: %w", err)
}

//
Expand Down Expand Up @@ -702,15 +700,11 @@ func (h *Headscale) Serve() error {
httpServer := &http.Server{
Addr: h.cfg.Addr,
Handler: router,
ReadTimeout: types.HTTPReadTimeout,
// Go does not handle timeouts in HTTP very well, and there is
// no good way to handle streaming timeouts, therefore we need to
// keep this at unlimited and be careful to clean up connections
// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming
// TODO(kradalby): this timeout can now be set per handler with http.ResponseController:
// https://www.alexedwards.net/blog/how-to-use-the-http-responsecontroller-type
// replace this so only the longpoller has no timeout.
WriteTimeout: 0,
ReadTimeout: types.HTTPTimeout,

// Long polling should not have any timeout, this is overriden
// further down the chain
WriteTimeout: types.HTTPTimeout,
}

var httpListener net.Listener
Expand All @@ -729,27 +723,46 @@ func (h *Headscale) Serve() error {
log.Info().
Msgf("listening and serving HTTP on: %s", h.cfg.Addr)

promMux := http.NewServeMux()
promMux.Handle("/metrics", promhttp.Handler())
debugMux := http.NewServeMux()
debugMux.HandleFunc("/debug/notifier", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(h.nodeNotifier.String()))

promHTTPServer := &http.Server{
return
})
debugMux.HandleFunc("/debug/mapresp", func(w http.ResponseWriter, r *http.Request) {
h.mapSessionMu.Lock()
defer h.mapSessionMu.Unlock()

var b strings.Builder
b.WriteString("mapresponders:\n")
for k, v := range h.mapSessions {
fmt.Fprintf(&b, "\t%d: %p\n", k, v)
}

w.WriteHeader(http.StatusOK)
w.Write([]byte(b.String()))

return
})
debugMux.Handle("/metrics", promhttp.Handler())

debugHTTPServer := &http.Server{
Addr: h.cfg.MetricsAddr,
Handler: promMux,
ReadTimeout: types.HTTPReadTimeout,
Handler: debugMux,
ReadTimeout: types.HTTPTimeout,
WriteTimeout: 0,
}

var promHTTPListener net.Listener
promHTTPListener, err = net.Listen("tcp", h.cfg.MetricsAddr)

debugHTTPListener, err := net.Listen("tcp", h.cfg.MetricsAddr)
if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err)
}

errorGroup.Go(func() error { return promHTTPServer.Serve(promHTTPListener) })
errorGroup.Go(func() error { return debugHTTPServer.Serve(debugHTTPListener) })

log.Info().
Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr)
Msgf("listening and serving debug and metrics on: %s", h.cfg.MetricsAddr)

var tailsqlContext context.Context
if tailsqlEnabled {
Expand Down Expand Up @@ -815,7 +828,7 @@ func (h *Headscale) Serve() error {
context.Background(),
types.HTTPShutdownTimeout,
)
if err := promHTTPServer.Shutdown(ctx); err != nil {
if err := debugHTTPServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Failed to shutdown prometheus http")
}
if err := httpServer.Shutdown(ctx); err != nil {
Expand All @@ -833,7 +846,7 @@ func (h *Headscale) Serve() error {
}

// Close network listeners
promHTTPListener.Close()
debugHTTPListener.Close()
httpListener.Close()
grpcGatewayConn.Close()

Expand Down Expand Up @@ -898,7 +911,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
server := &http.Server{
Addr: h.cfg.TLS.LetsEncrypt.Listen,
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
ReadTimeout: types.HTTPReadTimeout,
ReadTimeout: types.HTTPTimeout,
}

go func() {
Expand Down
4 changes: 2 additions & 2 deletions hscontrol/derp/derp.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
}

func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
ctx, cancel := context.WithTimeout(context.Background(), types.HTTPReadTimeout)
ctx, cancel := context.WithTimeout(context.Background(), types.HTTPTimeout)
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil)
Expand All @@ -40,7 +40,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
}

client := http.Client{
Timeout: types.HTTPReadTimeout,
Timeout: types.HTTPTimeout,
}

resp, err := client.Do(req)
Expand Down
92 changes: 41 additions & 51 deletions hscontrol/noise.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package hscontrol
import (
"encoding/binary"
"encoding/json"
"errors"
"io"
"net/http"

Expand All @@ -12,7 +11,6 @@ import (
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"gorm.io/gorm"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp"
"tailscale.com/tailcfg"
Expand Down Expand Up @@ -103,12 +101,12 @@ func (h *Headscale) NoiseUpgradeHandler(
router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler)

server := http.Server{
ReadTimeout: types.HTTPReadTimeout,
ReadTimeout: types.HTTPTimeout,
}

noiseServer.httpBaseConfig = &http.Server{
Handler: router,
ReadHeaderTimeout: types.HTTPReadTimeout,
ReadHeaderTimeout: types.HTTPTimeout,
}
noiseServer.http2Server = &http2.Server{}

Expand Down Expand Up @@ -225,15 +223,6 @@ func (ns *noiseServer) NoisePollNetMapHandler(
key.NodePublic{},
)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "NoisePollNetMap").
Uint64("node.id", node.ID.Uint64()).
Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusNotFound)

return
}
log.Error().
Str("handler", "NoisePollNetMap").
Uint64("node.id", node.ID.Uint64()).
Expand All @@ -242,58 +231,59 @@ func (ns *noiseServer) NoisePollNetMapHandler(

return
}
log.Debug().
Str("handler", "NoisePollNetMap").
Str("node", node.Hostname).
Int("cap_ver", int(mapRequest.Version)).
Uint64("node.id", node.ID.Uint64()).
Msg("A node sending a MapRequest with Noise protocol")
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node)

session := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node)
sess.tracef("a node sending a MapRequest with Noise protocol")

// If a streaming mapSession exists for this node, close it
// and start a new one.
if session.isStreaming() {
log.Debug().
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Aquiring lock to check stream")
if sess.isStreaming() {
sess.tracef("aquiring lock to check stream")

ns.headscale.mapSessionMu.Lock()
if oldSession, ok := ns.headscale.mapSessions[node.ID]; ok {
log.Info().
Caller().
Uint64("node.id", node.ID.Uint64()).
Msg("Node has an open streaming session, replacing")
oldSession.close()
if _, ok := ns.headscale.mapSessions[node.ID]; ok {
// NOTE/TODO(kradalby): From how I understand the protocol, when
// a client connects with stream=true, and already has a streaming
// connection open, the correct way is to close the current channel
// and replace it. However, I cannot manage to get that working with
// some sort of lock/block happening on the cancelCh in the streaming
// session.
// Not closing the channel and replacing it puts us in a weird state
// which keeps a ghost stream open, receiving keep alives, but no updates.
//
// Typically a new connection is opened when one exists as a client which
// is already authenticated reconnects (e.g. down, then up). The client will
// start auth and streaming at the same time, and then cancel the streaming
// when the auth has finished successfully, opening a new connection.
//
// As a work-around to not replacing, abusing the clients "resilience"
// by reject the new connection which will cause the client to immediately
// reconnect and "fix" the issue, as the other connection typically has been
// closed, meaning there is nothing to replace.
//
// sess.infof("node has an open stream(%p), replacing with %p", oldSession, sess)
// oldSession.close()

defer ns.headscale.mapSessionMu.Unlock()

sess.infof("node has an open stream(%p), rejecting new stream", sess)
return
}

ns.headscale.mapSessions[node.ID] = session
ns.headscale.mapSessions[node.ID] = sess
ns.headscale.mapSessionMu.Unlock()
log.Debug().
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Releasing lock to check stream")
sess.tracef("releasing lock to check stream")
}

session.serve()
sess.serve()

if session.isStreaming() {
log.Debug().
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Aquiring lock to remove stream")
if sess.isStreaming() {
sess.tracef("aquiring lock to remove stream")
ns.headscale.mapSessionMu.Lock()
defer ns.headscale.mapSessionMu.Unlock()

delete(ns.headscale.mapSessions, node.ID)

ns.headscale.mapSessionMu.Unlock()
log.Debug().
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Releasing lock to remove stream")
sess.tracef("releasing lock to remove stream")
}
}
14 changes: 11 additions & 3 deletions hscontrol/notifier/notifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,19 @@ func (n *Notifier) String() string {
n.l.RLock()
defer n.l.RUnlock()

str := []string{"Notifier, in map:\n"}
var b strings.Builder
b.WriteString("chans:\n")

for k, v := range n.nodes {
str = append(str, fmt.Sprintf("\t%d: %v\n", k, v))
fmt.Fprintf(&b, "\t%d: %p\n", k, v)
}

return strings.Join(str, "")
b.WriteString("\n")
b.WriteString("connected:\n")

for k, v := range n.connected {
fmt.Fprintf(&b, "\t%d: %t\n", k, v)
}

return b.String()
}
Loading

0 comments on commit ae56239

Please sign in to comment.