Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

timer: Add a NewTimer method to Clock #13

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions clock.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ func (c defaultClock) ContextWithTimeout(ctx context.Context, d time.Duration) (
return context.WithTimeout(ctx, d)
}

func (c defaultClock) NewTimer(d time.Duration) Timer {
t := time.NewTimer(d)
return &defaultTimer{Timer: t}
}

// DefaultClock returns a clock that minimally wraps the `time` package
func DefaultClock() Clock {
return defaultClock{}
Expand Down Expand Up @@ -103,4 +108,14 @@ type Clock interface {
// uses the clock to determine the when the timeout has elapsed. Cause is
// ignored in Go 1.20 and earlier.
ContextWithTimeoutCause(ctx context.Context, d time.Duration, cause error) (context.Context, context.CancelFunc)

// NewTimer returns a Timer implementation which will fire after at
// least the specified duration [d]. The Ch() method returns a channel,
// and should be called inline with the receive or select case.
//
// Timers are most useful in select/case blocks. For simple cases,
// SleepFor should be preferred.
//
// Stop() is inherently racy. Be wary of the return value.
NewTimer(d time.Duration) Timer
}
98 changes: 91 additions & 7 deletions fake/fake_clock.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
// testing and skipping through timestamps without having to actually sleep in
// the test.
type Clock struct {
mu sync.Mutex
current time.Time
// sleepers contains a map from a channel on which that
// sleeper is sleeping to a target-time. When time is advanced past a
Expand All @@ -28,9 +27,8 @@ type Clock struct {
// protection necessary).
cbsWG sync.WaitGroup

// cond is broadcasted() upon any sleep or wakeup event (mutations to
// sleepers or cbs).
cond sync.Cond
// timer tracker
timerTrack timerTracker

// counter tracking the number of wakeups (protected by mu).
wakeups int
Expand All @@ -51,6 +49,21 @@ type Clock struct {
// counter tracking the number of callbacks that have ever been
// registered (via AfterFunc) (protected by mu).
callbacksAggregate int

// counter tracking the number of extracted channels (protected by mu).
extractedChans int

// counter tracking the aggregate number of extracted channels (protected by mu).
extractedChansAggregate int

// counter tracking the number of number of aggregate signaled timer channels
signaledChans int

// cond is broadcasted() upon any sleep or wakeup event (mutations to
// sleepers or cbs).
cond sync.Cond

mu sync.Mutex
}

var _ clocks.Clock = (*Clock)(nil)
Expand All @@ -62,7 +75,11 @@ func NewClock(initialTime time.Time) *Clock {
sleepers: map[chan<- struct{}]time.Time{},
cbs: map[*stopTimer]time.Time{},
cond: sync.Cond{},
timerTrack: timerTracker{
timers: map[*fakeTimer]time.Time{},
},
}
fc.timerTrack.fc = &fc
fc.cond.L = &fc.mu
return &fc
}
Expand All @@ -77,6 +94,10 @@ func (f *Clock) setClockLocked(t time.Time, cbRunningWG *sync.WaitGroup) int {
awoken++
}
}

timerWakeRes := f.timerTrack.wakeup(t)
f.signaledChans += timerWakeRes.notified

cbsRun := 0
for s, target := range f.cbs {
if target.Sub(t) <= 0 {
Expand All @@ -95,7 +116,7 @@ func (f *Clock) setClockLocked(t time.Time, cbRunningWG *sync.WaitGroup) int {
f.callbackExecs += cbsRun
f.current = t
f.cond.Broadcast()
return awoken + cbsRun
return awoken + cbsRun + timerWakeRes.awoken
}

// SetClock skips the FakeClock to the specified time (forward or backwards) The
Expand Down Expand Up @@ -344,6 +365,22 @@ func (f *Clock) AfterFunc(d time.Duration, cb func()) clocks.StopTimer {
return s
}

// NewTimer creates a new Timer
func (f *Clock) NewTimer(d time.Duration) clocks.Timer {
target := f.Now().Add(d)
// Capacity 1 so sending never blocks
ch := make(chan time.Time, 1)

ft := fakeTimer{
ch: ch,
tracker: &f.timerTrack,
}

f.timerTrack.registerTimer(&ft, target)

return &ft
}

// NumCallbackExecs returns the number of registered callbacks that have been
// executed due to time advancement.
func (f *Clock) NumCallbackExecs() int {
Expand Down Expand Up @@ -396,8 +433,8 @@ func (f *Clock) AwaitRegisteredCallbacks(n int) {
}
}

// AwaitTimerAborts waits until the aggregate number of registered callbacks
// (via AfterFunc) exceeds its argument.
// AwaitTimerAborts waits until the aggregate number of aborted callbacks
// (via AfterFunc) or timers exceeds its argument.
func (f *Clock) AwaitTimerAborts(n int) {
f.mu.Lock()
defer f.mu.Unlock()
Expand All @@ -406,6 +443,53 @@ func (f *Clock) AwaitTimerAborts(n int) {
}
}

// AwaitAggExtractedChans waits the aggregate number of calls to Ch() on
// timers to equal or exceed its argument.
// For this method to be most useful, users of timers should not store the
// value of .Ch(). Instead, call .Ch(), dereference the pointer, and attempt a
// receive immediately, as in case <-*timer.Ch().
func (f *Clock) AwaitAggExtractedChans(n int) {
f.mu.Lock()
defer f.mu.Unlock()
for f.extractedChansAggregate < n {
f.cond.Wait()
}
}

// NumAggExtractedChans returns the aggregate number of calls to Ch() on
// timers.
// For this method to be most useful, users of timers should not store the
// value of .Ch(). Instead, call .Ch(), dereference the pointer, and attempt a
// receive immediately, as in case <-*timer.Ch().
func (f *Clock) NumAggExtractedChans() int {
f.mu.Lock()
defer f.mu.Unlock()
return f.extractedChansAggregate
}

// numExtractedChans returns the aggregate number of calls to Ch() on
// timers.
func (f *Clock) numExtractedChans() int {
f.mu.Lock()
defer f.mu.Unlock()
return f.extractedChans
}

// awaitExtractedChans waits the number of calls to Ch() on
// timers to equal or exceed its argument.
func (f *Clock) awaitExtractedChans(n int) {
f.mu.Lock()
defer f.mu.Unlock()
for f.extractedChans < n {
f.cond.Wait()
}
}

// RegisteredTimers returns the execution-times of registered timers.
func (f *Clock) RegisteredTimers() []time.Time {
return f.timerTrack.registeredTimers()
}

// WaitAfterFuncs blocks until all currently running AfterFunc callbacks
// return.
func (f *Clock) WaitAfterFuncs() {
Expand Down
Loading