From 50b62ddfb3b81755e42b0fbb12b8b09cca8a2535 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 18:19:14 -0400 Subject: [PATCH] fix loading policy manager Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 112 ++++++++++++++++++++--------------------- hscontrol/policy/pm.go | 10 ++-- 2 files changed, 63 insertions(+), 59 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index b4d36caa85..a0e105cabb 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -88,7 +88,8 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer - polMan policy.PolicyManager + polManOnce sync.Once + polMan policy.PolicyManager mapper *mapper.Mapper nodeNotifier *notifier.Notifier @@ -531,8 +532,7 @@ func (h *Headscale) Serve() error { } var err error - - if err = h.loadACLPolicy(); err != nil { + if err = h.loadPolicyManager(); err != nil { return fmt.Errorf("failed to load ACL policy: %w", err) } @@ -814,7 +814,7 @@ func (h *Headscale) Serve() error { // TODO(kradalby): Reload config on SIGHUP // TODO(kradalby): Only update if we set a new policy - if err := h.loadACLPolicy(); err != nil { + if err := h.loadPolicyManager(); err != nil { log.Error().Err(err).Msg("failed to reload ACL policy") } @@ -1037,22 +1037,9 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { return &machineKey, nil } -func (h *Headscale) loadACLPolicy() error { - var ( - pm policy.PolicyManager - ) - - switch h.cfg.Policy.Mode { - case types.PolicyModeFile: - path := h.cfg.Policy.Path - - // It is fine to start headscale without a policy file. - if len(path) == 0 { - return nil - } - - absPath := util.AbsolutePathFromConfigPath(path) - +func (h *Headscale) loadPolicyManager() error { + var errOut error + h.polManOnce.Do(func() { // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still // a scenario where they might be allowed if the server has no nodes @@ -1063,54 +1050,67 @@ func (h *Headscale) loadACLPolicy() error { // allowed to be written to the database. nodes, err := h.db.ListNodes() if err != nil { - return fmt.Errorf("loading nodes from database to validate policy: %w", err) + errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err) + return } users, err := h.db.ListUsers() if err != nil { - return fmt.Errorf("loading users from database to validate policy: %w", err) + errOut = fmt.Errorf("loading users from database to validate policy: %w", err) + return } - pm, err = policy.NewPolicyManagerFromPath(absPath, users, nodes) - if err != nil { - return fmt.Errorf("loading policy from file: %w", err) - } + switch h.cfg.Policy.Mode { + case types.PolicyModeFile: + path := h.cfg.Policy.Path - if len(nodes) > 0 { - _, err = pm.SSHPolicy(nodes[0]) + // It is fine to start headscale without a policy file. + if len(path) == 0 { + h.polMan, err = policy.NewPolicyManager(nil, users, nodes) + if err != nil { + errOut = fmt.Errorf("policy manager with no policy: %w", err) + } + + return + } + + absPath := util.AbsolutePathFromConfigPath(path) + + h.polMan, err = policy.NewPolicyManagerFromPath(absPath, users, nodes) if err != nil { - return fmt.Errorf("verifying SSH rules: %w", err) + errOut = fmt.Errorf("loading policy from file (%s): %w", absPath, err) + return } - } - case types.PolicyModeDB: - p, err := h.db.GetPolicy() - if err != nil { - if errors.Is(err, types.ErrPolicyNotFound) { - return nil + if len(nodes) > 0 { + _, err = h.polMan.SSHPolicy(nodes[0]) + if err != nil { + errOut = fmt.Errorf("verifying SSH rules: %w", err) + return + } } - return fmt.Errorf("failed to get policy from database: %w", err) - } + case types.PolicyModeDB: + p, err := h.db.GetPolicy() + if err != nil { + if errors.Is(err, types.ErrPolicyNotFound) { + return + } - nodes, err := h.db.ListNodes() - if err != nil { - return fmt.Errorf("loading nodes from database to validate policy: %w", err) - } - users, err := h.db.ListUsers() - if err != nil { - return fmt.Errorf("loading users from database to validate policy: %w", err) - } - pm, err = policy.NewPolicyManager([]byte(p.Data), users, nodes) - if err != nil { - return fmt.Errorf("loading policy from database: %w", err) - } - default: - log.Fatal(). - Str("mode", string(h.cfg.Policy.Mode)). - Msg("Unknown ACL policy mode") - } + errOut = fmt.Errorf("failed to get policy from database: %w", err) + return + } - h.polMan = pm + h.polMan, err = policy.NewPolicyManager([]byte(p.Data), users, nodes) + if err != nil { + errOut = fmt.Errorf("loading policy from database: %w", err) + return + } + default: + log.Fatal(). + Str("mode", string(h.cfg.Policy.Mode)). + Msg("Unknown ACL policy mode") + } + }) - return nil + return errOut } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 8ca9f1dbb9..a5d736bd7d 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -40,9 +40,13 @@ func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes } func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { - pol, err := LoadACLPolicyFromBytes(polB) - if err != nil { - return nil, fmt.Errorf("parsing policy: %w", err) + var pol *ACLPolicy + var err error + if polB != nil && len(polB) > 0 { + pol, err = LoadACLPolicyFromBytes(polB) + if err != nil { + return nil, fmt.Errorf("parsing policy: %w", err) + } } pm := PolicyManagerV1{