Skip to content

Commit

Permalink
use dedicated registration ID for auth flow (#2337)
Browse files Browse the repository at this point in the history
  • Loading branch information
kradalby authored Jan 26, 2025
1 parent 97e5d95 commit 4c8e847
Show file tree
Hide file tree
Showing 26 changed files with 583 additions and 583 deletions.
9 changes: 4 additions & 5 deletions cmd/headscale/cli/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"fmt"

v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"tailscale.com/types/key"
)

const (
Expand Down Expand Up @@ -79,7 +79,7 @@ var createNodeCmd = &cobra.Command{
)
}

machineKey, err := cmd.Flags().GetString("key")
registrationID, err := cmd.Flags().GetString("key")
if err != nil {
ErrorOutput(
err,
Expand All @@ -88,8 +88,7 @@ var createNodeCmd = &cobra.Command{
)
}

var mkey key.MachinePublic
err = mkey.UnmarshalText([]byte(machineKey))
_, err = types.RegistrationIDFromString(registrationID)
if err != nil {
ErrorOutput(
err,
Expand All @@ -108,7 +107,7 @@ var createNodeCmd = &cobra.Command{
}

request := &v1.DebugCreateNodeRequest{
Key: machineKey,
Key: registrationID,
Name: name,
User: user,
Routes: routes,
Expand Down
4 changes: 2 additions & 2 deletions cmd/headscale/cli/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ var registerNodeCmd = &cobra.Command{
defer cancel()
defer conn.Close()

machineKey, err := cmd.Flags().GetString("key")
registrationID, err := cmd.Flags().GetString("key")
if err != nil {
ErrorOutput(
err,
Expand All @@ -132,7 +132,7 @@ var registerNodeCmd = &cobra.Command{
}

request := &v1.RegisterNodeRequest{
Key: machineKey,
Key: registrationID,
User: user,
}

Expand Down
6 changes: 3 additions & 3 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ type Headscale struct {
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier

registrationCache *zcache.Cache[string, types.Node]
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]

authProvider AuthProvider

Expand All @@ -123,7 +123,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
}

registrationCache := zcache.New[string, types.Node](
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
registerCacheExpiration,
registerCacheCleanup,
)
Expand Down Expand Up @@ -462,7 +462,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {

router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).Methods(http.MethodGet)

if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
Expand Down
135 changes: 89 additions & 46 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/juanfont/headscale/hscontrol/db"
Expand All @@ -20,16 +22,18 @@ import (

type AuthProvider interface {
RegisterHandler(http.ResponseWriter, *http.Request)
AuthURL(key.MachinePublic) string
AuthURL(types.RegistrationID) string
}

func logAuthFunc(
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) (func(string), func(string), func(error, string)) {
return func(msg string) {
log.Info().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -41,6 +45,7 @@ func logAuthFunc(
func(msg string) {
log.Trace().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -52,6 +57,7 @@ func logAuthFunc(
func(err error, msg string) {
log.Error().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Expand All @@ -63,16 +69,64 @@ func logAuthFunc(
}
}

func (h *Headscale) waitForFollowup(
req *http.Request,
regReq tailcfg.RegisterRequest,
logTrace func(string),
) {
logTrace("register request is a followup")
fu, err := url.Parse(regReq.Followup)
if err != nil {
logTrace("failed to parse followup URL")
return
}

followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
if err != nil {
logTrace("followup URL does not contains a valid registration ID")
return
}

logTrace(fmt.Sprintf("followup URL contains a valid registration ID, looking up in cache: %s", followupReg))

if reg, ok := h.registrationCache.Get(followupReg); ok {
logTrace("Node is waiting for interactive login")

select {
case <-req.Context().Done():
logTrace("node went away before it was registered")
return
case <-reg.Registered:
logTrace("node has successfully registered")
return
}
}
}

// handleRegister is the logic for registering a client.
func (h *Headscale) handleRegister(
writer http.ResponseWriter,
req *http.Request,
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey)
registrationId, err := types.NewRegistrationID()
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to generate registration ID")
http.Error(writer, "Internal server error", http.StatusInternalServerError)

return
}

logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")

// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
logTrace("handleRegister database lookup has returned")
if errors.Is(err, gorm.ErrRecordNotFound) {
Expand All @@ -84,27 +138,9 @@ func (h *Headscale) handleRegister(
}

// Check if the node is waiting for interactive login.
//
// TODO(juan): We could use this field to improve our protocol implementation,
// and hold the request until the client closes it, or the interactive
// login is completed (i.e., the user registers the node).
// This is not implemented yet, as it is no strictly required. The only side-effect
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if regReq.Followup != "" {
logTrace("register request is a followup")
if _, ok := h.registrationCache.Get(machineKey.String()); ok {
logTrace("Node is waiting for interactive login")

select {
case <-req.Context().Done():
return
case <-time.After(registrationHoldoff):
h.handleNewNode(writer, regReq, machineKey)

return
}
}
h.waitForFollowup(req, regReq, logTrace)
return
}

logInfo("Node not found in database, creating new")
Expand All @@ -113,25 +149,28 @@ func (h *Headscale) handleRegister(
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the node and then keep it around until a callback
// happens
newNode := types.Node{
MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname,
NodeKey: regReq.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
newNode := types.RegisterNode{
Node: types.Node{
MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname,
NodeKey: regReq.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
},
Registered: make(chan struct{}),
}

if !regReq.Expiry.IsZero() {
logTrace("Non-zero expiry time requested")
newNode.Expiry = &regReq.Expiry
newNode.Node.Expiry = &regReq.Expiry
}

h.registrationCache.Set(
machineKey.String(),
registrationId,
newNode,
)

h.handleNewNode(writer, regReq, machineKey)
h.handleNewNode(writer, regReq, registrationId)

return
}
Expand Down Expand Up @@ -206,27 +245,28 @@ func (h *Headscale) handleRegister(
}

if regReq.Followup != "" {
select {
case <-req.Context().Done():
return
case <-time.After(registrationHoldoff):
}
h.waitForFollowup(req, regReq, logTrace)
return
}

// The node has expired or it is logged out
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey)
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId)

// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
node.Expiry = &time.Time{}

// TODO(kradalby): do we need to rethink this as part of authflow?
// If we are here it means the client needs to be reauthorized,
// we need to make sure the NodeKey matches the one in the request
// TODO(juan): What happens when using fast user switching between two
// headscale-managed tailnets?
node.NodeKey = regReq.NodeKey
h.registrationCache.Set(
machineKey.String(),
*node,
registrationId,
types.RegisterNode{
Node: *node,
Registered: make(chan struct{}),
},
)

return
Expand Down Expand Up @@ -296,6 +336,8 @@ func (h *Headscale) handleAuthKey(
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
// key refreshes. This will allow us to remove the machineKey from the registration request.
node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if node != nil {
log.Trace().
Expand Down Expand Up @@ -444,16 +486,16 @@ func (h *Headscale) handleAuthKey(
func (h *Headscale) handleNewNode(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) {
logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey)
logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId)

resp := tailcfg.RegisterResponse{}

// The node registration is new, redirect the client to the registration URL
logTrace("The node seems to be new, sending auth url")
logTrace("The node is new, sending auth url")

resp.AuthURL = h.authProvider.AuthURL(machineKey)
resp.AuthURL = h.authProvider.AuthURL(registrationId)

respBody, err := json.Marshal(resp)
if err != nil {
Expand Down Expand Up @@ -660,6 +702,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
regReq tailcfg.RegisterRequest,
node types.Node,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) {
resp := tailcfg.RegisterResponse{}

Expand All @@ -673,12 +716,12 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
log.Trace().
Caller().
Str("node", node.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Msg("Node registration has expired or logged out. Sending a auth url to register")

resp.AuthURL = h.authProvider.AuthURL(machineKey)
resp.AuthURL = h.authProvider.AuthURL(registrationId)

respBody, err := json.Marshal(resp)
if err != nil {
Expand All @@ -703,7 +746,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(

log.Trace().
Caller().
Str("machine_key", machineKey.ShortString()).
Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Str("node", node.Hostname).
Expand Down
Loading

0 comments on commit 4c8e847

Please sign in to comment.