Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
policy manager
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
kradalby committed Oct 25, 2024

Verified

This commit was signed with the committer’s verified signature.
renovate-bot Mend Renovate
1 parent 9375836 commit 2e68455
Showing 9 changed files with 455 additions and 70 deletions.
4 changes: 3 additions & 1 deletion hscontrol/app.go
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@ import (
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policyv2"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
zerolog "github.com/philip-bui/grpc-zerolog"
@@ -88,7 +89,8 @@ type Headscale struct {
DERPMap *tailcfg.DERPMap
DERPServer *derpServer.DERPServer

ACLPolicy *policy.ACLPolicy
ACLPolicy *policy.ACLPolicy
PolicyManager *policyv2.PolicyManager

mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
105 changes: 103 additions & 2 deletions hscontrol/policyv2/filter.go
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ package policyv2
import (
"errors"
"fmt"
"time"

"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
@@ -16,6 +17,7 @@ var (
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *Policy) CompileFilterRules(
users types.Users,
nodes types.Nodes,
) ([]tailcfg.FilterRule, error) {
if pol == nil {
@@ -29,7 +31,7 @@ func (pol *Policy) CompileFilterRules(
return nil, ErrInvalidAction
}

srcIPs, err := acl.Sources.Resolve(pol, nodes)
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
if err != nil {
return nil, fmt.Errorf("resolving source ips: %w", err)
}
@@ -43,7 +45,7 @@ func (pol *Policy) CompileFilterRules(

var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {
ips, err := dest.Alias.Resolve(pol, nodes)
ips, err := dest.Alias.Resolve(pol, users, nodes)
if err != nil {
return nil, err
}
@@ -69,6 +71,105 @@ func (pol *Policy) CompileFilterRules(
return rules, nil
}

func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
return tailcfg.SSHAction{
Reject: !accept,
Accept: accept,
SessionDuration: duration,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
}
}

func (pol *Policy) CompileSSHPolicy(
users types.Users,
node types.Node,
nodes types.Nodes,
) (*tailcfg.SSHPolicy, error) {
if pol == nil {
return nil, nil
}

var rules []*tailcfg.SSHRule

for index, rule := range pol.SSHs {
var dest netipx.IPSetBuilder
for _, src := range rule.Destinations {
ips, err := src.Resolve(pol, users, nodes)
if err != nil {
return nil, err
}
dest.AddSet(ips)
}

destSet, err := dest.IPSet()
if err != nil {
return nil, err
}

if !node.InIPSet(destSet) {
continue
}

var action tailcfg.SSHAction
switch rule.Action {
case "accept":
action = sshAction(true, 0)
case "check":
action = sshAction(true, rule.CheckPeriod)
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
}

var principals []*tailcfg.SSHPrincipal
for _, src := range rule.Sources {
if isWildcard(rawSrc) {
principals = append(principals, &tailcfg.SSHPrincipal{
Any: true,
})
} else if isGroup(rawSrc) {
users, err := pol.expandUsersFromGroup(rawSrc)
if err != nil {
return nil, fmt.Errorf("parsing SSH policy, expanding user from group, index: %d->%d: %w", index, innerIndex, err)
}

for _, user := range users {
principals = append(principals, &tailcfg.SSHPrincipal{
UserLogin: user,
})
}
} else {
expandedSrcs, err := pol.ExpandAlias(
peers,
rawSrc,
)
if err != nil {
return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err)
}
for _, expandedSrc := range expandedSrcs.Prefixes() {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: expandedSrc.Addr().String(),
})
}
}
}

userMap := make(map[string]string, len(rule.Users))
for _, user := range rule.Users {
userMap[user] = "="
}
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
}

return &tailcfg.SSHPolicy{
Rules: rules,
}, nil
}

func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
var out []string

97 changes: 54 additions & 43 deletions hscontrol/policyv2/filter_test.go
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
)
@@ -17,6 +18,9 @@ import (
// Move it here, run it against both old and new CompileFilterRules

func TestParsing(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "testuser@"},
}
tests := []struct {
name string
format string
@@ -340,7 +344,7 @@ func TestParsing(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pol, err := PolicyFromBytes([]byte(tt.acl))
pol, err := policyFromBytes([]byte(tt.acl))
if tt.wantErr && err == nil {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)

@@ -355,18 +359,18 @@ func TestParsing(t *testing.T) {
return
}

rules, err := pol.CompileFilterRules(types.Nodes{
&types.Node{
IPv4: ap("100.100.100.100"),
},
&types.Node{
IPv4: ap("200.200.200.200"),
User: types.User{
Name: "testuser@",
rules, err := pol.CompileFilterRules(
users,
types.Nodes{
&types.Node{
IPv4: ap("100.100.100.100"),
},
Hostinfo: &tailcfg.Hostinfo{},
},
})
&types.Node{
IPv4: ap("200.200.200.200"),
User: users[0],
Hostinfo: &tailcfg.Hostinfo{},
},
})

if (err != nil) != tt.wantErr {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
@@ -435,6 +439,14 @@ var hsExitNodeDestForTest = []tailcfg.NetPortRange{
}

func TestReduceFilterRules(t *testing.T) {
users := types.Users{
types.User{Model: gorm.Model{ID: 1}, Name: "mickael"},
types.User{Model: gorm.Model{ID: 2}, Name: "user1@"},
types.User{Model: gorm.Model{ID: 3}, Name: "user2@"},
types.User{Model: gorm.Model{ID: 4}, Name: "user100@"},
types.User{Model: gorm.Model{ID: 5}, Name: "user3@"},
}

tests := []struct {
name string
node *types.Node
@@ -463,13 +475,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: types.User{Name: "mickael"},
User: users[0],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: types.User{Name: "mickael"},
User: users[0],
},
},
want: []tailcfg.FilterRule{},
@@ -510,7 +522,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"},
User: users[1],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
@@ -521,7 +533,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user1@"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@@ -600,19 +612,19 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"},
User: users[1],
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2@"},
User: users[2],
},
// "internal" exit node
&types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@@ -661,7 +673,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@@ -670,12 +682,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2@"},
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@@ -768,7 +780,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@@ -777,12 +789,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2@"},
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@@ -809,9 +821,11 @@ func TestReduceFilterRules(t *testing.T) {
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny},
{IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny},
{IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny},
{IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny},
// This should not be included I believe, seems like
// this is a bug in the v1 code.
// {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny},
// {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny},
// {IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny},
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny},
@@ -881,7 +895,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
},
@@ -890,12 +904,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2@"},
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@@ -969,7 +983,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
},
@@ -978,12 +992,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2@"},
User: users[2],
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@@ -1046,7 +1060,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
@@ -1056,7 +1070,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@@ -1090,22 +1104,19 @@ func TestReduceFilterRules(t *testing.T) {
filterV1, _ := polV1.CompileFilterRules(
append(tt.peers, tt.node),
)
polV2, err := PolicyFromBytes([]byte(tt.pol))
pm, err := NewPolicyManager([]byte(tt.pol), users, append(tt.peers, tt.node))
if err != nil {
t.Fatalf("parsing policy: %s", err)
}
filterV2, _ := polV2.CompileFilterRules(
append(tt.peers, tt.node),
)

if diff := cmp.Diff(filterV1, filterV2); diff != "" {
log.Trace().Interface("got", filterV2).Msg("result")
if diff := cmp.Diff(filterV1, pm.Filter()); diff != "" {
log.Trace().Interface("got", pm.Filter()).Msg("result")
t.Errorf("TestReduceFilterRules() unexpected diff between v1 and v2 (-want +got):\n%s", diff)
}

// TODO(kradalby): Move this from v1, or
// rewrite.
filterV2 = policy.ReduceFilterRules(tt.node, filterV2)
filterV2 := policy.ReduceFilterRules(tt.node, pm.Filter())

if diff := cmp.Diff(tt.want, filterV2); diff != "" {
log.Trace().Interface("got", filterV2).Msg("result")
80 changes: 80 additions & 0 deletions hscontrol/policyv2/policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package policyv2

import (
"fmt"
"sync"

"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
)

type PolicyManager struct {
mu sync.Mutex
pol *Policy
users []types.User
nodes types.Nodes

filter []tailcfg.FilterRule

// TODO(kradalby): Implement SSH policy
sshPolicy *tailcfg.SSHPolicy
}

// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
// It returns an error if the policy file is invalid.
// The policy manager will update the filter rules based on the users and nodes.
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
policy, err := policyFromBytes(b)
if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err)
}

pm := PolicyManager{
pol: policy,
users: users,
nodes: nodes,
}

err = pm.updateLocked()
if err != nil {
return nil, err
}

return &pm, nil
}

// Filter returns the current filter rules for the entire tailnet.
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter
}

// updateLocked updates the filter rules based on the current policy and nodes.
// It must be called with the lock held.
func (pm *PolicyManager) updateLocked() error {
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
if err != nil {
return fmt.Errorf("compiling filter rules: %w", err)
}

pm.filter = filter

return nil
}

// SetUsers updates the users in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetUsers(users []types.User) error {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}

// SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetNodes(nodes types.Nodes) error {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}
58 changes: 58 additions & 0 deletions hscontrol/policyv2/policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package policyv2

import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)

func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
return &types.Node{
ID: 0,
Hostname: name,
IPv4: ap(ipv4),
IPv6: ap(ipv6),
User: user,
UserID: user.ID,
Hostinfo: hostinfo,
}
}

func TestPolicyManager(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
}

tests := []struct {
name string
pol string
nodes types.Nodes
wantFilter []tailcfg.FilterRule
}{
{
name: "empty-policy",
pol: "{}",
nodes: types.Nodes{},
wantFilter: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
require.NoError(t, err)

filter := pm.Filter()
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
}

// TODO(kradalby): Test SSH Policy
})
}
}
134 changes: 120 additions & 14 deletions hscontrol/policyv2/types.go
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ import (
"net/netip"
"strconv"
"strings"
"time"

"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
@@ -64,7 +65,7 @@ func (a Asterix) UnmarshalJSON(b []byte) error {
return nil
}

func (a Asterix) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder

ips.AddPrefix(tsaddr.AllIPv4())
@@ -99,15 +100,47 @@ func (u Username) CanBeTagOwner() bool {
return true
}

func (u Username) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
func (u Username) resolveUser(users types.Users) (*types.User, error) {
var potentialUsers types.Users
for _, user := range users {
if user.ProviderIdentifier == string(u) {
potentialUsers = append(potentialUsers, user)

break
}
if user.Email == string(u) {
potentialUsers = append(potentialUsers, user)
}
if user.Name == string(u) {
potentialUsers = append(potentialUsers, user)
}
}

if len(potentialUsers) > 1 {
return nil, fmt.Errorf("unable to resolve user identifier to distinct: %s matched multiple %s", u, potentialUsers)
} else if len(potentialUsers) == 0 {
return nil, fmt.Errorf("unable to resolve user identifier, no user found: %s not in %s", u, users)
}

user := potentialUsers[0]

return &user, nil
}

func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder

user, err := u.resolveUser(users)
if err != nil {
return nil, err
}

for _, node := range nodes {
if node.IsTagged() {
continue
}

if node.User.Username() == string(u) {
if node.User.ID == user.ID {
node.AppendToIPSet(&ips)
}
}
@@ -137,11 +170,11 @@ func (g Group) CanBeTagOwner() bool {
return true
}

func (g Group) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder

for _, user := range p.Groups[g] {
uips, err := user.Resolve(nil, nodes)
uips, err := user.Resolve(nil, users, nodes)
if err != nil {
return nil, err
}
@@ -170,7 +203,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
return nil
}

func (t Tag) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
func (t Tag) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder

for _, node := range nodes {
@@ -197,7 +230,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
return nil
}

func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder

pref, ok := p.Hosts[h]
@@ -208,11 +241,26 @@ func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
if err != nil {
return nil, err
}

// If the IP is a single host, look for a node to ensure we add all the IPs of
// the node to the IPSet.
appendIfNodeHasIP(nodes, &ips, pref)
ips.AddPrefix(netip.Prefix(pref))

return ips.IPSet()
}

func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref Prefix) {
if netip.Prefix(pref).IsSingleIP() {
addr := netip.Prefix(pref).Addr()
for _, node := range nodes {
if node.HasIP(addr) {
node.AppendToIPSet(ips)
}
}
}
}

type Prefix netip.Prefix

func (p Prefix) Validate() error {
@@ -261,9 +309,10 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
return nil
}

func (p Prefix) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) {
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder

appendIfNodeHasIP(nodes, &ips, p)
ips.AddPrefix(netip.Prefix(p))

return ips.IPSet()
@@ -296,7 +345,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
return nil
}

func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) {
func (ag AutoGroup) Resolve(_ *Policy, _ types.Users, _ types.Nodes) (*netipx.IPSet, error) {
switch ag {
case AutoGroupInternet:
return theInternet(), nil
@@ -308,7 +357,7 @@ func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) {
type Alias interface {
Validate() error
UnmarshalJSON([]byte) error
Resolve(*Policy, types.Nodes) (*netipx.IPSet, error)
Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error)
}

type AliasWithPorts struct {
@@ -428,11 +477,11 @@ func (a *Aliases) UnmarshalJSON(b []byte) error {
return nil
}

func (a Aliases) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder

for _, alias := range a {
aips, err := alias.Resolve(p, nodes)
aips, err := alias.Resolve(p, users, nodes)
if err != nil {
return nil, err
}
@@ -530,10 +579,67 @@ type Policy struct {
TagOwners TagOwners `json:"tagOwners"`
ACLs []ACL `json:"acls"`
AutoApprovers AutoApprovers `json:"autoApprovers"`
// SSHs []SSH `json:"ssh"`
SSHs []SSH `json:"ssh"`
}

// SSH controls who can ssh into which machines.
type SSH struct {
Action string `json:"action"`
Sources SSHSrcAliases `json:"src"`
Destinations SSHDstAliases `json:"dst"`
Users []SSHUser `json:"users"`
CheckPeriod time.Duration `json:"checkPeriod,omitempty"`
}

// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule.
// It can be a list of usernames, groups, tags or autogroups.
type SSHSrcAliases []Alias

func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
if err != nil {
return err
}

*a = make([]Alias, len(aliases))
for i, alias := range aliases {
switch alias.Alias.(type) {
case *Username, *Group, *Tag, *AutoGroup:
(*a)[i] = alias.Alias
default:
return fmt.Errorf("type %T not supported", alias.Alias)
}
}
return nil
}

func PolicyFromBytes(b []byte) (*Policy, error) {
// SSHDstAliases is a list of aliases that can be used as destinations in an SSH rule.
// It can be a list of usernames, tags or autogroups.
type SSHDstAliases []Alias

func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
if err != nil {
return err
}

*a = make([]Alias, len(aliases))
for i, alias := range aliases {
switch alias.Alias.(type) {
case *Username, *Tag, *AutoGroup:
(*a)[i] = alias.Alias
default:
return fmt.Errorf("type %T not supported", alias.Alias)
}
}
return nil
}

type SSHUser string

func policyFromBytes(b []byte) (*Policy, error) {
var policy Policy
ast, err := hujson.Parse(b)
if err != nil {
22 changes: 12 additions & 10 deletions hscontrol/policyv2/types_test.go
Original file line number Diff line number Diff line change
@@ -173,7 +173,7 @@ func TestUnmarshalPolicy(t *testing.T) {
Destinations: []AliasWithPorts{
{
Alias: ptr.To(Username("otheruser@headscale.net")),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
@@ -186,7 +186,7 @@ func TestUnmarshalPolicy(t *testing.T) {
Destinations: []AliasWithPorts{
{
Alias: gp("group:other"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
@@ -199,7 +199,7 @@ func TestUnmarshalPolicy(t *testing.T) {
Destinations: []AliasWithPorts{
{
Alias: pp("100.101.102.104/32"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
@@ -212,7 +212,7 @@ func TestUnmarshalPolicy(t *testing.T) {
Destinations: []AliasWithPorts{
{
Alias: pp("172.16.0.0/16"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
@@ -225,7 +225,7 @@ func TestUnmarshalPolicy(t *testing.T) {
Destinations: []AliasWithPorts{
{
Alias: hp("host-1"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}},
Ports: []tailcfg.PortRange{{First: 80, Last: 88}},
},
},
},
@@ -239,8 +239,8 @@ func TestUnmarshalPolicy(t *testing.T) {
{
Alias: tp("tag:user"),
Ports: []tailcfg.PortRange{
tailcfg.PortRange{First: 80, Last: 80},
tailcfg.PortRange{First: 443, Last: 443},
{First: 80, Last: 80},
{First: 443, Last: 443},
},
},
},
@@ -255,7 +255,7 @@ func TestUnmarshalPolicy(t *testing.T) {
{
Alias: agp("autogroup:internet"),
Ports: []tailcfg.PortRange{
tailcfg.PortRange{First: 80, Last: 80},
{First: 80, Last: 80},
},
},
},
@@ -341,7 +341,7 @@ func TestUnmarshalPolicy(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policy, err := PolicyFromBytes([]byte(tt.input))
policy, err := policyFromBytes([]byte(tt.input))
// TODO(kradalby): This error checking is broken,
// but so is my brain, #longflight
if err == nil {
@@ -538,7 +538,9 @@ func TestResolvePolicy(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ips, err := tt.toResolve.Resolve(tt.pol, tt.nodes)
ips, err := tt.toResolve.Resolve(tt.pol,
types.Users{},
tt.nodes)
if err != nil {
t.Fatalf("failed to resolve: %s", err)
}
10 changes: 10 additions & 0 deletions hscontrol/types/node.go
Original file line number Diff line number Diff line change
@@ -130,6 +130,16 @@ func (node *Node) IPs() []netip.Addr {
return ret
}

// HasIP reports if a node has a given IP address.
func (node *Node) HasIP(i netip.Addr) bool {
for _, ip := range node.IPs() {
if ip.Compare(i) == 0 {
return true
}
}
return false
}

// IsTagged reports if a device is tagged
// and therefore should not be treated as a
// user owned device.
15 changes: 15 additions & 0 deletions hscontrol/types/users.go
Original file line number Diff line number Diff line change
@@ -2,7 +2,9 @@ package types

import (
"cmp"
"fmt"
"strconv"
"strings"

v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
@@ -13,6 +15,19 @@ import (

type UserID uint64

type Users []User

func (u Users) String() string {
var sb strings.Builder
sb.WriteString("[ ")
for _, user := range u {
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
}
sb.WriteString(" ]")

return sb.String()
}

// User is the way Headscale implements the concept of users in Tailscale
//
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users

0 comments on commit 2e68455

Please sign in to comment.