Skip to content

Commit

Permalink
wrap policy in policy manager interface (#2255)
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby authored Nov 26, 2024
1 parent 2c1ad6d commit f7b0cbb
Show file tree
Hide file tree
Showing 16 changed files with 741 additions and 370 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/test-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ jobs:
- TestPreAuthKeyCorrectUserLoggedInCommand
- TestApiKeyCommand
- TestNodeTagCommand
- TestNodeAdvertiseTagNoACLCommand
- TestNodeAdvertiseTagWithACLCommand
- TestNodeAdvertiseTagCommand
- TestNodeCommand
- TestNodeExpireCommand
- TestNodeRenameCommand
Expand Down
163 changes: 113 additions & 50 deletions hscontrol/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ type Headscale struct {
DERPMap *tailcfg.DERPMap
DERPServer *derpServer.DERPServer

ACLPolicy *policy.ACLPolicy
polManOnce sync.Once
polMan policy.PolicyManager

mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
Expand Down Expand Up @@ -153,6 +154,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
}
})

if err = app.loadPolicyManager(); err != nil {
return nil, fmt.Errorf("failed to load ACL policy: %w", err)
}

var authProvider AuthProvider
authProvider = NewAuthProviderWeb(cfg.ServerURL)
if cfg.OIDC.Issuer != "" {
Expand All @@ -165,6 +170,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
app.db,
app.nodeNotifier,
app.ipAlloc,
app.polMan,
)
if err != nil {
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
Expand Down Expand Up @@ -475,6 +481,52 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router
}

// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
users, err := db.ListUsers()
if err != nil {
return err
}

changed, err := polMan.SetUsers(users)
if err != nil {
return err
}

if changed {
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}

return nil
}

// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
nodes, err := db.ListNodes()
if err != nil {
return err
}

changed, err := polMan.SetNodes(nodes)
if err != nil {
return err
}

if changed {
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}

return nil
}

// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
if profilingEnabled {
Expand All @@ -490,19 +542,13 @@ func (h *Headscale) Serve() error {
}
}

var err error

if err = h.loadACLPolicy(); err != nil {
return fmt.Errorf("failed to load ACL policy: %w", err)
}

if dumpConfig {
spew.Dump(h.cfg)
}

// Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan)

if h.cfg.DERP.ServerEnabled {
// When embedded DERP is enabled we always need a STUN server
Expand Down Expand Up @@ -772,12 +818,21 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()).
Msg("Received SIGHUP, reloading ACL and Config")

// TODO(kradalby): Reload config on SIGHUP
if err := h.loadACLPolicy(); err != nil {
log.Error().Err(err).Msg("failed to reload ACL policy")
if err := h.loadPolicyManager(); err != nil {
log.Error().Err(err).Msg("failed to reload Policy")
}

if h.ACLPolicy != nil {
pol, err := h.policyBytes()
if err != nil {
log.Error().Err(err).Msg("failed to get policy blob")
}

changed, err := h.polMan.SetPolicy(pol)
if err != nil {
log.Error().Err(err).Msg("failed to set new policy")
}

if changed {
log.Info().
Msg("ACL policy successfully reloaded, notifying nodes of change")

Expand Down Expand Up @@ -996,27 +1051,46 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
return &machineKey, nil
}

func (h *Headscale) loadACLPolicy() error {
var (
pol *policy.ACLPolicy
err error
)

// policyBytes returns the appropriate policy for the
// current configuration as a []byte array.
func (h *Headscale) policyBytes() ([]byte, error) {
switch h.cfg.Policy.Mode {
case types.PolicyModeFile:
path := h.cfg.Policy.Path

// It is fine to start headscale without a policy file.
if len(path) == 0 {
return nil
return nil, nil
}

absPath := util.AbsolutePathFromConfigPath(path)
pol, err = policy.LoadACLPolicyFromPath(absPath)
policyFile, err := os.Open(absPath)
if err != nil {
return fmt.Errorf("failed to load ACL policy from file: %w", err)
return nil, err
}
defer policyFile.Close()

return io.ReadAll(policyFile)

case types.PolicyModeDB:
p, err := h.db.GetPolicy()
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil, nil
}

return nil, err
}

return []byte(p.Data), err
}

return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode)
}

func (h *Headscale) loadPolicyManager() error {
var errOut error
h.polManOnce.Do(func() {
// Validate and reject configuration that would error when applied
// when creating a map response. This requires nodes, so there is still
// a scenario where they might be allowed if the server has no nodes
Expand All @@ -1027,46 +1101,35 @@ func (h *Headscale) loadACLPolicy() error {
// allowed to be written to the database.
nodes, err := h.db.ListNodes()
if err != nil {
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err)
return
}
users, err := h.db.ListUsers()
if err != nil {
return fmt.Errorf("loading users from database to validate policy: %w", err)
errOut = fmt.Errorf("loading users from database to validate policy: %w", err)
return
}

_, err = pol.CompileFilterRules(users, nodes)
pol, err := h.policyBytes()
if err != nil {
return fmt.Errorf("verifying policy rules: %w", err)
}

if len(nodes) > 0 {
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
if err != nil {
return fmt.Errorf("verifying SSH rules: %w", err)
}
errOut = fmt.Errorf("loading policy bytes: %w", err)
return
}

case types.PolicyModeDB:
p, err := h.db.GetPolicy()
h.polMan, err = policy.NewPolicyManager(pol, users, nodes)
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil
}

return fmt.Errorf("failed to get policy from database: %w", err)
errOut = fmt.Errorf("creating policy manager: %w", err)
return
}

pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data))
if err != nil {
return fmt.Errorf("failed to parse policy: %w", err)
if len(nodes) > 0 {
_, err = h.polMan.SSHPolicy(nodes[0])
if err != nil {
errOut = fmt.Errorf("verifying SSH rules: %w", err)
return
}
}
default:
log.Fatal().
Str("mode", string(h.cfg.Policy.Mode)).
Msg("Unknown ACL policy mode")
}

h.ACLPolicy = pol
})

return nil
return errOut
}
7 changes: 7 additions & 0 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey(

return
}

err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
if err != nil {
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}

}

err = h.db.Write(func(tx *gorm.DB) error {
Expand Down
13 changes: 11 additions & 2 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ func TestAutoApproveRoutes(t *testing.T) {
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))

require.NoError(t, err)
assert.NotNil(t, pol)
require.NotNil(t, pol)

user, err := adb.CreateUser("test")
require.NoError(t, err)
Expand Down Expand Up @@ -600,8 +600,17 @@ func TestAutoApproveRoutes(t *testing.T) {
node0ByID, err := adb.GetNodeByID(0)
require.NoError(t, err)

users, err := adb.ListUsers()
assert.NoError(t, err)

nodes, err := adb.ListNodes()
assert.NoError(t, err)

pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
assert.NoError(t, err)

// TODO(kradalby): Check state update
err = adb.EnableAutoApprovedRoutes(pol, node0ByID)
err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
require.NoError(t, err)

enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
Expand Down
20 changes: 5 additions & 15 deletions hscontrol/db/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,18 +598,18 @@ func failoverRoute(
}

func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
aclPolicy *policy.ACLPolicy,
polMan policy.PolicyManager,
node *types.Node,
) error {
return hsdb.Write(func(tx *gorm.DB) error {
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
return EnableAutoApprovedRoutes(tx, polMan, node)
})
}

// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
func EnableAutoApprovedRoutes(
tx *gorm.DB,
aclPolicy *policy.ACLPolicy,
polMan policy.PolicyManager,
node *types.Node,
) error {
if node.IPv4 == nil && node.IPv6 == nil {
Expand All @@ -630,12 +630,7 @@ func EnableAutoApprovedRoutes(
continue
}

routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
netip.Prefix(advertisedRoute.Prefix),
)
if err != nil {
return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err)
}
routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix))

log.Trace().
Str("node", node.Hostname).
Expand All @@ -648,13 +643,8 @@ func EnableAutoApprovedRoutes(
if approvedAlias == node.User.Username() {
approvedRoutes = append(approvedRoutes, advertisedRoute)
} else {
users, err := ListUsers(tx)
if err != nil {
return fmt.Errorf("looking up users to expand route alias: %w", err)
}

// TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias)
approvedIps, err := polMan.ExpandAlias(approvedAlias)
if err != nil {
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
}
Expand Down
Loading

0 comments on commit f7b0cbb

Please sign in to comment.