diff --git a/schedule.go b/schedule.go index b12c18a..331c69b 100644 --- a/schedule.go +++ b/schedule.go @@ -1,13 +1,18 @@ package schedule -import "time" +import ( + "sync" + "sync/atomic" + "time" +) // Task holds information about the running task and can be used to stop running tasks. type Task struct { stop chan struct{} nextExecution time.Time startedAt time.Time - stopped bool + stopped int32 // 0 means active, 1 means stopped + once sync.Once } // newTask creates a new Task. @@ -35,7 +40,7 @@ func (s *Task) ExecutesIn() time.Duration { // IsActive returns true if the scheduler is active. func (s *Task) IsActive() bool { - return !s.stopped + return atomic.LoadInt32(&s.stopped) == 0 } // Wait blocks until the scheduler is stopped. @@ -46,12 +51,10 @@ func (s *Task) Wait() { // Stop stops the scheduler. func (s *Task) Stop() { - if s.stopped { - return - } - - s.stopped = true - close(s.stop) + s.once.Do(func() { + atomic.StoreInt32(&s.stopped, 1) + close(s.stop) + }) } // After executes the task after the given duration. @@ -59,13 +62,18 @@ func (s *Task) Stop() { func After(duration time.Duration, task func()) *Task { scheduler := newTask() scheduler.nextExecution = time.Now().Add(duration) + timer := time.NewTimer(duration) go func() { select { - case <-time.After(duration): + case <-timer.C: task() scheduler.Stop() case <-scheduler.stop: + // If the task is stopped before the timer fires, stop the timer. + if !timer.Stop() { + <-timer.C // drain if necessary + } return } }() @@ -78,13 +86,21 @@ func After(duration time.Duration, task func()) *Task { func At(t time.Time, task func()) *Task { scheduler := newTask() scheduler.nextExecution = t + d := time.Until(t) + if d < 0 { + d = 0 + } + timer := time.NewTimer(d) go func() { select { - case <-time.After(time.Until(t)): + case <-timer.C: task() scheduler.Stop() case <-scheduler.stop: + if !timer.Stop() { + <-timer.C + } return } }() @@ -97,23 +113,20 @@ func At(t time.Time, task func()) *Task { func Every(interval time.Duration, task func() bool) *Task { scheduler := newTask() scheduler.nextExecution = time.Now().Add(interval) - ticker := time.NewTicker(interval) go func() { for { select { case <-ticker.C: - res := task() - if !res { + if !task() { scheduler.Stop() + ticker.Stop() + return } - scheduler.nextExecution = time.Now().Add(interval) - case <-scheduler.stop: ticker.Stop() - return } }