Skip to content

Commit

Permalink
# rule: fix 32-bit platforms don't support adding rules with a mark v…
Browse files Browse the repository at this point in the history
…alue of 0x80000000/0xF0000000 ~ 0xF0000000/0xF0000000

 The maximum value for an `int` type on a 32-bit platform is 0x7FFFFFFF. Since 0xF0000000 exceeds this limit, we need to use `uint` instead of `int` to handle these values.
  • Loading branch information
qxoqx authored and aboch committed Aug 5, 2024
1 parent d13535d commit 8f96fd8
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 33 deletions.
4 changes: 2 additions & 2 deletions route_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,7 @@ type RouteGetOptions struct {
VrfName string
SrcAddr net.IP
UID *uint32
Mark int
Mark uint32
FIBMatch bool
}

Expand Down Expand Up @@ -1630,7 +1630,7 @@ func (h *Handle) RouteGetWithOptions(destination net.IP, options *RouteGetOption

if options.Mark > 0 {
b := make([]byte, 4)
native.PutUint32(b, uint32(options.Mark))
native.PutUint32(b, options.Mark)

req.AddData(nl.NewRtAttr(unix.RTA_MARK, b))
}
Expand Down
80 changes: 65 additions & 15 deletions route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2449,27 +2449,41 @@ func TestRouteFWMarkOption(t *testing.T) {
}

// a table different than unix.RT_TABLE_MAIN
testtable := 1000
testTable0 := 254
testTable1 := 1000
testTable2 := 1001

gw1 := net.IPv4(192, 168, 1, 254)
gw2 := net.IPv4(192, 168, 2, 254)
gw0 := net.IPv4(192, 168, 1, 254)
gw1 := net.IPv4(192, 168, 2, 254)
gw2 := net.IPv4(192, 168, 3, 254)

// add default route via gw1 (in main route table by default)
// add default route via gw0 (in main route table by default)
defaultRouteMain := Route{
Dst: nil,
Gw: gw1,
Dst: nil,
Gw: gw0,
Table: testTable0,
}
if err := RouteAdd(&defaultRouteMain); err != nil {
t.Fatal(err)
}

// add default route via gw1 in test route table
defaultRouteTest1 := Route{
Dst: nil,
Gw: gw1,
Table: testTable1,
}
if err := RouteAdd(&defaultRouteTest1); err != nil {
t.Fatal(err)
}

// add default route via gw2 in test route table
defaultRouteTest := Route{
defaultRouteTest2 := Route{
Dst: nil,
Gw: gw2,
Table: testtable,
Table: testTable2,
}
if err := RouteAdd(&defaultRouteTest); err != nil {
if err := RouteAdd(&defaultRouteTest2); err != nil {
t.Fatal(err)
}

Expand All @@ -2481,34 +2495,70 @@ func TestRouteFWMarkOption(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if len(routes) != 2 || routes[0].Table == routes[1].Table {
if len(routes) != 3 || routes[0].Table == routes[1].Table || routes[1].Table == routes[2].Table ||
routes[0].Table == routes[2].Table {
t.Fatal("Routes not added properly")
}

// add a rule that fwmark match should result in route lookup of test table
fwmark := 1000
fwmark1 := uint32(0xAFFFFFFF)
fwmark2 := uint32(0xBFFFFFFF)

rule := NewRule()
rule.Mark = fwmark
rule.Mask = 0xFFFFFFFF
rule.Table = testtable
rule.Mark = fwmark1
rule.Mask = &[]uint32{0xFFFFFFFF}[0]

rule.Table = testTable1
if err := RuleAdd(rule); err != nil {
t.Fatal(err)
}

rule = NewRule()
rule.Mark = fwmark2
rule.Mask = &[]uint32{0xFFFFFFFF}[0]
rule.Table = testTable2
if err := RuleAdd(rule); err != nil {
t.Fatal(err)
}

rules, err := RuleListFiltered(FAMILY_V4, &Rule{Mark: fwmark1}, RT_FILTER_MARK)
if err != nil {
t.Fatal(err)
}
if len(rules) != 1 || rules[0].Table != testTable1 || rules[0].Mark != fwmark1 {
t.Fatal("Rules not added properly")
}

rules, err = RuleListFiltered(FAMILY_V4, &Rule{Mark: fwmark2}, RT_FILTER_MARK)
if err != nil {
t.Fatal(err)
}
if len(rules) != 1 || rules[0].Table != testTable2 || rules[0].Mark != fwmark2 {
t.Fatal("Rules not added properly")
}

dstIP := net.IPv4(10, 1, 1, 1)

// check getting route without FWMark option
routes, err = RouteGetWithOptions(dstIP, &RouteGetOptions{})
if err != nil {
t.Fatal(err)
}
if len(routes) != 1 || !routes[0].Gw.Equal(gw0) {
t.Fatal(routes)
}

// check getting route with FWMark option
routes, err = RouteGetWithOptions(dstIP, &RouteGetOptions{Mark: fwmark1})
if err != nil {
t.Fatal(err)
}
if len(routes) != 1 || !routes[0].Gw.Equal(gw1) {
t.Fatal(routes)
}

// check getting route with FWMark option
routes, err = RouteGetWithOptions(dstIP, &RouteGetOptions{Mark: fwmark})
routes, err = RouteGetWithOptions(dstIP, &RouteGetOptions{Mark: fwmark2})
if err != nil {
t.Fatal(err)
}
Expand Down
8 changes: 4 additions & 4 deletions rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ type Rule struct {
Priority int
Family int
Table int
Mark int
Mask int
Mark uint32
Mask *uint32
Tos uint
TunID uint
Goto int
Expand Down Expand Up @@ -51,8 +51,8 @@ func NewRule() *Rule {
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Priority: -1,
Mark: -1,
Mask: -1,
Mark: 0,
Mask: nil,
Goto: -1,
Flow: -1,
}
Expand Down
25 changes: 18 additions & 7 deletions rule_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error {
native.PutUint32(b, uint32(rule.Priority))
req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b))
}
if rule.Mark >= 0 {
if rule.Mark != 0 || rule.Mask != nil {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Mark))
native.PutUint32(b, rule.Mark)
req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b))
}
if rule.Mask >= 0 {
if rule.Mask != nil {
b := make([]byte, 4)
native.PutUint32(b, uint32(rule.Mask))
native.PutUint32(b, *rule.Mask)
req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b))
}
if rule.Flow >= 0 {
Expand Down Expand Up @@ -242,9 +242,10 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) (
Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)),
}
case nl.FRA_FWMARK:
rule.Mark = int(native.Uint32(attrs[j].Value[0:4]))
rule.Mark = native.Uint32(attrs[j].Value[0:4])
case nl.FRA_FWMASK:
rule.Mask = int(native.Uint32(attrs[j].Value[0:4]))
mask := native.Uint32(attrs[j].Value[0:4])
rule.Mask = &mask
case nl.FRA_TUN_ID:
rule.TunID = uint(native.Uint64(attrs[j].Value[0:8]))
case nl.FRA_IIFNAME:
Expand Down Expand Up @@ -297,7 +298,7 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) (
continue
case filterMask&RT_FILTER_MARK != 0 && rule.Mark != filter.Mark:
continue
case filterMask&RT_FILTER_MASK != 0 && rule.Mask != filter.Mask:
case filterMask&RT_FILTER_MASK != 0 && !ptrEqual(rule.Mask, filter.Mask):
continue
}
}
Expand All @@ -321,3 +322,13 @@ func (pr *RuleUIDRange) toRtAttrData() []byte {
native.PutUint32(b[1], pr.End)
return bytes.Join(b, []byte{})
}

func ptrEqual(a, b *uint32) bool {
if a == b {
return true
}
if (a == nil) || (b == nil) {
return false
}
return *a == *b
}
Loading

0 comments on commit 8f96fd8

Please sign in to comment.