Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nftables: Add more operations and raise kernel errors #902

Merged
merged 4 commits into from
Apr 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions pyroute2.core/pr2modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,28 @@ def alloc(self):
else:
raise KeyError('no free address available')

def alloc_multi(self, count):
with self.lock:
addresses = []
raised = False
try:
for _ in range(count):
addr = self.alloc()
try:
addresses.append(addr)
except:
# In case of a MemoryError during appending,
# the finally block would not free the address.
self.free(addr)
return addresses
except:
raised = True
raise
finally:
if raised:
for addr in addresses:
self.free(addr)

def locate(self, addr):
if self.reverse:
addr = self.maxaddr - addr
Expand Down
44 changes: 40 additions & 4 deletions pyroute2.core/pr2modules/netlink/nfnetlink/nftsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
"""

import threading
from pr2modules.netlink import NLM_F_REQUEST
from pr2modules.netlink import (
NLM_F_REQUEST,
NLM_F_ACK,
NLM_F_CREATE,
NLM_F_APPEND,
NLM_F_EXCL,
NLM_F_REPLACE,
)
from pr2modules.netlink import NLM_F_DUMP
from pr2modules.netlink import NETLINK_NETFILTER
from pr2modules.netlink import nla
Expand Down Expand Up @@ -1096,7 +1103,17 @@ def request_put(self, msg, msg_type, msg_flags=NLM_F_REQUEST):
if one_shot:
self.commit()

def _command(self, msg_class, commands, cmd, kwarg, flags=NLM_F_REQUEST):
def _command(self, msg_class, commands, cmd, kwarg):
flags = kwarg.pop('flags', NLM_F_ACK)
cmd_name = cmd
cmd_flags = {
'add': NLM_F_CREATE | NLM_F_APPEND,
'create': NLM_F_CREATE | NLM_F_APPEND | NLM_F_EXCL,
'insert': NLM_F_CREATE,
'replace': NLM_F_REPLACE,
}
flags |= cmd_flags.get(cmd, 0)
flags |= NLM_F_REQUEST
cmd = commands[cmd]
msg = msg_class()
msg['attrs'] = []
Expand All @@ -1115,5 +1132,24 @@ def _command(self, msg_class, commands, cmd, kwarg, flags=NLM_F_REQUEST):
for key, value in kwarg.items():
nla = msg_class.name2nla(key)
msg['attrs'].append([nla, value])
#
return self.request_put(msg, msg_type=cmd, msg_flags=flags)
msg['header']['type'] = (NFNL_SUBSYS_NFTABLES << 8) | cmd
msg['header']['flags'] = flags | NLM_F_REQUEST
msg['nfgen_family'] = self._nfgen_family

if cmd_name != 'get':
trans_start = nfgen_msg()
trans_start['res_id'] = NFNL_SUBSYS_NFTABLES
trans_start['header']['type'] = 0x10
trans_start['header']['flags'] = NLM_F_REQUEST

trans_end = nfgen_msg()
trans_end['res_id'] = NFNL_SUBSYS_NFTABLES
trans_end['header']['type'] = 0x11
trans_end['header']['flags'] = NLM_F_REQUEST

messages = [trans_start, msg, trans_end]
self.nlm_request_batch(messages, noraise=(flags & NLM_F_ACK) == 0)
# Only throw an error when the request fails. For now,
# do not return anything.
else:
return self.request_get(msg, msg['header']['type'], flags)[0]
81 changes: 77 additions & 4 deletions pyroute2.core/pr2modules/netlink/nlsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
from pr2modules.config import AF_NETLINK
from pr2modules.common import AddrPool
from pr2modules.common import DEFAULT_RCVBUF
from pr2modules.netlink import nlmsg
from pr2modules.netlink import nlmsg, NLM_F_ACK
from pr2modules.netlink import nlmsgerr
from pr2modules.netlink import mtypes
from pr2modules.netlink import NLMSG_ERROR
Expand Down Expand Up @@ -412,6 +412,12 @@ def get(*argv, **kwarg):
self.nlm_request = nlm_request
self.get = get

def nlm_request_batch(*argv, **kwarg):
return tuple(self._genlm_request_batch(*argv, **kwarg))

self._genlm_request_batch = self.nlm_request_batch
self.nlm_request_batch = nlm_request_batch

# Set defaults
self.post_init()

Expand Down Expand Up @@ -597,6 +603,21 @@ def async_recv(self):
else:
return

def _send_batch(self, msgs, addr=(0, 0)):
with self.backlog_lock:
for msg in msgs:
self.backlog[msg['header']['sequence_number']] = []
# We have locked the message locks in the caller already.
data = bytearray()
for msg in msgs:
if not isinstance(msg, nlmsg):
msg_class = self.marshal.msg_map[msg['header']['type']]
msg = msg_class(msg)
msg.reset()
msg.encode()
data += msg.data
self._sock.sendto(data, addr)

def put(
self,
msg,
Expand Down Expand Up @@ -655,7 +676,12 @@ def sendto_gate(self, msg, addr):
raise NotImplementedError()

def get(
self, bufsize=DEFAULT_RCVBUF, msg_seq=0, terminate=None, callback=None
self,
bufsize=DEFAULT_RCVBUF,
msg_seq=0,
terminate=None,
callback=None,
noraise=False,
):
'''
Get parsed messages list. If `msg_seq` is given, return
Expand All @@ -670,6 +696,9 @@ def get(
the network data
- 0: bufsize will be calculated from SO_RCVBUF sockopt
- int >= 0: just a bufsize

If `noraise` is true, error messages will be treated as any
other message.
'''
ctime = time.time()

Expand Down Expand Up @@ -728,7 +757,10 @@ def get(
self.backlog[msg_seq].remove(msg)

# If there is an error, raise exception
if msg['header']['error'] is not None:
if (
msg['header']['error'] is not None
and not noraise
):
# reschedule all the remaining messages,
# including errors and acks, into a
# separate deque
Expand Down Expand Up @@ -893,6 +925,47 @@ def get(
if backlog_acquired:
self.backlog_lock.release()

def nlm_request_batch(self, msgs, noraise=False):
"""
This function is for messages which are expected to have side effects.
Do not blindly retry in case of errors as this might duplicate them.
"""
expected_responses = []
acquired = 0
seqs = self.addr_pool.alloc_multi(len(msgs))
try:
for seq in seqs:
self.lock[seq].acquire()
acquired += 1
for seq, msg in zip(seqs, msgs):
msg['header']['sequence_number'] = seq
if 'pid' not in msg['header']:
msg['header']['pid'] = self.epid or os.getpid()
if (msg['header']['flags'] & NLM_F_ACK) or (
msg['header']['flags'] & NLM_F_DUMP
):
expected_responses.append(seq)
self._send_batch(msgs)

for seq in expected_responses:
for msg in self.get(msg_seq=seq, noraise=noraise):
if msg['header']['flags'] & NLM_F_DUMP_INTR:
# Leave error handling to the caller
raise NetlinkDumpInterrupted()
yield msg
finally:
# Release locks in reverse order.
for seq in seqs[acquired - 1 :: -1]:
self.lock[seq].release()

with self.backlog_lock:
for seq in seqs:
# Clear the backlog. We may have raised an error
# causing the backlog to not be consumed entirely.
if seq in self.backlog:
del self.backlog[seq]
self.addr_pool.free(seq, ban=0xFF)

def nlm_request(
self,
msg,
Expand Down Expand Up @@ -924,7 +997,7 @@ def nlm_request(
yield msg
break
except NetlinkError as e:
if e.code != 16:
if e.code != errno.EBUSY:
raise
if retry_count >= 30:
raise
Expand Down
30 changes: 21 additions & 9 deletions pyroute2.nftables/pr2modules/nftables/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def table(self, cmd, **kwarg):

nft.table('add', name='test0')
'''
commands = {'add': NFT_MSG_NEWTABLE, 'del': NFT_MSG_DELTABLE}
# fix default kwargs
if 'flags' not in kwarg:
kwarg['flags'] = 0
commands = {
'add': NFT_MSG_NEWTABLE,
'create': NFT_MSG_NEWTABLE,
'del': NFT_MSG_DELTABLE,
'get': NFT_MSG_GETTABLE,
}
return self._command(nft_table_msg, commands, cmd, kwarg)

def chain(self, cmd, **kwarg):
Expand All @@ -69,7 +71,12 @@ def chain(self, cmd, **kwarg):
type='filter',
policy=0)
'''
commands = {'add': NFT_MSG_NEWCHAIN, 'del': NFT_MSG_DELCHAIN}
commands = {
'add': NFT_MSG_NEWCHAIN,
'create': NFT_MSG_NEWCHAIN,
'del': NFT_MSG_DELCHAIN,
'get': NFT_MSG_GETCHAIN,
}
# TODO: What about 'ingress' (netdev family)?
hooks = {
'prerouting': 0,
Expand Down Expand Up @@ -103,13 +110,18 @@ def rule(self, cmd, **kwarg):
expressions=(ipv4addr(src='192.168.0.0/24'),
verdict(code=1)))
'''
# TODO: more operations
commands = {'add': NFT_MSG_NEWRULE, 'del': NFT_MSG_DELRULE}
commands = {
'add': NFT_MSG_NEWRULE,
'create': NFT_MSG_NEWRULE,
'insert': NFT_MSG_NEWRULE,
'replace': NFT_MSG_NEWRULE,
'del': NFT_MSG_DELRULE,
'get': NFT_MSG_GETRULE,
}

if 'expressions' in kwarg:
expressions = []
for exp in kwarg['expressions']:
expressions.extend(exp)
kwarg['expressions'] = expressions
# FIXME: flags!!!
return self._command(nft_rule_msg, commands, cmd, kwarg, flags=3585)
return self._command(nft_rule_msg, commands, cmd, kwarg)