diff --git a/semaphore.go b/semaphore.go index 835fb48..236b3ae 100644 --- a/semaphore.go +++ b/semaphore.go @@ -77,13 +77,16 @@ func (s *semaphore) Acquire(ctx context.Context, n int) error { if n <= 0 { panic("n must be positive number") } + var ctxDoneCh <-chan struct{} + if ctx != nil { + ctxDoneCh = ctx.Done() + } for { - if ctx != nil { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } + // check if context is done + select { + case <-ctxDoneCh: + return ctx.Err() + default: } // get current semaphore count and limit @@ -108,18 +111,12 @@ func (s *semaphore) Acquire(ctx context.Context, n int) error { broadcastCh := s.broadcastCh s.lock.RUnlock() - if ctx != nil { - select { - case <-ctx.Done(): - return ctx.Err() - // waiting for broadcast signal - case <-broadcastCh: - } - } else { - select { - // waiting for broadcast signal - case <-broadcastCh: - } + select { + // check if context is done + case <-ctxDoneCh: + return ctx.Err() + // waiting for broadcast signal + case <-broadcastCh: } } }