From 700c0e37279890b14dd5372c22aad7efbd0b78ee Mon Sep 17 00:00:00 2001 From: Jeanette Booher Date: Wed, 20 Mar 2024 15:55:15 -0600 Subject: [PATCH] add udp option; very wip --- src/pkg/auth/uaa_client_test.go | 4 +-- src/pkg/leanstreams/tcpclient.go | 4 +-- src/pkg/leanstreams/tcplistener.go | 48 ++++++++++++++++++++++++++++-- src/pkg/leanstreams/tcpserver.go | 28 ++++++++++++----- src/pkg/leanstreams/util_test.go | 2 +- 5 files changed, 71 insertions(+), 15 deletions(-) diff --git a/src/pkg/auth/uaa_client_test.go b/src/pkg/auth/uaa_client_test.go index cc4646ccd..2b7759418 100644 --- a/src/pkg/auth/uaa_client_test.go +++ b/src/pkg/auth/uaa_client_test.go @@ -10,9 +10,9 @@ import ( "sync" "time" - "github.com/cloudfoundry/metric-store-release/src/pkg/logger" "github.com/cloudfoundry/metric-store-release/src/internal/testing" "github.com/cloudfoundry/metric-store-release/src/pkg/auth" + "github.com/cloudfoundry/metric-store-release/src/pkg/logger" "bytes" "encoding/base64" @@ -29,7 +29,7 @@ import ( ) var _ = Describe("UAAClient", func() { - Context("Read()", func() { + Context("ReadTCP()", func() { var tc *UAATestContext BeforeEach(func() { tc = uaaSetup() diff --git a/src/pkg/leanstreams/tcpclient.go b/src/pkg/leanstreams/tcpclient.go index 4f5372410..fce702b16 100644 --- a/src/pkg/leanstreams/tcpclient.go +++ b/src/pkg/leanstreams/tcpclient.go @@ -51,9 +51,9 @@ func newTCPClient(cfg *TCPClientConfig) *TCPClient { return &TCPClient{ MaxMessageSize: maxMessageSize, - headerByteSize: headerByteSize, + headerByteSize: tcpHeaderByteSize, address: cfg.Address, - incomingHeaderBuffer: make([]byte, headerByteSize), + incomingHeaderBuffer: make([]byte, tcpHeaderByteSize), tlsConfig: cfg.TLSConfig, done: make(chan struct{}), } diff --git a/src/pkg/leanstreams/tcplistener.go b/src/pkg/leanstreams/tcplistener.go index 53024a9bd..c9fb6ba97 100644 --- a/src/pkg/leanstreams/tcplistener.go +++ b/src/pkg/leanstreams/tcplistener.go @@ -26,6 +26,7 @@ type TCPListener struct { ConnConfig *TCPServerConfig tlsConfig *tls.Config Address string + IsUDP bool connectionCount int connectionCountMetricName string @@ -68,6 +69,7 @@ type TCPListenerConfig struct { TLSConfig *tls.Config MetricRegistrar MetricRegistrar ConnCountMetricName string + isUDP bool } // ListenTCP creates a TCPListener, and opens it's local connection to @@ -99,6 +101,7 @@ func ListenTCP(cfg TCPListenerConfig) (*TCPListener, error) { connectionCount: 0, connectionCountMetricName: cfg.ConnCountMetricName, metrics: cfg.MetricRegistrar, + IsUDP: false, } if err := btl.openSocket(); err != nil { @@ -108,6 +111,41 @@ func ListenTCP(cfg TCPListenerConfig) (*TCPListener, error) { return btl, nil } +func ListenUDP(cfg TCPListenerConfig) (*TCPListener, error) { + maxMessageSize := DefaultMaxMessageSize + // 0 is the default, and the message must be atleast 1 byte large + if cfg.MaxMessageSize != 0 { + maxMessageSize = cfg.MaxMessageSize + } + connCfg := TCPServerConfig{ + MaxMessageSize: maxMessageSize, + Address: cfg.Address, + TLSConfig: cfg.TLSConfig, + } + + ctx, cancel := context.WithCancel(context.Background()) + + btl := &TCPListener{ + logger: cfg.Logger, + callback: cfg.Callback, + shutdown: cancel, + shutdownCtx: ctx, + shutdownGroup: &sync.WaitGroup{}, + ConnConfig: &connCfg, + tlsConfig: cfg.TLSConfig, + Address: "", + connectionCount: 0, + connectionCountMetricName: cfg.ConnCountMetricName, + metrics: cfg.MetricRegistrar, + IsUDP: true, + } + + if err := btl.openSocket(); err != nil { + return nil, err + } + + return btl, nil +} func (t *TCPListener) Addr() net.Addr { if t.socket == nil { return nil @@ -137,7 +175,7 @@ func (t *TCPListener) blockListen() error { continue } - conn := newTCPServer(t.ConnConfig) + conn := newTCPServer(t.ConnConfig, t.IsUDP) // Don't dial out, wrap the underlying conn in one of ours conn.socket = c @@ -279,7 +317,13 @@ func (t *TCPListener) readLoop(conn *TCPServer) { // we want to kill the connection, exit the goroutine, and let the client handle re-connecting if need be. // Handle getting the data header for { - msgLen, err := conn.Read(dataBuffer) + var msgLen int + var err error + if t.IsUDP { + msgLen, err = conn.ReadUDP(dataBuffer) + } else { + msgLen, err = conn.ReadTCP(dataBuffer) + } if err != nil { if t.logger != nil { t.logger.Printf("Address %s: Failure to read from connection. Underlying error: %s", conn.address, err) diff --git a/src/pkg/leanstreams/tcpserver.go b/src/pkg/leanstreams/tcpserver.go index 605f80999..774f071de 100644 --- a/src/pkg/leanstreams/tcpserver.go +++ b/src/pkg/leanstreams/tcpserver.go @@ -10,7 +10,7 @@ import ( ) const ( - headerByteSize = 8 + tcpHeaderByteSize = 8 ) var ( @@ -30,6 +30,7 @@ type TCPServer struct { tlsConfig *tls.Config headerByteSize int MaxMessageSize int + isUDP bool // For processing incoming data incomingHeaderBuffer []byte @@ -51,7 +52,7 @@ type TCPServerConfig struct { TLSConfig *tls.Config } -func newTCPServer(cfg *TCPServerConfig) *TCPServer { +func newTCPServer(cfg *TCPServerConfig, isUDP bool) *TCPServer { maxMessageSize := DefaultMaxMessageSize // 0 is the default, and the message must be atleast 1 byte large if cfg.MaxMessageSize != 0 { @@ -60,12 +61,13 @@ func newTCPServer(cfg *TCPServerConfig) *TCPServer { return &TCPServer{ MaxMessageSize: maxMessageSize, - headerByteSize: headerByteSize, + headerByteSize: tcpHeaderByteSize, address: cfg.Address, - incomingHeaderBuffer: make([]byte, headerByteSize), + incomingHeaderBuffer: make([]byte, tcpHeaderByteSize), writeLock: sync.Mutex{}, outgoingDataBuffer: make([]byte, maxMessageSize), tlsConfig: cfg.TLSConfig, + isUDP: isUDP, } } @@ -79,7 +81,7 @@ func (c *TCPServer) Close() error { } func (c *TCPServer) lowLevelRead(buffer []byte) (int, error) { - fmt.Println("In low level Read") + fmt.Println("In low level ReadTCP") var totalBytesRead = 0 var err error var bytesRead = 0 @@ -106,13 +108,13 @@ func (c *TCPServer) lowLevelRead(buffer []byte) (int, error) { // as a general case in other networking code. Following in the footsteps of (greatness|madness) return totalBytesRead, err } - // Read some bytes, return the length + // ReadTCP some bytes, return the length return totalBytesRead, nil } -func (c *TCPServer) Read(b []byte) (int, error) { - // Read the header +func (c *TCPServer) ReadTCP(b []byte) (int, error) { + // ReadTCP the header hLength, err := c.lowLevelRead(c.incomingHeaderBuffer) if err != nil { return hLength, err @@ -147,3 +149,13 @@ func (c *TCPServer) Read(b []byte) (int, error) { } return bLength, err } + +func (c *TCPServer) ReadUDP(b []byte) (int, error) { + + // Using the header, read the remaining body + bLength, err := c.lowLevelRead(b) + if err != nil { + c.Close() + } + return bLength, err +} diff --git a/src/pkg/leanstreams/util_test.go b/src/pkg/leanstreams/util_test.go index 1ac778758..4074c7631 100644 --- a/src/pkg/leanstreams/util_test.go +++ b/src/pkg/leanstreams/util_test.go @@ -24,7 +24,7 @@ func TestMessageBytesToInt(t *testing.T) { } for _, c := range cases { - bytes := int64ToByteArray(c.input, headerByteSize) + bytes := int64ToByteArray(c.input, tcpHeaderByteSize) result, _ := byteArrayToInt64(bytes) if result != c.output { t.Errorf("Conversion between bytes incorrect. Original value %d, got %d", c.input, result)