Skip to content

Commit

Permalink
use ipv4/6 everywhere instead of address list
Browse files Browse the repository at this point in the history
Updates juanfont#1828

Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Apr 15, 2024
1 parent 4af1460 commit c1a05f4
Show file tree
Hide file tree
Showing 16 changed files with 586 additions and 774 deletions.
5 changes: 2 additions & 3 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ func (h *Headscale) handleAuthKey(
ForcedTags: pak.Proto().GetAclTags(),
}

addrs, err := h.ipAlloc.Next()
ipv4, ipv6, err := h.ipAlloc.Next()
if err != nil {
log.Error().
Caller().
Expand All @@ -397,7 +397,7 @@ func (h *Headscale) handleAuthKey(

node, err = h.db.RegisterNode(
nodeToRegister,
addrs,
ipv4, ipv6,
)
if err != nil {
log.Error().
Expand Down Expand Up @@ -461,7 +461,6 @@ func (h *Headscale) handleAuthKey(

log.Info().
Str("node", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(node.IPAddresses.StringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey")
}

Expand Down
58 changes: 33 additions & 25 deletions hscontrol/db/ip.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"database/sql"
"errors"
"fmt"
"net/netip"
Expand Down Expand Up @@ -46,12 +47,24 @@ func NewIPAllocator(db *HSDatabase, prefix4, prefix6 *netip.Prefix) (*IPAllocato
prefix6: prefix6,
}

var addressesSlices []string
var v4s []sql.NullString
var v6s []sql.NullString

if db != nil {
db.Read(func(rx *gorm.DB) error {
return rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices).Error
err := db.Read(func(rx *gorm.DB) error {
return rx.Model(&types.Node{}).Pluck("ipv4", &v4s).Error
})
if err != nil {
return nil, fmt.Errorf("reading IPv4 addresses from database: %w", err)
}

err = db.Read(func(rx *gorm.DB) error {
return rx.Model(&types.Node{}).Pluck("ipv6", &v6s).Error
})
if err != nil {
return nil, fmt.Errorf("reading IPv6 addresses from database: %w", err)
}

}

var ips netipx.IPSetBuilder
Expand Down Expand Up @@ -79,18 +92,14 @@ func NewIPAllocator(db *HSDatabase, prefix4, prefix6 *netip.Prefix) (*IPAllocato

// Fetch all the IP Addresses currently handed out from the Database
// and add them to the used IP set.
for _, slice := range addressesSlices {
var machineAddresses types.NodeAddresses
err := machineAddresses.Scan(slice)
if err != nil {
return nil, fmt.Errorf(
"parsing IPs from database %v: %w", machineAddresses,
err,
)
}

for _, ip := range machineAddresses {
ips.Add(ip)
for _, addrStr := range append(v4s, v6s...) {
if addrStr.Valid {
addr, err := netip.ParseAddr(addrStr.String)
if err != nil {
return nil, fmt.Errorf("parsing IP address from database: %w", err)
}

ips.Add(addr)
}
}

Expand All @@ -108,31 +117,30 @@ func NewIPAllocator(db *HSDatabase, prefix4, prefix6 *netip.Prefix) (*IPAllocato
return &ret, nil
}

func (i *IPAllocator) Next() (types.NodeAddresses, error) {
func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) {
i.mu.Lock()
defer i.mu.Unlock()

var ret types.NodeAddresses
var err error
var ret4 *netip.Addr
var ret6 *netip.Addr

if i.prefix4 != nil {
v4, err := i.next(i.prev4, i.prefix4)
ret4, err = i.next(i.prev4, i.prefix4)
if err != nil {
return nil, fmt.Errorf("allocating IPv4 address: %w", err)
return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err)
}

ret = append(ret, *v4)
}

if i.prefix6 != nil {
v6, err := i.next(i.prev6, i.prefix6)
ret6, err = i.next(i.prev6, i.prefix6)
if err != nil {
return nil, fmt.Errorf("allocating IPv6 address: %w", err)
return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err)
}

ret = append(ret, *v6)
}

return ret, nil
return ret4, ret6, nil
}

var ErrCouldNotAllocateIP = errors.New("failed to allocate IP")
Expand Down
93 changes: 54 additions & 39 deletions hscontrol/db/ip_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"database/sql"
"net/netip"
"os"
"testing"
Expand Down Expand Up @@ -44,7 +45,8 @@ func TestIPAllocator(t *testing.T) {
prefix4 *netip.Prefix
prefix6 *netip.Prefix
getCount int
want []types.NodeAddresses
want4 []netip.Addr
want6 []netip.Addr
}{
{
name: "simple",
Expand All @@ -57,11 +59,11 @@ func TestIPAllocator(t *testing.T) {

getCount: 1,

want: []types.NodeAddresses{
{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
want4: []netip.Addr{
na("100.64.0.1"),
},
want6: []netip.Addr{
na("fd7a:115c:a1e0::1"),
},
},
{
Expand All @@ -74,10 +76,8 @@ func TestIPAllocator(t *testing.T) {

getCount: 1,

want: []types.NodeAddresses{
{
na("100.64.0.1"),
},
want4: []netip.Addr{
na("100.64.0.1"),
},
},
{
Expand All @@ -90,10 +90,8 @@ func TestIPAllocator(t *testing.T) {

getCount: 1,

want: []types.NodeAddresses{
{
na("fd7a:115c:a1e0::1"),
},
want6: []netip.Addr{
na("fd7a:115c:a1e0::1"),
},
},
{
Expand All @@ -102,9 +100,13 @@ func TestIPAllocator(t *testing.T) {
db := newDb()

db.DB.Save(&types.Node{
IPAddresses: types.NodeAddresses{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.1",
},
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::1",
},
})

Expand All @@ -116,11 +118,11 @@ func TestIPAllocator(t *testing.T) {

getCount: 1,

want: []types.NodeAddresses{
{
na("100.64.0.2"),
na("fd7a:115c:a1e0::2"),
},
want4: []netip.Addr{
na("100.64.0.2"),
},
want6: []netip.Addr{
na("fd7a:115c:a1e0::2"),
},
},
{
Expand All @@ -129,9 +131,13 @@ func TestIPAllocator(t *testing.T) {
db := newDb()

db.DB.Save(&types.Node{
IPAddresses: types.NodeAddresses{
na("100.64.0.2"),
na("fd7a:115c:a1e0::2"),
IPv4DatabaseField: sql.NullString{
Valid: true,
String: "100.64.0.2",
},
IPv6DatabaseField: sql.NullString{
Valid: true,
String: "fd7a:115c:a1e0::2",
},
})

Expand All @@ -143,15 +149,13 @@ func TestIPAllocator(t *testing.T) {

getCount: 2,

want: []types.NodeAddresses{
{
na("100.64.0.1"),
na("fd7a:115c:a1e0::1"),
},
{
na("100.64.0.3"),
na("fd7a:115c:a1e0::3"),
},
want4: []netip.Addr{
na("100.64.0.1"),
na("100.64.0.3"),
},
want6: []netip.Addr{
na("fd7a:115c:a1e0::1"),
na("fd7a:115c:a1e0::3"),
},
},
}
Expand All @@ -164,18 +168,29 @@ func TestIPAllocator(t *testing.T) {

spew.Dump(alloc)

var got []types.NodeAddresses
var got4s []netip.Addr
var got6s []netip.Addr

for range tt.getCount {
gotSet, err := alloc.Next()
got4, got6, err := alloc.Next()
if err != nil {
t.Fatalf("allocating next IP: %s", err)
}

got = append(got, gotSet)
if got4 != nil {
got4s = append(got4s, *got4)
}

if got6 != nil {
got6s = append(got6s, *got6)
}
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("IPAllocator unexpected result (-want +got):\n%s", diff)
if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" {
t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff)
}

if diff := cmp.Diff(tt.want6, got6s, util.Comparers...); diff != "" {
t.Errorf("IPAllocator 6s unexpected result (-want +got):\n%s", diff)
}
})
}
Expand Down
35 changes: 7 additions & 28 deletions hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,12 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
nodeKey := key.NewNode()
machineKey := key.NewMachine()

v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
node := types.Node{
ID: types.NodeID(index),
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
IPAddresses: types.NodeAddresses{
netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))),
},
ID: types.NodeID(index),
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
IPv4: &v4,
Hostname: "testnode" + strconv.Itoa(index),
UserID: stor[index%2].user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Expand Down Expand Up @@ -301,27 +300,6 @@ func (s *Suite) TestExpireNode(c *check.C) {
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
}

func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
input := types.NodeAddresses([]netip.Addr{
netip.MustParseAddr("192.0.2.1"),
netip.MustParseAddr("2001:db8::1"),
})
serialized, err := input.Value()
c.Assert(err, check.IsNil)
if serial, ok := serialized.(string); ok {
c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1")
}

var deserialized types.NodeAddresses
err = deserialized.Scan(serialized)
c.Assert(err, check.IsNil)

c.Assert(len(deserialized), check.Equals, len(input))
for i := range deserialized {
c.Assert(deserialized[i], check.Equals, input[i])
}
}

func (s *Suite) TestGenerateGivenName(c *check.C) {
user1, err := db.CreateUser("user-1")
c.Assert(err, check.IsNil)
Expand Down Expand Up @@ -561,6 +539,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
// Check if a subprefix of an autoapproved route is approved
route2 := netip.MustParsePrefix("10.11.0.0/24")

v4 := netip.MustParseAddr("100.64.0.1")
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
Expand All @@ -573,7 +552,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
RequestTags: []string{"tag:exit"},
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
},
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
IPv4: &v4,
}

db.DB.Save(&node)
Expand Down
4 changes: 2 additions & 2 deletions hscontrol/db/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ func EnableAutoApprovedRoutes(
aclPolicy *policy.ACLPolicy,
node *types.Node,
) error {
if len(node.IPAddresses) == 0 {
if node.IPv4 == nil && node.IPv6 == nil {
return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
}

Expand Down Expand Up @@ -652,7 +652,7 @@ func EnableAutoApprovedRoutes(
}

// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
if approvedIps.Contains(node.IPAddresses[0]) {
if approvedIps.Contains(*node.IPv4) {
approvedRoutes = append(approvedRoutes, advertisedRoute)
}
}
Expand Down
4 changes: 2 additions & 2 deletions hscontrol/grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, err
}

addrs, err := api.h.ipAlloc.Next()
ipv4, ipv6, err := api.h.ipAlloc.Next()
if err != nil {
return nil, err
}
Expand All @@ -208,7 +208,7 @@ func (api headscaleV1APIServer) RegisterNode(
request.GetUser(),
nil,
util.RegisterMethodCLI,
addrs,
ipv4, ipv6,
)
})
if err != nil {
Expand Down
Loading

0 comments on commit c1a05f4

Please sign in to comment.