From 44a1363046388d90f33f1390854acfb31674bda3 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Tue, 10 Mar 2026 14:39:28 +0000 Subject: [PATCH 1/4] Fix samples to use OrchestratorOptions struct The NewOrchestrationWorker signature was changed to take an OrchestratorOptions struct but the samples were not updated. Signed-off-by: joshvanl --- samples/distributedtracing/distributedtracing.go | 6 +++++- samples/externalevents/externalevents.go | 6 +++++- samples/parallel/parallel.go | 6 +++++- samples/retries/retries.go | 6 +++++- samples/sequence/sequence.go | 6 +++++- samples/taskexecutionid/taskexecutionid.go | 6 +++++- 6 files changed, 30 insertions(+), 6 deletions(-) diff --git a/samples/distributedtracing/distributedtracing.go b/samples/distributedtracing/distributedtracing.go index 1701dc43..9d3a6538 100644 --- a/samples/distributedtracing/distributedtracing.go +++ b/samples/distributedtracing/distributedtracing.go @@ -76,7 +76,11 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac // Create a new backend // Use the in-memory sqlite provider by specifying "" be := sqlite.NewSqliteBackend(sqlite.NewSqliteOptions(""), logger) - orchestrationWorker := backend.NewOrchestrationWorker(be, executor, logger) + orchestrationWorker := backend.NewOrchestrationWorker(backend.OrchestratorOptions{ + Backend: be, + Executor: executor, + Logger: logger, + }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) diff --git a/samples/externalevents/externalevents.go b/samples/externalevents/externalevents.go index 4baa9a83..3960e760 100644 --- a/samples/externalevents/externalevents.go +++ b/samples/externalevents/externalevents.go @@ -67,7 +67,11 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac // Create a new backend // Use the in-memory sqlite provider by specifying "" be := sqlite.NewSqliteBackend(sqlite.NewSqliteOptions(""), logger) - orchestrationWorker := backend.NewOrchestrationWorker(be, executor, logger) + orchestrationWorker := backend.NewOrchestrationWorker(backend.OrchestratorOptions{ + Backend: be, + Executor: executor, + Logger: logger, + }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) diff --git a/samples/parallel/parallel.go b/samples/parallel/parallel.go index d12f7904..b950968f 100644 --- a/samples/parallel/parallel.go +++ b/samples/parallel/parallel.go @@ -59,7 +59,11 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac // Create a new backend // Use the in-memory sqlite provider by specifying "" be := sqlite.NewSqliteBackend(sqlite.NewSqliteOptions(""), logger) - orchestrationWorker := backend.NewOrchestrationWorker(be, executor, logger) + orchestrationWorker := backend.NewOrchestrationWorker(backend.OrchestratorOptions{ + Backend: be, + Executor: executor, + Logger: logger, + }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) diff --git a/samples/retries/retries.go b/samples/retries/retries.go index 6df30e17..45fdc7fa 100644 --- a/samples/retries/retries.go +++ b/samples/retries/retries.go @@ -57,7 +57,11 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac // Create a new backend // Use the in-memory sqlite provider by specifying "" be := sqlite.NewSqliteBackend(sqlite.NewSqliteOptions(""), logger) - orchestrationWorker := backend.NewOrchestrationWorker(be, executor, logger) + orchestrationWorker := backend.NewOrchestrationWorker(backend.OrchestratorOptions{ + Backend: be, + Executor: executor, + Logger: logger, + }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) diff --git a/samples/sequence/sequence.go b/samples/sequence/sequence.go index bd2721a7..ec940900 100644 --- a/samples/sequence/sequence.go +++ b/samples/sequence/sequence.go @@ -55,7 +55,11 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac // Create a new backend // Use the in-memory sqlite provider by specifying "" be := sqlite.NewSqliteBackend(sqlite.NewSqliteOptions(""), logger) - orchestrationWorker := backend.NewOrchestrationWorker(be, executor, logger) + orchestrationWorker := backend.NewOrchestrationWorker(backend.OrchestratorOptions{ + Backend: be, + Executor: executor, + Logger: logger, + }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) diff --git a/samples/taskexecutionid/taskexecutionid.go b/samples/taskexecutionid/taskexecutionid.go index 7ea894e0..5a329b3d 100644 --- a/samples/taskexecutionid/taskexecutionid.go +++ b/samples/taskexecutionid/taskexecutionid.go @@ -61,7 +61,11 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac // Create a new backend // Use the in-memory sqlite provider by specifying "" be := sqlite.NewSqliteBackend(sqlite.NewSqliteOptions(""), logger) - orchestrationWorker := backend.NewOrchestrationWorker(be, executor, logger) + orchestrationWorker := backend.NewOrchestrationWorker(backend.OrchestratorOptions{ + Backend: be, + Executor: executor, + Logger: logger, + }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) From 24e9405f4e8fbd067d4869d76359a6b8fd4aba11 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Tue, 10 Mar 2026 14:42:02 +0000 Subject: [PATCH 2/4] Add loop event types and handlers Introduces the event-driven loop infrastructure that will be used by the worker and executor refactors: - backend/loops: EventWorker and EventExecutor marker interfaces with concrete event types (DispatchWorkItem, Shutdown, ExecuteOrchestrator, ExecuteActivity, ConnectStream, DisconnectStream, etc.) - backend/loops/worker: Handler that processes dispatched work items inline within a loop, calling Process/Complete/Abandon on the processor. - backend/loops/executor: Handler that manages gRPC streams and dispatches orchestrator/activity work items to connected clients. - backend/local/loops: EventTask types and handler for the local tasks backend, replacing sync.Map-based pending task tracking with serialized loop processing. Branched from https://github.com/dapr/durabletask-go/pull/72 Signed-off-by: joshvanl --- backend/local/loops/loops.go | 64 +++++++ backend/local/loops/task/task.go | 100 +++++++++++ backend/loops/executor/executor.go | 265 +++++++++++++++++++++++++++++ backend/loops/loops.go | 74 ++++++++ backend/loops/worker/worker.go | 83 +++++++++ go.mod | 28 +-- go.sum | 75 ++++---- 7 files changed, 642 insertions(+), 47 deletions(-) create mode 100644 backend/local/loops/loops.go create mode 100644 backend/local/loops/task/task.go create mode 100644 backend/loops/executor/executor.go create mode 100644 backend/loops/loops.go create mode 100644 backend/loops/worker/worker.go diff --git a/backend/local/loops/loops.go b/backend/local/loops/loops.go new file mode 100644 index 00000000..b54c6448 --- /dev/null +++ b/backend/local/loops/loops.go @@ -0,0 +1,64 @@ +package loops + +import ( + "github.com/dapr/durabletask-go/api" + "github.com/dapr/durabletask-go/api/protos" +) + +type taskbase struct{} + +func (*taskbase) isEventTask() {} + +// EventTask is the marker interface for local task backend events. +type EventTask interface{ isEventTask() } + +// RegisterPendingOrchestrator registers a new pending orchestrator waiter. +type RegisterPendingOrchestrator struct { + *taskbase + InstanceID string + Response chan<- *protos.OrchestratorResponse +} + +// CompleteOrchestrator signals completion of an orchestrator task. +type CompleteOrchestrator struct { + *taskbase + InstanceID string + Response *protos.OrchestratorResponse + ErrCh chan<- error +} + +// CancelOrchestrator cancels a pending orchestrator task. +type CancelOrchestrator struct { + *taskbase + InstanceID api.InstanceID + ErrCh chan<- error +} + +// RegisterPendingActivity registers a new pending activity waiter. +type RegisterPendingActivity struct { + *taskbase + Key string + Response chan<- *protos.ActivityResponse +} + +// CompleteActivity signals completion of an activity task. +type CompleteActivity struct { + *taskbase + InstanceID string + TaskID int32 + Response *protos.ActivityResponse + ErrCh chan<- error +} + +// CancelActivity cancels a pending activity task. +type CancelActivity struct { + *taskbase + InstanceID api.InstanceID + TaskID int32 + ErrCh chan<- error +} + +// Shutdown signals the loop to stop. +type Shutdown struct { + *taskbase +} diff --git a/backend/local/loops/task/task.go b/backend/local/loops/task/task.go new file mode 100644 index 00000000..bba983ee --- /dev/null +++ b/backend/local/loops/task/task.go @@ -0,0 +1,100 @@ +package task + +import ( + "context" + "fmt" + + "github.com/dapr/durabletask-go/api" + "github.com/dapr/durabletask-go/api/protos" + "github.com/dapr/durabletask-go/backend" + "github.com/dapr/durabletask-go/backend/local/loops" + "github.com/dapr/kit/events/loop" +) + +type pendingOrchestrator struct { + response chan<- *protos.OrchestratorResponse +} + +type pendingActivity struct { + response chan<- *protos.ActivityResponse +} + +type task struct { + pendingOrchestrators map[string]*pendingOrchestrator + pendingActivities map[string]*pendingActivity +} + +// New creates a new task backend loop. +func New() loop.Interface[loops.EventTask] { + return loop.New[loops.EventTask](64).NewLoop(&task{ + pendingOrchestrators: make(map[string]*pendingOrchestrator), + pendingActivities: make(map[string]*pendingActivity), + }) +} + +// Handle implements loop.Handler[loops.EventTask]. +func (t *task) Handle(_ context.Context, event loops.EventTask) error { + switch e := event.(type) { + case *loops.RegisterPendingOrchestrator: + t.pendingOrchestrators[e.InstanceID] = &pendingOrchestrator{ + response: e.Response, + } + case *loops.CompleteOrchestrator: + e.ErrCh <- t.completeOrchestrator(e.InstanceID, e.Response) + case *loops.CancelOrchestrator: + e.ErrCh <- t.completeOrchestrator(string(e.InstanceID), nil) + case *loops.RegisterPendingActivity: + t.pendingActivities[e.Key] = &pendingActivity{ + response: e.Response, + } + case *loops.CompleteActivity: + key := backend.GetActivityExecutionKey(e.InstanceID, e.TaskID) + e.ErrCh <- t.completeActivity(key, e.Response) + case *loops.CancelActivity: + key := backend.GetActivityExecutionKey(string(e.InstanceID), e.TaskID) + e.ErrCh <- t.completeActivity(key, nil) + case *loops.Shutdown: + t.shutdown() + default: + return fmt.Errorf("unexpected event type %T", event) + } + + return nil +} + +func (t *task) completeOrchestrator(instanceID string, res *protos.OrchestratorResponse) error { + pending, ok := t.pendingOrchestrators[instanceID] + if !ok { + return api.NewUnknownInstanceIDError(instanceID) + } + delete(t.pendingOrchestrators, instanceID) + + // Send response (nil means cancelled) and close the channel. + pending.response <- res + close(pending.response) + return nil +} + +func (t *task) completeActivity(key string, res *protos.ActivityResponse) error { + pending, ok := t.pendingActivities[key] + if !ok { + return api.NewUnknownTaskIDError(key, 0) + } + delete(t.pendingActivities, key) + + // Send response (nil means cancelled) and close the channel. + pending.response <- res + close(pending.response) + return nil +} + +func (t *task) shutdown() { + for id, pending := range t.pendingOrchestrators { + close(pending.response) + delete(t.pendingOrchestrators, id) + } + for key, pending := range t.pendingActivities { + close(pending.response) + delete(t.pendingActivities, key) + } +} diff --git a/backend/loops/executor/executor.go b/backend/loops/executor/executor.go new file mode 100644 index 00000000..8a6f81ad --- /dev/null +++ b/backend/loops/executor/executor.go @@ -0,0 +1,265 @@ +package executor + +import ( + "context" + "fmt" + + "github.com/dapr/durabletask-go/api" + "github.com/dapr/durabletask-go/api/protos" + "github.com/dapr/durabletask-go/backend/loops" + "github.com/dapr/kit/events/loop" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +var errShuttingDown error = status.Error(codes.Canceled, "shutting down") + +// Logger is the logging interface needed by the executor handler. +type Logger interface { + Debugf(format string, v ...any) + Infof(format string, v ...any) + Warnf(format string, v ...any) + Errorf(format string, v ...any) +} + +// Backend is the subset of the backend interface needed for cancellation. +type Backend interface { + CancelOrchestratorTask(context.Context, api.InstanceID) error + CancelActivityTask(context.Context, api.InstanceID, int32) error +} + +type pendingOrchestrator struct { + instanceID api.InstanceID + streamID string +} + +type pendingActivity struct { + instanceID api.InstanceID + taskID int32 + streamID string +} + +type streamState struct { + streamID string + stream protos.TaskHubSidecarService_GetWorkItemsServer + errCh chan<- error +} + +type pendingDispatchItem struct { + workItem *protos.WorkItem + dispatched chan<- error + orchInstanceID *api.InstanceID + activityKey *string + actInstanceID *api.InstanceID + actTaskID *int32 +} + +// Options configures the executor loop handler. +type Options struct { + Backend Backend + Logger Logger +} + +type executor struct { + backend Backend + logger Logger + + stream *streamState + pendingOrch map[api.InstanceID]*pendingOrchestrator + pendingAct map[string]*pendingActivity + pendingDispatch []pendingDispatchItem + + loop loop.Interface[loops.EventExecutor] +} + +// New creates a new executor loop. +func New(opts Options) loop.Interface[loops.EventExecutor] { + e := &executor{ + backend: opts.Backend, + logger: opts.Logger, + pendingOrch: make(map[api.InstanceID]*pendingOrchestrator), + pendingAct: make(map[string]*pendingActivity), + } + e.loop = loop.New[loops.EventExecutor](64).NewLoop(e) + return e.loop +} + +// Handle implements loop.Handler[loops.EventExecutor]. +func (e *executor) Handle(ctx context.Context, event loops.EventExecutor) error { + switch ev := event.(type) { + case *loops.ExecuteOrchestrator: + e.handleExecuteOrchestrator(ev) + case *loops.ExecuteActivity: + e.handleExecuteActivity(ev) + case *loops.ConnectStream: + e.handleConnectStream(ev) + case *loops.DisconnectStream: + e.handleDisconnectStream(ev) + case *loops.StreamShutdown: + e.handleStreamShutdown() + case *loops.ShutdownExecutor: + e.handleShutdown(ctx) + } + return nil +} + +func (e *executor) handleExecuteOrchestrator(ev *loops.ExecuteOrchestrator) { + iid := api.InstanceID(ev.InstanceID) + e.pendingOrch[iid] = &pendingOrchestrator{instanceID: iid} + + if e.stream == nil { + iidCopy := iid + e.pendingDispatch = append(e.pendingDispatch, pendingDispatchItem{ + workItem: ev.WorkItem, + dispatched: ev.Dispatched, + orchInstanceID: &iidCopy, + }) + return + } + + e.pendingOrch[iid].streamID = e.stream.streamID + + if err := e.sendWorkItem(ev.WorkItem); err != nil { + ev.Dispatched <- err + return + } + ev.Dispatched <- nil +} + +func (e *executor) handleExecuteActivity(ev *loops.ExecuteActivity) { + iid := api.InstanceID(ev.InstanceID) + e.pendingAct[ev.Key] = &pendingActivity{ + instanceID: iid, + taskID: ev.TaskID, + } + + if e.stream == nil { + key := ev.Key + iidCopy := iid + taskID := ev.TaskID + e.pendingDispatch = append(e.pendingDispatch, pendingDispatchItem{ + workItem: ev.WorkItem, + dispatched: ev.Dispatched, + activityKey: &key, + actInstanceID: &iidCopy, + actTaskID: &taskID, + }) + return + } + + e.pendingAct[ev.Key].streamID = e.stream.streamID + + if err := e.sendWorkItem(ev.WorkItem); err != nil { + ev.Dispatched <- err + return + } + ev.Dispatched <- nil +} + +func (e *executor) handleConnectStream(ev *loops.ConnectStream) { + if e.stream != nil { + e.logger.Warnf("rejecting stream %s: another stream %s is already connected", ev.StreamID, e.stream.streamID) + ev.ErrCh <- fmt.Errorf("another stream is already connected") + return + } + + e.stream = &streamState{ + streamID: ev.StreamID, + stream: ev.Stream, + errCh: ev.ErrCh, + } + + // Flush any buffered work items. + for i, item := range e.pendingDispatch { + if item.orchInstanceID != nil { + if p, ok := e.pendingOrch[*item.orchInstanceID]; ok { + p.streamID = ev.StreamID + } + } + if item.activityKey != nil { + if p, ok := e.pendingAct[*item.activityKey]; ok { + p.streamID = ev.StreamID + } + } + + if err := e.sendWorkItem(item.workItem); err != nil { + item.dispatched <- err + // Stream failed; keep remaining items for next stream. + e.pendingDispatch = e.pendingDispatch[i+1:] + e.stream = nil + return + } + item.dispatched <- nil + } + e.pendingDispatch = nil +} + +func (e *executor) handleDisconnectStream(ev *loops.DisconnectStream) { + if e.stream == nil || e.stream.streamID != ev.StreamID { + return + } + + e.logger.Infof("stream %s disconnected, cleaning up", ev.StreamID) + + for iid, p := range e.pendingOrch { + if p.streamID == ev.StreamID { + e.logger.Debugf("cleaning up pending orchestrator: %s", iid) + if err := e.backend.CancelOrchestratorTask(context.Background(), p.instanceID); err != nil { + e.logger.Warnf("failed to cancel orchestrator task: %v", err) + } + delete(e.pendingOrch, iid) + } + } + for key, p := range e.pendingAct { + if p.streamID == ev.StreamID { + e.logger.Debugf("cleaning up pending activity: %s", key) + if err := e.backend.CancelActivityTask(context.Background(), p.instanceID, p.taskID); err != nil { + e.logger.Warnf("failed to cancel activity task: %v", err) + } + delete(e.pendingAct, key) + } + } + + e.stream = nil +} + +func (e *executor) handleStreamShutdown() { + if e.stream != nil { + e.stream.errCh <- errShuttingDown + e.stream = nil + } +} + +func (e *executor) handleShutdown(ctx context.Context) { + if e.stream != nil { + e.stream.errCh <- errShuttingDown + e.stream = nil + } + + for _, item := range e.pendingDispatch { + item.dispatched <- errShuttingDown + } + e.pendingDispatch = nil + + for iid, p := range e.pendingOrch { + if err := e.backend.CancelOrchestratorTask(ctx, p.instanceID); err != nil { + e.logger.Warnf("failed to cancel orchestrator task: %v", err) + } + delete(e.pendingOrch, iid) + } + for key, p := range e.pendingAct { + if err := e.backend.CancelActivityTask(ctx, p.instanceID, p.taskID); err != nil { + e.logger.Warnf("failed to cancel activity task: %v", err) + } + delete(e.pendingAct, key) + } +} + +func (e *executor) sendWorkItem(wi *protos.WorkItem) error { + if e.stream == nil { + return fmt.Errorf("no stream connected") + } + + stream := e.stream.stream + return stream.Send(wi) +} diff --git a/backend/loops/loops.go b/backend/loops/loops.go new file mode 100644 index 00000000..cf85b726 --- /dev/null +++ b/backend/loops/loops.go @@ -0,0 +1,74 @@ +package loops + +import ( + "github.com/dapr/durabletask-go/api/protos" +) + +type workerbase struct{} + +func (*workerbase) isEventWorker() {} + +// EventWorker is the marker interface for worker loop events. +type EventWorker interface{ isEventWorker() } + +// Shutdown signals the worker loop to stop. +type Shutdown struct{ *workerbase } + +// DispatchWorkItem delivers a work item directly to the worker loop from +// the backend. The worker processes it inline and signals completion via +// the Callback channel (nil = completed, non-nil error = abandoned). +type DispatchWorkItem struct { + *workerbase + WorkItem any + Callback chan<- error +} + +type executorbase struct{} + +func (*executorbase) isEventExecutor() {} + +// EventExecutor is the marker interface for executor loop events. +type EventExecutor interface{ isEventExecutor() } + +// ExecuteOrchestrator delivers an orchestrator work item to the connected +// stream. The caller blocks on Dispatched until the item has been sent to +// a stream (or an error occurs), preserving back-pressure semantics. +type ExecuteOrchestrator struct { + *executorbase + InstanceID string + WorkItem *protos.WorkItem + Dispatched chan<- error +} + +// ExecuteActivity delivers an activity work item to the connected stream. +// Same back-pressure semantics as ExecuteOrchestrator. +type ExecuteActivity struct { + *executorbase + Key string + InstanceID string + TaskID int32 + WorkItem *protos.WorkItem + Dispatched chan<- error +} + +// ConnectStream signals that a new GetWorkItems stream has connected. +type ConnectStream struct { + *executorbase + StreamID string + Stream protos.TaskHubSidecarService_GetWorkItemsServer + // ErrCh receives the final error when the stream should terminate. + ErrCh chan<- error +} + +// DisconnectStream signals that a GetWorkItems stream has disconnected. +type DisconnectStream struct { + *executorbase + StreamID string +} + +// StreamShutdown signals that the external shutdown channel has fired. +type StreamShutdown struct{ *executorbase } + +// ShutdownExecutor signals the executor loop to stop. All pending items +// are cancelled and the loop terminates. +type ShutdownExecutor struct{ *executorbase } diff --git a/backend/loops/worker/worker.go b/backend/loops/worker/worker.go new file mode 100644 index 00000000..d1f47739 --- /dev/null +++ b/backend/loops/worker/worker.go @@ -0,0 +1,83 @@ +package worker + +import ( + "context" + "fmt" + + "github.com/dapr/durabletask-go/backend/loops" +) + +// Logger is the logging interface needed by the worker handler. +type Logger interface { + Debugf(format string, v ...any) + Infof(format string, v ...any) + Errorf(format string, v ...any) +} + +// Processor is the interface for processing work items. +type Processor[T any] interface { + Name() string + ProcessWorkItem(context.Context, T) error + AbandonWorkItem(context.Context, T) error + CompleteWorkItem(context.Context, T) error +} + +// Options configures the worker loop handler. +type Options[T fmt.Stringer] struct { + Processor Processor[T] + Logger Logger +} + +type worker[T fmt.Stringer] struct { + logger Logger + processor Processor[T] +} + +// New creates a new worker handler. The caller is responsible for creating the +// loop and wiring this handler into it. +func New[T fmt.Stringer](opts Options[T]) *worker[T] { + return &worker[T]{ + processor: opts.Processor, + logger: opts.Logger, + } +} + +// Handle implements loop.Handler[loops.EventWorker]. +func (w *worker[T]) Handle(ctx context.Context, event loops.EventWorker) error { + switch e := event.(type) { + case *loops.DispatchWorkItem: + w.handleDispatch(ctx, e) + case *loops.Shutdown: + w.logger.Infof("%v: worker stopped", w.processor.Name()) + } + + return nil +} + +func (w *worker[T]) handleDispatch(ctx context.Context, e *loops.DispatchWorkItem) { + wi := e.WorkItem.(T) + w.processWorkItem(ctx, wi) + e.Callback <- nil +} + +func (w *worker[T]) processWorkItem(ctx context.Context, wi T) { + w.logger.Debugf("%v: processing work item: %s", w.processor.Name(), wi) + + if err := w.processor.ProcessWorkItem(ctx, wi); err != nil { + w.logger.Errorf("%v: failed to process work item: %v", w.processor.Name(), err) + if err = w.processor.AbandonWorkItem(context.Background(), wi); err != nil { + w.logger.Errorf("%v: failed to abandon work item: %v", w.processor.Name(), err) + } + return + } + + if err := w.processor.CompleteWorkItem(ctx, wi); err != nil { + w.logger.Errorf("%v: failed to complete work item: %v", w.processor.Name(), err) + if err = w.processor.AbandonWorkItem(context.Background(), wi); err != nil { + w.logger.Errorf("%v: failed to abandon work item: %v", w.processor.Name(), err) + } + return + } + + w.logger.Debugf("%v: work item processed successfully", w.processor.Name()) +} diff --git a/go.mod b/go.mod index e52484a7..486ca0b6 100644 --- a/go.mod +++ b/go.mod @@ -4,17 +4,17 @@ go 1.26.0 require ( github.com/cenkalti/backoff/v4 v4.3.0 - github.com/dapr/kit v0.15.3-0.20250616160611-598b032bce69 + github.com/dapr/kit v0.17.1-0.20260303145220-d749ae76d3c3 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.4 github.com/stretchr/testify v1.10.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 - go.opentelemetry.io/otel v1.34.0 + go.opentelemetry.io/otel v1.35.0 go.opentelemetry.io/otel/exporters/zipkin v1.34.0 - go.opentelemetry.io/otel/sdk v1.34.0 - go.opentelemetry.io/otel/trace v1.34.0 - google.golang.org/grpc v1.70.0 - google.golang.org/protobuf v1.36.4 + go.opentelemetry.io/otel/sdk v1.35.0 + go.opentelemetry.io/otel/trace v1.35.0 + google.golang.org/grpc v1.73.0 + google.golang.org/protobuf v1.36.6 modernc.org/sqlite v1.34.5 ) @@ -32,17 +32,19 @@ require ( github.com/openzipkin/zipkin-go v0.4.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/sirupsen/logrus v1.9.3 // indirect github.com/stretchr/objx v0.5.2 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/otel/metric v1.34.0 // indirect - golang.org/x/crypto v0.32.0 // indirect + go.opentelemetry.io/otel/metric v1.35.0 // indirect + golang.org/x/crypto v0.39.0 // indirect golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect - golang.org/x/net v0.34.0 // indirect - golang.org/x/sync v0.10.0 // indirect - golang.org/x/sys v0.29.0 // indirect - golang.org/x/text v0.21.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250124145028-65684f501c47 // indirect + golang.org/x/net v0.41.0 // indirect + golang.org/x/sync v0.15.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.26.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + k8s.io/utils v0.0.0-20230726121419-3b25d923346b // indirect modernc.org/libc v1.61.9 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.8.2 // indirect diff --git a/go.sum b/go.sum index 8760f4c0..85232e98 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/dapr/kit v0.15.3-0.20250616160611-598b032bce69 h1:I1Uoy3fn906AZZdG8+n8fHitgY7Wn9c+smz4WQdOy1Q= -github.com/dapr/kit v0.15.3-0.20250616160611-598b032bce69/go.mod h1:6w2Pr38zOAtBn+ld/jknwI4kgMfwanCIcFVnPykdPZQ= +github.com/dapr/kit v0.17.1-0.20260303145220-d749ae76d3c3 h1:noPW1pCxCefI0O19Ay70TX2TwLs5OrnVhfz7aJsOKkM= +github.com/dapr/kit v0.17.1-0.20260303145220-d749ae76d3c3/go.mod h1:40ZWs5P6xfYf7O59XgwqZkIyDldTIXlhTQhGop8QoSM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -16,8 +16,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -46,6 +46,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94 github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= @@ -57,47 +59,52 @@ go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJyS go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 h1:CV7UdSGJt/Ao6Gp4CXckLxVRRsRgDHoI8XjbL3PDl8s= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0/go.mod h1:FRmFuRJfag1IZ2dPkHnEoSFVgTVPUd2qf5Vi69hLb8I= -go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY= -go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI= +go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= +go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= go.opentelemetry.io/otel/exporters/zipkin v1.34.0 h1:GSjCkoYqsnvUMCjxF18j2tCWH8fhGZYjH3iYgechPTI= go.opentelemetry.io/otel/exporters/zipkin v1.34.0/go.mod h1:h830hluwAqgSNnZbxL2rJhmAlE7/0SF9esoHVLU04Gc= -go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ= -go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE= -go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A= -go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU= -go.opentelemetry.io/otel/sdk/metric v1.32.0 h1:rZvFnvmvawYb0alrYkjraqJq0Z4ZUJAiyYCU9snn1CU= -go.opentelemetry.io/otel/sdk/metric v1.32.0/go.mod h1:PWeZlq0zt9YkYAp3gjKZ0eicRYvOh1Gd+X99x6GHpCQ= -go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k= -go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= -golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= -golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= +go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= +go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= +go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= +go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= +go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= +go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= +go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= -golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= -golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= -golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= -golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= +golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE= -golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250124145028-65684f501c47 h1:91mG8dNTpkC0uChJUQ9zCiRqx3GEEFOWaRZ0mI6Oj2I= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250124145028-65684f501c47/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= -google.golang.org/grpc v1.70.0 h1:pWFv03aZoHzlRKHWicjsZytKAiYCtNS0dHbXnIdq7jQ= -google.golang.org/grpc v1.70.0/go.mod h1:ofIJqVKDXx/JiXrwr2IG4/zwdH9txy3IlF40RmcJSQw= -google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM= -google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= +google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +k8s.io/utils v0.0.0-20230726121419-3b25d923346b h1:sgn3ZU783SCgtaSJjpcVVlRqd6GSnlTLKgpAAttJvpI= +k8s.io/utils v0.0.0-20230726121419-3b25d923346b/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0= modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= modernc.org/ccgo/v4 v4.23.13 h1:PFiaemQwE/jdwi8XEHyEV+qYWoIuikLP3T4rvDeJb00= From 6ceb40b541b0a9292cd2b464b3037a90553328f1 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Tue, 10 Mar 2026 14:56:47 +0000 Subject: [PATCH 3/4] Refactor executor and local tasks backend to loop-based dispatch Replaces the channel-based work item queue and sync.Map-based pending task tracking in the gRPC executor with a single-threaded event loop. All stream management, work item dispatch, and cleanup now happen serially within the loop handler, eliminating data races. - grpcExecutor: replaces workItemQueue channel and pendingOrchestrators/ pendingActivities sync.Maps with an executor event loop. Adds Start() to the Executor interface. GetWorkItems now connects a stream via the loop and blocks until disconnected. - local.TasksBackend: replaces sync.Map-based pending task tracking with a task event loop. All complete/cancel/register operations are now serialized through the loop. - sqlite/postgres: Start now runs the TasksBackend loop in a goroutine, Stop closes it. - task.taskExecutor: adds no-op Start() that blocks until context done. - Executor mock: adds Start() mock method. - gRPC tests: start the executor loop and cancel on cleanup. Branched from https://github.com/dapr/durabletask-go/pull/72 Signed-off-by: joshvanl --- backend/executor.go | 257 ++++++++++++----------------------- backend/local/task.go | 150 ++++++++++---------- backend/postgres/postgres.go | 6 +- backend/sqlite/sqlite.go | 10 +- task/executor.go | 6 + tests/grpc/grpc_test.go | 10 +- tests/mocks/Executor.go | 46 +++++++ 7 files changed, 230 insertions(+), 255 deletions(-) diff --git a/backend/executor.go b/backend/executor.go index 523c32a8..b602fd03 100644 --- a/backend/executor.go +++ b/backend/executor.go @@ -6,7 +6,6 @@ import ( "fmt" "strconv" "strings" - "sync" "time" "github.com/google/uuid" @@ -21,24 +20,16 @@ import ( "github.com/dapr/durabletask-go/api" "github.com/dapr/durabletask-go/api/helpers" "github.com/dapr/durabletask-go/api/protos" + "github.com/dapr/durabletask-go/backend/loops" + loopexecutor "github.com/dapr/durabletask-go/backend/loops/executor" + "github.com/dapr/kit/concurrency" + "github.com/dapr/kit/events/loop" ) var emptyCompleteTaskResponse = &protos.CompleteTaskResponse{} -var errShuttingDown error = status.Error(codes.Canceled, "shutting down") - -type pendingOrchestrator struct { - instanceID api.InstanceID - streamID string -} - -type pendingActivity struct { - instanceID api.InstanceID - taskID int32 - streamID string -} - type Executor interface { + Start(context.Context) error ExecuteOrchestrator(ctx context.Context, iid api.InstanceID, oldEvents []*protos.HistoryEvent, newEvents []*protos.HistoryEvent) (*protos.OrchestratorResponse, error) ExecuteActivity(context.Context, api.InstanceID, *protos.HistoryEvent) (*protos.HistoryEvent, error) Shutdown(ctx context.Context) error @@ -47,9 +38,7 @@ type Executor interface { type grpcExecutor struct { protos.UnimplementedTaskHubSidecarServiceServer - workItemQueue chan *protos.WorkItem - pendingOrchestrators *sync.Map // map[api.InstanceID]*pendingOrchestrator - pendingActivities *sync.Map // map[string]*pendingActivity + executorLoop loop.Interface[loops.EventExecutor] backend Backend logger Logger onWorkItemConnection func(context.Context) error @@ -91,12 +80,6 @@ func WithStreamShutdownChannel(c <-chan any) grpcExecutorOptions { } } -func WithStreamSendTimeout(d time.Duration) grpcExecutorOptions { - return func(g *grpcExecutor) { - g.streamSendTimeout = &d - } -} - func WithSkipWaitForInstanceStart() grpcExecutorOptions { return func(g *grpcExecutor) { g.skipWaitForInstanceStart = true @@ -105,27 +88,49 @@ func WithSkipWaitForInstanceStart() grpcExecutorOptions { // NewGrpcExecutor returns the Executor object and a method to invoke to register the gRPC server in the executor. func NewGrpcExecutor(be Backend, logger Logger, opts ...grpcExecutorOptions) (executor Executor, registerServerFn func(grpcServer grpc.ServiceRegistrar)) { - grpcExecutor := &grpcExecutor{ - workItemQueue: make(chan *protos.WorkItem), - backend: be, - logger: logger, - pendingOrchestrators: &sync.Map{}, - pendingActivities: &sync.Map{}, + grpcExec := &grpcExecutor{ + backend: be, + logger: logger, } for _, opt := range opts { - opt(grpcExecutor) + opt(grpcExec) } - return grpcExecutor, func(grpcServer grpc.ServiceRegistrar) { - protos.RegisterTaskHubSidecarServiceServer(grpcServer, grpcExecutor) + grpcExec.executorLoop = loopexecutor.New(loopexecutor.Options{ + Backend: be, + Logger: logger, + }) + + return grpcExec, func(grpcServer grpc.ServiceRegistrar) { + protos.RegisterTaskHubSidecarServiceServer(grpcServer, grpcExec) } } -// ExecuteOrchestrator implements Executor -func (executor *grpcExecutor) ExecuteOrchestrator(ctx context.Context, iid api.InstanceID, oldEvents []*protos.HistoryEvent, newEvents []*protos.HistoryEvent) (*protos.OrchestratorResponse, error) { - executor.pendingOrchestrators.Store(iid, &pendingOrchestrator{instanceID: iid}) +// Start starts the executor loop and blocks until the context is cancelled. +func (g *grpcExecutor) Start(ctx context.Context) error { + manager := concurrency.NewRunnerManager(g.executorLoop.Run) + if g.streamShutdownChan != nil { + manager.Add(func(ctx context.Context) error { + select { + case <-g.streamShutdownChan: + g.executorLoop.Enqueue(new(loops.StreamShutdown)) + case <-ctx.Done(): + } + return nil + }) + } + // When context is cancelled, close the executor loop so Run unblocks. + manager.Add(func(ctx context.Context) error { + <-ctx.Done() + g.executorLoop.Close(new(loops.ShutdownExecutor)) + return nil + }) + return manager.Run(ctx) +} +// ExecuteOrchestrator implements Executor +func (g *grpcExecutor) ExecuteOrchestrator(ctx context.Context, iid api.InstanceID, oldEvents []*protos.HistoryEvent, newEvents []*protos.HistoryEvent) (*protos.OrchestratorResponse, error) { req := &protos.OrchestratorRequest{ InstanceId: string(iid), ExecutionId: nil, @@ -139,26 +144,32 @@ func (executor *grpcExecutor) ExecuteOrchestrator(ctx context.Context, iid api.I }, } - wait := executor.backend.WaitForOrchestratorCompletion(req) + wait := g.backend.WaitForOrchestratorCompletion(req) + + dispatched := make(chan error, 1) + g.executorLoop.Enqueue(&loops.ExecuteOrchestrator{ + InstanceID: string(iid), + WorkItem: workItem, + Dispatched: dispatched, + }) - // Send the orchestration execution work-item to the connected worker. - // This will block if the worker isn't listening for work items. + // Block until dispatched to stream (preserves back-pressure). select { case <-ctx.Done(): - executor.logger.Warnf("%s: context canceled before dispatching orchestrator work item", iid) + g.logger.Warnf("%s: context canceled before dispatching orchestrator work item", iid) return nil, fmt.Errorf("context canceled before dispatching orchestrator work item: %w", ctx.Err()) - case executor.workItemQueue <- workItem: + case err := <-dispatched: + if err != nil { + return nil, fmt.Errorf("failed to dispatch orchestrator work item: %w", err) + } } resp, err := wait(ctx) - - // this orchestrator is either completed or cancelled, but its no longer pending, delete it - executor.pendingOrchestrators.Delete(iid) if err != nil { if errors.Is(err, api.ErrTaskCancelled) { return nil, errors.New("operation aborted") } - executor.logger.Warnf("%s: failed before receiving orchestration result", iid) + g.logger.Warnf("%s: failed before receiving orchestration result", iid) return nil, err } @@ -166,10 +177,8 @@ func (executor *grpcExecutor) ExecuteOrchestrator(ctx context.Context, iid api.I } // ExecuteActivity implements Executor -func (executor *grpcExecutor) ExecuteActivity(ctx context.Context, iid api.InstanceID, e *protos.HistoryEvent) (*protos.HistoryEvent, error) { +func (g *grpcExecutor) ExecuteActivity(ctx context.Context, iid api.InstanceID, e *protos.HistoryEvent) (*protos.HistoryEvent, error) { key := GetActivityExecutionKey(string(iid), e.EventId) - executor.pendingActivities.Store(key, &pendingActivity{instanceID: iid, taskID: e.EventId}) - task := e.GetTaskScheduled() req := &protos.ActivityRequest{ @@ -187,26 +196,34 @@ func (executor *grpcExecutor) ExecuteActivity(ctx context.Context, iid api.Insta }, } - wait := executor.backend.WaitForActivityCompletion(req) + wait := g.backend.WaitForActivityCompletion(req) + + dispatched := make(chan error, 1) + g.executorLoop.Enqueue(&loops.ExecuteActivity{ + Key: key, + InstanceID: string(iid), + TaskID: e.EventId, + WorkItem: workItem, + Dispatched: dispatched, + }) - // Send the activity execution work-item to the connected worker. - // This will block if the worker isn't listening for work items. + // Block until dispatched to stream (preserves back-pressure). select { case <-ctx.Done(): - executor.logger.Warnf("%s/%s#%d: context canceled before dispatching activity work item", iid, task.Name, e.EventId) + g.logger.Warnf("%s/%s#%d: context canceled before dispatching activity work item", iid, task.Name, e.EventId) return nil, fmt.Errorf("context canceled before dispatching activity work item: %w", ctx.Err()) - case executor.workItemQueue <- workItem: + case err := <-dispatched: + if err != nil { + return nil, fmt.Errorf("failed to dispatch activity work item: %w", err) + } } resp, err := wait(ctx) - - // this activity is either completed or cancelled, but its no longer pending, delete it - executor.pendingActivities.Delete(key) if err != nil { if errors.Is(err, api.ErrTaskCancelled) { return nil, errors.New("operation aborted") } - executor.logger.Warnf("%s/%s#%d: failed before receiving activity result", iid, task.Name, e.EventId) + g.logger.Warnf("%s/%s#%d: failed before receiving activity result", iid, task.Name, e.EventId) return nil, err } @@ -244,31 +261,7 @@ func (executor *grpcExecutor) ExecuteActivity(ctx context.Context, iid api.Insta // Shutdown implements Executor func (g *grpcExecutor) Shutdown(ctx context.Context) error { - // closing the work item queue is a signal for shutdown - close(g.workItemQueue) - - // Iterate through all pending items and close them to unblock the goroutines waiting on this - g.pendingActivities.Range(func(_, value any) bool { - p, ok := value.(*pendingActivity) - if ok { - err := g.backend.CancelActivityTask(ctx, p.instanceID, p.taskID) - if err != nil { - g.logger.Warnf("failed to cancel activity task: %v", err) - } - } - return true - }) - g.pendingOrchestrators.Range(func(_, value any) bool { - p, ok := value.(*pendingOrchestrator) - if ok { - err := g.backend.CancelOrchestratorTask(ctx, p.instanceID) - if err != nil { - g.logger.Warnf("failed to cancel orchestrator task: %v", err) - } - } - return true - }) - + g.executorLoop.Close(new(loops.ShutdownExecutor)) return nil } @@ -295,109 +288,28 @@ func (g *grpcExecutor) GetWorkItems(req *protos.GetWorkItemsRequest, stream prot g.logger.Warnf("error while disconnecting work item stream: %v", derr) } - return status.Errorf(codes.Unavailable, message) + return status.Errorf(codes.Unavailable, "%s", message) } defer func() { - // If there's any pending activity left, remove them - g.pendingActivities.Range(func(key, value any) bool { - if p, ok := value.(*pendingActivity); ok && p.streamID == streamID { - g.logger.Debugf("cleaning up pending activity: %s", key) - err := g.backend.CancelActivityTask(context.Background(), p.instanceID, p.taskID) - if err != nil { - g.logger.Warnf("failed to cancel activity task: %v", err) - } - g.pendingActivities.Delete(key) - } - return true - }) - g.pendingOrchestrators.Range(func(key, value any) bool { - if p, ok := value.(*pendingOrchestrator); ok && p.streamID == streamID { - g.logger.Debugf("cleaning up pending orchestrator: %s", key) - err := g.backend.CancelOrchestratorTask(context.Background(), p.instanceID) - if err != nil { - g.logger.Warnf("failed to cancel orchestrator task: %v", err) - } - } - return true - }) + g.executorLoop.Enqueue(&loops.DisconnectStream{StreamID: streamID}) if err := g.executeOnWorkItemDisconnect(stream.Context()); err != nil { g.logger.Warnf("error while disconnecting work item stream: %v", err) } }() - ch := make(chan *protos.WorkItem) errCh := make(chan error, 1) - go func() { - for { - select { - case <-stream.Context().Done(): - return - case wi := <-ch: - errCh <- stream.Send(wi) - } - } - }() - - // The worker client invokes this method, which streams back work-items as they arrive. - for { - select { - case <-stream.Context().Done(): - g.logger.Info("work item stream closed") - return nil - case wi, ok := <-g.workItemQueue: - if !ok { - continue - } - - switch x := wi.Request.(type) { - case *protos.WorkItem_OrchestratorRequest: - key := x.OrchestratorRequest.GetInstanceId() - if value, ok := g.pendingOrchestrators.Load(api.InstanceID(key)); ok { - if p, ok := value.(*pendingOrchestrator); ok { - p.streamID = streamID - } - } - case *protos.WorkItem_ActivityRequest: - key := GetActivityExecutionKey(x.ActivityRequest.GetOrchestrationInstance().GetInstanceId(), x.ActivityRequest.GetTaskId()) - if value, ok := g.pendingActivities.Load(key); ok { - if p, ok := value.(*pendingActivity); ok { - p.streamID = streamID - } - } - } - - if err := g.sendWorkItem(stream, wi, ch, errCh); err != nil { - g.logger.Errorf("encountered an error while sending work item: %v", err) - return err - } - - case <-g.streamShutdownChan: - return errShuttingDown - } - } -} + g.executorLoop.Enqueue(&loops.ConnectStream{ + StreamID: streamID, + Stream: stream, + ErrCh: errCh, + }) -func (g *grpcExecutor) sendWorkItem(stream protos.TaskHubSidecarService_GetWorkItemsServer, wi *protos.WorkItem, - ch chan *protos.WorkItem, errCh chan error, -) error { + // Wait for either the stream context to be done or the loop to signal an error. select { case <-stream.Context().Done(): - return stream.Context().Err() - case ch <- wi: - } - - ctx := stream.Context() - if g.streamSendTimeout != nil { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, *g.streamSendTimeout) - defer cancel() - } - - select { - case <-ctx.Done(): - g.logger.Errorf("timed out while sending work item") - return fmt.Errorf("timed out while sending work item: %w", ctx.Err()) + g.logger.Info("work item stream closed") + return nil case err := <-errCh: return err } @@ -680,4 +592,3 @@ func createGetInstanceResponse(req *protos.GetInstanceRequest, metadata *Orchest return &protos.GetInstanceResponse{Exists: true, OrchestrationState: state} } - diff --git a/backend/local/task.go b/backend/local/task.go index 6d1666ea..35d2ab35 100644 --- a/backend/local/task.go +++ b/backend/local/task.go @@ -2,128 +2,136 @@ package local import ( "context" - "sync" "github.com/dapr/durabletask-go/api" "github.com/dapr/durabletask-go/api/protos" "github.com/dapr/durabletask-go/backend" + "github.com/dapr/durabletask-go/backend/local/loops" + looptask "github.com/dapr/durabletask-go/backend/local/loops/task" + "github.com/dapr/kit/events/loop" ) -type pendingOrchestrator struct { - response *protos.OrchestratorResponse - complete chan struct{} -} - -type pendingActivity struct { - response *protos.ActivityResponse - complete chan struct{} -} - type TasksBackend struct { - pendingOrchestrators *sync.Map - pendingActivities *sync.Map + loop loop.Interface[loops.EventTask] } func NewTasksBackend() *TasksBackend { return &TasksBackend{ - pendingOrchestrators: &sync.Map{}, - pendingActivities: &sync.Map{}, + loop: looptask.New(), } } +func (be *TasksBackend) Run(ctx context.Context) error { + return be.loop.Run(ctx) +} + +func (be *TasksBackend) Close() { + be.loop.Close(new(loops.Shutdown)) +} + func (be *TasksBackend) CompleteActivityTask(ctx context.Context, response *protos.ActivityResponse) error { - if be.deletePendingActivityTask(response.GetInstanceId(), response.GetTaskId(), response) { - return nil + errCh := make(chan error, 1) + be.loop.Enqueue(&loops.CompleteActivity{ + InstanceID: response.GetInstanceId(), + TaskID: response.GetTaskId(), + Response: response, + ErrCh: errCh, + }) + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err } - - return api.NewUnknownTaskIDError(response.GetInstanceId(), response.GetTaskId()) } func (be *TasksBackend) CancelActivityTask(ctx context.Context, instanceID api.InstanceID, taskID int32) error { - if be.deletePendingActivityTask(string(instanceID), taskID, nil) { - return nil + errCh := make(chan error, 1) + be.loop.Enqueue(&loops.CancelActivity{ + InstanceID: instanceID, + TaskID: taskID, + ErrCh: errCh, + }) + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err } - return api.NewUnknownTaskIDError(instanceID.String(), taskID) } func (be *TasksBackend) WaitForActivityCompletion(request *protos.ActivityRequest) func(context.Context) (*protos.ActivityResponse, error) { key := backend.GetActivityExecutionKey(request.GetOrchestrationInstance().GetInstanceId(), request.GetTaskId()) - pending := &pendingActivity{ - response: nil, - complete: make(chan struct{}, 1), - } - be.pendingActivities.Store(key, pending) + responseCh := make(chan *protos.ActivityResponse, 1) + + be.loop.Enqueue(&loops.RegisterPendingActivity{ + Key: key, + Response: responseCh, + }) return func(ctx context.Context) (*protos.ActivityResponse, error) { select { case <-ctx.Done(): return nil, ctx.Err() - case <-pending.complete: - if pending.response == nil { + case resp, ok := <-responseCh: + if !ok || resp == nil { return nil, api.ErrTaskCancelled } - return pending.response, nil + return resp, nil } } } func (be *TasksBackend) CompleteOrchestratorTask(ctx context.Context, response *protos.OrchestratorResponse) error { - if be.deletePendingOrchestrator(response.GetInstanceId(), response) { - return nil + errCh := make(chan error, 1) + be.loop.Enqueue(&loops.CompleteOrchestrator{ + InstanceID: response.GetInstanceId(), + Response: response, + ErrCh: errCh, + }) + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err } - return api.NewUnknownInstanceIDError(response.GetInstanceId()) } func (be *TasksBackend) CancelOrchestratorTask(ctx context.Context, instanceID api.InstanceID) error { - if be.deletePendingOrchestrator(string(instanceID), nil) { - return nil + errCh := make(chan error, 1) + be.loop.Enqueue(&loops.CancelOrchestrator{ + InstanceID: instanceID, + ErrCh: errCh, + }) + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err } - return api.NewUnknownInstanceIDError(instanceID.String()) } func (be *TasksBackend) WaitForOrchestratorCompletion(request *protos.OrchestratorRequest) func(context.Context) (*protos.OrchestratorResponse, error) { - pending := &pendingOrchestrator{ - response: nil, - complete: make(chan struct{}, 1), - } - be.pendingOrchestrators.Store(request.GetInstanceId(), pending) + responseCh := make(chan *protos.OrchestratorResponse, 1) + + be.loop.Enqueue(&loops.RegisterPendingOrchestrator{ + InstanceID: request.GetInstanceId(), + Response: responseCh, + }) return func(ctx context.Context) (*protos.OrchestratorResponse, error) { select { case <-ctx.Done(): return nil, ctx.Err() - case <-pending.complete: - if pending.response == nil { + case resp, ok := <-responseCh: + if !ok || resp == nil { return nil, api.ErrTaskCancelled } - return pending.response, nil + return resp, nil } } } - -func (be *TasksBackend) deletePendingActivityTask(iid string, taskID int32, res *protos.ActivityResponse) bool { - key := backend.GetActivityExecutionKey(iid, taskID) - p, ok := be.pendingActivities.LoadAndDelete(key) - if !ok { - return false - } - - // Note that res can be nil in case of certain failures - pending := p.(*pendingActivity) - pending.response = res - close(pending.complete) - return true -} - -func (be *TasksBackend) deletePendingOrchestrator(instanceID string, res *protos.OrchestratorResponse) bool { - p, ok := be.pendingOrchestrators.LoadAndDelete(instanceID) - if !ok { - return false - } - - // Note that res can be nil in case of certain failures - pending := p.(*pendingOrchestrator) - pending.response = res - close(pending.complete) - return true -} diff --git a/backend/postgres/postgres.go b/backend/postgres/postgres.go index a3f08efc..9e1d1b1f 100644 --- a/backend/postgres/postgres.go +++ b/backend/postgres/postgres.go @@ -1149,12 +1149,14 @@ func (be *postgresBackend) PurgeOrchestrationState(ctx context.Context, id api.I } // Start implements backend.Backend -func (*postgresBackend) Start(context.Context) error { +func (be *postgresBackend) Start(ctx context.Context) error { + go be.TasksBackend.Run(ctx) return nil } // Stop implements backend.Backend -func (*postgresBackend) Stop(context.Context) error { +func (be *postgresBackend) Stop(context.Context) error { + be.TasksBackend.Close() return nil } diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index 7356819f..accafa82 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -98,6 +98,10 @@ func NewSqliteBackend(opts *SqliteOptions, logger backend.Logger) backend.Backen // CreateTaskHub creates the sqlite database and applies the schema func (be *sqliteBackend) CreateTaskHub(context.Context) error { + if be.db != nil { + return backend.ErrTaskHubExists + } + db, err := sql.Open("sqlite", be.dsn) if err != nil { panic(fmt.Errorf("failed to open the database: %w", err)) @@ -1089,12 +1093,14 @@ func (be *sqliteBackend) PurgeOrchestrationState(ctx context.Context, id api.Ins } // Start implements backend.Backend -func (*sqliteBackend) Start(context.Context) error { +func (be *sqliteBackend) Start(ctx context.Context) error { + go be.TasksBackend.Run(ctx) return nil } // Stop implements backend.Backend -func (*sqliteBackend) Stop(context.Context) error { +func (be *sqliteBackend) Stop(context.Context) error { + be.TasksBackend.Close() return nil } diff --git a/task/executor.go b/task/executor.go index a6158133..9892a46b 100644 --- a/task/executor.go +++ b/task/executor.go @@ -154,6 +154,12 @@ func (te *taskExecutor) ExecuteOrchestrator(ctx context.Context, id api.Instance return response, nil } +func (te taskExecutor) Start(ctx context.Context) error { + // In-process executor has no background work. Block until context is cancelled. + <-ctx.Done() + return nil +} + func (te taskExecutor) Shutdown(ctx context.Context) error { // Nothing to do return nil diff --git a/tests/grpc/grpc_test.go b/tests/grpc/grpc_test.go index aae673d6..685dd752 100644 --- a/tests/grpc/grpc_test.go +++ b/tests/grpc/grpc_test.go @@ -54,6 +54,8 @@ func TestMain(m *testing.M) { if err := taskHubWorker.Start(ctx); err != nil { log.Fatalf("failed to start worker: %v", err) } + executorCtx, executorCancel := context.WithCancel(ctx) + go grpcExecutor.Start(executorCtx) lis, err := net.Listen("tcp", ":0") if err != nil { @@ -79,15 +81,9 @@ func TestMain(m *testing.M) { // Run the test exitCode exitCode := m.Run() + executorCancel() timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - err = grpcExecutor.Shutdown(timeoutCtx) - if err != nil { - log.Fatalf("failed to shutdown grpc Executor: %v", err) - } - - timeoutCtx, cancel = context.WithTimeout(ctx, 5*time.Second) - defer cancel() if err := taskHubWorker.Shutdown(timeoutCtx); err != nil { log.Fatalf("failed to shutdown worker: %v", err) } diff --git a/tests/mocks/Executor.go b/tests/mocks/Executor.go index 8da3a2d1..d464e30e 100644 --- a/tests/mocks/Executor.go +++ b/tests/mocks/Executor.go @@ -146,6 +146,52 @@ func (_c *Executor_ExecuteOrchestrator_Call) RunAndReturn(run func(context.Conte return _c } +// Start provides a mock function with given fields: _a0 +func (_m *Executor) Start(_a0 context.Context) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Start") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Executor_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type Executor_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +// - _a0 context.Context +func (_e *Executor_Expecter) Start(_a0 interface{}) *Executor_Start_Call { + return &Executor_Start_Call{Call: _e.mock.On("Start", _a0)} +} + +func (_c *Executor_Start_Call) Run(run func(_a0 context.Context)) *Executor_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Executor_Start_Call) Return(_a0 error) *Executor_Start_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Executor_Start_Call) RunAndReturn(run func(context.Context) error) *Executor_Start_Call { + _c.Call.Return(run) + return _c +} + // Shutdown provides a mock function with given fields: ctx func (_m *Executor) Shutdown(ctx context.Context) error { ret := _m.Called(ctx) From 9590e23c0bf35295ff6e1db63cf36b1b416bef81 Mon Sep 17 00:00:00 2001 From: joshvanl Date: Tue, 10 Mar 2026 15:00:58 +0000 Subject: [PATCH 4/4] Push-based worker dispatch with blocking Start Replaces the poll-based worker model with push-based dispatch and makes all Start methods blocking (return when context is cancelled). Worker changes: - TaskWorker interface: Start(ctx) returns error (blocking), adds Dispatch(wi, callback) for push-based dispatch, removes StopAndDrain. - TaskProcessor interface: removes NextWorkItem (no longer polls). - taskWorker uses N event loops (one per parallelism slot) with round-robin dispatch via atomic counter. - Removes processWorkItem from worker.go (moved to loops/worker handler). - activity.go/orchestration.go: removes NextWorkItem methods. TaskHub changes: - TaskHubWorker interface: removes Shutdown (callers cancel context). - taskHubWorker.Start uses RunnerManager to run backend, workers, and pollAndDispatch bridges concurrently. Blocks until context cancelled. - pollAndDispatch bridges blocking Next*WorkItem calls into fire-and- forget Dispatch calls to worker loops. Backend changes: - sqlite/postgres Start: now blocking (delegates to TasksBackend.Run). - sqlite CreateTaskHub: idempotent (returns ErrTaskHubExists if already initialized). Test changes: - All tests use context cancellation instead of StopAndDrain/Shutdown. - orchestrations_test.go: initTaskHubWorker returns CancelFunc, uses ready channel for synchronization. - worker_test.go: rewritten for push-based Dispatch API. - taskhub_test.go: simplified for blocking Start. - grpc_test.go: uses goroutines for blocking Start, context cancel for cleanup. Sample changes: - All samples updated: Init returns CancelFunc, Start runs in goroutine. Branched from https://github.com/dapr/durabletask-go/pull/73 Signed-off-by: joshvanl --- backend/activity.go | 7 +- backend/orchestration.go | 5 - backend/postgres/postgres.go | 3 +- backend/sqlite/sqlite.go | 3 +- backend/taskhub.go | 76 +++---- backend/worker.go | 153 +++++--------- .../distributedtracing/distributedtracing.go | 17 +- samples/externalevents/externalevents.go | 17 +- samples/parallel/parallel.go | 17 +- samples/retries/retries.go | 17 +- samples/sequence/sequence.go | 17 +- samples/taskexecutionid/taskexecutionid.go | 19 +- tests/grpc/grpc_test.go | 16 +- tests/mocks/TaskWorker.go | 69 ++++--- tests/mocks/task.go | 33 ---- tests/orchestrations_test.go | 171 ++++++++-------- tests/taskhub_test.go | 50 +++-- tests/worker_test.go | 186 ++++++++---------- 18 files changed, 388 insertions(+), 488 deletions(-) diff --git a/backend/activity.go b/backend/activity.go index 096c681b..91954fb9 100644 --- a/backend/activity.go +++ b/backend/activity.go @@ -38,12 +38,7 @@ func (*activityProcessor) Name() string { return "activity-processor" } -// NextWorkItem implements TaskDispatcher -func (ap *activityProcessor) NextWorkItem(ctx context.Context) (*ActivityWorkItem, error) { - return ap.be.NextActivityWorkItem(ctx) -} - -// ProcessWorkItem implements TaskDispatcher +// ProcessWorkItem implements TaskProcessor func (p *activityProcessor) ProcessWorkItem(ctx context.Context, awi *ActivityWorkItem) error { ts := awi.NewEvent.GetTaskScheduled() if ts == nil { diff --git a/backend/orchestration.go b/backend/orchestration.go index 59753398..2756cf0a 100644 --- a/backend/orchestration.go +++ b/backend/orchestration.go @@ -55,11 +55,6 @@ func (*orchestratorProcessor) Name() string { return "orchestration-processor" } -// NextWorkItem implements TaskProcessor -func (p *orchestratorProcessor) NextWorkItem(ctx context.Context) (*OrchestrationWorkItem, error) { - return p.be.NextOrchestrationWorkItem(ctx) -} - // ProcessWorkItem implements TaskProcessor func (w *orchestratorProcessor) ProcessWorkItem(ctx context.Context, wi *OrchestrationWorkItem) error { w.logger.Debugf("%v: received work item with %d new event(s): %v", wi.InstanceID, len(wi.NewEvents), helpers.HistoryListSummary(wi.NewEvents)) diff --git a/backend/postgres/postgres.go b/backend/postgres/postgres.go index 9e1d1b1f..57ffb473 100644 --- a/backend/postgres/postgres.go +++ b/backend/postgres/postgres.go @@ -1150,8 +1150,7 @@ func (be *postgresBackend) PurgeOrchestrationState(ctx context.Context, id api.I // Start implements backend.Backend func (be *postgresBackend) Start(ctx context.Context) error { - go be.TasksBackend.Run(ctx) - return nil + return be.TasksBackend.Run(ctx) } // Stop implements backend.Backend diff --git a/backend/sqlite/sqlite.go b/backend/sqlite/sqlite.go index accafa82..b9304b7f 100644 --- a/backend/sqlite/sqlite.go +++ b/backend/sqlite/sqlite.go @@ -1094,8 +1094,7 @@ func (be *sqliteBackend) PurgeOrchestrationState(ctx context.Context, id api.Ins // Start implements backend.Backend func (be *sqliteBackend) Start(ctx context.Context) error { - go be.TasksBackend.Run(ctx) - return nil + return be.TasksBackend.Run(ctx) } // Stop implements backend.Backend diff --git a/backend/taskhub.go b/backend/taskhub.go index 2fa92ccc..3bbb8569 100644 --- a/backend/taskhub.go +++ b/backend/taskhub.go @@ -2,15 +2,15 @@ package backend import ( "context" - "sync" + "errors" + + "github.com/dapr/kit/concurrency" ) type TaskHubWorker interface { - // Start starts the backend and the configured internal workers. + // Start starts the backend and the configured internal workers. Blocks + // until the context is cancelled. Start(context.Context) error - - // Shutdown stops the backend and all internal workers. - Shutdown(context.Context) error } type taskHubWorker struct { @@ -34,40 +34,48 @@ func (w *taskHubWorker) Start(ctx context.Context) error { return err } - if err := w.backend.Start(ctx); err != nil { - return err - } - w.logger.Infof("worker started with backend %v", w.backend) - w.orchestrationWorker.Start(ctx) - w.activityWorker.Start(ctx) - return nil -} + manager := concurrency.NewRunnerManager( + w.backend.Start, + w.orchestrationWorker.Start, + w.activityWorker.Start, + func(ctx context.Context) error { + pollAndDispatch(ctx, w.backend.NextOrchestrationWorkItem, w.orchestrationWorker, w.logger) + return nil + }, + func(ctx context.Context) error { + pollAndDispatch(ctx, w.backend.NextActivityWorkItem, w.activityWorker, w.logger) + return nil + }, + ) -func (w *taskHubWorker) Shutdown(ctx context.Context) error { - w.logger.Info("backend stopping...") - if err := w.backend.Stop(ctx); err != nil { - return err - } - - w.logger.Info("workers stopping and draining...") - defer w.logger.Info("finished stopping and draining workers!") - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - w.orchestrationWorker.StopAndDrain() + defer func() { + w.logger.Info("backend stopping...") + w.backend.Stop(context.Background()) }() - wg.Add(1) - go func() { - defer wg.Done() - w.activityWorker.StopAndDrain() - }() + return manager.Run(ctx) +} - wg.Wait() +// pollAndDispatch bridges a blocking NextWorkItem call into Dispatch. It runs +// until the context is cancelled. +func pollAndDispatch[T WorkItem](ctx context.Context, next func(context.Context) (T, error), worker TaskWorker[T], logger Logger) { + for { + wi, err := next(ctx) + if ctx.Err() != nil { + return + } + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + logger.Errorf("failed to get next work item: %v", err) + continue + } - return nil + // Fire-and-forget: the callback is not needed for poll-based + // backends since they don't need completion signaling. + worker.Dispatch(wi, make(chan error, 1)) + } } diff --git a/backend/worker.go b/backend/worker.go index 45ced88d..66adc9c0 100644 --- a/backend/worker.go +++ b/backend/worker.go @@ -2,33 +2,34 @@ package backend import ( "context" - "sync" + "sync/atomic" + + "github.com/dapr/durabletask-go/backend/loops" + loopworker "github.com/dapr/durabletask-go/backend/loops/worker" + "github.com/dapr/kit/concurrency" + "github.com/dapr/kit/events/loop" ) type TaskWorker[T WorkItem] interface { - // Start starts background polling for the activity work items. - Start(context.Context) + // Start starts the worker loops and blocks until the context is cancelled. + Start(context.Context) error - // StopAndDrain stops the worker and waits for all outstanding work items to finish. - StopAndDrain() + // Dispatch pushes a work item directly to a worker loop. The callback + // channel receives nil on completion or an error if the work item was + // abandoned. Dispatch round-robins across worker loops. + Dispatch(wi T, callback chan<- error) } type TaskProcessor[T WorkItem] interface { Name() string ProcessWorkItem(context.Context, T) error - NextWorkItem(context.Context) (T, error) AbandonWorkItem(context.Context, T) error CompleteWorkItem(context.Context, T) error } -type worker[T WorkItem] struct { - logger Logger - - processor TaskProcessor[T] - closeCh chan struct{} - wg sync.WaitGroup - workItems chan T - parallelLock chan struct{} +type taskWorker[T WorkItem] struct { + workers []loop.Interface[loops.EventWorker] + nextWorker atomic.Uint64 } type NewTaskWorkerOptions func(*WorkerOptions) @@ -53,104 +54,48 @@ func NewTaskWorker[T WorkItem](p TaskProcessor[T], logger Logger, opts ...NewTas configure(options) } - var parallelLock chan struct{} - if options.MaxParallelWorkItems != nil { - parallelLock = make(chan struct{}, *options.MaxParallelWorkItems) + n := int32(1) + if options.MaxParallelWorkItems != nil && *options.MaxParallelWorkItems > 1 { + n = *options.MaxParallelWorkItems } - return &worker[T]{ - processor: p, - logger: logger, - workItems: make(chan T), - parallelLock: parallelLock, - closeCh: make(chan struct{}), + workers := make([]loop.Interface[loops.EventWorker], n) + for i := range workers { + handler := loopworker.New(loopworker.Options[T]{ + Processor: p, + Logger: logger, + }) + workers[i] = loop.New[loops.EventWorker](64).NewLoop(handler) } -} -func (w *worker[T]) Name() string { - return w.processor.Name() + return &taskWorker[T]{ + workers: workers, + } } -func (w *worker[T]) Start(ctx context.Context) { - w.wg.Add(2) - - ctx, cancel := context.WithCancel(ctx) - - go func() { - defer w.wg.Done() - defer cancel() - - select { - case <-w.closeCh: - case <-ctx.Done(): - } - }() - - go func() { - defer w.wg.Done() - defer w.logger.Infof("%v: worker stopped", w.Name()) - - for { - - if w.parallelLock != nil { - select { - case w.parallelLock <- struct{}{}: - case <-ctx.Done(): - return - } - } - - wi, err := w.processor.NextWorkItem(ctx) - if err != nil { - if w.parallelLock != nil { - <-w.parallelLock - } - - if ctx.Err() != nil { - return - } - - w.logger.Errorf("%v: failed to get next work item: %v", w.Name(), err) - continue - } - - w.wg.Add(1) - go func() { - defer func() { - if w.parallelLock != nil { - <-w.parallelLock - } - w.wg.Done() - }() - w.processWorkItem(ctx, wi) - }() +func (w *taskWorker[T]) Start(ctx context.Context) error { + manager := concurrency.NewRunnerManager() + for _, worker := range w.workers { + manager.Add(worker.Run) + } + // When context is cancelled, close all worker loops so their Run methods + // unblock from the channel read and return. + manager.Add(func(ctx context.Context) error { + <-ctx.Done() + for _, worker := range w.workers { + worker.Close(new(loops.Shutdown)) } - }() + return nil + }) + return manager.Run(ctx) } -func (w *worker[T]) StopAndDrain() { - close(w.closeCh) - w.wg.Wait() +func (w *taskWorker[T]) Dispatch(wi T, callback chan<- error) { + // Round-robin across worker loops. + idx := w.nextWorker.Add(1) - 1 + w.workers[idx%uint64(len(w.workers))].Enqueue(&loops.DispatchWorkItem{ + WorkItem: wi, + Callback: callback, + }) } -func (w *worker[T]) processWorkItem(ctx context.Context, wi T) { - w.logger.Debugf("%v: processing work item: %s", w.Name(), wi) - - if err := w.processor.ProcessWorkItem(ctx, wi); err != nil { - w.logger.Errorf("%v: failed to process work item: %v", w.Name(), err) - if err = w.processor.AbandonWorkItem(context.Background(), wi); err != nil { - w.logger.Errorf("%v: failed to abandon work item: %v", w.Name(), err) - } - return - } - - if err := w.processor.CompleteWorkItem(ctx, wi); err != nil { - w.logger.Errorf("%v: failed to complete work item: %v", w.Name(), err) - if err = w.processor.AbandonWorkItem(context.Background(), wi); err != nil { - w.logger.Errorf("%v: failed to abandon work item: %v", w.Name(), err) - } - return - } - - w.logger.Debugf("%v: work item processed successfully", w.Name()) -} diff --git a/samples/distributedtracing/distributedtracing.go b/samples/distributedtracing/distributedtracing.go index 9d3a6538..6f810f78 100644 --- a/samples/distributedtracing/distributedtracing.go +++ b/samples/distributedtracing/distributedtracing.go @@ -40,11 +40,11 @@ func main() { // Init the client ctx := context.Background() - client, worker, err := Init(ctx, r) + client, shutdown, err := Init(ctx, r) if err != nil { log.Fatalf("Failed to initialize the client: %v", err) } - defer worker.Shutdown(ctx) + defer shutdown() // Start a new orchestration id, err := client.ScheduleNewOrchestration(ctx, DistributedTraceSampleOrchestrator) @@ -67,7 +67,7 @@ func main() { } // Init creates and initializes an in-memory client and worker pair with default configuration. -func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, backend.TaskHubWorker, error) { +func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, context.CancelFunc, error) { logger := backend.DefaultLogger() // Create an executor @@ -80,20 +80,17 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac Backend: be, Executor: executor, Logger: logger, + AppID: "sample", }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) - // Start the worker - err := taskHubWorker.Start(ctx) - if err != nil { - return nil, nil, err - } + ctx, cancel := context.WithCancel(ctx) + go taskHubWorker.Start(ctx) - // Get the client to the backend taskHubClient := backend.NewTaskHubClient(be) - return taskHubClient, taskHubWorker, nil + return taskHubClient, cancel, nil } func ConfigureZipkinTracing() (*trace.TracerProvider, error) { diff --git a/samples/externalevents/externalevents.go b/samples/externalevents/externalevents.go index 3960e760..5b5b14b8 100644 --- a/samples/externalevents/externalevents.go +++ b/samples/externalevents/externalevents.go @@ -19,11 +19,11 @@ func main() { // Init the client ctx := context.Background() - client, worker, err := Init(ctx, r) + client, shutdown, err := Init(ctx, r) if err != nil { log.Fatalf("Failed to initialize the client: %v", err) } - defer worker.Shutdown(ctx) + defer shutdown() // Start a new orchestration id, err := client.ScheduleNewOrchestration(ctx, "ExternalEventOrchestrator") @@ -58,7 +58,7 @@ func main() { } // Init creates and initializes an in-memory client and worker pair with default configuration. -func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, backend.TaskHubWorker, error) { +func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, context.CancelFunc, error) { logger := backend.DefaultLogger() // Create an executor @@ -71,20 +71,17 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac Backend: be, Executor: executor, Logger: logger, + AppID: "sample", }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) - // Start the worker - err := taskHubWorker.Start(ctx) - if err != nil { - return nil, nil, err - } + ctx, cancel := context.WithCancel(ctx) + go taskHubWorker.Start(ctx) - // Get the client to the backend taskHubClient := backend.NewTaskHubClient(be) - return taskHubClient, taskHubWorker, nil + return taskHubClient, cancel, nil } // ExternalEventOrchestrator is an orchestrator function that blocks for 30 seconds or diff --git a/samples/parallel/parallel.go b/samples/parallel/parallel.go index b950968f..c33ff8cd 100644 --- a/samples/parallel/parallel.go +++ b/samples/parallel/parallel.go @@ -23,11 +23,11 @@ func main() { // Init the client ctx := context.Background() - client, worker, err := Init(ctx, r) + client, shutdown, err := Init(ctx, r) if err != nil { log.Fatalf("Failed to initialize the client: %v", err) } - defer worker.Shutdown(ctx) + defer shutdown() // Start a new orchestration id, err := client.ScheduleNewOrchestration(ctx, UpdateDevicesOrchestrator) @@ -50,7 +50,7 @@ func main() { } // Init creates and initializes an in-memory client and worker pair with default configuration. -func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, backend.TaskHubWorker, error) { +func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, context.CancelFunc, error) { logger := backend.DefaultLogger() // Create an executor @@ -63,20 +63,19 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac Backend: be, Executor: executor, Logger: logger, + AppID: "sample", }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) - // Start the worker - err := taskHubWorker.Start(ctx) - if err != nil { - return nil, nil, err - } + // Start the worker in a goroutine (Start is blocking) + ctx, cancel := context.WithCancel(ctx) + go taskHubWorker.Start(ctx) // Get the client to the backend taskHubClient := backend.NewTaskHubClient(be) - return taskHubClient, taskHubWorker, nil + return taskHubClient, cancel, nil } // UpdateDevicesOrchestrator is an orchestrator that runs activities in parallel diff --git a/samples/retries/retries.go b/samples/retries/retries.go index 45fdc7fa..ac26c1a9 100644 --- a/samples/retries/retries.go +++ b/samples/retries/retries.go @@ -21,11 +21,11 @@ func main() { // Init the client ctx := context.Background() - client, worker, err := Init(ctx, r) + client, shutdown, err := Init(ctx, r) if err != nil { log.Fatalf("Failed to initialize the client: %v", err) } - defer worker.Shutdown(ctx) + defer shutdown() // Start a new orchestration id, err := client.ScheduleNewOrchestration(ctx, RetryActivityOrchestrator) @@ -48,7 +48,7 @@ func main() { } // Init creates and initializes an in-memory client and worker pair with default configuration. -func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, backend.TaskHubWorker, error) { +func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, context.CancelFunc, error) { logger := backend.DefaultLogger() // Create an executor @@ -61,20 +61,17 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac Backend: be, Executor: executor, Logger: logger, + AppID: "sample", }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) - // Start the worker - err := taskHubWorker.Start(ctx) - if err != nil { - return nil, nil, err - } + ctx, cancel := context.WithCancel(ctx) + go taskHubWorker.Start(ctx) - // Get the client to the backend taskHubClient := backend.NewTaskHubClient(be) - return taskHubClient, taskHubWorker, nil + return taskHubClient, cancel, nil } func RetryActivityOrchestrator(ctx *task.OrchestrationContext) (any, error) { diff --git a/samples/sequence/sequence.go b/samples/sequence/sequence.go index ec940900..47d31c7b 100644 --- a/samples/sequence/sequence.go +++ b/samples/sequence/sequence.go @@ -19,11 +19,11 @@ func main() { // Init the client ctx := context.Background() - client, worker, err := Init(ctx, r) + client, shutdown, err := Init(ctx, r) if err != nil { log.Fatalf("Failed to initialize the client: %v", err) } - defer worker.Shutdown(ctx) + defer shutdown() // Start a new orchestration id, err := client.ScheduleNewOrchestration(ctx, ActivitySequenceOrchestrator) @@ -46,7 +46,7 @@ func main() { } // Init creates and initializes an in-memory client and worker pair with default configuration. -func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, backend.TaskHubWorker, error) { +func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, context.CancelFunc, error) { logger := backend.DefaultLogger() // Create an executor @@ -59,20 +59,19 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac Backend: be, Executor: executor, Logger: logger, + AppID: "sample", }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) - // Start the worker - err := taskHubWorker.Start(ctx) - if err != nil { - return nil, nil, err - } + // Start the worker in a goroutine (Start is blocking) + ctx, cancel := context.WithCancel(ctx) + go taskHubWorker.Start(ctx) // Get the client to the backend taskHubClient := backend.NewTaskHubClient(be) - return taskHubClient, taskHubWorker, nil + return taskHubClient, cancel, nil } // ActivitySequenceOrchestrator makes three activity calls in sequence and results the results diff --git a/samples/taskexecutionid/taskexecutionid.go b/samples/taskexecutionid/taskexecutionid.go index 5a329b3d..29c379c7 100644 --- a/samples/taskexecutionid/taskexecutionid.go +++ b/samples/taskexecutionid/taskexecutionid.go @@ -23,13 +23,11 @@ func main() { // Init the client ctx := context.Background() - client, worker, err := Init(ctx, r) + client, shutdown, err := Init(ctx, r) if err != nil { log.Fatalf("Failed to initialize the client: %v", err) } - defer func() { - must(worker.Shutdown(ctx)) - }() + defer shutdown() // Start a new orchestration id, err := client.ScheduleNewOrchestration(ctx, RetryActivityOrchestrator) @@ -52,7 +50,7 @@ func main() { } // Init creates and initializes an in-memory client and worker pair with default configuration. -func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, backend.TaskHubWorker, error) { +func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, context.CancelFunc, error) { logger := backend.DefaultLogger() // Create an executor @@ -65,20 +63,17 @@ func Init(ctx context.Context, r *task.TaskRegistry) (backend.TaskHubClient, bac Backend: be, Executor: executor, Logger: logger, + AppID: "sample", }) activityWorker := backend.NewActivityTaskWorker(be, executor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) - // Start the worker - err := taskHubWorker.Start(ctx) - if err != nil { - return nil, nil, err - } + ctx, cancel := context.WithCancel(ctx) + go taskHubWorker.Start(ctx) - // Get the client to the backend taskHubClient := backend.NewTaskHubClient(be) - return taskHubClient, taskHubWorker, nil + return taskHubClient, cancel, nil } func RetryActivityOrchestrator(ctx *task.OrchestrationContext) (any, error) { diff --git a/tests/grpc/grpc_test.go b/tests/grpc/grpc_test.go index 685dd752..60444901 100644 --- a/tests/grpc/grpc_test.go +++ b/tests/grpc/grpc_test.go @@ -51,11 +51,10 @@ func TestMain(m *testing.M) { activityWorker := backend.NewActivityTaskWorker(be, grpcExecutor, logger) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) - if err := taskHubWorker.Start(ctx); err != nil { - log.Fatalf("failed to start worker: %v", err) - } - executorCtx, executorCancel := context.WithCancel(ctx) - go grpcExecutor.Start(executorCtx) + + workerCtx, workerCancel := context.WithCancel(ctx) + go taskHubWorker.Start(workerCtx) + go grpcExecutor.Start(workerCtx) lis, err := net.Listen("tcp", ":0") if err != nil { @@ -81,12 +80,7 @@ func TestMain(m *testing.M) { // Run the test exitCode exitCode := m.Run() - executorCancel() - timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - if err := taskHubWorker.Shutdown(timeoutCtx); err != nil { - log.Fatalf("failed to shutdown worker: %v", err) - } + workerCancel() grpcServer.Stop() os.Exit(exitCode) } diff --git a/tests/mocks/TaskWorker.go b/tests/mocks/TaskWorker.go index d55dfe01..8e7fdfa6 100644 --- a/tests/mocks/TaskWorker.go +++ b/tests/mocks/TaskWorker.go @@ -23,68 +23,83 @@ func (_m *TaskWorker[T]) EXPECT() *TaskWorker_Expecter[T] { return &TaskWorker_Expecter[T]{mock: &_m.Mock} } -// Start provides a mock function with given fields: _a0 -func (_m *TaskWorker[T]) Start(_a0 context.Context) { - _m.Called(_a0) +// Dispatch provides a mock function with given fields: wi, callback +func (_m *TaskWorker[T]) Dispatch(wi T, callback chan<- error) { + _m.Called(wi, callback) } -// TaskWorker_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' -type TaskWorker_Start_Call[T backend.WorkItem] struct { +// TaskWorker_Dispatch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Dispatch' +type TaskWorker_Dispatch_Call[T backend.WorkItem] struct { *mock.Call } -// Start is a helper method to define mock.On call -// - _a0 context.Context -func (_e *TaskWorker_Expecter[T]) Start(_a0 interface{}) *TaskWorker_Start_Call[T] { - return &TaskWorker_Start_Call[T]{Call: _e.mock.On("Start", _a0)} +// Dispatch is a helper method to define mock.On call +// - wi T +// - callback chan<- error +func (_e *TaskWorker_Expecter[T]) Dispatch(wi interface{}, callback interface{}) *TaskWorker_Dispatch_Call[T] { + return &TaskWorker_Dispatch_Call[T]{Call: _e.mock.On("Dispatch", wi, callback)} } -func (_c *TaskWorker_Start_Call[T]) Run(run func(_a0 context.Context)) *TaskWorker_Start_Call[T] { +func (_c *TaskWorker_Dispatch_Call[T]) Run(run func(wi T, callback chan<- error)) *TaskWorker_Dispatch_Call[T] { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(T), args[1].(chan<- error)) }) return _c } -func (_c *TaskWorker_Start_Call[T]) Return() *TaskWorker_Start_Call[T] { +func (_c *TaskWorker_Dispatch_Call[T]) Return() *TaskWorker_Dispatch_Call[T] { _c.Call.Return() return _c } -func (_c *TaskWorker_Start_Call[T]) RunAndReturn(run func(context.Context)) *TaskWorker_Start_Call[T] { +func (_c *TaskWorker_Dispatch_Call[T]) RunAndReturn(run func(T, chan<- error)) *TaskWorker_Dispatch_Call[T] { _c.Run(run) return _c } -// StopAndDrain provides a mock function with no fields -func (_m *TaskWorker[T]) StopAndDrain() { - _m.Called() +// Start provides a mock function with given fields: _a0 +func (_m *TaskWorker[T]) Start(_a0 context.Context) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Start") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 } -// TaskWorker_StopAndDrain_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StopAndDrain' -type TaskWorker_StopAndDrain_Call[T backend.WorkItem] struct { +// TaskWorker_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type TaskWorker_Start_Call[T backend.WorkItem] struct { *mock.Call } -// StopAndDrain is a helper method to define mock.On call -func (_e *TaskWorker_Expecter[T]) StopAndDrain() *TaskWorker_StopAndDrain_Call[T] { - return &TaskWorker_StopAndDrain_Call[T]{Call: _e.mock.On("StopAndDrain")} +// Start is a helper method to define mock.On call +// - _a0 context.Context +func (_e *TaskWorker_Expecter[T]) Start(_a0 interface{}) *TaskWorker_Start_Call[T] { + return &TaskWorker_Start_Call[T]{Call: _e.mock.On("Start", _a0)} } -func (_c *TaskWorker_StopAndDrain_Call[T]) Run(run func()) *TaskWorker_StopAndDrain_Call[T] { +func (_c *TaskWorker_Start_Call[T]) Run(run func(_a0 context.Context)) *TaskWorker_Start_Call[T] { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } -func (_c *TaskWorker_StopAndDrain_Call[T]) Return() *TaskWorker_StopAndDrain_Call[T] { - _c.Call.Return() +func (_c *TaskWorker_Start_Call[T]) Return(_a0 error) *TaskWorker_Start_Call[T] { + _c.Call.Return(_a0) return _c } -func (_c *TaskWorker_StopAndDrain_Call[T]) RunAndReturn(run func()) *TaskWorker_StopAndDrain_Call[T] { - _c.Run(run) +func (_c *TaskWorker_Start_Call[T]) RunAndReturn(run func(context.Context) error) *TaskWorker_Start_Call[T] { + _c.Call.Return(run) return _c } diff --git a/tests/mocks/task.go b/tests/mocks/task.go index f16a969b..6cb1ec96 100644 --- a/tests/mocks/task.go +++ b/tests/mocks/task.go @@ -16,9 +16,6 @@ type TestTaskProcessor[T backend.WorkItem] struct { processingBlocked atomic.Bool - workItemMu sync.Mutex - workItems []T - abandonedWorkItemMu sync.Mutex abandonedWorkItems []T @@ -40,14 +37,6 @@ func (t *TestTaskProcessor[T]) UnblockProcessing() { t.processingBlocked.Store(false) } -func (t *TestTaskProcessor[T]) PendingWorkItems() []T { - t.workItemMu.Lock() - defer t.workItemMu.Unlock() - - // copy array - return append([]T{}, t.workItems...) -} - func (t *TestTaskProcessor[T]) AbandonedWorkItems() []T { t.abandonedWorkItemMu.Lock() defer t.abandonedWorkItemMu.Unlock() @@ -64,32 +53,10 @@ func (t *TestTaskProcessor[T]) CompletedWorkItems() []T { return append([]T{}, t.completedWorkItems...) } -func (t *TestTaskProcessor[T]) AddWorkItems(wis ...T) { - t.workItemMu.Lock() - defer t.workItemMu.Unlock() - - t.workItems = append(t.workItems, wis...) -} - func (t *TestTaskProcessor[T]) Name() string { return t.name } -func (t *TestTaskProcessor[T]) NextWorkItem(context.Context) (T, error) { - t.workItemMu.Lock() - defer t.workItemMu.Unlock() - - if len(t.workItems) == 0 { - var tt T - return tt, errors.New("no work items") - } - - wi := t.workItems[0] - t.workItems = t.workItems[1:] - - return wi, nil -} - func (t *TestTaskProcessor[T]) ProcessWorkItem(ctx context.Context, wi T) error { if !t.processingBlocked.Load() { return nil diff --git a/tests/orchestrations_test.go b/tests/orchestrations_test.go index e3aa4cc8..8874ec19 100644 --- a/tests/orchestrations_test.go +++ b/tests/orchestrations_test.go @@ -37,8 +37,8 @@ func Test_EmptyOrchestration(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "EmptyOrchestrator") @@ -66,8 +66,8 @@ func Test_SingleTimer(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "SingleTimer") @@ -108,8 +108,8 @@ func Test_ConcurrentTimers(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "TimerFanOut") @@ -147,8 +147,8 @@ func Test_IsReplaying(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "IsReplayingOrch") @@ -193,8 +193,8 @@ func Test_SingleActivity(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("世界")) @@ -242,8 +242,8 @@ func Test_SingleActivity_TaskSpan(t *testing.T) { ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "SingleActivity", api.WithInput("世界")) @@ -290,8 +290,8 @@ func Test_ActivityChain(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "ActivityChain") @@ -333,8 +333,8 @@ func Test_ActivityRetries(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "ActivityRetries") @@ -391,8 +391,8 @@ func Test_ActivityFanOut(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r, backend.WithMaxParallelism(10)) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r, backend.WithMaxParallelism(10)) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "ActivityFanOut") @@ -439,8 +439,8 @@ func Test_SingleSubOrchestrator_Completed(t *testing.T) { ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() id, err := client.ScheduleNewOrchestration(ctx, "Parent", api.WithInput("Hello, world!")) require.NoError(t, err) @@ -471,8 +471,8 @@ func Test_SingleSubOrchestrator_Failed(t *testing.T) { ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() id, err := client.ScheduleNewOrchestration(ctx, "Parent") require.NoError(t, err) @@ -510,8 +510,8 @@ func Test_SingleSubOrchestrator_Failed_Retries(t *testing.T) { ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() id, err := client.ScheduleNewOrchestration(ctx, "Parent") require.NoError(t, err) @@ -555,8 +555,8 @@ func Test_ContinueAsNew(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "ContinueAsNewTest", api.WithInput(0)) @@ -607,8 +607,8 @@ func Test_ContinueAsNew_Events(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "ContinueAsNewTest", api.WithInput(0)) @@ -647,8 +647,8 @@ func Test_ExternalEventContention(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "ContinueAsNewTest") @@ -691,8 +691,8 @@ func Test_ExternalEventOrchestration(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "ExternalEventOrchestration", api.WithInput(0)) @@ -744,8 +744,8 @@ func Test_ExternalEventTimeout(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run two variations, one where we raise the external event and one where we don't (timeout) for _, raiseEvent := range []bool{true, false} { @@ -816,8 +816,8 @@ func Test_SuspendResumeOrchestration(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration, which will block waiting for external events id, err := client.ScheduleNewOrchestration(ctx, "SuspendResumeOrchestration", api.WithInput(0)) @@ -889,8 +889,8 @@ func Test_TerminateOrchestration(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration, which will block waiting for external events id, err := client.ScheduleNewOrchestration(ctx, "MyOrchestrator") @@ -945,8 +945,8 @@ func Test_TerminateOrchestration_Recursive(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Test terminating with and without recursion for _, recurse := range []bool{true, false} { @@ -1021,8 +1021,8 @@ func Test_TerminateOrchestration_Recursive_TerminateCompletedSubOrchestration(t // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Test terminating with and without recursion for _, recurse := range []bool{true, false} { @@ -1101,8 +1101,8 @@ func Test_PurgeCompletedOrchestration(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "ExternalEventOrchestration") @@ -1160,8 +1160,8 @@ func Test_PurgeOrchestration_Recursive(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Test terminating with and without recursion for _, recurse := range []bool{true, false} { @@ -1229,8 +1229,8 @@ func Test_RecreateCompletedOrchestration(t *testing.T) { // Initialization ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the first orchestration id, err := client.ScheduleNewOrchestration(ctx, "HelloOrchestration", api.WithInput("世界")) @@ -1284,8 +1284,8 @@ func Test_SingleActivity_ReuseInstanceIDIgnore(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() instanceID := api.InstanceID("IGNORE_IF_RUNNING_OR_COMPLETED") reuseIdPolicy := &api.OrchestrationIdReusePolicy{ @@ -1335,8 +1335,8 @@ func Test_SingleActivity_ReuseInstanceIDTerminate(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() instanceID := api.InstanceID("TERMINATE_IF_RUNNING_OR_COMPLETED") reuseIdPolicy := &api.OrchestrationIdReusePolicy{ @@ -1386,8 +1386,8 @@ func Test_SingleActivity_ReuseInstanceIDError(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() instanceID := api.InstanceID("ERROR_IF_RUNNING_OR_COMPLETED") @@ -1428,8 +1428,8 @@ func Test_TaskExecutionId(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "TaskExecutionID") @@ -1489,8 +1489,8 @@ func Test_TaskExecutionId(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "TaskExecutionID") @@ -1530,8 +1530,8 @@ func Test_TaskExecutionId(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "TaskExecutionID") @@ -1587,8 +1587,8 @@ func Test_ActivityTraceContext(t *testing.T) { ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "TraceContextOrchestration") @@ -1623,8 +1623,8 @@ func Test_OrchestrationPatching_DefaultToPatched(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "Orchestrator") @@ -1658,8 +1658,8 @@ func Test_OrchestrationPatching_RunUnpatchedVersion(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "Orchestrator") @@ -1695,8 +1695,8 @@ func Test_OrchestrationPatching_MultiplePatches(t *testing.T) { // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "Orchestrator") @@ -1735,8 +1735,8 @@ func Test_OrchestrationPatching_ContinueAsNewDoNotCarryOverChoices(t *testing.T) // Initialization ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() // Run the orchestration id, err := client.ScheduleNewOrchestration(ctx, "Orchestrator") @@ -1777,8 +1777,8 @@ func Test_OrchestrationPatching_PatchPersistsAcrossReplays(t *testing.T) { }) ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() id, err := client.ScheduleNewOrchestration(ctx, "Orchestrator") require.NoError(t, err) @@ -1830,8 +1830,8 @@ func Test_OrchestrationPatching_PatchRemembersToStayFalse(t *testing.T) { }) ctx := context.Background() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() id, err := client.ScheduleNewOrchestration(ctx, "Orchestrator") require.NoError(t, err) @@ -1868,8 +1868,8 @@ func Test_OrchestrationPatching_TracingSpans(t *testing.T) { ctx := context.Background() exporter := utils.InitTracing() - client, worker := initTaskHubWorker(ctx, r) - defer worker.Shutdown(ctx) + client, shutdown := initTaskHubWorker(ctx, r) + defer shutdown() id, err := client.ScheduleNewOrchestration(ctx, "PatchTracingOrchestrator") require.NoError(t, err) @@ -1889,7 +1889,7 @@ func Test_OrchestrationPatching_TracingSpans(t *testing.T) { ) } -func initTaskHubWorker(ctx context.Context, r *task.TaskRegistry, opts ...backend.NewTaskWorkerOptions) (backend.TaskHubClient, backend.TaskHubWorker) { +func initTaskHubWorker(ctx context.Context, r *task.TaskRegistry, opts ...backend.NewTaskWorkerOptions) (backend.TaskHubClient, context.CancelFunc) { // TODO: Switch to options pattern logger := backend.DefaultLogger() be := sqlite.NewSqliteBackend(sqlite.NewSqliteOptions(""), logger) @@ -1902,9 +1902,18 @@ func initTaskHubWorker(ctx context.Context, r *task.TaskRegistry, opts ...backen }, opts...) activityWorker := backend.NewActivityTaskWorker(be, executor, logger, opts...) taskHubWorker := backend.NewTaskHubWorker(be, orchestrationWorker, activityWorker, logger) - if err := taskHubWorker.Start(ctx); err != nil { - panic(err) - } + ctx, cancel := context.WithCancel(ctx) + ready := make(chan struct{}) + go func() { + // CreateTaskHub initializes the DB. Signal readiness so the test + // can proceed once the backend is initialized. + if err := be.CreateTaskHub(ctx); err != nil { + panic(err) + } + close(ready) + taskHubWorker.Start(ctx) + }() + <-ready taskHubClient := backend.NewTaskHubClient(be) - return taskHubClient, taskHubWorker + return taskHubClient, cancel } diff --git a/tests/taskhub_test.go b/tests/taskhub_test.go index 538004e6..8f908fd3 100644 --- a/tests/taskhub_test.go +++ b/tests/taskhub_test.go @@ -7,37 +7,49 @@ import ( "github.com/dapr/durabletask-go/backend" "github.com/dapr/durabletask-go/tests/mocks" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func Test_TaskHubWorkerStartsDependencies(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) be := mocks.NewBackend(t) orchWorker := mocks.NewTaskWorker[*backend.OrchestrationWorkItem](t) actWorker := mocks.NewTaskWorker[*backend.ActivityWorkItem](t) - be.EXPECT().CreateTaskHub(ctx).Return(nil).Once() - be.EXPECT().Start(ctx).Return(nil).Once() - orchWorker.EXPECT().Start(ctx).Return().Once() - actWorker.EXPECT().Start(ctx).Return().Once() + be.EXPECT().CreateTaskHub(mock.Anything).Return(nil).Once() + + // All Start methods block until context is cancelled. Return nil on cancel. + be.EXPECT().Start(mock.Anything).RunAndReturn(func(ctx context.Context) error { + <-ctx.Done() + return nil + }).Once() + orchWorker.EXPECT().Start(mock.Anything).RunAndReturn(func(ctx context.Context) error { + <-ctx.Done() + return nil + }).Once() + actWorker.EXPECT().Start(mock.Anything).RunAndReturn(func(ctx context.Context) error { + <-ctx.Done() + return nil + }).Once() + + // The pollAndDispatch goroutines will call Next*WorkItem on the backend. + be.On("NextOrchestrationWorkItem", mock.Anything).Return(nil, context.Canceled).Maybe() + be.On("NextActivityWorkItem", mock.Anything).Return(nil, context.Canceled).Maybe() + + // Stop is called during cleanup. + be.EXPECT().Stop(mock.Anything).Return(nil).Once() w := backend.NewTaskHubWorker(be, orchWorker, actWorker, logger) - err := w.Start(ctx) - assert.NoError(t, err) -} -func Test_TaskHubWorkerStopsDependencies(t *testing.T) { - ctx := context.Background() + errCh := make(chan error, 1) + go func() { + errCh <- w.Start(ctx) + }() - be := mocks.NewBackend(t) - orchWorker := mocks.NewTaskWorker[*backend.OrchestrationWorkItem](t) - actWorker := mocks.NewTaskWorker[*backend.ActivityWorkItem](t) - - be.EXPECT().Stop(ctx).Return(nil).Once() - orchWorker.EXPECT().StopAndDrain().Return().Once() - actWorker.EXPECT().StopAndDrain().Return().Once() + // Cancel context to unblock Start. + cancel() - w := backend.NewTaskHubWorker(be, orchWorker, actWorker, logger) - err := w.Shutdown(ctx) + err := <-errCh assert.NoError(t, err) } diff --git a/tests/worker_test.go b/tests/worker_test.go index 6324569a..1126522c 100644 --- a/tests/worker_test.go +++ b/tests/worker_test.go @@ -3,7 +3,6 @@ package tests import ( "context" "errors" - "sync/atomic" "testing" "time" @@ -13,7 +12,6 @@ import ( "github.com/dapr/durabletask-go/backend/runtimestate" "github.com/dapr/durabletask-go/tests/mocks" "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/timestamppb" @@ -26,7 +24,9 @@ var ( ) func Test_TryProcessSingleOrchestrationWorkItem_BasicFlow(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + wi := &backend.OrchestrationWorkItem{ InstanceID: "test123", NewEvents: []*protos.HistoryEvent{ @@ -48,18 +48,9 @@ func Test_TryProcessSingleOrchestrationWorkItem_BasicFlow(t *testing.T) { state := &backend.OrchestrationRuntimeState{} result := &protos.OrchestratorResponse{} - ctx, cancel := context.WithCancel(ctx) - completed := atomic.Bool{} be := mocks.NewBackend(t) - be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(wi, nil).Once() - be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(nil, errors.New("")).Once().Run(func(mock.Arguments) { - cancel() - }) be.EXPECT().GetOrchestrationRuntimeState(anyContext, wi).Return(state, nil).Once() - be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi).RunAndReturn(func(ctx context.Context, owi *backend.OrchestrationWorkItem) error { - completed.Store(true) - return nil - }).Once() + be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi).Return(nil).Once() ex := mocks.NewExecutor(t) ex.EXPECT().ExecuteOrchestrator(anyContext, wi.InstanceID, state.OldEvents, mock.Anything).Return(result, nil).Once() @@ -70,15 +61,19 @@ func Test_TryProcessSingleOrchestrationWorkItem_BasicFlow(t *testing.T) { Logger: logger, AppID: "testapp", }) - worker.Start(ctx) + go worker.Start(ctx) - require.EventuallyWithT(t, func(collect *assert.CollectT) { - if !completed.Load() { - collect.Errorf("process next not called CompleteOrchestrationWorkItem yet") - } - }, 1*time.Second, 100*time.Millisecond) + callback := make(chan error, 1) + worker.Dispatch(wi, callback) + + select { + case err := <-callback: + require.NoError(t, err) + case <-time.After(1 * time.Second): + t.Fatal("dispatch callback not received within timeout") + } - worker.StopAndDrain() + cancel() t.Logf("state.NewEvents: %v", state.NewEvents) require.Len(t, state.NewEvents, 2) @@ -109,9 +104,8 @@ func Test_TryProcessSingleOrchestrationWorkItem_Idempotency(t *testing.T) { } ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) + defer cancel() - completed := atomic.Bool{} be := mocks.NewBackend(t) ex := mocks.NewExecutor(t) @@ -125,18 +119,8 @@ func Test_TryProcessSingleOrchestrationWorkItem_Idempotency(t *testing.T) { return &protos.OrchestratorResponse{}, nil }).Times(2) - be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(wi, nil).Once() be.EXPECT().AbandonOrchestrationWorkItem(anyContext, wi).Return(nil).Once() - - be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(wi, nil).Once() - be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi).RunAndReturn(func(ctx context.Context, owi *backend.OrchestrationWorkItem) error { - completed.Store(true) - return nil - }).Once() - - be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(nil, errors.New("")).Once().Run(func(mock.Arguments) { - cancel() - }) + be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi).Return(nil).Once() worker := backend.NewOrchestrationWorker(backend.OrchestratorOptions{ Backend: be, @@ -144,11 +128,27 @@ func Test_TryProcessSingleOrchestrationWorkItem_Idempotency(t *testing.T) { Logger: logger, AppID: "testapp", }, backend.WithMaxParallelism(1)) - worker.Start(ctx) + go worker.Start(ctx) + + // First dispatch: orchestrator returns an error, work item should be abandoned. + cb1 := make(chan error, 1) + worker.Dispatch(wi, cb1) + select { + case <-cb1: + case <-time.After(1 * time.Second): + t.Fatal("first dispatch callback not received within timeout") + } - require.Eventually(t, completed.Load, 2*time.Second, 10*time.Millisecond) + // Second dispatch: orchestrator succeeds, work item should be completed. + cb2 := make(chan error, 1) + worker.Dispatch(wi, cb2) + select { + case <-cb2: + case <-time.After(1 * time.Second): + t.Fatal("second dispatch callback not received within timeout") + } - worker.StopAndDrain() + cancel() t.Logf("state.NewEvents: %v", wi.State.NewEvents) require.Len(t, wi.State.NewEvents, 3) @@ -158,7 +158,6 @@ func Test_TryProcessSingleOrchestrationWorkItem_Idempotency(t *testing.T) { } func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t *testing.T) { - ctx := context.Background() iid := api.InstanceID("test123") // Simulate getting an ExecutionStarted message from the orchestration queue @@ -184,13 +183,10 @@ func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t * // Empty orchestration runtime state since we're starting a new execution from scratch state := runtimestate.NewOrchestrationRuntimeState(string(iid), nil, []*protos.HistoryEvent{}) - ctx, cancel := context.WithCancel(ctx) - be := mocks.NewBackend(t) - be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(wi, nil).Once() - be.EXPECT().NextOrchestrationWorkItem(anyContext).Return(nil, errors.New("")).Once().Run(func(mock.Arguments) { - cancel() - }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + be := mocks.NewBackend(t) be.EXPECT().GetOrchestrationRuntimeState(anyContext, wi).Return(state, nil).Once() ex := mocks.NewExecutor(t) @@ -216,11 +212,7 @@ func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t * ex.EXPECT().ExecuteOrchestrator(anyContext, iid, []*protos.HistoryEvent{}, mock.Anything).Return(result, nil).Once() // After execution, the Complete action should be called - completed := atomic.Bool{} - be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi).RunAndReturn(func(ctx context.Context, owi *backend.OrchestrationWorkItem) error { - completed.Store(true) - return nil - }).Once() + be.EXPECT().CompleteOrchestrationWorkItem(anyContext, wi).Return(nil).Once() // Set up and run the test worker := backend.NewOrchestrationWorker(backend.OrchestratorOptions{ @@ -229,19 +221,19 @@ func Test_TryProcessSingleOrchestrationWorkItem_ExecutionStartedAndCompleted(t * Logger: logger, AppID: "testapp", }) - worker.Start(ctx) - //ok, err := worker.ProcessNext(ctx) - //// Successfully processing a work-item should result in a nil error - //assert.Nil(t, err) - //assert.True(t, ok) - - require.EventuallyWithT(t, func(collect *assert.CollectT) { - if !completed.Load() { - collect.Errorf("process next not called CompleteOrchestrationWorkItem yet") - } - }, 1*time.Second, 100*time.Millisecond) + go worker.Start(ctx) - worker.StopAndDrain() + callback := make(chan error, 1) + worker.Dispatch(wi, callback) + + select { + case err := <-callback: + require.NoError(t, err) + case <-time.After(1 * time.Second): + t.Fatal("dispatch callback not received within timeout") + } + + cancel() t.Logf("state.NewEvents: %v", state.NewEvents) require.Len(t, state.NewEvents, 3) @@ -263,79 +255,65 @@ func Test_TaskWorker(t *testing.T) { second := &backend.ActivityWorkItem{ SequenceNumber: 2, } - tp.AddWorkItems(first, second) worker := backend.NewTaskWorker[*backend.ActivityWorkItem](tp, logger, backend.WithMaxParallelism(1)) + go worker.Start(ctx) - worker.Start(ctx) + cb1 := make(chan error, 1) + cb2 := make(chan error, 1) + worker.Dispatch(first, cb1) + worker.Dispatch(second, cb2) - require.EventuallyWithT(t, func(collect *assert.CollectT) { - if len(tp.PendingWorkItems()) == 0 { - return - } - collect.Errorf("work items not consumed yet") - }, 500*time.Millisecond, 100*time.Millisecond) + select { + case <-cb1: + case <-time.After(1 * time.Second): + t.Fatal("first dispatch callback not received within timeout") + } + select { + case <-cb2: + case <-time.After(1 * time.Second): + t.Fatal("second dispatch callback not received within timeout") + } - require.Len(t, tp.PendingWorkItems(), 0) require.Len(t, tp.AbandonedWorkItems(), 0) require.Len(t, tp.CompletedWorkItems(), 2) require.Equal(t, first, tp.CompletedWorkItems()[0]) require.Equal(t, second, tp.CompletedWorkItems()[1]) - - drainFinished := make(chan bool) - go func() { - worker.StopAndDrain() - drainFinished <- true - }() - - select { - case <-drainFinished: - return - case <-time.After(1 * time.Second): - t.Fatalf("worker stop and drain not finished within timeout") - } - } func Test_StartAndStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - defer cancel() tp := mocks.NewTestTaskPocessor[*backend.ActivityWorkItem]("test") tp.BlockProcessing() - first := backend.ActivityWorkItem{ + first := &backend.ActivityWorkItem{ SequenceNumber: 1, } - second := backend.ActivityWorkItem{ - SequenceNumber: 2, - } - tp.AddWorkItems(&first, &second) worker := backend.NewTaskWorker[*backend.ActivityWorkItem](tp, logger, backend.WithMaxParallelism(1)) - worker.Start(ctx) - - require.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Len(c, tp.PendingWorkItems(), 1) - }, time.Second*5, 100*time.Millisecond) - - // due to the configuration of the TestTaskProcessor, now the work item is blocked on ProcessWorkItem until the context is cancelled - drainFinished := make(chan bool) + startDone := make(chan error, 1) go func() { - worker.StopAndDrain() - drainFinished <- true + startDone <- worker.Start(ctx) }() + // Dispatch a work item that will block on processing. + cb := make(chan error, 1) + worker.Dispatch(first, cb) + + // The work item should be in-flight (blocked). Give it a moment to start. + time.Sleep(50 * time.Millisecond) + + // Cancel context to stop the worker. This unblocks processing and causes abandon. + cancel() + select { - case <-drainFinished: - return + case <-startDone: case <-time.After(1 * time.Second): - t.Fatalf("worker stop and drain not finished within timeout") + t.Fatalf("worker start not finished within timeout") } - require.Len(t, tp.PendingWorkItems(), 1) - require.Equal(t, second, tp.PendingWorkItems()[0]) require.Len(t, tp.AbandonedWorkItems(), 1) require.Equal(t, first, tp.AbandonedWorkItems()[0]) require.Len(t, tp.CompletedWorkItems(), 0)