diff --git a/js/eventloop.go b/js/eventloop.go index 1d323245209..649b2ca6205 100644 --- a/js/eventloop.go +++ b/js/eventloop.go @@ -30,6 +30,7 @@ import ( type eventLoop struct { queueLock sync.Mutex queue []func() + started int wakeupCh chan struct{} // maybe use sync.Cond ? reservedCount int } @@ -60,10 +61,15 @@ func (e *eventLoop) RunOnLoop(f func()) { func (e *eventLoop) Reserve() func(func()) { e.queueLock.Lock() e.reservedCount++ + started := e.started e.queueLock.Unlock() return func(f func()) { e.queueLock.Lock() + if started != e.started { + e.queueLock.Unlock() + return + } e.queue = append(e.queue, f) e.reservedCount-- e.queueLock.Unlock() @@ -78,6 +84,10 @@ func (e *eventLoop) Reserve() func(func()) { // or the context is done //nolint:cyclop func (e *eventLoop) Start(ctx context.Context) { + e.queueLock.Lock() + e.started++ + e.reservedCount = 0 + e.queueLock.Unlock() done := ctx.Done() for { select { // check if done diff --git a/js/eventloop_test.go b/js/eventloop_test.go index 26487fc3e6c..96dbb5ce695 100644 --- a/js/eventloop_test.go +++ b/js/eventloop_test.go @@ -36,15 +36,15 @@ func TestBasicEventLoop(t *testing.T) { defer cancel() loop.RunOnLoop(func() { ran++ }) loop.Start(ctx) - require.Equal(t, ran, 1) + require.Equal(t, 1, ran) loop.RunOnLoop(func() { ran++ }) loop.RunOnLoop(func() { ran++ }) loop.Start(ctx) - require.Equal(t, ran, 3) + require.Equal(t, 3, ran) loop.RunOnLoop(func() { ran++; cancel() }) loop.RunOnLoop(func() { ran++ }) loop.Start(ctx) - require.Equal(t, ran, 4) + require.Equal(t, 4, ran) } func TestEventLoopReserve(t *testing.T) { @@ -66,6 +66,46 @@ func TestEventLoopReserve(t *testing.T) { start := time.Now() loop.Start(ctx) took := time.Since(start) - require.Equal(t, ran, 2) - require.Greater(t, took, time.Second) + require.Equal(t, 2, ran) + require.Less(t, time.Second, took) + require.Greater(t, time.Second+time.Millisecond*100, took) +} + +func TestEventLoopReserveStopBetweenStarts(t *testing.T) { + t.Parallel() + loop := newEventLoop() + var ran int + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + loop.RunOnLoop(func() { + ran++ + r := loop.Reserve() + go func() { + time.Sleep(time.Second) + r(func() { + ran++ + }) + }() + }) + go func() { + time.Sleep(200 * time.Millisecond) + cancel() + }() + loop.Start(ctx) + require.Equal(t, 1, ran) + + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + loop.RunOnLoop(func() { + ran++ + r := loop.Reserve() + go func() { + time.Sleep(time.Second) + r(func() { + ran++ + }) + }() + }) + loop.Start(ctx) + require.Equal(t, 3, ran) }