Skip to content

Commit

Permalink
nftables: merge pull request #1017 from inemajo/nftables_sets
Browse files Browse the repository at this point in the history
Draft: Nftables sets

Bug-Url: #1017
Bug-Url: #1013
  • Loading branch information
svinota authored Oct 27, 2022
2 parents 642f0a4 + 4a107fa commit dc05a0e
Show file tree
Hide file tree
Showing 3 changed files with 504 additions and 5 deletions.
53 changes: 53 additions & 0 deletions examples/nftables_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import time

from pyroute2.netlink.nfnetlink.nftsocket import NFPROTO_IPV4
from pyroute2.nftables.main import NFTables
from pyroute2.nftables.main import NFTSetElem


def test_ipv4_addr_set():
with NFTables(nfgen_family=NFPROTO_IPV4) as nft:
nft.table("add", name="filter")
my_set = nft.sets("add", table="filter", name="test0", key_type="ipv4_addr",
comment="my test set", timeout=0)

# With str
nft.set_elems(
"add",
table="filter",
set="test0",
elements={"10.2.3.4", "10.4.3.2"},
)

# With NFTSet & NFTSetElem classes
nft.set_elems(
"add",
set=my_set,
elements={NFTSetElem(value="9.9.9.9", timeout=1000)},
)

try:
assert {e.value for e in nft.set_elems("get", table="filter", set="test0")} == {
"10.2.3.4",
"10.4.3.2",
"9.9.9.9",
}
assert nft.sets("get", table="filter", name="test0").comment == b"my test set"

time.sleep(1.2)
# timeout for elem 9.9.9.9 (1000ms)
assert {e.value for e in nft.set_elems("get", table="filter", set="test0")} == {
"10.2.3.4",
"10.4.3.2",
}
finally:
nft.sets("del", table="filter", name="test0")
nft.table("del", name="filter")


def main():
test_ipv4_addr_set()


if __name__ == "__main__":
main()
174 changes: 169 additions & 5 deletions pyroute2/netlink/nfnetlink/nftsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
See also: pyroute2.nftables
"""

import struct
import threading

from pyroute2.netlink import (
Expand All @@ -16,6 +17,8 @@
NLM_F_REPLACE,
NLM_F_REQUEST,
nla,
nla_base_string,
nlmsg_atoms,
)
from pyroute2.netlink.nfnetlink import NFNL_SUBSYS_NFTABLES, nfgen_msg
from pyroute2.netlink.nlsocket import NetlinkSocket
Expand Down Expand Up @@ -46,6 +49,120 @@
NFT_MSG_GETFLOWTABLE = 23
NFT_MSG_DELFLOWTABLE = 24

# from nftables/include/datatype.h
DATA_TYPE_INVALID = 0
DATA_TYPE_VERDICT = 1
DATA_TYPE_NFPROTO = 2
DATA_TYPE_BITMASK = 3
DATA_TYPE_INTEGER = 4
DATA_TYPE_STRING = 5
DATA_TYPE_LLADDR = 6
DATA_TYPE_IPADDR = 7
DATA_TYPE_IP6ADDR = 8
DATA_TYPE_ETHERADDR = 9
DATA_TYPE_ETHERTYPE = 10
DATA_TYPE_ARPOP = 11
DATA_TYPE_INET_PROTOCOL = 12
DATA_TYPE_INET_SERVICE = 13
DATA_TYPE_ICMP_TYPE = 14
DATA_TYPE_TCP_FLAG = 15
DATA_TYPE_DCCP_PKTTYPE = 16
DATA_TYPE_MH_TYPE = 17
DATA_TYPE_TIME = 18
DATA_TYPE_MARK = 19
DATA_TYPE_IFINDEX = 20
DATA_TYPE_ARPHRD = 21
DATA_TYPE_REALM = 22
DATA_TYPE_CLASSID = 23
DATA_TYPE_UID = 24
DATA_TYPE_GID = 25
DATA_TYPE_CT_STATE = 26
DATA_TYPE_CT_DIR = 27
DATA_TYPE_CT_STATUS = 28
DATA_TYPE_ICMP6_TYPE = 29
DATA_TYPE_CT_LABEL = 30
DATA_TYPE_PKTTYPE = 31
DATA_TYPE_ICMP_CODE = 32
DATA_TYPE_ICMPV6_CODE = 33
DATA_TYPE_ICMPX_CODE = 34
DATA_TYPE_DEVGROUP = 35
DATA_TYPE_DSCP = 36
DATA_TYPE_ECN = 37
DATA_TYPE_FIB_ADDR = 38
DATA_TYPE_BOOLEAN = 39
DATA_TYPE_CT_EVENTBIT = 40
DATA_TYPE_IFNAME = 41
DATA_TYPE_IGMP_TYPE = 42
DATA_TYPE_TIME_DATE = 43
DATA_TYPE_TIME_HOUR = 44
DATA_TYPE_TIME_DAY = 45
DATA_TYPE_CGROUPV2 = 46

# from include/uapi/linux/netfilter.h
NFPROTO_INET = 1
NFPROTO_IPV4 = 2
NFPROTO_ARP = 3
NFPROTO_NETDEV = 5
NFPROTO_BRIDGE = 7
NFPROTO_IPV6 = 10


class nftnl_udata(nla_base_string):
# TLV structures:
# nftnl_udata
# <-------- HEADER --------> <------ PAYLOAD ------>
# +------------+-------------+- - - - - - - - - - - -+
# | type | len | value |
# | (1 byte) | (1 byte) | |
# +--------------------------+- - - - - - - - - - - -+
# <-- sizeof(nftnl_udata) -> <-- nftnl_udata->len -->
__slots__ = ()

@classmethod
def udata_decode(cls, data):
offset = 0
result = []
while offset + 2 < len(data):
utype = data[offset]
ulen = data[offset + 1]
offset += 2
if offset + ulen > len(data):
return None # bad decode
try:
type_name = cls.udata_types[utype]
except IndexError:
return None # bad decode

value = data[offset : offset + ulen]
if type_name.endswith("_COMMENT") and value[-1] == 0:
value = value[:-1] # remove \x00 c *str
result.append((type_name, value))
offset += ulen
return result

@classmethod
def udata_encode(cls, values):
value = b""
for type_name, udata in values:
if isinstance(udata, str):
udata = udata.encode()
if type_name.endswith("_COMMENT") and udata[-1] != 0:
udata = udata + b"\x00"
utype = cls.udata_types.index(type_name)
value += struct.pack("BB", utype, len(udata)) + udata
return value

def decode(self):
nla_base_string.decode(self)
value = self.udata_decode(self['value'])
if value is not None:
self.value = value

def encode(self):
if not isinstance(self.value, (bytes, str)):
self['value'] = self.udata_encode(self.value)
nla_base_string.encode(self)


class nft_map_uint8(nla):
ops = {}
Expand Down Expand Up @@ -74,6 +191,14 @@ def decode(self):
o for i, o in enumerate(self.ops) if self['value'] & 1 << i
)

def encode(self):
value = 0
for i, name in enumerate(self.ops):
if name in self.value:
value |= 1 << i
self["value"] = value
nla.encode(self)


class nft_flags_be16(nla):
fields = [('value', '>H')]
Expand Down Expand Up @@ -849,15 +974,28 @@ class nft_set_msg(nfgen_msg, nft_contains_expr):
('NFTA_SET_POLICY', 'set_policy'),
('NFTA_SET_DESC', 'set_desc'),
('NFTA_SET_ID', 'be32'),
('NFTA_SET_TIMEOUT', 'be32'),
('NFTA_SET_TIMEOUT', 'be64'),
('NFTA_SET_GC_INTERVAL', 'be32'),
('NFTA_SET_USERDATA', 'hex'),
('NFTA_SET_USERDATA', 'set_udata'),
('NFTA_SET_PAD', 'hex'),
('NFTA_SET_OBJ_TYPE', 'be32'),
('NFTA_SET_HANDLE', 'be64'),
('NFTA_SET_EXPR', '*nft_expr'),
('NFTA_SET_EXPR', 'nft_expr'),
('NFTA_SET_EXPRESSIONS', '*nft_expr'),
)

class set_udata(nftnl_udata):
udata_types = (
"NFTNL_UDATA_SET_KEYBYTEORDER",
"NFTNL_UDATA_SET_DATABYTEORDER",
"NFTNL_UDATA_SET_MERGE_ELEMENTS",
"NFTNL_UDATA_SET_KEY_TYPEOF",
"NFTNL_UDATA_SET_DATA_TYPEOF",
"NFTNL_UDATA_SET_EXPR",
"NFTNL_UDATA_SET_DATA_INTERVAL",
"NFTNL_UDATA_SET_COMMENT",
)

class set_flags(nft_flags_be32):
ops = (
'NFT_SET_ANONYMOUS',
Expand Down Expand Up @@ -924,13 +1062,20 @@ class set_elem(nla, nft_contains_expr):
('NFTA_SET_ELEM_FLAGS', 'set_elem_flags'),
('NFTA_SET_ELEM_TIMEOUT', 'be64'),
('NFTA_SET_ELEM_EXPIRATION', 'be64'),
('NFTA_SET_ELEM_USERDATA', 'binary'),
('NFTA_SET_ELEM_EXPR', '*nft_expr'),
('NFTA_SET_ELEM_USERDATA', 'set_elem_udata'),
('NFTA_SET_ELEM_EXPR', 'nft_expr'),
('NFTA_SET_ELEM_PAD', 'hex'),
('NFTA_SET_ELEM_OBJREF', 'asciiz'),
('NFTA_SET_ELEM_KEY_END', 'data_attributes'),
('NFTA_SET_ELEM_EXPRESSIONS', '*nft_expr'),
)

class set_elem_udata(nftnl_udata):
udata_types = (
"NFTNL_UDATA_SET_ELEM_COMMENT",
"NFTNL_UDATA_SET_ELEM_FLAGS",
)

class set_elem_flags(nft_flags_be32):
ops = {1: 'NFT_SET_ELEM_INTERVAL_END'}

Expand Down Expand Up @@ -1155,3 +1300,22 @@ def _command(self, msg_class, commands, cmd, kwarg):
# do not return anything.
else:
return self.request_get(msg, msg['header']['type'], flags)[0]


# call nft describe "data_type" for more informations
DATA_TYPE_NAME_TO_INFO = {
"verdict": (DATA_TYPE_VERDICT, 4, nft_data.nfta_data.verdict.verdict_code),
"nf_proto": (DATA_TYPE_NFPROTO, 1, nlmsg_atoms.uint8),
"bitmask": (DATA_TYPE_BITMASK, 4, nlmsg_atoms.uint32),
"integer": (DATA_TYPE_INTEGER, 4, nlmsg_atoms.int32),
"string": (DATA_TYPE_STRING, 0, nlmsg_atoms.asciiz),
"lladdr": (DATA_TYPE_LLADDR, 0, nlmsg_atoms.lladdr),
"ipv4_addr": (DATA_TYPE_IPADDR, 4, nlmsg_atoms.ip4addr),
"ipv6_addr": (DATA_TYPE_IP6ADDR, 16, nlmsg_atoms.ip6addr),
"ether_addr": (DATA_TYPE_ETHERADDR, 6, nlmsg_atoms.l2addr),
"ether_type": (DATA_TYPE_ETHERADDR, 2, nlmsg_atoms.uint16),
"inet_proto": (DATA_TYPE_INET_PROTOCOL, 1, nlmsg_atoms.uint8),
}
DATA_TYPE_ID_TO_NAME = {
value[0]: key for key, value in DATA_TYPE_NAME_TO_INFO.items()
}
Loading

0 comments on commit dc05a0e

Please sign in to comment.