Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rework mapsession #1791

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ jobs:
- TestTaildrop
- TestResolveMagicDNS
- TestExpireNode
- TestNodeOnlineLastSeenStatus
- TestNodeOnlineStatus
- TestPingAllByIPManyUpDown
- TestEnablingRoutes
- TestHASubnetRouterFailover
- TestEnableDisableAutoApprovedRoute
Expand Down
146 changes: 79 additions & 67 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/derp"
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
Expand Down Expand Up @@ -77,6 +78,11 @@ const (
registerCacheCleanup = time.Minute * 20
)

// func init() {
// deadlock.Opts.DeadlockTimeout = 15 * time.Second
// deadlock.Opts.PrintAllCurrentGoroutines = true
// }

// Headscale represents the base app of the service.
type Headscale struct {
cfg *types.Config
Expand All @@ -89,15 +95,18 @@ type Headscale struct {

ACLPolicy *policy.ACLPolicy

mapper *mapper.Mapper
nodeNotifier *notifier.Notifier

oidcProvider *oidc.Provider
oauth2Config *oauth2.Config

registrationCache *cache.Cache

shutdownChan chan struct{}
pollNetMapStreamWG sync.WaitGroup

mapSessions map[types.NodeID]*mapSession
mapSessionMu sync.Mutex
}

var (
Expand Down Expand Up @@ -129,6 +138,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
nodeNotifier: notifier.NewNotifier(),
mapSessions: make(map[types.NodeID]*mapSession),
}

app.db, err = db.NewHeadscaleDatabase(
Expand Down Expand Up @@ -199,26 +209,37 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, target, http.StatusFound)
}

// expireEphemeralNodes deletes ephemeral node records that have not been
// deleteExpireEphemeralNodes deletes ephemeral node records that have not been
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)

var update types.StateUpdate
var changed bool
for range ticker.C {
var removed []types.NodeID
var changed []types.NodeID
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)

return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
continue
}

if changed && update.Valid() {
if removed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, update)
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: removed,
})
}

if changed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changed,
})
}
}
}
Expand All @@ -243,8 +264,9 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
continue
}

log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes")
if changed && update.Valid() {
if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")

ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
}
Expand Down Expand Up @@ -272,14 +294,11 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
h.DERPMap.Regions[region.RegionID] = &region
}

stateUpdate := types.StateUpdate{
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateDERPUpdated,
DERPMap: h.DERPMap,
}
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
}
})
}
}
}
Expand All @@ -303,11 +322,6 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,

meta, ok := metadata.FromIncomingContext(ctx)
if !ok {
log.Error().
Caller().
Str("client_address", client.Addr.String()).
Msg("Retrieving metadata is failed")

return ctx, status.Errorf(
codes.InvalidArgument,
"Retrieving metadata is failed",
Expand All @@ -316,11 +330,6 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,

authHeader, ok := meta["authorization"]
if !ok {
log.Error().
Caller().
Str("client_address", client.Addr.String()).
Msg("Authorization token is not supplied")

return ctx, status.Errorf(
codes.Unauthenticated,
"Authorization token is not supplied",
Expand All @@ -330,11 +339,6 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
token := authHeader[0]

if !strings.HasPrefix(token, AuthPrefix) {
log.Error().
Caller().
Str("client_address", client.Addr.String()).
Msg(`missing "Bearer " prefix in "Authorization" header`)

return ctx, status.Error(
codes.Unauthenticated,
`missing "Bearer " prefix in "Authorization" header`,
Expand All @@ -343,12 +347,6 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,

valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix))
if err != nil {
log.Error().
Caller().
Err(err).
Str("client_address", client.Addr.String()).
Msg("failed to validate token")

return ctx, status.Error(codes.Internal, "failed to validate token")
}

Expand Down Expand Up @@ -483,7 +481,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router
}

// Serve launches a GIN server with the Headscale API.
// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile {
if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok {
Expand All @@ -502,6 +500,7 @@ func (h *Headscale) Serve() error {

// Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier.ConnectedMap())

if h.cfg.DERP.ServerEnabled {
// When embedded DERP is enabled we always need a STUN server
Expand All @@ -511,7 +510,7 @@ func (h *Headscale) Serve() error {

region, err := h.DERPServer.GenerateRegion()
if err != nil {
return err
return fmt.Errorf("generating DERP region for embedded server: %w", err)
}

if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
Expand All @@ -533,7 +532,7 @@ func (h *Headscale) Serve() error {

// TODO(kradalby): These should have cancel channels and be cleaned
// up on shutdown.
go h.expireEphemeralNodes(updateInterval)
go h.deleteExpireEphemeralNodes(updateInterval)
go h.expireExpiredMachines(updateInterval)

if zl.GlobalLevel() == zl.TraceLevel {
Expand Down Expand Up @@ -586,14 +585,14 @@ func (h *Headscale) Serve() error {
}...,
)
if err != nil {
return err
return fmt.Errorf("setting up gRPC gateway via socket: %w", err)
}

// Connect to the gRPC server over localhost to skip
// the authentication.
err = v1.RegisterHeadscaleServiceHandler(ctx, grpcGatewayMux, grpcGatewayConn)
if err != nil {
return err
return fmt.Errorf("registering Headscale API service to gRPC: %w", err)
}

// Start the local gRPC server without TLS and without authentication
Expand All @@ -614,9 +613,7 @@ func (h *Headscale) Serve() error {

tlsConfig, err := h.getTLSSettings()
if err != nil {
log.Error().Err(err).Msg("Failed to set up TLS configuration")

return err
return fmt.Errorf("configuring TLS settings: %w", err)
}

//
Expand Down Expand Up @@ -681,12 +678,11 @@ func (h *Headscale) Serve() error {
httpServer := &http.Server{
Addr: h.cfg.Addr,
Handler: router,
ReadTimeout: types.HTTPReadTimeout,
// Go does not handle timeouts in HTTP very well, and there is
// no good way to handle streaming timeouts, therefore we need to
// keep this at unlimited and be careful to clean up connections
// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming
WriteTimeout: 0,
ReadTimeout: types.HTTPTimeout,

// Long polling should not have any timeout, this is overriden
// further down the chain
WriteTimeout: types.HTTPTimeout,
}

var httpListener net.Listener
Expand All @@ -705,27 +701,46 @@ func (h *Headscale) Serve() error {
log.Info().
Msgf("listening and serving HTTP on: %s", h.cfg.Addr)

promMux := http.NewServeMux()
promMux.Handle("/metrics", promhttp.Handler())
debugMux := http.NewServeMux()
debugMux.HandleFunc("/debug/notifier", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(h.nodeNotifier.String()))

promHTTPServer := &http.Server{
return
})
debugMux.HandleFunc("/debug/mapresp", func(w http.ResponseWriter, r *http.Request) {
h.mapSessionMu.Lock()
defer h.mapSessionMu.Unlock()

var b strings.Builder
b.WriteString("mapresponders:\n")
for k, v := range h.mapSessions {
fmt.Fprintf(&b, "\t%d: %p\n", k, v)
}

w.WriteHeader(http.StatusOK)
w.Write([]byte(b.String()))

return
})
debugMux.Handle("/metrics", promhttp.Handler())

debugHTTPServer := &http.Server{
Addr: h.cfg.MetricsAddr,
Handler: promMux,
ReadTimeout: types.HTTPReadTimeout,
Handler: debugMux,
ReadTimeout: types.HTTPTimeout,
WriteTimeout: 0,
}

var promHTTPListener net.Listener
promHTTPListener, err = net.Listen("tcp", h.cfg.MetricsAddr)

debugHTTPListener, err := net.Listen("tcp", h.cfg.MetricsAddr)
if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err)
}

errorGroup.Go(func() error { return promHTTPServer.Serve(promHTTPListener) })
errorGroup.Go(func() error { return debugHTTPServer.Serve(debugHTTPListener) })

log.Info().
Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr)
Msgf("listening and serving debug and metrics on: %s", h.cfg.MetricsAddr)

var tailsqlContext context.Context
if tailsqlEnabled {
Expand All @@ -742,7 +757,6 @@ func (h *Headscale) Serve() error {
}

// Handle common process-killing signals so we can gracefully shut down:
h.shutdownChan = make(chan struct{})
sigc := make(chan os.Signal, 1)
signal.Notify(sigc,
syscall.SIGHUP,
Expand Down Expand Up @@ -785,16 +799,14 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully")

close(h.shutdownChan)

h.pollNetMapStreamWG.Wait()

// Gracefully shut down servers
ctx, cancel := context.WithTimeout(
context.Background(),
types.HTTPShutdownTimeout,
)
if err := promHTTPServer.Shutdown(ctx); err != nil {
if err := debugHTTPServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Failed to shutdown prometheus http")
}
if err := httpServer.Shutdown(ctx); err != nil {
Expand All @@ -812,7 +824,7 @@ func (h *Headscale) Serve() error {
}

// Close network listeners
promHTTPListener.Close()
debugHTTPListener.Close()
httpListener.Close()
grpcGatewayConn.Close()

Expand Down Expand Up @@ -877,7 +889,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
server := &http.Server{
Addr: h.cfg.TLS.LetsEncrypt.Listen,
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
ReadTimeout: types.HTTPReadTimeout,
ReadTimeout: types.HTTPTimeout,
}

go func() {
Expand Down
Loading
Loading