From 5c15b20bd5af2fdb1f5b1d046c1c9125c7a20f87 Mon Sep 17 00:00:00 2001 From: Casey Callendrello Date: Thu, 2 Aug 2018 19:39:05 +0200 Subject: [PATCH] Add support for iptables in nftables mode. Iptables also has the ability to work in nftables mode, where it is supposed to act like iptables but use the nftables subsystem. Unfortunately, it isn't exactly the same. The biggest difference is that counter output is iptables-save style, rather than with "-c N N". Also, improve some tests. --- iptables/iptables.go | 97 +++++++++++++++++++++++++++++--------- iptables/iptables_test.go | 98 +++++++++++++++++++++++++++++++++++---- test | 7 ++- 3 files changed, 171 insertions(+), 31 deletions(-) diff --git a/iptables/iptables.go b/iptables/iptables.go index 3b62fe2..8db2597 100644 --- a/iptables/iptables.go +++ b/iptables/iptables.go @@ -29,11 +29,15 @@ import ( // Adds the output of stderr to exec.ExitError type Error struct { exec.ExitError - cmd exec.Cmd - msg string + cmd exec.Cmd + msg string + exitStatus *int //for overriding } func (e *Error) ExitStatus() int { + if e.exitStatus != nil { + return *e.exitStatus + } return e.Sys().(syscall.WaitStatus).ExitStatus() } @@ -65,6 +69,7 @@ type IPTables struct { v1 int v2 int v3 int + mode string // the underlying iptables operating mode, e.g. nf_tables } // New creates a new IPTables. @@ -81,12 +86,10 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) { return nil, err } vstring, err := getIptablesVersionString(path) - v1, v2, v3, err := extractIptablesVersion(vstring) + v1, v2, v3, mode, err := extractIptablesVersion(vstring) + + checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3) - checkPresent, waitPresent, randomFullyPresent, err := getIptablesCommandSupport(v1, v2, v3) - if err != nil { - return nil, fmt.Errorf("error checking iptables version: %v", err) - } ipt := IPTables{ path: path, proto: proto, @@ -96,6 +99,7 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) { v1: v1, v2: v2, v3: v3, + mode: mode, } return &ipt, nil } @@ -266,10 +270,27 @@ func (ipt *IPTables) executeList(args []string) ([]string, error) { } rules := strings.Split(stdout.String(), "\n") + + // strip trailing newline if len(rules) > 0 && rules[len(rules)-1] == "" { rules = rules[:len(rules)-1] } + // nftables mode doesn't return an error code when listing a non-existent + // chain. Patch that up. + if len(rules) == 0 && ipt.mode == "nf_tables" { + v := 1 + return nil, &Error{ + cmd: exec.Cmd{Args: args}, + msg: "iptables: No chain/target/match by that name.", + exitStatus: &v, + } + } + + for i, rule := range rules { + rules[i] = filterRuleOutput(rule) + } + return rules, nil } @@ -284,11 +305,18 @@ func (ipt *IPTables) NewChain(table, chain string) error { func (ipt *IPTables) ClearChain(table, chain string) error { err := ipt.NewChain(table, chain) + // the exit code for "this table already exists" is different for + // different iptables modes + existsErr := 1 + if ipt.mode == "nf_tables" { + existsErr = 4 + } + eerr, eok := err.(*Error) switch { case err == nil: return nil - case eok && eerr.ExitStatus() == 1: + case eok && eerr.ExitStatus() == existsErr: // chain already exists. Flush (clear) it. return ipt.run("-t", table, "-F", chain) default: @@ -357,7 +385,7 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error { if err := cmd.Run(); err != nil { switch e := err.(type) { case *exec.ExitError: - return &Error{*e, cmd, stderr.String()} + return &Error{*e, cmd, stderr.String(), nil} default: return err } @@ -376,36 +404,40 @@ func getIptablesCommand(proto Protocol) string { } // Checks if iptables has the "-C" and "--wait" flag -func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, error) { - - return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3), nil +func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool) { + return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3) } -// getIptablesVersion returns the first three components of the iptables version. -// e.g. "iptables v1.3.66" would return (1, 3, 66, nil) -func extractIptablesVersion(str string) (int, int, int, error) { - versionMatcher := regexp.MustCompile("v([0-9]+)\\.([0-9]+)\\.([0-9]+)") +// getIptablesVersion returns the first three components of the iptables version +// and the operating mode (e.g. nf_tables or legacy) +// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil) +func extractIptablesVersion(str string) (int, int, int, string, error) { + versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`) result := versionMatcher.FindStringSubmatch(str) if result == nil { - return 0, 0, 0, fmt.Errorf("no iptables version found in string: %s", str) + return 0, 0, 0, "", fmt.Errorf("no iptables version found in string: %s", str) } v1, err := strconv.Atoi(result[1]) if err != nil { - return 0, 0, 0, err + return 0, 0, 0, "", err } v2, err := strconv.Atoi(result[2]) if err != nil { - return 0, 0, 0, err + return 0, 0, 0, "", err } v3, err := strconv.Atoi(result[3]) if err != nil { - return 0, 0, 0, err + return 0, 0, 0, "", err } - return v1, v2, v3, nil + mode := "legacy" + if result[4] != "" { + mode = result[4] + } + return v1, v2, v3, mode, nil } // Runs "iptables --version" to get the version string @@ -473,3 +505,26 @@ func (ipt *IPTables) existsForOldIptables(table, chain string, rulespec []string } return strings.Contains(stdout.String(), rs), nil } + +// counterRegex is the regex used to detect nftables counter format +var counterRegex = regexp.MustCompile(`^\[([0-9]+):([0-9]+)\] `) + +// filterRuleOutput works around some inconsistencies in output. +// For example, when iptables is in legacy vs. nftables mode, it produces +// different results. +func filterRuleOutput(rule string) string { + out := rule + + // work around an output difference in nftables mode where counters + // are output in iptables-save format, rather than iptables -S format + // The string begins with "[0:0]" + // + // Fixes #49 + if groups := counterRegex.FindStringSubmatch(out); groups != nil { + // drop the brackets + out = out[len(groups[0]):] + out = fmt.Sprintf("%s -c %s %s", out, groups[1], groups[2]) + } + + return out +} diff --git a/iptables/iptables_test.go b/iptables/iptables_test.go index 2c851b8..dcd996c 100644 --- a/iptables/iptables_test.go +++ b/iptables/iptables_test.go @@ -97,8 +97,10 @@ func mustTestableIptables() []*IPTables { } func TestChain(t *testing.T) { - for _, ipt := range mustTestableIptables() { - runChainTests(t, ipt) + for i, ipt := range mustTestableIptables() { + t.Run(fmt.Sprint(i), func(t *testing.T) { + runChainTests(t, ipt) + }) } } @@ -179,8 +181,10 @@ func runChainTests(t *testing.T, ipt *IPTables) { } func TestRules(t *testing.T) { - for _, ipt := range mustTestableIptables() { - runRulesTests(t, ipt) + for i, ipt := range mustTestableIptables() { + t.Run(fmt.Sprint(i), func(t *testing.T) { + runRulesTests(t, ipt) + }) } } @@ -265,12 +269,17 @@ func runRulesTests(t *testing.T, ipt *IPTables) { t.Fatalf("ListWithCounters failed: %v", err) } + suffix := " -c 0 0 -j ACCEPT" + if ipt.mode == "nf_tables" { + suffix = " -j ACCEPT -c 0 0" + } + expected = []string{ "-N " + chain, - "-A " + chain + " -s " + subnet1 + " -d " + address1 + " -c 0 0 -j ACCEPT", - "-A " + chain + " -s " + subnet2 + " -d " + address2 + " -c 0 0 -j ACCEPT", - "-A " + chain + " -s " + subnet2 + " -d " + address1 + " -c 0 0 -j ACCEPT", - "-A " + chain + " -s " + address1 + " -d " + subnet2 + " -c 0 0 -j ACCEPT", + "-A " + chain + " -s " + subnet1 + " -d " + address1 + suffix, + "-A " + chain + " -s " + subnet2 + " -d " + address2 + suffix, + "-A " + chain + " -s " + subnet2 + " -d " + address1 + suffix, + "-A " + chain + " -s " + address1 + " -d " + subnet2 + suffix, } if !reflect.DeepEqual(rules, expected) { @@ -408,3 +417,76 @@ func TestIsNotExist(t *testing.T) { t.Fatal("IsNotExist returned false, expected true") } } + +func TestFilterRuleOutput(t *testing.T) { + testCases := []struct { + name string + in string + out string + }{ + { + "legacy output", + "-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT", + "-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT", + }, + { + "nft output", + "[99:42] -A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT", + "-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT -c 99 42", + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + actual := filterRuleOutput(tt.in) + if actual != tt.out { + t.Fatalf("expect %s actual %s", tt.out, actual) + } + }) + } +} + +func TestExtractIptablesVersion(t *testing.T) { + testCases := []struct { + in string + v1, v2, v3 int + mode string + err bool + }{ + { + "iptables v1.8.0 (nf_tables)", + 1, 8, 0, + "nf_tables", + false, + }, + { + "iptables v1.8.0 (legacy)", + 1, 8, 0, + "legacy", + false, + }, + { + "iptables v1.6.2", + 1, 6, 2, + "legacy", + false, + }, + } + + for i, tt := range testCases { + t.Run(fmt.Sprint(i), func(t *testing.T) { + v1, v2, v3, mode, err := extractIptablesVersion(tt.in) + if err == nil && tt.err { + t.Fatal("expected err, got none") + } else if err != nil && !tt.err { + t.Fatalf("unexpected err %s", err) + } + + if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != tt.mode { + t.Fatalf("expected %d %d %d %s, got %d %d %d %s", + tt.v1, tt.v2, tt.v3, tt.mode, + v1, v2, v3, mode) + } + }) + } +} diff --git a/test b/test index f4ff369..91bece8 100755 --- a/test +++ b/test @@ -45,11 +45,14 @@ split=(${TEST// / }) TEST=${split[@]/#/${REPO_PATH}/} echo "Running tests..." -go test -i ${TEST} +bin=$(mktemp) + +go test -c -o ${bin} ${COVER} -i ${TEST} if [[ -z "$SUDO_PERMITTED" ]]; then echo "Test aborted for safety reasons. Please set the SUDO_PERMITTED variable." exit 1 fi -sudo -E bash -c "PATH=\$GOROOT/bin:\$PATH go test ${COVER} $@ ${TEST}" +sudo -E bash -c "${bin} $@ ${TEST}" echo "Success" +rm "${bin}"