diff --git a/chasm/chasmtest/task_helpers.go b/chasm/chasmtest/task_helpers.go new file mode 100644 index 00000000000..c92836887a8 --- /dev/null +++ b/chasm/chasmtest/task_helpers.go @@ -0,0 +1,98 @@ +package chasmtest + +import ( + "context" + "fmt" + + "go.temporal.io/server/chasm" +) + +// ExecutePureTask validates and executes a pure task atomically via [Engine.UpdateComponent]. +// It returns taskDropped=true if [chasm.PureTaskHandler.Validate] returns (false, nil), +// indicating the task is no longer relevant and was not executed. +// +// For root components, construct ref with [chasm.NewComponentRef]. For subcomponents, obtain +// the ref via [chasm.Context.Ref] inside a [Engine.ReadComponent] or [Engine.UpdateComponent] +// callback and deserialize it with [chasm.DeserializeComponentRef]. +// +// This helper ensures that Validate is always exercised alongside Execute, matching the real +// engine's behavior. Use [chasm.MockMutableContext] directly when you need to inspect the +// typed task payloads added to the context during execution. +func ExecutePureTask[C chasm.Component, T any]( + ctx context.Context, + e *Engine, + ref chasm.ComponentRef, + handler chasm.PureTaskHandler[C, T], + attrs chasm.TaskAttributes, + task T, +) (taskDropped bool, err error) { + engineCtx := chasm.NewEngineContext(ctx, e) + _, err = e.UpdateComponent( + engineCtx, + ref, + func(mutableCtx chasm.MutableContext, c chasm.Component) error { + typedC, ok := c.(C) + if !ok { + return fmt.Errorf("component type mismatch: got %T", c) + } + var valid bool + valid, err = handler.Validate(mutableCtx, typedC, attrs, task) + if err != nil { + return err + } + if !valid { + taskDropped = true + return nil + } + return handler.Execute(mutableCtx, typedC, attrs, task) + }, + ) + return taskDropped, err +} + +// ExecuteSideEffectTask validates and executes a side-effect task. +// Validation runs via [Engine.ReadComponent] (read-only), and if valid, +// [chasm.SideEffectTaskHandler.Execute] is called with an engine context so that +// [chasm.UpdateComponent] and [chasm.ReadComponent] inside the handler route through +// the test engine. +// +// It returns taskDropped=true if [chasm.SideEffectTaskHandler.Validate] returns (false, nil), +// indicating the task is no longer relevant and was not executed. +// +// For root components, construct ref with [chasm.NewComponentRef]. For subcomponents, obtain +// the ref via [chasm.Context.Ref] inside a [Engine.ReadComponent] or [Engine.UpdateComponent] +// callback and deserialize it with [chasm.DeserializeComponentRef]. +// +// Use [chasm.MockMutableContext] directly when you need to inspect typed task payloads added +// during execution, since the real engine serializes them into history layer tasks. +func ExecuteSideEffectTask[C chasm.Component, T any]( + ctx context.Context, + e *Engine, + ref chasm.ComponentRef, + handler chasm.SideEffectTaskHandler[C, T], + attrs chasm.TaskAttributes, + task T, +) (taskDropped bool, err error) { + engineCtx := chasm.NewEngineContext(ctx, e) + + var valid bool + if err = e.ReadComponent( + engineCtx, + ref, + func(chasmCtx chasm.Context, c chasm.Component) error { + typedC, ok := c.(C) + if !ok { + return fmt.Errorf("component type mismatch: got %T", c) + } + valid, err = handler.Validate(chasmCtx, typedC, attrs, task) + return err + }, + ); err != nil { + return false, err + } + if !valid { + return true, nil + } + + return false, handler.Execute(engineCtx, ref, attrs, task) +} diff --git a/chasm/chasmtest/test_engine.go b/chasm/chasmtest/test_engine.go new file mode 100644 index 00000000000..41699370cec --- /dev/null +++ b/chasm/chasmtest/test_engine.go @@ -0,0 +1,711 @@ +package chasmtest + +import ( + "context" + "fmt" + "sync" + "testing" + + enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/api/serviceerror" + enumsspb "go.temporal.io/server/api/enums/v1" + persistencespb "go.temporal.io/server/api/persistence/v1" + "go.temporal.io/server/chasm" + "go.temporal.io/server/common/clock" + "go.temporal.io/server/common/definition" + "go.temporal.io/server/common/log" + "go.temporal.io/server/common/metrics" + "go.temporal.io/server/common/testing/testlogger" + "go.temporal.io/server/service/history/tasks" +) + +// WithTimeSource overrides the engine's default real-time clock with the provided time source. +// Pass a *clock.EventTimeSource when tests need to control what ctx.Now() returns inside handlers. +// The caller holds the reference and calls ts.Update(...) directly to advance time. +func WithTimeSource(ts clock.TimeSource) EngineOption { + return func(e *Engine) { + e.timeSource = ts + } +} + +type ( + EngineOption func(*Engine) + + // Engine is a lightweight in-memory CHASM engine for unit tests. It implements + // [chasm.Engine] and supports the full set of conflict and reuse policies, as + // well as blocking [PollComponent] with [NotifyExecution] — matching the behavior + // of the production engine as closely as possible without persistence or shard logic. + Engine struct { + t *testing.T + registry *chasm.Registry + logger log.Logger + metrics metrics.Handler + timeSource clock.TimeSource + // current maps (namespaceID, businessID) -> the latest run (running or closed). + current map[businessKey]*execution + // all maps (namespaceID, businessID, runID) -> any run, for lookups by specific RunID. + all map[runKey]*execution + notifier *executionNotifier + } + + executionStatus int + + execution struct { + key chasm.ExecutionKey + node *chasm.Node + backend *chasm.MockNodeBackend + root chasm.RootComponent + status executionStatus + requestID string + // failed is only meaningful when status == executionStatusClosed. + // It controls whether AllowDuplicateFailedOnly reuse policy permits a new run. + failed bool + } + + businessKey struct { + namespaceID string + businessID string + } + + runKey struct { + namespaceID string + businessID string + runID string + } +) + +const ( + executionStatusRunning executionStatus = iota + executionStatusClosed +) + +var defaultTransitionOptions = chasm.TransitionOptions{ + ReusePolicy: chasm.BusinessIDReusePolicyAllowDuplicate, + ConflictPolicy: chasm.BusinessIDConflictPolicyFail, +} + +var _ chasm.Engine = (*Engine)(nil) + +func NewEngine( + t *testing.T, + registry *chasm.Registry, + opts ...EngineOption, +) *Engine { + t.Helper() + + e := &Engine{ + t: t, + registry: registry, + logger: testlogger.NewTestLogger(t, testlogger.FailOnExpectedErrorOnly), + metrics: metrics.NoopMetricsHandler, + timeSource: clock.NewRealTimeSource(), + current: make(map[businessKey]*execution), + all: make(map[runKey]*execution), + notifier: newExecutionNotifier(), + } + + for _, opt := range opts { + opt(e) + } + + return e +} + +// CloseExecution marks the execution identified by ref as closed, removing it from the +// set of running executions. Set failed=true to simulate a failed or terminated execution, +// which affects the [chasm.BusinessIDReusePolicyAllowDuplicateFailedOnly] reuse policy. +// A closed execution can still be read via [Engine.ReadComponent] using its specific RunID, +// but it will no longer be returned as the current run for the businessID. +func (e *Engine) CloseExecution(_ context.Context, ref chasm.ComponentRef, failed bool) error { + exec, err := e.executionForRef(ref) + if err != nil { + return err + } + exec.status = executionStatusClosed + exec.failed = failed + status := enumspb.WORKFLOW_EXECUTION_STATUS_COMPLETED + if failed { + status = enumspb.WORKFLOW_EXECUTION_STATUS_FAILED + } + // Keep the backend execution state consistent so that CloseTransaction correctly + // skips the lifecycle-change logic on any subsequent UpdateComponent calls. + _, _ = exec.backend.UpdateWorkflowStateStatus(enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED, status) + return nil +} + +// Tasks returns all tasks scheduled for the execution identified by ref, grouped by category. +// Tasks accumulate across every [Engine.UpdateComponent], [Engine.StartExecution], and +// [Engine.UpdateWithStartExecution] call on the execution, matching what the real engine +// would deliver to task processors. +func (e *Engine) Tasks(ref chasm.ComponentRef) (map[tasks.Category][]tasks.Task, error) { + exec, err := e.executionForRef(ref) + if err != nil { + return nil, err + } + // Return a shallow copy so callers cannot mutate the internal task lists. + result := make(map[tasks.Category][]tasks.Task, len(exec.backend.TasksByCategory)) + for cat, ts := range exec.backend.TasksByCategory { + result[cat] = ts + } + return result, nil +} + +func (e *Engine) StartExecution( + ctx context.Context, + ref chasm.ComponentRef, + startFn func(chasm.MutableContext) (chasm.RootComponent, error), + opts ...chasm.TransitionOption, +) (chasm.StartExecutionResult, error) { + options := constructTransitionOptions(opts...) + bKey := newBusinessKey(ref.ExecutionKey) + + current, hasCurrent := e.current[bKey] + if hasCurrent { + // Idempotency: if the requestID matches the original create request, return the existing run. + if options.RequestID != "" && options.RequestID == current.requestID { + serializedRef, err := current.node.Ref(current.root) + if err != nil { + return chasm.StartExecutionResult{}, err + } + return chasm.StartExecutionResult{ + ExecutionKey: current.key, + ExecutionRef: serializedRef, + Created: false, + }, nil + } + + switch current.status { + case executionStatusRunning: + return e.handleConflictPolicy(ctx, ref, current, startFn, options) + case executionStatusClosed: + return e.handleReusePolicy(ctx, ref, current, startFn, options) + } + } + + return e.startNew(ctx, ref.ExecutionKey, startFn, options.RequestID) +} + +func (e *Engine) UpdateWithStartExecution( + ctx context.Context, + ref chasm.ComponentRef, + startFn func(chasm.MutableContext) (chasm.RootComponent, error), + updateFn func(chasm.MutableContext, chasm.Component) error, + opts ...chasm.TransitionOption, +) (chasm.EngineUpdateWithStartExecutionResult, error) { + options := constructTransitionOptions(opts...) + bKey := newBusinessKey(ref.ExecutionKey) + + current, hasCurrent := e.current[bKey] + if hasCurrent { + if current.status == executionStatusRunning { + // Execution is running — just apply the update, no start. + serializedRef, err := e.updateComponentInExecution(ctx, current, ref, updateFn) + if err != nil { + return chasm.EngineUpdateWithStartExecutionResult{}, err + } + return chasm.EngineUpdateWithStartExecutionResult{ + ExecutionKey: current.key, + ExecutionRef: serializedRef, + Created: false, + }, nil + } + + // Execution is closed — check reuse policy before starting a new one. + switch options.ReusePolicy { + case chasm.BusinessIDReusePolicyAllowDuplicate: + // No restriction; fall through to start+update. + case chasm.BusinessIDReusePolicyAllowDuplicateFailedOnly: + if !current.failed { + return chasm.EngineUpdateWithStartExecutionResult{}, chasm.NewExecutionAlreadyStartedErr( + fmt.Sprintf( + "CHASM execution already completed successfully. BusinessID: %s, RunID: %s, ID Reuse Policy: %v", + ref.BusinessID, current.key.RunID, options.ReusePolicy, + ), + current.requestID, + current.key.RunID, + ) + } + case chasm.BusinessIDReusePolicyRejectDuplicate: + return chasm.EngineUpdateWithStartExecutionResult{}, chasm.NewExecutionAlreadyStartedErr( + fmt.Sprintf( + "CHASM execution already finished. BusinessID: %s, RunID: %s, ID Reuse Policy: %v", + ref.BusinessID, current.key.RunID, options.ReusePolicy, + ), + current.requestID, + current.key.RunID, + ) + default: + return chasm.EngineUpdateWithStartExecutionResult{}, serviceerror.NewInternal( + fmt.Sprintf("unknown business ID reuse policy: %v", options.ReusePolicy), + ) + } + } + + return e.startAndUpdateNew(ctx, ref.ExecutionKey, startFn, updateFn, options.RequestID) +} + +func (e *Engine) UpdateComponent( + ctx context.Context, + ref chasm.ComponentRef, + updateFn func(chasm.MutableContext, chasm.Component) error, + _ ...chasm.TransitionOption, +) ([]byte, error) { + execution, err := e.executionForRef(ref) + if err != nil { + return nil, err + } + return e.updateComponentInExecution(ctx, execution, ref, updateFn) +} + +func (e *Engine) ReadComponent( + ctx context.Context, + ref chasm.ComponentRef, + readFn func(chasm.Context, chasm.Component) error, + _ ...chasm.TransitionOption, +) error { + execution, err := e.executionForRef(ref) + if err != nil { + return err + } + + chasmCtx := chasm.NewContext(ctx, execution.node) + component, err := execution.node.Component(chasmCtx, ref) + if err != nil { + return err + } + + return readFn(chasmCtx, component) +} + +// PollComponent waits until the supplied predicate is satisfied when evaluated against the +// component identified by ref. If the predicate is true immediately it returns without blocking. +// Otherwise it subscribes to [NotifyExecution] signals and re-evaluates after each one, just +// like the production engine. Returns (nil, nil) if ctx is cancelled (long-poll timeout +// semantics — the caller should re-poll). +func (e *Engine) PollComponent( + ctx context.Context, + ref chasm.ComponentRef, + predicate func(chasm.Context, chasm.Component) (bool, error), + _ ...chasm.TransitionOption, +) ([]byte, error) { + executionKey := ref.ExecutionKey + + checkPredicate := func() ([]byte, bool, error) { + exec, err := e.executionForRef(ref) + if err != nil { + return nil, false, err + } + chasmCtx := chasm.NewContext(ctx, exec.node) + component, err := exec.node.Component(chasmCtx, ref) + if err != nil { + return nil, false, err + } + satisfied, err := predicate(chasmCtx, component) + if err != nil || !satisfied { + return nil, satisfied, err + } + serializedRef, err := exec.node.Ref(component) + return serializedRef, true, err + } + + // Evaluate once before subscribing. + serializedRef, satisfied, err := checkPredicate() + if err != nil || satisfied { + return serializedRef, err + } + + for { + ch, unsubscribe := e.notifier.subscribe(executionKey) + // Re-evaluate while holding the subscription to avoid missing a notification + // that arrives between the failed check above and this subscribe call. + serializedRef, satisfied, err = checkPredicate() + if err != nil || satisfied { + unsubscribe() + return serializedRef, err + } + + select { + case <-ch: + unsubscribe() + serializedRef, satisfied, err = checkPredicate() + if err != nil || satisfied { + return serializedRef, err + } + case <-ctx.Done(): + unsubscribe() + return nil, nil //nolint:nilerr // nil, nil = long-poll timeout; caller should re-poll + } + } +} + +// NotifyExecution wakes up any [PollComponent] callers waiting on the execution. +func (e *Engine) NotifyExecution(key chasm.ExecutionKey) { + e.notifier.notify(key) +} + +func (e *Engine) DeleteExecution( + _ context.Context, + ref chasm.ComponentRef, + _ chasm.DeleteExecutionRequest, +) error { + exec, err := e.executionForRef(ref) + if err != nil { + return err + } + rKey := newRunKey(exec.key) + bKey := newBusinessKey(exec.key) + delete(e.all, rKey) + // Only evict from current if this is still the current run for the businessID. + if cur, ok := e.current[bKey]; ok && cur == exec { + delete(e.current, bKey) + } + return nil +} + +// handleConflictPolicy is called when a StartExecution arrives for a businessID whose +// current run is still running. +func (e *Engine) handleConflictPolicy( + ctx context.Context, + ref chasm.ComponentRef, + current *execution, + startFn func(chasm.MutableContext) (chasm.RootComponent, error), + options chasm.TransitionOptions, +) (chasm.StartExecutionResult, error) { + switch options.ConflictPolicy { + case chasm.BusinessIDConflictPolicyFail: + return chasm.StartExecutionResult{}, chasm.NewExecutionAlreadyStartedErr( + fmt.Sprintf( + "CHASM execution still running. BusinessID: %s, RunID: %s, ID Conflict Policy: %v", + ref.BusinessID, current.key.RunID, options.ConflictPolicy, + ), + current.requestID, + current.key.RunID, + ) + case chasm.BusinessIDConflictPolicyTerminateExisting: + current.status = executionStatusClosed + current.failed = true + return e.startNew(ctx, ref.ExecutionKey, startFn, options.RequestID) + case chasm.BusinessIDConflictPolicyUseExisting: + serializedRef, err := current.node.Ref(current.root) + if err != nil { + return chasm.StartExecutionResult{}, err + } + return chasm.StartExecutionResult{ + ExecutionKey: current.key, + ExecutionRef: serializedRef, + Created: false, + }, nil + default: + return chasm.StartExecutionResult{}, serviceerror.NewInternal( + fmt.Sprintf("unknown business ID conflict policy: %v", options.ConflictPolicy), + ) + } +} + +// handleReusePolicy is called when a StartExecution arrives for a businessID whose +// current run is closed/completed. +func (e *Engine) handleReusePolicy( + ctx context.Context, + ref chasm.ComponentRef, + current *execution, + startFn func(chasm.MutableContext) (chasm.RootComponent, error), + options chasm.TransitionOptions, +) (chasm.StartExecutionResult, error) { + switch options.ReusePolicy { + case chasm.BusinessIDReusePolicyAllowDuplicate: + // No restriction. + case chasm.BusinessIDReusePolicyAllowDuplicateFailedOnly: + if !current.failed { + return chasm.StartExecutionResult{}, chasm.NewExecutionAlreadyStartedErr( + fmt.Sprintf( + "CHASM execution already completed successfully. BusinessID: %s, RunID: %s, ID Reuse Policy: %v", + ref.BusinessID, current.key.RunID, options.ReusePolicy, + ), + current.requestID, + current.key.RunID, + ) + } + case chasm.BusinessIDReusePolicyRejectDuplicate: + return chasm.StartExecutionResult{}, chasm.NewExecutionAlreadyStartedErr( + fmt.Sprintf( + "CHASM execution already finished. BusinessID: %s, RunID: %s, ID Reuse Policy: %v", + ref.BusinessID, current.key.RunID, options.ReusePolicy, + ), + current.requestID, + current.key.RunID, + ) + default: + return chasm.StartExecutionResult{}, serviceerror.NewInternal( + fmt.Sprintf("unknown business ID reuse policy: %v", options.ReusePolicy), + ) + } + return e.startNew(ctx, ref.ExecutionKey, startFn, options.RequestID) +} + +// startNew creates a new execution and registers it as the current run for the businessID. +func (e *Engine) startNew( + ctx context.Context, + key chasm.ExecutionKey, + startFn func(chasm.MutableContext) (chasm.RootComponent, error), + requestID string, +) (chasm.StartExecutionResult, error) { + exec := e.newExecution(key) + exec.requestID = requestID + + mutableCtx := chasm.NewMutableContext(ctx, exec.node) + root, err := startFn(mutableCtx) + if err != nil { + return chasm.StartExecutionResult{}, err + } + if err := exec.node.SetRootComponent(root); err != nil { + return chasm.StartExecutionResult{}, err + } + if _, err = exec.node.CloseTransaction(); err != nil { + return chasm.StartExecutionResult{}, err + } + + exec.root = root + e.current[newBusinessKey(exec.key)] = exec + e.all[newRunKey(exec.key)] = exec + + serializedRef, err := exec.node.Ref(root) + if err != nil { + return chasm.StartExecutionResult{}, err + } + + return chasm.StartExecutionResult{ + ExecutionKey: exec.key, + ExecutionRef: serializedRef, + Created: true, + }, nil +} + +// startAndUpdateNew creates a new execution, applies startFn and updateFn in the same +// transaction, and registers it as the current run for the businessID. +func (e *Engine) startAndUpdateNew( + ctx context.Context, + key chasm.ExecutionKey, + startFn func(chasm.MutableContext) (chasm.RootComponent, error), + updateFn func(chasm.MutableContext, chasm.Component) error, + requestID string, +) (chasm.EngineUpdateWithStartExecutionResult, error) { + exec := e.newExecution(key) + exec.requestID = requestID + + mutableCtx := chasm.NewMutableContext(ctx, exec.node) + root, err := startFn(mutableCtx) + if err != nil { + return chasm.EngineUpdateWithStartExecutionResult{}, err + } + if err := exec.node.SetRootComponent(root); err != nil { + return chasm.EngineUpdateWithStartExecutionResult{}, err + } + if err := updateFn(mutableCtx, root); err != nil { + return chasm.EngineUpdateWithStartExecutionResult{}, err + } + if _, err = exec.node.CloseTransaction(); err != nil { + return chasm.EngineUpdateWithStartExecutionResult{}, err + } + + exec.root = root + e.current[newBusinessKey(exec.key)] = exec + e.all[newRunKey(exec.key)] = exec + + serializedRef, err := exec.node.Ref(root) + if err != nil { + return chasm.EngineUpdateWithStartExecutionResult{}, err + } + + return chasm.EngineUpdateWithStartExecutionResult{ + ExecutionKey: exec.key, + ExecutionRef: serializedRef, + Created: true, + }, nil +} + +func (e *Engine) newExecution(key chasm.ExecutionKey) *execution { + // bsMu guards transitionCount and execState, which are shared across handler closures. + // This is a separate mutex from MockNodeBackend's internal mu to avoid deadlocks. + var ( + bsMu sync.Mutex + transitionCount int64 = 1 + execState = persistencespb.WorkflowExecutionState{ + State: enumsspb.WORKFLOW_EXECUTION_STATE_CREATED, + Status: enumspb.WORKFLOW_EXECUTION_STATUS_RUNNING, + } + ) + + backend := &chasm.MockNodeBackend{ + // NextTransitionCount increments on every CloseTransaction call, matching + // the real engine's per-transition monotonic counter. + HandleNextTransitionCount: func() int64 { + bsMu.Lock() + defer bsMu.Unlock() + transitionCount++ + return transitionCount + }, + // CurrentVersionedTransition reflects the latest committed transition count. + HandleCurrentVersionedTransition: func() *persistencespb.VersionedTransition { + bsMu.Lock() + defer bsMu.Unlock() + return &persistencespb.VersionedTransition{ + NamespaceFailoverVersion: 1, + TransitionCount: transitionCount, + } + }, + HandleGetCurrentVersion: func() int64 { return 1 }, + HandleGetWorkflowKey: func() definition.WorkflowKey { + return definition.NewWorkflowKey(key.NamespaceID, key.BusinessID, key.RunID) + }, + HandleIsWorkflow: func() bool { return false }, + // GetExecutionState returns the current lifecycle state, which CloseTransaction + // uses to decide whether to call UpdateWorkflowStateStatus. + HandleGetExecutionState: func() *persistencespb.WorkflowExecutionState { + bsMu.Lock() + defer bsMu.Unlock() + s := execState // copy to avoid aliasing + return &s + }, + // UpdateWorkflowStateStatus is called by CloseTransaction when the root + // component's LifecycleState changes (Running → Completed/Failed/Terminated). + HandleUpdateWorkflowStateStatus: func(state enumsspb.WorkflowExecutionState, status enumspb.WorkflowExecutionStatus) (bool, error) { + bsMu.Lock() + defer bsMu.Unlock() + changed := execState.State != state || execState.Status != status + execState.State = state + execState.Status = status + return changed, nil + }, + } + return &execution{ + key: key, + backend: backend, + node: chasm.NewEmptyTree( + e.registry, + e.timeSource, + backend, + chasm.DefaultPathEncoder, + e.logger, + e.metrics, + ), + } +} + +// executionForRef looks up an execution by the ref's RunID when present, or falls back +// to the current run for the businessID when RunID is empty. +func (e *Engine) executionForRef(ref chasm.ComponentRef) (*execution, error) { + if ref.RunID != "" { + exec, ok := e.all[newRunKey(ref.ExecutionKey)] + if !ok { + return nil, serviceerror.NewNotFound( + fmt.Sprintf("execution not found: namespace=%q business_id=%q run_id=%q", ref.NamespaceID, ref.BusinessID, ref.RunID), + ) + } + return exec, nil + } + exec, ok := e.current[newBusinessKey(ref.ExecutionKey)] + if !ok { + return nil, serviceerror.NewNotFound( + fmt.Sprintf("execution not found: namespace=%q business_id=%q", ref.NamespaceID, ref.BusinessID), + ) + } + return exec, nil +} + +func (e *Engine) updateComponentInExecution( + ctx context.Context, + execution *execution, + ref chasm.ComponentRef, + updateFn func(chasm.MutableContext, chasm.Component) error, +) ([]byte, error) { + chasmCtx := chasm.NewContext(ctx, execution.node) + component, err := execution.node.Component(chasmCtx, ref) + if err != nil { + return nil, err + } + + mutableCtx := chasm.NewMutableContext(ctx, execution.node) + if err := updateFn(mutableCtx, component); err != nil { + return nil, err + } + + if _, err = execution.node.CloseTransaction(); err != nil { + return nil, err + } + + return mutableCtx.Ref(component) +} + +func constructTransitionOptions(opts ...chasm.TransitionOption) chasm.TransitionOptions { + options := defaultTransitionOptions + for _, opt := range opts { + opt(&options) + } + // NOTE: TransitionOptions.Speculative is intentionally not implemented here — it is also + // unimplemented in the production engine (see the TODO in service/history/chasm_engine.go). + return options +} + +func newBusinessKey(key chasm.ExecutionKey) businessKey { + return businessKey{namespaceID: key.NamespaceID, businessID: key.BusinessID} +} + +func newRunKey(key chasm.ExecutionKey) runKey { + return runKey{namespaceID: key.NamespaceID, businessID: key.BusinessID, runID: key.RunID} +} + +// executionNotifier allows [PollComponent] callers to subscribe to state-change +// signals for a given execution. [notify] closes the channel for all current +// subscribers; each subscriber must resubscribe after being woken. +type executionNotifier struct { + mu sync.Mutex + subscribers map[chasm.ExecutionKey][]chan struct{} +} + +func newExecutionNotifier() *executionNotifier { + return &executionNotifier{ + subscribers: make(map[chasm.ExecutionKey][]chan struct{}), + } +} + +// subscribe returns a channel that will be closed on the next [notify] for key, +// and an unsubscribe function that must be called when the caller is done waiting. +func (n *executionNotifier) subscribe(key chasm.ExecutionKey) (<-chan struct{}, func()) { + ch := make(chan struct{}) + n.mu.Lock() + n.subscribers[key] = append(n.subscribers[key], ch) + n.mu.Unlock() + + unsubscribed := false + unsubscribe := func() { + n.mu.Lock() + defer n.mu.Unlock() + if unsubscribed { + return + } + unsubscribed = true + subs := n.subscribers[key] + for i, s := range subs { + if s == ch { + n.subscribers[key] = append(subs[:i], subs[i+1:]...) + if len(n.subscribers[key]) == 0 { + delete(n.subscribers, key) + } + break + } + } + } + return ch, unsubscribe +} + +// notify closes all subscriber channels for key, waking any blocked [PollComponent] callers. +func (n *executionNotifier) notify(key chasm.ExecutionKey) { + n.mu.Lock() + subs := n.subscribers[key] + delete(n.subscribers, key) + n.mu.Unlock() + + for _, ch := range subs { + close(ch) + } +} diff --git a/chasm/lib/callback/tasks_test.go b/chasm/lib/callback/tasks_test.go index 7804d084edc..b455d529253 100644 --- a/chasm/lib/callback/tasks_test.go +++ b/chasm/lib/callback/tasks_test.go @@ -15,6 +15,7 @@ import ( "go.temporal.io/server/api/historyservicemock/v1" persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/chasm" + "go.temporal.io/server/chasm/chasmtest" callbackspb "go.temporal.io/server/chasm/lib/callback/gen/callbackpb/v1" "go.temporal.io/server/common/backoff" "go.temporal.io/server/common/clock" @@ -168,17 +169,13 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { metrics.DestinationTag("http://localhost"), metrics.OutcomeTag(tc.expectedMetricOutcome)) - // Setup logger and time source + // Setup logger logger := log.NewTestLogger() - timeSource := clock.NewEventTimeSource() - timeSource.Update(time.Now()) // Create task handler with mock namespace registry nsRegistry := namespace.NewMockRegistry(ctrl) nsRegistry.EXPECT().GetNamespaceByID(gomock.Any()).Return(ns, nil) - // Create mock engine - mockEngine := chasm.NewMockEngine(ctrl) handler := &invocationTaskHandler{ config: &Config{ RequestTimeout: dynamicconfig.GetDurationPropertyFnFilteredByDestination(time.Second), @@ -202,13 +199,10 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { err = chasmRegistry.Register(&mockNexusCompletionGetterLibrary{}) require.NoError(t, err) - nodeBackend := &chasm.MockNodeBackend{} - root := chasm.NewEmptyTree(chasmRegistry, timeSource, nodeBackend, chasm.DefaultPathEncoder, logger, metricsHandler) - callback := &Callback{ CallbackState: &callbackspb.CallbackState{ RequestId: "request-id", - RegistrationTime: timestamppb.New(timeSource.Now()), + RegistrationTime: timestamppb.New(time.Now()), Callback: &callbackspb.Callback{ Variant: &callbackspb.Callback_Nexus_{ Nexus: &callbackspb.Callback_Nexus{ @@ -224,72 +218,55 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { // Create completion completion := nexusrpc.CompleteOperationOptions{} - // Set up the CompletionSource field to return our mock completion - require.NoError(t, root.SetRootComponent(&mockNexusCompletionGetterComponent{ - completion: completion, - // Create callback in SCHEDULED state - Callback: chasm.NewComponentField( - chasm.NewMutableContext(context.Background(), root), - callback, - ), - })) - _, err = root.CloseTransaction() - require.NoError(t, err) - - // Setup engine expectations to directly call handler logic with MockMutableContext - mockEngine.EXPECT().ReadComponent( - gomock.Any(), - gomock.Any(), - gomock.Any(), - ).DoAndReturn(func(ctx context.Context, ref chasm.ComponentRef, readFn func(chasm.Context, chasm.Component) error, opts ...chasm.TransitionOption) error { - mockCtx := &chasm.MockContext{ - HandleNow: func(component chasm.Component) time.Time { - return timeSource.Now() - }, - HandleRef: func(component chasm.Component) ([]byte, error) { - return []byte{}, nil - }, - } - return readFn(mockCtx, callback) - }) - - mockEngine.EXPECT().UpdateComponent( - gomock.Any(), - gomock.Any(), - gomock.Any(), - ).DoAndReturn(func(ctx context.Context, ref chasm.ComponentRef, updateFn func(chasm.MutableContext, chasm.Component) error, opts ...chasm.TransitionOption) ([]any, error) { - mockCtx := &chasm.MockMutableContext{ - MockContext: chasm.MockContext{ - HandleNow: func(component chasm.Component) time.Time { - return timeSource.Now() - }, - HandleRef: func(component chasm.Component) ([]byte, error) { - return []byte{}, nil - }, - }, - } - err := updateFn(mockCtx, callback) - return nil, err - }) - - // Create ComponentRef - ref := chasm.NewComponentRef[*Callback](chasm.ExecutionKey{ + executionKey := chasm.ExecutionKey{ NamespaceID: "namespace-id", BusinessID: "workflow-id", RunID: "run-id", - }) + } + testEngine := chasmtest.NewEngine(t, chasmRegistry) + engineCtx := chasm.NewEngineContext(context.Background(), testEngine) + _, err = chasm.StartExecution[*mockNexusCompletionGetterComponent, struct{}]( + engineCtx, + executionKey, + func(ctx chasm.MutableContext, _ struct{}) (*mockNexusCompletionGetterComponent, error) { + return &mockNexusCompletionGetterComponent{ + completion: completion, + Callback: chasm.NewComponentField(ctx, callback), + }, nil + }, + struct{}{}, + ) + require.NoError(t, err) + + rootRef := chasm.NewComponentRef[*mockNexusCompletionGetterComponent](executionKey) + var callbackRef chasm.ComponentRef + require.NoError(t, testEngine.ReadComponent(engineCtx, rootRef, func(chasmCtx chasm.Context, _ chasm.Component) error { + serialized, err := chasmCtx.Ref(callback) + if err != nil { + return err + } + callbackRef, err = chasm.DeserializeComponentRef(serialized) + return err + })) - // Execute with engine context - engineCtx := chasm.NewEngineContext(context.Background(), mockEngine) err = handler.Execute( engineCtx, - ref, + callbackRef, chasm.TaskAttributes{Destination: "http://localhost"}, &callbackspb.InvocationTask{Attempt: 0}, ) - // Verify the outcome and tasks - tc.assertOutcome(t, callback, err) + // Verify outcome by reading component state directly + var resultCallback *Callback + require.NoError(t, testEngine.ReadComponent( + engineCtx, + callbackRef, + func(_ chasm.Context, c chasm.Component) error { + resultCallback = c.(*Callback) + return nil + }, + )) + tc.assertOutcome(t, resultCallback, err) }) } } diff --git a/chasm/lib/scheduler/handler_test.go b/chasm/lib/scheduler/handler_test.go index 03ca797714c..209209ddeca 100644 --- a/chasm/lib/scheduler/handler_test.go +++ b/chasm/lib/scheduler/handler_test.go @@ -1,14 +1,17 @@ package scheduler_test import ( + "context" "testing" "github.com/stretchr/testify/require" "go.temporal.io/api/serviceerror" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/chasm" + "go.temporal.io/server/chasm/chasmtest" "go.temporal.io/server/chasm/lib/scheduler" schedulerpb "go.temporal.io/server/chasm/lib/scheduler/gen/schedulerpb/v1" + "go.temporal.io/server/common/log" legacyscheduler "go.temporal.io/server/service/worker/scheduler" "go.uber.org/mock/gomock" "google.golang.org/protobuf/types/known/timestamppb" @@ -108,20 +111,29 @@ func TestSentinelHandler_MigrateToWorkflow(t *testing.T) { } func TestHandler_CreateFromMigrationState_Sentinel(t *testing.T) { - env := newTestEnv(t, withMockEngine()) - sentinel, ctx, _ := setupSentinelForTest(t) - - h := scheduler.NewTestHandler(env.Logger) - - // StartExecution returns already-started because the sentinel occupies the key. - env.MockEngine.EXPECT().StartExecution(gomock.Any(), gomock.Any(), gomock.Any()). - Return(chasm.StartExecutionResult{}, chasm.NewExecutionAlreadyStartedErr("already exists", "", "")) - - // ReadComponent invokes the read function with the sentinel. - env.ExpectReadComponent(ctx, sentinel) + ctrl := gomock.NewController(t) + logger := log.NewTestLogger() + registry := chasm.NewRegistry(logger) + require.NoError(t, registry.Register(&chasm.CoreLibrary{})) + require.NoError(t, registry.Register(newTestLibrary(logger, newRealSpecProcessor(ctrl, logger)))) + + h := scheduler.NewTestHandler(logger) + testEngine := chasmtest.NewEngine(t, registry) + engineCtx := chasm.NewEngineContext(context.Background(), testEngine) + _, err := chasm.StartExecution[*scheduler.Scheduler, struct{}]( + engineCtx, + chasm.ExecutionKey{ + NamespaceID: namespaceID, + BusinessID: scheduleID, + }, + func(ctx chasm.MutableContext, _ struct{}) (*scheduler.Scheduler, error) { + return scheduler.NewSentinel(ctx, namespace, namespaceID, scheduleID), nil + }, + struct{}{}, + ) + require.NoError(t, err) - engineCtx := env.EngineContext() - _, err := h.TestCreateFromMigrationState(engineCtx, &schedulerpb.CreateFromMigrationStateRequest{ + _, err = h.TestCreateFromMigrationState(engineCtx, &schedulerpb.CreateFromMigrationStateRequest{ NamespaceId: namespaceID, State: &schedulerpb.SchedulerMigrationState{ SchedulerState: &schedulerpb.SchedulerState{ @@ -137,16 +149,29 @@ func TestHandler_CreateFromMigrationState_Sentinel(t *testing.T) { } func TestHandler_MigrateToWorkflow_Sentinel(t *testing.T) { - env := newTestEnv(t, withMockEngine()) - sentinel, ctx, _ := setupSentinelForTest(t) - - h := scheduler.NewTestHandler(env.Logger) - - // UpdateComponent invokes the update function with the sentinel. - env.ExpectUpdateComponent(ctx, sentinel) + ctrl := gomock.NewController(t) + logger := log.NewTestLogger() + registry := chasm.NewRegistry(logger) + require.NoError(t, registry.Register(&chasm.CoreLibrary{})) + require.NoError(t, registry.Register(newTestLibrary(logger, newRealSpecProcessor(ctrl, logger)))) + + h := scheduler.NewTestHandler(logger) + testEngine := chasmtest.NewEngine(t, registry) + engineCtx := chasm.NewEngineContext(context.Background(), testEngine) + _, err := chasm.StartExecution[*scheduler.Scheduler, struct{}]( + engineCtx, + chasm.ExecutionKey{ + NamespaceID: namespaceID, + BusinessID: scheduleID, + }, + func(ctx chasm.MutableContext, _ struct{}) (*scheduler.Scheduler, error) { + return scheduler.NewSentinel(ctx, namespace, namespaceID, scheduleID), nil + }, + struct{}{}, + ) + require.NoError(t, err) - engineCtx := env.EngineContext() - _, err := h.TestMigrateToWorkflow(engineCtx, &schedulerpb.MigrateToWorkflowRequest{ + _, err = h.TestMigrateToWorkflow(engineCtx, &schedulerpb.MigrateToWorkflowRequest{ NamespaceId: namespaceID, ScheduleId: scheduleID, })