diff --git a/common/event_emitter.go b/common/event_emitter.go index 02f9df8e8..9be567b19 100644 --- a/common/event_emitter.go +++ b/common/event_emitter.go @@ -22,6 +22,7 @@ package common import ( "context" + "sync" ) // Ensure BaseEventEmitter implements the EventEmitter interface. @@ -96,9 +97,17 @@ type NavigationEvent struct { err error } +type queue struct { + writeMutex sync.Mutex + write []Event + readMutex sync.Mutex + read []Event +} + type eventHandler struct { - ctx context.Context - ch chan Event + ctx context.Context + ch chan Event + queue *queue } // EventEmitter that all event emitters need to implement. @@ -114,8 +123,10 @@ type syncFunc func() (done chan struct{}) // BaseEventEmitter emits events to registered handlers. type BaseEventEmitter struct { - handlers map[string][]eventHandler - handlersAll []eventHandler + handlers map[string][]*eventHandler + handlersAll []*eventHandler + + queues map[chan Event]*queue syncCh chan syncFunc ctx context.Context @@ -124,9 +135,10 @@ type BaseEventEmitter struct { // NewBaseEventEmitter creates a new instance of a base event emitter. func NewBaseEventEmitter(ctx context.Context) BaseEventEmitter { bem := BaseEventEmitter{ - handlers: make(map[string][]eventHandler), + handlers: make(map[string][]*eventHandler), syncCh: make(chan syncFunc), ctx: ctx, + queues: make(map[chan Event]*queue), } go bem.syncAll(ctx) return bem @@ -165,14 +177,32 @@ func (e *BaseEventEmitter) sync(fn func()) { } func (e *BaseEventEmitter) emit(event string, data interface{}) { - emitEvent := func(eh eventHandler) { + emitEvent := func(eh *eventHandler) { + eh.queue.readMutex.Lock() + defer eh.queue.readMutex.Unlock() + + // We try to read from the read queue (queue.read). + // If there isn't anything on the read queue, then there must + // be something being populated by the synched emitTo + // func below. + // Swap around the read queue with the write queue. + // Queue is now being populated again by emitTo, and all + // emitEvent goroutines can continue to consume from + // the read queue until that is again depleted. + if len(eh.queue.read) == 0 { + eh.queue.writeMutex.Lock() + eh.queue.read, eh.queue.write = eh.queue.write, eh.queue.read + eh.queue.writeMutex.Unlock() + } + select { - case eh.ch <- Event{event, data}: + case eh.ch <- eh.queue.read[0]: + eh.queue.read = eh.queue.read[1:] case <-eh.ctx.Done(): // TODO: handle the error } } - emitTo := func(handlers []eventHandler) (updated []eventHandler) { + emitTo := func(handlers []*eventHandler) (updated []*eventHandler) { for i := 0; i < len(handlers); { handler := handlers[i] select { @@ -180,6 +210,10 @@ func (e *BaseEventEmitter) emit(event string, data interface{}) { handlers = append(handlers[:i], handlers[i+1:]...) continue default: + handler.queue.writeMutex.Lock() + handler.queue.write = append(handler.queue.write, Event{typ: event, data: data}) + handler.queue.writeMutex.Unlock() + go emitEvent(handler) i++ } @@ -195,13 +229,14 @@ func (e *BaseEventEmitter) emit(event string, data interface{}) { // On registers a handler for a specific event. func (e *BaseEventEmitter) on(ctx context.Context, events []string, ch chan Event) { e.sync(func() { + q, ok := e.queues[ch] + if !ok { + q = &queue{} + e.queues[ch] = q + } + for _, event := range events { - _, ok := e.handlers[event] - if !ok { - e.handlers[event] = make([]eventHandler, 0) - } - eh := eventHandler{ctx, ch} - e.handlers[event] = append(e.handlers[event], eh) + e.handlers[event] = append(e.handlers[event], &eventHandler{ctx: ctx, ch: ch, queue: q}) } }) } @@ -209,6 +244,12 @@ func (e *BaseEventEmitter) on(ctx context.Context, events []string, ch chan Even // OnAll registers a handler for all events. func (e *BaseEventEmitter) onAll(ctx context.Context, ch chan Event) { e.sync(func() { - e.handlersAll = append(e.handlersAll, eventHandler{ctx, ch}) + q, ok := e.queues[ch] + if !ok { + q = &queue{} + e.queues[ch] = q + } + + e.handlersAll = append(e.handlersAll, &eventHandler{ctx: ctx, ch: ch, queue: q}) }) } diff --git a/common/event_emitter_test.go b/common/event_emitter_test.go index 2e37a8d9d..703832378 100644 --- a/common/event_emitter_test.go +++ b/common/event_emitter_test.go @@ -23,8 +23,10 @@ package common import ( "context" "testing" + "time" "github.com/chromedp/cdproto" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -125,3 +127,187 @@ func TestEventEmitterAllEvents(t *testing.T) { }) }) } + +//nolint:gocognit +func TestBaseEventEmitter(t *testing.T) { + t.Parallel() + + t.Run("order of emitted events kept", func(t *testing.T) { + // Test description + // + // 1. Emit many events from the emitWorker. + // 2. Handler receives the emitted events. + // + // Success criteria: Ensure that the ordering of events is + // received in the order they're emitted. + + t.Parallel() + + eventName := "AtomicIntEvent" + maxInt := 100 + + ctx, cancel := context.WithCancel(context.Background()) + emitter := NewBaseEventEmitter(ctx) + ch := make(chan Event) + emitter.on(ctx, []string{eventName}, ch) + + var expectedI int + handler := func() { + defer cancel() + + for expectedI != maxInt { + e := <-ch + + i, ok := e.data.(int) + if !ok { + assert.FailNow(t, "unexpected type read from channel", e.data) + } + + assert.Equal(t, eventName, e.typ) + assert.Equal(t, expectedI, i) + + expectedI++ + } + + close(ch) + } + go handler() + + emitWorker := func() { + for i := 0; i < maxInt; i++ { + emitter.emit(eventName, i) + } + } + go emitWorker() + + select { + case <-ctx.Done(): + case <-time.After(time.Second * 2): + assert.FailNow(t, "test timed out, deadlock?") + } + }) + + t.Run("order of emitted different event types kept", func(t *testing.T) { + // Test description + // + // 1. Emit many different event types from the emitWorker. + // 2. Handler receives the emitted events. + // + // Success criteria: Ensure that the ordering of events is + // received in the order they're emitted. + + t.Parallel() + + eventName1 := "AtomicIntEvent1" + eventName2 := "AtomicIntEvent2" + eventName3 := "AtomicIntEvent3" + eventName4 := "AtomicIntEvent4" + maxInt := 100 + + ctx, cancel := context.WithCancel(context.Background()) + emitter := NewBaseEventEmitter(ctx) + ch := make(chan Event) + // Calling on twice to ensure that the same queue is used + // internally for the same channel and handler. + emitter.on(ctx, []string{eventName1, eventName2}, ch) + emitter.on(ctx, []string{eventName3, eventName4}, ch) + + var expectedI int + handler := func() { + defer cancel() + + for expectedI != maxInt { + e := <-ch + + i, ok := e.data.(int) + if !ok { + assert.FailNow(t, "unexpected type read from channel", e.data) + } + + assert.Equal(t, expectedI, i) + + expectedI++ + } + + close(ch) + } + go handler() + + emitWorker := func() { + for i := 0; i < maxInt; i += 4 { + emitter.emit(eventName1, i) + emitter.emit(eventName2, i+1) + emitter.emit(eventName3, i+2) + emitter.emit(eventName4, i+3) + } + } + go emitWorker() + + select { + case <-ctx.Done(): + case <-time.After(time.Second * 2): + assert.FailNow(t, "test timed out, deadlock?") + } + }) + + t.Run("handler can emit without deadlocking", func(t *testing.T) { + // Test description + // + // 1. Emit many events from the emitWorker. + // 2. Handler receives emitted events (AtomicIntEvent1). + // 3. Handler emits event as AtomicIntEvent2. + // 4. Handler received emitted events again (AtomicIntEvent2). + // + // Success criteria: No deadlock should occur between receiving, + // emitting, and receiving of events. + + t.Parallel() + + eventName1 := "AtomicIntEvent1" + eventName2 := "AtomicIntEvent2" + maxInt := 100 + + ctx, cancel := context.WithCancel(context.Background()) + emitter := NewBaseEventEmitter(ctx) + ch := make(chan Event) + emitter.on(ctx, []string{eventName1, eventName2}, ch) + + var expectedI2 int + handler := func() { + defer cancel() + + for expectedI2 != maxInt { + e := <-ch + + switch e.typ { + case eventName1: + i, ok := e.data.(int) + if !ok { + assert.FailNow(t, "unexpected type read from channel", e.data) + } + emitter.emit(eventName2, i) + case eventName2: + expectedI2++ + default: + assert.FailNow(t, "unexpected event type received") + } + } + + close(ch) + } + go handler() + + emitWorker := func() { + for i := 0; i < maxInt; i++ { + emitter.emit(eventName1, i) + } + } + go emitWorker() + + select { + case <-ctx.Done(): + case <-time.After(time.Second * 2): + assert.FailNow(t, "test timed out, deadlock?") + } + }) +}