diff --git a/bucket.go b/bucket.go index 5c297bf..d31b121 100644 --- a/bucket.go +++ b/bucket.go @@ -131,3 +131,16 @@ func (b *bucket) Flush(reinsert func(*Timer)) { b.SetExpiration(-1) } + +// Clear timer list +func (b *bucket) Clear() { + b.mu.Lock() + defer b.mu.Unlock() + b.timers = b.timers.Init() +} + +func (b *bucket) Len() int { + b.mu.Lock() + defer b.mu.Unlock() + return b.timers.Len() +} diff --git a/delayqueue/delayqueue.go b/delayqueue/delayqueue.go index 82e6a9d..7a07b1f 100644 --- a/delayqueue/delayqueue.go +++ b/delayqueue/delayqueue.go @@ -80,6 +80,13 @@ func (pq *priorityQueue) PeekAndShift(max int64) (*item, int64) { return item, 0 } +// Clear the priorityQueue +func (pq *priorityQueue) Clear() { + if pq.Len() != 0 { + *pq = nil + } +} + // The end of PriorityQueue implementation. // DelayQueue is an unbounded blocking queue of *Delayed* elements, in which @@ -184,3 +191,13 @@ exit: // Reset the states atomic.StoreInt32(&dq.sleeping, 0) } + +func (dq *DelayQueue) Len() int { + return dq.pq.Len() +} + +func (dq *DelayQueue) Clear() { + dq.mu.Lock() + dq.pq.Clear() + dq.mu.Unlock() +} diff --git a/timingwheel.go b/timingwheel.go index e2c64e0..830831e 100644 --- a/timingwheel.go +++ b/timingwheel.go @@ -2,6 +2,7 @@ package timingwheel import ( "errors" + "sync" "sync/atomic" "time" "unsafe" @@ -16,6 +17,7 @@ type TimingWheel struct { interval int64 // in milliseconds currentTime int64 // in milliseconds + mu sync.RWMutex buckets []*bucket queue *delayqueue.DelayQueue @@ -71,7 +73,13 @@ func (tw *TimingWheel) add(t *Timer) bool { } else if t.expiration < currentTime+tw.interval { // Put it into its own bucket virtualID := t.expiration / tw.tick + tw.mu.RLock() + if tw.buckets == nil { + tw.mu.RUnlock() + return false + } b := tw.buckets[virtualID%tw.wheelSize] + tw.mu.RUnlock() b.Add(t) // Set the bucket expiration time @@ -109,6 +117,9 @@ func (tw *TimingWheel) add(t *Timer) bool { // addOrRun inserts the timer t into the current timing wheel, or run the // timer's task if it has already expired. func (tw *TimingWheel) addOrRun(t *Timer) { + if tw.IsStopped() { + return + } if !tw.add(t) { // Already expired @@ -160,8 +171,41 @@ func (tw *TimingWheel) Start() { // not wait for the task to complete before returning. If the caller needs to // know whether the task is completed, it must coordinate with the task explicitly. func (tw *TimingWheel) Stop() { + if tw.IsStopped() { + return + } close(tw.exitC) tw.waitGroup.Wait() + tw.clear() +} + +func (tw *TimingWheel) clear() { + tw.queue.Clear() + tw.mu.Lock() + for _, b := range tw.buckets { + b.Clear() + } + tw.buckets = nil + tw.mu.Unlock() + // Try to clear the overflow wheel if present + overflowWheel := atomic.LoadPointer(&tw.overflowWheel) + if overflowWheel != nil { + (*TimingWheel)(overflowWheel).clear() + } +} + +func (tw *TimingWheel) Len() int { + l := 0 + tw.mu.Lock() + for i := 0; i < len(tw.buckets); i++ { + l += tw.buckets[i].Len() + } + tw.mu.Unlock() + overflowWheel := atomic.LoadPointer(&tw.overflowWheel) + if overflowWheel != nil { + l += (*TimingWheel)(overflowWheel).Len() + } + return l } // AfterFunc waits for the duration to elapse and then calls f in its own goroutine. @@ -224,3 +268,12 @@ func (tw *TimingWheel) ScheduleFunc(s Scheduler, f func()) (t *Timer) { return } + +func (tw *TimingWheel) IsStopped() bool { + select { + case <-tw.exitC: + return true + default: + } + return false +} diff --git a/timingwheel_benchmark_test.go b/timingwheel_benchmark_test.go index f9bd7f1..58927f8 100644 --- a/timingwheel_benchmark_test.go +++ b/timingwheel_benchmark_test.go @@ -72,3 +72,19 @@ func BenchmarkStandardTimer_StartStop(b *testing.B) { }) } } + +func BenchmarkTimingWheel_KeepStartStop(b *testing.B) { + var tw *timingwheel.TimingWheel + for j := 0; j < 10; j++ { + b.ResetTimer() + tw = timingwheel.NewTimingWheel(1*time.Minute, 20) + tw.Start() + l := 100 + for i := 0; i < l; i++ { + tw.AfterFunc(time.Duration(i+1)*time.Minute, func() { + }) + } + tw.Stop() + b.StopTimer() + } +} diff --git a/timingwheel_test.go b/timingwheel_test.go index 8f97403..0503010 100644 --- a/timingwheel_test.go +++ b/timingwheel_test.go @@ -89,3 +89,101 @@ func TestTimingWheel_ScheduleFunc(t *testing.T) { } } } + +func TestTimingWheel_IsStopped(t *testing.T) { + tw := timingwheel.NewTimingWheel(time.Millisecond, 20) + tw.Start() + if tw.IsStopped() { + t.Errorf("IsStopped() = true before stop") + } + tw.Stop() + if !tw.IsStopped() { + t.Errorf("IsStopped() = false after stop") + } + // test stop 2 times + tw.Stop() +} + +func TestTimingWheel_Len(t *testing.T) { + type fields struct { + tw *timingwheel.TimingWheel + len int + } + tests := []struct { + name string + fields fields + want int + }{ + { + name: "", + fields: fields{ + tw: timingwheel.NewTimingWheel(1*time.Millisecond, 20), + len: 0, + }, + want: 0, + }, + { + name: "", + fields: fields{ + tw: timingwheel.NewTimingWheel(1*time.Second, 20), + len: 100, + }, + want: 100, + }, + { + name: "", + fields: fields{ + tw: timingwheel.NewTimingWheel(1*time.Minute, 20), + len: 100, + }, + want: 100, + }, + { + name: "", + fields: fields{ + tw: timingwheel.NewTimingWheel(1*time.Minute, 20), + len: 10000, + }, + want: 10000, + }, + { + name: "", + fields: fields{ + tw: timingwheel.NewTimingWheel(1*time.Minute, 200), + len: 100, + }, + want: 100, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tw := tt.fields.tw + tw.Start() + defer tw.Stop() + for i := 0; i < tt.fields.len; i++ { + tw.AfterFunc(time.Duration(i+1)*time.Minute, func() { + }) + } + if got := tw.Len(); got != tt.want { + t.Errorf("Len() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTimingWheel_clear(t *testing.T) { + tw := timingwheel.NewTimingWheel(1*time.Minute, 20) + tw.Start() + l := 10000 + for i := 0; i < l; i++ { + tw.AfterFunc(time.Duration(i+1)*time.Minute, func() { + }) + } + if tw.Len() != l { + t.Errorf("add events fail") + } + tw.Stop() + if tw.Len() != 0 { + t.Errorf("clear events fail. tw.Len(): %d", tw.Len()) + } +}