diff --git a/service/udp.go b/service/udp.go index 4830e302..57a0efb7 100644 --- a/service/udp.go +++ b/service/udp.go @@ -87,6 +87,7 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis type packetHandler struct { natTimeout time.Duration ciphers CipherList + nm *natmap m UDPMetrics targetIPValidator onet.TargetIPValidator } @@ -113,108 +114,94 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali func (h *packetHandler) Handle(clientConn net.PacketConn) { var running sync.WaitGroup - nm := newNATmap(h.natTimeout, h.m, &running) - defer nm.Close() - cipherBuf := make([]byte, serverUDPBufferSize) - textBuf := make([]byte, serverUDPBufferSize) + h.nm = newNATmap(h.natTimeout, h.m, &running) + defer h.nm.Close() for { - clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) - if errors.Is(err, net.ErrClosed) { - break + status := "OK" + keyID, clientInfo, clientProxyBytes, proxyTargetBytes, connErr := h.handleConnection(clientConn) + if connErr != nil { + if errors.Is(connErr.Cause, net.ErrClosed) { + break + } + logger.Debugf("UDP Error: %v: %v", connErr.Message, connErr.Cause) + status = connErr.Status } + h.m.AddUDPPacketFromClient(clientInfo, keyID, status, clientProxyBytes, proxyTargetBytes) + } +} - var clientInfo ipinfo.IPInfo - keyID := "" - var proxyTargetBytes int +func (h *packetHandler) authenticate(clientConn net.PacketConn) (*natconn, []byte, int, *onet.ConnectionError) { + cipherBuf := make([]byte, serverUDPBufferSize) + textBuf := make([]byte, serverUDPBufferSize) + clientProxyBytes, clientAddr, err := clientConn.ReadFrom(cipherBuf) + if err != nil { + return nil, nil, 0, onet.NewConnectionError("ERR_READ", "Failed to read from client", err) + } - connError := func() (connError *onet.ConnectionError) { - defer func() { - if r := recover(); r != nil { - logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r) - debug.PrintStack() - } - }() + if logger.IsEnabledFor(logging.DEBUG) { + defer logger.Debugf("UDP(%v): done", clientAddr) + logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes) + } - // Error from ReadFrom - if err != nil { - return onet.NewConnectionError("ERR_READ", "Failed to read from client", err) - } - if logger.IsEnabledFor(logging.DEBUG) { - defer logger.Debugf("UDP(%v): done", clientAddr) - logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes) - } + targetConn := h.nm.Get(clientAddr.String()) + remoteIP := clientAddr.(*net.UDPAddr).AddrPort().Addr() - cipherData := cipherBuf[:clientProxyBytes] - var payload []byte - var tgtUDPAddr *net.UDPAddr - targetConn := nm.Get(clientAddr.String()) - if targetConn == nil { - var locErr error - clientInfo, locErr = ipinfo.GetIPInfoFromAddr(h.m, clientAddr) - if locErr != nil { - logger.Warningf("Failed client info lookup: %v", locErr) - } - debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo) - - ip := clientAddr.(*net.UDPAddr).AddrPort().Addr() - var textData []byte - var cryptoKey *shadowsocks.EncryptionKey - unpackStart := time.Now() - textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers) - timeToCipher := time.Since(unpackStart) - h.m.AddUDPCipherSearch(err == nil, timeToCipher) - - if err != nil { - return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack initial packet", err) - } + unpackStart := time.Now() + textData, keyID, cryptoKey, keyErr := findAccessKeyUDP(remoteIP, textBuf, cipherBuf[:clientProxyBytes], h.ciphers) + timeToCipher := time.Since(unpackStart) + h.m.AddUDPCipherSearch(err == nil, timeToCipher) + if keyErr != nil { + return targetConn, nil, 0, onet.NewConnectionError("ERR_CIPHER", "Failed to find a valid cipher", keyErr) + } - var onetErr *onet.ConnectionError - if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { - return onetErr - } + if targetConn != nil { + return targetConn, textData, clientProxyBytes, nil + } - udpConn, err := net.ListenPacket("udp", "") - if err != nil { - return onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) - } - targetConn = nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID) - } else { - clientInfo = targetConn.clientInfo + udpConn, err := net.ListenPacket("udp", "") + if err != nil { + return targetConn, textData, clientProxyBytes, onet.NewConnectionError("ERR_CREATE_SOCKET", "Failed to create UDP socket", err) + } - unpackStart := time.Now() - textData, err := shadowsocks.Unpack(nil, cipherData, targetConn.cryptoKey) - timeToCipher := time.Since(unpackStart) - h.m.AddUDPCipherSearch(err == nil, timeToCipher) + clientInfo, locErr := ipinfo.GetIPInfoFromAddr(h.m, clientAddr) + if locErr != nil { + logger.Warningf("Failed client info lookup: %v", locErr) + } + debugUDPAddr(clientAddr, "Got info \"%#v\"", clientInfo) - if err != nil { - return onet.NewConnectionError("ERR_CIPHER", "Failed to unpack data from client", err) - } + targetConn = h.nm.Add(clientAddr, clientConn, cryptoKey, udpConn, clientInfo, keyID) + return targetConn, textData, clientProxyBytes, nil +} - // The key ID is known with confidence once decryption succeeds. - keyID = targetConn.keyID +func (h *packetHandler) handleConnection(clientConn net.PacketConn) (string, ipinfo.IPInfo, int, int, *onet.ConnectionError) { + defer func() { + if r := recover(); r != nil { + logger.Errorf("Panic in UDP loop: %v. Continuing to listen.", r) + debug.PrintStack() + } + }() - var onetErr *onet.ConnectionError - if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil { - return onetErr - } - } + targetConn, textData, clientProxyBytes, authErr := h.authenticate(clientConn) + if authErr != nil { + var clientInfo ipinfo.IPInfo + if targetConn != nil { + clientInfo = targetConn.clientInfo + } + return "", clientInfo, clientProxyBytes, 0, authErr + } - debugUDPAddr(clientAddr, "Proxy exit %v", targetConn.LocalAddr()) - proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature - if err != nil { - return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) - } - return nil - }() + payload, tgtUDPAddr, onetErr := h.validatePacket(textData) + if onetErr != nil { + return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, 0, onetErr + } - status := "OK" - if connError != nil { - logger.Debugf("UDP Error: %v: %v", connError.Message, connError.Cause) - status = connError.Status - } - h.m.AddUDPPacketFromClient(clientInfo, keyID, status, clientProxyBytes, proxyTargetBytes) + debugUDPAddr(targetConn.clientAddr, "Proxy exit %v", targetConn.LocalAddr()) + proxyTargetBytes, err := targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature + if err != nil { + return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, proxyTargetBytes, onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) } + return targetConn.keyID, targetConn.clientInfo, clientProxyBytes, proxyTargetBytes, nil } // Given the decrypted contents of a UDP packet, return @@ -245,8 +232,9 @@ func isDNS(addr net.Addr) bool { type natconn struct { net.PacketConn - cryptoKey *shadowsocks.EncryptionKey - keyID string + cryptoKey *shadowsocks.EncryptionKey + keyID string + clientAddr net.Addr // We store the client information in the NAT map to avoid recomputing it // for every downstream packet in a UDP-based connection. clientInfo ipinfo.IPInfo @@ -327,11 +315,12 @@ func (m *natmap) Get(key string) *natconn { return m.keyConn[key] } -func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, clientInfo ipinfo.IPInfo) *natconn { +func (m *natmap) set(clientAddr net.Addr, pc net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, keyID string, clientInfo ipinfo.IPInfo) *natconn { entry := &natconn{ PacketConn: pc, cryptoKey: cryptoKey, keyID: keyID, + clientAddr: clientAddr, clientInfo: clientInfo, defaultTimeout: m.timeout, } @@ -339,7 +328,7 @@ func (m *natmap) set(key string, pc net.PacketConn, cryptoKey *shadowsocks.Encry m.Lock() defer m.Unlock() - m.keyConn[key] = entry + m.keyConn[clientAddr.String()] = entry return entry } @@ -356,7 +345,7 @@ func (m *natmap) del(key string) net.PacketConn { } func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *shadowsocks.EncryptionKey, targetConn net.PacketConn, clientInfo ipinfo.IPInfo, keyID string) *natconn { - entry := m.set(clientAddr.String(), targetConn, cryptoKey, keyID, clientInfo) + entry := m.set(clientAddr, targetConn, cryptoKey, keyID, clientInfo) m.metrics.AddUDPNatEntry(clientAddr, keyID) m.running.Add(1) diff --git a/service/udp_test.go b/service/udp_test.go index f94238c5..e461280e 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -162,12 +162,20 @@ func TestIPFilter(t *testing.T) { t.Run("Localhost allowed", func(t *testing.T) { metrics := sendToDiscard(payloads, allowAll) + assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded) + assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets) + for _, report := range metrics.upstreamPackets { + assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size") + assert.Greater(t, report.proxyTargetBytes, 0, "Expected nonzero bytes to be sent for allowed packet") + assert.Equal(t, report.accessKey, "id-0", "Unexpected access key: %s", report.accessKey) + } }) t.Run("Localhost not allowed", func(t *testing.T) { metrics := sendToDiscard(payloads, onet.RequirePublicIP) - assert.Equal(t, 0, metrics.natEntriesAdded, "Unexpected NAT entry on rejected packet") + + assert.Equal(t, metrics.natEntriesAdded, 1, "Expected 1 NAT entry, not %d", metrics.natEntriesAdded) assert.Equal(t, 2, len(metrics.upstreamPackets), "Expected 2 reports, not %v", metrics.upstreamPackets) for _, report := range metrics.upstreamPackets { assert.Greater(t, report.clientProxyBytes, 0, "Expected nonzero input packet size")