Skip to content

Commit

Permalink
nlsocket: Implement improved batch handling
Browse files Browse the repository at this point in the history
Currently, the handling of batch messages is difficult as nlm_request
sends each message separately and listens for responses from a single
message sequence number only. Therefore, this commit introduces
`nlm_request_batch`, which can be supplied with a list of messages. The
function will then listen for expected responses for any of the supplied
messages and return/yield them to the caller.

In the future, this will allow for improved error handling. Since some
nftables write operations require batches, the current approach is to
accumulate messages in a batch and send them to the socket in a
fire-and-forget manner. This causes errors reported by the kernel to go
unnoticed.

This commit prepares improved error handling for nftables by allowing to
read those errors. It is related to issue #892.
  • Loading branch information
blechschmidt committed Apr 23, 2022
1 parent 8d6498d commit 39741a4
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 4 deletions.
20 changes: 20 additions & 0 deletions pyroute2.core/pr2modules/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,26 @@ def alloc(self):
else:
raise KeyError('no free address available')

def alloc_multi(self, count):
with self.lock:
addresses = []
raised = False
try:
for _ in range(count):
addr = self.alloc()
try:
addresses.append(addr)
except: # In case of a MemoryError during appending, the finally block would not free the address.
self.free(addr)
return addresses
except:
raised = True
raise
finally:
if raised:
for addr in addresses:
self.free(addr)

def locate(self, addr):
if self.reverse:
addr = self.maxaddr - addr
Expand Down
72 changes: 68 additions & 4 deletions pyroute2.core/pr2modules/netlink/nlsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
from pr2modules.config import AF_NETLINK
from pr2modules.common import AddrPool
from pr2modules.common import DEFAULT_RCVBUF
from pr2modules.netlink import nlmsg
from pr2modules.netlink import nlmsg, NLM_F_ACK
from pr2modules.netlink import nlmsgerr
from pr2modules.netlink import mtypes
from pr2modules.netlink import NLMSG_ERROR
Expand Down Expand Up @@ -412,6 +412,12 @@ def get(*argv, **kwarg):
self.nlm_request = nlm_request
self.get = get

def nlm_request_batch(*argv, **kwarg):
return tuple(self._genlm_request_batch(*argv, **kwarg))

self._genlm_request_batch = self.nlm_request_batch
self.nlm_request_batch = nlm_request_batch

# Set defaults
self.post_init()

Expand Down Expand Up @@ -597,6 +603,21 @@ def async_recv(self):
else:
return

def _send_batch(self, msgs, addr=(0, 0)):
with self.backlog_lock:
for msg in msgs:
self.backlog[msg['header']['sequence_number']] = []
# We have locked the message locks in the caller already.
data = bytearray()
for msg in msgs:
if not isinstance(msg, nlmsg):
msg_class = self.marshal.msg_map[msg['header']['type']]
msg = msg_class(msg)
msg.reset()
msg.encode()
data += msg.data
self._sock.sendto(data, addr)

def put(
self,
msg,
Expand Down Expand Up @@ -655,7 +676,8 @@ def sendto_gate(self, msg, addr):
raise NotImplementedError()

def get(
self, bufsize=DEFAULT_RCVBUF, msg_seq=0, terminate=None, callback=None
self, bufsize=DEFAULT_RCVBUF, msg_seq=0, terminate=None, callback=None,
noraise=False
):
'''
Get parsed messages list. If `msg_seq` is given, return
Expand All @@ -670,6 +692,9 @@ def get(
the network data
- 0: bufsize will be calculated from SO_RCVBUF sockopt
- int >= 0: just a bufsize
If `noraise` is true, error messages will be treated as any
other message.
'''
ctime = time.time()

Expand Down Expand Up @@ -728,7 +753,7 @@ def get(
self.backlog[msg_seq].remove(msg)

# If there is an error, raise exception
if msg['header']['error'] is not None:
if msg['header']['error'] is not None and not noraise:
# reschedule all the remaining messages,
# including errors and acks, into a
# separate deque
Expand Down Expand Up @@ -893,6 +918,45 @@ def get(
if backlog_acquired:
self.backlog_lock.release()

def nlm_request_batch(self, msgs, noraise=False):
"""
This function is for messages which are expected to have side effects.
Do not blindly retry in case of errors as this might duplicate them.
"""
expected_responses = []
acquired = 0
seqs = self.addr_pool.alloc_multi(len(msgs))
try:
for seq in seqs:
self.lock[seq].acquire()
acquired += 1
for seq, msg in zip(seqs, msgs):
msg['header']['sequence_number'] = seq
if 'pid' not in msg['header']:
msg['header']['pid'] = self.epid or os.getpid()
if (msg['header']['flags'] & NLM_F_ACK) or (
msg['header']['flags'] & NLM_F_DUMP):
expected_responses.append(seq)
self._send_batch(msgs)

for seq in expected_responses:
for msg in self.get(msg_seq=seq, noraise=noraise):
if msg['header']['flags'] & NLM_F_DUMP_INTR:
raise NetlinkDumpInterrupted() # Leave error handling to the caller
yield msg
finally:
# Release locks in reverse order.
for seq in seqs[acquired - 1::-1]:
self.lock[seq].release()

with self.backlog_lock:
for seq in seqs:
# Clear the backlog. We may have raised an error
# causing the backlog to not be consumed entirely.
if seq in self.backlog:
del self.backlog[seq]
self.addr_pool.free(seq, ban=0xFF)

def nlm_request(
self,
msg,
Expand Down Expand Up @@ -924,7 +988,7 @@ def nlm_request(
yield msg
break
except NetlinkError as e:
if e.code != 16:
if e.code != errno.EBUSY:
raise
if retry_count >= 30:
raise
Expand Down

0 comments on commit 39741a4

Please sign in to comment.