Skip to content

Commit

Permalink
Merge pull request #1 from kradalby/namespace-mappings
Browse files Browse the repository at this point in the history
Implement namespace mappings
  • Loading branch information
unreality authored Oct 19, 2021
2 parents 710616f + 677bd9b commit 8fe72dc
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 60 deletions.
6 changes: 3 additions & 3 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("machine", m.Name).
Msg("Machine registration has expired. Sending a authurl to register")

if h.cfg.OIDCIssuer != "" {
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
} else {
Expand Down Expand Up @@ -225,7 +225,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration").
Str("machine", m.Name).
Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDCIssuer != "" {
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
Expand Down Expand Up @@ -424,7 +424,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
db.Save(&m)

h.updateMachineExpiry(&m) // TODO: do we want to do different expiry times for AuthKeys?

pak.Used = true
db.Save(&pak)

Expand Down
26 changes: 15 additions & 11 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ package headscale
import (
"errors"
"fmt"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"net/http"
"os"
"sort"
"strings"
"sync"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"

"github.com/rs/zerolog/log"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -57,14 +58,19 @@ type Config struct {

DNSConfig *tailcfg.DNSConfig

OIDCIssuer string
OIDCClientID string
OIDCClientSecret string
OIDC OIDCConfig

MaxMachineRegistrationDuration time.Duration
DefaultMachineRegistrationDuration time.Duration
}

type OIDCConfig struct {
Issuer string
ClientID string
ClientSecret string
MatchMap map[string]string
}

// Headscale represents the base app of the service
type Headscale struct {
cfg Config
Expand Down Expand Up @@ -122,14 +128,14 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, err
}

if cfg.OIDCIssuer != "" {
if cfg.OIDC.Issuer != "" {
err = h.initOIDC()
if err != nil {
return nil, err
}
}
}

if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
if err != nil {
return nil, err
Expand Down Expand Up @@ -294,7 +300,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {

times = append(times, lastChange)
}

}

sort.Slice(times, func(i, j int) bool {
Expand All @@ -305,7 +310,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {

if len(times) == 0 {
return time.Now().UTC()

} else {
return times[0]
}
Expand Down
20 changes: 13 additions & 7 deletions cli_test.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
package headscale

import (
"time"

"gopkg.in/check.v1"
)

func (s *Suite) TestRegisterMachine(c *check.C) {
n, err := h.CreateNamespace("test")
c.Assert(err, check.IsNil)

now := time.Now().UTC()

m := Machine{
ID: 0,
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
IPAddress: "10.0.0.1",
ID: 0,
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
IPAddress: "10.0.0.1",
Expiry: &now,
RequestedExpiry: &now,
}
h.db.Save(&m)

Expand Down
24 changes: 20 additions & 4 deletions cmd/headscale/cli/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"os"
"path/filepath"
"regexp"
"strings"
"time"

Expand Down Expand Up @@ -73,7 +74,6 @@ func LoadConfig(path string) error {
} else {
return nil
}

}

func GetDNSConfig() (*tailcfg.DNSConfig, string) {
Expand Down Expand Up @@ -206,15 +206,19 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
ACMEEmail: viper.GetString("acme_email"),
ACMEURL: viper.GetString("acme_url"),

OIDCIssuer: viper.GetString("oidc_issuer"),
OIDCClientID: viper.GetString("oidc_client_id"),
OIDCClientSecret: viper.GetString("oidc_client_secret"),
OIDC: headscale.OIDCConfig{
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: viper.GetString("oidc.client_secret"),
},

MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time
DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration

}

cfg.OIDC.MatchMap = loadOIDCMatchMap()

h, err := headscale.NewHeadscale(cfg)
if err != nil {
return nil, err
Expand Down Expand Up @@ -291,3 +295,15 @@ func HasJsonOutputFlag() bool {
}
return false
}

// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in
// the match map is valid regex strings.
func loadOIDCMatchMap() map[string]string {
strMap := viper.GetStringMapString("oidc.domain_map")

for oidcMatcher := range strMap {
_ = regexp.MustCompile(oidcMatcher)
}

return strMap
}
89 changes: 54 additions & 35 deletions oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ import (
"crypto/rand"
"encoding/hex"
"fmt"
"net/http"
"regexp"
"strings"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/gin-gonic/gin"
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"net/http"
"strings"
"time"
)

type IDTokenClaims struct {
Expand All @@ -26,16 +28,16 @@ func (h *Headscale) initOIDC() error {
var err error
// grab oidc config if it hasn't been already
if h.oauth2Config == nil {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)

if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
return err
}

h.oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDCClientID,
ClientSecret: h.cfg.OIDCClientSecret,
ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
Expand All @@ -62,7 +64,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {

b := make([]byte, 16)
_, err := rand.Read(b)

if err != nil {
log.Error().Msg("could not read 16 bytes from rand")
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
Expand All @@ -86,7 +87,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
// TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback
func (h *Headscale) OIDCCallback(c *gin.Context) {

code := c.Query("code")
state := c.Query("state")

Expand All @@ -109,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
return
}

verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID})
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})

idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil {
Expand All @@ -131,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
return
}

//retrieve machinekey from state cache
// retrieve machinekey from state cache
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)

if !mKeyFound {
Expand All @@ -149,7 +149,6 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {

// retrieve machine information
m, err := h.GetMachineByMachineKey(mKeyStr)

if err != nil {
log.Error().Msg("machine key not found in database")
c.String(http.StatusInternalServerError, "could not get machine info from database")
Expand All @@ -158,40 +157,40 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {

now := time.Now().UTC()

// register the machine if it's new
if !m.Registered {
nsName := strings.ReplaceAll(claims.Email, "@", "-") // TODO: Implement a better email sanitisation
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
// register the machine if it's new
if !m.Registered {

log.Debug().Msg("Registering new machine after successful callback")
log.Debug().Msg("Registering new machine after successful callback")

ns, err := h.GetNamespace(nsName)
if err != nil {
ns, err = h.CreateNamespace(nsName)
ns, err := h.GetNamespace(nsName)
if err != nil {
ns, err = h.CreateNamespace(nsName)

if err != nil {
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
c.String(http.StatusInternalServerError, "could not create new namespace")
return
}
}

ip, err := h.getAvailableIP()
if err != nil {
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
c.String(http.StatusInternalServerError, "could not create new namespace")
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
return
}
}

ip, err := h.getAvailableIP()
if err != nil {
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
return
m.IPAddress = ip.String()
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "oidc"
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
}

m.IPAddress = ip.String()
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "oidc"
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
}
h.updateMachineExpiry(m)

h.updateMachineExpiry(m)

c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html>
<body>
<h1>headscale</h1>
Expand All @@ -202,4 +201,24 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
</html>
`, claims.Email)))

}

log.Error().
Str("email", claims.Email).
Str("username", claims.Username).
Str("machine", m.Name).
Msg("Email could not be mapped to a namespace")
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
}

func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) {
for match, namespace := range h.cfg.OIDC.MatchMap {
regex := regexp.MustCompile(match)
if regex.MatchString(email) {
return namespace, true
}
}

return "", false
}
Loading

0 comments on commit 8fe72dc

Please sign in to comment.