Skip to content

Commit

Permalink
Merge branch 'main' into docker-distroless
Browse files Browse the repository at this point in the history
  • Loading branch information
ItalyPaleAle authored Nov 13, 2021
2 parents 7d77acd + ba65092 commit 5ec7158
Show file tree
Hide file tree
Showing 51 changed files with 10,372 additions and 1,699 deletions.
61 changes: 35 additions & 26 deletions acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ import (
"tailscale.com/tailcfg"
)

const errorEmptyPolicy = Error("empty policy")
const errorInvalidAction = Error("invalid action")
const errorInvalidUserSection = Error("invalid user section")
const errorInvalidGroup = Error("invalid group")
const errorInvalidTag = Error("invalid tag")
const errorInvalidNamespace = Error("invalid namespace")
const errorInvalidPortFormat = Error("invalid port format")
const (
errorEmptyPolicy = Error("empty policy")
errorInvalidAction = Error("invalid action")
errorInvalidUserSection = Error("invalid user section")
errorInvalidGroup = Error("invalid group")
errorInvalidTag = Error("invalid tag")
errorInvalidNamespace = Error("invalid namespace")
errorInvalidPortFormat = Error("invalid port format")
)

// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules
func (h *Headscale) LoadACLPolicy(path string) error {
Expand All @@ -36,7 +38,14 @@ func (h *Headscale) LoadACLPolicy(path string) error {
if err != nil {
return err
}
err = hujson.Unmarshal(b, &policy)

ast, err := hujson.Parse(b)
if err != nil {
return err
}
ast.Standardize()
b = ast.Pack()
err = json.Unmarshal(b, &policy)
if err != nil {
return err
}
Expand All @@ -53,7 +62,7 @@ func (h *Headscale) LoadACLPolicy(path string) error {
return nil
}

func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}

for i, a := range h.aclPolicy.ACLs {
Expand All @@ -71,7 +80,7 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
Msgf("Error parsing ACL %d, User %d", i, j)
return nil, err
}
srcIPs = append(srcIPs, *srcs...)
srcIPs = append(srcIPs, srcs...)
}
r.SrcIPs = srcIPs

Expand All @@ -83,7 +92,7 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
Msgf("Error parsing ACL %d, Port %d", i, j)
return nil, err
}
destPorts = append(destPorts, *dests...)
destPorts = append(destPorts, dests...)
}

rules = append(rules, tailcfg.FilterRule{
Expand All @@ -92,14 +101,14 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
})
}

return &rules, nil
return rules, nil
}

func (h *Headscale) generateACLPolicySrcIP(u string) (*[]string, error) {
func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) {
return h.expandAlias(u)
}

func (h *Headscale) generateACLPolicyDestPorts(d string) (*[]tailcfg.NetPortRange, error) {
func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange, error) {
tokens := strings.Split(d, ":")
if len(tokens) < 2 || len(tokens) > 3 {
return nil, errorInvalidPortFormat
Expand Down Expand Up @@ -128,7 +137,7 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) (*[]tailcfg.NetPortRang
}

dests := []tailcfg.NetPortRange{}
for _, d := range *expanded {
for _, d := range expanded {
for _, p := range *ports {
pr := tailcfg.NetPortRange{
IP: d,
Expand All @@ -137,12 +146,12 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) (*[]tailcfg.NetPortRang
dests = append(dests, pr)
}
}
return &dests, nil
return dests, nil
}

func (h *Headscale) expandAlias(s string) (*[]string, error) {
func (h *Headscale) expandAlias(s string) ([]string, error) {
if s == "*" {
return &[]string{"*"}, nil
return []string{"*"}, nil
}

if strings.HasPrefix(s, "group:") {
Expand All @@ -155,11 +164,11 @@ func (h *Headscale) expandAlias(s string) (*[]string, error) {
if err != nil {
return nil, errorInvalidNamespace
}
for _, node := range *nodes {
for _, node := range nodes {
ips = append(ips, node.IPAddress)
}
}
return &ips, nil
return ips, nil
}

if strings.HasPrefix(s, "tag:") {
Expand Down Expand Up @@ -195,7 +204,7 @@ func (h *Headscale) expandAlias(s string) (*[]string, error) {
}
}
}
return &ips, nil
return ips, nil
}

n, err := h.GetNamespace(s)
Expand All @@ -205,24 +214,24 @@ func (h *Headscale) expandAlias(s string) (*[]string, error) {
return nil, err
}
ips := []string{}
for _, n := range *nodes {
for _, n := range nodes {
ips = append(ips, n.IPAddress)
}
return &ips, nil
return ips, nil
}

if h, ok := h.aclPolicy.Hosts[s]; ok {
return &[]string{h.String()}, nil
return []string{h.String()}, nil
}

ip, err := netaddr.ParseIP(s)
if err == nil {
return &[]string{ip.String()}, nil
return []string{ip.String()}, nil
}

cidr, err := netaddr.ParseIPPrefix(s)
if err == nil {
return &[]string{cidr.String()}, nil
return []string{cidr.String()}, nil
}

return nil, errorInvalidUserSection
Expand Down
49 changes: 24 additions & 25 deletions acls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ func (s *Suite) TestWrongPath(c *check.C) {
func (s *Suite) TestBrokenHuJson(c *check.C) {
err := h.LoadACLPolicy("./tests/acls/broken.hujson")
c.Assert(err, check.NotNil)

}

func (s *Suite) TestInvalidPolicyHuson(c *check.C) {
Expand Down Expand Up @@ -57,10 +56,10 @@ func (s *Suite) TestPortRange(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)

c.Assert(*rules, check.HasLen, 1)
c.Assert((*rules)[0].DstPorts, check.HasLen, 1)
c.Assert((*rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(5400))
c.Assert((*rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(5500))
c.Assert(rules, check.HasLen, 1)
c.Assert((rules)[0].DstPorts, check.HasLen, 1)
c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(5400))
c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(5500))
}

func (s *Suite) TestPortWildcard(c *check.C) {
Expand All @@ -71,12 +70,12 @@ func (s *Suite) TestPortWildcard(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)

c.Assert(*rules, check.HasLen, 1)
c.Assert((*rules)[0].DstPorts, check.HasLen, 1)
c.Assert((*rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert((*rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert((*rules)[0].SrcIPs, check.HasLen, 1)
c.Assert((*rules)[0].SrcIPs[0], check.Equals, "*")
c.Assert(rules, check.HasLen, 1)
c.Assert((rules)[0].DstPorts, check.HasLen, 1)
c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert((rules)[0].SrcIPs, check.HasLen, 1)
c.Assert((rules)[0].SrcIPs[0], check.Equals, "*")
}

func (s *Suite) TestPortNamespace(c *check.C) {
Expand Down Expand Up @@ -110,13 +109,13 @@ func (s *Suite) TestPortNamespace(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)

c.Assert(*rules, check.HasLen, 1)
c.Assert((*rules)[0].DstPorts, check.HasLen, 1)
c.Assert((*rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert((*rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert((*rules)[0].SrcIPs, check.HasLen, 1)
c.Assert((*rules)[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert((*rules)[0].SrcIPs[0], check.Equals, ip.String())
c.Assert(rules, check.HasLen, 1)
c.Assert((rules)[0].DstPorts, check.HasLen, 1)
c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert((rules)[0].SrcIPs, check.HasLen, 1)
c.Assert((rules)[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert((rules)[0].SrcIPs[0], check.Equals, ip.String())
}

func (s *Suite) TestPortGroup(c *check.C) {
Expand Down Expand Up @@ -150,11 +149,11 @@ func (s *Suite) TestPortGroup(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)

c.Assert(*rules, check.HasLen, 1)
c.Assert((*rules)[0].DstPorts, check.HasLen, 1)
c.Assert((*rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert((*rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert((*rules)[0].SrcIPs, check.HasLen, 1)
c.Assert((*rules)[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert((*rules)[0].SrcIPs[0], check.Equals, ip.String())
c.Assert(rules, check.HasLen, 1)
c.Assert((rules)[0].DstPorts, check.HasLen, 1)
c.Assert((rules)[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert((rules)[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert((rules)[0].SrcIPs, check.HasLen, 1)
c.Assert((rules)[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert((rules)[0].SrcIPs[0], check.Equals, ip.String())
}
9 changes: 8 additions & 1 deletion acls_types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package headscale

import (
"encoding/json"
"strings"

"github.com/tailscale/hujson"
Expand Down Expand Up @@ -43,7 +44,13 @@ type ACLTest struct {
func (h *Hosts) UnmarshalJSON(data []byte) error {
hosts := Hosts{}
hs := make(map[string]string)
err := hujson.Unmarshal(data, &hs)
ast, err := hujson.Parse(data)
if err != nil {
return err
}
ast.Standardize()
data = ast.Pack()
err = json.Unmarshal(data, &hs)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
<p>
<code>
<b>headscale -n NAMESPACE nodes register -k %s</b>
<b>headscale -n NAMESPACE nodes register --key %s</b>
</code>
</p>
Expand Down Expand Up @@ -306,7 +306,7 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
Peers: nodePeers,
DNSConfig: dnsConfig,
Domain: h.cfg.BaseDomain,
PacketFilter: *h.aclRules,
PacketFilter: h.aclRules,
DERPMap: h.DERPMap,
UserProfiles: profiles,
}
Expand Down
Loading

0 comments on commit 5ec7158

Please sign in to comment.