Skip to content

Commit

Permalink
Optimize locking (#143)
Browse files Browse the repository at this point in the history
* optimise locking of streams list

* fix two race conditions
  • Loading branch information
merlinran authored Jan 18, 2023
1 parent c2d7462 commit 2287b9d
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 40 deletions.
12 changes: 7 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"io"
"net/http"
"sync"
"sync/atomic"
"time"

"gopkg.in/cenkalti/backoff.v1"
Expand Down Expand Up @@ -49,7 +50,7 @@ type Client struct {
ResponseValidator ResponseValidator
Connection *http.Client
URL string
EventID string
LastEventID atomic.Value // []byte
maxBufferSize int
mu sync.Mutex
EncodingBase64 bool
Expand Down Expand Up @@ -234,9 +235,9 @@ func (c *Client) readLoop(reader *EventStreamReader, outCh chan *Event, erChan c
var msg *Event
if msg, err = c.processEvent(event); err == nil {
if len(msg.ID) > 0 {
c.EventID = string(msg.ID)
c.LastEventID.Store(msg.ID)
} else {
msg.ID = []byte(c.EventID)
msg.ID, _ = c.LastEventID.Load().([]byte)
}

// Send downstream if the event has something useful
Expand Down Expand Up @@ -305,8 +306,9 @@ func (c *Client) request(ctx context.Context, stream string) (*http.Response, er
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Connection", "keep-alive")

if c.EventID != "" {
req.Header.Set("Last-Event-ID", c.EventID)
lastID, exists := c.LastEventID.Load().([]byte)
if exists && lastID != nil {
req.Header.Set("Last-Event-ID", string(lastID))
}

// Add user specified headers
Expand Down
2 changes: 1 addition & 1 deletion http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestHTTPStreamHandlerEventID(t *testing.T) {
time.Sleep(time.Millisecond * 100)

c := NewClient(server.URL + "/events")
c.EventID = "2"
c.LastEventID.Store([]byte("2"))

events := make(chan *Event)
var cErr error
Expand Down
68 changes: 37 additions & 31 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ const DefaultBufferSize = 1024

// Server Is our main struct
type Server struct {
Streams map[string]*Stream
// Extra headers adding to the HTTP response to each client
Headers map[string]string
// Sets a ttl that prevents old events from being transmitted
EventTTL time.Duration
// Specifies the size of the message buffer for each stream
BufferSize int
mu sync.Mutex
// Encodes all data as base64
EncodeBase64 bool
// Splits an events data into multiple data: entries
Expand All @@ -34,6 +33,9 @@ type Server struct {
// Specifies the function to run when client subscribe or un-subscribe
OnSubscribe func(streamID string, sub *Subscriber)
OnUnsubscribe func(streamID string, sub *Subscriber)

streams map[string]*Stream
muStreams sync.RWMutex
}

// New will create a server and setup defaults
Expand All @@ -42,7 +44,7 @@ func New() *Server {
BufferSize: DefaultBufferSize,
AutoStream: false,
AutoReplay: true,
Streams: make(map[string]*Stream),
streams: make(map[string]*Stream),
Headers: map[string]string{},
}
}
Expand All @@ -53,7 +55,7 @@ func NewWithCallback(onSubscribe, onUnsubscribe func(streamID string, sub *Subsc
BufferSize: DefaultBufferSize,
AutoStream: false,
AutoReplay: true,
Streams: make(map[string]*Stream),
streams: make(map[string]*Stream),
Headers: map[string]string{},
OnSubscribe: onSubscribe,
OnUnsubscribe: onUnsubscribe,
Expand All @@ -62,64 +64,68 @@ func NewWithCallback(onSubscribe, onUnsubscribe func(streamID string, sub *Subsc

// Close shuts down the server, closes all of the streams and connections
func (s *Server) Close() {
s.mu.Lock()
defer s.mu.Unlock()
s.muStreams.Lock()
defer s.muStreams.Unlock()

for id := range s.Streams {
s.Streams[id].quit <- struct{}{}
delete(s.Streams, id)
for id := range s.streams {
s.streams[id].close()
delete(s.streams, id)
}
}

// CreateStream will create a new stream and register it
func (s *Server) CreateStream(id string) *Stream {
s.mu.Lock()
defer s.mu.Unlock()
s.muStreams.Lock()
defer s.muStreams.Unlock()

if s.Streams[id] != nil {
return s.Streams[id]
if s.streams[id] != nil {
return s.streams[id]
}

str := newStream(id, s.BufferSize, s.AutoReplay, s.AutoStream, s.OnSubscribe, s.OnUnsubscribe)
str.run()

s.Streams[id] = str
s.streams[id] = str

return str
}

// RemoveStream will remove a stream
func (s *Server) RemoveStream(id string) {
s.mu.Lock()
defer s.mu.Unlock()
s.muStreams.Lock()
defer s.muStreams.Unlock()

if s.Streams[id] != nil {
s.Streams[id].close()
delete(s.Streams, id)
if s.streams[id] != nil {
s.streams[id].close()
delete(s.streams, id)
}
}

// StreamExists checks whether a stream by a given id exists
func (s *Server) StreamExists(id string) bool {
s.mu.Lock()
defer s.mu.Unlock()

return s.Streams[id] != nil
return s.getStream(id) != nil
}

// Publish sends a mesage to every client in a streamID
// Publish sends a mesage to every client in a streamID.
// If the stream's buffer is full, it blocks until the message is sent out to
// all subscribers (but not necessarily arrived the clients), or when the
// stream is closed.
func (s *Server) Publish(id string, event *Event) {
s.mu.Lock()
defer s.mu.Unlock()
if s.Streams[id] != nil {
s.Streams[id].event <- s.process(event)
stream := s.getStream(id)
if stream == nil {
return
}

select {
case <-stream.quit:
case stream.event <- s.process(event):
}
}

func (s *Server) getStream(id string) *Stream {
s.mu.Lock()
defer s.mu.Unlock()
return s.Streams[id]
s.muStreams.RLock()
defer s.muStreams.RUnlock()
return s.streams[id]
}

func (s *Server) process(event *Event) *Event {
Expand Down
6 changes: 5 additions & 1 deletion stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package sse

import (
"net/url"
"sync"
"sync/atomic"
)

Expand All @@ -14,6 +15,7 @@ type Stream struct {
ID string
event chan *Event
quit chan struct{}
quitOnce sync.Once
register chan *Subscriber
deregister chan *Subscriber
subscribers []*Subscriber
Expand Down Expand Up @@ -87,7 +89,9 @@ func (str *Stream) run() {
}

func (str *Stream) close() {
str.quit <- struct{}{}
str.quitOnce.Do(func() {
close(str.quit)
})
}

func (str *Stream) getSubIndex(sub *Subscriber) int {
Expand Down
5 changes: 3 additions & 2 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ func TestStreamRemoveSubscriber(t *testing.T) {
s.run()
defer s.close()

s.addSubscriber(0, nil)
sub := s.addSubscriber(0, nil)
time.Sleep(time.Millisecond * 100)
s.deregister <- sub
time.Sleep(time.Millisecond * 100)
s.removeSubscriber(0)

assert.Equal(t, 0, s.getSubscriberCount())
}
Expand Down

0 comments on commit 2287b9d

Please sign in to comment.