From 2287b9d2025e1ba578a27511ee32fe130fbe2896 Mon Sep 17 00:00:00 2001 From: Merlin Ran Date: Wed, 18 Jan 2023 07:19:47 -0500 Subject: [PATCH] Optimize locking (#143) * optimise locking of streams list * fix two race conditions --- client.go | 12 +++++---- http_test.go | 2 +- server.go | 68 +++++++++++++++++++++++++++----------------------- stream.go | 6 ++++- stream_test.go | 5 ++-- 5 files changed, 53 insertions(+), 40 deletions(-) diff --git a/client.go b/client.go index a054b55..61772b6 100644 --- a/client.go +++ b/client.go @@ -13,6 +13,7 @@ import ( "io" "net/http" "sync" + "sync/atomic" "time" "gopkg.in/cenkalti/backoff.v1" @@ -49,7 +50,7 @@ type Client struct { ResponseValidator ResponseValidator Connection *http.Client URL string - EventID string + LastEventID atomic.Value // []byte maxBufferSize int mu sync.Mutex EncodingBase64 bool @@ -234,9 +235,9 @@ func (c *Client) readLoop(reader *EventStreamReader, outCh chan *Event, erChan c var msg *Event if msg, err = c.processEvent(event); err == nil { if len(msg.ID) > 0 { - c.EventID = string(msg.ID) + c.LastEventID.Store(msg.ID) } else { - msg.ID = []byte(c.EventID) + msg.ID, _ = c.LastEventID.Load().([]byte) } // Send downstream if the event has something useful @@ -305,8 +306,9 @@ func (c *Client) request(ctx context.Context, stream string) (*http.Response, er req.Header.Set("Accept", "text/event-stream") req.Header.Set("Connection", "keep-alive") - if c.EventID != "" { - req.Header.Set("Last-Event-ID", c.EventID) + lastID, exists := c.LastEventID.Load().([]byte) + if exists && lastID != nil { + req.Header.Set("Last-Event-ID", string(lastID)) } // Add user specified headers diff --git a/http_test.go b/http_test.go index a7e4a5e..a415022 100644 --- a/http_test.go +++ b/http_test.go @@ -102,7 +102,7 @@ func TestHTTPStreamHandlerEventID(t *testing.T) { time.Sleep(time.Millisecond * 100) c := NewClient(server.URL + "/events") - c.EventID = "2" + c.LastEventID.Store([]byte("2")) events := make(chan *Event) var cErr error diff --git a/server.go b/server.go index db82257..9774db4 100644 --- a/server.go +++ b/server.go @@ -15,13 +15,12 @@ const DefaultBufferSize = 1024 // Server Is our main struct type Server struct { - Streams map[string]*Stream + // Extra headers adding to the HTTP response to each client Headers map[string]string // Sets a ttl that prevents old events from being transmitted EventTTL time.Duration // Specifies the size of the message buffer for each stream BufferSize int - mu sync.Mutex // Encodes all data as base64 EncodeBase64 bool // Splits an events data into multiple data: entries @@ -34,6 +33,9 @@ type Server struct { // Specifies the function to run when client subscribe or un-subscribe OnSubscribe func(streamID string, sub *Subscriber) OnUnsubscribe func(streamID string, sub *Subscriber) + + streams map[string]*Stream + muStreams sync.RWMutex } // New will create a server and setup defaults @@ -42,7 +44,7 @@ func New() *Server { BufferSize: DefaultBufferSize, AutoStream: false, AutoReplay: true, - Streams: make(map[string]*Stream), + streams: make(map[string]*Stream), Headers: map[string]string{}, } } @@ -53,7 +55,7 @@ func NewWithCallback(onSubscribe, onUnsubscribe func(streamID string, sub *Subsc BufferSize: DefaultBufferSize, AutoStream: false, AutoReplay: true, - Streams: make(map[string]*Stream), + streams: make(map[string]*Stream), Headers: map[string]string{}, OnSubscribe: onSubscribe, OnUnsubscribe: onUnsubscribe, @@ -62,64 +64,68 @@ func NewWithCallback(onSubscribe, onUnsubscribe func(streamID string, sub *Subsc // Close shuts down the server, closes all of the streams and connections func (s *Server) Close() { - s.mu.Lock() - defer s.mu.Unlock() + s.muStreams.Lock() + defer s.muStreams.Unlock() - for id := range s.Streams { - s.Streams[id].quit <- struct{}{} - delete(s.Streams, id) + for id := range s.streams { + s.streams[id].close() + delete(s.streams, id) } } // CreateStream will create a new stream and register it func (s *Server) CreateStream(id string) *Stream { - s.mu.Lock() - defer s.mu.Unlock() + s.muStreams.Lock() + defer s.muStreams.Unlock() - if s.Streams[id] != nil { - return s.Streams[id] + if s.streams[id] != nil { + return s.streams[id] } str := newStream(id, s.BufferSize, s.AutoReplay, s.AutoStream, s.OnSubscribe, s.OnUnsubscribe) str.run() - s.Streams[id] = str + s.streams[id] = str return str } // RemoveStream will remove a stream func (s *Server) RemoveStream(id string) { - s.mu.Lock() - defer s.mu.Unlock() + s.muStreams.Lock() + defer s.muStreams.Unlock() - if s.Streams[id] != nil { - s.Streams[id].close() - delete(s.Streams, id) + if s.streams[id] != nil { + s.streams[id].close() + delete(s.streams, id) } } // StreamExists checks whether a stream by a given id exists func (s *Server) StreamExists(id string) bool { - s.mu.Lock() - defer s.mu.Unlock() - - return s.Streams[id] != nil + return s.getStream(id) != nil } -// Publish sends a mesage to every client in a streamID +// Publish sends a mesage to every client in a streamID. +// If the stream's buffer is full, it blocks until the message is sent out to +// all subscribers (but not necessarily arrived the clients), or when the +// stream is closed. func (s *Server) Publish(id string, event *Event) { - s.mu.Lock() - defer s.mu.Unlock() - if s.Streams[id] != nil { - s.Streams[id].event <- s.process(event) + stream := s.getStream(id) + if stream == nil { + return + } + + select { + case <-stream.quit: + case stream.event <- s.process(event): } } func (s *Server) getStream(id string) *Stream { - s.mu.Lock() - defer s.mu.Unlock() - return s.Streams[id] + s.muStreams.RLock() + defer s.muStreams.RUnlock() + return s.streams[id] } func (s *Server) process(event *Event) *Event { diff --git a/stream.go b/stream.go index 5e251c1..bfbcb9b 100644 --- a/stream.go +++ b/stream.go @@ -6,6 +6,7 @@ package sse import ( "net/url" + "sync" "sync/atomic" ) @@ -14,6 +15,7 @@ type Stream struct { ID string event chan *Event quit chan struct{} + quitOnce sync.Once register chan *Subscriber deregister chan *Subscriber subscribers []*Subscriber @@ -87,7 +89,9 @@ func (str *Stream) run() { } func (str *Stream) close() { - str.quit <- struct{}{} + str.quitOnce.Do(func() { + close(str.quit) + }) } func (str *Stream) getSubIndex(sub *Subscriber) int { diff --git a/stream_test.go b/stream_test.go index 761f1a1..1c89a6e 100644 --- a/stream_test.go +++ b/stream_test.go @@ -38,9 +38,10 @@ func TestStreamRemoveSubscriber(t *testing.T) { s.run() defer s.close() - s.addSubscriber(0, nil) + sub := s.addSubscriber(0, nil) + time.Sleep(time.Millisecond * 100) + s.deregister <- sub time.Sleep(time.Millisecond * 100) - s.removeSubscriber(0) assert.Equal(t, 0, s.getSubscriberCount()) }