Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for iptables in nftables mode. #51

Merged
merged 1 commit into from
Aug 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 76 additions & 21 deletions iptables/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ import (
// Adds the output of stderr to exec.ExitError
type Error struct {
exec.ExitError
cmd exec.Cmd
msg string
cmd exec.Cmd
msg string
exitStatus *int //for overriding
}

func (e *Error) ExitStatus() int {
if e.exitStatus != nil {
return *e.exitStatus
}
return e.Sys().(syscall.WaitStatus).ExitStatus()
}

Expand Down Expand Up @@ -65,6 +69,7 @@ type IPTables struct {
v1 int
v2 int
v3 int
mode string // the underlying iptables operating mode, e.g. nf_tables
}

// New creates a new IPTables.
Expand All @@ -81,12 +86,10 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
return nil, err
}
vstring, err := getIptablesVersionString(path)
v1, v2, v3, err := extractIptablesVersion(vstring)
v1, v2, v3, mode, err := extractIptablesVersion(vstring)

checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)

checkPresent, waitPresent, randomFullyPresent, err := getIptablesCommandSupport(v1, v2, v3)
if err != nil {
return nil, fmt.Errorf("error checking iptables version: %v", err)
}
ipt := IPTables{
path: path,
proto: proto,
Expand All @@ -96,6 +99,7 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
v1: v1,
v2: v2,
v3: v3,
mode: mode,
}
return &ipt, nil
}
Expand Down Expand Up @@ -266,10 +270,27 @@ func (ipt *IPTables) executeList(args []string) ([]string, error) {
}

rules := strings.Split(stdout.String(), "\n")

// strip trailing newline
if len(rules) > 0 && rules[len(rules)-1] == "" {
rules = rules[:len(rules)-1]
}

// nftables mode doesn't return an error code when listing a non-existent
// chain. Patch that up.
if len(rules) == 0 && ipt.mode == "nf_tables" {
v := 1
return nil, &Error{
cmd: exec.Cmd{Args: args},
msg: "iptables: No chain/target/match by that name.",
exitStatus: &v,
}
}

for i, rule := range rules {
rules[i] = filterRuleOutput(rule)
}

return rules, nil
}

Expand All @@ -284,11 +305,18 @@ func (ipt *IPTables) NewChain(table, chain string) error {
func (ipt *IPTables) ClearChain(table, chain string) error {
err := ipt.NewChain(table, chain)

// the exit code for "this table already exists" is different for
// different iptables modes
existsErr := 1
if ipt.mode == "nf_tables" {
existsErr = 4
}

eerr, eok := err.(*Error)
switch {
case err == nil:
return nil
case eok && eerr.ExitStatus() == 1:
case eok && eerr.ExitStatus() == existsErr:
// chain already exists. Flush (clear) it.
return ipt.run("-t", table, "-F", chain)
default:
Expand Down Expand Up @@ -357,7 +385,7 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
if err := cmd.Run(); err != nil {
switch e := err.(type) {
case *exec.ExitError:
return &Error{*e, cmd, stderr.String()}
return &Error{*e, cmd, stderr.String(), nil}
default:
return err
}
Expand All @@ -376,36 +404,40 @@ func getIptablesCommand(proto Protocol) string {
}

// Checks if iptables has the "-C" and "--wait" flag
func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, error) {

return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3), nil
func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool) {
return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3)
}

// getIptablesVersion returns the first three components of the iptables version.
// e.g. "iptables v1.3.66" would return (1, 3, 66, nil)
func extractIptablesVersion(str string) (int, int, int, error) {
versionMatcher := regexp.MustCompile("v([0-9]+)\\.([0-9]+)\\.([0-9]+)")
// getIptablesVersion returns the first three components of the iptables version
// and the operating mode (e.g. nf_tables or legacy)
// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil)
func extractIptablesVersion(str string) (int, int, int, string, error) {
versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`)
result := versionMatcher.FindStringSubmatch(str)
if result == nil {
return 0, 0, 0, fmt.Errorf("no iptables version found in string: %s", str)
return 0, 0, 0, "", fmt.Errorf("no iptables version found in string: %s", str)
}

v1, err := strconv.Atoi(result[1])
if err != nil {
return 0, 0, 0, err
return 0, 0, 0, "", err
}

v2, err := strconv.Atoi(result[2])
if err != nil {
return 0, 0, 0, err
return 0, 0, 0, "", err
}

v3, err := strconv.Atoi(result[3])
if err != nil {
return 0, 0, 0, err
return 0, 0, 0, "", err
}

return v1, v2, v3, nil
mode := "legacy"
if result[4] != "" {
mode = result[4]
}
return v1, v2, v3, mode, nil
}

// Runs "iptables --version" to get the version string
Expand Down Expand Up @@ -473,3 +505,26 @@ func (ipt *IPTables) existsForOldIptables(table, chain string, rulespec []string
}
return strings.Contains(stdout.String(), rs), nil
}

// counterRegex is the regex used to detect nftables counter format
var counterRegex = regexp.MustCompile(`^\[([0-9]+):([0-9]+)\] `)

// filterRuleOutput works around some inconsistencies in output.
// For example, when iptables is in legacy vs. nftables mode, it produces
// different results.
func filterRuleOutput(rule string) string {
out := rule

// work around an output difference in nftables mode where counters
// are output in iptables-save format, rather than iptables -S format
// The string begins with "[0:0]"
//
// Fixes #49
if groups := counterRegex.FindStringSubmatch(out); groups != nil {
// drop the brackets
out = out[len(groups[0]):]
out = fmt.Sprintf("%s -c %s %s", out, groups[1], groups[2])
}

return out
}
98 changes: 90 additions & 8 deletions iptables/iptables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ func mustTestableIptables() []*IPTables {
}

func TestChain(t *testing.T) {
for _, ipt := range mustTestableIptables() {
runChainTests(t, ipt)
for i, ipt := range mustTestableIptables() {
t.Run(fmt.Sprint(i), func(t *testing.T) {
runChainTests(t, ipt)
})
}
}

Expand Down Expand Up @@ -179,8 +181,10 @@ func runChainTests(t *testing.T, ipt *IPTables) {
}

func TestRules(t *testing.T) {
for _, ipt := range mustTestableIptables() {
runRulesTests(t, ipt)
for i, ipt := range mustTestableIptables() {
t.Run(fmt.Sprint(i), func(t *testing.T) {
runRulesTests(t, ipt)
})
}
}

Expand Down Expand Up @@ -265,12 +269,17 @@ func runRulesTests(t *testing.T, ipt *IPTables) {
t.Fatalf("ListWithCounters failed: %v", err)
}

suffix := " -c 0 0 -j ACCEPT"
if ipt.mode == "nf_tables" {
suffix = " -j ACCEPT -c 0 0"
}

expected = []string{
"-N " + chain,
"-A " + chain + " -s " + subnet1 + " -d " + address1 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + subnet2 + " -d " + address2 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + subnet2 + " -d " + address1 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + address1 + " -d " + subnet2 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + subnet1 + " -d " + address1 + suffix,
"-A " + chain + " -s " + subnet2 + " -d " + address2 + suffix,
"-A " + chain + " -s " + subnet2 + " -d " + address1 + suffix,
"-A " + chain + " -s " + address1 + " -d " + subnet2 + suffix,
}

if !reflect.DeepEqual(rules, expected) {
Expand Down Expand Up @@ -408,3 +417,76 @@ func TestIsNotExist(t *testing.T) {
t.Fatal("IsNotExist returned false, expected true")
}
}

func TestFilterRuleOutput(t *testing.T) {
testCases := []struct {
name string
in string
out string
}{
{
"legacy output",
"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
},
{
"nft output",
"[99:42] -A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT -c 99 42",
},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
actual := filterRuleOutput(tt.in)
if actual != tt.out {
t.Fatalf("expect %s actual %s", tt.out, actual)
}
})
}
}

func TestExtractIptablesVersion(t *testing.T) {
testCases := []struct {
in string
v1, v2, v3 int
mode string
err bool
}{
{
"iptables v1.8.0 (nf_tables)",
1, 8, 0,
"nf_tables",
false,
},
{
"iptables v1.8.0 (legacy)",
1, 8, 0,
"legacy",
false,
},
{
"iptables v1.6.2",
1, 6, 2,
"legacy",
false,
},
}

for i, tt := range testCases {
t.Run(fmt.Sprint(i), func(t *testing.T) {
v1, v2, v3, mode, err := extractIptablesVersion(tt.in)
if err == nil && tt.err {
t.Fatal("expected err, got none")
} else if err != nil && !tt.err {
t.Fatalf("unexpected err %s", err)
}

if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != tt.mode {
t.Fatalf("expected %d %d %d %s, got %d %d %d %s",
tt.v1, tt.v2, tt.v3, tt.mode,
v1, v2, v3, mode)
}
})
}
}
7 changes: 5 additions & 2 deletions test
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@ split=(${TEST// / })
TEST=${split[@]/#/${REPO_PATH}/}

echo "Running tests..."
go test -i ${TEST}
bin=$(mktemp)

go test -c -o ${bin} ${COVER} -i ${TEST}
if [[ -z "$SUDO_PERMITTED" ]]; then
echo "Test aborted for safety reasons. Please set the SUDO_PERMITTED variable."
exit 1
fi

sudo -E bash -c "PATH=\$GOROOT/bin:\$PATH go test ${COVER} $@ ${TEST}"
sudo -E bash -c "${bin} $@ ${TEST}"
echo "Success"
rm "${bin}"