diff --git a/internal/basictestserver/server.go b/internal/basictestserver/server.go index 05fa36e..dd9993c 100644 --- a/internal/basictestserver/server.go +++ b/internal/basictestserver/server.go @@ -79,14 +79,28 @@ func (t *TestServer) Stop() { func (t *TestServer) Run() { defer close(t.done) + + incoming := make(chan *packets.ControlPacket, 65535) + + // read incoming packets in a separate goroutine to avoid deadlocks due to unbuffered t.conn + go func() { + for { + recv, err := packets.ReadPacket(t.conn) + if err != nil { + t.logger.Println("error in test server reading packet", err) + close(incoming) + return + } + incoming <- recv + } + }() + for { select { case <-t.stop: return - default: - recv, err := packets.ReadPacket(t.conn) - if err != nil { - t.logger.Println("error in test server reading packet", err) + case recv, ok := <-incoming: + if !ok { return } t.logger.Println("test server received a control packet:", recv.PacketType()) @@ -179,7 +193,7 @@ func (t *TestServer) Run() { t.logger.Println("test server sending pingresp") pr := packets.NewControlPacket(packets.PINGRESP) if _, err := pr.WriteTo(t.conn); err != nil { - t.logger.Println("error writing pingreq", err) + t.logger.Println("error writing pingresp", err) } } } diff --git a/paho/client.go b/paho/client.go index 4ef2640..83a2c40 100644 --- a/paho/client.go +++ b/paho/client.go @@ -70,8 +70,9 @@ type ( Session session.SessionManager autoCloseSession bool - AuthHandler Auther - PingHandler Pinger + AuthHandler Auther + PingHandler Pinger + defaultPinger bool // Router - new inbound messages will be passed to the `Route(*packets.Publish)` function. // @@ -199,9 +200,8 @@ func NewClient(conf ClientConfig) *Client { c.onPublishReceivedTracker = make([]int, len(c.onPublishReceived)) // Must have the same number of elements as onPublishReceived if c.config.PingHandler == nil { - c.config.PingHandler = DefaultPingerWithCustomFailHandler(func(e error) { - go c.error(e) - }) + c.config.defaultPinger = true + c.config.PingHandler = NewDefaultPinger() } if c.config.OnClientError == nil { c.config.OnClientError = func(e error) {} @@ -340,15 +340,15 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) { c.serverProps.SharedSubAvailable = ca.Properties.SharedSubAvailable } - if keepalive > 0 { // "Keep Alive value of 0 has the effect of turning off..." - c.debug.Println("received CONNACK, starting PingHandler") - c.workers.Add(1) - go func() { - defer c.workers.Done() - defer c.debug.Println("returning from ping handler worker") - c.config.PingHandler.Start(c.config.Conn, time.Duration(keepalive)*time.Second) - }() - } + c.debug.Println("received CONNACK, starting PingHandler") + c.workers.Add(1) + go func() { + defer c.workers.Done() + defer c.debug.Println("returning from ping handler worker") + if err := c.config.PingHandler.Run(c.config.Conn, keepalive); err != nil { + go c.error(fmt.Errorf("ping handler error: %w", err)) + } + }() c.debug.Println("starting publish packets loop") c.workers.Add(1) @@ -502,6 +502,7 @@ func (c *Client) incoming() { go c.error(err) return } + c.config.PingHandler.PacketSent() } } case packets.PUBLISH: @@ -619,6 +620,7 @@ func (c *Client) Authenticate(ctx context.Context, a *Auth) (*AuthResponse, erro if _, err := a.Packet().WriteTo(c.config.Conn); err != nil { return nil, err } + c.config.PingHandler.PacketSent() var rp packets.ControlPacket select { @@ -679,6 +681,7 @@ func (c *Client) Subscribe(ctx context.Context, s *Subscribe) (*Suback, error) { // The packet will remain in the session state until `Session` is notified of the disconnection. return nil, err } + c.config.PingHandler.PacketSent() c.debug.Println("waiting for SUBACK") subCtx, cf := context.WithTimeout(ctx, c.config.PacketTimeout) @@ -743,6 +746,7 @@ func (c *Client) Unsubscribe(ctx context.Context, u *Unsubscribe) (*Unsuback, er // The packet will remain in the session state until `Session` is notified of the disconnection. return nil, err } + c.config.PingHandler.PacketSent() unsubCtx, cf := context.WithTimeout(ctx, c.config.PacketTimeout) defer cf() @@ -849,6 +853,7 @@ func (c *Client) PublishWithOptions(ctx context.Context, p *Publish, o PublishOp go c.error(err) return nil, err } + c.config.PingHandler.PacketSent() return nil, nil case 1, 2: return c.publishQoS12(ctx, pb, o) @@ -875,6 +880,7 @@ func (c *Client) publishQoS12(ctx context.Context, pb *packets.Publish, o Publis return nil, ErrNetworkErrorAfterStored // Async send, so we don't wait for the response (may add callbacks in the future to enable user to obtain status) } } + c.config.PingHandler.PacketSent() if o.Method == PublishMethod_AsyncSend { return nil, nil // Async send, so we don't wait for the response (may add callbacks in the future to enable user to obtain status) diff --git a/paho/client_test.go b/paho/client_test.go index 2b5f215..ef9367b 100644 --- a/paho/client_test.go +++ b/paho/client_test.go @@ -131,7 +131,7 @@ func TestClientSubscribe(t *testing.T) { }() go func() { defer c.workers.Done() - c.config.PingHandler.Start(c.config.Conn, 30*time.Second) + c.config.PingHandler.Run(c.config.Conn, 30) }() c.config.Session.ConAckReceived(c.config.Conn, &packets.Connect{}, &packets.Connack{}) @@ -176,7 +176,7 @@ func TestClientUnsubscribe(t *testing.T) { }() go func() { defer c.workers.Done() - c.config.PingHandler.Start(c.config.Conn, 30*time.Second) + c.config.PingHandler.Run(c.config.Conn, 30) }() c.config.Session.ConAckReceived(c.config.Conn, &packets.Connect{}, &packets.Connack{}) @@ -214,7 +214,7 @@ func TestClientPublishQoS0(t *testing.T) { }() go func() { defer c.workers.Done() - c.config.PingHandler.Start(c.config.Conn, 30*time.Second) + c.config.PingHandler.Run(c.config.Conn, 30) }() c.config.Session.ConAckReceived(c.config.Conn, &packets.Connect{}, &packets.Connack{}) @@ -256,7 +256,7 @@ func TestClientPublishQoS1(t *testing.T) { }() go func() { defer c.workers.Done() - c.config.PingHandler.Start(c.config.Conn, 30*time.Second) + c.config.PingHandler.Run(c.config.Conn, 30) }() c.config.Session.ConAckReceived(c.config.Conn, &packets.Connect{}, &packets.Connack{}) @@ -301,7 +301,7 @@ func TestClientPublishQoS2(t *testing.T) { }() go func() { defer c.workers.Done() - c.config.PingHandler.Start(c.config.Conn, 30*time.Second) + c.config.PingHandler.Run(c.config.Conn, 30) }() c.config.Session.ConAckReceived(c.config.Conn, &packets.Connect{}, &packets.Connack{}) @@ -348,7 +348,7 @@ func TestClientReceiveQoS0(t *testing.T) { }() go func() { defer c.workers.Done() - c.config.PingHandler.Start(c.config.Conn, 30*time.Second) + c.config.PingHandler.Run(c.config.Conn, 30) }() c.config.Session.ConAckReceived(c.config.Conn, &packets.Connect{}, &packets.Connack{}) go c.routePublishPackets() @@ -395,7 +395,7 @@ func TestClientReceiveQoS1(t *testing.T) { }() go func() { defer c.workers.Done() - c.config.PingHandler.Start(c.config.Conn, 30*time.Second) + c.config.PingHandler.Run(c.config.Conn, 30) }() c.config.Session.ConAckReceived(c.config.Conn, &packets.Connect{}, &packets.Connack{}) go c.routePublishPackets() @@ -443,7 +443,7 @@ func TestClientReceiveQoS2(t *testing.T) { }() go func() { defer c.workers.Done() - c.config.PingHandler.Start(c.config.Conn, 30*time.Second) + c.config.PingHandler.Run(c.config.Conn, 30) }() c.config.Session.ConAckReceived(c.config.Conn, &packets.Connect{}, &packets.Connack{}) go c.routePublishPackets() @@ -660,7 +660,7 @@ func TestReceiveServerDisconnect(t *testing.T) { }() go func() { defer c.workers.Done() - c.config.PingHandler.Start(c.config.Conn, 30*time.Second) + c.config.PingHandler.Run(c.config.Conn, 30) }() c.config.Session.ConAckReceived(c.config.Conn, &packets.Connect{}, &packets.Connack{}) @@ -701,7 +701,7 @@ func TestAuthenticate(t *testing.T) { }() go func() { defer c.workers.Done() - c.config.PingHandler.Start(c.config.Conn, 30*time.Second) + c.config.PingHandler.Run(c.config.Conn, 30) }() c.config.Session.ConAckReceived(c.config.Conn, &packets.Connect{}, &packets.Connack{}) diff --git a/paho/default_pinger_test.go b/paho/default_pinger_test.go new file mode 100644 index 0000000..4fdeca5 --- /dev/null +++ b/paho/default_pinger_test.go @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2024 Contributors to the Eclipse Foundation + * + * All rights reserved. This program and the accompanying materials + * are made available under the terms of the Eclipse Public License v2.0 + * and Eclipse Distribution License v1.0 which accompany this distribution. + * + * The Eclipse Public License is available at + * https://www.eclipse.org/legal/epl-2.0/ + * and the Eclipse Distribution License is available at + * http://www.eclipse.org/org/documents/edl-v10.php. + * + * SPDX-License-Identifier: EPL-2.0 OR BSD-3-Clause + */ + +package paho + +import ( + "net" + "testing" + "time" + + "github.com/eclipse/paho.golang/packets" + paholog "github.com/eclipse/paho.golang/paho/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultPingerTimeout(t *testing.T) { + fakeServerConn, fakeClientConn := net.Pipe() + + go func() { + // keep reading from fakeServerConn and throw away the data + buf := make([]byte, 1024) + for { + _, err := fakeServerConn.Read(buf) + if err != nil { + return + } + } + }() + defer fakeServerConn.Close() + + pinger := NewDefaultPinger() + pinger.SetDebug(paholog.NewTestLogger(t, "DefaultPinger:")) + + pingResult := make(chan error, 1) + go func() { + pingResult <- pinger.Run(fakeClientConn, 1) + }() + defer pinger.Stop() + + select { + case err := <-pingResult: + require.NotNil(t, err) + assert.EqualError(t, err, "PINGRESP timed out") + case <-time.After(10 * time.Second): + t.Error("expected DefaultPinger to detect timeout and return error") + } +} + +func TestDefaultPingerSuccess(t *testing.T) { + fakeClientConn, fakeServerConn := net.Pipe() + + pinger := NewDefaultPinger() + pinger.SetDebug(paholog.NewTestLogger(t, "DefaultPinger:")) + + pingResult := make(chan error, 1) + go func() { + pingResult <- pinger.Run(fakeClientConn, 3) + }() + defer pinger.Stop() + + go func() { + // keep reading from fakeServerConn and call PingResp() when a PINGREQ is received + for { + recv, err := packets.ReadPacket(fakeServerConn) + if err != nil { + return + } + if recv.Type == packets.PINGREQ { + pinger.PingResp() + } + } + }() + defer fakeServerConn.Close() + + select { + case err := <-pingResult: + t.Errorf("expected DefaultPinger to not return error, got %v", err) + case <-time.After(10 * time.Second): + // PASS + } +} + +func TestDefaultPingerPacketSent(t *testing.T) { + fakeClientConn, fakeServerConn := net.Pipe() + + pinger := NewDefaultPinger() + pinger.SetDebug(paholog.NewTestLogger(t, "DefaultPinger:")) + + pingResult := make(chan error, 1) + go func() { + pingResult <- pinger.Run(fakeClientConn, 3) + }() + defer pinger.Stop() + + // keep calling PacketSent() in a goroutine to check that the Pinger avoids sending PINGREQs when not needed + stop := make(chan struct{}) + go func() { + for { + select { + case <-stop: + return + default: + } + // keep calling PacketSent() + pinger.PacketSent() + } + }() + defer close(stop) + + // keep reading from fakeServerConn and call PingResp() when a PINGREQ is received + // if more than one PINGREQ is received, the test will fail + count := 0 + tooManyPingreqs := make(chan struct{}) + go func() { + for { + recv, err := packets.ReadPacket(fakeServerConn) + if err != nil { + return + } + if recv.Type == packets.PINGREQ { + count++ + pinger.PingResp() + if count > 1 { // we allow the count to be 1 because the first PINGREQ is sent immediately + close(tooManyPingreqs) + } + } + } + }() + defer fakeServerConn.Close() + + select { + case <-tooManyPingreqs: + t.Error("expected DefaultPinger to not send PINGREQs when not needed") + case err := <-pingResult: + t.Errorf("expected DefaultPinger to not return error, got %v", err) + case <-time.After(10 * time.Second): + // PASS + } +} diff --git a/paho/packet_ids_test.go b/paho/packet_ids_test.go index d28b029..df74552 100644 --- a/paho/packet_ids_test.go +++ b/paho/packet_ids_test.go @@ -46,7 +46,7 @@ func TestPackedIdNoExhaustion(t *testing.T) { c.stop = make(chan struct{}) c.publishPackets = make(chan *packets.Publish) go c.incoming() - go c.config.PingHandler.Start(c.config.Conn, 30*time.Second) + go c.config.PingHandler.Run(c.config.Conn, 30) c.config.Session.ConAckReceived(c.config.Conn, &packets.Connect{}, &packets.Connack{}) for i := 0; i < 70000; i++ { diff --git a/paho/pinger.go b/paho/pinger.go index 5f2096d..d8ba4f0 100644 --- a/paho/pinger.go +++ b/paho/pinger.go @@ -19,120 +19,168 @@ import ( "fmt" "net" "sync" - "sync/atomic" "time" "github.com/eclipse/paho.golang/packets" "github.com/eclipse/paho.golang/paho/log" ) -// PingFailHandler is a type for the function that is invoked -// when we have sent a Pingreq to the server and not received -// a Pingresp within 1.5x our pingtimeout -type PingFailHandler func(error) - -// Pinger is an interface of the functions for a struct that is -// used to manage sending PingRequests and responding to -// PingResponses -// Start() takes a net.Conn which is a connection over which an -// MQTT session has already been established, and a time.Duration -// of the keepalive setting passed to the server when the MQTT -// session was established. -// Stop() is used to stop the Pinger -// PingResp() is the function that is called by the Client when -// a PingResponse is received -// SetDebug() is used to pass in a Logger to be used to log debug -// information, for example sharing a logger with the main client type Pinger interface { - Start(net.Conn, time.Duration) + // Run() starts the pinger. It blocks until the pinger is stopped. + // If the pinger stops due to an error, it returns the error. + // If the keepAlive is 0, it returns nil immediately. + // Run() must be called only once. + Run(conn net.Conn, keepAlive uint16) error + + // Stop() gracefully stops the pinger. Stop() + + // PacketSent() is called when a packet is sent to the server. + PacketSent() + + // PingResp() is called when a PINGRESP is received from the server. PingResp() + + // SetDebug() sets the logger for debugging. + // It is not thread-safe and must be called before Run() to avoid race conditions. SetDebug(log.Logger) } -// PingHandler is the library provided default Pinger -type PingHandler struct { - mu sync.Mutex - lastPing time.Time - conn net.Conn - stop chan struct{} - pingFailHandler PingFailHandler - pingOutstanding int32 - debug log.Logger +// DefaultPinger is the default implementation of Pinger. +type DefaultPinger struct { + timer *time.Timer + keepAlive uint16 + conn net.Conn + previousPingAcked chan struct{} + done chan struct{} + errChan chan error + ackReceived chan struct{} + stopOnce sync.Once + mu sync.Mutex + debug log.Logger } -// DefaultPingerWithCustomFailHandler returns an instance of the -// default Pinger but with a custom PingFailHandler that is called -// when the client has not received a response to a PingRequest -// within the appropriate amount of time -func DefaultPingerWithCustomFailHandler(pfh PingFailHandler) *PingHandler { - return &PingHandler{ - pingFailHandler: pfh, - debug: log.NOOPLogger{}, +func NewDefaultPinger() *DefaultPinger { + previousPingAcked := make(chan struct{}, 1) + previousPingAcked <- struct{}{} // initial value + return &DefaultPinger{ + previousPingAcked: previousPingAcked, + errChan: make(chan error, 1), + done: make(chan struct{}), + ackReceived: make(chan struct{}, 1), + debug: log.NOOPLogger{}, } } -// Start is the library provided Pinger's implementation of -// the required interface function() -func (p *PingHandler) Start(c net.Conn, pt time.Duration) { +func (p *DefaultPinger) Run(conn net.Conn, keepAlive uint16) error { + if keepAlive == 0 { + p.debug.Println("Run() returning immediately due to keepAlive == 0") + return nil + } + if conn == nil { + return fmt.Errorf("conn is nil") + } p.mu.Lock() - p.conn = c - p.stop = make(chan struct{}) - p.mu.Unlock() - checkTicker := time.NewTicker(pt / 4) - defer checkTicker.Stop() - for { - select { - case <-p.stop: - return - case <-checkTicker.C: - if atomic.LoadInt32(&p.pingOutstanding) > 0 && time.Since(p.lastPing) > (pt+pt>>1) { - p.pingFailHandler(fmt.Errorf("ping resp timed out")) - // ping outstanding and not reset in 1.5 times ping timer - return - } - if time.Since(p.lastPing) >= pt { - // time to send a ping - if _, err := packets.NewControlPacket(packets.PINGREQ).WriteTo(p.conn); err != nil { - if p.pingFailHandler != nil { - p.pingFailHandler(err) - } - return - } - atomic.AddInt32(&p.pingOutstanding, 1) - p.lastPing = time.Now() - p.debug.Println("pingHandler sending ping request") - } - } + if p.timer != nil { + p.mu.Unlock() + return fmt.Errorf("Run() already called") + } + select { + case <-p.done: + p.mu.Unlock() + return fmt.Errorf("Run() called after stop()") + default: } + p.keepAlive = keepAlive + p.conn = conn + p.timer = time.AfterFunc(0, p.sendPingreq) // Immediately send first pingreq + p.mu.Unlock() + + return <-p.errChan +} + +func (p *DefaultPinger) Stop() { + p.stop(nil) } -// Stop is the library provided Pinger's implementation of -// the required interface function() -func (p *PingHandler) Stop() { +func (p *DefaultPinger) PacketSent() { p.mu.Lock() defer p.mu.Unlock() - if p.stop == nil { + if p.timer == nil { + p.debug.Println("PacketSent() called before Run()") return } - p.debug.Println("pingHandler stopping") select { - case <-p.stop: - // Already stopped, do nothing + case <-p.done: + p.debug.Println("PacketSent() returning due to done channel") + return + default: + } + + p.debug.Println("PacketSent() resetting timer") + p.timer.Reset(time.Duration(p.keepAlive) * time.Second) +} + +func (p *DefaultPinger) PingResp() { + select { + case p.ackReceived <- struct{}{}: default: - close(p.stop) + p.debug.Println("PingResp() called when ackReceived channel is full") + p.stop(fmt.Errorf("received unexpected PINGRESP")) } } -// PingResp is the library provided Pinger's implementation of -// the required interface function() -func (p *PingHandler) PingResp() { - p.debug.Println("pingHandler resetting pingOutstanding") - atomic.StoreInt32(&p.pingOutstanding, 0) +func (p *DefaultPinger) SetDebug(debug log.Logger) { + p.debug = debug +} + +func (p *DefaultPinger) sendPingreq() { + // Wait for previous ping to be acked before sending another + select { + case <-p.previousPingAcked: + case <-p.done: + p.debug.Println("sendPingreq() returning before sending PINGREQ due to done channel") + return + } + + p.debug.Println("sendPingreq() sending PINGREQ packet") + if _, err := packets.NewControlPacket(packets.PINGREQ).WriteTo(p.conn); err != nil { + p.stop(fmt.Errorf("failed to send PINGREQ: %w", err)) + p.debug.Printf("sendPingreq() calling stop() and returning due to packet write error: %v", err) + return + } + p.debug.Println("sendPingreq() sent PINGREQ packet, waiting for PINGRESP") + pingrespTimeout := time.NewTimer(time.Duration(p.keepAlive) * time.Second) + + p.PacketSent() + + select { + case <-p.done: + p.debug.Println("sendPingreq() returning after sending PINGREQ due to done channel") + case <-p.ackReceived: + p.previousPingAcked <- struct{}{} + p.debug.Println("sendPingreq() returning after receiving PINGRESP") + case <-pingrespTimeout.C: + p.debug.Println("sendPingreq() calling stop() and returning due to PINGRESP timeout") + p.stop(fmt.Errorf("PINGRESP timed out")) + return + } + + // Stop the timer if it hasn't fired yet + if !pingrespTimeout.Stop() { + <-pingrespTimeout.C + } } -// SetDebug sets the logger l to be used for printing debug -// information for the pinger -func (p *PingHandler) SetDebug(l log.Logger) { - p.debug = l +func (p *DefaultPinger) stop(err error) { + p.mu.Lock() + defer p.mu.Unlock() + p.debug.Printf("stop() called with error: %v", err) + p.stopOnce.Do(func() { + if p.timer != nil { + p.timer.Stop() + } + p.errChan <- err + close(p.done) + }) }