Skip to content

Commit

Permalink
🛞 conn, service: improve mmsg api
Browse files Browse the repository at this point in the history
Return early to indicate error, like the underlying syscall does.
It no longer drops send errors on the floor, and allows more accurate stats.
  • Loading branch information
database64128 committed Oct 18, 2024
1 parent c8a3daa commit 60f50b0
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 149 deletions.
9 changes: 5 additions & 4 deletions conn/conn_mmsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ import (
"net"
)

// ListenUDPRawConn is like [ListenUDP] but wraps the [*net.UDPConn] in a [rawUDPConn] for batch I/O.
func (lc *ListenConfig) ListenUDPRawConn(ctx context.Context, network, address string) (c rawUDPConn, info SocketInfo, err error) {
// ListenUDPMmsgConn is like [ListenUDP] but wraps the [*net.UDPConn] in a [MmsgConn] for
// reading and writing multiple messages using the recvmmsg(2) and sendmmsg(2) system calls.
func (lc *ListenConfig) ListenUDPMmsgConn(ctx context.Context, network, address string) (c MmsgConn, info SocketInfo, err error) {
info.MaxUDPGSOSegments = 1
nlc := net.ListenConfig{
Control: lc.fns.controlFunc(&info),
}
pc, err := nlc.ListenPacket(ctx, network, address)
if err != nil {
return rawUDPConn{}, info, err
return MmsgConn{}, info, err
}
c, err = NewRawUDPConn(pc.(*net.UDPConn))
c, err = NewMmsgConn(pc.(*net.UDPConn))
return c, info, err
}
119 changes: 70 additions & 49 deletions conn/mmsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,100 +6,104 @@ import (
"net"
"os"
"syscall"
"unsafe"

"golang.org/x/sys/unix"
)

type rawUDPConn struct {
// MmsgConn wraps a [*net.UDPConn] and provides methods for reading and writing
// multiple messages using the recvmmsg(2) and sendmmsg(2) system calls.
type MmsgConn struct {
*net.UDPConn
rawConn syscall.RawConn
}

// NewRawUDPConn wraps a [net.UDPConn] in a [rawUDPConn] for batch I/O.
func NewRawUDPConn(udpConn *net.UDPConn) (rawUDPConn, error) {
// NewMmsgConn returns a new [MmsgConn] for udpConn.
func NewMmsgConn(udpConn *net.UDPConn) (MmsgConn, error) {
rawConn, err := udpConn.SyscallConn()
if err != nil {
return rawUDPConn{}, err
return MmsgConn{}, err
}

return rawUDPConn{
return MmsgConn{
UDPConn: udpConn,
rawConn: rawConn,
}, nil
}

// MmsgRConn wraps a [net.UDPConn] and provides the [ReadMsgs] method
// for reading multiple messages in a single recvmmsg(2) system call.
// MmsgRConn provides read access to the [MmsgConn].
//
// [MmsgRConn] is not safe for concurrent use.
// Use the [RConn] method to create a new [MmsgRConn] instance for each goroutine.
// MmsgRConn is not safe for concurrent use.
// Always create a new MmsgRConn for each goroutine.
type MmsgRConn struct {
rawUDPConn
MmsgConn
rawReadFunc func(fd uintptr) (done bool)
readMsgvec []Mmsghdr
readFlags int
readN int
readErr error
}

// MmsgWConn wraps a [net.UDPConn] and provides the [WriteMsgs] method
// for writing multiple messages in a single sendmmsg(2) system call.
// MmsgWConn provides write access to the [MmsgConn].
//
// [MmsgWConn] is not safe for concurrent use.
// Use the [WConn] method to create a new [MmsgWConn] instance for each goroutine.
// MmsgWConn is not safe for concurrent use.
// Always create a new MmsgWConn for each goroutine.
type MmsgWConn struct {
rawUDPConn
MmsgConn
rawWriteFunc func(fd uintptr) (done bool)
writeMsgvec []Mmsghdr
writeFlags int
writeErrno syscall.Errno
writeN int
writeErr error
}

// RConn returns a new [MmsgRConn] instance for batch reading.
func (c rawUDPConn) RConn() *MmsgRConn {
mmsgRConn := MmsgRConn{
rawUDPConn: c,
// NewRConn returns the connection wrapped in a new [*MmsgRConn] for batch reading.
func (c MmsgConn) NewRConn() *MmsgRConn {
rc := MmsgRConn{
MmsgConn: c,
}

mmsgRConn.rawReadFunc = func(fd uintptr) (done bool) {
rc.rawReadFunc = func(fd uintptr) (done bool) {
var errno syscall.Errno
mmsgRConn.readN, errno = recvmmsg(int(fd), mmsgRConn.readMsgvec, mmsgRConn.readFlags)
rc.readN, errno = recvmmsg(int(fd), rc.readMsgvec, rc.readFlags)
switch errno {
case 0:
rc.readErr = nil
case syscall.EAGAIN:
return false
default:
mmsgRConn.readErr = os.NewSyscallError("recvmmsg", errno)
rc.readErr = os.NewSyscallError("recvmmsg", errno)
}
return true
}

return &mmsgRConn
return &rc
}

// WConn returns a new [MmsgWConn] instance for batch writing.
func (c rawUDPConn) WConn() *MmsgWConn {
mmsgWConn := MmsgWConn{
rawUDPConn: c,
// NewWConn returns the connection wrapped in a new [*MmsgWConn] for batch writing.
func (c MmsgConn) NewWConn() *MmsgWConn {
wc := MmsgWConn{
MmsgConn: c,
}

mmsgWConn.rawWriteFunc = func(fd uintptr) (done bool) {
wc.rawWriteFunc = func(fd uintptr) (done bool) {
wc.writeN = 0
for {
n, errno := sendmmsg(int(fd), mmsgWConn.writeMsgvec, mmsgWConn.writeFlags)
n, errno := sendmmsg(int(fd), wc.writeMsgvec, wc.writeFlags)
switch errno {
case 0:
case syscall.EAGAIN:
return false
default:
mmsgWConn.writeErrno = errno
mmsgWConn.writeMsgvec = mmsgWConn.writeMsgvec[1:]
if len(mmsgWConn.writeMsgvec) == 0 {
return true
}
continue
wc.writeErr = os.NewSyscallError("sendmmsg", errno)
return true
}

mmsgWConn.writeMsgvec = mmsgWConn.writeMsgvec[n:]
wc.writeMsgvec = wc.writeMsgvec[n:]
wc.writeN += n

if len(mmsgWConn.writeMsgvec) == 0 {
if len(wc.writeMsgvec) == 0 {
wc.writeErr = nil
return true
}

Expand All @@ -120,32 +124,49 @@ func (c rawUDPConn) WConn() *MmsgWConn {
}
}

return &mmsgWConn
return &wc
}

// ReadMsgs reads as many messages as possible into the given msgvec
// ReadMsgs reads as many messages as possible into msgvec
// and returns the number of messages read or an error.
func (c *MmsgRConn) ReadMsgs(msgvec []Mmsghdr, flags int) (int, error) {
c.readMsgvec = msgvec
c.readFlags = flags
c.readN = 0
c.readErr = nil
if err := c.rawConn.Read(c.rawReadFunc); err != nil {
return 0, err
}
return c.readN, c.readErr
}

// WriteMsgs writes all messages in the given msgvec and returns the last encountered error.
func (c *MmsgWConn) WriteMsgs(msgvec []Mmsghdr, flags int) error {
// WriteMsgs writes the messages in msgvec to the connection.
// It returns the number of messages written as n, and if n < len(msgvec),
// the error from writing the n-th message.
func (c *MmsgWConn) WriteMsgs(msgvec []Mmsghdr, flags int) (int, error) {
c.writeMsgvec = msgvec
c.writeFlags = flags
c.writeErrno = 0
if err := c.rawConn.Write(c.rawWriteFunc); err != nil {
return err
return 0, err
}
if c.writeErrno != 0 {
return os.NewSyscallError("sendmmsg", c.writeErrno)
return c.writeN, c.writeErr
}

type Mmsghdr struct {
Msghdr unix.Msghdr
Msglen uint32
}

func mmsgSyscall(trap uintptr, fd int, msgvec []Mmsghdr, flags int) (int, syscall.Errno) {
r0, _, e1 := unix.Syscall6(trap, uintptr(fd), uintptr(unsafe.Pointer(unsafe.SliceData(msgvec))), uintptr(len(msgvec)), uintptr(flags), 0, 0)
if e1 != 0 {
return 0, e1
}
return nil
return int(r0), 0
}

func recvmmsg(fd int, msgvec []Mmsghdr, flags int) (int, syscall.Errno) {
return mmsgSyscall(SYS_RECVMMSG, fd, msgvec, flags)
}

func sendmmsg(fd int, msgvec []Mmsghdr, flags int) (int, syscall.Errno) {
return mmsgSyscall(unix.SYS_SENDMMSG, fd, msgvec, flags)
}
File renamed without changes.
File renamed without changes.
26 changes: 0 additions & 26 deletions conn/syscall_mmsg.go

This file was deleted.

10 changes: 0 additions & 10 deletions conn/ztypes_mmsg.go

This file was deleted.

72 changes: 42 additions & 30 deletions service/client_mmsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (c *client) setStartFunc(batchMode string) {
}

func (c *client) startMmsg(ctx context.Context) error {
wgConn, wgConnInfo, err := c.wgConnListenConfig.ListenUDPRawConn(ctx, c.wgListenNetwork, c.wgListenAddress)
wgConn, wgConnInfo, err := c.wgConnListenConfig.ListenUDPMmsgConn(ctx, c.wgListenNetwork, c.wgListenAddress)
if err != nil {
return err
}
Expand All @@ -63,7 +63,7 @@ func (c *client) startMmsg(ctx context.Context) error {
c.mwg.Add(1)

go func() {
c.recvFromWgConnRecvmmsg(ctx, wgConn.RConn(), wgConnInfo)
c.recvFromWgConnRecvmmsg(ctx, wgConn.NewRConn(), wgConnInfo)
c.mwg.Done()
}()

Expand Down Expand Up @@ -270,7 +270,7 @@ func (c *client) recvFromWgConnRecvmmsg(ctx context.Context, wgConn *conn.MmsgRC
return
}

proxyConn, proxyConnInfo, err := c.proxyConnListenConfig.ListenUDPRawConn(ctx, c.proxyConnListenNetwork, c.proxyConnListenAddress)
proxyConn, proxyConnInfo, err := c.proxyConnListenConfig.ListenUDPMmsgConn(ctx, c.proxyConnListenNetwork, c.proxyConnListenAddress)
if err != nil {
c.logger.Warn("Failed to create UDP socket for new session",
zap.String("client", c.name),
Expand Down Expand Up @@ -334,7 +334,7 @@ func (c *client) recvFromWgConnRecvmmsg(ctx context.Context, wgConn *conn.MmsgRC
c.relayWgToProxySendmmsg(clientNatUplinkMmsg{
clientAddrPort: clientAddrPort,
proxyAddrPort: proxyAddrPort,
proxyConn: proxyConn.WConn(),
proxyConn: proxyConn.NewWConn(),
proxyConnInfo: proxyConnInfo,
proxyConnSendCh: proxyConnSendCh,
handler: handler,
Expand All @@ -348,8 +348,8 @@ func (c *client) recvFromWgConnRecvmmsg(ctx context.Context, wgConn *conn.MmsgRC
clientPktinfop: clientPktinfop,
clientPktinfo: &natEntry.clientPktinfo,
proxyAddrPort: proxyAddrPort,
proxyConn: proxyConn.RConn(),
wgConn: wgConn.WConn(),
proxyConn: proxyConn.NewRConn(),
wgConn: wgConn.NewWConn(),
wgConnInfo: wgConnInfo,
handler: handler,
maxProxyPacketSize: maxProxyPacketSize,
Expand Down Expand Up @@ -573,14 +573,24 @@ main:

sendQueuedPackets = sendQueuedPackets[:0]

if err := uplink.proxyConn.WriteMsgs(msgvec, 0); err != nil {
c.logger.Warn("Failed to write swgpPacket to proxyConn",
zap.String("client", c.name),
zap.String("listenAddress", c.wgListenAddress),
zap.Stringer("clientAddress", uplink.clientAddrPort),
zap.Stringer("proxyAddress", uplink.proxyAddrPort),
zap.Error(err),
)
for start := 0; start < len(msgvec); {
n, err := uplink.proxyConn.WriteMsgs(msgvec[start:], 0)
start += n
if err != nil {
c.logger.Warn("Failed to write swgpPacket to proxyConn",
zap.String("client", c.name),
zap.String("listenAddress", c.wgListenAddress),
zap.Stringer("clientAddress", uplink.clientAddrPort),
zap.Stringer("proxyAddress", uplink.proxyAddrPort),
zap.Uint("swgpPacketLength", uint(iovec[start].Len)),
zap.Error(err),
)
start++
}

sendmmsgCount++
msgsSent += uint64(n)
burstBatchSize = max(burstBatchSize, n)
}

if isHandshake {
Expand All @@ -595,10 +605,6 @@ main:
}
}

sendmmsgCount++
msgsSent += uint64(len(msgvec))
burstBatchSize = max(burstBatchSize, len(msgvec))

packetBuf = packetBuf[:0]
cmsgBuf = cmsgBuf[:0]
for _, buf := range bufvec {
Expand Down Expand Up @@ -898,19 +904,25 @@ func (c *client) relayProxyToWgSendmmsg(downlink clientNatDownlinkMmsg) {

queuedPackets = queuedPackets[:0]

if err = downlink.wgConn.WriteMsgs(smsgvec, 0); err != nil {
c.logger.Warn("Failed to write wgPacket to wgConn",
zap.String("client", c.name),
zap.String("listenAddress", c.wgListenAddress),
zap.Stringer("clientAddress", downlink.clientAddrPort),
zap.Stringer("proxyAddress", downlink.proxyAddrPort),
zap.Error(err),
)
}
for start := 0; start < len(smsgvec); {
n, err := downlink.wgConn.WriteMsgs(smsgvec[start:], 0)
start += n
if err != nil {
c.logger.Warn("Failed to write wgPacket to wgConn",
zap.String("client", c.name),
zap.String("listenAddress", c.wgListenAddress),
zap.Stringer("clientAddress", downlink.clientAddrPort),
zap.Stringer("proxyAddress", downlink.proxyAddrPort),
zap.Uint("wgPacketLength", uint(siovec[start].Len)),
zap.Error(err),
)
start++
}

sendmmsgCount++
msgsSent += uint64(len(smsgvec))
burstSendBatchSize = max(burstSendBatchSize, len(smsgvec))
sendmmsgCount++
msgsSent += uint64(n)
burstSendBatchSize = max(burstSendBatchSize, n)
}

sendPacketBuf = sendPacketBuf[:0]
sendCmsgBuf = sendCmsgBuf[:0]
Expand Down
Loading

0 comments on commit 60f50b0

Please sign in to comment.