diff --git a/.gitignore b/.gitignore index db4561e..66b01cd 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,5 @@ docs/_build/ # PyBuilder target/ + +.idea/ diff --git a/examples/simple_stateful.py b/examples/simple_stateful.py index 9a89289..4429468 100644 --- a/examples/simple_stateful.py +++ b/examples/simple_stateful.py @@ -1,24 +1,24 @@ -from pyptables import default_tables, restore -from pyptables.rules import Rule, Accept - -# get a default set of tables and chains -tables = default_tables() - -# get the forward chain of the filter tables -forward = tables['filter']['FORWARD'] - -# any packet matching an established connection should be allowed -forward.append(Accept(match='conntrack', ctstate='ESTABLISHED')) - -# add rules to the forward chain for DNS, HTTP and HTTPS ports -forward.append(Accept(proto='tcp', dport='53')) -forward.append(Accept(proto='tcp', dport='80')) -forward.append(Accept(proto='tcp', dport='443')) - -# any packet not matching a rules will be dropped -forward.policy = Rule.DROP - -# write the rules into the kernel -restore(tables) - -print tables.to_iptables() \ No newline at end of file +from pyptables import default_tables, restore +from pyptables.rules import Rule, Accept + +# get a default set of tables and chains +tables = default_tables() + +# get the forward chain of the filter tables +forward = tables['filter']['FORWARD'] + +# any packet matching an established connection should be allowed +forward.append(Accept(match='conntrack', ctstate='ESTABLISHED')) + +# add rules to the forward chain for DNS, HTTP and HTTPS ports +forward.append(Accept(proto='tcp', dport='53')) +forward.append(Accept(proto='tcp', dport='80')) +forward.append(Accept(proto='tcp', dport='443')) + +# any packet not matching a rules will be dropped +forward.policy = Rule.DROP + +# write the rules into the kernel +restore(tables) + +print(tables.to_iptables()) diff --git a/pyptables/__init__.py b/pyptables/__init__.py index be94bd8..59cbaaf 100644 --- a/pyptables/__init__.py +++ b/pyptables/__init__.py @@ -1,10 +1,11 @@ import re import subprocess -from tables import Tables, Table -from chains import BuiltinChain, UserChain -from rules import Rule, Accept, Drop, Jump, Redirect, Return, Log, CustomRule -from rules.matches import Match +from pyptables.tables import Tables, Table +from pyptables.chains import BuiltinChain, UserChain +from pyptables.rules import Rule, Accept, Drop, Jump, Redirect, Return, Log, CustomRule +from pyptables.rules.matches import Match + def default_tables(): """Generate a set of iptables containing all the default tables and chains""" diff --git a/pyptables/__main__.py b/pyptables/__main__.py index 4a23995..b57d382 100644 --- a/pyptables/__main__.py +++ b/pyptables/__main__.py @@ -22,4 +22,4 @@ if '--line-numbers' in sys.argv: output = add_line_numbers(output) -print output +print(output) diff --git a/pyptables/chains.py b/pyptables/chains.py index c80f8a2..efcc91d 100644 --- a/pyptables/chains.py +++ b/pyptables/chains.py @@ -1,7 +1,8 @@ import re from collections import namedtuple -from base import DebugObject +from pyptables.base import DebugObject + class AbstractChain(DebugObject, list): """Represents an iptables Chain. Holds a number of Rule objects in a list-like fashion""" @@ -27,7 +28,7 @@ def to_iptables(self): 'rules': rule_output, }, ) - except Exception, e: # pragma: no cover + except Exception as e: # pragma: no cover e.iptables_path = getattr(e, 'iptables_path', []) e.iptables_path.insert(0, self.name) raise diff --git a/pyptables/rules/__init__.py b/pyptables/rules/__init__.py index bdcb061..95d4e8f 100644 --- a/pyptables/rules/__init__.py +++ b/pyptables/rules/__init__.py @@ -2,7 +2,7 @@ rules for iptables. """ -from base import AbstractRule, CustomRule, Rule, CompositeRule +from pyptables.rules.base import AbstractRule, CustomRule, Rule, CompositeRule from pyptables.chains import AbstractChain as _AbstractChain diff --git a/pyptables/rules/arguments.py b/pyptables/rules/arguments.py index a07a40c..bcdf179 100644 --- a/pyptables/rules/arguments.py +++ b/pyptables/rules/arguments.py @@ -222,7 +222,7 @@ def _update_args(self, args, kwargs): def __iter__(self): kwargs = dict(self.kwargs) # duplicate dictionary, as it is modified below for argument in self.known_args: - for key in kwargs.keys(): + for key in kwargs: if argument.matches(key): value = kwargs.pop(key) yield argument.bind(key, value) diff --git a/pyptables/rules/base.py b/pyptables/rules/base.py index c4e32b1..e4bfe6e 100644 --- a/pyptables/rules/base.py +++ b/pyptables/rules/base.py @@ -1,9 +1,10 @@ import itertools -from ..base import DebugObject +from pyptables.base import DebugObject + +from pyptables.rules.arguments import UnboundArgument, ArgumentList +from pyptables.rules.matches import Match -from arguments import UnboundArgument, ArgumentList -from matches import Match class AbstractRule(DebugObject): """Represents an iptables rule""" @@ -19,7 +20,7 @@ def to_iptables(self, prefix=''): 'header': self._header(), 'rules': self._rule_definition(prefix), } - except Exception, e: # pragma: no cover + except Exception as e: # pragma: no cover e.iptables_path = getattr(e, 'iptables_path', []) e.iptables_path.insert(0, "Rule:\n created: %s\n comment: %s" % (self.debug_info(), self.comment)) raise diff --git a/pyptables/rules/forwarding/__init__.py b/pyptables/rules/forwarding/__init__.py index 01b3b67..6675769 100644 --- a/pyptables/rules/forwarding/__init__.py +++ b/pyptables/rules/forwarding/__init__.py @@ -2,4 +2,4 @@ rules for iptables. """ -from base import ForwardingRule +from pyptables.rules.forwarding.base import ForwardingRule diff --git a/pyptables/rules/forwarding/locations.py b/pyptables/rules/forwarding/locations.py index 0738a15..5adf903 100644 --- a/pyptables/rules/forwarding/locations.py +++ b/pyptables/rules/forwarding/locations.py @@ -3,9 +3,10 @@ Locations represent a network location. """ -from ...base import DebugObject -from ..arguments import ArgumentList -from hosts import Hosts +from pyptables.base import DebugObject +from pyptables.rules.arguments import ArgumentList +from pyptables.rules.forwarding.hosts import Hosts + class Location(DebugObject): """Represents a network location""" diff --git a/pyptables/rules/input/__init__.py b/pyptables/rules/input/__init__.py index 10656c1..22438c1 100644 --- a/pyptables/rules/input/__init__.py +++ b/pyptables/rules/input/__init__.py @@ -2,4 +2,4 @@ rules for iptables. """ -from base import InputRule +from pyptables.rules.input.base import InputRule diff --git a/pyptables/rules/marks.py b/pyptables/rules/marks.py index 4446ee6..1493ece 100644 --- a/pyptables/rules/marks.py +++ b/pyptables/rules/marks.py @@ -3,7 +3,7 @@ from functools import partial from random import Random -from . import Rule +from pyptables.rules import Rule _marks = [] class Mark(Rule): diff --git a/pyptables/tables.py b/pyptables/tables.py index 1a67a68..cda3bcd 100644 --- a/pyptables/tables.py +++ b/pyptables/tables.py @@ -1,6 +1,7 @@ from collections import OrderedDict -from base import DebugObject +from pyptables.base import DebugObject + class Tables(DebugObject, OrderedDict): """Dictionary like top-level container of iptables, holds a number of Table objects.""" @@ -19,7 +20,7 @@ def to_iptables(self): 'header': header, 'tables': table_output, } - except Exception, e: #pragma: no cover + except Exception as e: # pragma: no cover e.iptables_path = getattr(e, 'iptables_path', []) e.iptables_path.insert(0, "Tables") e.message = "Iptables error at:\n %s\n\nError message: %s" % ("\n".join(e.iptables_path).replace('\n', '\n '), e.message) @@ -64,7 +65,7 @@ def to_iptables(self): 'rules': "\n\n".join([result.rules for result in chain_results]), 'footer': 'COMMIT' } - except Exception, e: #pragma: no cover + except Exception as e: # pragma: no cover e.iptables_path = getattr(e, 'iptables_path', []) e.iptables_path.insert(0, self.name) raise @@ -78,4 +79,4 @@ def append(self, chain): return chain def __repr__(self): - return "" % (self.name, self.values()) + return "" % (self.name, list(self.values())) diff --git a/pyptables/test/__init__.py b/pyptables/test/__init__.py index de73d01..487873c 100644 --- a/pyptables/test/__init__.py +++ b/pyptables/test/__init__.py @@ -1,139 +1,152 @@ -import itertools -import os.path -import unittest - -from StringIO import StringIO - -from pyptables import default_tables, Rule, UserChain, Jump, CustomRule -from pyptables.rules import CompositeRule -from pyptables.rules.arguments import ArgumentList, CustomArgument -from pyptables.rules.marks import Mark, random_mark, Marked -from pyptables.rules.input import InputRule -from pyptables.rules.forwarding import ForwardingRule -from pyptables.rules.forwarding.hosts import HostList, HostRange -from pyptables.rules.forwarding.ipsets import IPSet -from pyptables.rules.forwarding.locations import Location -from pyptables.rules.forwarding.zones import Zone -from pyptables.rules.forwarding.channels import TCPChannel, UDPChannel, ICMPChannel - -def compare(a, b): - for line_no, lines in enumerate(itertools.izip_longest(a, b, fillvalue='')): - lines = [line.strip() for line in lines] - if all(line.startswith('#') for line in lines): - continue - for i, chars in enumerate(itertools.izip_longest(*lines)): - if chars[0] != chars[1]: - raise ValueError("line %s doesn't match:\n\t%s\n\t%s\n\t%s^" % (line_no, lines[0], lines[1], i*'_')) - return True - - -class MainTest(unittest.TestCase): - - def test(self): - self.assertRaisesRegexp(ValueError, "Only 'not' is supported", lambda: Rule(s__invalid=None).to_iptables()) - self.assertRaisesRegexp(ValueError, "This argument is not invertable", lambda: Rule(f__not=None).to_iptables()) - self.assertRaisesRegexp(ValueError, "Only 'not' is supported", lambda: Rule(custom__invalid=None).to_iptables()) - self.assertRaisesRegexp(ValueError, "badly formatted argument name", lambda: Rule(custom__not__invalid=None).to_iptables()) - Rule(custom__not=None).to_iptables() - arg_list = ArgumentList(custom='1') - self.assertIsInstance(arg_list['custom'], CustomArgument) - self.assertRaises(KeyError, lambda: arg_list['missing']) - arg_list = arg_list(another=None, args=(ArgumentList(custom='2'),)) - self.assertTrue('custom' in arg_list) - self.assertFalse('missing' in arg_list) - self.assertIsInstance(arg_list['custom'], CustomArgument) - self.assertIsInstance(arg_list['another'], CustomArgument) - self.assertEqual(arg_list.to_iptables(), "--another --custom 1 --custom 2") - repr(Rule(j='DROP').arguments) - tables = default_tables() - chain = UserChain('test_chain', comment='A user chain') - repr(chain) - repr(Rule()) - Accept = Rule(j='ACCEPT') - chain.append(Rule(i='eth0', s='1.1.2.1', d__not='1.1.1.2', jump='DROP', comment='A Rule')) - def tables_set(): tables['filter'] = None - self.assertRaises(TypeError, tables_set) - print tables - def table_set(): tables['filter']['INPUT'] = None - self.assertRaises(TypeError, table_set) - print tables['filter'] - tables['filter'].append(chain) - tables['filter']['INPUT'].append(Jump(chain)) - tables['filter']['INPUT'].append(Jump('string_chain')) - tables['filter']['OUTPUT'].append(CustomRule('a random string')) - tables['filter']['OUTPUT'].append(CustomRule('a random string', comment='this is a custom rule with a comment')) - tables['filter']['OUTPUT'].append(CustomRule('a random string', comment='this is a custom rule with a comment')) - tables['filter']['OUTPUT'].append(CompositeRule([Accept(s='1.1.1.1'), Rule(j='DROP')])) - tables['mangle']['OUTPUT'].append(Mark(123)) - random = random_mark() - - Log = Rule(j='LOG') - simple_zone = Zone('A zone', 'eth0') - repr(simple_zone) - simple_location = Location('A location3', Zone('A zone', 'br0', physdev='eth0')) - repr(simple_location) - ip_set = IPSet('a_set') - repr(ip_set) - list_location = Location.from_ip_list('A location', None, '1.1.1.1,2.2.2.2') - range_location = Location.from_ip_list('A location2', simple_zone, '3.1.1.1-3.2.2.2') - tables['mangle']['OUTPUT'].append(Marked(random)) - tables['mangle']['INPUT'].append(InputRule(policy='DROP', - channels=[], - sources=itertools.chain(list_location, - range_location, - [simple_location, ip_set], - ), - log=True, - log_cls=Log, - )) - tcp = TCPChannel(sports='1', dports='2') - udp = UDPChannel(states='ESTABLISHED') - icmp1 = ICMPChannel(icmp_type='1') - icmp2 = ICMPChannel() - host_list = HostList(['1.1.1.1', '2.2.2.2']) - repr(host_list) - str(host_list) - host_range = HostRange('1.1.1.1-2.2.2.2') - repr(host_range) - str(host_range) - tables['mangle']['INPUT'].append(InputRule('ACCEPT', - channels=[tcp, udp, icmp1, icmp2], - )) - tables['mangle']['INPUT'].append(InputRule('REJECT')) - tables['mangle']['INPUT'].append(InputRule('NONE')) - tables['filter']['FORWARD'].append(ForwardingRule(policy='DROP', - sources=[], - destinations=[simple_location, ip_set], - )) - tables['filter']['FORWARD'].append(ForwardingRule(policy='ACCEPT', - sources=list_location, - destinations=[], - )) - tables['filter']['FORWARD'].append(ForwardingRule(policy='REJECT', - sources=range_location, - destinations=list_location, - )) - tables['filter']['FORWARD'].append(ForwardingRule(policy='REJECT', - sources=[], - destinations=range_location, - )) - tables['filter']['FORWARD'].append(ForwardingRule(policy='NONE', - sources=[], - destinations=[], - channels=[tcp, udp, icmp1, icmp2], - args=[host_list.as_input(), host_range.as_input()], - log=True, - log_cls=Log, - )) - self.assertRaises(ValueError, lambda: InputRule('BAD').to_iptables()) - self.assertRaises(ValueError, lambda: ForwardingRule(policy='BAD', sources=[], destinations=[]).to_iptables()) - self.assertRaisesRegexp(ValueError, "Argument must be of type.*", lambda: Rule(p=1).to_iptables()) - result = tables.to_iptables() - fixture_file = os.path.join(os.path.dirname(__file__), 'test.dat') - with open(fixture_file, 'w') as fixture: - fixture.write(result) - with open(fixture_file) as fixture: - try: - compare(fixture, StringIO(result)) - except ValueError, e: - self.fail(str(e)) +import itertools + +import six +from six.moves import zip_longest +import os.path +import unittest + +from io import StringIO + +from pyptables import default_tables, Rule, UserChain, Jump, CustomRule +from pyptables.rules import CompositeRule +from pyptables.rules.arguments import ArgumentList, CustomArgument +from pyptables.rules.marks import Mark, random_mark, Marked +from pyptables.rules.input import InputRule +from pyptables.rules.forwarding import ForwardingRule +from pyptables.rules.forwarding.hosts import HostList, HostRange +from pyptables.rules.forwarding.ipsets import IPSet +from pyptables.rules.forwarding.locations import Location +from pyptables.rules.forwarding.zones import Zone +from pyptables.rules.forwarding.channels import TCPChannel, UDPChannel, ICMPChannel + +def compare(a, b): + for line_no, lines in enumerate(zip_longest(a, b, fillvalue='')): + lines = [line.strip() for line in lines] + if all(line.startswith('#') for line in lines): + continue + for i, chars in enumerate(zip_longest(*lines)): + if chars[0] != chars[1]: + raise ValueError("line %s doesn't match:\n\t%s\n\t%s\n\t%s^" % (line_no, lines[0], lines[1], i*'_')) + return True + + +class MainTest(unittest.TestCase): + def test(self): + with six.assertRaisesRegex(self, ValueError, "Only 'not' is supported"): + Rule(s__invalid=None).to_iptables() + with six.assertRaisesRegex(self, ValueError, "This argument is not invertable"): + Rule(f__not=None).to_iptables() + with six.assertRaisesRegex(self, ValueError, "Only 'not' is supported"): + Rule(custom__invalid=None).to_iptables() + with six.assertRaisesRegex(self, ValueError, "badly formatted argument name"): + Rule(custom__not__invalid=None).to_iptables() + Rule(custom__not=None).to_iptables() + arg_list = ArgumentList(custom='1') + self.assertIsInstance(arg_list['custom'], CustomArgument) + with self.assertRaises(KeyError): + _ = arg_list['missing'] + arg_list = arg_list(another=None, args=(ArgumentList(custom='2'),)) + self.assertTrue('custom' in arg_list) + self.assertFalse('missing' in arg_list) + self.assertIsInstance(arg_list['custom'], CustomArgument) + self.assertIsInstance(arg_list['another'], CustomArgument) + self.assertEqual(arg_list.to_iptables(), "--another --custom 1 --custom 2") + repr(Rule(j='DROP').arguments) + tables = default_tables() + chain = UserChain('test_chain', comment='A user chain') + repr(chain) + repr(Rule()) + Accept = Rule(j='ACCEPT') + chain.append(Rule(i='eth0', s='1.1.2.1', d__not='1.1.1.2', jump='DROP', comment='A Rule')) + + with self.assertRaises(TypeError): + tables['filter'] = None + print(tables) + + with self.assertRaises(TypeError): + tables['filter']['INPUT'] = None + print(tables['filter']) + + tables['filter'].append(chain) + tables['filter']['INPUT'].append(Jump(chain)) + tables['filter']['INPUT'].append(Jump('string_chain')) + tables['filter']['OUTPUT'].append(CustomRule('a random string')) + tables['filter']['OUTPUT'].append(CustomRule('a random string', comment='this is a custom rule with a comment')) + tables['filter']['OUTPUT'].append(CustomRule('a random string', comment='this is a custom rule with a comment')) + tables['filter']['OUTPUT'].append(CompositeRule([Accept(s='1.1.1.1'), Rule(j='DROP')])) + tables['mangle']['OUTPUT'].append(Mark(123)) + random = random_mark() + + Log = Rule(j='LOG') + simple_zone = Zone('A zone', 'eth0') + repr(simple_zone) + simple_location = Location('A location3', Zone('A zone', 'br0', physdev='eth0')) + repr(simple_location) + ip_set = IPSet('a_set') + repr(ip_set) + list_location = Location.from_ip_list('A location', None, '1.1.1.1,2.2.2.2') + range_location = Location.from_ip_list('A location2', simple_zone, '3.1.1.1-3.2.2.2') + tables['mangle']['OUTPUT'].append(Marked(random)) + tables['mangle']['INPUT'].append(InputRule(policy='DROP', + channels=[], + sources=itertools.chain(list_location, + range_location, + [simple_location, ip_set], + ), + log=True, + log_cls=Log, + )) + tcp = TCPChannel(sports='1', dports='2') + udp = UDPChannel(states='ESTABLISHED') + icmp1 = ICMPChannel(icmp_type='1') + icmp2 = ICMPChannel() + host_list = HostList(['1.1.1.1', '2.2.2.2']) + repr(host_list) + str(host_list) + host_range = HostRange('1.1.1.1-2.2.2.2') + repr(host_range) + str(host_range) + tables['mangle']['INPUT'].append(InputRule('ACCEPT', + channels=[tcp, udp, icmp1, icmp2], + )) + tables['mangle']['INPUT'].append(InputRule('REJECT')) + tables['mangle']['INPUT'].append(InputRule('NONE')) + tables['filter']['FORWARD'].append(ForwardingRule(policy='DROP', + sources=[], + destinations=[simple_location, ip_set], + )) + tables['filter']['FORWARD'].append(ForwardingRule(policy='ACCEPT', + sources=list_location, + destinations=[], + )) + tables['filter']['FORWARD'].append(ForwardingRule(policy='REJECT', + sources=range_location, + destinations=list_location, + )) + tables['filter']['FORWARD'].append(ForwardingRule(policy='REJECT', + sources=[], + destinations=range_location, + )) + tables['filter']['FORWARD'].append(ForwardingRule(policy='NONE', + sources=[], + destinations=[], + channels=[tcp, udp, icmp1, icmp2], + args=[host_list.as_input(), host_range.as_input()], + log=True, + log_cls=Log, + )) + with self.assertRaises(ValueError): + InputRule('BAD').to_iptables() + with self.assertRaises(ValueError): + ForwardingRule(policy='BAD', sources=[], destinations=[]).to_iptables() + with six.assertRaisesRegex(self, ValueError, "Argument must be of type.*"): + Rule(p=1).to_iptables() + result = tables.to_iptables() + fixture_file = os.path.join(os.path.dirname(__file__), 'test.dat') + with open(fixture_file, 'w') as fixture: + fixture.write(result) + with open(fixture_file) as fixture: + try: + compare(fixture, StringIO(six.u(result))) + except ValueError as e: # pragma: nocover + self.fail(str(e)) diff --git a/setup.py b/setup.py index ead0884..5b67f48 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,9 @@ license='LICENSE.txt', description='Python package for generating Linux iptables configurations.', long_description=open('README.rst').read(), + requires=[ + 'six' + ], install_requires=[ "nose", "coverage",