Skip to content

Commit

Permalink
nftables: merge pull request #902 from blechschmidt/master
Browse files Browse the repository at this point in the history
nftables: Add more operations and raise kernel errors

Bug-Url: #902
  • Loading branch information
svinota authored Apr 24, 2022
2 parents 8d6498d + 40edd81 commit 39c3e04
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 17 deletions.
22 changes: 22 additions & 0 deletions pyroute2.core/pr2modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,28 @@ def alloc(self):
else:
raise KeyError('no free address available')

def alloc_multi(self, count):
with self.lock:
addresses = []
raised = False
try:
for _ in range(count):
addr = self.alloc()
try:
addresses.append(addr)
except:
# In case of a MemoryError during appending,
# the finally block would not free the address.
self.free(addr)
return addresses
except:
raised = True
raise
finally:
if raised:
for addr in addresses:
self.free(addr)

def locate(self, addr):
if self.reverse:
addr = self.maxaddr - addr
Expand Down
44 changes: 40 additions & 4 deletions pyroute2.core/pr2modules/netlink/nfnetlink/nftsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
"""

import threading
from pr2modules.netlink import NLM_F_REQUEST
from pr2modules.netlink import (
NLM_F_REQUEST,
NLM_F_ACK,
NLM_F_CREATE,
NLM_F_APPEND,
NLM_F_EXCL,
NLM_F_REPLACE,
)
from pr2modules.netlink import NLM_F_DUMP
from pr2modules.netlink import NETLINK_NETFILTER
from pr2modules.netlink import nla
Expand Down Expand Up @@ -1096,7 +1103,17 @@ def request_put(self, msg, msg_type, msg_flags=NLM_F_REQUEST):
if one_shot:
self.commit()

def _command(self, msg_class, commands, cmd, kwarg, flags=NLM_F_REQUEST):
def _command(self, msg_class, commands, cmd, kwarg):
flags = kwarg.pop('flags', NLM_F_ACK)
cmd_name = cmd
cmd_flags = {
'add': NLM_F_CREATE | NLM_F_APPEND,
'create': NLM_F_CREATE | NLM_F_APPEND | NLM_F_EXCL,
'insert': NLM_F_CREATE,
'replace': NLM_F_REPLACE,
}
flags |= cmd_flags.get(cmd, 0)
flags |= NLM_F_REQUEST
cmd = commands[cmd]
msg = msg_class()
msg['attrs'] = []
Expand All @@ -1115,5 +1132,24 @@ def _command(self, msg_class, commands, cmd, kwarg, flags=NLM_F_REQUEST):
for key, value in kwarg.items():
nla = msg_class.name2nla(key)
msg['attrs'].append([nla, value])
#
return self.request_put(msg, msg_type=cmd, msg_flags=flags)
msg['header']['type'] = (NFNL_SUBSYS_NFTABLES << 8) | cmd
msg['header']['flags'] = flags | NLM_F_REQUEST
msg['nfgen_family'] = self._nfgen_family

if cmd_name != 'get':
trans_start = nfgen_msg()
trans_start['res_id'] = NFNL_SUBSYS_NFTABLES
trans_start['header']['type'] = 0x10
trans_start['header']['flags'] = NLM_F_REQUEST

trans_end = nfgen_msg()
trans_end['res_id'] = NFNL_SUBSYS_NFTABLES
trans_end['header']['type'] = 0x11
trans_end['header']['flags'] = NLM_F_REQUEST

messages = [trans_start, msg, trans_end]
self.nlm_request_batch(messages, noraise=(flags & NLM_F_ACK) == 0)
# Only throw an error when the request fails. For now,
# do not return anything.
else:
return self.request_get(msg, msg['header']['type'], flags)[0]
81 changes: 77 additions & 4 deletions pyroute2.core/pr2modules/netlink/nlsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
from pr2modules.config import AF_NETLINK
from pr2modules.common import AddrPool
from pr2modules.common import DEFAULT_RCVBUF
from pr2modules.netlink import nlmsg
from pr2modules.netlink import nlmsg, NLM_F_ACK
from pr2modules.netlink import nlmsgerr
from pr2modules.netlink import mtypes
from pr2modules.netlink import NLMSG_ERROR
Expand Down Expand Up @@ -412,6 +412,12 @@ def get(*argv, **kwarg):
self.nlm_request = nlm_request
self.get = get

def nlm_request_batch(*argv, **kwarg):
return tuple(self._genlm_request_batch(*argv, **kwarg))

self._genlm_request_batch = self.nlm_request_batch
self.nlm_request_batch = nlm_request_batch

# Set defaults
self.post_init()

Expand Down Expand Up @@ -597,6 +603,21 @@ def async_recv(self):
else:
return

def _send_batch(self, msgs, addr=(0, 0)):
with self.backlog_lock:
for msg in msgs:
self.backlog[msg['header']['sequence_number']] = []
# We have locked the message locks in the caller already.
data = bytearray()
for msg in msgs:
if not isinstance(msg, nlmsg):
msg_class = self.marshal.msg_map[msg['header']['type']]
msg = msg_class(msg)
msg.reset()
msg.encode()
data += msg.data
self._sock.sendto(data, addr)

def put(
self,
msg,
Expand Down Expand Up @@ -655,7 +676,12 @@ def sendto_gate(self, msg, addr):
raise NotImplementedError()

def get(
self, bufsize=DEFAULT_RCVBUF, msg_seq=0, terminate=None, callback=None
self,
bufsize=DEFAULT_RCVBUF,
msg_seq=0,
terminate=None,
callback=None,
noraise=False,
):
'''
Get parsed messages list. If `msg_seq` is given, return
Expand All @@ -670,6 +696,9 @@ def get(
the network data
- 0: bufsize will be calculated from SO_RCVBUF sockopt
- int >= 0: just a bufsize
If `noraise` is true, error messages will be treated as any
other message.
'''
ctime = time.time()

Expand Down Expand Up @@ -728,7 +757,10 @@ def get(
self.backlog[msg_seq].remove(msg)

# If there is an error, raise exception
if msg['header']['error'] is not None:
if (
msg['header']['error'] is not None
and not noraise
):
# reschedule all the remaining messages,
# including errors and acks, into a
# separate deque
Expand Down Expand Up @@ -893,6 +925,47 @@ def get(
if backlog_acquired:
self.backlog_lock.release()

def nlm_request_batch(self, msgs, noraise=False):
"""
This function is for messages which are expected to have side effects.
Do not blindly retry in case of errors as this might duplicate them.
"""
expected_responses = []
acquired = 0
seqs = self.addr_pool.alloc_multi(len(msgs))
try:
for seq in seqs:
self.lock[seq].acquire()
acquired += 1
for seq, msg in zip(seqs, msgs):
msg['header']['sequence_number'] = seq
if 'pid' not in msg['header']:
msg['header']['pid'] = self.epid or os.getpid()
if (msg['header']['flags'] & NLM_F_ACK) or (
msg['header']['flags'] & NLM_F_DUMP
):
expected_responses.append(seq)
self._send_batch(msgs)

for seq in expected_responses:
for msg in self.get(msg_seq=seq, noraise=noraise):
if msg['header']['flags'] & NLM_F_DUMP_INTR:
# Leave error handling to the caller
raise NetlinkDumpInterrupted()
yield msg
finally:
# Release locks in reverse order.
for seq in seqs[acquired - 1 :: -1]:
self.lock[seq].release()

with self.backlog_lock:
for seq in seqs:
# Clear the backlog. We may have raised an error
# causing the backlog to not be consumed entirely.
if seq in self.backlog:
del self.backlog[seq]
self.addr_pool.free(seq, ban=0xFF)

def nlm_request(
self,
msg,
Expand Down Expand Up @@ -924,7 +997,7 @@ def nlm_request(
yield msg
break
except NetlinkError as e:
if e.code != 16:
if e.code != errno.EBUSY:
raise
if retry_count >= 30:
raise
Expand Down
30 changes: 21 additions & 9 deletions pyroute2.nftables/pr2modules/nftables/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def table(self, cmd, **kwarg):
nft.table('add', name='test0')
'''
commands = {'add': NFT_MSG_NEWTABLE, 'del': NFT_MSG_DELTABLE}
# fix default kwargs
if 'flags' not in kwarg:
kwarg['flags'] = 0
commands = {
'add': NFT_MSG_NEWTABLE,
'create': NFT_MSG_NEWTABLE,
'del': NFT_MSG_DELTABLE,
'get': NFT_MSG_GETTABLE,
}
return self._command(nft_table_msg, commands, cmd, kwarg)

def chain(self, cmd, **kwarg):
Expand All @@ -69,7 +71,12 @@ def chain(self, cmd, **kwarg):
type='filter',
policy=0)
'''
commands = {'add': NFT_MSG_NEWCHAIN, 'del': NFT_MSG_DELCHAIN}
commands = {
'add': NFT_MSG_NEWCHAIN,
'create': NFT_MSG_NEWCHAIN,
'del': NFT_MSG_DELCHAIN,
'get': NFT_MSG_GETCHAIN,
}
# TODO: What about 'ingress' (netdev family)?
hooks = {
'prerouting': 0,
Expand Down Expand Up @@ -103,13 +110,18 @@ def rule(self, cmd, **kwarg):
expressions=(ipv4addr(src='192.168.0.0/24'),
verdict(code=1)))
'''
# TODO: more operations
commands = {'add': NFT_MSG_NEWRULE, 'del': NFT_MSG_DELRULE}
commands = {
'add': NFT_MSG_NEWRULE,
'create': NFT_MSG_NEWRULE,
'insert': NFT_MSG_NEWRULE,
'replace': NFT_MSG_NEWRULE,
'del': NFT_MSG_DELRULE,
'get': NFT_MSG_GETRULE,
}

if 'expressions' in kwarg:
expressions = []
for exp in kwarg['expressions']:
expressions.extend(exp)
kwarg['expressions'] = expressions
# FIXME: flags!!!
return self._command(nft_rule_msg, commands, cmd, kwarg, flags=3585)
return self._command(nft_rule_msg, commands, cmd, kwarg)

0 comments on commit 39c3e04

Please sign in to comment.