From f7b0cbbbea77d27203ece5eb3ba2f893c47e806d Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 26 Nov 2024 15:16:06 +0100 Subject: [PATCH] wrap policy in policy manager interface (#2255) Signed-off-by: Kristoffer Dalby --- .github/workflows/test-integration.yaml | 3 +- hscontrol/app.go | 163 +++++++--- hscontrol/auth.go | 7 + hscontrol/db/node_test.go | 13 +- hscontrol/db/routes.go | 20 +- hscontrol/grpcv1.go | 50 +-- hscontrol/mapper/mapper.go | 60 ++-- hscontrol/mapper/mapper_test.go | 5 +- hscontrol/mapper/tail.go | 8 +- hscontrol/mapper/tail_test.go | 5 +- hscontrol/oidc.go | 14 + hscontrol/policy/pm.go | 181 +++++++++++ hscontrol/policy/pm_test.go | 158 ++++++++++ hscontrol/poll.go | 19 +- integration/cli_test.go | 401 +++++++++++------------- integration/hsic/hsic.go | 4 + 16 files changed, 741 insertions(+), 370 deletions(-) create mode 100644 hscontrol/policy/pm.go create mode 100644 hscontrol/policy/pm_test.go diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 1e514f24b2..1584862462 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -31,8 +31,7 @@ jobs: - TestPreAuthKeyCorrectUserLoggedInCommand - TestApiKeyCommand - TestNodeTagCommand - - TestNodeAdvertiseTagNoACLCommand - - TestNodeAdvertiseTagWithACLCommand + - TestNodeAdvertiseTagCommand - TestNodeCommand - TestNodeExpireCommand - TestNodeRenameCommand diff --git a/hscontrol/app.go b/hscontrol/app.go index 62877df2a2..1651b8f211 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -88,7 +88,8 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer - ACLPolicy *policy.ACLPolicy + polManOnce sync.Once + polMan policy.PolicyManager mapper *mapper.Mapper nodeNotifier *notifier.Notifier @@ -153,6 +154,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } }) + if err = app.loadPolicyManager(); err != nil { + return nil, fmt.Errorf("failed to load ACL policy: %w", err) + } + var authProvider AuthProvider authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { @@ -165,6 +170,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app.db, app.nodeNotifier, app.ipAlloc, + app.polMan, ) if err != nil { if cfg.OIDC.OnlyStartIfOIDCIsAvailable { @@ -475,6 +481,52 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { return router } +// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. +// Maybe we should attempt a new in memory state and not go via the DB? +func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + users, err := db.ListUsers() + if err != nil { + return err + } + + changed, err := polMan.SetUsers(users) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + +// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. +// Maybe we should attempt a new in memory state and not go via the DB? +func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + nodes, err := db.ListNodes() + if err != nil { + return err + } + + changed, err := polMan.SetNodes(nodes) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { if profilingEnabled { @@ -490,19 +542,13 @@ func (h *Headscale) Serve() error { } } - var err error - - if err = h.loadACLPolicy(); err != nil { - return fmt.Errorf("failed to load ACL policy: %w", err) - } - if dumpConfig { spew.Dump(h.cfg) } // Fetch an initial DERP Map before we start serving h.DERPMap = derp.GetDERPMap(h.cfg.DERP) - h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier) + h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan) if h.cfg.DERP.ServerEnabled { // When embedded DERP is enabled we always need a STUN server @@ -772,12 +818,21 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received SIGHUP, reloading ACL and Config") - // TODO(kradalby): Reload config on SIGHUP - if err := h.loadACLPolicy(); err != nil { - log.Error().Err(err).Msg("failed to reload ACL policy") + if err := h.loadPolicyManager(); err != nil { + log.Error().Err(err).Msg("failed to reload Policy") } - if h.ACLPolicy != nil { + pol, err := h.policyBytes() + if err != nil { + log.Error().Err(err).Msg("failed to get policy blob") + } + + changed, err := h.polMan.SetPolicy(pol) + if err != nil { + log.Error().Err(err).Msg("failed to set new policy") + } + + if changed { log.Info(). Msg("ACL policy successfully reloaded, notifying nodes of change") @@ -996,27 +1051,46 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { return &machineKey, nil } -func (h *Headscale) loadACLPolicy() error { - var ( - pol *policy.ACLPolicy - err error - ) - +// policyBytes returns the appropriate policy for the +// current configuration as a []byte array. +func (h *Headscale) policyBytes() ([]byte, error) { switch h.cfg.Policy.Mode { case types.PolicyModeFile: path := h.cfg.Policy.Path // It is fine to start headscale without a policy file. if len(path) == 0 { - return nil + return nil, nil } absPath := util.AbsolutePathFromConfigPath(path) - pol, err = policy.LoadACLPolicyFromPath(absPath) + policyFile, err := os.Open(absPath) if err != nil { - return fmt.Errorf("failed to load ACL policy from file: %w", err) + return nil, err } + defer policyFile.Close() + + return io.ReadAll(policyFile) + case types.PolicyModeDB: + p, err := h.db.GetPolicy() + if err != nil { + if errors.Is(err, types.ErrPolicyNotFound) { + return nil, nil + } + + return nil, err + } + + return []byte(p.Data), err + } + + return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode) +} + +func (h *Headscale) loadPolicyManager() error { + var errOut error + h.polManOnce.Do(func() { // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still // a scenario where they might be allowed if the server has no nodes @@ -1027,46 +1101,35 @@ func (h *Headscale) loadACLPolicy() error { // allowed to be written to the database. nodes, err := h.db.ListNodes() if err != nil { - return fmt.Errorf("loading nodes from database to validate policy: %w", err) + errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err) + return } users, err := h.db.ListUsers() if err != nil { - return fmt.Errorf("loading users from database to validate policy: %w", err) + errOut = fmt.Errorf("loading users from database to validate policy: %w", err) + return } - _, err = pol.CompileFilterRules(users, nodes) + pol, err := h.policyBytes() if err != nil { - return fmt.Errorf("verifying policy rules: %w", err) - } - - if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) - if err != nil { - return fmt.Errorf("verifying SSH rules: %w", err) - } + errOut = fmt.Errorf("loading policy bytes: %w", err) + return } - case types.PolicyModeDB: - p, err := h.db.GetPolicy() + h.polMan, err = policy.NewPolicyManager(pol, users, nodes) if err != nil { - if errors.Is(err, types.ErrPolicyNotFound) { - return nil - } - - return fmt.Errorf("failed to get policy from database: %w", err) + errOut = fmt.Errorf("creating policy manager: %w", err) + return } - pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data)) - if err != nil { - return fmt.Errorf("failed to parse policy: %w", err) + if len(nodes) > 0 { + _, err = h.polMan.SSHPolicy(nodes[0]) + if err != nil { + errOut = fmt.Errorf("verifying SSH rules: %w", err) + return + } } - default: - log.Fatal(). - Str("mode", string(h.cfg.Policy.Mode)). - Msg("Unknown ACL policy mode") - } - - h.ACLPolicy = pol + }) - return nil + return errOut } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 675450319d..2b23aad3d7 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey( return } + + err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier) + if err != nil { + http.Error(writer, "Internal server error", http.StatusInternalServerError) + return + } + } err = h.db.Write(func(tx *gorm.DB) error { diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index e3dd376e80..7c83c1be42 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -563,7 +563,7 @@ func TestAutoApproveRoutes(t *testing.T) { pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl)) require.NoError(t, err) - assert.NotNil(t, pol) + require.NotNil(t, pol) user, err := adb.CreateUser("test") require.NoError(t, err) @@ -600,8 +600,17 @@ func TestAutoApproveRoutes(t *testing.T) { node0ByID, err := adb.GetNodeByID(0) require.NoError(t, err) + users, err := adb.ListUsers() + assert.NoError(t, err) + + nodes, err := adb.ListNodes() + assert.NoError(t, err) + + pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes) + assert.NoError(t, err) + // TODO(kradalby): Check state update - err = adb.EnableAutoApprovedRoutes(pol, node0ByID) + err = adb.EnableAutoApprovedRoutes(pm, node0ByID) require.NoError(t, err) enabledRoutes, err := adb.GetEnabledRoutes(node0ByID) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 1c07ed9dbb..0a72c4278e 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -598,18 +598,18 @@ func failoverRoute( } func (hsdb *HSDatabase) EnableAutoApprovedRoutes( - aclPolicy *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, ) error { return hsdb.Write(func(tx *gorm.DB) error { - return EnableAutoApprovedRoutes(tx, aclPolicy, node) + return EnableAutoApprovedRoutes(tx, polMan, node) }) } // EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. func EnableAutoApprovedRoutes( tx *gorm.DB, - aclPolicy *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, ) error { if node.IPv4 == nil && node.IPv6 == nil { @@ -630,12 +630,7 @@ func EnableAutoApprovedRoutes( continue } - routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( - netip.Prefix(advertisedRoute.Prefix), - ) - if err != nil { - return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err) - } + routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix)) log.Trace(). Str("node", node.Hostname). @@ -648,13 +643,8 @@ func EnableAutoApprovedRoutes( if approvedAlias == node.User.Username() { approvedRoutes = append(approvedRoutes, advertisedRoute) } else { - users, err := ListUsers(tx) - if err != nil { - return fmt.Errorf("looking up users to expand route alias: %w", err) - } - // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias) + approvedIps, err := polMan.ExpandAlias(approvedAlias) if err != nil { return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index d66bda2e17..3e9fcb5e78 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -21,7 +21,6 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/db" - "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" ) @@ -58,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser( return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.CreateUserResponse{User: user.Proto()}, nil } @@ -97,6 +101,11 @@ func (api headscaleV1APIServer) DeleteUser( return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.DeleteUserResponse{}, nil } @@ -241,6 +250,11 @@ func (api headscaleV1APIServer) RegisterNode( return nil, err } + err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using node: %w", err) + } + return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } @@ -480,10 +494,7 @@ func (api headscaleV1APIServer) ListNodes( resp.Online = true } - validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( - node, - ) - resp.InvalidTags = invalidTags + validTags := api.h.polMan.Tags(node) resp.ValidTags = validTags response[index] = resp } @@ -759,11 +770,6 @@ func (api headscaleV1APIServer) SetPolicy( p := request.GetPolicy() - pol, err := policy.LoadACLPolicyFromBytes([]byte(p)) - if err != nil { - return nil, fmt.Errorf("loading ACL policy file: %w", err) - } - // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still // a scenario where they might be allowed if the server has no nodes @@ -773,18 +779,13 @@ func (api headscaleV1APIServer) SetPolicy( if err != nil { return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) } - users, err := api.h.db.ListUsers() - if err != nil { - return nil, fmt.Errorf("loading users from database to validate policy: %w", err) - } - - _, err = pol.CompileFilterRules(users, nodes) + changed, err := api.h.polMan.SetPolicy([]byte(p)) if err != nil { - return nil, fmt.Errorf("verifying policy rules: %w", err) + return nil, fmt.Errorf("setting policy: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) + _, err = api.h.polMan.SSHPolicy(nodes[0]) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } @@ -795,12 +796,13 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - api.h.ACLPolicy = pol - - ctx := types.NotifyCtx(context.Background(), "acl-update", "na") - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StateFullUpdate, - }) + // Only send update if the packet filter has changed. + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-update", "na") + api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } response := &v1.SetPolicyResponse{ Policy: updated.Data, diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 5205a1125c..51c96f8c87 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -55,6 +55,7 @@ type Mapper struct { cfg *types.Config derpMap *tailcfg.DERPMap notif *notifier.Notifier + polMan policy.PolicyManager uid string created time.Time @@ -71,6 +72,7 @@ func NewMapper( cfg *types.Config, derpMap *tailcfg.DERPMap, notif *notifier.Notifier, + polMan policy.PolicyManager, ) *Mapper { uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) @@ -79,6 +81,7 @@ func NewMapper( cfg: cfg, derpMap: derpMap, notif: notif, + polMan: polMan, uid: uid, created: time.Now(), @@ -153,11 +156,9 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, - users []types.User, - pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { - resp, err := m.baseWithConfigMapResponse(node, pol, capVer) + resp, err := m.baseWithConfigMapResponse(node, capVer) if err != nil { return nil, err } @@ -165,11 +166,9 @@ func (m *Mapper) fullMapResponse( err = appendPeerChanges( resp, true, // full change - pol, + m.polMan, node, capVer, - users, - peers, peers, m.cfg, ) @@ -184,19 +183,14 @@ func (m *Mapper) fullMapResponse( func (m *Mapper) FullMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { peers, err := m.ListPeers(node.ID) if err != nil { return nil, err } - users, err := m.db.ListUsers() - if err != nil { - return nil, err - } - resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, mapRequest.Version) if err != nil { return nil, err } @@ -210,10 +204,9 @@ func (m *Mapper) FullMapResponse( func (m *Mapper) ReadOnlyMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { - resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version) + resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version) if err != nil { return nil, err } @@ -249,7 +242,6 @@ func (m *Mapper) PeerChangedResponse( node *types.Node, changed map[types.NodeID]bool, patches []*tailcfg.PeerChange, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { resp := m.baseMapResponse() @@ -259,11 +251,6 @@ func (m *Mapper) PeerChangedResponse( return nil, err } - users, err := m.db.ListUsers() - if err != nil { - return nil, fmt.Errorf("listing users for map response: %w", err) - } - var removedIDs []tailcfg.NodeID var changedIDs []types.NodeID for nodeID, nodeChanged := range changed { @@ -284,11 +271,9 @@ func (m *Mapper) PeerChangedResponse( err = appendPeerChanges( &resp, false, // partial change - pol, + m.polMan, node, mapRequest.Version, - users, - peers, changedNodes, m.cfg, ) @@ -315,7 +300,7 @@ func (m *Mapper) PeerChangedResponse( // Add the node itself, it might have changed, and particularly // if there are no patches or changes, this is a self update. - tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg) + tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg) if err != nil { return nil, err } @@ -330,7 +315,6 @@ func (m *Mapper) PeerChangedPatchResponse( mapRequest tailcfg.MapRequest, node *types.Node, changed []*tailcfg.PeerChange, - pol *policy.ACLPolicy, ) ([]byte, error) { resp := m.baseMapResponse() resp.PeersChangedPatch = changed @@ -459,12 +443,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse { // incremental. func (m *Mapper) baseWithConfigMapResponse( node *types.Node, - pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { resp := m.baseMapResponse() - tailnode, err := tailNode(node, capVer, pol, m.cfg) + tailnode, err := tailNode(node, capVer, m.polMan, m.cfg) if err != nil { return nil, err } @@ -517,35 +500,30 @@ func appendPeerChanges( resp *tailcfg.MapResponse, fullChange bool, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, capVer tailcfg.CapabilityVersion, - users []types.User, - peers types.Nodes, changed types.Nodes, cfg *types.Config, ) error { - packetFilter, err := pol.CompileFilterRules(users, append(peers, node)) - if err != nil { - return err - } + filter := polMan.Filter() - sshPolicy, err := pol.CompileSSHPolicy(node, users, peers) + sshPolicy, err := polMan.SSHPolicy(node) if err != nil { return err } // If there are filter rules present, see if there are any nodes that cannot // access each-other at all and remove them from the peers. - if len(packetFilter) > 0 { - changed = policy.FilterNodesByACL(node, changed, packetFilter) + if len(filter) > 0 { + changed = policy.FilterNodesByACL(node, changed, filter) } profiles := generateUserProfiles(node, changed) dnsConfig := generateDNSConfig(cfg, node) - tailPeers, err := tailNodes(changed, capVer, pol, cfg) + tailPeers, err := tailNodes(changed, capVer, polMan, cfg) if err != nil { return err } @@ -570,7 +548,7 @@ func appendPeerChanges( // new PacketFilters field and "base" allows us to send a full update when we // have to send an empty list, avoiding the hack in the else block. resp.PacketFilters = map[string][]tailcfg.FilterRule{ - "base": policy.ReduceFilterRules(node, packetFilter), + "base": policy.ReduceFilterRules(node, filter), } } else { // This is a hack to avoid sending an empty list of packet filters. @@ -578,11 +556,11 @@ func appendPeerChanges( // be omitted, causing the client to consider it unchanged, keeping the // previous packet filter. Worst case, this can cause a node that previously // has access to a node to _not_ loose access if an empty (allow none) is sent. - reduced := policy.ReduceFilterRules(node, packetFilter) + reduced := policy.ReduceFilterRules(node, filter) if len(reduced) > 0 { resp.PacketFilter = reduced } else { - resp.PacketFilter = packetFilter + resp.PacketFilter = filter } } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 8dd5180815..4ee8c6444e 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -461,18 +461,19 @@ func Test_fullMapResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node)) + mappy := NewMapper( nil, tt.cfg, tt.derpMap, nil, + polMan, ) got, err := mappy.fullMapResponse( tt.node, tt.peers, - []types.User{user1, user2}, - tt.pol, 0, ) diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 24c521dc04..4082df2b45 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -14,7 +14,7 @@ import ( func tailNodes( nodes types.Nodes, capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, cfg *types.Config, ) ([]*tailcfg.Node, error) { tNodes := make([]*tailcfg.Node, len(nodes)) @@ -23,7 +23,7 @@ func tailNodes( node, err := tailNode( node, capVer, - pol, + polMan, cfg, ) if err != nil { @@ -40,7 +40,7 @@ func tailNodes( func tailNode( node *types.Node, capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, cfg *types.Config, ) (*tailcfg.Node, error) { addrs := node.Prefixes() @@ -81,7 +81,7 @@ func tailNode( return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) } - tags, _ := pol.TagsOfNode(node) + tags := polMan.Tags(node) tags = lo.Uniq(append(tags, node.ForcedTags...)) tNode := tailcfg.Node{ diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index b6692c16fb..9d7f1fedfb 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -184,6 +184,7 @@ func TestTailNode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node}) cfg := &types.Config{ BaseDomain: tt.baseDomain, DNSConfig: tt.dnsConfig, @@ -192,7 +193,7 @@ func TestTailNode(t *testing.T) { got, err := tailNode( tt.node, 0, - tt.pol, + polMan, cfg, ) @@ -245,7 +246,7 @@ func TestNodeExpiry(t *testing.T) { tn, err := tailNode( node, 0, - &policy.ACLPolicy{}, + &policy.PolicyManagerV1{}, &types.Config{}, ) if err != nil { diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index e8461967ee..1db1ec079f 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -53,6 +54,7 @@ type AuthProviderOIDC struct { registrationCache *zcache.Cache[string, key.MachinePublic] notifier *notifier.Notifier ipAlloc *db.IPAllocator + polMan policy.PolicyManager oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -65,6 +67,7 @@ func NewAuthProviderOIDC( db *db.HSDatabase, notif *notifier.Notifier, ipAlloc *db.IPAllocator, + polMan policy.PolicyManager, ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already @@ -96,6 +99,7 @@ func NewAuthProviderOIDC( registrationCache: registrationCache, notifier: notif, ipAlloc: ipAlloc, + polMan: polMan, oidcProvider: oidcProvider, oauth2Config: oauth2Config, @@ -478,6 +482,11 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( return nil, fmt.Errorf("creating or updating user: %w", err) } + err = usersChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return user, nil } @@ -501,6 +510,11 @@ func (a *AuthProviderOIDC) registerNode( return fmt.Errorf("could not register node: %w", err) } + err = nodesChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return fmt.Errorf("updating resources using node: %w", err) + } + return nil } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go new file mode 100644 index 0000000000..7dbaed33c9 --- /dev/null +++ b/hscontrol/policy/pm.go @@ -0,0 +1,181 @@ +package policy + +import ( + "fmt" + "io" + "net/netip" + "os" + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "go4.org/netipx" + "tailscale.com/tailcfg" + "tailscale.com/util/deephash" +) + +type PolicyManager interface { + Filter() []tailcfg.FilterRule + SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) + Tags(*types.Node) []string + ApproversForRoute(netip.Prefix) []string + ExpandAlias(string) (*netipx.IPSet, error) + SetPolicy([]byte) (bool, error) + SetUsers(users []types.User) (bool, error) + SetNodes(nodes types.Nodes) (bool, error) +} + +func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) { + policyFile, err := os.Open(path) + if err != nil { + return nil, err + } + defer policyFile.Close() + + policyBytes, err := io.ReadAll(policyFile) + if err != nil { + return nil, err + } + + return NewPolicyManager(policyBytes, users, nodes) +} + +func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { + var pol *ACLPolicy + var err error + if polB != nil && len(polB) > 0 { + pol, err = LoadACLPolicyFromBytes(polB) + if err != nil { + return nil, fmt.Errorf("parsing policy: %w", err) + } + } + + pm := PolicyManagerV1{ + pol: pol, + users: users, + nodes: nodes, + } + + _, err = pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) { + pm := PolicyManagerV1{ + pol: pol, + users: users, + nodes: nodes, + } + + _, err := pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +type PolicyManagerV1 struct { + mu sync.Mutex + pol *ACLPolicy + + users []types.User + nodes types.Nodes + + filterHash deephash.Sum + filter []tailcfg.FilterRule +} + +// updateLocked updates the filter rules based on the current policy and nodes. +// It must be called with the lock held. +func (pm *PolicyManagerV1) updateLocked() (bool, error) { + filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes) + if err != nil { + return false, fmt.Errorf("compiling filter rules: %w", err) + } + + filterHash := deephash.Hash(&filter) + if filterHash == pm.filterHash { + return false, nil + } + + pm.filter = filter + pm.filterHash = filterHash + + return true, nil +} + +func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule { + pm.mu.Lock() + defer pm.mu.Unlock() + return pm.filter +} + +func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes) +} + +func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) { + pol, err := LoadACLPolicyFromBytes(polB) + if err != nil { + return false, fmt.Errorf("parsing policy: %w", err) + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.pol = pol + + return pm.updateLocked() +} + +// SetUsers updates the users in the policy manager and updates the filter rules. +func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.users = users + return pm.updateLocked() +} + +// SetNodes updates the nodes in the policy manager and updates the filter rules. +func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.nodes = nodes + return pm.updateLocked() +} + +func (pm *PolicyManagerV1) Tags(node *types.Node) []string { + if pm == nil { + return nil + } + + tags, _ := pm.pol.TagsOfNode(node) + return tags +} + +func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string { + // TODO(kradalby): This can be a parse error of the address in the policy, + // in the new policy this will be typed and not a problem, in this policy + // we will just return empty list + if pm.pol == nil { + return nil + } + approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route) + return approvers +} + +func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) { + ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias) + if err != nil { + return nil, err + } + return ips, nil +} diff --git a/hscontrol/policy/pm_test.go b/hscontrol/policy/pm_test.go new file mode 100644 index 0000000000..24b78e4d28 --- /dev/null +++ b/hscontrol/policy/pm_test.go @@ -0,0 +1,158 @@ +package policy + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func TestPolicySetChange(t *testing.T) { + users := []types.User{ + { + Model: gorm.Model{ID: 1}, + Name: "testuser", + }, + } + tests := []struct { + name string + users []types.User + nodes types.Nodes + policy []byte + wantUsersChange bool + wantNodesChange bool + wantPolicyChange bool + wantFilter []tailcfg.FilterRule + }{ + { + name: "set-nodes", + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantNodesChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users", + users: users, + wantUsersChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users-and-node", + users: users, + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantUsersChange: false, + wantNodesChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-policy", + policy: []byte(` +{ +"acls": [ + { + "action": "accept", + "src": [ + "100.64.0.61", + ], + "dst": [ + "100.64.0.62:*", + ], + }, + ], +} + `), + wantPolicyChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.61/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol := ` +{ + "groups": { + "group:example": [ + "testuser", + ], + }, + + "hosts": { + "host-1": "100.64.0.1", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +` + pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{}) + require.NoError(t, err) + + if tt.policy != nil { + change, err := pm.SetPolicy(tt.policy) + require.NoError(t, err) + + assert.Equal(t, tt.wantPolicyChange, change) + } + + if tt.users != nil { + change, err := pm.SetUsers(tt.users) + require.NoError(t, err) + + assert.Equal(t, tt.wantUsersChange, change) + } + + if tt.nodes != nil { + change, err := pm.SetNodes(tt.nodes) + require.NoError(t, err) + + assert.Equal(t, tt.wantNodesChange, change) + } + + if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" { + t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/poll.go b/hscontrol/poll.go index a8ae01f44f..e6047d4550 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -286,7 +286,7 @@ func (m *mapSession) serveLongPoll() { switch update.Type { case types.StateFullUpdate: m.tracef("Sending Full MapResponse") - data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) + data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) case types.StatePeerChanged: changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) @@ -296,12 +296,12 @@ func (m *mapSession) serveLongPoll() { lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) updateType = "change" case types.StatePeerChangedPatch: m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy) + data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches) updateType = "patch" case types.StatePeerRemoved: changed := make(map[types.NodeID]bool, len(update.Removed)) @@ -310,13 +310,13 @@ func (m *mapSession) serveLongPoll() { changed[nodeID] = false } m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) updateType = "remove" case types.StateSelfUpdate: lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) // create the map so an empty (self) update is sent - data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage) updateType = "remove" case types.StateDERPUpdated: m.tracef("Sending DERPUpdate MapResponse") @@ -488,9 +488,12 @@ func (m *mapSession) handleEndpointUpdate() { return } - if m.h.ACLPolicy != nil { + // TODO(kradalby): Only update the node that has actually changed + nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier) + + if m.h.polMan != nil { // update routes with peer information - err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) + err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node) if err != nil { m.errf(err, "Error running auto approved routes") mapResponseEndpointUpdates.WithLabelValues("error").Inc() @@ -544,7 +547,7 @@ func (m *mapSession) handleEndpointUpdate() { func (m *mapSession) handleReadOnlyRequest() { m.tracef("Client asked for a lite update, responding without peers") - mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy) + mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node) if err != nil { m.errf(err, "Failed to create MapResponse") http.Error(m.w, "", http.StatusInternalServerError) diff --git a/integration/cli_test.go b/integration/cli_test.go index 2e152deb1c..9def16f7ba 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -13,7 +13,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" ) func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error { @@ -35,7 +35,7 @@ func TestUserCommand(t *testing.T) { t.Parallel() scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -44,10 +44,10 @@ func TestUserCommand(t *testing.T) { } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) var listUsers []v1.User err = executeAndUnmarshal(headscale, @@ -60,7 +60,7 @@ func TestUserCommand(t *testing.T) { }, &listUsers, ) - require.NoError(t, err) + assertNoErr(t, err) result := []string{listUsers[0].GetName(), listUsers[1].GetName()} sort.Strings(result) @@ -82,7 +82,7 @@ func TestUserCommand(t *testing.T) { "newname", }, ) - require.NoError(t, err) + assertNoErr(t, err) var listAfterRenameUsers []v1.User err = executeAndUnmarshal(headscale, @@ -95,7 +95,7 @@ func TestUserCommand(t *testing.T) { }, &listAfterRenameUsers, ) - require.NoError(t, err) + assertNoErr(t, err) result = []string{listAfterRenameUsers[0].GetName(), listAfterRenameUsers[1].GetName()} sort.Strings(result) @@ -115,7 +115,7 @@ func TestPreAuthKeyCommand(t *testing.T) { count := 3 scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -123,13 +123,13 @@ func TestPreAuthKeyCommand(t *testing.T) { } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak")) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) keys := make([]*v1.PreAuthKey, count) - require.NoError(t, err) + assertNoErr(t, err) for index := 0; index < count; index++ { var preAuthKey v1.PreAuthKey @@ -151,7 +151,7 @@ func TestPreAuthKeyCommand(t *testing.T) { }, &preAuthKey, ) - require.NoError(t, err) + assertNoErr(t, err) keys[index] = &preAuthKey } @@ -172,7 +172,7 @@ func TestPreAuthKeyCommand(t *testing.T) { }, &listedPreAuthKeys, ) - require.NoError(t, err) + assertNoErr(t, err) // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 4) @@ -213,9 +213,7 @@ func TestPreAuthKeyCommand(t *testing.T) { continue } - tags := listedPreAuthKeys[index].GetAclTags() - sort.Strings(tags) - assert.Equal(t, []string{"tag:test1", "tag:test2"}, tags) + assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"}) } // Test key expiry @@ -229,7 +227,7 @@ func TestPreAuthKeyCommand(t *testing.T) { listedPreAuthKeys[1].GetKey(), }, ) - require.NoError(t, err) + assertNoErr(t, err) var listedPreAuthKeysAfterExpire []v1.PreAuthKey err = executeAndUnmarshal( @@ -245,7 +243,7 @@ func TestPreAuthKeyCommand(t *testing.T) { }, &listedPreAuthKeysAfterExpire, ) - require.NoError(t, err) + assertNoErr(t, err) assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now())) @@ -259,7 +257,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { user := "pre-auth-key-without-exp-user" scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -267,10 +265,10 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipaknaexp")) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) var preAuthKey v1.PreAuthKey err = executeAndUnmarshal( @@ -287,7 +285,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { }, &preAuthKey, ) - require.NoError(t, err) + assertNoErr(t, err) var listedPreAuthKeys []v1.PreAuthKey err = executeAndUnmarshal( @@ -303,7 +301,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { }, &listedPreAuthKeys, ) - require.NoError(t, err) + assertNoErr(t, err) // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 2) @@ -322,7 +320,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { user := "pre-auth-key-reus-ephm-user" scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -330,10 +328,10 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipakresueeph")) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) var preAuthReusableKey v1.PreAuthKey err = executeAndUnmarshal( @@ -350,7 +348,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { }, &preAuthReusableKey, ) - require.NoError(t, err) + assertNoErr(t, err) var preAuthEphemeralKey v1.PreAuthKey err = executeAndUnmarshal( @@ -367,7 +365,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { }, &preAuthEphemeralKey, ) - require.NoError(t, err) + assertNoErr(t, err) assert.True(t, preAuthEphemeralKey.GetEphemeral()) assert.False(t, preAuthEphemeralKey.GetReusable()) @@ -386,7 +384,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { }, &listedPreAuthKeys, ) - require.NoError(t, err) + assertNoErr(t, err) // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 3) @@ -400,7 +398,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { user2 := "user2" scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -416,10 +414,10 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { hsic.WithTLS(), hsic.WithHostnameAsServerURL(), ) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) var user2Key v1.PreAuthKey @@ -441,10 +439,10 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { }, &user2Key, ) - require.NoError(t, err) + assertNoErr(t, err) allClients, err := scenario.ListTailscaleClients() - require.NoError(t, err) + assertNoErrListClients(t, err) assert.Len(t, allClients, 1) @@ -452,22 +450,22 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { // Log out from user1 err = client.Logout() - require.NoError(t, err) + assertNoErr(t, err) err = scenario.WaitForTailscaleLogout() - require.NoError(t, err) + assertNoErr(t, err) status, err := client.Status() - require.NoError(t, err) + assertNoErr(t, err) if status.BackendState == "Starting" || status.BackendState == "Running" { t.Fatalf("expected node to be logged out, backend state: %s", status.BackendState) } err = client.Login(headscale.GetEndpoint(), user2Key.GetKey()) - require.NoError(t, err) + assertNoErr(t, err) status, err = client.Status() - require.NoError(t, err) + assertNoErr(t, err) if status.BackendState != "Running" { t.Fatalf("expected node to be logged in, backend state: %s", status.BackendState) } @@ -488,7 +486,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { }, &listNodes, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, listNodes, 1) assert.Equal(t, "user2", listNodes[0].GetUser().GetName()) @@ -501,7 +499,7 @@ func TestApiKeyCommand(t *testing.T) { count := 5 scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -510,10 +508,10 @@ func TestApiKeyCommand(t *testing.T) { } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) keys := make([]string, count) @@ -529,7 +527,7 @@ func TestApiKeyCommand(t *testing.T) { "json", }, ) - require.NoError(t, err) + assert.Nil(t, err) assert.NotEmpty(t, apiResult) keys[idx] = apiResult @@ -548,7 +546,7 @@ func TestApiKeyCommand(t *testing.T) { }, &listedAPIKeys, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, listedAPIKeys, 5) @@ -604,7 +602,7 @@ func TestApiKeyCommand(t *testing.T) { listedAPIKeys[idx].GetPrefix(), }, ) - require.NoError(t, err) + assert.Nil(t, err) expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true } @@ -620,7 +618,7 @@ func TestApiKeyCommand(t *testing.T) { }, &listedAfterExpireAPIKeys, ) - require.NoError(t, err) + assert.Nil(t, err) for index := range listedAfterExpireAPIKeys { if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok { @@ -646,7 +644,7 @@ func TestApiKeyCommand(t *testing.T) { "--prefix", listedAPIKeys[0].GetPrefix(), }) - require.NoError(t, err) + assert.Nil(t, err) var listedAPIKeysAfterDelete []v1.ApiKey err = executeAndUnmarshal(headscale, @@ -659,7 +657,7 @@ func TestApiKeyCommand(t *testing.T) { }, &listedAPIKeysAfterDelete, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, listedAPIKeysAfterDelete, 4) } @@ -669,7 +667,7 @@ func TestNodeTagCommand(t *testing.T) { t.Parallel() scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -677,17 +675,17 @@ func TestNodeTagCommand(t *testing.T) { } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) machineKeys := []string{ "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", } nodes := make([]*v1.Node, len(machineKeys)) - require.NoError(t, err) + assert.Nil(t, err) for index, machineKey := range machineKeys { _, err := headscale.Execute( @@ -705,7 +703,7 @@ func TestNodeTagCommand(t *testing.T) { "json", }, ) - require.NoError(t, err) + assert.Nil(t, err) var node v1.Node err = executeAndUnmarshal( @@ -723,7 +721,7 @@ func TestNodeTagCommand(t *testing.T) { }, &node, ) - require.NoError(t, err) + assert.Nil(t, err) nodes[index] = &node } @@ -742,7 +740,7 @@ func TestNodeTagCommand(t *testing.T) { }, &node, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, []string{"tag:test"}, node.GetForcedTags()) @@ -756,7 +754,7 @@ func TestNodeTagCommand(t *testing.T) { "--output", "json", }, ) - require.ErrorContains(t, err, "tag must start with the string 'tag:'") + assert.ErrorContains(t, err, "tag must start with the string 'tag:'") // Test list all nodes after added seconds resultMachines := make([]*v1.Node, len(machineKeys)) @@ -770,7 +768,7 @@ func TestNodeTagCommand(t *testing.T) { }, &resultMachines, ) - require.NoError(t, err) + assert.Nil(t, err) found := false for _, node := range resultMachines { if node.GetForcedTags() != nil { @@ -781,84 +779,30 @@ func TestNodeTagCommand(t *testing.T) { } } } - assert.True( + assert.Equal( t, + true, found, "should find a node with the tag 'tag:test' in the list of nodes", ) } -func TestNodeAdvertiseTagNoACLCommand(t *testing.T) { +func TestNodeAdvertiseTagCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) - defer scenario.ShutdownAssertNoPanics(t) - - spec := map[string]int{ - "user1": 1, - } - - err = scenario.CreateHeadscaleEnv( - spec, - []tsic.Option{tsic.WithTags([]string{"tag:test"})}, - hsic.WithTestName("cliadvtags"), - ) - require.NoError(t, err) - - headscale, err := scenario.Headscale() - require.NoError(t, err) - - // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, spec["user1"]) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--tags", - "--output", "json", + tests := []struct { + name string + policy *policy.ACLPolicy + wantTag bool + }{ + { + name: "no-policy", + wantTag: false, }, - &resultMachines, - ) - require.NoError(t, err) - found := false - for _, node := range resultMachines { - if node.GetInvalidTags() != nil { - for _, tag := range node.GetInvalidTags() { - if tag == "tag:test" { - found = true - } - } - } - } - assert.True( - t, - found, - "should not find a node with the tag 'tag:test' in the list of nodes", - ) -} - -func TestNodeAdvertiseTagWithACLCommand(t *testing.T) { - IntegrationSkip(t) - t.Parallel() - - scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) - defer scenario.ShutdownAssertNoPanics(t) - - spec := map[string]int{ - "user1": 1, - } - - err = scenario.CreateHeadscaleEnv( - spec, - []tsic.Option{tsic.WithTags([]string{"tag:exists"})}, - hsic.WithTestName("cliadvtags"), - hsic.WithACLPolicy( - &policy.ACLPolicy{ + { + name: "with-policy", + policy: &policy.ACLPolicy{ ACLs: []policy.ACL{ { Action: "accept", @@ -867,45 +811,61 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) { }, }, TagOwners: map[string][]string{ - "tag:exists": {"user1"}, + "tag:test": {"user1"}, }, }, - ), - ) - require.NoError(t, err) + wantTag: true, + }, + } - headscale, err := scenario.Headscale() - require.NoError(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + // defer scenario.ShutdownAssertNoPanics(t) - // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, spec["user1"]) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--tags", - "--output", "json", - }, - &resultMachines, - ) - require.NoError(t, err) - found := false - for _, node := range resultMachines { - if node.GetValidTags() != nil { - for _, tag := range node.GetValidTags() { - if tag == "tag:exists" { - found = true + spec := map[string]int{ + "user1": 1, + } + + err = scenario.CreateHeadscaleEnv(spec, + []tsic.Option{tsic.WithTags([]string{"tag:test"})}, + hsic.WithTestName("cliadvtags"), + hsic.WithACLPolicy(tt.policy), + ) + assertNoErr(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + // Test list all nodes after added seconds + resultMachines := make([]*v1.Node, spec["user1"]) + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--tags", + "--output", "json", + }, + &resultMachines, + ) + assert.Nil(t, err) + found := false + for _, node := range resultMachines { + if tags := node.GetValidTags(); tags != nil { + found = slices.Contains(tags, "tag:test") } } - } + assert.Equalf( + t, + tt.wantTag, + found, + "'tag:test' found(%t) is the list of nodes, expected %t", found, tt.wantTag, + ) + }) } - assert.True( - t, - found, - "should not find a node with the tag 'tag:exists' in the list of nodes", - ) } func TestNodeCommand(t *testing.T) { @@ -913,7 +873,7 @@ func TestNodeCommand(t *testing.T) { t.Parallel() scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -922,10 +882,10 @@ func TestNodeCommand(t *testing.T) { } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) // Pregenerated machine keys machineKeys := []string{ @@ -936,7 +896,7 @@ func TestNodeCommand(t *testing.T) { "mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", } nodes := make([]*v1.Node, len(machineKeys)) - require.NoError(t, err) + assert.Nil(t, err) for index, machineKey := range machineKeys { _, err := headscale.Execute( @@ -954,7 +914,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - require.NoError(t, err) + assert.Nil(t, err) var node v1.Node err = executeAndUnmarshal( @@ -972,7 +932,7 @@ func TestNodeCommand(t *testing.T) { }, &node, ) - require.NoError(t, err) + assert.Nil(t, err) nodes[index] = &node } @@ -992,7 +952,7 @@ func TestNodeCommand(t *testing.T) { }, &listAll, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, listAll, 5) @@ -1013,7 +973,7 @@ func TestNodeCommand(t *testing.T) { "mkey:dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584", } otherUserMachines := make([]*v1.Node, len(otherUserMachineKeys)) - require.NoError(t, err) + assert.Nil(t, err) for index, machineKey := range otherUserMachineKeys { _, err := headscale.Execute( @@ -1031,7 +991,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - require.NoError(t, err) + assert.Nil(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1049,7 +1009,7 @@ func TestNodeCommand(t *testing.T) { }, &node, ) - require.NoError(t, err) + assert.Nil(t, err) otherUserMachines[index] = &node } @@ -1069,7 +1029,7 @@ func TestNodeCommand(t *testing.T) { }, &listAllWithotherUser, ) - require.NoError(t, err) + assert.Nil(t, err) // All nodes, nodes + otherUser assert.Len(t, listAllWithotherUser, 7) @@ -1095,7 +1055,7 @@ func TestNodeCommand(t *testing.T) { }, &listOnlyotherUserMachineUser, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, listOnlyotherUserMachineUser, 2) @@ -1127,7 +1087,7 @@ func TestNodeCommand(t *testing.T) { "--force", }, ) - require.NoError(t, err) + assert.Nil(t, err) // Test: list main user after node is deleted var listOnlyMachineUserAfterDelete []v1.Node @@ -1144,7 +1104,7 @@ func TestNodeCommand(t *testing.T) { }, &listOnlyMachineUserAfterDelete, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, listOnlyMachineUserAfterDelete, 4) } @@ -1154,7 +1114,7 @@ func TestNodeExpireCommand(t *testing.T) { t.Parallel() scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -1162,10 +1122,10 @@ func TestNodeExpireCommand(t *testing.T) { } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) // Pregenerated machine keys machineKeys := []string{ @@ -1193,7 +1153,7 @@ func TestNodeExpireCommand(t *testing.T) { "json", }, ) - require.NoError(t, err) + assert.Nil(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1211,7 +1171,7 @@ func TestNodeExpireCommand(t *testing.T) { }, &node, ) - require.NoError(t, err) + assert.Nil(t, err) nodes[index] = &node } @@ -1230,7 +1190,7 @@ func TestNodeExpireCommand(t *testing.T) { }, &listAll, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, listAll, 5) @@ -1250,7 +1210,7 @@ func TestNodeExpireCommand(t *testing.T) { fmt.Sprintf("%d", listAll[idx].GetId()), }, ) - require.NoError(t, err) + assert.Nil(t, err) } var listAllAfterExpiry []v1.Node @@ -1265,7 +1225,7 @@ func TestNodeExpireCommand(t *testing.T) { }, &listAllAfterExpiry, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, listAllAfterExpiry, 5) @@ -1281,7 +1241,7 @@ func TestNodeRenameCommand(t *testing.T) { t.Parallel() scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -1289,10 +1249,10 @@ func TestNodeRenameCommand(t *testing.T) { } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) // Pregenerated machine keys machineKeys := []string{ @@ -1303,7 +1263,7 @@ func TestNodeRenameCommand(t *testing.T) { "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", } nodes := make([]*v1.Node, len(machineKeys)) - require.NoError(t, err) + assert.Nil(t, err) for index, machineKey := range machineKeys { _, err := headscale.Execute( @@ -1321,7 +1281,7 @@ func TestNodeRenameCommand(t *testing.T) { "json", }, ) - require.NoError(t, err) + assertNoErr(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1339,7 +1299,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &node, ) - require.NoError(t, err) + assertNoErr(t, err) nodes[index] = &node } @@ -1358,7 +1318,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &listAll, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, listAll, 5) @@ -1379,7 +1339,7 @@ func TestNodeRenameCommand(t *testing.T) { fmt.Sprintf("newnode-%d", idx+1), }, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Contains(t, res, "Node renamed") } @@ -1396,7 +1356,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &listAllAfterRename, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, listAllAfterRename, 5) @@ -1417,7 +1377,7 @@ func TestNodeRenameCommand(t *testing.T) { strings.Repeat("t", 64), }, ) - require.ErrorContains(t, err, "not be over 63 chars") + assert.ErrorContains(t, err, "not be over 63 chars") var listAllAfterRenameAttempt []v1.Node err = executeAndUnmarshal( @@ -1431,7 +1391,7 @@ func TestNodeRenameCommand(t *testing.T) { }, &listAllAfterRenameAttempt, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, listAllAfterRenameAttempt, 5) @@ -1447,7 +1407,7 @@ func TestNodeMoveCommand(t *testing.T) { t.Parallel() scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -1456,10 +1416,10 @@ func TestNodeMoveCommand(t *testing.T) { } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) // Randomly generated node key machineKey := "mkey:688411b767663479632d44140f08a9fde87383adc7cdeb518f62ce28a17ef0aa" @@ -1479,7 +1439,7 @@ func TestNodeMoveCommand(t *testing.T) { "json", }, ) - require.NoError(t, err) + assert.Nil(t, err) var node v1.Node err = executeAndUnmarshal( @@ -1497,11 +1457,11 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Equal(t, uint64(1), node.GetId()) assert.Equal(t, "nomad-node", node.GetName()) - assert.Equal(t, "old-user", node.GetUser().GetName()) + assert.Equal(t, node.GetUser().GetName(), "old-user") nodeID := fmt.Sprintf("%d", node.GetId()) @@ -1520,9 +1480,9 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - require.NoError(t, err) + assert.Nil(t, err) - assert.Equal(t, "new-user", node.GetUser().GetName()) + assert.Equal(t, node.GetUser().GetName(), "new-user") var allNodes []v1.Node err = executeAndUnmarshal( @@ -1536,13 +1496,13 @@ func TestNodeMoveCommand(t *testing.T) { }, &allNodes, ) - require.NoError(t, err) + assert.Nil(t, err) assert.Len(t, allNodes, 1) assert.Equal(t, allNodes[0].GetId(), node.GetId()) assert.Equal(t, allNodes[0].GetUser(), node.GetUser()) - assert.Equal(t, "new-user", allNodes[0].GetUser().GetName()) + assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user") _, err = headscale.Execute( []string{ @@ -1557,12 +1517,12 @@ func TestNodeMoveCommand(t *testing.T) { "json", }, ) - require.ErrorContains( + assert.ErrorContains( t, err, "user not found", ) - assert.Equal(t, "new-user", node.GetUser().GetName()) + assert.Equal(t, node.GetUser().GetName(), "new-user") err = executeAndUnmarshal( headscale, @@ -1579,9 +1539,9 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - require.NoError(t, err) + assert.Nil(t, err) - assert.Equal(t, "old-user", node.GetUser().GetName()) + assert.Equal(t, node.GetUser().GetName(), "old-user") err = executeAndUnmarshal( headscale, @@ -1598,9 +1558,9 @@ func TestNodeMoveCommand(t *testing.T) { }, &node, ) - require.NoError(t, err) + assert.Nil(t, err) - assert.Equal(t, "old-user", node.GetUser().GetName()) + assert.Equal(t, node.GetUser().GetName(), "old-user") } func TestPolicyCommand(t *testing.T) { @@ -1608,7 +1568,7 @@ func TestPolicyCommand(t *testing.T) { t.Parallel() scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -1623,10 +1583,10 @@ func TestPolicyCommand(t *testing.T) { "HEADSCALE_POLICY_MODE": "database", }), ) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) p := policy.ACLPolicy{ ACLs: []policy.ACL{ @@ -1646,7 +1606,7 @@ func TestPolicyCommand(t *testing.T) { policyFilePath := "/etc/headscale/policy.json" err = headscale.WriteFile(policyFilePath, pBytes) - require.NoError(t, err) + assertNoErr(t, err) // No policy is present at this time. // Add a new policy from a file. @@ -1660,7 +1620,7 @@ func TestPolicyCommand(t *testing.T) { }, ) - require.NoError(t, err) + assertNoErr(t, err) // Get the current policy and check // if it is the same as the one we set. @@ -1676,11 +1636,11 @@ func TestPolicyCommand(t *testing.T) { }, &output, ) - require.NoError(t, err) + assertNoErr(t, err) assert.Len(t, output.TagOwners, 1) assert.Len(t, output.ACLs, 1) - assert.Equal(t, []string{"policy-user"}, output.TagOwners["tag:exists"]) + assert.Equal(t, output.TagOwners["tag:exists"], []string{"policy-user"}) } func TestPolicyBrokenConfigCommand(t *testing.T) { @@ -1688,7 +1648,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { t.Parallel() scenario, err := NewScenario(dockertestMaxWait()) - require.NoError(t, err) + assertNoErr(t, err) defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ @@ -1703,10 +1663,10 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { "HEADSCALE_POLICY_MODE": "database", }), ) - require.NoError(t, err) + assertNoErr(t, err) headscale, err := scenario.Headscale() - require.NoError(t, err) + assertNoErr(t, err) p := policy.ACLPolicy{ ACLs: []policy.ACL{ @@ -1728,7 +1688,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { policyFilePath := "/etc/headscale/policy.json" err = headscale.WriteFile(policyFilePath, pBytes) - require.NoError(t, err) + assertNoErr(t, err) // No policy is present at this time. // Add a new policy from a file. @@ -1741,7 +1701,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { policyFilePath, }, ) - require.ErrorContains(t, err, "verifying policy rules: invalid action") + assert.ErrorContains(t, err, "compiling filter rules: invalid action") // The new policy was invalid, the old one should still be in place, which // is none. @@ -1754,5 +1714,6 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { "json", }, ) - require.ErrorContains(t, err, "acl policy not found") + assert.ErrorContains(t, err, "acl policy not found") } + diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index cd725f31cb..a008d9d5d2 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -79,6 +79,10 @@ type Option = func(c *HeadscaleInContainer) // HeadscaleInContainer instance. func WithACLPolicy(acl *policy.ACLPolicy) Option { return func(hsic *HeadscaleInContainer) { + if acl == nil { + return + } + // TODO(kradalby): Move somewhere appropriate hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath