diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index 8250910e2c..c77aeef304 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -4,10 +4,10 @@ import ( "fmt" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "google.golang.org/grpc/status" + "tailscale.com/types/key" ) const ( @@ -93,11 +93,13 @@ var createNodeCmd = &cobra.Command{ return } - if !util.NodePublicKeyRegex.Match([]byte(machineKey)) { - err = errPreAuthKeyMalformed + + var mkey key.MachinePublic + err = mkey.UnmarshalText([]byte(machineKey)) + if err != nil { ErrorOutput( err, - fmt.Sprintf("Error: %s", err), + fmt.Sprintf("Failed to parse machine key from flag: %s", err), output, ) diff --git a/hscontrol/app.go b/hscontrol/app.go index bb67ffc421..f130764b66 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -452,7 +452,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { router.HandleFunc("/register/{nkey}", h.RegisterWebAPI).Methods(http.MethodGet) h.addLegacyHandlers(router) - router.HandleFunc("/oidc/register/{nkey}", h.RegisterOIDC).Methods(http.MethodGet) + router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet) router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet) router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig). diff --git a/hscontrol/auth.go b/hscontrol/auth.go index ffe48949d1..5fc365fdf8 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -45,7 +45,7 @@ func (h *Headscale) handleRegister( // is that the client will hammer headscale with requests until it gets a // successful RegisterResponse. if registerRequest.Followup != "" { - if _, ok := h.registrationCache.Get(registerRequest.NodeKey.String()); ok { + if _, ok := h.registrationCache.Get(machineKey.String()); ok { log.Debug(). Caller(). Str("node", registerRequest.Hostinfo.Hostname). @@ -116,7 +116,7 @@ func (h *Headscale) handleRegister( } h.registrationCache.Set( - newNode.NodeKey.String(), + machineKey.String(), newNode, registerCacheExpiration, ) @@ -205,7 +205,7 @@ func (h *Headscale) handleRegister( // headscale-managed tailnets? node.NodeKey = registerRequest.NodeKey h.registrationCache.Set( - registerRequest.NodeKey.String(), + machineKey.String(), *node, registerCacheExpiration, ) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 87c174807d..bc122ee9eb 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -376,7 +376,7 @@ func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error { func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( cache *cache.Cache, - nodeKeyStr string, + mkey key.MachinePublic, userName string, nodeExpiry *time.Time, registrationMethod string, @@ -384,20 +384,14 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( hsdb.mu.Lock() defer hsdb.mu.Unlock() - nodeKey := key.NodePublic{} - err := nodeKey.UnmarshalText([]byte(nodeKeyStr)) - if err != nil { - return nil, err - } - log.Debug(). - Str("nodeKey", nodeKey.ShortString()). + Str("machine_key", mkey.ShortString()). Str("userName", userName). Str("registrationMethod", registrationMethod). Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). Msg("Registering node from API/CLI or auth callback") - if nodeInterface, ok := cache.Get(nodeKey.String()); ok { + if nodeInterface, ok := cache.Get(mkey.String()); ok { if registrationNode, ok := nodeInterface.(types.Node); ok { user, err := hsdb.getUser(userName) if err != nil { @@ -425,7 +419,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( ) if err == nil { - cache.Delete(nodeKeyStr) + cache.Delete(mkey.String()) } return node, err diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 753d5c79b6..5c05146d3d 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -172,12 +172,18 @@ func (api headscaleV1APIServer) RegisterNode( ) (*v1.RegisterNodeResponse, error) { log.Trace(). Str("user", request.GetUser()). - Str("node_key", request.GetKey()). + Str("machine_key", request.GetKey()). Msg("Registering node") + var mkey key.MachinePublic + err := mkey.UnmarshalText([]byte(request.GetKey())) + if err != nil { + return nil, err + } + node, err := api.h.db.RegisterNodeFromAuthCallback( api.h.registrationCache, - request.GetKey(), + mkey, request.GetUser(), nil, util.RegisterMethodCLI, @@ -532,8 +538,11 @@ func (api headscaleV1APIServer) DebugCreateNode( return nil, err } + nodeKey := key.NewNode() + newNode := types.Node{ MachineKey: mkey, + NodeKey: nodeKey.Public(), Hostname: request.GetName(), GivenName: givenName, User: *user, @@ -544,14 +553,12 @@ func (api headscaleV1APIServer) DebugCreateNode( HostInfo: types.HostInfo(hostinfo), } - nodeKey := key.NodePublic{} - err = nodeKey.UnmarshalText([]byte(request.GetKey())) - if err != nil { - log.Panic().Msg("can not add node for debug. invalid node key") - } + log.Debug(). + Str("machine_key", mkey.ShortString()). + Msg("adding debug machine via CLI, appending to registration cache") api.h.registrationCache.Set( - nodeKey.String(), + mkey.String(), newNode, registerCacheExpiration, ) diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index a5ddd97373..568519fd51 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -90,42 +90,28 @@ func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.T // RegisterOIDC redirects to the OIDC provider for authentication // Puts NodeKey in cache so the callback can retrieve it using the oidc state param -// Listens in /oidc/register/:nKey. +// Listens in /oidc/register/:mKey. func (h *Headscale) RegisterOIDC( writer http.ResponseWriter, req *http.Request, ) { vars := mux.Vars(req) - nodeKeyStr, ok := vars["nkey"] + machineKeyStr, ok := vars["mkey"] log.Debug(). Caller(). - Str("node_key", nodeKeyStr). + Str("machine_key", machineKeyStr). Bool("ok", ok). Msg("Received oidc register call") - if !util.NodePublicKeyRegex.Match([]byte(nodeKeyStr)) { - log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusUnauthorized) - _, err := writer.Write([]byte("Unauthorized")) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return - } - // We need to make sure we dont open for XSS style injections, if the parameter that // is passed as a key is not parsable/validated as a NodePublic key, then fail to render // the template and log an error. - var nodeKey key.NodePublic - err := nodeKey.UnmarshalText( - []byte(nodeKeyStr), + var machineKey key.MachinePublic + err := machineKey.UnmarshalText( + []byte(machineKeyStr), ) - - if !ok || nodeKeyStr == "" || err != nil { + if err != nil { log.Warn(). Err(err). Msg("Failed to parse incoming nodekey in OIDC registration") @@ -154,7 +140,7 @@ func (h *Headscale) RegisterOIDC( // place the node key into the state cache, so it can be retrieved later h.registrationCache.Set( stateStr, - nodeKey, + machineKey, registerCacheExpiration, ) @@ -232,7 +218,7 @@ func (h *Headscale) OIDCCallback( return } - nodeKey, nodeExists, err := h.validateNodeForOIDCCallback( + machineKey, nodeExists, err := h.validateNodeForOIDCCallback( writer, state, claims, @@ -255,7 +241,7 @@ func (h *Headscale) OIDCCallback( return } - if err := h.registerNodeForOIDCCallback(writer, user, nodeKey, idTokenExpiry); err != nil { + if err := h.registerNodeForOIDCCallback(writer, user, machineKey, idTokenExpiry); err != nil { return } @@ -462,10 +448,10 @@ func (h *Headscale) validateNodeForOIDCCallback( state string, claims *IDTokenClaims, expiry time.Time, -) (*key.NodePublic, bool, error) { +) (*key.MachinePublic, bool, error) { // retrieve nodekey from state cache - nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state) - if !nodeKeyFound { + machineKeyIf, machineKeyFound := h.registrationCache.Get(state) + if !machineKeyFound { log.Trace(). Msg("requested node state key expired before authorisation completed") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -478,11 +464,11 @@ func (h *Headscale) validateNodeForOIDCCallback( return nil, false, errOIDCNodeKeyMissing } - var nodeKey key.NodePublic - nodeKey, nodeKeyOK := nodeKeyIf.(key.NodePublic) - if !nodeKeyOK { + var machineKey key.MachinePublic + machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic) + if !machineKeyOK { log.Trace(). - Interface("got", nodeKeyIf). + Interface("got", machineKeyIf). Msg("requested node state key is not a nodekey") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) @@ -498,7 +484,7 @@ func (h *Headscale) validateNodeForOIDCCallback( // The error is not important, because if it does not // exist, then this is a new node and we will move // on to registration. - node, _ := h.db.GetNodeByNodeKey(nodeKey) + node, _ := h.db.GetNodeByMachineKey(machineKey) if node != nil { log.Trace(). @@ -553,7 +539,7 @@ func (h *Headscale) validateNodeForOIDCCallback( return nil, true, nil } - return &nodeKey, false, nil + return &machineKey, false, nil } func getUserName( @@ -624,13 +610,13 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback( func (h *Headscale) registerNodeForOIDCCallback( writer http.ResponseWriter, user *types.User, - nodeKey *key.NodePublic, + machineKey *key.MachinePublic, expiry time.Time, ) error { if _, err := h.db.RegisterNodeFromAuthCallback( // TODO(kradalby): find a better way to use the cache across modules h.registrationCache, - nodeKey.String(), + *machineKey, user.Name, &expiry, util.RegisterMethodOIDC, diff --git a/integration/cli_test.go b/integration/cli_test.go index 61439c3f97..ed1b3fc8af 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -683,8 +683,8 @@ func TestNodeTagCommand(t *testing.T) { assertNoErr(t, err) machineKeys := []string{ - "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - "nodekey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", + "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", + "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", } nodes := make([]*v1.Node, len(machineKeys)) assert.Nil(t, err) @@ -816,13 +816,13 @@ func TestNodeCommand(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - // Randomly generated node keys + // Pregenerated machine keys machineKeys := []string{ - "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - "nodekey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", - "nodekey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - "nodekey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", - "nodekey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", + "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", + "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", + "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", + "mkey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", + "mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", } nodes := make([]*v1.Node, len(machineKeys)) assert.Nil(t, err) @@ -898,8 +898,8 @@ func TestNodeCommand(t *testing.T) { assert.Equal(t, "node-5", listAll[4].Name) otherUserMachineKeys := []string{ - "nodekey:b5b444774186d4217adcec407563a1223929465ee2c68a4da13af0d0185b4f8e", - "nodekey:dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584", + "mkey:b5b444774186d4217adcec407563a1223929465ee2c68a4da13af0d0185b4f8e", + "mkey:dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584", } otherUserMachines := make([]*v1.Node, len(otherUserMachineKeys)) assert.Nil(t, err) @@ -1056,13 +1056,13 @@ func TestNodeExpireCommand(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - // Randomly generated node keys + // Pregenerated machine keys machineKeys := []string{ - "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - "nodekey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", - "nodekey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - "nodekey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", - "nodekey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", + "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", + "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", + "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", + "mkey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", + "mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", } nodes := make([]*v1.Node, len(machineKeys)) @@ -1183,13 +1183,13 @@ func TestNodeRenameCommand(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - // Randomly generated node keys + // Pregenerated machine keys machineKeys := []string{ - "nodekey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", - "nodekey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", - "nodekey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - "nodekey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", - "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", + "mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", + "mkey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", + "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", + "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", + "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", } nodes := make([]*v1.Node, len(machineKeys)) assert.Nil(t, err) @@ -1210,7 +1210,7 @@ func TestNodeRenameCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assertNoErr(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1228,7 +1228,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &node, ) - assert.Nil(t, err) + assertNoErr(t, err) nodes[index] = &node } @@ -1350,7 +1350,7 @@ func TestNodeMoveCommand(t *testing.T) { assertNoErr(t, err) // Randomly generated node key - machineKey := "nodekey:688411b767663479632d44140f08a9fde87383adc7cdeb518f62ce28a17ef0aa" + machineKey := "mkey:688411b767663479632d44140f08a9fde87383adc7cdeb518f62ce28a17ef0aa" _, err = headscale.Execute( []string{