diff --git a/examples/nftables_sets.py b/examples/nftables_sets.py new file mode 100644 index 000000000..670a147aa --- /dev/null +++ b/examples/nftables_sets.py @@ -0,0 +1,53 @@ +import time + +from pyroute2.netlink.nfnetlink.nftsocket import NFPROTO_IPV4 +from pyroute2.nftables.main import NFTables +from pyroute2.nftables.main import NFTSetElem + + +def test_ipv4_addr_set(): + with NFTables(nfgen_family=NFPROTO_IPV4) as nft: + nft.table("add", name="filter") + my_set = nft.sets("add", table="filter", name="test0", key_type="ipv4_addr", + comment="my test set", timeout=0) + + # With str + nft.set_elems( + "add", + table="filter", + set="test0", + elements={"10.2.3.4", "10.4.3.2"}, + ) + + # With NFTSet & NFTSetElem classes + nft.set_elems( + "add", + set=my_set, + elements={NFTSetElem(value="9.9.9.9", timeout=1000)}, + ) + + try: + assert {e.value for e in nft.set_elems("get", table="filter", set="test0")} == { + "10.2.3.4", + "10.4.3.2", + "9.9.9.9", + } + assert nft.sets("get", table="filter", name="test0").comment == b"my test set" + + time.sleep(1.2) + # timeout for elem 9.9.9.9 (1000ms) + assert {e.value for e in nft.set_elems("get", table="filter", set="test0")} == { + "10.2.3.4", + "10.4.3.2", + } + finally: + nft.sets("del", table="filter", name="test0") + nft.table("del", name="filter") + + +def main(): + test_ipv4_addr_set() + + +if __name__ == "__main__": + main() diff --git a/pyroute2/netlink/nfnetlink/nftsocket.py b/pyroute2/netlink/nfnetlink/nftsocket.py index 48ae9b845..0646be6b9 100644 --- a/pyroute2/netlink/nfnetlink/nftsocket.py +++ b/pyroute2/netlink/nfnetlink/nftsocket.py @@ -4,6 +4,7 @@ See also: pyroute2.nftables """ +import struct import threading from pyroute2.netlink import ( @@ -16,6 +17,8 @@ NLM_F_REPLACE, NLM_F_REQUEST, nla, + nla_base_string, + nlmsg_atoms, ) from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES, nfgen_msg from pyroute2.netlink.nlsocket import NetlinkSocket @@ -46,6 +49,120 @@ NFT_MSG_GETFLOWTABLE = 23 NFT_MSG_DELFLOWTABLE = 24 +# from nftables/include/datatype.h +DATA_TYPE_INVALID = 0 +DATA_TYPE_VERDICT = 1 +DATA_TYPE_NFPROTO = 2 +DATA_TYPE_BITMASK = 3 +DATA_TYPE_INTEGER = 4 +DATA_TYPE_STRING = 5 +DATA_TYPE_LLADDR = 6 +DATA_TYPE_IPADDR = 7 +DATA_TYPE_IP6ADDR = 8 +DATA_TYPE_ETHERADDR = 9 +DATA_TYPE_ETHERTYPE = 10 +DATA_TYPE_ARPOP = 11 +DATA_TYPE_INET_PROTOCOL = 12 +DATA_TYPE_INET_SERVICE = 13 +DATA_TYPE_ICMP_TYPE = 14 +DATA_TYPE_TCP_FLAG = 15 +DATA_TYPE_DCCP_PKTTYPE = 16 +DATA_TYPE_MH_TYPE = 17 +DATA_TYPE_TIME = 18 +DATA_TYPE_MARK = 19 +DATA_TYPE_IFINDEX = 20 +DATA_TYPE_ARPHRD = 21 +DATA_TYPE_REALM = 22 +DATA_TYPE_CLASSID = 23 +DATA_TYPE_UID = 24 +DATA_TYPE_GID = 25 +DATA_TYPE_CT_STATE = 26 +DATA_TYPE_CT_DIR = 27 +DATA_TYPE_CT_STATUS = 28 +DATA_TYPE_ICMP6_TYPE = 29 +DATA_TYPE_CT_LABEL = 30 +DATA_TYPE_PKTTYPE = 31 +DATA_TYPE_ICMP_CODE = 32 +DATA_TYPE_ICMPV6_CODE = 33 +DATA_TYPE_ICMPX_CODE = 34 +DATA_TYPE_DEVGROUP = 35 +DATA_TYPE_DSCP = 36 +DATA_TYPE_ECN = 37 +DATA_TYPE_FIB_ADDR = 38 +DATA_TYPE_BOOLEAN = 39 +DATA_TYPE_CT_EVENTBIT = 40 +DATA_TYPE_IFNAME = 41 +DATA_TYPE_IGMP_TYPE = 42 +DATA_TYPE_TIME_DATE = 43 +DATA_TYPE_TIME_HOUR = 44 +DATA_TYPE_TIME_DAY = 45 +DATA_TYPE_CGROUPV2 = 46 + +# from include/uapi/linux/netfilter.h +NFPROTO_INET = 1 +NFPROTO_IPV4 = 2 +NFPROTO_ARP = 3 +NFPROTO_NETDEV = 5 +NFPROTO_BRIDGE = 7 +NFPROTO_IPV6 = 10 + + +class nftnl_udata(nla_base_string): + # TLV structures: + # nftnl_udata + # <-------- HEADER --------> <------ PAYLOAD ------> + # +------------+-------------+- - - - - - - - - - - -+ + # | type | len | value | + # | (1 byte) | (1 byte) | | + # +--------------------------+- - - - - - - - - - - -+ + # <-- sizeof(nftnl_udata) -> <-- nftnl_udata->len --> + __slots__ = () + + @classmethod + def udata_decode(cls, data): + offset = 0 + result = [] + while offset + 2 < len(data): + utype = data[offset] + ulen = data[offset + 1] + offset += 2 + if offset + ulen > len(data): + return None # bad decode + try: + type_name = cls.udata_types[utype] + except IndexError: + return None # bad decode + + value = data[offset : offset + ulen] + if type_name.endswith("_COMMENT") and value[-1] == 0: + value = value[:-1] # remove \x00 c *str + result.append((type_name, value)) + offset += ulen + return result + + @classmethod + def udata_encode(cls, values): + value = b"" + for type_name, udata in values: + if isinstance(udata, str): + udata = udata.encode() + if type_name.endswith("_COMMENT") and udata[-1] != 0: + udata = udata + b"\x00" + utype = cls.udata_types.index(type_name) + value += struct.pack("BB", utype, len(udata)) + udata + return value + + def decode(self): + nla_base_string.decode(self) + value = self.udata_decode(self['value']) + if value is not None: + self.value = value + + def encode(self): + if not isinstance(self.value, (bytes, str)): + self['value'] = self.udata_encode(self.value) + nla_base_string.encode(self) + class nft_map_uint8(nla): ops = {} @@ -74,6 +191,14 @@ def decode(self): o for i, o in enumerate(self.ops) if self['value'] & 1 << i ) + def encode(self): + value = 0 + for i, name in enumerate(self.ops): + if name in self.value: + value |= 1 << i + self["value"] = value + nla.encode(self) + class nft_flags_be16(nla): fields = [('value', '>H')] @@ -849,15 +974,28 @@ class nft_set_msg(nfgen_msg, nft_contains_expr): ('NFTA_SET_POLICY', 'set_policy'), ('NFTA_SET_DESC', 'set_desc'), ('NFTA_SET_ID', 'be32'), - ('NFTA_SET_TIMEOUT', 'be32'), + ('NFTA_SET_TIMEOUT', 'be64'), ('NFTA_SET_GC_INTERVAL', 'be32'), - ('NFTA_SET_USERDATA', 'hex'), + ('NFTA_SET_USERDATA', 'set_udata'), ('NFTA_SET_PAD', 'hex'), ('NFTA_SET_OBJ_TYPE', 'be32'), ('NFTA_SET_HANDLE', 'be64'), - ('NFTA_SET_EXPR', '*nft_expr'), + ('NFTA_SET_EXPR', 'nft_expr'), + ('NFTA_SET_EXPRESSIONS', '*nft_expr'), ) + class set_udata(nftnl_udata): + udata_types = ( + "NFTNL_UDATA_SET_KEYBYTEORDER", + "NFTNL_UDATA_SET_DATABYTEORDER", + "NFTNL_UDATA_SET_MERGE_ELEMENTS", + "NFTNL_UDATA_SET_KEY_TYPEOF", + "NFTNL_UDATA_SET_DATA_TYPEOF", + "NFTNL_UDATA_SET_EXPR", + "NFTNL_UDATA_SET_DATA_INTERVAL", + "NFTNL_UDATA_SET_COMMENT", + ) + class set_flags(nft_flags_be32): ops = ( 'NFT_SET_ANONYMOUS', @@ -924,13 +1062,20 @@ class set_elem(nla, nft_contains_expr): ('NFTA_SET_ELEM_FLAGS', 'set_elem_flags'), ('NFTA_SET_ELEM_TIMEOUT', 'be64'), ('NFTA_SET_ELEM_EXPIRATION', 'be64'), - ('NFTA_SET_ELEM_USERDATA', 'binary'), - ('NFTA_SET_ELEM_EXPR', '*nft_expr'), + ('NFTA_SET_ELEM_USERDATA', 'set_elem_udata'), + ('NFTA_SET_ELEM_EXPR', 'nft_expr'), ('NFTA_SET_ELEM_PAD', 'hex'), ('NFTA_SET_ELEM_OBJREF', 'asciiz'), ('NFTA_SET_ELEM_KEY_END', 'data_attributes'), + ('NFTA_SET_ELEM_EXPRESSIONS', '*nft_expr'), ) + class set_elem_udata(nftnl_udata): + udata_types = ( + "NFTNL_UDATA_SET_ELEM_COMMENT", + "NFTNL_UDATA_SET_ELEM_FLAGS", + ) + class set_elem_flags(nft_flags_be32): ops = {1: 'NFT_SET_ELEM_INTERVAL_END'} @@ -1155,3 +1300,22 @@ def _command(self, msg_class, commands, cmd, kwarg): # do not return anything. else: return self.request_get(msg, msg['header']['type'], flags)[0] + + +# call nft describe "data_type" for more informations +DATA_TYPE_NAME_TO_INFO = { + "verdict": (DATA_TYPE_VERDICT, 4, nft_data.nfta_data.verdict.verdict_code), + "nf_proto": (DATA_TYPE_NFPROTO, 1, nlmsg_atoms.uint8), + "bitmask": (DATA_TYPE_BITMASK, 4, nlmsg_atoms.uint32), + "integer": (DATA_TYPE_INTEGER, 4, nlmsg_atoms.int32), + "string": (DATA_TYPE_STRING, 0, nlmsg_atoms.asciiz), + "lladdr": (DATA_TYPE_LLADDR, 0, nlmsg_atoms.lladdr), + "ipv4_addr": (DATA_TYPE_IPADDR, 4, nlmsg_atoms.ip4addr), + "ipv6_addr": (DATA_TYPE_IP6ADDR, 16, nlmsg_atoms.ip6addr), + "ether_addr": (DATA_TYPE_ETHERADDR, 6, nlmsg_atoms.l2addr), + "ether_type": (DATA_TYPE_ETHERADDR, 2, nlmsg_atoms.uint16), + "inet_proto": (DATA_TYPE_INET_PROTOCOL, 1, nlmsg_atoms.uint8), +} +DATA_TYPE_ID_TO_NAME = { + value[0]: key for key, value in DATA_TYPE_NAME_TO_INFO.items() +} diff --git a/pyroute2/nftables/main.py b/pyroute2/nftables/main.py index 6cee2d015..1e0e530df 100644 --- a/pyroute2/nftables/main.py +++ b/pyroute2/nftables/main.py @@ -2,23 +2,220 @@ ''' from pyroute2.netlink.nfnetlink import nfgen_msg from pyroute2.netlink.nfnetlink.nftsocket import ( + DATA_TYPE_ID_TO_NAME, + DATA_TYPE_NAME_TO_INFO, NFT_MSG_DELCHAIN, NFT_MSG_DELRULE, + NFT_MSG_DELSET, + NFT_MSG_DELSETELEM, NFT_MSG_DELTABLE, NFT_MSG_GETCHAIN, NFT_MSG_GETRULE, NFT_MSG_GETSET, + NFT_MSG_GETSETELEM, NFT_MSG_GETTABLE, NFT_MSG_NEWCHAIN, NFT_MSG_NEWRULE, + NFT_MSG_NEWSET, + NFT_MSG_NEWSETELEM, NFT_MSG_NEWTABLE, NFTSocket, nft_chain_msg, nft_rule_msg, + nft_set_elem_list_msg, + nft_set_msg, nft_table_msg, ) +class NFTSet: + __slots__ = ('table', 'name', 'key_type', 'timeout', 'counter', 'comment') + + def __init__(self, table, name, **kwargs): + self.table = table + self.name = name + + for attrname in self.__slots__: + if attrname in kwargs: + setattr(self, attrname, kwargs[attrname]) + elif attrname not in ("table", "name"): + setattr(self, attrname, None) + + def as_netlink(self): + attrs = {"NFTA_SET_TABLE": self.table, "NFTA_SET_NAME": self.name} + set_flags = set() + + if self.key_type is not None: + key_type, key_len, _ = DATA_TYPE_NAME_TO_INFO.get(self.key_type) + attrs["NFTA_SET_KEY_TYPE"] = key_type + attrs["NFTA_SET_KEY_LEN"] = key_len + + if self.timeout is not None: + set_flags.add("NFT_SET_TIMEOUT") + attrs["NFTA_SET_TIMEOUT"] = self.timeout + + if self.counter is True: + attrs["NFTA_SET_EXPR"] = {'attrs': [('NFTA_EXPR_NAME', 'counter')]} + + if self.comment is not None: + attrs["NFTA_SET_USERDATA"] = [ + ("NFTNL_UDATA_SET_COMMENT", self.comment) + ] + + # ID is used for bulk create, but not implemented + attrs['NFTA_SET_ID'] = 1 + attrs["NFTA_SET_FLAGS"] = set_flags + return attrs + + @classmethod + def from_netlink(cls, msg): + data_type_name = DATA_TYPE_ID_TO_NAME.get( + msg.get_attr("NFTA_SET_KEY_TYPE"), + msg.get_attr("NFTA_SET_KEY_TYPE"), # fallback to raw value + ) + + counter = False + expr = msg.get_attr('NFTA_SET_EXPR') + if expr: + expr = expr.get_attrs('NFTA_EXPR_NAME') + if expr and "counter" in expr: + counter = True + + comment = None + udata = msg.get_attr("NFTA_SET_USERDATA") + if udata: + for key, value in udata: + if key == "NFTNL_UDATA_SET_COMMENT": + comment = value + break + + return cls( + table=msg.get_attr('NFTA_SET_TABLE'), + name=msg.get_attr('NFTA_SET_NAME'), + key_type=data_type_name, + timeout=msg.get_attr('NFTA_SET_TIMEOUT'), + counter=counter, + comment=comment, + ) + + @classmethod + def from_dict(cls, d): + return cls( + **{ + name: value + for name, value in d.items() + if name in cls.__slots__ + } + ) + + def as_dict(self): + return {name: getattr(self, name) for name in self.__slots__} + + def __repr__(self): + return str(self.as_dict()) + + +class NFTSetElem: + __slots__ = ( + 'value', + 'timeout', + 'expiration', + 'counter_bytes', + 'counter_packets', + 'comment', + ) + + def __init__(self, value, **kwargs): + self.value = value + for name in self.__slots__: + if name in kwargs: + setattr(self, name, kwargs[name]) + elif name != "value": + setattr(self, name, None) + + @classmethod + def from_netlink(cls, msg, modifier): + value = msg.get_attr('NFTA_SET_ELEM_KEY').get_attr("NFTA_DATA_VALUE") + if modifier is not None: + # Need to find a better way + modifier.data = value + modifier.length = 4 + len(modifier.data) + modifier.decode() + value = modifier.value + + kwarg = { + "expiration": msg.get_attr('NFTA_SET_ELEM_EXPIRATION'), + "timeout": msg.get_attr('NFTA_SET_ELEM_TIMEOUT'), + } + + elem_expr = msg.get_attr('NFTA_SET_ELEM_EXPR') + if elem_expr: + if elem_expr.get_attr('NFTA_EXPR_NAME') == "counter": + elem_expr = elem_expr.get_attr("NFTA_EXPR_DATA") + kwarg.update( + { + "counter_bytes": elem_expr.get_attr( + "NFTA_COUNTER_BYTES" + ), + "counter_packets": elem_expr.get_attr( + "NFTA_COUNTER_PACKETS" + ), + } + ) + + udata = msg.get_attr('NFTA_SET_ELEM_USERDATA') + if udata: + for type_name, data in udata: + if type_name == "NFTNL_UDATA_SET_ELEM_COMMENT": + kwarg["comment"] = data + + return cls(value=value, **kwarg) + + def as_netlink(self, modifier): + if modifier is not None: + modifier.value = self.value + modifier.encode() + value = modifier["value"] + else: + value = self.value + + attrs = [ + ['NFTA_SET_ELEM_KEY', {'attrs': [('NFTA_DATA_VALUE', value)]}] + ] + + if self.timeout is not None: + attrs.append(['NFTA_SET_ELEM_TIMEOUT', self.timeout]) + + if self.expiration is not None: + attrs.append(['NFTA_SET_ELEM_EXPIRATION', self.expiration]) + + if self.comment is not None: + attrs.append( + [ + 'NFTA_SET_ELEM_USERDATA', + [("NFTNL_UDATA_SET_ELEM_COMMENT", self.comment)], + ] + ) + + return {'attrs': attrs} + + @classmethod + def from_dict(cls, d): + return cls( + **{ + name: value + for name, value in d.items() + if name in cls.__slots__ + } + ) + + def as_dict(self): + return {name: getattr(self, name) for name in self.__slots__} + + def __repr__(self): + return str(self.as_dict()) + + class NFTables(NFTSocket): # TODO: documentation @@ -125,3 +322,88 @@ def rule(self, cmd, **kwarg): expressions.extend(exp) kwarg['expressions'] = expressions return self._command(nft_rule_msg, commands, cmd, kwarg) + + def sets(self, cmd, **kwarg): + ''' + Example:: + nft.sets("add", table="filter", name="test0", key_type="ipv4_addr", + timeout=10000, counter=True, + comment="my comment max 252 bytes") + nft.sets("get", table="filter", name="test0") + nft.sets("del", table="filter", name="test0") + my_set = nft.sets("add", set=NFTSet(table="filter", name="test1", + key_type="ipv4_addr") + nft.sets("del", set=my_set) + ''' + commands = { + 'add': NFT_MSG_NEWSET, + 'get': NFT_MSG_GETSET, + 'del': NFT_MSG_DELSET, + } + + if "set" in kwarg: + nft_set = kwarg.pop("set") + else: + nft_set = NFTSet(**kwarg) + kwarg = nft_set.as_netlink() + msg = self._command(nft_set_msg, commands, cmd, kwarg) + if cmd == "get": + return NFTSet.from_netlink(msg) + return nft_set + + def set_elems(self, cmd, **kwarg): + ''' + Example:: + nft.set_elems("add", table="filter", set="test0", + elements={"10.2.3.4", "10.4.3.2"}) + nft.set_elems("add", set=NFTSet(table="filter", name="test0"), + elements=[{"value": "10.2.3.4", "timeout": 10000}]) + nft.set_elems("add", table="filter", set="test0", + elements=[NFTSetElem(value="10.2.3.4", + timeout=10000, + comment="hello world")]) + nft.set_elems("get", table="filter", set="test0") + nft.set_elems("del", table="filter", set="test0", + elements=["10.2.3.4"]) + ''' + commands = { + 'add': NFT_MSG_NEWSETELEM, + 'get': NFT_MSG_GETSETELEM, + 'del': NFT_MSG_DELSETELEM, + } + if isinstance(kwarg["set"], NFTSet): + nft_set = kwarg.pop("set") + kwarg["table"] = nft_set.table + kwarg["set"] = nft_set.name + else: + nft_set = self.sets("get", table=kwarg["table"], name=kwarg["set"]) + + found = DATA_TYPE_NAME_TO_INFO.get(nft_set.key_type) + if found: + _, _, modifier = found + modifier = modifier() + modifier.header = None + else: + modifier = None + + if cmd == "get": + msg = nft_set_elem_list_msg() + msg['attrs'] = [ + ["NFTA_SET_ELEM_LIST_TABLE", kwarg["table"]], + ["NFTA_SET_ELEM_LIST_SET", kwarg["set"]], + ] + msg = self.request_get(msg, NFT_MSG_GETSETELEM)[0] + elements = set() + for elem in msg.get_attr('NFTA_SET_ELEM_LIST_ELEMENTS'): + elements.add(NFTSetElem.from_netlink(elem, modifier)) + return elements + + elements = [] + for elem in kwarg.pop("elements"): + if isinstance(elem, dict): + elem = NFTSetElem.from_dict(elem) + elif not isinstance(elem, NFTSetElem): + elem = NFTSetElem(value=elem) + elements.append(elem.as_netlink(modifier)) + kwarg["elements"] = elements + return self._command(nft_set_elem_list_msg, commands, cmd, kwarg)