Skip to content

Commit

Permalink
make sure machinekey is concistently used
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Nov 17, 2023
1 parent d9e5b4a commit 09954c1
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 86 deletions.
10 changes: 6 additions & 4 deletions cmd/headscale/cli/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"fmt"

v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"tailscale.com/types/key"
)

const (
Expand Down Expand Up @@ -93,11 +93,13 @@ var createNodeCmd = &cobra.Command{

return
}
if !util.NodePublicKeyRegex.Match([]byte(machineKey)) {
err = errPreAuthKeyMalformed

var mkey key.MachinePublic
err = mkey.UnmarshalText([]byte(machineKey))
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error: %s", err),
fmt.Sprintf("Failed to parse machine key from flag: %s", err),
output,
)

Expand Down
2 changes: 1 addition & 1 deletion hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router.HandleFunc("/register/{nkey}", h.RegisterWebAPI).Methods(http.MethodGet)
h.addLegacyHandlers(router)

router.HandleFunc("/oidc/register/{nkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
Expand Down
6 changes: 3 additions & 3 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (h *Headscale) handleRegister(
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if registerRequest.Followup != "" {
if _, ok := h.registrationCache.Get(registerRequest.NodeKey.String()); ok {
if _, ok := h.registrationCache.Get(machineKey.String()); ok {
log.Debug().
Caller().
Str("node", registerRequest.Hostinfo.Hostname).
Expand Down Expand Up @@ -116,7 +116,7 @@ func (h *Headscale) handleRegister(
}

h.registrationCache.Set(
newNode.NodeKey.String(),
machineKey.String(),
newNode,
registerCacheExpiration,
)
Expand Down Expand Up @@ -205,7 +205,7 @@ func (h *Headscale) handleRegister(
// headscale-managed tailnets?
node.NodeKey = registerRequest.NodeKey
h.registrationCache.Set(
registerRequest.NodeKey.String(),
machineKey.String(),
*node,
registerCacheExpiration,
)
Expand Down
14 changes: 4 additions & 10 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,28 +376,22 @@ func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error {

func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
cache *cache.Cache,
nodeKeyStr string,
mkey key.MachinePublic,
userName string,
nodeExpiry *time.Time,
registrationMethod string,
) (*types.Node, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()

nodeKey := key.NodePublic{}
err := nodeKey.UnmarshalText([]byte(nodeKeyStr))
if err != nil {
return nil, err
}

log.Debug().
Str("nodeKey", nodeKey.ShortString()).
Str("machine_key", mkey.ShortString()).
Str("userName", userName).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback")

if nodeInterface, ok := cache.Get(nodeKey.String()); ok {
if nodeInterface, ok := cache.Get(mkey.String()); ok {
if registrationNode, ok := nodeInterface.(types.Node); ok {
user, err := hsdb.getUser(userName)
if err != nil {
Expand Down Expand Up @@ -425,7 +419,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
)

if err == nil {
cache.Delete(nodeKeyStr)
cache.Delete(mkey.String())
}

return node, err
Expand Down
23 changes: 15 additions & 8 deletions hscontrol/grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,18 @@ func (api headscaleV1APIServer) RegisterNode(
) (*v1.RegisterNodeResponse, error) {
log.Trace().
Str("user", request.GetUser()).
Str("node_key", request.GetKey()).
Str("machine_key", request.GetKey()).
Msg("Registering node")

var mkey key.MachinePublic
err := mkey.UnmarshalText([]byte(request.GetKey()))
if err != nil {
return nil, err
}

node, err := api.h.db.RegisterNodeFromAuthCallback(
api.h.registrationCache,
request.GetKey(),
mkey,
request.GetUser(),
nil,
util.RegisterMethodCLI,
Expand Down Expand Up @@ -532,8 +538,11 @@ func (api headscaleV1APIServer) DebugCreateNode(
return nil, err
}

nodeKey := key.NewNode()

newNode := types.Node{
MachineKey: mkey,
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
GivenName: givenName,
User: *user,
Expand All @@ -544,14 +553,12 @@ func (api headscaleV1APIServer) DebugCreateNode(
HostInfo: types.HostInfo(hostinfo),
}

nodeKey := key.NodePublic{}
err = nodeKey.UnmarshalText([]byte(request.GetKey()))
if err != nil {
log.Panic().Msg("can not add node for debug. invalid node key")
}
log.Debug().
Str("machine_key", mkey.ShortString()).
Msg("adding debug machine via CLI, appending to registration cache")

api.h.registrationCache.Set(
nodeKey.String(),
mkey.String(),
newNode,
registerCacheExpiration,
)
Expand Down
56 changes: 21 additions & 35 deletions hscontrol/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,42 +90,28 @@ func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.T

// RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:nKey.
// Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
nodeKeyStr, ok := vars["nkey"]
machineKeyStr, ok := vars["mkey"]

log.Debug().
Caller().
Str("node_key", nodeKeyStr).
Str("machine_key", machineKeyStr).
Bool("ok", ok).
Msg("Received oidc register call")

if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")

writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
util.LogErr(err, "Failed to write response")
}

return
}

// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText(
[]byte(nodeKeyStr),
var machineKey key.MachinePublic
err := machineKey.UnmarshalText(
[]byte(machineKeyStr),
)

if !ok || nodeKeyStr == "" || err != nil {
if err != nil {
log.Warn().
Err(err).
Msg("Failed to parse incoming nodekey in OIDC registration")
Expand Down Expand Up @@ -154,7 +140,7 @@ func (h *Headscale) RegisterOIDC(
// place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(
stateStr,
nodeKey,
machineKey,
registerCacheExpiration,
)

Expand Down Expand Up @@ -232,7 +218,7 @@ func (h *Headscale) OIDCCallback(
return
}

nodeKey, nodeExists, err := h.validateNodeForOIDCCallback(
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
writer,
state,
claims,
Expand All @@ -255,7 +241,7 @@ func (h *Headscale) OIDCCallback(
return
}

if err := h.registerNodeForOIDCCallback(writer, user, nodeKey, idTokenExpiry); err != nil {
if err := h.registerNodeForOIDCCallback(writer, user, machineKey, idTokenExpiry); err != nil {
return
}

Expand Down Expand Up @@ -462,10 +448,10 @@ func (h *Headscale) validateNodeForOIDCCallback(
state string,
claims *IDTokenClaims,
expiry time.Time,
) (*key.NodePublic, bool, error) {
) (*key.MachinePublic, bool, error) {
// retrieve nodekey from state cache
nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
if !nodeKeyFound {
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound {
log.Trace().
Msg("requested node state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
Expand All @@ -478,11 +464,11 @@ func (h *Headscale) validateNodeForOIDCCallback(
return nil, false, errOIDCNodeKeyMissing
}

var nodeKey key.NodePublic
nodeKey, nodeKeyOK := nodeKeyIf.(key.NodePublic)
if !nodeKeyOK {
var machineKey key.MachinePublic
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
if !machineKeyOK {
log.Trace().
Interface("got", nodeKeyIf).
Interface("got", machineKeyIf).
Msg("requested node state key is not a nodekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
Expand All @@ -498,7 +484,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := h.db.GetNodeByNodeKey(nodeKey)
node, _ := h.db.GetNodeByMachineKey(machineKey)

if node != nil {
log.Trace().
Expand Down Expand Up @@ -553,7 +539,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
return nil, true, nil
}

return &nodeKey, false, nil
return &machineKey, false, nil
}

func getUserName(
Expand Down Expand Up @@ -624,13 +610,13 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
func (h *Headscale) registerNodeForOIDCCallback(
writer http.ResponseWriter,
user *types.User,
nodeKey *key.NodePublic,
machineKey *key.MachinePublic,
expiry time.Time,
) error {
if _, err := h.db.RegisterNodeFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules
h.registrationCache,
nodeKey.String(),
*machineKey,
user.Name,
&expiry,
util.RegisterMethodOIDC,
Expand Down
Loading

0 comments on commit 09954c1

Please sign in to comment.