Skip to content

Commit

Permalink
move to use tailscfg types over strings/custom types (#1612)
Browse files Browse the repository at this point in the history
* rename database only fields

Signed-off-by: Kristoffer Dalby <[email protected]>

* use correct endpoint type over string list

Signed-off-by: Kristoffer Dalby <[email protected]>

* remove HostInfo wrapper

Signed-off-by: Kristoffer Dalby <[email protected]>

* wrap errors in database hooks

Signed-off-by: Kristoffer Dalby <[email protected]>

---------

Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby authored Nov 21, 2023
1 parent ed4e199 commit b918aa0
Show file tree
Hide file tree
Showing 13 changed files with 147 additions and 154 deletions.
3 changes: 2 additions & 1 deletion hscontrol/db/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)

Expand Down Expand Up @@ -593,7 +594,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:exit"},
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
},
Expand Down
2 changes: 1 addition & 1 deletion hscontrol/db/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error {
}

advertisedRoutes := map[netip.Prefix]bool{}
for _, prefix := range node.HostInfo.RoutableIPs {
for _, prefix := range node.Hostinfo.RoutableIPs {
advertisedRoutes[prefix] = false
}

Expand Down
18 changes: 9 additions & 9 deletions hscontrol/db/routes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
Hostinfo: &hostInfo,
}
db.db.Save(&node)

Expand Down Expand Up @@ -81,7 +81,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
Hostinfo: &hostInfo,
}
db.db.Save(&node)

Expand Down Expand Up @@ -152,7 +152,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo1),
Hostinfo: &hostInfo1,
}
db.db.Save(&node1)

Expand All @@ -174,7 +174,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo2),
Hostinfo: &hostInfo2,
}
db.db.Save(&node2)

Expand Down Expand Up @@ -232,7 +232,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo1),
Hostinfo: &hostInfo1,
LastSeen: &now,
}
db.db.Save(&node1)
Expand Down Expand Up @@ -266,7 +266,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo2),
Hostinfo: &hostInfo2,
LastSeen: &now,
}
db.db.Save(&node2)
Expand Down Expand Up @@ -313,9 +313,9 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 1)

node2.HostInfo = types.HostInfo(tailcfg.Hostinfo{
node2.Hostinfo = &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix, prefix2},
})
}
err = db.db.Save(&node2).Error
c.Assert(err, check.IsNil)

Expand Down Expand Up @@ -368,7 +368,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo1),
Hostinfo: &hostInfo1,
LastSeen: &now,
}
db.db.Save(&node1)
Expand Down
2 changes: 1 addition & 1 deletion hscontrol/grpcv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
Expiry: &time.Time{},
LastSeen: &time.Time{},

HostInfo: types.HostInfo(hostinfo),
Hostinfo: &hostinfo,
}

log.Debug().
Expand Down
2 changes: 1 addition & 1 deletion hscontrol/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{
"device_name": []string{node.Hostname},
"device_model": []string{node.HostInfo.OS},
"device_model": []string{node.Hostinfo.OS},
}

if len(node.IPAddresses) > 0 {
Expand Down
9 changes: 3 additions & 6 deletions hscontrol/mapper/mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,7 @@ func Test_fullMapResponse(t *testing.T) {
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
HostInfo: types.HostInfo{},
Endpoints: []string{},
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{
{
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
Expand Down Expand Up @@ -267,8 +266,7 @@ func Test_fullMapResponse(t *testing.T) {
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
HostInfo: types.HostInfo{},
Endpoints: []string{},
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{},
CreatedAt: created,
}
Expand Down Expand Up @@ -324,8 +322,7 @@ func Test_fullMapResponse(t *testing.T) {
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
HostInfo: types.HostInfo{},
Endpoints: []string{},
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{},
CreatedAt: created,
}
Expand Down
15 changes: 4 additions & 11 deletions hscontrol/mapper/tail.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ func tailNode(
}

var derp string
if node.HostInfo.NetInfo != nil {
derp = fmt.Sprintf("127.3.3.40:%d", node.HostInfo.NetInfo.PreferredDERP)
if node.Hostinfo.NetInfo != nil {
derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP)
} else {
derp = "127.3.3.40:0" // Zero means disconnected or unknown.
}
Expand All @@ -90,18 +90,11 @@ func tailNode(
return nil, err
}

hostInfo := node.GetHostInfo()

online := node.IsOnline()

tags, _ := pol.TagsOfNode(node)
tags = lo.Uniq(append(tags, node.ForcedTags...))

endpoints, err := node.EndpointsToAddrPort()
if err != nil {
return nil, err
}

tNode := tailcfg.Node{
ID: tailcfg.NodeID(node.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(
Expand All @@ -118,9 +111,9 @@ func tailNode(
DiscoKey: node.DiscoKey,
Addresses: addrs,
AllowedIPs: allowedIPs,
Endpoints: endpoints,
Endpoints: node.Endpoints,
DERP: derp,
Hostinfo: hostInfo.View(),
Hostinfo: node.Hostinfo.View(),
Created: node.CreatedAt,

Tags: tags,
Expand Down
9 changes: 5 additions & 4 deletions hscontrol/mapper/tail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ func TestTailNode(t *testing.T) {
wantErr bool
}{
{
name: "empty-node",
node: &types.Node{},
name: "empty-node",
node: &types.Node{
Hostinfo: &tailcfg.Hostinfo{},
},
pol: &policy.ACLPolicy{},
dnsConfig: &tailcfg.DNSConfig{},
baseDomain: "",
Expand Down Expand Up @@ -102,8 +104,7 @@ func TestTailNode(t *testing.T) {
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
HostInfo: types.HostInfo{},
Endpoints: []string{},
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{
{
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
Expand Down
18 changes: 12 additions & 6 deletions hscontrol/policy/acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,13 @@ func excludeCorrectlyTaggedNodes(
}
// for each node if tag is in tags list, don't append it.
for _, node := range nodes {
hi := node.GetHostInfo()

found := false
for _, t := range hi.RequestTags {

if node.Hostinfo == nil {
continue
}

for _, t := range node.Hostinfo.RequestTags {
if util.StringOrPrefixListContains(tags, t) {
found = true

Expand Down Expand Up @@ -787,8 +790,11 @@ func (pol *ACLPolicy) expandIPsFromTag(
for _, user := range owners {
nodes := filterNodesByUser(nodes, user)
for _, node := range nodes {
hi := node.GetHostInfo()
if util.StringOrPrefixListContains(hi.RequestTags, alias) {
if node.Hostinfo == nil {
continue
}

if util.StringOrPrefixListContains(node.Hostinfo.RequestTags, alias) {
node.IPAddresses.AppendToIPSet(&build)
}
}
Expand Down Expand Up @@ -882,7 +888,7 @@ func (pol *ACLPolicy) TagsOfNode(

validTagMap := make(map[string]bool)
invalidTagMap := make(map[string]bool)
for _, tag := range node.HostInfo.RequestTags {
for _, tag := range node.Hostinfo.RequestTags {
owners, err := expandOwnersFromTag(pol, tag)
if errors.Is(err, ErrInvalidTag) {
invalidTagMap[tag] = true
Expand Down
Loading

0 comments on commit b918aa0

Please sign in to comment.