From b61467e13d9a0d1a7076a986a206f0ca612010c9 Mon Sep 17 00:00:00 2001 From: David Lavieri Date: Fri, 24 Oct 2025 23:12:02 +0200 Subject: [PATCH] feat: Add Close() method, refactor and fix tests --- .github/workflows/build.yaml | 2 +- krun.go | 82 ++--- krun_test.go | 573 ++++++++++------------------------- 3 files changed, 208 insertions(+), 449 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 0635a07..ac1c985 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -24,7 +24,7 @@ jobs: run: go mod download - name: "Run tests" - run: go test -v ./... + run: go test -count 10 -v ./... - name: "Build" run: go build -o ./bin/krun . diff --git a/krun.go b/krun.go index e47fcf6..ff27375 100644 --- a/krun.go +++ b/krun.go @@ -2,10 +2,14 @@ package krun import ( "context" + "errors" "sync" "time" ) +// ErrPoolClosed it is closed (hahah) +var ErrPoolClosed = errors.New("pool's closed") + type Result struct { Data interface{} Error error @@ -16,13 +20,17 @@ type Krun interface { Run(ctx context.Context, f Job) <-chan *Result Wait(ctx context.Context) Size() int + Close() error } type krun struct { - n int - waitSleep time.Duration - workers chan *worker - mu sync.RWMutex + poolSize int + closed bool + + workers chan *worker + mu sync.RWMutex + + wg sync.WaitGroup } type worker struct { job Job @@ -36,9 +44,12 @@ type Config struct { func New(cfg *Config) Krun { k := &krun{ - n: cfg.Size, - workers: make(chan *worker, cfg.Size), - waitSleep: cfg.WaitSleep, + poolSize: cfg.Size, + closed: false, + + workers: make(chan *worker, cfg.Size), + wg: sync.WaitGroup{}, + mu: sync.RWMutex{}, } for i := 0; i < cfg.Size; i++ { @@ -50,7 +61,7 @@ func New(cfg *Config) Krun { func (k *krun) Size() int { k.mu.RLock() - s := k.n + s := k.poolSize k.mu.RUnlock() return s } @@ -58,9 +69,10 @@ func (k *krun) Size() int { func (k *krun) Run(ctx context.Context, f Job) <-chan *Result { // get worker from the channel w := k.pop() + k.wg.Add(1) // assign Job to the worker and Run it - cr := make(chan *Result) + cr := make(chan *Result, 1) w.job = f w.result = cr go k.work(ctx, w) @@ -70,27 +82,37 @@ func (k *krun) Run(ctx context.Context, f Job) <-chan *Result { } func (k *krun) Wait(ctx context.Context) { - k.mu.RLock() - n := k.n - k.mu.RUnlock() + done := make(chan struct{}) + + go func() { + k.wg.Wait() + close(done) + }() - if k.len() == n { + select { + case <-ctx.Done(): + return + case <-done: return } +} - for { - select { - case <-ctx.Done(): - return - case <-time.After(k.waitSleep): - // "wait" until all workers are back - if k.len() < n { - continue - } - - return - } +func (k *krun) Close() error { + k.mu.Lock() + if k.closed { + k.mu.Unlock() + return ErrPoolClosed } + k.closed = true + k.mu.Unlock() + + // Wait for all work to complete + k.wg.Wait() + + // Close worker channel + close(k.workers) + + return nil } func (k *krun) work(ctx context.Context, w *worker) { @@ -98,11 +120,11 @@ func (k *krun) work(ctx context.Context, w *worker) { d, err := w.job(ctx) // send Result into the caller channel - // this will block until is read w.result <- &Result{d, err} // return worker to Krun k.push(w) + k.wg.Done() } func (k *krun) push(w *worker) { k.workers <- w @@ -111,11 +133,3 @@ func (k *krun) push(w *worker) { func (k *krun) pop() *worker { return <-k.workers } - -func (k *krun) len() int { - k.mu.RLock() - l := len(k.workers) - k.mu.RUnlock() - - return l -} diff --git a/krun_test.go b/krun_test.go index b55a247..3fe07e6 100644 --- a/krun_test.go +++ b/krun_test.go @@ -3,7 +3,6 @@ package krun import ( "context" "errors" - "sync" "testing" "time" ) @@ -14,32 +13,12 @@ func TestNew(t *testing.T) { t.Run("returns a Krun", func(t *testing.T) { k := New(&Config{}) if k == nil { - t.Errorf("Expected Krun, got nil") + t.Fatalf("Expected Krun, got nil") } v, ok := k.(*krun) if !ok { - t.Errorf("Expected *krun, got %T", v) - } - }) - - t.Run("returns a Krun with the correct size", func(t *testing.T) { - k := New(&Config{Size: 5}).(*krun) - - if k.n != 5 { - t.Errorf("Expected 5, got %v", k.Size()) - } - - if len(k.workers) != 5 { - t.Errorf("Expected 5, got %v", len(k.workers)) - } - }) - - t.Run("returns a Krun with the correct waitSleep", func(t *testing.T) { - k := New(&Config{WaitSleep: time.Second}).(*krun) - - if k.waitSleep != time.Second { - t.Errorf("Expected 1s, got %v", k.waitSleep) + t.Fatalf("Expected *krun, got %T", v) } }) } @@ -48,12 +27,16 @@ func TestKrun_Size(t *testing.T) { t.Parallel() t.Run("returns the correct size", func(t *testing.T) { - k := krun{ - n: 5, - } + k := New(&Config{Size: 5}) if k.Size() != 5 { - t.Errorf("Expected 5, got %v", k.Size()) + t.Fatalf("Expected 5, got %v", k.Size()) + } + + k = New(&Config{Size: 2}) + + if k.Size() != 2 { + t.Fatalf("Expected 2, got %v", k.Size()) } }) } @@ -62,33 +45,65 @@ func TestRun(t *testing.T) { t.Parallel() t.Run("returns a channel", func(t *testing.T) { - k := krun{ - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, - } - k.push(&worker{}) + k := New(&Config{Size: 1}) r := k.Run(context.Background(), func(ctx context.Context) (interface{}, error) { return nil, nil }) if r == nil { - t.Errorf("Expected channel, got nil") + t.Fatalf("Expected channel, got nil") } }) - t.Run("blocks waiting for available worker", func(t *testing.T) { - k := krun{ - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, + t.Run("runs and return value", func(t *testing.T) { + k := New(&Config{Size: 1}) + + r := <-k.Run(context.Background(), func(ctx context.Context) (interface{}, error) { + return "my-string", nil + }) + + switch tp := r.Data.(type) { + case string: + if tp != "my-string" { + t.Fatalf("expected \"my-string\", received: %s", tp) + } + + break + default: + t.Fatalf("expected string, got %t", tp) + } + }) + + t.Run("runs and return error", func(t *testing.T) { + k := New(&Config{Size: 1}) + + myErr := errors.New("something went wrong") + + r := <-k.Run(context.Background(), func(ctx context.Context) (interface{}, error) { + return nil, myErr + }) + + if r.Error == nil { + t.Fatalf("expected an error, got nil") + } else if !errors.Is(r.Error, myErr) { + t.Fatalf("expected error to equal: %v, got: %v", myErr, r.Error) } + }) + + t.Run("blocks waiting for available worker", func(t *testing.T) { + ctx := t.Context() + k := New(&Config{Size: 1}) errChan := make(chan error) + k.Run(ctx, func(ctx context.Context) (interface{}, error) { + time.Sleep(time.Millisecond * 5) + return nil, nil + }) + go func() { - k.Run(context.Background(), func(ctx context.Context) (interface{}, error) { + k.Run(ctx, func(ctx context.Context) (interface{}, error) { return nil, nil }) errChan <- errors.New("expected k.Run to block") @@ -96,72 +111,28 @@ func TestRun(t *testing.T) { select { case e := <-errChan: - t.Errorf(e.Error()) + t.Fatalf(e.Error()) case <-time.After(time.Millisecond): return } }) - t.Run("sends off the job for work and return result", func(t *testing.T) { - k := krun{ - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, - } - k.push(&worker{}) - - r := k.Run(context.Background(), func(ctx context.Context) (interface{}, error) { - return 5, nil - }) - - res := <-r - if i, ok := res.Data.(int); !ok { - t.Errorf("Expected int, got %T", res.Data) - } else if i != 5 { - t.Errorf("Expected 5, got %v", i) - } - }) - t.Run("pass context to job", func(t *testing.T) { - k := krun{ - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, - } - k.push(&worker{}) + k := New(&Config{Size: 1}) ctx := context.Background() r := k.Run(ctx, func(ctx2 context.Context) (interface{}, error) { if ctx != ctx2 { - t.Errorf("Expected %v, got %v", ctx, ctx2) + t.Fatalf("Expected %v, got %v", ctx, ctx2) } return nil, nil }) if d := <-r; d.Error != nil { - t.Errorf("Expected nil, got %v", d.Error) + t.Fatalf("Expected nil, got %v", d.Error) } else if d.Data != nil { - t.Errorf("Expected nil, got %v", d.Data) - } - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - r = k.Run(ctx, func(ctx2 context.Context) (interface{}, error) { - select { - case <-ctx2.Done(): - return nil, ctx2.Err() - default: - t.Errorf("Expected context to be cancelled") - return nil, nil - } - }) - - if d := <-r; d.Error == nil { - t.Errorf("Expected error, got nil") - } else if !errors.Is(d.Error, context.Canceled) { - t.Errorf("Expected context.Canceled, got %v", d.Error) + t.Fatalf("Expected nil, got %v", d.Data) } }) } @@ -170,385 +141,159 @@ func TestKrun_Wait(t *testing.T) { t.Parallel() t.Run("blocks if workers are not done", func(t *testing.T) { - k := krun{ - waitSleep: time.Millisecond * 50, - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, - } + k := New(&Config{Size: 1}) ctx := context.Background() - wChan := make(chan struct{}) - // len 1, has no workers + release := make(chan struct{}) + resCh := k.Run(ctx, func(ctx context.Context) (interface{}, error) { + <-release + return "ok", nil + }) + + doneWait := make(chan struct{}) go func() { k.Wait(ctx) - wChan <- struct{}{} + close(doneWait) }() select { - case <-wChan: - t.Errorf("Expected k.Wait to block") - case <-time.After(time.Millisecond * 100): - break + case <-doneWait: + t.Fatalf("expected Wait to block while job is running") + case <-time.After(5 * time.Millisecond): } - }) - - t.Run("unblocks when workers are done", func(t *testing.T) { - k := krun{ - waitSleep: time.Millisecond * 50, - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, - } - ctx := context.Background() - wChan := make(chan struct{}) - // len 1, work is "done" - k.workers <- &worker{} - go func() { - k.Wait(ctx) - wChan <- struct{}{} - }() + close(release) - time.Sleep(time.Millisecond) + select { + case <-resCh: + case <-time.After(50 * time.Millisecond): + t.Fatalf("job did not complete in time") + } select { - case <-wChan: - break - case <-time.After(time.Millisecond): - t.Errorf("Expected k.Wait to unblock") + case <-doneWait: + case <-time.After(20 * time.Millisecond): + t.Fatalf("Wait did not return after jobs finished") } }) - t.Run("sleep between checks", func(t *testing.T) { - k := krun{ - waitSleep: time.Millisecond * 50, - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, - } + t.Run("unblocks when workers are done", func(t *testing.T) { + k := New(&Config{Size: 1}) ctx := context.Background() - wChan := make(chan struct{}) - // len 1, work is "done" - - var start time.Time - var end time.Time + doneWait := make(chan struct{}) go func() { - time.Sleep(time.Millisecond * 60) - k.workers <- &worker{} - }() - - go func() { - start = time.Now() k.Wait(ctx) - end = time.Now() - wChan <- struct{}{} + close(doneWait) }() - time.Sleep(time.Millisecond) - select { - case <-wChan: - if end.Sub(start) < time.Millisecond*50 { - t.Errorf("Expected k.Wait to sleep for at least 50ms") - } - case <-time.After(time.Millisecond * 100): - t.Errorf("Expected k.Wait to unblock") + case <-doneWait: + // ok + case <-time.After(5 * time.Millisecond): + t.Fatalf("expected Wait to return promptly when there are no jobs running") } }) - t.Run("context done unblock", func(t *testing.T) { - k := krun{ - waitSleep: time.Millisecond * 50, - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, - } + t.Run("unblocks when context time out", func(t *testing.T) { + k := New(&Config{Size: 1}) - var start time.Time - var end time.Time - c := make(chan struct{}) + release := make(chan struct{}) + _ = k.Run(context.Background(), func(ctx context.Context) (interface{}, error) { + <-release + return nil, nil + }) - go func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) - defer cancel() + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() - start = time.Now() + doneWait := make(chan struct{}) + go func() { k.Wait(ctx) - end = time.Now() - - c <- struct{}{} + close(doneWait) }() select { - case <-c: - if end.Sub(start) < time.Millisecond*100 { - t.Errorf("Expected k.Wait to sleep for at least 100ms") + case <-doneWait: + if time.Since(start) < 5*time.Millisecond { + t.Fatalf("expected Wait to respect context timeout; returned too early after %s", time.Since(start)) } - - break - case <-time.After(time.Millisecond * 200): - t.Errorf("Expected k.Wait to unblock") - } - }) -} - -func TestKrun_Work(t *testing.T) { - t.Parallel() - - t.Run("data is passed through", func(t *testing.T) { - k := krun{ - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, - } - - // data is passed through - rChan := make(chan *Result) - w := &worker{ - job: func(ctx context.Context) (interface{}, error) { - return 5, nil - }, - result: rChan, + case <-time.After(20 * time.Millisecond): + t.Fatalf("expected Wait to unblock due to context timeout") } - go k.work(context.Background(), w) - - r := <-rChan - if i, ok := r.Data.(int); !ok { - t.Errorf("Expected int, got %T", r.Data) - } else if i != 5 { - t.Errorf("Expected 5, got %v", i) - } - close(rChan) - - _ = <-k.workers - close(k.workers) + close(release) + k.Wait(context.Background()) }) +} - t.Run("error is passed through", func(t *testing.T) { - k := krun{ - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, - } - - // error is passed through - errChan := make(chan *Result) - testErr := errors.New("test error") - w := &worker{ - job: func(ctx context.Context) (interface{}, error) { - return nil, testErr - }, - result: errChan, - } - - go k.work(context.Background(), w) - - r := <-errChan - if !errors.Is(r.Error, testErr) { - t.Errorf("Expected %v, got %v", testErr, r.Error) +func TestKrun_Close(t *testing.T) { + t.Run("first close returns nil", func(t *testing.T) { + k := New(&Config{Size: 1}) + if err := k.Close(); err != nil { + t.Fatalf("expected nil, got %v", err) } - close(errChan) - - _ = <-k.workers - close(k.workers) }) - t.Run("context is passed through", func(t *testing.T) { - k := krun{ - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, - } - - // context carries info given to it - ctxChan := make(chan *Result) - ctx := context.WithValue(context.Background(), "test", "test") - w := &worker{ - job: func(ctx context.Context) (interface{}, error) { - return ctx.Value("test"), nil - }, - result: ctxChan, + t.Run("double close returns ErrPoolClosed", func(t *testing.T) { + k := New(&Config{Size: 1}) + if err := k.Close(); err != nil { + t.Fatalf("expected first close to be nil, got %v", err) } - - go k.work(ctx, w) - - r := <-ctxChan - if r.Data != "test" { - t.Errorf("Expected test, got %v", r.Data) + if err := k.Close(); !errors.Is(err, ErrPoolClosed) { + t.Fatalf("expected ErrPoolClosed, got %v", err) } - close(ctxChan) - - _ = <-k.workers - close(k.workers) }) - t.Run("worker is pushed back to the channel", func(t *testing.T) { - k := krun{ - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, - } - - // worker is pushed back to the channel - w := &worker{ - job: func(ctx context.Context) (interface{}, error) { - return nil, nil - }, - result: make(chan *Result), - } - - go k.work(context.Background(), w) - - if len(k.workers) != 0 { - t.Errorf("Expected 0, got %v", len(k.workers)) - } - - _ = <-w.result - - time.Sleep(time.Millisecond) + t.Run("waits for running jobs to finish", func(t *testing.T) { + k := New(&Config{Size: 1}) - if len(k.workers) != 1 { - t.Errorf("Expected 1, got %v", len(k.workers)) - } - - bw := <-k.workers - if w != bw { - t.Errorf("Expected %p, got %p", w, w) - } - - close(k.workers) - }) -} - -func TestKrun_Push(t *testing.T) { - t.Parallel() - - t.Run("pushes worker to channel", func(t *testing.T) { - k := krun{ - workers: make(chan *worker, 3), - mu: sync.RWMutex{}, - n: 3, - } - workers := []*worker{ - &worker{}, - &worker{}, - &worker{}, - } - - k.push(workers[0]) - k.push(workers[1]) - k.push(workers[2]) - - if len(k.workers) != 3 { - t.Errorf("Expected 3, got %v", len(k.workers)) - } - - for i := 0; i < 3; i++ { - w := <-k.workers - var found bool - - for _, expect := range workers { - if w == expect { - found = true - break - } - } - - if !found { - t.Errorf("Expected worker to be in workers") - } - } - - close(k.workers) - }) + ctx := context.Background() + started := make(chan struct{}, 1) + resCh := k.Run(ctx, func(ctx context.Context) (interface{}, error) { + started <- struct{}{} + time.Sleep(10 * time.Millisecond) + return "done", nil + }) - t.Run("blocks until worker is pulled", func(t *testing.T) { - k := krun{ - workers: make(chan *worker, 1), - mu: sync.RWMutex{}, - n: 1, + // Ensure job actually started + select { + case <-started: + case <-time.After(5 * time.Millisecond): + t.Fatalf("job did not start in time") } - errChan := make(chan error) - + doneClose := make(chan error, 1) + startClose := time.Now() go func() { - k.push(&worker{}) - k.push(&worker{}) - errChan <- errors.New("pushed more than 1 worker") + doneClose <- k.Close() }() + // get job result before its closed + var got *Result select { - case <-errChan: - t.Errorf("Expected goroutine to be blocked") - case <-time.After(time.Millisecond): - defer close(errChan) + case got = <-resCh: + // ok + case <-time.After(30 * time.Millisecond): + t.Fatalf("timed out waiting for job result while closing") } - }) -} - -func TestKrun_Pop(t *testing.T) { - t.Parallel() - - k := krun{ - workers: make(chan *worker, 3), - mu: sync.RWMutex{}, - n: 3, - } - workers := []*worker{ - &worker{}, - &worker{}, - &worker{}, - } - - for _, w := range workers { - k.workers <- w - } - - for i := 0; i < 3; i++ { - w := k.pop() - var found bool - - for _, expect := range workers { - if w == expect { - found = true - break - } + if got == nil || got.Error != nil || got.Data != "done" { + t.Fatalf("unexpected result: %#v", got) } - if !found { - t.Errorf("Expected worker to be in workers") - } - } -} - -func TestKrun_Len(t *testing.T) { - t.Parallel() - - k := krun{ - workers: make(chan *worker, 3), - mu: sync.RWMutex{}, - n: 3, - } - workers := []*worker{ - &worker{}, - &worker{}, - &worker{}, - } - - if k.len() != 0 { - t.Errorf("Expected 0, got %v", k.len()) - } - - for i, w := range workers { - k.workers <- w - - if k.len() != i+1 { - t.Errorf("Expected %d, got %v", i+1, k.len()) + // finish shortly after the jobs done + var closeErr error + select { + case closeErr = <-doneClose: + if closeErr != nil { + t.Fatalf("expected close nil, got %v", closeErr) + } + if time.Since(startClose) < 9*time.Millisecond { + t.Fatalf("Close returned too early; expected it to wait for the job") + } + case <-time.After(50 * time.Millisecond): + t.Fatalf("Close did not complete in time") } - } - - close(k.workers) + }) }