diff --git a/Makefile b/Makefile
index 067b572d..3bba0285 100644
--- a/Makefile
+++ b/Makefile
@@ -109,11 +109,11 @@ lint: lint.check ## Run golangci-lint
##@ Run
.PHONY: reviewable
-reviewable: swag # Ensure a PR is ready for review.
+reviewable: lint test-integration # Ensure a PR is ready for review.
@go mod tidy
.PHONY: check-diff
-check-diff: reviewable # Ensure branch is clean.
+check-diff: swag # Ensure branch is clean.
@test -z "$$(git status --porcelain)" || (echo "$$(git status --porcelain)" && $(FAIL))
@$(OK) branch is clean
diff --git a/cmd/run.go b/cmd/run.go
index 5ee56ece..10715419 100644
--- a/cmd/run.go
+++ b/cmd/run.go
@@ -3,7 +3,6 @@ package cmd
import (
"context"
"log"
- "time"
"github.com/compliance-framework/api/internal/api"
"github.com/compliance-framework/api/internal/api/handler"
@@ -14,7 +13,6 @@ import (
"github.com/compliance-framework/api/internal/service/digest"
"github.com/compliance-framework/api/internal/service/email"
"github.com/compliance-framework/api/internal/service/relational/workflows"
- "github.com/compliance-framework/api/internal/service/scheduler"
"github.com/compliance-framework/api/internal/service/worker"
"github.com/compliance-framework/api/internal/workflow"
"github.com/spf13/cobra"
@@ -86,11 +84,6 @@ func RunServer(cmd *cobra.Command, args []string) {
sugar.Fatalw("Failed to start worker service", "error", err)
}
- // Initialize scheduler for other jobs (if any)
- // Note: Digest scheduling is now handled by River's periodic jobs
- sched := scheduler.NewCronScheduler(sugar)
- sched.Start()
-
// Initialize workflow manager
workflowExecService := workflows.NewWorkflowExecutionService(db)
workflowInstService := workflows.NewWorkflowInstanceService(db)
@@ -101,11 +94,12 @@ func RunServer(cmd *cobra.Command, args []string) {
workflowInstService,
stepExecService,
sugar,
+ workerService,
)
metrics := api.NewMetricsHandler(ctx, sugar)
server := api.NewServer(ctx, sugar, cfg, metrics)
- handler.RegisterHandlers(server, sugar, db, cfg, digestService, sched, workflowManager)
+ handler.RegisterHandlers(server, sugar, db, cfg, digestService, workflowManager, workerService, workerService.GetDAGExecutor())
oscal.RegisterHandlers(server, sugar, db, cfg)
auth.RegisterHandlers(server, sugar, db, cfg, metrics, emailService, workerService)
@@ -121,23 +115,9 @@ func RunServer(cmd *cobra.Command, args []string) {
sugar.Fatalw("Failed to start server", "error", err)
}
- // Note: Defer statements are registered in reverse order of execution.
- // This ensures proper shutdown order: scheduler -> worker service
defer func() {
- // Stop worker service last (after scheduler has stopped)
if err := workerService.Stop(ctx); err != nil {
sugar.Errorw("Failed to stop worker service", "error", err)
}
}()
-
- defer func() {
- // Stop scheduler first
- stopCtx := sched.Stop()
- select {
- case <-stopCtx.Done():
- sugar.Debug("All scheduled jobs completed gracefully")
- case <-time.After(10 * time.Second):
- sugar.Warn("Scheduler shutdown timeout, some jobs may not have completed")
- }
- }()
}
diff --git a/internal/api/handler/api.go b/internal/api/handler/api.go
index 12c99b5f..cb3d9407 100644
--- a/internal/api/handler/api.go
+++ b/internal/api/handler/api.go
@@ -9,14 +9,13 @@ import (
"github.com/compliance-framework/api/internal/config"
"github.com/compliance-framework/api/internal/service/digest"
workflowsvc "github.com/compliance-framework/api/internal/service/relational/workflows"
- "github.com/compliance-framework/api/internal/service/scheduler"
"github.com/compliance-framework/api/internal/workflow"
"github.com/labstack/echo/v4"
"go.uber.org/zap"
"gorm.io/gorm"
)
-func RegisterHandlers(server *api.Server, logger *zap.SugaredLogger, db *gorm.DB, config *config.Config, digestService *digest.Service, sched scheduler.Scheduler, workflowManager *workflow.Manager) {
+func RegisterHandlers(server *api.Server, logger *zap.SugaredLogger, db *gorm.DB, config *config.Config, digestService *digest.Service, workflowManager *workflow.Manager, notificationEnqueuer workflow.NotificationEnqueuer, dagExecutor *workflow.DAGExecutor) {
healthHandler := NewHealthHandler(logger, db)
healthHandler.Register(server.API().Group("/health"))
@@ -41,8 +40,8 @@ func RegisterHandlers(server *api.Server, logger *zap.SugaredLogger, db *gorm.DB
userHandler.RegisterSelfRoutes(userGroup)
// Digest handler (admin only)
- if digestService != nil && sched != nil {
- digestHandler := NewDigestHandler(digestService, sched, logger)
+ if digestService != nil {
+ digestHandler := NewDigestHandler(digestService, logger)
digestGroup := server.API().Group("/admin/digest")
digestGroup.Use(middleware.JWTMiddleware(config.JWTPublicKey))
digestGroup.Use(middleware.RequireAdminGroups(db, config, logger))
@@ -50,11 +49,11 @@ func RegisterHandlers(server *api.Server, logger *zap.SugaredLogger, db *gorm.DB
}
// Register workflow handlers
- registerWorkflowHandlers(server, logger, db, config, workflowManager)
+ registerWorkflowHandlers(server, logger, db, config, workflowManager, notificationEnqueuer, dagExecutor)
}
// registerWorkflowHandlers registers all workflow-related HTTP handlers with authentication
-func registerWorkflowHandlers(server *api.Server, logger *zap.SugaredLogger, db *gorm.DB, config *config.Config, workflowManager *workflow.Manager) {
+func registerWorkflowHandlers(server *api.Server, logger *zap.SugaredLogger, db *gorm.DB, config *config.Config, workflowManager *workflow.Manager, notificationEnqueuer workflow.NotificationEnqueuer, dagExecutor *workflow.DAGExecutor) {
// Create workflow group with authentication middleware
workflowGroup := server.API().Group("/workflows")
workflowGroup.Use(middleware.JWTMiddleware(config.JWTPublicKey))
@@ -77,27 +76,28 @@ func registerWorkflowHandlers(server *api.Server, logger *zap.SugaredLogger, db
// Handlers that require workflow manager
if workflowManager != nil {
- registerWorkflowExecutionHandlers(workflowGroup, logger, db, workflowManager)
+ registerWorkflowExecutionHandlers(workflowGroup, logger, db, workflowManager, notificationEnqueuer, dagExecutor)
}
}
// registerWorkflowExecutionHandlers registers execution-related handlers that require the workflow manager
-func registerWorkflowExecutionHandlers(workflowGroup *echo.Group, logger *zap.SugaredLogger, db *gorm.DB, workflowManager *workflow.Manager) {
+func registerWorkflowExecutionHandlers(workflowGroup *echo.Group, logger *zap.SugaredLogger, db *gorm.DB, workflowManager *workflow.Manager, notificationEnqueuer workflow.NotificationEnqueuer, dagExecutor *workflow.DAGExecutor) {
roleAssignmentService := workflowsvc.NewRoleAssignmentService(db)
- assignmentService := workflow.NewAssignmentService(roleAssignmentService, db)
+ stepExecService := workflowsvc.NewStepExecutionService(db, nil)
+ assignmentService := workflow.NewAssignmentService(roleAssignmentService, stepExecService, db, logger, notificationEnqueuer)
// Workflow execution handler
workflowExecutionHandler := workflows.NewWorkflowExecutionHandler(logger, db, workflowManager, assignmentService)
workflowExecutionHandler.Register(workflowGroup.Group("/executions"))
// Step execution handler with transition service
- transitionService := createStepTransitionService(db, logger)
+ transitionService := createStepTransitionService(db, logger, notificationEnqueuer, dagExecutor)
stepExecutionHandler := workflows.NewStepExecutionHandler(logger, db, transitionService, assignmentService)
stepExecutionHandler.Register(workflowGroup.Group("/step-executions"))
}
// createStepTransitionService creates and configures the step transition service with all dependencies
-func createStepTransitionService(db *gorm.DB, logger *zap.SugaredLogger) *workflow.StepTransitionService {
+func createStepTransitionService(db *gorm.DB, logger *zap.SugaredLogger, notificationEnqueuer workflow.NotificationEnqueuer, executor *workflow.DAGExecutor) *workflow.StepTransitionService {
// Create services needed for step transition
stepExecService := workflowsvc.NewStepExecutionService(db, nil)
stepDefService := workflowsvc.NewWorkflowStepDefinitionService(db)
@@ -107,17 +107,7 @@ func createStepTransitionService(db *gorm.DB, logger *zap.SugaredLogger) *workfl
roleAssignmentService := workflowsvc.NewRoleAssignmentService(db)
// Create assignment service
- assignmentService := workflow.NewAssignmentService(roleAssignmentService, db)
-
- // Create executor for step transition coordination
- stdLogger := log.Default()
- executor := workflow.NewDAGExecutor(
- stepExecService,
- workflowExecService,
- stepDefService,
- assignmentService,
- stdLogger,
- )
+ assignmentService := workflow.NewAssignmentService(roleAssignmentService, stepExecService, db, logger, notificationEnqueuer)
// Create evidence integration for step evidence storage
evidenceIntegration := workflow.NewEvidenceIntegration(db, logger)
@@ -126,6 +116,20 @@ func createStepTransitionService(db *gorm.DB, logger *zap.SugaredLogger) *workfl
stepExecService.SetEvidenceCreator(evidenceIntegration)
workflowExecService.SetEvidenceCreator(evidenceIntegration)
+ // Use the shared executor from the worker service when available so that there is exactly
+ // one DAGExecutor instance (consistent logger, notifications, and evidence integration).
+ // Fall back to constructing a local executor when the worker is disabled (executor == nil).
+ if executor == nil {
+ executor = workflow.NewDAGExecutor(
+ stepExecService,
+ workflowExecService,
+ stepDefService,
+ assignmentService,
+ log.Default(),
+ notificationEnqueuer,
+ )
+ }
+
// Create and return step transition service
return workflow.NewStepTransitionService(
stepExecService,
diff --git a/internal/api/handler/digest.go b/internal/api/handler/digest.go
index bce83bda..954aeb9f 100644
--- a/internal/api/handler/digest.go
+++ b/internal/api/handler/digest.go
@@ -6,7 +6,6 @@ import (
"github.com/compliance-framework/api/internal/api"
"github.com/compliance-framework/api/internal/service/digest"
- "github.com/compliance-framework/api/internal/service/scheduler"
"github.com/labstack/echo/v4"
"go.uber.org/zap"
)
@@ -14,15 +13,13 @@ import (
// DigestHandler handles digest-related API endpoints
type DigestHandler struct {
digestService *digest.Service
- scheduler scheduler.Scheduler
logger *zap.SugaredLogger
}
// NewDigestHandler creates a new digest handler
-func NewDigestHandler(digestService *digest.Service, sched scheduler.Scheduler, logger *zap.SugaredLogger) *DigestHandler {
+func NewDigestHandler(digestService *digest.Service, logger *zap.SugaredLogger) *DigestHandler {
return &DigestHandler{
digestService: digestService,
- scheduler: sched,
logger: logger,
}
}
@@ -50,12 +47,11 @@ func (h *DigestHandler) TriggerDigest(ctx echo.Context) error {
if jobName == "" {
jobName = "global-evidence-digest"
}
-
- if h.scheduler == nil {
- return ctx.JSON(http.StatusInternalServerError, api.NewError(fmt.Errorf("scheduler is not available")))
+ if jobName != "global-evidence-digest" {
+ return ctx.JSON(http.StatusBadRequest, api.NewError(fmt.Errorf("unsupported digest job: %s", jobName)))
}
- if err := h.scheduler.RunNow(ctx.Request().Context(), jobName); err != nil {
+ if err := h.digestService.SendGlobalDigest(ctx.Request().Context()); err != nil {
h.logger.Errorw("Failed to trigger digest job", "job", jobName, "error", err)
return ctx.JSON(http.StatusInternalServerError, api.NewError(err))
}
diff --git a/internal/api/handler/digest_integration_test.go b/internal/api/handler/digest_integration_test.go
index 17f3adfc..41d414af 100644
--- a/internal/api/handler/digest_integration_test.go
+++ b/internal/api/handler/digest_integration_test.go
@@ -5,14 +5,12 @@ package handler
import (
"context"
"encoding/json"
- "fmt"
"net/http/httptest"
"testing"
"github.com/compliance-framework/api/internal/api"
"github.com/compliance-framework/api/internal/service/digest"
"github.com/compliance-framework/api/internal/service/email"
- "github.com/compliance-framework/api/internal/service/scheduler"
"github.com/compliance-framework/api/internal/tests"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
@@ -27,56 +25,9 @@ type DigestApiIntegrationSuite struct {
server *api.Server
logger *zap.SugaredLogger
digestHandler *DigestHandler
- mockScheduler *MockScheduler
emailService *email.Service
}
-// MockScheduler implements the scheduler.Service interface for testing
-type MockScheduler struct {
- jobs map[string]bool
-}
-
-func NewMockScheduler() *MockScheduler {
- return &MockScheduler{
- jobs: make(map[string]bool),
- }
-}
-
-func (m *MockScheduler) Start() {
- // Mock implementation
-}
-
-func (m *MockScheduler) Stop() context.Context {
- // Mock implementation
- return context.Background()
-}
-
-func (m *MockScheduler) Schedule(schedule scheduler.Schedule, job scheduler.Job) error {
- m.jobs[job.Name()] = true
- return nil
-}
-
-func (m *MockScheduler) ScheduleCron(cronExpr string, job scheduler.Job) error {
- m.jobs[job.Name()] = true
- return nil
-}
-
-func (m *MockScheduler) RunNow(ctx context.Context, name string) error {
- if _, exists := m.jobs[name]; !exists {
- return fmt.Errorf("job %q not found", name)
- }
- // Mock job execution
- return nil
-}
-
-func (m *MockScheduler) ListJobs() []string {
- var jobs []string
- for name := range m.jobs {
- jobs = append(jobs, name)
- }
- return jobs
-}
-
func (suite *DigestApiIntegrationSuite) SetupSuite() {
suite.IntegrationTestSuite.SetupSuite()
@@ -88,27 +39,22 @@ func (suite *DigestApiIntegrationSuite) SetupSuite() {
suite.Require().NoError(err, "Failed to create email service")
suite.emailService = emailService
- // Create mock scheduler
- suite.mockScheduler = NewMockScheduler()
-
// Create digest handler
digestService := digest.NewService(suite.DB, suite.emailService, nil, suite.Config, suite.logger)
- suite.digestHandler = NewDigestHandler(digestService, suite.mockScheduler, suite.logger)
+ suite.digestHandler = NewDigestHandler(digestService, suite.logger)
// Setup server
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
suite.server = api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
// Register handlers
- RegisterHandlers(suite.server, suite.logger, suite.DB, suite.Config, digestService, suite.mockScheduler, nil)
+ RegisterHandlers(suite.server, suite.logger, suite.DB, suite.Config, digestService, nil, nil, nil)
}
func (suite *DigestApiIntegrationSuite) SetupTest() {
err := suite.Migrator.Refresh()
suite.Require().NoError(err)
- // Pre-register the default job in the mock scheduler
- suite.mockScheduler.jobs["global-evidence-digest"] = true
}
func (suite *DigestApiIntegrationSuite) TestTriggerDigest() {
@@ -132,22 +78,16 @@ func (suite *DigestApiIntegrationSuite) TestTriggerDigest() {
})
suite.Run("TriggerDigestWithCustomJob", func() {
- // Pre-register the custom job
- suite.mockScheduler.jobs["custom-job"] = true
-
rec := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/admin/digest/trigger?job=custom-job", nil)
req.Header.Set("Authorization", "Bearer "+*token)
suite.server.E().ServeHTTP(rec, req)
- suite.Equal(200, rec.Code, "Expected OK response for TriggerDigest with custom job")
+ suite.Equal(400, rec.Code, "Expected Bad Request response for unsupported custom digest job")
- var response map[string]string
+ var response api.Error
err = json.Unmarshal(rec.Body.Bytes(), &response)
- suite.Require().NoError(err, "Failed to unmarshal TriggerDigest response")
-
- suite.Equal("Digest job triggered successfully", response["message"])
- suite.Equal("custom-job", response["job"])
+ suite.Require().NoError(err, "Failed to unmarshal TriggerDigest error response")
})
suite.Run("TriggerDigestUnauthorized", func() {
@@ -190,35 +130,3 @@ func (suite *DigestApiIntegrationSuite) TestPreviewDigest() {
suite.Equal(401, rec.Code, "Expected Unauthorized response for missing token")
})
}
-
-func (suite *DigestApiIntegrationSuite) TestTriggerDigestWithNilScheduler() {
- // Test with nil scheduler to verify error handling
- token, err := suite.GetAuthToken()
- suite.Require().NoError(err)
-
- // Create handler with nil scheduler
- digestService := digest.NewService(suite.DB, suite.emailService, nil, suite.Config, suite.logger)
- nilSchedulerHandler := NewDigestHandler(digestService, nil, suite.logger)
-
- // Create a temporary echo context for testing
- e := suite.server.E()
- req := httptest.NewRequest("POST", "/api/admin/digest/trigger", nil)
- req.Header.Set("Authorization", "Bearer "+*token)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
-
- err = nilSchedulerHandler.TriggerDigest(c)
- suite.NoError(err, "Expected no error from TriggerDigest with nil scheduler")
- suite.Equal(500, rec.Code, "Expected Internal Server Error when scheduler is nil")
-
- var response api.Error
- err = json.Unmarshal(rec.Body.Bytes(), &response)
- suite.Require().NoError(err, "Failed to unmarshal error response")
-
- // Check if the error contains our expected message
- for _, errMsg := range response.Errors {
- if msgStr, ok := errMsg.(string); ok {
- suite.Contains(msgStr, "scheduler is not available")
- }
- }
-}
diff --git a/internal/api/handler/evidence_integration_test.go b/internal/api/handler/evidence_integration_test.go
index 4ad7e59f..340f7cab 100644
--- a/internal/api/handler/evidence_integration_test.go
+++ b/internal/api/handler/evidence_integration_test.go
@@ -158,7 +158,7 @@ func (suite *EvidenceApiIntegrationSuite) TestCreate() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(evidence)
req := httptest.NewRequest(http.MethodPost, "/api/evidence", bytes.NewReader(reqBody))
@@ -211,7 +211,7 @@ func (suite *EvidenceApiIntegrationSuite) TestSearch() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(struct {
Filter labelfilter.Filter
@@ -264,7 +264,7 @@ func (suite *EvidenceApiIntegrationSuite) TestSearch() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(struct {
Filter labelfilter.Filter
@@ -317,7 +317,7 @@ func (suite *EvidenceApiIntegrationSuite) TestSearch() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
var reqBody, _ = json.Marshal(struct {
Filter labelfilter.Filter
@@ -381,7 +381,7 @@ func (suite *EvidenceApiIntegrationSuite) TestSearch() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
var reqBody, _ = json.Marshal(struct {
Filter labelfilter.Filter
@@ -453,7 +453,7 @@ func (suite *EvidenceApiIntegrationSuite) TestSearch() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
var reqBody, _ = json.Marshal(struct {
Filter labelfilter.Filter
@@ -552,7 +552,7 @@ func (suite *EvidenceApiIntegrationSuite) TestStatusOverTime() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(struct {
Filter labelfilter.Filter
@@ -673,7 +673,7 @@ func (suite *EvidenceApiIntegrationSuite) TestComplianceByFilter() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/evidence/compliance-by-filter/%s", filter.ID), nil)
server.E().ServeHTTP(rec, req)
@@ -704,7 +704,7 @@ func (suite *EvidenceApiIntegrationSuite) TestComplianceByFilter() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/evidence/compliance-by-filter/%s", uuid.New()), nil)
server.E().ServeHTTP(rec, req)
@@ -718,7 +718,7 @@ func (suite *EvidenceApiIntegrationSuite) TestComplianceByFilter() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/evidence/compliance-by-filter/invalid-uuid", nil)
server.E().ServeHTTP(rec, req)
diff --git a/internal/api/handler/filter_integration_test.go b/internal/api/handler/filter_integration_test.go
index 50c53544..e2b036ae 100644
--- a/internal/api/handler/filter_integration_test.go
+++ b/internal/api/handler/filter_integration_test.go
@@ -52,7 +52,7 @@ func (suite *FilterApiIntegrationSuite) TestCreate() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(createReq)
req := httptest.NewRequest(http.MethodPost, "/api/filters", bytes.NewReader(reqBody))
@@ -96,7 +96,7 @@ func (suite *FilterApiIntegrationSuite) TestCreate() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(createReq)
req := httptest.NewRequest(http.MethodPost, "/api/filters", bytes.NewReader(reqBody))
@@ -137,7 +137,7 @@ func (suite *FilterApiIntegrationSuite) TestCreate() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(createReq)
req := httptest.NewRequest(http.MethodPost, "/api/filters", bytes.NewReader(reqBody))
@@ -166,7 +166,7 @@ func (suite *FilterApiIntegrationSuite) TestList() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
// Create filter linked to AC-1
withControlReq := createFilterRequest{
@@ -244,7 +244,7 @@ func (suite *FilterApiIntegrationSuite) TestList() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
// Create filter linked to our system component
withComponentReq := createFilterRequest{
@@ -341,7 +341,7 @@ func (suite *FilterApiIntegrationSuite) TestUpdate() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(updateReq)
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/api/filters/%s", filter.ID), bytes.NewReader(reqBody))
@@ -399,7 +399,7 @@ func (suite *FilterApiIntegrationSuite) TestUpdate() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(updateReq)
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/api/filters/%s", filter.ID), bytes.NewReader(reqBody))
@@ -476,7 +476,7 @@ func (suite *FilterApiIntegrationSuite) TestUpdate() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(updateReq)
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/api/filters/%s", filter.ID), bytes.NewReader(reqBody))
@@ -551,7 +551,7 @@ func (suite *FilterApiIntegrationSuite) TestUpdate() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(updateReq)
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/api/filters/%s", filter.ID), bytes.NewReader(reqBody))
diff --git a/internal/api/handler/heartbeat_integration_test.go b/internal/api/handler/heartbeat_integration_test.go
index 518506b6..303a181d 100644
--- a/internal/api/handler/heartbeat_integration_test.go
+++ b/internal/api/handler/heartbeat_integration_test.go
@@ -40,7 +40,7 @@ func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatCreateValidation() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(heartbeat)
req := httptest.NewRequest(http.MethodPost, "/api/agent/heartbeat", bytes.NewReader(reqBody))
@@ -62,7 +62,7 @@ func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatCreate() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
reqBody, _ := json.Marshal(heartbeat)
req := httptest.NewRequest(http.MethodPost, "/api/agent/heartbeat", bytes.NewReader(reqBody))
@@ -95,7 +95,7 @@ func (suite *HeartbeatApiIntegrationSuite) TestHeartbeatOverTime() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/agent/heartbeat/over-time/", nil)
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
diff --git a/internal/api/handler/oscal/assessmentplan_integration_test.go b/internal/api/handler/oscal/assessmentplan_integration_test.go
index 0ae02c18..16aed51e 100644
--- a/internal/api/handler/oscal/assessmentplan_integration_test.go
+++ b/internal/api/handler/oscal/assessmentplan_integration_test.go
@@ -41,7 +41,7 @@ func (suite *AssessmentPlanApiIntegrationSuite) SetupSuite() {
suite.logger = logger.Sugar()
metrics := api.NewMetricsHandler(context.Background(), suite.logger)
suite.server = api.NewServer(context.Background(), suite.logger, suite.Config, metrics)
- handler.RegisterHandlers(suite.server, suite.logger, suite.DB, suite.Config, nil, nil, nil)
+ handler.RegisterHandlers(suite.server, suite.logger, suite.DB, suite.Config, nil, nil, nil, nil)
RegisterHandlers(suite.server, suite.logger, suite.DB, suite.Config)
}
diff --git a/internal/api/handler/users_integration_test.go b/internal/api/handler/users_integration_test.go
index 1be7874d..ac85e6ab 100644
--- a/internal/api/handler/users_integration_test.go
+++ b/internal/api/handler/users_integration_test.go
@@ -33,7 +33,7 @@ func (suite *UserApiIntegrationSuite) SetupSuite() {
suite.logger = logger.Sugar()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
suite.server = api.NewServer(context.Background(), logger.Sugar(), suite.Config, metrics)
- RegisterHandlers(suite.server, suite.logger, suite.DB, suite.Config, nil, nil, nil)
+ RegisterHandlers(suite.server, suite.logger, suite.DB, suite.Config, nil, nil, nil, nil)
}
func (suite *UserApiIntegrationSuite) SetupTest() {
diff --git a/internal/api/handler/workflows/base.go b/internal/api/handler/workflows/base.go
index 95cb684b..cce14dad 100644
--- a/internal/api/handler/workflows/base.go
+++ b/internal/api/handler/workflows/base.go
@@ -89,25 +89,6 @@ func (b *BaseHandler) ParseUUID(ctx echo.Context, paramName, entityName string)
return &id, nil
}
-// ParseQueryUUID parses a UUID from a query parameter
-func (b *BaseHandler) ParseQueryUUID(ctx echo.Context, paramName, entityName string) (*uuid.UUID, error) {
- idStr := ctx.QueryParam(paramName)
- if idStr == "" {
- return nil, nil
- }
-
- id, err := uuid.Parse(idStr)
- if err != nil {
- b.sugar.Errorw("Invalid "+entityName+" ID", "error", err, "param", paramName, "value", idStr)
- err = ctx.JSON(http.StatusBadRequest, api.NewError(err))
- if err != nil {
- return nil, err
- }
- return nil, ErrResponseSent
- }
- return &id, nil
-}
-
// HandleServiceError handles service layer errors with appropriate HTTP status codes
func (b *BaseHandler) HandleServiceError(ctx echo.Context, err error, operation, entityName string) error {
if err == gorm.ErrRecordNotFound || isNotFoundError(err) {
diff --git a/internal/api/handler/workflows/step_execution_integration_test.go b/internal/api/handler/workflows/step_execution_integration_test.go
index 82290256..101823b4 100644
--- a/internal/api/handler/workflows/step_execution_integration_test.go
+++ b/internal/api/handler/workflows/step_execution_integration_test.go
@@ -39,11 +39,8 @@ func setupStepExecutionTestHandler(t *testing.T) (*StepExecutionHandler, *gorm.D
workflowDefinitionService := workflows.NewWorkflowDefinitionService(db)
roleAssignmentService := workflows.NewRoleAssignmentService(db)
- // Set the workflow execution service on evidence integration to use the same instance
- evidenceIntegration.SetWorkflowExecutionService(workflowExecService)
-
// Create assignment service
- assignmentService := workflow.NewAssignmentService(roleAssignmentService, db)
+ assignmentService := workflow.NewAssignmentService(roleAssignmentService, stepExecService, db, zap.NewNop().Sugar(), nil)
// Create executor for step transition coordination
stdLogger := log.Default()
@@ -53,6 +50,7 @@ func setupStepExecutionTestHandler(t *testing.T) (*StepExecutionHandler, *gorm.D
stepDefService,
assignmentService,
stdLogger,
+ nil,
)
// Create step transition service
diff --git a/internal/api/handler/workflows/workflow_execution.go b/internal/api/handler/workflows/workflow_execution.go
index 2a87b050..625aa382 100644
--- a/internal/api/handler/workflows/workflow_execution.go
+++ b/internal/api/handler/workflows/workflow_execution.go
@@ -119,7 +119,7 @@ func (h *WorkflowExecutionHandler) Start(ctx echo.Context) error {
TriggeredByID: req.TriggeredByID,
}
- executionID, err := h.manager.StartWorkflowExecution(
+ execution, err := h.manager.StartWorkflowExecution(
ctx.Request().Context(),
req.WorkflowInstanceID,
opts,
@@ -128,13 +128,7 @@ func (h *WorkflowExecutionHandler) Start(ctx echo.Context) error {
return h.HandleServiceError(ctx, err, "start", "workflow execution")
}
- // Get the created execution
- execution, err := h.service.GetByID(executionID)
- if err != nil {
- return h.HandleServiceError(ctx, err, "get", "workflow execution")
- }
-
- h.sugar.Infow("Workflow execution started", "id", executionID)
+ h.sugar.Infow("Workflow execution started", "id", execution.ID)
return h.RespondCreated(ctx, WorkflowExecutionResponse{Data: execution})
}
@@ -307,14 +301,9 @@ func (h *WorkflowExecutionHandler) Cancel(ctx echo.Context) error {
}
// Use the manager to cancel the execution
- if err := h.manager.CancelExecution(ctx.Request().Context(), id, reason); err != nil {
- return h.HandleServiceError(ctx, err, "cancel", "workflow execution")
- }
-
- // Get the updated execution
- execution, err := h.service.GetByID(id)
+ execution, err := h.manager.CancelExecution(ctx.Request().Context(), id, reason)
if err != nil {
- return h.HandleServiceError(ctx, err, "get", "workflow execution after cancellation")
+ return h.HandleServiceError(ctx, err, "cancel", "workflow execution")
}
h.sugar.Infow("Workflow execution cancelled", "id", id)
@@ -342,18 +331,12 @@ func (h *WorkflowExecutionHandler) Retry(ctx echo.Context) error {
}
// Use the manager to retry the execution
- newExecutionID, err := h.manager.RetryExecution(ctx.Request().Context(), id)
+ execution, err := h.manager.RetryExecution(ctx.Request().Context(), id)
if err != nil {
return h.HandleServiceError(ctx, err, "retry", "workflow execution")
}
- // Get the new execution
- execution, err := h.service.GetByID(newExecutionID)
- if err != nil {
- return h.HandleServiceError(ctx, err, "get", "new workflow execution")
- }
-
- h.sugar.Infow("Workflow execution retried", "original_id", id, "new_id", newExecutionID)
+ h.sugar.Infow("Workflow execution retried", "original_id", id, "new_id", execution.ID)
return h.RespondCreated(ctx, WorkflowExecutionResponse{Data: execution})
}
diff --git a/internal/api/handler/workflows/workflow_execution_integration_test.go b/internal/api/handler/workflows/workflow_execution_integration_test.go
index 7231cd97..0b06ca96 100644
--- a/internal/api/handler/workflows/workflow_execution_integration_test.go
+++ b/internal/api/handler/workflows/workflow_execution_integration_test.go
@@ -48,7 +48,7 @@ func setupExecutionTestHandler(t *testing.T) (*WorkflowExecutionHandler, *gorm.D
workflowExecService := workflows.NewWorkflowExecutionService(db)
workflowInstService := workflows.NewWorkflowInstanceService(db)
roleAssignmentService := workflows.NewRoleAssignmentService(db)
- assignmentService := workflow.NewAssignmentService(roleAssignmentService, db)
+ assignmentService := workflow.NewAssignmentService(roleAssignmentService, stepExecService, db, zap.NewNop().Sugar(), nil)
// Create a mock river client for testing
mockRiver := &MockRiverClient{}
@@ -61,6 +61,7 @@ func setupExecutionTestHandler(t *testing.T) (*WorkflowExecutionHandler, *gorm.D
workflowInstService,
stepExecService,
logger,
+ nil,
)
handler := NewWorkflowExecutionHandler(logger, db, manager, assignmentService)
diff --git a/internal/api/handler/workflows/workflow_instance.go b/internal/api/handler/workflows/workflow_instance.go
index aba99c47..0e504b13 100644
--- a/internal/api/handler/workflows/workflow_instance.go
+++ b/internal/api/handler/workflows/workflow_instance.go
@@ -10,12 +10,14 @@ import (
type WorkflowInstanceHandler struct {
*BaseHandler
+ db *gorm.DB
service *workflows.WorkflowInstanceService
}
func NewWorkflowInstanceHandler(sugar *zap.SugaredLogger, db *gorm.DB) *WorkflowInstanceHandler {
return &WorkflowInstanceHandler{
BaseHandler: NewBaseHandler(sugar),
+ db: db,
service: workflows.NewWorkflowInstanceService(db),
}
}
@@ -80,6 +82,12 @@ func (h *WorkflowInstanceHandler) Create(ctx echo.Context) error {
h.sugar.Errorw("Failed to parse system security plan ID", "error", err)
return h.HandleServiceError(ctx, err, "parse", "system security plan ID")
}
+
+ actorID, _, err := h.GetActorFromClaims(ctx, h.db)
+ if err != nil {
+ return HandleError(err)
+ }
+
instance := &workflows.WorkflowInstance{
WorkflowDefinitionID: req.WorkflowDefinitionID,
Name: req.Name,
@@ -88,6 +96,7 @@ func (h *WorkflowInstanceHandler) Create(ctx echo.Context) error {
Cadence: req.Cadence,
IsActive: true, // Default to active
GracePeriodDays: req.GracePeriodDays,
+ CreatedByID: actorID,
}
if req.IsActive != nil {
@@ -217,6 +226,11 @@ func (h *WorkflowInstanceHandler) Update(ctx echo.Context) error {
return HandleError(err)
}
+ actorID, _, err := h.GetActorFromClaims(ctx, h.db)
+ if err != nil {
+ return HandleError(err)
+ }
+
instance, err := h.service.GetByID(id)
if err != nil {
return h.HandleServiceError(ctx, err, "get", "workflow instance")
@@ -237,6 +251,7 @@ func (h *WorkflowInstanceHandler) Update(ctx echo.Context) error {
if req.GracePeriodDays != nil {
instance.GracePeriodDays = req.GracePeriodDays
}
+ instance.UpdatedByID = actorID
if err := h.service.Update(id, instance); err != nil {
return h.HandleServiceError(ctx, err, "update", "workflow instance")
diff --git a/internal/api/handler/workflows/workflow_instance_integration_test.go b/internal/api/handler/workflows/workflow_instance_integration_test.go
index 59b83a6e..8493b6ce 100644
--- a/internal/api/handler/workflows/workflow_instance_integration_test.go
+++ b/internal/api/handler/workflows/workflow_instance_integration_test.go
@@ -10,6 +10,8 @@ import (
"testing"
"github.com/compliance-framework/api/internal/api/middleware"
+ "github.com/compliance-framework/api/internal/authn"
+ "github.com/compliance-framework/api/internal/service/relational"
"github.com/compliance-framework/api/internal/service/relational/workflows"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
@@ -26,11 +28,25 @@ func setupInstanceTestHandler(t *testing.T) (*WorkflowInstanceHandler, *gorm.DB)
return handler, db
}
+func createTestUser(t *testing.T, db *gorm.DB) *relational.User {
+ user := &relational.User{Email: "test@example.com", FirstName: "Test", LastName: "User"}
+ require.NoError(t, db.Create(user).Error)
+ return user
+}
+
+func setAuthClaims(c echo.Context, user *relational.User) {
+ claims := &authn.UserClaims{GivenName: user.FirstName, FamilyName: user.LastName}
+ claims.Subject = user.Email
+ c.Set("user", claims)
+}
+
func TestWorkflowInstanceHandler_Create(t *testing.T) {
handler, db := setupInstanceTestHandler(t)
e := echo.New()
e.Validator = middleware.NewValidator()
+ testUser := createTestUser(t, db)
+
workflowDef := &workflows.WorkflowDefinition{
Name: "Test Workflow",
Version: "1.0",
@@ -57,6 +73,7 @@ func TestWorkflowInstanceHandler_Create(t *testing.T) {
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
+ setAuthClaims(c, testUser)
err = handler.Create(c)
require.NoError(t, err)
@@ -110,6 +127,7 @@ func TestWorkflowInstanceHandler_Create(t *testing.T) {
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
+ setAuthClaims(c, testUser)
err = handler.Create(c)
require.NoError(t, err)
@@ -278,6 +296,8 @@ func TestWorkflowInstanceHandler_Update(t *testing.T) {
handler, db := setupInstanceTestHandler(t)
e := echo.New()
+ testUser := createTestUser(t, db)
+
workflowDef := &workflows.WorkflowDefinition{
Name: "Test Workflow",
Version: "1.0",
@@ -315,6 +335,7 @@ func TestWorkflowInstanceHandler_Update(t *testing.T) {
c := e.NewContext(req, rec)
c.SetParamNames("id")
c.SetParamValues(instance.ID.String())
+ setAuthClaims(c, testUser)
err = handler.Update(c)
require.NoError(t, err)
@@ -348,6 +369,7 @@ func TestWorkflowInstanceHandler_Update(t *testing.T) {
c := e.NewContext(req, rec)
c.SetParamNames("id")
c.SetParamValues(nonExistentID.String())
+ setAuthClaims(c, testUser)
err = handler.Update(c)
require.NoError(t, err)
diff --git a/internal/config/workflow.go b/internal/config/workflow.go
index 54c33e10..4c2b4953 100644
--- a/internal/config/workflow.go
+++ b/internal/config/workflow.go
@@ -22,6 +22,18 @@ type WorkflowConfig struct {
// OverdueCheckEnabled determines if we should check for overdue workflows
OverdueCheckEnabled bool `mapstructure:"overdue_check_enabled" yaml:"overdue_check_enabled" json:"overdueCheckEnabled"`
+
+ // DueSoonEnabled determines if the daily due-soon reminder emails are enabled
+ DueSoonEnabled bool `mapstructure:"due_soon_enabled" yaml:"due_soon_enabled" json:"dueSoonEnabled"`
+
+ // DueSoonSchedule is the cron schedule for the due-soon checker (default: daily at 08:00 UTC)
+ DueSoonSchedule string `mapstructure:"due_soon_schedule" yaml:"due_soon_schedule" json:"dueSoonSchedule"`
+
+ // TaskDigestEnabled determines if the daily workflow task digest emails are enabled
+ TaskDigestEnabled bool `mapstructure:"task_digest_enabled" yaml:"task_digest_enabled" json:"taskDigestEnabled"`
+
+ // TaskDigestSchedule is the cron schedule for the workflow task digest (default: daily at 08:00 UTC)
+ TaskDigestSchedule string `mapstructure:"task_digest_schedule" yaml:"task_digest_schedule" json:"taskDigestSchedule"`
}
// DefaultWorkflowConfig returns a default workflow configuration
@@ -31,6 +43,10 @@ func DefaultWorkflowConfig() *WorkflowConfig {
Schedule: "@every 15m",
GracePeriodDays: 7,
OverdueCheckEnabled: true,
+ DueSoonEnabled: false,
+ DueSoonSchedule: "0 0 8 * * *",
+ TaskDigestEnabled: false,
+ TaskDigestSchedule: "0 0 8 * * *",
}
}
@@ -43,6 +59,10 @@ func LoadWorkflowConfig(path string) (*WorkflowConfig, error) {
v.SetDefault("scheduler_schedule", "@every 15m")
v.SetDefault("grace_period_days", 7)
v.SetDefault("overdue_check_enabled", true)
+ v.SetDefault("due_soon_enabled", false)
+ v.SetDefault("due_soon_schedule", "0 0 8 * * *")
+ v.SetDefault("task_digest_enabled", false)
+ v.SetDefault("task_digest_schedule", "0 0 8 * * *")
// Configure environment variable loading
v.SetEnvPrefix("CCF_WORKFLOW")
@@ -76,12 +96,22 @@ func LoadWorkflowConfig(path string) (*WorkflowConfig, error) {
// Validate checks if the configuration is valid
func (c *WorkflowConfig) Validate() error {
+ parser := cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
if c.SchedulerEnabled {
- parser := cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
if _, err := parser.Parse(c.Schedule); err != nil {
return fmt.Errorf("invalid workflow scheduler schedule: %w", err)
}
}
+ if c.DueSoonEnabled {
+ if _, err := parser.Parse(c.DueSoonSchedule); err != nil {
+ return fmt.Errorf("invalid due_soon_schedule: %w", err)
+ }
+ }
+ if c.TaskDigestEnabled {
+ if _, err := parser.Parse(c.TaskDigestSchedule); err != nil {
+ return fmt.Errorf("invalid task_digest_schedule: %w", err)
+ }
+ }
if c.GracePeriodDays < 0 {
return fmt.Errorf("workflow grace period days must be non-negative")
}
diff --git a/internal/config/workflow_test.go b/internal/config/workflow_test.go
index 819d4dd5..6b06d101 100644
--- a/internal/config/workflow_test.go
+++ b/internal/config/workflow_test.go
@@ -15,6 +15,10 @@ func TestDefaultWorkflowConfig(t *testing.T) {
assert.Equal(t, "@every 15m", config.Schedule)
assert.Equal(t, 7, config.GracePeriodDays)
assert.True(t, config.OverdueCheckEnabled)
+ assert.False(t, config.DueSoonEnabled)
+ assert.Equal(t, "0 0 8 * * *", config.DueSoonSchedule)
+ assert.False(t, config.TaskDigestEnabled)
+ assert.Equal(t, "0 0 8 * * *", config.TaskDigestSchedule)
}
func TestLoadWorkflowConfig_Defaults(t *testing.T) {
@@ -23,6 +27,10 @@ func TestLoadWorkflowConfig_Defaults(t *testing.T) {
require.NoError(t, os.Unsetenv("CCF_WORKFLOW_SCHEDULER_SCHEDULE"))
require.NoError(t, os.Unsetenv("CCF_WORKFLOW_GRACE_PERIOD_DAYS"))
require.NoError(t, os.Unsetenv("CCF_WORKFLOW_OVERDUE_CHECK_ENABLED"))
+ require.NoError(t, os.Unsetenv("CCF_WORKFLOW_DUE_SOON_ENABLED"))
+ require.NoError(t, os.Unsetenv("CCF_WORKFLOW_DUE_SOON_SCHEDULE"))
+ require.NoError(t, os.Unsetenv("CCF_WORKFLOW_TASK_DIGEST_ENABLED"))
+ require.NoError(t, os.Unsetenv("CCF_WORKFLOW_TASK_DIGEST_SCHEDULE"))
config, err := LoadWorkflowConfig("")
require.NoError(t, err)
@@ -31,6 +39,10 @@ func TestLoadWorkflowConfig_Defaults(t *testing.T) {
assert.Equal(t, "@every 15m", config.Schedule)
assert.Equal(t, 7, config.GracePeriodDays)
assert.True(t, config.OverdueCheckEnabled)
+ assert.False(t, config.DueSoonEnabled)
+ assert.Equal(t, "0 0 8 * * *", config.DueSoonSchedule)
+ assert.False(t, config.TaskDigestEnabled)
+ assert.Equal(t, "0 0 8 * * *", config.TaskDigestSchedule)
}
func TestLoadWorkflowConfig_EnvVars(t *testing.T) {
@@ -54,6 +66,27 @@ func TestLoadWorkflowConfig_EnvVars(t *testing.T) {
assert.False(t, config.OverdueCheckEnabled)
}
+func TestLoadWorkflowConfig_NotificationEnvVars(t *testing.T) {
+ require.NoError(t, os.Setenv("CCF_WORKFLOW_DUE_SOON_ENABLED", "true"))
+ require.NoError(t, os.Setenv("CCF_WORKFLOW_DUE_SOON_SCHEDULE", "0 0 9 * * *"))
+ require.NoError(t, os.Setenv("CCF_WORKFLOW_TASK_DIGEST_ENABLED", "true"))
+ require.NoError(t, os.Setenv("CCF_WORKFLOW_TASK_DIGEST_SCHEDULE", "0 0 7 * * *"))
+ defer func() {
+ _ = os.Unsetenv("CCF_WORKFLOW_DUE_SOON_ENABLED")
+ _ = os.Unsetenv("CCF_WORKFLOW_DUE_SOON_SCHEDULE")
+ _ = os.Unsetenv("CCF_WORKFLOW_TASK_DIGEST_ENABLED")
+ _ = os.Unsetenv("CCF_WORKFLOW_TASK_DIGEST_SCHEDULE")
+ }()
+
+ config, err := LoadWorkflowConfig("")
+ require.NoError(t, err)
+
+ assert.True(t, config.DueSoonEnabled)
+ assert.Equal(t, "0 0 9 * * *", config.DueSoonSchedule)
+ assert.True(t, config.TaskDigestEnabled)
+ assert.Equal(t, "0 0 7 * * *", config.TaskDigestSchedule)
+}
+
func TestLoadWorkflowConfig_File(t *testing.T) {
content := `
scheduler_enabled: true
diff --git a/internal/service/digest/job.go b/internal/service/digest/job.go
deleted file mode 100644
index 67db3883..00000000
--- a/internal/service/digest/job.go
+++ /dev/null
@@ -1,32 +0,0 @@
-package digest
-
-import (
- "context"
-
- "go.uber.org/zap"
-)
-
-// GlobalDigestJob is a scheduled job that sends global evidence digests
-type GlobalDigestJob struct {
- service *Service
- logger *zap.SugaredLogger
-}
-
-// NewGlobalDigestJob creates a new global digest job
-func NewGlobalDigestJob(service *Service, logger *zap.SugaredLogger) *GlobalDigestJob {
- return &GlobalDigestJob{
- service: service,
- logger: logger,
- }
-}
-
-// Name returns the unique name of the job
-func (j *GlobalDigestJob) Name() string {
- return "global-evidence-digest"
-}
-
-// Execute runs the digest job
-func (j *GlobalDigestJob) Execute(ctx context.Context) error {
- j.logger.Debug("Executing global evidence digest job")
- return j.service.SendGlobalDigest(ctx)
-}
diff --git a/internal/service/digest/service_test.go b/internal/service/digest/service_test.go
index 604b2af4..cc89cc36 100644
--- a/internal/service/digest/service_test.go
+++ b/internal/service/digest/service_test.go
@@ -53,8 +53,3 @@ func TestEvidenceSummaryStructure(t *testing.T) {
assert.Len(t, summary.TopExpired, 2)
assert.Len(t, summary.TopNotSatisfied, 1)
}
-
-func TestGlobalDigestJobName(t *testing.T) {
- job := &GlobalDigestJob{}
- assert.Equal(t, "global-evidence-digest", job.Name())
-}
diff --git a/internal/service/email/templates/service_test.go b/internal/service/email/templates/service_test.go
index ddc43f93..baf9f814 100644
--- a/internal/service/email/templates/service_test.go
+++ b/internal/service/email/templates/service_test.go
@@ -56,6 +56,200 @@ func TestTemplateService_MissingTemplates(t *testing.T) {
require.Error(t, err, "UseText should error for missing template")
}
+func TestTemplateService_WorkflowTaskAssigned(t *testing.T) {
+ service, err := NewTemplateService()
+ require.NoError(t, err)
+
+ dueDate := "2026-03-01 09:00:00 +0000 UTC"
+ data := TemplateData{
+ "UserName": "Alice Smith",
+ "StepTitle": "Review Policy",
+ "WorkflowTitle": "Annual Audit",
+ "WorkflowInstanceTitle": "Audit 2026",
+ "StepURL": "https://app.example.com/steps/abc",
+ "DueDate": dueDate,
+ }
+
+ html, text, err := service.Use("workflow-task-assigned", data)
+ require.NoError(t, err)
+ require.NotEmpty(t, html)
+ require.NotEmpty(t, text)
+ require.Contains(t, html, "Alice Smith")
+ require.Contains(t, html, "Review Policy")
+ require.Contains(t, html, "Annual Audit")
+ require.Contains(t, html, "https://app.example.com/steps/abc")
+ require.Contains(t, text, "Alice Smith")
+ require.Contains(t, text, "Review Policy")
+ require.Contains(t, text, "https://app.example.com/steps/abc")
+}
+
+func TestTemplateService_WorkflowTaskAssigned_NoDueDate(t *testing.T) {
+ service, err := NewTemplateService()
+ require.NoError(t, err)
+
+ data := TemplateData{
+ "UserName": "Bob",
+ "StepTitle": "Submit Evidence",
+ "WorkflowTitle": "SOC2 Audit",
+ "WorkflowInstanceTitle": "SOC2 2026",
+ "StepURL": "https://app.example.com/steps/xyz",
+ "DueDate": nil,
+ }
+
+ html, text, err := service.Use("workflow-task-assigned", data)
+ require.NoError(t, err)
+ require.NotEmpty(t, html)
+ require.NotEmpty(t, text)
+ require.Contains(t, html, "Bob")
+ require.NotContains(t, html, "Due Date")
+}
+
+func TestTemplateService_WorkflowTaskDueSoon(t *testing.T) {
+ service, err := NewTemplateService()
+ require.NoError(t, err)
+
+ data := TemplateData{
+ "UserName": "Alice Smith",
+ "StepTitle": "Submit Evidence",
+ "WorkflowTitle": "SOC2 Audit",
+ "WorkflowInstanceTitle": "SOC2 2026",
+ "StepURL": "https://app.example.com/steps/abc",
+ "DueDate": "2026-03-01",
+ }
+
+ html, text, err := service.Use("workflow-task-due-soon", data)
+ require.NoError(t, err)
+ require.NotEmpty(t, html)
+ require.NotEmpty(t, text)
+ require.Contains(t, html, "Alice Smith")
+ require.Contains(t, html, "Submit Evidence")
+ require.Contains(t, html, "SOC2 Audit")
+ require.Contains(t, html, "https://app.example.com/steps/abc")
+ require.Contains(t, html, "2026-03-01")
+ require.Contains(t, text, "Alice Smith")
+ require.Contains(t, text, "Submit Evidence")
+ require.Contains(t, text, "https://app.example.com/steps/abc")
+ require.Contains(t, text, "TOMORROW")
+}
+
+func TestTemplateService_WorkflowExecutionFailed_WithData(t *testing.T) {
+ service, err := NewTemplateService()
+ require.NoError(t, err)
+
+ data := TemplateData{
+ "RecipientName": "Alice Smith",
+ "WorkflowTitle": "SOC2 Audit",
+ "WorkflowInstanceName": "SOC2 2026",
+ "ExecutionID": "exec-abc-123",
+ "FailureReason": "2 of 5 steps failed",
+ "FailedAt": "Wed, 19 Feb 2026 08:00:00 UTC",
+ "FailedSteps": 2,
+ "CompletedSteps": 3,
+ "TotalSteps": 5,
+ "WorkflowURL": "https://app.example.com/workflows/abc",
+ }
+
+ html, text, err := service.Use("workflow-execution-failed", data)
+ require.NoError(t, err)
+ require.NotEmpty(t, html)
+ require.NotEmpty(t, text)
+ require.Contains(t, html, "Alice Smith")
+ require.Contains(t, html, "SOC2 Audit")
+ require.Contains(t, html, "SOC2 2026")
+ require.Contains(t, html, "2 of 5 steps failed")
+ require.Contains(t, html, "exec-abc-123")
+ require.Contains(t, text, "Alice Smith")
+ require.Contains(t, text, "SOC2 Audit")
+ require.Contains(t, text, "2 of 5 steps failed")
+ require.Contains(t, text, "FAILED")
+}
+
+func TestTemplateService_WorkflowExecutionFailed_NoURL(t *testing.T) {
+ service, err := NewTemplateService()
+ require.NoError(t, err)
+
+ data := TemplateData{
+ "RecipientName": "Bob",
+ "WorkflowTitle": "Annual Audit",
+ "WorkflowInstanceName": "Audit 2026",
+ "ExecutionID": "exec-xyz-456",
+ "FailureReason": "1 of 3 steps failed",
+ "FailedAt": "Wed, 19 Feb 2026 09:00:00 UTC",
+ "FailedSteps": 1,
+ "CompletedSteps": 2,
+ "TotalSteps": 3,
+ "WorkflowURL": "",
+ }
+
+ html, text, err := service.Use("workflow-execution-failed", data)
+ require.NoError(t, err)
+ require.NotEmpty(t, html)
+ require.NotEmpty(t, text)
+ require.Contains(t, html, "Bob")
+ require.NotContains(t, html, "View Workflow Instance")
+}
+
+func TestTemplateService_WorkflowTaskDigest_WithTasks(t *testing.T) {
+ service, err := NewTemplateService()
+ require.NoError(t, err)
+
+ pendingDue := "2026-03-15"
+ overdueDue := "2026-02-01"
+ data := TemplateData{
+ "UserName": "Alice Smith",
+ "PeriodLabel": "Daily digest — Wednesday, 19 February 2026",
+ "PendingTasks": []map[string]interface{}{
+ {
+ "StepTitle": "Submit Evidence",
+ "WorkflowTitle": "SOC2 Audit",
+ "WorkflowInstanceTitle": "SOC2 2026",
+ "DueDate": &pendingDue,
+ "StepURL": "https://app.example.com/steps/abc",
+ },
+ },
+ "OverdueTasks": []map[string]interface{}{
+ {
+ "StepTitle": "Review Policy",
+ "WorkflowTitle": "Annual Audit",
+ "WorkflowInstanceTitle": "Audit 2026",
+ "DueDate": &overdueDue,
+ "StepURL": "https://app.example.com/steps/xyz",
+ },
+ },
+ }
+
+ html, text, err := service.Use("workflow-task-digest", data)
+ require.NoError(t, err)
+ require.NotEmpty(t, html)
+ require.NotEmpty(t, text)
+ require.Contains(t, html, "Alice Smith")
+ require.Contains(t, html, "Submit Evidence")
+ require.Contains(t, html, "Review Policy")
+ require.Contains(t, html, "SOC2 Audit")
+ require.Contains(t, text, "Alice Smith")
+ require.Contains(t, text, "Submit Evidence")
+ require.Contains(t, text, "PENDING")
+ require.Contains(t, text, "OVERDUE")
+}
+
+func TestTemplateService_WorkflowTaskDigest_EmptyTasks(t *testing.T) {
+ service, err := NewTemplateService()
+ require.NoError(t, err)
+
+ data := TemplateData{
+ "UserName": "Bob",
+ "PeriodLabel": "Daily digest — Wednesday, 19 February 2026",
+ "PendingTasks": []map[string]interface{}{},
+ "OverdueTasks": []map[string]interface{}{},
+ }
+
+ html, text, err := service.Use("workflow-task-digest", data)
+ require.NoError(t, err)
+ require.NotEmpty(t, html)
+ require.NotEmpty(t, text)
+ require.Contains(t, html, "Bob")
+}
+
func TestTemplateService_ListTemplates(t *testing.T) {
service, err := NewTemplateService()
require.NoError(t, err, "Failed to create template service")
diff --git a/internal/service/email/templates/templates/workflow-execution-failed.html b/internal/service/email/templates/templates/workflow-execution-failed.html
new file mode 100644
index 00000000..4dd5c06e
--- /dev/null
+++ b/internal/service/email/templates/templates/workflow-execution-failed.html
@@ -0,0 +1,196 @@
+
+
+
+
+
+ Workflow Execution Failed
+
+
+
+
+
+
+
+
Hi {{.RecipientName}},
+
+
+
{{.WorkflowInstanceName}} — execution failed
+
Workflow: {{.WorkflowTitle}}
+
+
+
Execution Summary
+
+
+
{{.FailedSteps}}
+
Failed Steps
+
+
+
{{.CompletedSteps}}
+
Completed Steps
+
+
+
{{.TotalSteps}}
+
Total Steps
+
+
+
+
Execution Details
+
+
+ Execution ID
+ {{.ExecutionID}}
+
+
+ Failed At
+ {{.FailedAt}}
+
+
+
+
+
Failure Reason
+
{{.FailureReason}}
+
+
+ {{if .WorkflowURL}}
View Workflow Instance{{end}}
+
+
+
+
+
+
+
diff --git a/internal/service/email/templates/templates/workflow-execution-failed.txt b/internal/service/email/templates/templates/workflow-execution-failed.txt
new file mode 100644
index 00000000..1f7cfa72
--- /dev/null
+++ b/internal/service/email/templates/templates/workflow-execution-failed.txt
@@ -0,0 +1,32 @@
+WORKFLOW EXECUTION FAILED
+=========================
+
+Hi {{.RecipientName}},
+
+A workflow execution has failed and may require your attention.
+
+WORKFLOW: {{.WorkflowTitle}}
+INSTANCE: {{.WorkflowInstanceName}}
+
+EXECUTION SUMMARY
+-----------------
+Failed Steps: {{.FailedSteps}}
+Completed Steps: {{.CompletedSteps}}
+Total Steps: {{.TotalSteps}}
+
+EXECUTION DETAILS
+-----------------
+Execution ID: {{.ExecutionID}}
+Failed At: {{.FailedAt}}
+
+FAILURE REASON
+--------------
+{{.FailureReason}}
+{{if .WorkflowURL}}
+VIEW WORKFLOW INSTANCE
+----------------------
+{{.WorkflowURL}}
+{{end}}
+---
+You are receiving this because you are the owner of this workflow instance.
+Compliance Framework
diff --git a/internal/service/email/templates/templates/workflow-task-assigned.html b/internal/service/email/templates/templates/workflow-task-assigned.html
new file mode 100644
index 00000000..d64bb1bc
--- /dev/null
+++ b/internal/service/email/templates/templates/workflow-task-assigned.html
@@ -0,0 +1,198 @@
+
+
+
+
+
+ Task Ready for You
+
+
+
+
+
+
+
+
Hello {{.UserName}},
+
+
+ A workflow task has been assigned to you and is ready for action.
+
+
+
+
Task
+
{{.StepTitle}}
+
+
Workflow
+
{{.WorkflowTitle}}
+
+
Instance
+
{{.WorkflowInstanceTitle}}
+
+ {{if .DueDate}}
+
Due Date
+
+ {{.DueDate}}
+
+ {{end}}
+
+
+
+
+
+
+
+
+
+
diff --git a/internal/service/email/templates/templates/workflow-task-assigned.txt b/internal/service/email/templates/templates/workflow-task-assigned.txt
new file mode 100644
index 00000000..8c0a918c
--- /dev/null
+++ b/internal/service/email/templates/templates/workflow-task-assigned.txt
@@ -0,0 +1,19 @@
+Task Ready for You
+==================
+
+Hello {{.UserName}},
+
+A workflow task has been assigned to you and is ready for action.
+
+TASK DETAILS
+------------
+Task: {{.StepTitle}}
+Workflow: {{.WorkflowTitle}}
+Instance: {{.WorkflowInstanceTitle}}
+{{if .DueDate}}Due Date: {{.DueDate}}
+{{end}}
+View your task: {{.StepURL}}
+
+---
+This is an automated notification from Compliance Framework.
+Please do not reply to this email.
diff --git a/internal/service/email/templates/templates/workflow-task-digest.html b/internal/service/email/templates/templates/workflow-task-digest.html
new file mode 100644
index 00000000..079574b4
--- /dev/null
+++ b/internal/service/email/templates/templates/workflow-task-digest.html
@@ -0,0 +1,133 @@
+
+
+
+
+
+ Your Workflow Task Summary
+
+
+
+
+
+
+
+
Hello {{.UserName}},
+
Here is a summary of your workflow tasks that need attention.
+
+ {{if .OverdueTasks}}
+
+
Overdue Tasks ({{len .OverdueTasks}})
+ {{range .OverdueTasks}}
+
+
{{.StepTitle}}
+
+ {{.WorkflowTitle}}
+ {{if .DueDate}}Due: {{.DueDate}}{{end}}
+
+ {{if .StepURL}}
View task →{{end}}
+
+ {{end}}
+
+ {{end}}
+
+ {{if .PendingTasks}}
+
+
Pending Tasks ({{len .PendingTasks}})
+ {{range .PendingTasks}}
+
+
{{.StepTitle}}
+
+ {{.WorkflowTitle}}
+ {{if .DueDate}}Due: {{.DueDate}}{{end}}
+
+ {{if .StepURL}}
View task →{{end}}
+
+ {{end}}
+
+ {{end}}
+
+ {{if and (eq (len .PendingTasks) 0) (eq (len .OverdueTasks) 0)}}
+
You have no pending or overdue tasks at this time.
+ {{end}}
+
+
+
+
+
+
diff --git a/internal/service/email/templates/templates/workflow-task-digest.txt b/internal/service/email/templates/templates/workflow-task-digest.txt
new file mode 100644
index 00000000..99e2fd9c
--- /dev/null
+++ b/internal/service/email/templates/templates/workflow-task-digest.txt
@@ -0,0 +1,33 @@
+Your Workflow Task Summary
+==========================
+
+Hello {{.UserName}},
+
+Here is a summary of your workflow tasks that need attention.
+Period: {{.PeriodLabel}}
+
+{{if .OverdueTasks}}
+OVERDUE TASKS ({{len .OverdueTasks}})
+--------------------------------------
+{{range .OverdueTasks}}
+- {{.StepTitle}}
+ Workflow: {{.WorkflowTitle}}{{if .WorkflowInstanceTitle}} / {{.WorkflowInstanceTitle}}{{end}}
+ {{if .DueDate}}Due: {{.DueDate}}{{end}}
+ {{if .StepURL}}Link: {{.StepURL}}{{end}}
+{{end}}
+{{end}}
+{{if .PendingTasks}}
+PENDING TASKS ({{len .PendingTasks}})
+--------------------------------------
+{{range .PendingTasks}}
+- {{.StepTitle}}
+ Workflow: {{.WorkflowTitle}}{{if .WorkflowInstanceTitle}} / {{.WorkflowInstanceTitle}}{{end}}
+ {{if .DueDate}}Due: {{.DueDate}}{{end}}
+ {{if .StepURL}}Link: {{.StepURL}}{{end}}
+{{end}}
+{{end}}
+---
+{{if .MyTasksURL}}View all your tasks: {{.MyTasksURL}}
+
+{{end}}This is an automated digest from Compliance Framework.
+To unsubscribe, update your notification preferences in your account settings.
diff --git a/internal/service/email/templates/templates/workflow-task-due-soon.html b/internal/service/email/templates/templates/workflow-task-due-soon.html
new file mode 100644
index 00000000..3de40d48
--- /dev/null
+++ b/internal/service/email/templates/templates/workflow-task-due-soon.html
@@ -0,0 +1,196 @@
+
+
+
+
+
+ Task Due Soon
+
+
+
+
+
+
+
+
Hello {{.UserName}},
+
+
+ This is a reminder that the following task is due soon. Please make sure to complete it on time.
+
+
+
+
Task
+
{{.StepTitle}}
+
+
Workflow
+
{{.WorkflowTitle}}
+
+
Instance
+
{{.WorkflowInstanceTitle}}
+
+
Due Date
+
+ {{.DueDate}}
+
+
+
+
+
+
+
+
+
+
+
diff --git a/internal/service/email/templates/templates/workflow-task-due-soon.txt b/internal/service/email/templates/templates/workflow-task-due-soon.txt
new file mode 100644
index 00000000..6b757a74
--- /dev/null
+++ b/internal/service/email/templates/templates/workflow-task-due-soon.txt
@@ -0,0 +1,19 @@
+Task Due Tomorrow
+=================
+
+Hello {{.UserName}},
+
+This is a reminder that the following task is due TOMORROW. Please make sure to complete it on time.
+
+TASK DETAILS
+------------
+Task: {{.StepTitle}}
+Workflow: {{.WorkflowTitle}}
+Instance: {{.WorkflowInstanceTitle}}
+Due Date: {{.DueDate}}
+
+Complete your task: {{.StepURL}}
+
+---
+This is an automated notification from Compliance Framework.
+Please do not reply to this email.
diff --git a/internal/service/relational/workflows/service_manager.go b/internal/service/relational/workflows/service_manager.go
deleted file mode 100644
index 2459a951..00000000
--- a/internal/service/relational/workflows/service_manager.go
+++ /dev/null
@@ -1,44 +0,0 @@
-package workflows
-
-import (
- "gorm.io/gorm"
-)
-
-// ServiceManager provides a unified interface to all workflow services
-type ServiceManager struct {
- WorkflowDefinition *WorkflowDefinitionService
- WorkflowStep *WorkflowStepDefinitionService
- WorkflowInstance *WorkflowInstanceService
- WorkflowExecution *WorkflowExecutionService
- StepExecution *StepExecutionService
- RoleAssignment *RoleAssignmentService
- ControlRelationship *ControlRelationshipService
- db *gorm.DB
-}
-
-// NewServiceManager creates a new ServiceManager with all workflow services
-func NewServiceManager(db *gorm.DB) *ServiceManager {
- return &ServiceManager{
- WorkflowDefinition: NewWorkflowDefinitionService(db),
- WorkflowStep: NewWorkflowStepDefinitionService(db),
- WorkflowInstance: NewWorkflowInstanceService(db),
- WorkflowExecution: NewWorkflowExecutionService(db),
- StepExecution: NewStepExecutionService(db, nil),
- RoleAssignment: NewRoleAssignmentService(db),
- ControlRelationship: NewControlRelationshipService(db),
- db: db,
- }
-}
-
-// DB returns the underlying database connection
-func (sm *ServiceManager) DB() *gorm.DB {
- return sm.db
-}
-
-// Transaction executes a function within a database transaction
-func (sm *ServiceManager) Transaction(fn func(*ServiceManager) error) error {
- return sm.db.Transaction(func(tx *gorm.DB) error {
- txManager := NewServiceManager(tx)
- return fn(txManager)
- })
-}
diff --git a/internal/service/relational/workflows/step_execution_service.go b/internal/service/relational/workflows/step_execution_service.go
index 60f7bcb5..98be1e1f 100644
--- a/internal/service/relational/workflows/step_execution_service.go
+++ b/internal/service/relational/workflows/step_execution_service.go
@@ -297,6 +297,43 @@ func (s *StepExecutionService) AssignTo(id *uuid.UUID, assignedToType, assignedT
}).Error
}
+// ReassignWithTx updates step assignment fields using the provided transaction.
+// If tx is nil, it falls back to the service DB handle.
+func (s *StepExecutionService) ReassignWithTx(tx *gorm.DB, id *uuid.UUID, assignedToType, assignedToID string, assignedAt time.Time) error {
+ if tx == nil {
+ tx = s.db
+ }
+
+ return tx.Model(&StepExecution{}).
+ Where("id = ?", id).
+ Updates(map[string]interface{}{
+ "assigned_to_type": assignedToType,
+ "assigned_to_id": assignedToID,
+ "assigned_at": assignedAt,
+ }).Error
+}
+
+// BulkFailWithTx marks all non-terminal steps in an execution as failed using the provided transaction.
+// If tx is nil, it falls back to the service DB handle.
+func (s *StepExecutionService) BulkFailWithTx(tx *gorm.DB, executionID *uuid.UUID, reason string, failedAt time.Time) error {
+ if tx == nil {
+ tx = s.db
+ }
+
+ return tx.Model(&StepExecution{}).
+ Where("workflow_execution_id = ? AND status IN ?", executionID, []string{
+ StepStatusPending.String(),
+ StepStatusBlocked.String(),
+ StepStatusInProgress.String(),
+ StepStatusOverdue.String(),
+ }).
+ Updates(map[string]interface{}{
+ "status": StepStatusFailed.String(),
+ "failed_at": failedAt,
+ "failure_reason": reason,
+ }).Error
+}
+
// GetPendingSteps retrieves all pending step executions for a workflow execution
func (s *StepExecutionService) GetPendingSteps(executionID *uuid.UUID) ([]StepExecution, error) {
var stepExecutions []StepExecution
diff --git a/internal/service/relational/workflows/workflow_execution_service.go b/internal/service/relational/workflows/workflow_execution_service.go
index 0448a58c..660b4c84 100644
--- a/internal/service/relational/workflows/workflow_execution_service.go
+++ b/internal/service/relational/workflows/workflow_execution_service.go
@@ -252,6 +252,30 @@ func (s *WorkflowExecutionService) Fail(id *uuid.UUID, reason string) error {
})
}
+// FailIfNotTerminal marks the execution as failed only if it is not already in a terminal state
+// (failed, completed, or cancelled). Returns true if the row was updated, false if it was already terminal.
+// This is the idempotent version of Fail, safe to call from concurrent code paths.
+func (s *WorkflowExecutionService) FailIfNotTerminal(ctx context.Context, id *uuid.UUID, reason string) (bool, error) {
+ now := time.Now()
+ terminalStatuses := []string{
+ WorkflowStatusFailed.String(),
+ WorkflowStatusCompleted.String(),
+ WorkflowStatusCancelled.String(),
+ }
+ result := s.db.WithContext(ctx).
+ Model(&WorkflowExecution{}).
+ Where("id = ? AND status NOT IN ?", id, terminalStatuses).
+ Updates(map[string]interface{}{
+ "status": WorkflowStatusFailed.String(),
+ "failed_at": now,
+ "failure_reason": reason,
+ })
+ if result.Error != nil {
+ return false, result.Error
+ }
+ return result.RowsAffected > 0, nil
+}
+
// Cancel cancels a workflow execution
func (s *WorkflowExecutionService) Cancel(id *uuid.UUID) error {
return s.db.Model(&WorkflowExecution{}).
diff --git a/internal/service/scheduler/cron.go b/internal/service/scheduler/cron.go
index 2807f2be..3de16e28 100644
--- a/internal/service/scheduler/cron.go
+++ b/internal/service/scheduler/cron.go
@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"sync"
- "time"
"github.com/robfig/cron/v3"
"go.uber.org/zap"
@@ -32,21 +31,6 @@ func NewCronScheduler(logger *zap.SugaredLogger) *CronScheduler {
}
}
-// ParseCronNext parses a cron expression and returns the next scheduled time
-// Supports 6-field cron syntax (with seconds) and descriptors like @weekly, @daily, etc.
-// Format: second minute hour day month weekday
-//
-// @weekly = Monday 00:00:00 UTC
-func ParseCronNext(cronExpr string, from time.Time) (time.Time, error) {
- // Use same parser as CronScheduler - supports seconds (6-field cron)
- parser := cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
- schedule, err := parser.Parse(cronExpr)
- if err != nil {
- return time.Time{}, fmt.Errorf("failed to parse cron expression %q: %w", cronExpr, err)
- }
- return schedule.Next(from), nil
-}
-
// Schedule adds a job to run on the given schedule
func (s *CronScheduler) Schedule(schedule Schedule, job Job) error {
var cronExpr string
diff --git a/internal/service/worker/due_soon_checker.go b/internal/service/worker/due_soon_checker.go
new file mode 100644
index 00000000..56edb947
--- /dev/null
+++ b/internal/service/worker/due_soon_checker.go
@@ -0,0 +1,111 @@
+package worker
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/compliance-framework/api/internal/service/relational/workflows"
+ "github.com/compliance-framework/api/internal/workflow"
+ "github.com/riverqueue/river"
+ "go.uber.org/zap"
+ "gorm.io/gorm"
+)
+
+// DueSoonCheckerArgs represents the arguments for the periodic due-soon checker job
+type DueSoonCheckerArgs struct{}
+
+// Kind returns the job kind for River
+func (DueSoonCheckerArgs) Kind() string { return "workflow_due_soon_checker" }
+
+// Timeout returns the timeout for the due-soon checker job
+func (DueSoonCheckerArgs) Timeout() time.Duration { return 5 * time.Minute }
+
+// DueSoonCheckerWorker scans for step executions due in a week and enqueues reminder emails
+type DueSoonCheckerWorker struct {
+ db *gorm.DB
+ client workflow.RiverClient
+ logger *zap.SugaredLogger
+}
+
+// NewDueSoonCheckerWorker creates a new DueSoonCheckerWorker
+func NewDueSoonCheckerWorker(db *gorm.DB, client workflow.RiverClient, logger *zap.SugaredLogger) *DueSoonCheckerWorker {
+ return &DueSoonCheckerWorker{
+ db: db,
+ client: client,
+ logger: logger,
+ }
+}
+
+// Work scans for step executions due in ~1 week and enqueues WorkflowTaskDueSoonArgs jobs
+func (w *DueSoonCheckerWorker) Work(ctx context.Context, job *river.Job[DueSoonCheckerArgs]) error {
+ if w.db == nil {
+ return fmt.Errorf("DueSoonCheckerWorker: db is nil")
+ }
+
+ now := time.Now()
+ windowStart := now
+ windowEnd := now.Add(7 * 24 * time.Hour)
+
+ var steps []workflows.StepExecution
+ if err := w.db.WithContext(ctx).
+ Preload("WorkflowExecution.WorkflowInstance.WorkflowDefinition").
+ Preload("WorkflowStepDefinition").
+ Where("status IN ? AND due_date IS NOT NULL AND due_date >= ? AND due_date <= ? AND assigned_to_type = ? AND assigned_to_id != ''",
+ []string{
+ workflows.StepStatusPending.String(),
+ workflows.StepStatusInProgress.String(),
+ },
+ windowStart,
+ windowEnd,
+ workflows.AssignmentTypeUser.String(),
+ ).
+ Find(&steps).Error; err != nil {
+ return fmt.Errorf("due-soon checker: failed to query steps: %w", err)
+ }
+
+ if len(steps) == 0 {
+ w.logger.Infow("DueSoonCheckerWorker: no steps due soon", "window_start", windowStart, "window_end", windowEnd)
+ return nil
+ }
+
+ params := make([]river.InsertManyParams, 0, len(steps))
+ for i := range steps {
+ step := &steps[i]
+ if step.DueDate == nil {
+ continue
+ }
+
+ titles := resolveStepTitles(step)
+
+ args := WorkflowTaskDueSoonArgs{
+ UserID: step.AssignedToID,
+ StepExecutionID: step.ID.String(),
+ StepTitle: titles.Step,
+ WorkflowTitle: titles.Workflow,
+ WorkflowInstanceTitle: titles.Instance,
+ StepURL: "",
+ DueDate: *step.DueDate,
+ }
+ insertOpts := JobInsertOptionsForWorkflowNotification()
+ insertOpts.UniqueOpts = river.UniqueOpts{
+ ByArgs: true,
+ ByPeriod: 24 * time.Hour,
+ }
+ params = append(params, river.InsertManyParams{
+ Args: args,
+ InsertOpts: insertOpts,
+ })
+ }
+
+ if len(params) == 0 {
+ return nil
+ }
+
+ if _, err := w.client.InsertMany(ctx, params); err != nil {
+ return fmt.Errorf("due-soon checker: failed to enqueue reminder jobs: %w", err)
+ }
+
+ w.logger.Infow("DueSoonCheckerWorker: enqueued due-soon reminders", "count", len(params))
+ return nil
+}
diff --git a/internal/service/worker/helpers.go b/internal/service/worker/helpers.go
new file mode 100644
index 00000000..d6e006bb
--- /dev/null
+++ b/internal/service/worker/helpers.go
@@ -0,0 +1,53 @@
+package worker
+
+import (
+ "strings"
+ "time"
+
+ "github.com/compliance-framework/api/internal/service/relational/workflows"
+)
+
+type stepTitles struct {
+ Step string
+ Workflow string
+ Instance string
+}
+
+func resolveStepTitles(step *workflows.StepExecution) stepTitles {
+ titles := stepTitles{}
+ if step == nil {
+ return titles
+ }
+
+ if step.WorkflowStepDefinition != nil {
+ titles.Step = step.WorkflowStepDefinition.Name
+ }
+ if step.WorkflowExecution != nil && step.WorkflowExecution.WorkflowInstance != nil {
+ if step.WorkflowExecution.WorkflowInstance.WorkflowDefinition != nil {
+ titles.Workflow = step.WorkflowExecution.WorkflowInstance.WorkflowDefinition.Name
+ }
+ titles.Instance = step.WorkflowExecution.WorkflowInstance.Name
+ }
+
+ return titles
+}
+
+func resolveTaskURL(stepURL, webBaseURL string) string {
+ if stepURL != "" {
+ return stepURL
+ }
+ return webBaseURL + "/my-tasks"
+}
+
+// formatDate formats a time value as "dd/mmm/yyyy" (e.g. "05/mar/2025").
+func formatDate(t time.Time) string {
+ return strings.ToLower(t.Format("02/Jan/2006"))
+}
+
+// formatDueDate formats an optional due date pointer; returns "" when nil.
+func formatDueDate(t *time.Time) string {
+ if t == nil {
+ return ""
+ }
+ return formatDate(*t)
+}
diff --git a/internal/service/worker/helpers_test.go b/internal/service/worker/helpers_test.go
new file mode 100644
index 00000000..4b2db436
--- /dev/null
+++ b/internal/service/worker/helpers_test.go
@@ -0,0 +1,83 @@
+package worker
+
+import (
+ "testing"
+
+ "github.com/compliance-framework/api/internal/service/relational/workflows"
+ "github.com/stretchr/testify/require"
+)
+
+func TestResolveStepTitles(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil step", func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, stepTitles{}, resolveStepTitles(nil))
+ })
+
+ t.Run("partial preload", func(t *testing.T) {
+ t.Parallel()
+ step := &workflows.StepExecution{
+ WorkflowStepDefinition: &workflows.WorkflowStepDefinition{Name: "Collect Evidence"},
+ }
+
+ require.Equal(t, stepTitles{
+ Step: "Collect Evidence",
+ }, resolveStepTitles(step))
+ })
+
+ t.Run("fully preloaded", func(t *testing.T) {
+ t.Parallel()
+ step := &workflows.StepExecution{
+ WorkflowStepDefinition: &workflows.WorkflowStepDefinition{Name: "Collect Evidence"},
+ WorkflowExecution: &workflows.WorkflowExecution{
+ WorkflowInstance: &workflows.WorkflowInstance{
+ Name: "Q1 2026 Access Review",
+ WorkflowDefinition: &workflows.WorkflowDefinition{Name: "Access Review"},
+ },
+ },
+ }
+
+ require.Equal(t, stepTitles{
+ Step: "Collect Evidence",
+ Workflow: "Access Review",
+ Instance: "Q1 2026 Access Review",
+ }, resolveStepTitles(step))
+ })
+}
+
+func TestNotificationUserFullName(t *testing.T) {
+ t.Parallel()
+
+ t.Run("first name only", func(t *testing.T) {
+ t.Parallel()
+ user := NotificationUser{FirstName: "Alice"}
+ require.Equal(t, "Alice", user.FullName())
+ })
+
+ t.Run("first and last name", func(t *testing.T) {
+ t.Parallel()
+ user := NotificationUser{FirstName: "Alice", LastName: "Smith"}
+ require.Equal(t, "Alice Smith", user.FullName())
+ })
+
+ t.Run("last name only preserves existing behavior", func(t *testing.T) {
+ t.Parallel()
+ user := NotificationUser{LastName: "Smith"}
+ require.Equal(t, " Smith", user.FullName())
+ })
+}
+
+func TestResolveTaskURL(t *testing.T) {
+ t.Parallel()
+
+ t.Run("uses step URL when present", func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, "https://app.example.com/steps/123", resolveTaskURL("https://app.example.com/steps/123", "https://app.example.com"))
+ })
+
+ t.Run("falls back to my tasks URL", func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, "https://app.example.com/my-tasks", resolveTaskURL("", "https://app.example.com"))
+ })
+}
diff --git a/internal/service/worker/jobs.go b/internal/service/worker/jobs.go
index b232b309..b5e667c5 100644
--- a/internal/service/worker/jobs.go
+++ b/internal/service/worker/jobs.go
@@ -8,6 +8,8 @@ import (
"github.com/compliance-framework/api/internal/service/email/types"
"github.com/riverqueue/river"
+ "go.uber.org/zap"
+ "gorm.io/gorm"
)
// Job types for email processing
@@ -17,6 +19,71 @@ const (
JobTypeSendGlobalDigest = "send_global_digest"
)
+// Job types for workflow notifications
+const (
+ JobTypeWorkflowTaskAssigned = "workflow_task_assigned"
+ JobTypeWorkflowTaskDueSoon = "workflow_task_due_soon"
+ JobTypeWorkflowTaskDigest = "workflow_task_digest"
+ JobTypeWorkflowExecutionFailed = "workflow_execution_failed"
+)
+
+// WorkflowTaskAssignedArgs represents the arguments for a new-task-assigned notification email
+type WorkflowTaskAssignedArgs struct {
+ AssignedToType string `json:"assigned_to_type"`
+ UserID string `json:"user_id"`
+ StepExecutionID string `json:"step_execution_id"`
+ StepTitle string `json:"step_title"`
+ WorkflowTitle string `json:"workflow_title"`
+ WorkflowInstanceTitle string `json:"workflow_instance_title"`
+ StepURL string `json:"step_url"`
+ DueDate *time.Time `json:"due_date,omitempty"`
+}
+
+// WorkflowTaskDueSoonArgs represents the arguments for a task-due-in-1-day reminder email
+type WorkflowTaskDueSoonArgs struct {
+ UserID string `json:"user_id"`
+ StepExecutionID string `json:"step_execution_id"`
+ StepTitle string `json:"step_title"`
+ WorkflowTitle string `json:"workflow_title"`
+ WorkflowInstanceTitle string `json:"workflow_instance_title"`
+ StepURL string `json:"step_url"`
+ DueDate time.Time `json:"due_date"`
+}
+
+// WorkflowTaskDigestArgs represents the arguments for a per-user task digest email
+type WorkflowTaskDigestArgs struct {
+ UserID string `json:"user_id"`
+}
+
+// WorkflowExecutionFailedArgs represents the arguments for a workflow-execution-failed notification email
+type WorkflowExecutionFailedArgs struct {
+ WorkflowExecutionID string `json:"workflow_execution_id"`
+}
+
+// Kind returns the job kind for River
+func (WorkflowTaskAssignedArgs) Kind() string { return JobTypeWorkflowTaskAssigned }
+
+// Kind returns the job kind for River
+func (WorkflowTaskDueSoonArgs) Kind() string { return JobTypeWorkflowTaskDueSoon }
+
+// Kind returns the job kind for River
+func (WorkflowTaskDigestArgs) Kind() string { return JobTypeWorkflowTaskDigest }
+
+// Kind returns the job kind for River
+func (WorkflowExecutionFailedArgs) Kind() string { return JobTypeWorkflowExecutionFailed }
+
+// Timeout returns the timeout for workflow task assigned jobs
+func (WorkflowTaskAssignedArgs) Timeout() time.Duration { return 30 * time.Second }
+
+// Timeout returns the timeout for workflow task due soon jobs
+func (WorkflowTaskDueSoonArgs) Timeout() time.Duration { return 30 * time.Second }
+
+// Timeout returns the timeout for workflow task digest jobs
+func (WorkflowTaskDigestArgs) Timeout() time.Duration { return 5 * time.Minute }
+
+// Timeout returns the timeout for workflow execution failed jobs
+func (WorkflowExecutionFailedArgs) Timeout() time.Duration { return 30 * time.Second }
+
// SendEmailArgs represents the arguments for sending an email
type SendEmailArgs struct {
// Email message fields
@@ -66,14 +133,30 @@ func (SendGlobalDigestArgs) Kind() string { return JobTypeSendGlobalDigest }
type EmailService interface {
Send(ctx context.Context, message *types.Message) (*types.SendResult, error)
SendWithProvider(ctx context.Context, providerName string, message *types.Message) (*types.SendResult, error)
+ UseTemplate(templateName string, data map[string]interface{}) (htmlContent, textContent string, err error)
+ GetDefaultFromAddress() string
+}
+
+// UserRepository is the minimal DB interface needed by notification workers
+type UserRepository interface {
+ FindUserByID(ctx context.Context, userID string) (NotificationUser, error)
}
-// Logger interface for logging
-type Logger interface {
- Infow(msg string, keysAndValues ...interface{})
- Errorw(msg string, keysAndValues ...interface{})
- Warnw(msg string, keysAndValues ...interface{})
- Debugw(msg string, keysAndValues ...interface{})
+// NotificationUser holds the user fields needed for sending notification emails
+type NotificationUser struct {
+ ID string
+ Email string
+ FirstName string
+ LastName string
+ TaskAvailableEmailSubscribed bool
+ TaskDailyDigestSubscribed bool
+}
+
+func (u NotificationUser) FullName() string {
+ if u.LastName == "" {
+ return u.FirstName
+ }
+ return u.FirstName + " " + u.LastName
}
// DigestService interface for dependency injection
@@ -99,11 +182,11 @@ func (SendGlobalDigestArgs) Timeout() time.Duration {
// SendEmailWorker handles sending email jobs
type SendEmailWorker struct {
emailService EmailService
- logger Logger
+ logger *zap.SugaredLogger
}
// NewSendEmailWorker creates a new SendEmailWorker
-func NewSendEmailWorker(emailService EmailService, logger Logger) *SendEmailWorker {
+func NewSendEmailWorker(emailService EmailService, logger *zap.SugaredLogger) *SendEmailWorker {
return &SendEmailWorker{
emailService: emailService,
logger: logger,
@@ -173,11 +256,11 @@ func (w *SendEmailWorker) Work(ctx context.Context, job *river.Job[SendEmailArgs
// SendEmailFromWorker handles sending email from provider jobs
type SendEmailFromWorker struct {
emailService EmailService
- logger Logger
+ logger *zap.SugaredLogger
}
// NewSendEmailFromWorker creates a new SendEmailFromWorker
-func NewSendEmailFromWorker(emailService EmailService, logger Logger) *SendEmailFromWorker {
+func NewSendEmailFromWorker(emailService EmailService, logger *zap.SugaredLogger) *SendEmailFromWorker {
return &SendEmailFromWorker{
emailService: emailService,
logger: logger,
@@ -254,11 +337,11 @@ func (w *SendEmailFromWorker) Work(ctx context.Context, job *river.Job[SendEmail
// SendGlobalDigestWorker handles sending global digest jobs
type SendGlobalDigestWorker struct {
digestService DigestService
- logger Logger
+ logger *zap.SugaredLogger
}
// NewSendGlobalDigestWorker creates a new SendGlobalDigestWorker
-func NewSendGlobalDigestWorker(digestService DigestService, logger Logger) *SendGlobalDigestWorker {
+func NewSendGlobalDigestWorker(digestService DigestService, logger *zap.SugaredLogger) *SendGlobalDigestWorker {
return &SendGlobalDigestWorker{
digestService: digestService,
logger: logger,
@@ -282,21 +365,234 @@ func (w *SendGlobalDigestWorker) Work(ctx context.Context, job *river.Job[SendGl
return nil
}
-// JobInsertOptions returns common insert options for email jobs
-func JobInsertOptions() *river.InsertOpts {
+// WorkflowTaskAssignedWorker handles new-task-assigned notification email jobs
+type WorkflowTaskAssignedWorker struct {
+ emailService EmailService
+ userRepo UserRepository
+ webBaseURL string
+ logger *zap.SugaredLogger
+}
+
+// NewWorkflowTaskAssignedWorker creates a new WorkflowTaskAssignedWorker
+func NewWorkflowTaskAssignedWorker(emailService EmailService, userRepo UserRepository, webBaseURL string, logger *zap.SugaredLogger) *WorkflowTaskAssignedWorker {
+ return &WorkflowTaskAssignedWorker{
+ emailService: emailService,
+ userRepo: userRepo,
+ webBaseURL: webBaseURL,
+ logger: logger,
+ }
+}
+
+// Work is the River work function for sending new-task-assigned notification emails
+func (w *WorkflowTaskAssignedWorker) Work(ctx context.Context, job *river.Job[WorkflowTaskAssignedArgs]) error {
+ args := job.Args
+
+ if args.AssignedToType == "email" {
+ return w.sendToEmailAddress(ctx, args)
+ }
+ return w.sendToUser(ctx, args)
+}
+
+// sendToUser looks up the user by ID and sends the notification if they are subscribed
+func (w *WorkflowTaskAssignedWorker) sendToUser(ctx context.Context, args WorkflowTaskAssignedArgs) error {
+ user, err := w.userRepo.FindUserByID(ctx, args.UserID)
+ if err != nil {
+ w.logger.Warnw("WorkflowTaskAssignedWorker: user not found, skipping",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ "error", err,
+ )
+ return nil
+ }
+
+ if !user.TaskAvailableEmailSubscribed {
+ w.logger.Debugw("WorkflowTaskAssignedWorker: user not subscribed, skipping",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ )
+ return nil
+ }
+
+ return w.sendEmail(ctx, args, user.Email, user.FullName())
+}
+
+// sendToEmailAddress sends the notification directly to the email address assignee without a user lookup
+func (w *WorkflowTaskAssignedWorker) sendToEmailAddress(ctx context.Context, args WorkflowTaskAssignedArgs) error {
+ return w.sendEmail(ctx, args, args.UserID, "")
+}
+
+// sendEmail renders the template and sends the notification email
+func (w *WorkflowTaskAssignedWorker) sendEmail(ctx context.Context, args WorkflowTaskAssignedArgs, toAddress string, userName string) error {
+ myTasksURL := resolveTaskURL(args.StepURL, w.webBaseURL)
+ templateData := map[string]interface{}{
+ "UserName": userName,
+ "StepTitle": args.StepTitle,
+ "WorkflowTitle": args.WorkflowTitle,
+ "WorkflowInstanceTitle": args.WorkflowInstanceTitle,
+ "StepURL": myTasksURL,
+ "MyTasksURL": w.webBaseURL + "/my-tasks",
+ "DueDate": formatDueDate(args.DueDate),
+ }
+
+ htmlBody, textBody, err := w.emailService.UseTemplate("workflow-task-assigned", templateData)
+ if err != nil {
+ w.logger.Errorw("WorkflowTaskAssignedWorker: failed to render template",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ "error", err,
+ )
+ return fmt.Errorf("failed to render workflow-task-assigned template: %w", err)
+ }
+
+ message := &types.Message{
+ From: w.emailService.GetDefaultFromAddress(),
+ To: []string{toAddress},
+ Subject: fmt.Sprintf("Task ready for you: %s — %s", args.StepTitle, args.WorkflowTitle),
+ HTMLBody: htmlBody,
+ TextBody: textBody,
+ }
+
+ result, err := w.emailService.Send(ctx, message)
+ if err != nil {
+ w.logger.Errorw("WorkflowTaskAssignedWorker: failed to send email",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ "error", err,
+ )
+ return fmt.Errorf("failed to send workflow-task-assigned email: %w", err)
+ }
+
+ if !result.Success {
+ w.logger.Errorw("WorkflowTaskAssignedWorker: email send reported failure",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ "error", result.Error,
+ )
+ return fmt.Errorf("workflow-task-assigned email send failed: %s", result.Error)
+ }
+
+ w.logger.Infow("WorkflowTaskAssignedWorker: email sent",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ "message_id", result.MessageID,
+ )
+
+ return nil
+}
+
+// WorkflowTaskDueSoonWorker handles task-due-soon reminder email jobs
+type WorkflowTaskDueSoonWorker struct {
+ emailService EmailService
+ userRepo UserRepository
+ webBaseURL string
+ logger *zap.SugaredLogger
+}
+
+// NewWorkflowTaskDueSoonWorker creates a new WorkflowTaskDueSoonWorker
+func NewWorkflowTaskDueSoonWorker(emailService EmailService, userRepo UserRepository, webBaseURL string, logger *zap.SugaredLogger) *WorkflowTaskDueSoonWorker {
+ return &WorkflowTaskDueSoonWorker{
+ emailService: emailService,
+ userRepo: userRepo,
+ webBaseURL: webBaseURL,
+ logger: logger,
+ }
+}
+
+// Work is the River work function for sending task-due-in-1-day reminder emails
+func (w *WorkflowTaskDueSoonWorker) Work(ctx context.Context, job *river.Job[WorkflowTaskDueSoonArgs]) error {
+ args := job.Args
+
+ user, err := w.userRepo.FindUserByID(ctx, args.UserID)
+ if err != nil {
+ w.logger.Warnw("WorkflowTaskDueSoonWorker: user not found, skipping",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ "error", err,
+ )
+ return nil
+ }
+
+ if !user.TaskAvailableEmailSubscribed {
+ w.logger.Debugw("WorkflowTaskDueSoonWorker: user not subscribed, skipping",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ )
+ return nil
+ }
+
+ myTasksURL := resolveTaskURL(args.StepURL, w.webBaseURL)
+ templateData := map[string]interface{}{
+ "UserName": user.FullName(),
+ "StepTitle": args.StepTitle,
+ "WorkflowTitle": args.WorkflowTitle,
+ "WorkflowInstanceTitle": args.WorkflowInstanceTitle,
+ "StepURL": myTasksURL,
+ "MyTasksURL": w.webBaseURL + "/my-tasks",
+ "DueDate": formatDate(args.DueDate),
+ }
+
+ htmlBody, textBody, err := w.emailService.UseTemplate("workflow-task-due-soon", templateData)
+ if err != nil {
+ w.logger.Errorw("WorkflowTaskDueSoonWorker: failed to render template",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ "error", err,
+ )
+ return fmt.Errorf("failed to render workflow-task-due-soon template: %w", err)
+ }
+
+ message := &types.Message{
+ From: w.emailService.GetDefaultFromAddress(),
+ To: []string{user.Email},
+ Subject: fmt.Sprintf("Reminder: %s is due soon — %s", args.StepTitle, args.WorkflowTitle),
+ HTMLBody: htmlBody,
+ TextBody: textBody,
+ }
+
+ result, err := w.emailService.Send(ctx, message)
+ if err != nil {
+ w.logger.Errorw("WorkflowTaskDueSoonWorker: failed to send email",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ "error", err,
+ )
+ return fmt.Errorf("failed to send workflow-task-due-soon email: %w", err)
+ }
+
+ if !result.Success {
+ w.logger.Errorw("WorkflowTaskDueSoonWorker: email send reported failure",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ "error", result.Error,
+ )
+ return fmt.Errorf("workflow-task-due-soon email send failed: %s", result.Error)
+ }
+
+ w.logger.Infow("WorkflowTaskDueSoonWorker: email sent",
+ "step_execution_id", args.StepExecutionID,
+ "user_id", args.UserID,
+ "message_id", result.MessageID,
+ )
+
+ return nil
+}
+
+// JobInsertOptionsForWorkflowNotification returns insert options for workflow notification email jobs
+func JobInsertOptionsForWorkflowNotification() *river.InsertOpts {
return &river.InsertOpts{
- Queue: "email", // Default queue for email jobs
- MaxAttempts: 5, // Retry up to 5 times
- // River uses exponential backoff by default
+ Queue: "email",
+ MaxAttempts: 5,
}
}
-// JobInsertOptionsWithQueue returns insert options for jobs with specified queue
-func JobInsertOptionsWithQueue(queue string) *river.InsertOpts {
+func JobInsertOptionsForWorkflowTaskAssignedNotification() *river.InsertOpts {
return &river.InsertOpts{
- Queue: queue,
- MaxAttempts: 5, // Retry up to 5 times
- // River uses exponential backoff by default
+ Queue: "email",
+ MaxAttempts: 5,
+ UniqueOpts: river.UniqueOpts{
+ ByArgs: true,
+ ByPeriod: 5 * time.Minute,
+ },
}
}
@@ -305,12 +601,11 @@ func JobInsertOptionsWithRetry(queue string, maxAttempts int) *river.InsertOpts
return &river.InsertOpts{
Queue: queue,
MaxAttempts: maxAttempts,
- // River uses exponential backoff by default
}
}
// Workers returns all workers as work functions with dependencies injected
-func Workers(emailService EmailService, digestService DigestService, logger Logger) *river.Workers {
+func Workers(emailService EmailService, digestService DigestService, userRepo UserRepository, db *gorm.DB, webBaseURL string, logger *zap.SugaredLogger) *river.Workers {
workers := river.NewWorkers()
// Create worker instances with dependencies
@@ -326,5 +621,22 @@ func Workers(emailService EmailService, digestService DigestService, logger Logg
river.AddWorker(workers, river.WorkFunc(sendGlobalDigestWorker.Work))
}
+ // Register workflow notification workers if dependencies are available
+ if userRepo != nil {
+ workflowTaskAssignedWorker := NewWorkflowTaskAssignedWorker(emailService, userRepo, webBaseURL, logger)
+ river.AddWorker(workers, river.WorkFunc(workflowTaskAssignedWorker.Work))
+
+ workflowTaskDueSoonWorker := NewWorkflowTaskDueSoonWorker(emailService, userRepo, webBaseURL, logger)
+ river.AddWorker(workers, river.WorkFunc(workflowTaskDueSoonWorker.Work))
+
+ if db != nil {
+ workflowTaskDigestWorker := NewWorkflowTaskDigestWorker(db, emailService, userRepo, webBaseURL, logger)
+ river.AddWorker(workers, river.WorkFunc(workflowTaskDigestWorker.Work))
+
+ workflowExecutionFailedWorker := NewWorkflowExecutionFailedWorker(db, emailService, userRepo, webBaseURL, logger)
+ river.AddWorker(workers, river.WorkFunc(workflowExecutionFailedWorker.Work))
+ }
+ }
+
return workers
}
diff --git a/internal/service/worker/service.go b/internal/service/worker/service.go
index 0d19495e..21fdda6c 100644
--- a/internal/service/worker/service.go
+++ b/internal/service/worker/service.go
@@ -31,6 +31,7 @@ type Service struct {
db *gorm.DB
emailSvc *email.Service
digestSvc DigestService
+ userRepo UserRepository
logger *zap.SugaredLogger
started bool
startedMu sync.RWMutex
@@ -38,11 +39,7 @@ type Service struct {
digestCfg *config.Config
// Workflow services
- workflowExecutor interface{}
- workflowManager *workflow.Manager
- stepExecutionService interface{}
- workflowExecutionService interface{}
- stepDefinitionService interface{}
+ workflowExecutor *workflow.DAGExecutor
}
type riverClientProxy struct {
@@ -56,14 +53,22 @@ func (p *riverClientProxy) InsertMany(ctx context.Context, params []river.Insert
return p.client.InsertMany(ctx, params)
}
-// NewService creates a new worker service
-func NewService(
- cfg *config.WorkerConfig,
- db *gorm.DB,
- emailSvc *email.Service,
- logger *zap.SugaredLogger,
-) (*Service, error) {
- return NewServiceWithDigest(cfg, db, emailSvc, nil, nil, logger)
+type notificationEnqueuerProxy struct {
+ enqueuer workflow.NotificationEnqueuer
+}
+
+func (p *notificationEnqueuerProxy) EnqueueWorkflowTaskAssigned(ctx context.Context, stepExecution *workflows.StepExecution) error {
+ if p.enqueuer == nil {
+ return fmt.Errorf("notification enqueuer not initialized")
+ }
+ return p.enqueuer.EnqueueWorkflowTaskAssigned(ctx, stepExecution)
+}
+
+func (p *notificationEnqueuerProxy) EnqueueWorkflowExecutionFailed(ctx context.Context, execution *workflows.WorkflowExecution) error {
+ if p.enqueuer == nil {
+ return fmt.Errorf("notification enqueuer not initialized")
+ }
+ return p.enqueuer.EnqueueWorkflowExecutionFailed(ctx, execution)
}
// NewServiceWithDigest creates a new worker service with digest support
@@ -129,8 +134,12 @@ func NewServiceWithDigest(
workflowInstService := workflows.NewWorkflowInstanceService(db)
roleAssignmentService := workflows.NewRoleAssignmentService(db)
+ // Create proxies to handle circular dependency (Service implements NotificationEnqueuer but is built after workflow objects)
+ clientProxy := &riverClientProxy{}
+ enqueuerProxy := ¬ificationEnqueuerProxy{}
+
// Create assignment service
- assignmentService := workflow.NewAssignmentService(roleAssignmentService, db)
+ assignmentService := workflow.NewAssignmentService(roleAssignmentService, stepExecService, db, logger, enqueuerProxy)
// Create workflow executor
workflowLogger := log.New(os.Stdout, "[WORKFLOW] ", log.LstdFlags)
@@ -140,6 +149,7 @@ func NewServiceWithDigest(
stepDefService,
assignmentService,
workflowLogger,
+ enqueuerProxy,
)
// Initialize evidence integration and set it on the executor
@@ -154,10 +164,6 @@ func NewServiceWithDigest(
// Create workflow workers
workflowExecutionWorker := workflow.NewWorkflowExecutionWorker(executor, evidenceIntegration, logger)
- stepExecutionWorker := workflow.NewStepExecutionWorker(stepExecService, logger)
-
- // Create a proxy for RiverClient to handle circular dependency
- clientProxy := &riverClientProxy{}
// Create Manager with proxy
workflowManager := workflow.NewManager(
@@ -166,6 +172,7 @@ func NewServiceWithDigest(
workflowInstService,
stepExecService,
logger,
+ enqueuerProxy,
)
// Determine grace period days for the workflow scheduler, with safe defaults.
@@ -176,17 +183,20 @@ func NewServiceWithDigest(
overdueCheckEnabled = digestCfg.Workflow.OverdueCheckEnabled
}
+ overdueService := workflow.NewOverdueService(
+ db,
+ workflowExecService,
+ stepExecService,
+ evidenceIntegration,
+ logger,
+ gracePeriodDays,
+ enqueuerProxy,
+ )
+
schedulerWorker := workflow.NewWorkflowSchedulerWorker(
workflowManager,
workflowInstService,
- workflow.NewOverdueService(
- db,
- workflowExecService,
- stepExecService,
- evidenceIntegration,
- logger,
- gracePeriodDays,
- ),
+ overdueService,
overdueCheckEnabled,
logger,
gracePeriodDays,
@@ -194,13 +204,25 @@ func NewServiceWithDigest(
// Register workers with dependencies injected
// We start with the email/digest workers
- workers := Workers(emailSvc, digestSvc, logger)
+ userRepo := NewGORMUserRepository(db)
+ webBaseURL := ""
+ if digestCfg != nil {
+ webBaseURL = digestCfg.WebBaseURL
+ }
+ workers := Workers(emailSvc, digestSvc, userRepo, db, webBaseURL, logger)
// Add workflow workers
river.AddWorker(workers, river.WorkFunc(workflowExecutionWorker.Work))
- river.AddWorker(workers, river.WorkFunc(stepExecutionWorker.Work))
river.AddWorker(workers, river.WorkFunc(schedulerWorker.Work))
+ // Add due-soon checker worker (uses clientProxy which is wired to the real client after construction)
+ dueSoonCheckerWorker := NewDueSoonCheckerWorker(db, clientProxy, logger)
+ river.AddWorker(workers, river.WorkFunc(dueSoonCheckerWorker.Work))
+
+ // Add workflow task digest checker worker
+ digestCheckerWorker := NewWorkflowTaskDigestCheckerWorker(db, clientProxy, logger)
+ river.AddWorker(workers, river.WorkFunc(digestCheckerWorker.Work))
+
// Configure periodic jobs
periodicJobs := periodicJobsFromConfig(digestCfg, logger)
@@ -222,25 +244,25 @@ func NewServiceWithDigest(
db: db,
emailSvc: emailSvc,
digestSvc: digestSvc,
+ userRepo: userRepo,
digestCfg: digestCfg,
logger: logger,
started: false,
pgxPool: pgxPool,
- // Store workflow services
- workflowExecutor: executor,
- workflowManager: workflowManager,
- stepExecutionService: stepExecService,
- workflowExecutionService: workflowExecService,
- stepDefinitionService: stepDefService,
+ workflowExecutor: executor,
}
+ // Wire the service itself into the notification enqueuer proxy now that it is fully constructed.
+ enqueuerProxy.enqueuer = service
+
return service, nil
}
-// GetWorkflowManager returns the workflow manager
-func (s *Service) GetWorkflowManager() *workflow.Manager {
- return s.workflowManager
+// GetDAGExecutor returns the shared DAG executor used by workflow River workers.
+// Returns nil when the worker service is disabled.
+func (s *Service) GetDAGExecutor() *workflow.DAGExecutor {
+ return s.workflowExecutor
}
// Start starts the worker service
@@ -386,6 +408,40 @@ func parseCronScheduleWithFallback(cronSchedule string, fallback string, jobName
return schedule
}
+func NewDueSoonCheckerPeriodicJob(schedule string, logger *zap.SugaredLogger) *river.PeriodicJob {
+ sched := parseCronScheduleWithFallback(schedule, "0 0 8 * * *", "due-soon checker", logger)
+
+ return river.NewPeriodicJob(
+ sched,
+ func() (river.JobArgs, *river.InsertOpts) {
+ return &DueSoonCheckerArgs{}, &river.InsertOpts{
+ Queue: "email",
+ MaxAttempts: 3,
+ }
+ },
+ &river.PeriodicJobOpts{
+ RunOnStart: false,
+ },
+ )
+}
+
+func NewWorkflowTaskDigestPeriodicJob(schedule string, logger *zap.SugaredLogger) *river.PeriodicJob {
+ sched := parseCronScheduleWithFallback(schedule, "0 0 8 * * *", "workflow task digest", logger)
+
+ return river.NewPeriodicJob(
+ sched,
+ func() (river.JobArgs, *river.InsertOpts) {
+ return &WorkflowTaskDigestCheckerArgs{}, &river.InsertOpts{
+ Queue: "digest",
+ MaxAttempts: 3,
+ }
+ },
+ &river.PeriodicJobOpts{
+ RunOnStart: false,
+ },
+ )
+}
+
func periodicJobsFromConfig(cfg *config.Config, logger *zap.SugaredLogger) []*river.PeriodicJob {
var periodicJobs []*river.PeriodicJob
if cfg == nil {
@@ -397,6 +453,12 @@ func periodicJobsFromConfig(cfg *config.Config, logger *zap.SugaredLogger) []*ri
if cfg.Workflow != nil && cfg.Workflow.SchedulerEnabled {
periodicJobs = append(periodicJobs, NewWorkflowSchedulerPeriodicJob(cfg.Workflow.Schedule, logger))
}
+ if cfg.Workflow != nil && cfg.Workflow.DueSoonEnabled {
+ periodicJobs = append(periodicJobs, NewDueSoonCheckerPeriodicJob(cfg.Workflow.DueSoonSchedule, logger))
+ }
+ if cfg.Workflow != nil && cfg.Workflow.TaskDigestEnabled {
+ periodicJobs = append(periodicJobs, NewWorkflowTaskDigestPeriodicJob(cfg.Workflow.TaskDigestSchedule, logger))
+ }
return periodicJobs
}
@@ -424,6 +486,82 @@ func buildRiverConfig(cfg *config.WorkerConfig, workers *river.Workers, periodic
}
}
+// EnqueueWorkflowTaskAssigned enqueues a workflow-task-assigned notification email job.
+// Implements the workflow.NotificationEnqueuer interface.
+func (s *Service) EnqueueWorkflowTaskAssigned(ctx context.Context, stepExecution *workflows.StepExecution) error {
+ if !s.config.Enabled || s.client == nil {
+ return nil
+ }
+
+ if stepExecution == nil {
+ return nil
+ }
+
+ // Only enqueue for user or email-type assignees
+ if (stepExecution.AssignedToType != workflows.AssignmentTypeUser.String() &&
+ stepExecution.AssignedToType != workflows.AssignmentTypeEmail.String()) ||
+ stepExecution.AssignedToID == "" {
+ return nil
+ }
+
+ // Reload with full nested relations so title fields are always populated,
+ // regardless of what the caller had preloaded on the passed-in struct.
+ var full workflows.StepExecution
+ if err := s.db.WithContext(ctx).
+ Preload("WorkflowStepDefinition").
+ Preload("WorkflowExecution.WorkflowInstance.WorkflowDefinition").
+ First(&full, "id = ?", stepExecution.ID).Error; err == nil {
+ stepExecution = &full
+ }
+
+ titles := resolveStepTitles(stepExecution)
+
+ args := &WorkflowTaskAssignedArgs{
+ AssignedToType: stepExecution.AssignedToType,
+ UserID: stepExecution.AssignedToID,
+ StepExecutionID: stepExecution.ID.String(),
+ StepTitle: titles.Step,
+ WorkflowTitle: titles.Workflow,
+ WorkflowInstanceTitle: titles.Instance,
+ StepURL: "",
+ DueDate: stepExecution.DueDate,
+ }
+
+ _, err := s.client.InsertMany(ctx, []river.InsertManyParams{
+ {Args: args, InsertOpts: JobInsertOptionsForWorkflowTaskAssignedNotification()},
+ })
+ if err != nil {
+ return fmt.Errorf("failed to enqueue workflow-task-assigned job: %w", err)
+ }
+
+ return nil
+}
+
+// EnqueueWorkflowExecutionFailed enqueues a workflow-execution-failed notification email job.
+// Implements the workflow.NotificationEnqueuer interface.
+func (s *Service) EnqueueWorkflowExecutionFailed(ctx context.Context, execution *workflows.WorkflowExecution) error {
+ if !s.config.Enabled || s.client == nil {
+ return nil
+ }
+
+ if execution == nil || execution.ID == nil {
+ return nil
+ }
+
+ args := &WorkflowExecutionFailedArgs{
+ WorkflowExecutionID: execution.ID.String(),
+ }
+
+ _, err := s.client.InsertMany(ctx, []river.InsertManyParams{
+ {Args: args, InsertOpts: JobInsertOptionsForWorkflowNotification()},
+ })
+ if err != nil {
+ return fmt.Errorf("failed to enqueue workflow-execution-failed job: %w", err)
+ }
+
+ return nil
+}
+
// EnqueueSendEmail enqueues a send email job
func (s *Service) EnqueueSendEmail(ctx context.Context, args *SendEmailArgs) error {
if !s.config.Enabled {
diff --git a/internal/service/worker/service_test.go b/internal/service/worker/service_test.go
index a476263b..1a6c9d47 100644
--- a/internal/service/worker/service_test.go
+++ b/internal/service/worker/service_test.go
@@ -29,6 +29,16 @@ func (m *MockEmailService) SendWithProvider(ctx context.Context, providerName st
return args.Get(0).(*types.SendResult), args.Error(1)
}
+func (m *MockEmailService) UseTemplate(templateName string, data map[string]interface{}) (string, string, error) {
+ args := m.Called(templateName, data)
+ return args.String(0), args.String(1), args.Error(2)
+}
+
+func (m *MockEmailService) GetDefaultFromAddress() string {
+ args := m.Called()
+ return args.String(0)
+}
+
// MockDigestService is a mock implementation of DigestService
type MockDigestService struct {
mock.Mock
@@ -39,45 +49,19 @@ func (m *MockDigestService) SendGlobalDigest(ctx context.Context) error {
return args.Error(0)
}
-// MockLogger is a mock implementation of Logger
-type MockLogger struct {
- mock.Mock
- loggedMessages []string
-}
-
-func (m *MockLogger) Infow(msg string, keysAndValues ...interface{}) {
- m.Called(msg, keysAndValues)
- m.loggedMessages = append(m.loggedMessages, "INFO: "+msg)
-}
-
-func (m *MockLogger) Errorw(msg string, keysAndValues ...interface{}) {
- m.Called(msg, keysAndValues)
- m.loggedMessages = append(m.loggedMessages, "ERROR: "+msg)
-}
-
-func (m *MockLogger) Warnw(msg string, keysAndValues ...interface{}) {
- m.Called(msg, keysAndValues)
- m.loggedMessages = append(m.loggedMessages, "WARN: "+msg)
-}
-
-func (m *MockLogger) Debugw(msg string, keysAndValues ...interface{}) {
- m.Called(msg, keysAndValues)
- m.loggedMessages = append(m.loggedMessages, "DEBUG: "+msg)
-}
-
-func TestNewService_Disabled(t *testing.T) {
+func TestNewServiceWithDigest_Disabled(t *testing.T) {
cfg := &config.WorkerConfig{
Enabled: false,
}
logger := zap.NewNop().Sugar()
- service, err := NewService(cfg, nil, nil, logger)
+ service, err := NewServiceWithDigest(cfg, nil, nil, nil, nil, logger)
assert.NoError(t, err)
assert.NotNil(t, service)
assert.False(t, service.IsStarted())
}
-func TestNewService_RequiresEmailService(t *testing.T) {
+func TestNewServiceWithDigest_RequiresEmailService(t *testing.T) {
cfg := &config.WorkerConfig{
Enabled: true,
Workers: 5,
@@ -85,7 +69,7 @@ func TestNewService_RequiresEmailService(t *testing.T) {
}
logger := zap.NewNop().Sugar()
- service, err := NewService(cfg, nil, nil, logger)
+ service, err := NewServiceWithDigest(cfg, nil, nil, nil, nil, logger)
assert.Error(t, err)
assert.Nil(t, service)
assert.Contains(t, err.Error(), "email service is required")
@@ -97,7 +81,7 @@ func TestService_EnqueueWhenDisabled(t *testing.T) {
}
logger := zap.NewNop().Sugar()
- service, err := NewService(cfg, nil, nil, logger)
+ service, err := NewServiceWithDigest(cfg, nil, nil, nil, nil, logger)
assert.NoError(t, err)
ctx := context.Background()
@@ -113,13 +97,13 @@ func TestService_EnqueueWhenDisabled(t *testing.T) {
func TestNewSendEmailWorker(t *testing.T) {
mockEmailService := &MockEmailService{}
- mockLogger := &MockLogger{}
+ logger := zap.NewNop().Sugar()
- worker := NewSendEmailWorker(mockEmailService, mockLogger)
+ worker := NewSendEmailWorker(mockEmailService, logger)
assert.NotNil(t, worker)
assert.Equal(t, mockEmailService, worker.emailService)
- assert.Equal(t, mockLogger, worker.logger)
+ assert.Equal(t, logger, worker.logger)
}
func TestSendEmailWorker_MessageConstruction(t *testing.T) {
@@ -177,8 +161,7 @@ func TestSendEmailWorker_MessageConstruction(t *testing.T) {
func TestSendEmailWorker_Work_Validation(t *testing.T) {
mockEmailService := &MockEmailService{}
- mockLogger := &MockLogger{}
- worker := NewSendEmailWorker(mockEmailService, mockLogger)
+ worker := NewSendEmailWorker(mockEmailService, zap.NewNop().Sugar())
ctx := context.Background()
@@ -215,9 +198,6 @@ func TestSendEmailWorker_Work_Validation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- // Set up mock logger to expect any call
- mockLogger.On("Infow", "Processing send email job", mock.Anything).Maybe()
-
// Create a test job with the invalid args
job := &river.Job[SendEmailArgs]{
Args: *tt.args,
@@ -234,13 +214,13 @@ func TestSendEmailWorker_Work_Validation(t *testing.T) {
func TestNewSendEmailFromWorker(t *testing.T) {
mockEmailService := &MockEmailService{}
- mockLogger := &MockLogger{}
+ logger := zap.NewNop().Sugar()
- worker := NewSendEmailFromWorker(mockEmailService, mockLogger)
+ worker := NewSendEmailFromWorker(mockEmailService, logger)
assert.NotNil(t, worker)
assert.Equal(t, mockEmailService, worker.emailService)
- assert.Equal(t, mockLogger, worker.logger)
+ assert.Equal(t, logger, worker.logger)
}
func TestSendEmailFromWorker_MessageConstruction(t *testing.T) {
@@ -291,13 +271,13 @@ func TestSendEmailFromWorker_MessageConstruction(t *testing.T) {
func TestNewSendGlobalDigestWorker(t *testing.T) {
mockDigestService := &MockDigestService{}
- mockLogger := &MockLogger{}
+ logger := zap.NewNop().Sugar()
- worker := NewSendGlobalDigestWorker(mockDigestService, mockLogger)
+ worker := NewSendGlobalDigestWorker(mockDigestService, logger)
assert.NotNil(t, worker)
assert.Equal(t, mockDigestService, worker.digestService)
- assert.Equal(t, mockLogger, worker.logger)
+ assert.Equal(t, logger, worker.logger)
}
func TestSendGlobalDigestWorker_DigestCall(t *testing.T) {
@@ -318,18 +298,19 @@ func TestSendGlobalDigestWorker_DigestCall(t *testing.T) {
func TestWorkers(t *testing.T) {
mockEmailService := &MockEmailService{}
mockDigestService := &MockDigestService{}
- mockLogger := &MockLogger{}
- workers := Workers(mockEmailService, mockDigestService, mockLogger)
+ workers := Workers(mockEmailService, mockDigestService, nil, nil, "", zap.NewNop().Sugar())
assert.NotNil(t, workers)
}
-func TestJobInsertOptions(t *testing.T) {
- opts := JobInsertOptions()
+func TestJobInsertOptionsForWorkflowTaskAssignedNotification(t *testing.T) {
+ opts := JobInsertOptionsForWorkflowTaskAssignedNotification()
assert.Equal(t, "email", opts.Queue)
assert.Equal(t, 5, opts.MaxAttempts)
+ assert.True(t, opts.UniqueOpts.ByArgs)
+ assert.Equal(t, 5*time.Minute, opts.UniqueOpts.ByPeriod)
}
func TestParseCronScheduleWithFallback_InvalidUsesFallback(t *testing.T) {
@@ -349,23 +330,43 @@ func TestParseCronScheduleWithFallback_InvalidUsesFallback(t *testing.T) {
func TestPeriodicJobsFromConfig_WorkflowSchedulerEnabledGuard(t *testing.T) {
logger := zap.NewNop().Sugar()
+ // Nothing enabled → 0 jobs
jobs := periodicJobsFromConfig(&config.Config{
DigestEnabled: false,
Workflow: &config.WorkflowConfig{
- SchedulerEnabled: false,
- Schedule: "@every 15m",
+ SchedulerEnabled: false,
+ Schedule: "@every 15m",
+ DueSoonEnabled: false,
+ TaskDigestEnabled: false,
},
}, logger)
assert.Len(t, jobs, 0)
+ // Scheduler only → 1 job
jobs = periodicJobsFromConfig(&config.Config{
DigestEnabled: false,
Workflow: &config.WorkflowConfig{
- SchedulerEnabled: true,
- Schedule: "@every 15m",
+ SchedulerEnabled: true,
+ Schedule: "@every 15m",
+ DueSoonEnabled: false,
+ TaskDigestEnabled: false,
},
}, logger)
assert.Len(t, jobs, 1)
+
+ // Scheduler + due-soon + task digest → 3 jobs
+ jobs = periodicJobsFromConfig(&config.Config{
+ DigestEnabled: false,
+ Workflow: &config.WorkflowConfig{
+ SchedulerEnabled: true,
+ Schedule: "@every 15m",
+ DueSoonEnabled: true,
+ DueSoonSchedule: "0 8 * * *",
+ TaskDigestEnabled: true,
+ TaskDigestSchedule: "0 8 * * *",
+ },
+ }, logger)
+ assert.Len(t, jobs, 3)
}
func TestWorkflowSchedulerPeriodicJobConstructor_InsertOpts(t *testing.T) {
@@ -378,13 +379,6 @@ func TestWorkflowSchedulerPeriodicJobConstructor_InsertOpts(t *testing.T) {
assert.Equal(t, "schedule_workflows", args.Kind())
}
-func TestJobInsertOptionsWithQueue(t *testing.T) {
- opts := JobInsertOptionsWithQueue("custom-queue")
-
- assert.Equal(t, "custom-queue", opts.Queue)
- assert.Equal(t, 5, opts.MaxAttempts)
-}
-
func TestJobInsertOptionsWithRetry(t *testing.T) {
opts := JobInsertOptionsWithRetry("custom-queue", 10)
diff --git a/internal/service/worker/user_repository.go b/internal/service/worker/user_repository.go
new file mode 100644
index 00000000..538bca88
--- /dev/null
+++ b/internal/service/worker/user_repository.go
@@ -0,0 +1,46 @@
+package worker
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/compliance-framework/api/internal/service/relational"
+ "github.com/google/uuid"
+ "gorm.io/gorm"
+)
+
+// GORMUserRepository implements UserRepository using GORM
+type GORMUserRepository struct {
+ db *gorm.DB
+}
+
+// NewGORMUserRepository creates a new GORMUserRepository
+func NewGORMUserRepository(db *gorm.DB) *GORMUserRepository {
+ return &GORMUserRepository{db: db}
+}
+
+// FindUserByID looks up a user by UUID string and returns a NotificationUser
+func (r *GORMUserRepository) FindUserByID(ctx context.Context, userID string) (NotificationUser, error) {
+ parsed, err := uuid.Parse(userID)
+ if err != nil {
+ return NotificationUser{}, fmt.Errorf("invalid user ID %q: %w", userID, err)
+ }
+
+ var user relational.User
+ if err := r.db.WithContext(ctx).First(&user, parsed).Error; err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return NotificationUser{}, fmt.Errorf("user %s not found", userID)
+ }
+ return NotificationUser{}, fmt.Errorf("failed to fetch user %s: %w", userID, err)
+ }
+
+ return NotificationUser{
+ ID: user.ID.String(),
+ Email: user.Email,
+ FirstName: user.FirstName,
+ LastName: user.LastName,
+ TaskAvailableEmailSubscribed: user.TaskAvailableEmailSubscribed,
+ TaskDailyDigestSubscribed: user.TaskDailyDigestSubscribed,
+ }, nil
+}
diff --git a/internal/service/worker/workflow_execution_failed_worker.go b/internal/service/worker/workflow_execution_failed_worker.go
new file mode 100644
index 00000000..72be31c8
--- /dev/null
+++ b/internal/service/worker/workflow_execution_failed_worker.go
@@ -0,0 +1,159 @@
+package worker
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/compliance-framework/api/internal/service/email/types"
+ "github.com/compliance-framework/api/internal/service/relational/workflows"
+ "github.com/compliance-framework/api/internal/workflow"
+ "github.com/google/uuid"
+ "github.com/riverqueue/river"
+ "go.uber.org/zap"
+ "gorm.io/gorm"
+)
+
+// WorkflowExecutionFailedWorker sends a failure notification email to the workflow instance creator
+type WorkflowExecutionFailedWorker struct {
+ db *gorm.DB
+ emailService EmailService
+ userRepo UserRepository
+ webBaseURL string
+ logger *zap.SugaredLogger
+}
+
+// NewWorkflowExecutionFailedWorker creates a new WorkflowExecutionFailedWorker
+func NewWorkflowExecutionFailedWorker(db *gorm.DB, emailService EmailService, userRepo UserRepository, webBaseURL string, logger *zap.SugaredLogger) *WorkflowExecutionFailedWorker {
+ return &WorkflowExecutionFailedWorker{
+ db: db,
+ emailService: emailService,
+ userRepo: userRepo,
+ webBaseURL: webBaseURL,
+ logger: logger,
+ }
+}
+
+// Work sends a failure notification email for the workflow execution identified by job.Args.WorkflowExecutionID
+func (w *WorkflowExecutionFailedWorker) Work(ctx context.Context, job *river.Job[WorkflowExecutionFailedArgs]) error {
+ args := job.Args
+
+ executionID, err := uuid.Parse(args.WorkflowExecutionID)
+ if err != nil {
+ w.logger.Warnw("WorkflowExecutionFailedWorker: invalid execution ID, skipping",
+ "workflow_execution_id", args.WorkflowExecutionID,
+ "error", err,
+ )
+ return nil
+ }
+
+ if w.db == nil {
+ return fmt.Errorf("WorkflowExecutionFailedWorker: db is nil")
+ }
+
+ var execution workflows.WorkflowExecution
+ if err := w.db.WithContext(ctx).
+ Preload("WorkflowInstance.WorkflowDefinition").
+ Preload("StepExecutions").
+ First(&execution, "id = ?", executionID).Error; err != nil {
+ return fmt.Errorf("WorkflowExecutionFailedWorker: failed to load execution %s: %w", args.WorkflowExecutionID, err)
+ }
+
+ if execution.WorkflowInstance == nil {
+ w.logger.Warnw("WorkflowExecutionFailedWorker: workflow instance not found, skipping",
+ "workflow_execution_id", args.WorkflowExecutionID,
+ )
+ return nil
+ }
+
+ instance := execution.WorkflowInstance
+ if instance.CreatedByID == nil {
+ w.logger.Warnw("WorkflowExecutionFailedWorker: instance has no CreatedByID, skipping",
+ "workflow_execution_id", args.WorkflowExecutionID,
+ "workflow_instance_id", instance.ID,
+ )
+ return nil
+ }
+
+ recipient, err := w.userRepo.FindUserByID(ctx, instance.CreatedByID.String())
+ if err != nil {
+ w.logger.Warnw("WorkflowExecutionFailedWorker: creator user not found, skipping",
+ "workflow_execution_id", args.WorkflowExecutionID,
+ "user_id", instance.CreatedByID,
+ "error", err,
+ )
+ return nil
+ }
+
+ workflowTitle := ""
+ if instance.WorkflowDefinition != nil {
+ workflowTitle = instance.WorkflowDefinition.Name
+ }
+
+ counts := workflow.CountStepStatuses(execution.StepExecutions)
+ failedSteps := counts.Failed
+ completedSteps := counts.Completed
+ totalSteps := len(execution.StepExecutions)
+
+ failedAt := "unknown"
+ if execution.FailedAt != nil {
+ failedAt = formatDate(*execution.FailedAt)
+ }
+
+ templateData := map[string]interface{}{
+ "RecipientName": recipient.FullName(),
+ "WorkflowTitle": workflowTitle,
+ "WorkflowInstanceName": instance.Name,
+ "ExecutionID": execution.ID.String(),
+ "FailureReason": execution.FailureReason,
+ "FailedAt": failedAt,
+ "FailedSteps": failedSteps,
+ "CompletedSteps": completedSteps,
+ "TotalSteps": totalSteps,
+ "WorkflowURL": w.webBaseURL + "/my-tasks",
+ "MyTasksURL": w.webBaseURL + "/my-tasks",
+ }
+
+ htmlBody, textBody, err := w.emailService.UseTemplate("workflow-execution-failed", templateData)
+ if err != nil {
+ w.logger.Errorw("WorkflowExecutionFailedWorker: failed to render template",
+ "workflow_execution_id", args.WorkflowExecutionID,
+ "error", err,
+ )
+ return fmt.Errorf("failed to render workflow-execution-failed template: %w", err)
+ }
+
+ message := &types.Message{
+ From: w.emailService.GetDefaultFromAddress(),
+ To: []string{recipient.Email},
+ Subject: fmt.Sprintf("Workflow execution failed: %s", instance.Name),
+ HTMLBody: htmlBody,
+ TextBody: textBody,
+ }
+
+ result, err := w.emailService.Send(ctx, message)
+ if err != nil {
+ w.logger.Errorw("WorkflowExecutionFailedWorker: failed to send email",
+ "workflow_execution_id", args.WorkflowExecutionID,
+ "recipient", recipient.Email,
+ "error", err,
+ )
+ return fmt.Errorf("failed to send workflow-execution-failed email: %w", err)
+ }
+
+ if !result.Success {
+ w.logger.Errorw("WorkflowExecutionFailedWorker: email send reported failure",
+ "workflow_execution_id", args.WorkflowExecutionID,
+ "recipient", recipient.Email,
+ "error", result.Error,
+ )
+ return fmt.Errorf("workflow-execution-failed email send failed: %s", result.Error)
+ }
+
+ w.logger.Infow("WorkflowExecutionFailedWorker: failure notification sent",
+ "workflow_execution_id", args.WorkflowExecutionID,
+ "recipient", recipient.Email,
+ "message_id", result.MessageID,
+ )
+
+ return nil
+}
diff --git a/internal/service/worker/workflow_execution_failed_worker_test.go b/internal/service/worker/workflow_execution_failed_worker_test.go
new file mode 100644
index 00000000..836122ae
--- /dev/null
+++ b/internal/service/worker/workflow_execution_failed_worker_test.go
@@ -0,0 +1,43 @@
+package worker
+
+import (
+ "context"
+ "testing"
+
+ "github.com/google/uuid"
+ "github.com/riverqueue/river"
+ "github.com/stretchr/testify/assert"
+ "go.uber.org/zap"
+)
+
+func makeFailedJob(args WorkflowExecutionFailedArgs) *river.Job[WorkflowExecutionFailedArgs] {
+ return &river.Job[WorkflowExecutionFailedArgs]{Args: args}
+}
+
+func TestWorkflowExecutionFailedWorker_InvalidExecutionID_Skips(t *testing.T) {
+ ctx := context.Background()
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ w := NewWorkflowExecutionFailedWorker(nil, mockEmail, mockRepo, "http://localhost:8000", mockLog)
+
+ err := w.Work(ctx, makeFailedJob(WorkflowExecutionFailedArgs{WorkflowExecutionID: "not-a-uuid"}))
+ assert.NoError(t, err)
+ mockEmail.AssertNotCalled(t, "Send")
+}
+
+func TestWorkflowExecutionFailedWorker_NilDB_ReturnsError(t *testing.T) {
+ ctx := context.Background()
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ w := NewWorkflowExecutionFailedWorker(nil, mockEmail, mockRepo, "http://localhost:8000", mockLog)
+
+ err := w.Work(ctx, makeFailedJob(WorkflowExecutionFailedArgs{WorkflowExecutionID: uuid.New().String()}))
+ assert.Error(t, err)
+ mockEmail.AssertNotCalled(t, "Send")
+}
diff --git a/internal/service/worker/workflow_task_assigned_worker_test.go b/internal/service/worker/workflow_task_assigned_worker_test.go
new file mode 100644
index 00000000..36d5efaa
--- /dev/null
+++ b/internal/service/worker/workflow_task_assigned_worker_test.go
@@ -0,0 +1,151 @@
+package worker
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/compliance-framework/api/internal/service/email/types"
+ "github.com/riverqueue/river"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "go.uber.org/zap"
+)
+
+// MockUserRepository is a mock implementation of UserRepository
+type MockUserRepository struct {
+ mock.Mock
+}
+
+func (m *MockUserRepository) FindUserByID(ctx context.Context, userID string) (NotificationUser, error) {
+ args := m.Called(ctx, userID)
+ return args.Get(0).(NotificationUser), args.Error(1)
+}
+
+func makeTaskAssignedJob(args WorkflowTaskAssignedArgs) *river.Job[WorkflowTaskAssignedArgs] {
+ return &river.Job[WorkflowTaskAssignedArgs]{Args: args}
+}
+
+func TestWorkflowTaskAssignedWorker_SubscribedUser_SendsEmail(t *testing.T) {
+ ctx := context.Background()
+ dueDate := time.Now().Add(48 * time.Hour)
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ user := NotificationUser{
+ ID: "user-1",
+ Email: "alice@example.com",
+ FirstName: "Alice",
+ LastName: "Smith",
+ TaskAvailableEmailSubscribed: true,
+ }
+ mockRepo.On("FindUserByID", ctx, "user-1").Return(user, nil)
+ mockEmail.On("UseTemplate", "workflow-task-assigned", mock.Anything).Return("Task", "Task text", nil)
+ mockEmail.On("GetDefaultFromAddress").Return("noreply@example.com")
+ mockEmail.On("Send", ctx, mock.MatchedBy(func(msg *types.Message) bool {
+ return msg.To[0] == "alice@example.com"
+ })).Return(&types.SendResult{Success: true, MessageID: "msg-1"}, nil)
+
+ w := NewWorkflowTaskAssignedWorker(mockEmail, mockRepo, "http://localhost:8000", mockLog)
+
+ args := WorkflowTaskAssignedArgs{
+ UserID: "user-1",
+ StepExecutionID: "step-1",
+ StepTitle: "Review Policy",
+ WorkflowTitle: "Annual Audit",
+ WorkflowInstanceTitle: "Audit 2026",
+ StepURL: "https://app.example.com/steps/step-1",
+ DueDate: &dueDate,
+ }
+
+ err := w.Work(ctx, makeTaskAssignedJob(args))
+ assert.NoError(t, err)
+ mockEmail.AssertExpectations(t)
+ mockRepo.AssertExpectations(t)
+}
+
+func TestWorkflowTaskAssignedWorker_UnsubscribedUser_Skips(t *testing.T) {
+ ctx := context.Background()
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ user := NotificationUser{
+ ID: "user-2",
+ Email: "bob@example.com",
+ FirstName: "Bob",
+ TaskAvailableEmailSubscribed: false,
+ }
+ mockRepo.On("FindUserByID", ctx, "user-2").Return(user, nil)
+
+ w := NewWorkflowTaskAssignedWorker(mockEmail, mockRepo, "http://localhost:8000", mockLog)
+
+ args := WorkflowTaskAssignedArgs{
+ UserID: "user-2",
+ StepExecutionID: "step-2",
+ StepTitle: "Review Policy",
+ WorkflowTitle: "Annual Audit",
+ }
+
+ err := w.Work(ctx, makeTaskAssignedJob(args))
+ assert.NoError(t, err)
+ // Send must NOT be called
+ mockEmail.AssertNotCalled(t, "Send")
+}
+
+func TestWorkflowTaskAssignedWorker_UserNotFound_Skips(t *testing.T) {
+ ctx := context.Background()
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ mockRepo.On("FindUserByID", ctx, "missing-user").Return(NotificationUser{}, errors.New("not found"))
+
+ w := NewWorkflowTaskAssignedWorker(mockEmail, mockRepo, "http://localhost:8000", mockLog)
+
+ args := WorkflowTaskAssignedArgs{
+ UserID: "missing-user",
+ StepExecutionID: "step-3",
+ }
+
+ err := w.Work(ctx, makeTaskAssignedJob(args))
+ // Should return nil (non-fatal skip)
+ assert.NoError(t, err)
+ mockEmail.AssertNotCalled(t, "Send")
+}
+
+func TestWorkflowTaskAssignedWorker_TemplateError_ReturnsError(t *testing.T) {
+ ctx := context.Background()
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ user := NotificationUser{
+ ID: "user-3",
+ Email: "carol@example.com",
+ FirstName: "Carol",
+ TaskAvailableEmailSubscribed: true,
+ }
+ mockRepo.On("FindUserByID", ctx, "user-3").Return(user, nil)
+ mockEmail.On("UseTemplate", "workflow-task-assigned", mock.Anything).Return("", "", errors.New("template broken"))
+ mockEmail.On("GetDefaultFromAddress").Return("noreply@example.com")
+
+ w := NewWorkflowTaskAssignedWorker(mockEmail, mockRepo, "http://localhost:8000", mockLog)
+
+ args := WorkflowTaskAssignedArgs{
+ UserID: "user-3",
+ StepExecutionID: "step-4",
+ StepTitle: "Review",
+ WorkflowTitle: "Audit",
+ }
+
+ err := w.Work(ctx, makeTaskAssignedJob(args))
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "workflow-task-assigned template")
+}
diff --git a/internal/service/worker/workflow_task_digest_checker.go b/internal/service/worker/workflow_task_digest_checker.go
new file mode 100644
index 00000000..f0a480e0
--- /dev/null
+++ b/internal/service/worker/workflow_task_digest_checker.go
@@ -0,0 +1,93 @@
+package worker
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/compliance-framework/api/internal/service/relational"
+ "github.com/compliance-framework/api/internal/workflow"
+ "github.com/riverqueue/river"
+ "go.uber.org/zap"
+ "gorm.io/gorm"
+)
+
+// WorkflowTaskDigestCheckerArgs represents the arguments for the periodic digest checker job
+type WorkflowTaskDigestCheckerArgs struct{}
+
+// Kind returns the job kind for River
+func (WorkflowTaskDigestCheckerArgs) Kind() string { return "workflow_task_digest_checker" }
+
+// Timeout returns the timeout for the digest checker job
+func (WorkflowTaskDigestCheckerArgs) Timeout() time.Duration { return 5 * time.Minute }
+
+// WorkflowTaskDigestCheckerWorker queries all subscribed users and enqueues per-user digest jobs
+type WorkflowTaskDigestCheckerWorker struct {
+ db *gorm.DB
+ client workflow.RiverClient
+ logger *zap.SugaredLogger
+}
+
+// NewWorkflowTaskDigestCheckerWorker creates a new WorkflowTaskDigestCheckerWorker
+func NewWorkflowTaskDigestCheckerWorker(db *gorm.DB, client workflow.RiverClient, logger *zap.SugaredLogger) *WorkflowTaskDigestCheckerWorker {
+ return &WorkflowTaskDigestCheckerWorker{
+ db: db,
+ client: client,
+ logger: logger,
+ }
+}
+
+// Work queries all users with TaskDailyDigestSubscribed=true and enqueues a WorkflowTaskDigestArgs job for each
+func (w *WorkflowTaskDigestCheckerWorker) Work(ctx context.Context, job *river.Job[WorkflowTaskDigestCheckerArgs]) error {
+ if w.db == nil {
+ return fmt.Errorf("WorkflowTaskDigestCheckerWorker: db is nil")
+ }
+
+ var users []relational.User
+ if err := w.db.WithContext(ctx).
+ Where("task_daily_digest_subscribed = ? AND deleted_at IS NULL", true).
+ Find(&users).Error; err != nil {
+ return fmt.Errorf("workflow-task-digest-checker: failed to query subscribed users: %w", err)
+ }
+
+ if len(users) == 0 {
+ w.logger.Infow("WorkflowTaskDigestCheckerWorker: no subscribed users found")
+ return nil
+ }
+
+ params := make([]river.InsertManyParams, 0, len(users))
+ for i := range users {
+ if users[i].ID == nil {
+ continue
+ }
+ params = append(params, river.InsertManyParams{
+ Args: WorkflowTaskDigestArgs{UserID: users[i].ID.String()},
+ InsertOpts: &river.InsertOpts{
+ Queue: "digest",
+ MaxAttempts: 3,
+ },
+ })
+ }
+
+ if len(params) == 0 {
+ return nil
+ }
+
+ results, err := w.client.InsertMany(ctx, params)
+ if err != nil {
+ return fmt.Errorf("workflow-task-digest-checker: failed to enqueue digest jobs: %w", err)
+ }
+
+ inserted := 0
+ for _, r := range results {
+ if r != nil && r.Job != nil {
+ inserted++
+ }
+ }
+
+ w.logger.Infow("WorkflowTaskDigestCheckerWorker: enqueued digest jobs",
+ "total_users", len(users),
+ "enqueued", inserted,
+ )
+ return nil
+}
diff --git a/internal/service/worker/workflow_task_digest_worker.go b/internal/service/worker/workflow_task_digest_worker.go
new file mode 100644
index 00000000..0e702cc5
--- /dev/null
+++ b/internal/service/worker/workflow_task_digest_worker.go
@@ -0,0 +1,176 @@
+package worker
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/compliance-framework/api/internal/service/email/types"
+ "github.com/compliance-framework/api/internal/service/relational/workflows"
+ "github.com/riverqueue/river"
+ "go.uber.org/zap"
+ "gorm.io/gorm"
+)
+
+// DigestTask represents a single task entry in the digest email
+type DigestTask struct {
+ StepTitle string
+ WorkflowTitle string
+ WorkflowInstanceTitle string
+ DueDate *string
+ StepURL string
+}
+
+// WorkflowTaskDigestWorker sends a per-user digest of pending and overdue workflow tasks
+type WorkflowTaskDigestWorker struct {
+ db *gorm.DB
+ emailService EmailService
+ userRepo UserRepository
+ webBaseURL string
+ logger *zap.SugaredLogger
+}
+
+// NewWorkflowTaskDigestWorker creates a new WorkflowTaskDigestWorker
+func NewWorkflowTaskDigestWorker(db *gorm.DB, emailService EmailService, userRepo UserRepository, webBaseURL string, logger *zap.SugaredLogger) *WorkflowTaskDigestWorker {
+ return &WorkflowTaskDigestWorker{
+ db: db,
+ emailService: emailService,
+ userRepo: userRepo,
+ webBaseURL: webBaseURL,
+ logger: logger,
+ }
+}
+
+// Work sends a digest email for the user identified by job.Args.UserID
+func (w *WorkflowTaskDigestWorker) Work(ctx context.Context, job *river.Job[WorkflowTaskDigestArgs]) error {
+ args := job.Args
+
+ user, err := w.userRepo.FindUserByID(ctx, args.UserID)
+ if err != nil {
+ w.logger.Warnw("WorkflowTaskDigestWorker: user not found, skipping",
+ "user_id", args.UserID,
+ "error", err,
+ )
+ return nil
+ }
+
+ if !user.TaskDailyDigestSubscribed {
+ w.logger.Debugw("WorkflowTaskDigestWorker: user not subscribed to digest, skipping",
+ "user_id", args.UserID,
+ )
+ return nil
+ }
+
+ if w.db == nil {
+ return fmt.Errorf("WorkflowTaskDigestWorker: db is nil")
+ }
+
+ now := time.Now()
+
+ var steps []workflows.StepExecution
+ if err := w.db.WithContext(ctx).
+ Preload("WorkflowExecution.WorkflowInstance.WorkflowDefinition").
+ Preload("WorkflowStepDefinition").
+ Where("assigned_to_type = ? AND assigned_to_id = ? AND status IN ?",
+ workflows.AssignmentTypeUser.String(),
+ args.UserID,
+ []string{
+ workflows.StepStatusPending.String(),
+ workflows.StepStatusInProgress.String(),
+ workflows.StepStatusOverdue.String(),
+ },
+ ).
+ Find(&steps).Error; err != nil {
+ return fmt.Errorf("WorkflowTaskDigestWorker: failed to query steps for user %s: %w", args.UserID, err)
+ }
+
+ if len(steps) == 0 {
+ w.logger.Debugw("WorkflowTaskDigestWorker: no tasks for user, skipping",
+ "user_id", args.UserID,
+ )
+ return nil
+ }
+
+ var pendingTasks []DigestTask
+ var overdueTasks []DigestTask
+
+ for i := range steps {
+ step := &steps[i]
+ task := buildDigestTask(step)
+
+ if step.Status == workflows.StepStatusOverdue.String() ||
+ (step.DueDate != nil && step.DueDate.Before(now)) {
+ overdueTasks = append(overdueTasks, task)
+ } else {
+ pendingTasks = append(pendingTasks, task)
+ }
+ }
+
+ periodLabel := "Daily digest — " + now.Format("Monday, 2 January 2006")
+
+ templateData := map[string]interface{}{
+ "UserName": user.FullName(),
+ "PeriodLabel": periodLabel,
+ "PendingTasks": pendingTasks,
+ "OverdueTasks": overdueTasks,
+ "MyTasksURL": w.webBaseURL + "/my-tasks",
+ }
+
+ htmlBody, textBody, err := w.emailService.UseTemplate("workflow-task-digest", templateData)
+ if err != nil {
+ w.logger.Errorw("WorkflowTaskDigestWorker: failed to render template",
+ "user_id", args.UserID,
+ "error", err,
+ )
+ return fmt.Errorf("failed to render workflow-task-digest template: %w", err)
+ }
+
+ message := &types.Message{
+ From: w.emailService.GetDefaultFromAddress(),
+ To: []string{user.Email},
+ Subject: fmt.Sprintf("Your workflow task summary — %s", formatDate(now)),
+ HTMLBody: htmlBody,
+ TextBody: textBody,
+ }
+
+ result, err := w.emailService.Send(ctx, message)
+ if err != nil {
+ w.logger.Errorw("WorkflowTaskDigestWorker: failed to send email",
+ "user_id", args.UserID,
+ "error", err,
+ )
+ return fmt.Errorf("failed to send workflow-task-digest email: %w", err)
+ }
+
+ if !result.Success {
+ w.logger.Errorw("WorkflowTaskDigestWorker: email send reported failure",
+ "user_id", args.UserID,
+ "error", result.Error,
+ )
+ return fmt.Errorf("workflow-task-digest email send failed: %s", result.Error)
+ }
+
+ w.logger.Infow("WorkflowTaskDigestWorker: digest email sent",
+ "user_id", args.UserID,
+ "pending", len(pendingTasks),
+ "overdue", len(overdueTasks),
+ "message_id", result.MessageID,
+ )
+
+ return nil
+}
+
+func buildDigestTask(step *workflows.StepExecution) DigestTask {
+ task := DigestTask{}
+ titles := resolveStepTitles(step)
+
+ task.StepTitle = titles.Step
+ task.WorkflowTitle = titles.Workflow
+ task.WorkflowInstanceTitle = titles.Instance
+ if step.DueDate != nil {
+ formatted := formatDate(*step.DueDate)
+ task.DueDate = &formatted
+ }
+
+ return task
+}
diff --git a/internal/service/worker/workflow_task_digest_worker_test.go b/internal/service/worker/workflow_task_digest_worker_test.go
new file mode 100644
index 00000000..34898ae9
--- /dev/null
+++ b/internal/service/worker/workflow_task_digest_worker_test.go
@@ -0,0 +1,53 @@
+package worker
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/riverqueue/river"
+ "github.com/stretchr/testify/assert"
+ "go.uber.org/zap"
+)
+
+func makeDigestJob(args WorkflowTaskDigestArgs) *river.Job[WorkflowTaskDigestArgs] {
+ return &river.Job[WorkflowTaskDigestArgs]{Args: args}
+}
+
+func TestWorkflowTaskDigestWorker_UnsubscribedUser_Skips(t *testing.T) {
+ ctx := context.Background()
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ user := NotificationUser{
+ ID: "user-1",
+ Email: "alice@example.com",
+ FirstName: "Alice",
+ TaskDailyDigestSubscribed: false,
+ }
+ mockRepo.On("FindUserByID", ctx, "user-1").Return(user, nil)
+
+ w := NewWorkflowTaskDigestWorker(nil, mockEmail, mockRepo, "", mockLog)
+
+ err := w.Work(ctx, makeDigestJob(WorkflowTaskDigestArgs{UserID: "user-1"}))
+ assert.NoError(t, err)
+ mockEmail.AssertNotCalled(t, "Send")
+}
+
+func TestWorkflowTaskDigestWorker_UserNotFound_Skips(t *testing.T) {
+ ctx := context.Background()
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ mockRepo.On("FindUserByID", ctx, "missing").Return(NotificationUser{}, errors.New("not found"))
+
+ w := NewWorkflowTaskDigestWorker(nil, mockEmail, mockRepo, "", mockLog)
+
+ err := w.Work(ctx, makeDigestJob(WorkflowTaskDigestArgs{UserID: "missing"}))
+ assert.NoError(t, err)
+ mockEmail.AssertNotCalled(t, "Send")
+}
diff --git a/internal/service/worker/workflow_task_due_soon_worker_test.go b/internal/service/worker/workflow_task_due_soon_worker_test.go
new file mode 100644
index 00000000..62d28550
--- /dev/null
+++ b/internal/service/worker/workflow_task_due_soon_worker_test.go
@@ -0,0 +1,142 @@
+package worker
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/compliance-framework/api/internal/service/email/types"
+ "github.com/riverqueue/river"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "go.uber.org/zap"
+)
+
+func makeDueSoonJob(args WorkflowTaskDueSoonArgs) *river.Job[WorkflowTaskDueSoonArgs] {
+ return &river.Job[WorkflowTaskDueSoonArgs]{Args: args}
+}
+
+func TestWorkflowTaskDueSoonWorker_SubscribedUser_SendsEmail(t *testing.T) {
+ ctx := context.Background()
+ dueDate := time.Now().Add(24 * time.Hour)
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ user := NotificationUser{
+ ID: "user-1",
+ Email: "alice@example.com",
+ FirstName: "Alice",
+ LastName: "Smith",
+ TaskAvailableEmailSubscribed: true,
+ }
+ mockRepo.On("FindUserByID", ctx, "user-1").Return(user, nil)
+ mockEmail.On("UseTemplate", "workflow-task-due-soon", mock.Anything).Return("Due Soon", "Due soon text", nil)
+ mockEmail.On("GetDefaultFromAddress").Return("noreply@example.com")
+ mockEmail.On("Send", ctx, mock.MatchedBy(func(msg *types.Message) bool {
+ return msg.To[0] == "alice@example.com" && msg.Subject != ""
+ })).Return(&types.SendResult{Success: true, MessageID: "msg-2"}, nil)
+
+ w := NewWorkflowTaskDueSoonWorker(mockEmail, mockRepo, "http://localhost:8000", mockLog)
+
+ args := WorkflowTaskDueSoonArgs{
+ UserID: "user-1",
+ StepExecutionID: "step-1",
+ StepTitle: "Submit Evidence",
+ WorkflowTitle: "SOC2 Audit",
+ WorkflowInstanceTitle: "SOC2 2026",
+ StepURL: "https://app.example.com/steps/step-1",
+ DueDate: dueDate,
+ }
+
+ err := w.Work(ctx, makeDueSoonJob(args))
+ assert.NoError(t, err)
+ mockEmail.AssertExpectations(t)
+ mockRepo.AssertExpectations(t)
+}
+
+func TestWorkflowTaskDueSoonWorker_UnsubscribedUser_Skips(t *testing.T) {
+ ctx := context.Background()
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ user := NotificationUser{
+ ID: "user-2",
+ Email: "bob@example.com",
+ FirstName: "Bob",
+ TaskAvailableEmailSubscribed: false,
+ }
+ mockRepo.On("FindUserByID", ctx, "user-2").Return(user, nil)
+
+ w := NewWorkflowTaskDueSoonWorker(mockEmail, mockRepo, "http://localhost:8000", mockLog)
+
+ args := WorkflowTaskDueSoonArgs{
+ UserID: "user-2",
+ StepExecutionID: "step-2",
+ StepTitle: "Submit Evidence",
+ WorkflowTitle: "SOC2 Audit",
+ DueDate: time.Now().Add(24 * time.Hour),
+ }
+
+ err := w.Work(ctx, makeDueSoonJob(args))
+ assert.NoError(t, err)
+ mockEmail.AssertNotCalled(t, "Send")
+}
+
+func TestWorkflowTaskDueSoonWorker_UserNotFound_Skips(t *testing.T) {
+ ctx := context.Background()
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ mockRepo.On("FindUserByID", ctx, "missing-user").Return(NotificationUser{}, errors.New("not found"))
+
+ w := NewWorkflowTaskDueSoonWorker(mockEmail, mockRepo, "http://localhost:8000", mockLog)
+
+ args := WorkflowTaskDueSoonArgs{
+ UserID: "missing-user",
+ StepExecutionID: "step-3",
+ DueDate: time.Now().Add(24 * time.Hour),
+ }
+
+ err := w.Work(ctx, makeDueSoonJob(args))
+ assert.NoError(t, err)
+ mockEmail.AssertNotCalled(t, "Send")
+}
+
+func TestWorkflowTaskDueSoonWorker_TemplateError_ReturnsError(t *testing.T) {
+ ctx := context.Background()
+
+ mockEmail := &MockEmailService{}
+ mockRepo := &MockUserRepository{}
+ mockLog := zap.NewNop().Sugar()
+
+ user := NotificationUser{
+ ID: "user-3",
+ Email: "carol@example.com",
+ FirstName: "Carol",
+ TaskAvailableEmailSubscribed: true,
+ }
+ mockRepo.On("FindUserByID", ctx, "user-3").Return(user, nil)
+ mockEmail.On("UseTemplate", "workflow-task-due-soon", mock.Anything).Return("", "", errors.New("template broken"))
+ mockEmail.On("GetDefaultFromAddress").Return("noreply@example.com")
+
+ w := NewWorkflowTaskDueSoonWorker(mockEmail, mockRepo, "http://localhost:8000", mockLog)
+
+ args := WorkflowTaskDueSoonArgs{
+ UserID: "user-3",
+ StepExecutionID: "step-4",
+ StepTitle: "Submit Evidence",
+ WorkflowTitle: "SOC2 Audit",
+ DueDate: time.Now().Add(24 * time.Hour),
+ }
+
+ err := w.Work(ctx, makeDueSoonJob(args))
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "workflow-task-due-soon template")
+}
diff --git a/internal/workflow/assignment.go b/internal/workflow/assignment.go
index bb376880..b7567ed3 100644
--- a/internal/workflow/assignment.go
+++ b/internal/workflow/assignment.go
@@ -10,6 +10,7 @@ import (
"github.com/compliance-framework/api/internal/service/relational"
"github.com/compliance-framework/api/internal/service/relational/workflows"
"github.com/google/uuid"
+ "go.uber.org/zap"
"gorm.io/gorm"
)
@@ -22,14 +23,36 @@ type Assignee struct {
// AssignmentService handles logic for resolving step assignments
type AssignmentService struct {
roleAssignmentService RoleAssignmentServiceInterface
+ stepExecutionService StepExecutionAssignmentService
db *gorm.DB
+ notificationEnqueuer NotificationEnqueuer // Optional: for task-assigned notification emails
+ logger *zap.SugaredLogger
+}
+
+type StepExecutionAssignmentService interface {
+ ReassignWithTx(tx *gorm.DB, id *uuid.UUID, assignedToType, assignedToID string, assignedAt time.Time) error
}
// NewAssignmentService creates a new assignment service
-func NewAssignmentService(roleAssignmentService RoleAssignmentServiceInterface, db *gorm.DB) *AssignmentService {
+func NewAssignmentService(
+ roleAssignmentService RoleAssignmentServiceInterface,
+ stepExecutionService StepExecutionAssignmentService,
+ db *gorm.DB,
+ logger *zap.SugaredLogger,
+ notificationEnqueuer NotificationEnqueuer,
+) *AssignmentService {
+ if stepExecutionService == nil && db != nil {
+ stepExecutionService = workflows.NewStepExecutionService(db, nil)
+ }
+ if logger == nil {
+ logger = zap.NewNop().Sugar()
+ }
return &AssignmentService{
roleAssignmentService: roleAssignmentService,
+ stepExecutionService: stepExecutionService,
db: db,
+ logger: logger,
+ notificationEnqueuer: notificationEnqueuer,
}
}
@@ -86,40 +109,42 @@ func (s *AssignmentService) ReassignStep(
if s.db == nil {
return fmt.Errorf("assignment service database is not configured")
}
+ if s.stepExecutionService == nil {
+ return fmt.Errorf("assignment service step execution service is not configured")
+ }
if err := s.validateAssignee(newAssignee); err != nil {
return err
}
- return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
- var stepExecution workflows.StepExecution
- if err := tx.First(&stepExecution, stepExecutionID).Error; err != nil {
+ var updatedStepExecution workflows.StepExecution
+ if err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
+ if err := tx.First(&updatedStepExecution, stepExecutionID).Error; err != nil {
return err
}
- if !isReassignableStatus(stepExecution.Status) {
- return fmt.Errorf("%w: current status is %s", ErrReassignmentNotAllowed, stepExecution.Status)
+ if !isReassignableStatus(updatedStepExecution.Status) {
+ return fmt.Errorf("%w: current status is %s", ErrReassignmentNotAllowed, updatedStepExecution.Status)
}
if err := s.validateAssigneeExists(tx, newAssignee); err != nil {
return err
}
+ prevType := updatedStepExecution.AssignedToType
+ prevID := updatedStepExecution.AssignedToID
+
now := time.Now()
- if err := tx.Model(&workflows.StepExecution{}).
- Where("id = ?", stepExecutionID).
- Updates(map[string]interface{}{
- "assigned_to_type": newAssignee.Type,
- "assigned_to_id": newAssignee.ID,
- "assigned_at": now,
- }).Error; err != nil {
+ if err := s.stepExecutionService.ReassignWithTx(tx, updatedStepExecution.ID, newAssignee.Type, newAssignee.ID, now); err != nil {
return err
}
+ updatedStepExecution.AssignedToType = newAssignee.Type
+ updatedStepExecution.AssignedToID = newAssignee.ID
history := &workflows.StepReassignmentHistory{
- StepExecutionID: stepExecution.ID,
- WorkflowExecutionID: stepExecution.WorkflowExecutionID,
- PreviousAssignedToType: stepExecution.AssignedToType,
- PreviousAssignedToID: stepExecution.AssignedToID,
+ StepExecutionID: updatedStepExecution.ID,
+ WorkflowExecutionID: updatedStepExecution.WorkflowExecutionID,
+ PreviousAssignedToType: prevType,
+ PreviousAssignedToID: prevID,
NewAssignedToType: newAssignee.Type,
NewAssignedToID: newAssignee.ID,
Reason: reason,
@@ -127,7 +152,19 @@ func (s *AssignmentService) ReassignStep(
ReassignedByEmail: reassignedByEmail,
}
return tx.Create(history).Error
- })
+ }); err != nil {
+ return err
+ }
+
+ if s.notificationEnqueuer != nil &&
+ (newAssignee.Type == workflows.AssignmentTypeUser.String() || newAssignee.Type == workflows.AssignmentTypeEmail.String()) {
+ if err := s.notificationEnqueuer.EnqueueWorkflowTaskAssigned(ctx, &updatedStepExecution); err != nil {
+ // Non-fatal: log but don't fail the reassignment
+ s.logger.Errorw("failed to enqueue workflow task assigned notification", "error", err)
+ }
+ }
+
+ return nil
}
func (s *AssignmentService) BulkReassignByRole(
@@ -142,6 +179,9 @@ func (s *AssignmentService) BulkReassignByRole(
if s.db == nil {
return nil, fmt.Errorf("assignment service database is not configured")
}
+ if s.stepExecutionService == nil {
+ return nil, fmt.Errorf("assignment service step execution service is not configured")
+ }
if roleName == "" {
return nil, fmt.Errorf("role name is required")
}
@@ -180,13 +220,7 @@ func (s *AssignmentService) BulkReassignByRole(
continue
}
- if err := tx.Model(&workflows.StepExecution{}).
- Where("id = ?", stepExecution.ID).
- Updates(map[string]interface{}{
- "assigned_to_type": newAssignee.Type,
- "assigned_to_id": newAssignee.ID,
- "assigned_at": now,
- }).Error; err != nil {
+ if err := s.stepExecutionService.ReassignWithTx(tx, stepExecution.ID, newAssignee.Type, newAssignee.ID, now); err != nil {
return err
}
@@ -215,6 +249,24 @@ func (s *AssignmentService) BulkReassignByRole(
return nil, err
}
+ if s.notificationEnqueuer != nil &&
+ (newAssignee.Type == workflows.AssignmentTypeUser.String() || newAssignee.Type == workflows.AssignmentTypeEmail.String()) &&
+ len(result.ReassignedStepExecIDs) > 0 {
+ var reassignedSteps []workflows.StepExecution
+ if err := s.db.WithContext(ctx).
+ Where("id IN ?", result.ReassignedStepExecIDs).
+ Find(&reassignedSteps).Error; err != nil {
+ s.logger.Errorw("failed to load reassigned steps for bulk notification enqueue", "error", err)
+ return result, nil
+ }
+
+ for i := range reassignedSteps {
+ if err := s.notificationEnqueuer.EnqueueWorkflowTaskAssigned(ctx, &reassignedSteps[i]); err != nil {
+ s.logger.Errorw("failed to enqueue workflow task assigned notification for bulk reassignment", "step_execution_id", reassignedSteps[i].ID, "error", err)
+ }
+ }
+ }
+
return result, nil
}
diff --git a/internal/workflow/assignment_test.go b/internal/workflow/assignment_test.go
index 2c290b33..5ec9cf80 100644
--- a/internal/workflow/assignment_test.go
+++ b/internal/workflow/assignment_test.go
@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
+ "go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
@@ -36,6 +37,20 @@ func (m *MockRoleAssignmentService) GetByWorkflowInstanceID(instanceID *uuid.UUI
return args.Get(0).([]workflows.RoleAssignment), args.Error(1)
}
+type MockAssignmentNotificationEnqueuer struct {
+ mock.Mock
+}
+
+func (m *MockAssignmentNotificationEnqueuer) EnqueueWorkflowTaskAssigned(ctx context.Context, stepExecution *workflows.StepExecution) error {
+ args := m.Called(ctx, stepExecution)
+ return args.Error(0)
+}
+
+func (m *MockAssignmentNotificationEnqueuer) EnqueueWorkflowExecutionFailed(ctx context.Context, execution *workflows.WorkflowExecution) error {
+ args := m.Called(ctx, execution)
+ return args.Error(0)
+}
+
func TestResolveStepAssignees(t *testing.T) {
instanceID := uuid.New()
instance := &workflows.WorkflowInstance{
@@ -62,7 +77,7 @@ func TestResolveStepAssignees(t *testing.T) {
}
mockRoleService := new(MockRoleAssignmentService)
- assignmentService := NewAssignmentService(mockRoleService, nil)
+ assignmentService := NewAssignmentService(mockRoleService, nil, nil, zap.NewNop().Sugar(), nil)
// Mock responses
// Step 1: Role "admin" -> User "user-1"
@@ -128,7 +143,7 @@ func TestResolveStepAssignees_NoRole(t *testing.T) {
}
mockRoleService := new(MockRoleAssignmentService)
- assignmentService := NewAssignmentService(mockRoleService, nil)
+ assignmentService := NewAssignmentService(mockRoleService, nil, nil, zap.NewNop().Sugar(), nil)
assignments, err := assignmentService.ResolveStepAssignees(context.Background(), instance, steps)
assert.NoError(t, err)
@@ -194,7 +209,7 @@ func createAssignmentServiceGraph(t *testing.T, db *gorm.DB) (*workflows.Workflo
func TestReassignStep(t *testing.T) {
db := setupAssignmentServiceTestDB(t)
roleService := new(MockRoleAssignmentService)
- service := NewAssignmentService(roleService, db)
+ service := NewAssignmentService(roleService, nil, db, zap.NewNop().Sugar(), nil)
_, _, stepExec := createAssignmentServiceGraph(t, db)
@@ -252,7 +267,7 @@ func TestReassignStep(t *testing.T) {
func TestReassignStep_RejectsInvalidStatus(t *testing.T) {
db := setupAssignmentServiceTestDB(t)
roleService := new(MockRoleAssignmentService)
- service := NewAssignmentService(roleService, db)
+ service := NewAssignmentService(roleService, nil, db, zap.NewNop().Sugar(), nil)
_, _, stepExec := createAssignmentServiceGraph(t, db)
@@ -277,7 +292,7 @@ func TestReassignStep_RejectsInvalidStatus(t *testing.T) {
func TestReassignStep_AllowsOverdueStatus(t *testing.T) {
db := setupAssignmentServiceTestDB(t)
roleService := new(MockRoleAssignmentService)
- service := NewAssignmentService(roleService, db)
+ service := NewAssignmentService(roleService, nil, db, zap.NewNop().Sugar(), nil)
_, _, stepExec := createAssignmentServiceGraph(t, db)
stepExec.Status = workflows.StepStatusOverdue.String()
@@ -302,7 +317,7 @@ func TestReassignStep_AllowsOverdueStatus(t *testing.T) {
func TestReassignStep_RejectsInvalidAssigneeAndMissingUser(t *testing.T) {
db := setupAssignmentServiceTestDB(t)
roleService := new(MockRoleAssignmentService)
- service := NewAssignmentService(roleService, db)
+ service := NewAssignmentService(roleService, nil, db, zap.NewNop().Sugar(), nil)
_, _, stepExec := createAssignmentServiceGraph(t, db)
@@ -333,7 +348,7 @@ func TestReassignStep_RejectsInvalidAssigneeAndMissingUser(t *testing.T) {
func TestBulkReassignByRole(t *testing.T) {
db := setupAssignmentServiceTestDB(t)
roleService := new(MockRoleAssignmentService)
- service := NewAssignmentService(roleService, db)
+ service := NewAssignmentService(roleService, nil, db, zap.NewNop().Sugar(), nil)
execution, stepDef, stepExec := createAssignmentServiceGraph(t, db)
@@ -408,3 +423,58 @@ func TestBulkReassignByRole(t *testing.T) {
require.NoError(t, db.First(&otherRole, otherRoleStep.ID).Error)
assert.Equal(t, "other-role", otherRole.AssignedToID)
}
+
+func TestBulkReassignByRole_EnqueuesNotificationsForReassignedSteps(t *testing.T) {
+ db := setupAssignmentServiceTestDB(t)
+ roleService := new(MockRoleAssignmentService)
+ notificationEnqueuer := new(MockAssignmentNotificationEnqueuer)
+ service := NewAssignmentService(roleService, nil, db, zap.NewNop().Sugar(), notificationEnqueuer)
+
+ execution, stepDef, stepExec := createAssignmentServiceGraph(t, db)
+
+ secondStepDef := &workflows.WorkflowStepDefinition{
+ WorkflowDefinitionID: stepDef.WorkflowDefinitionID,
+ Name: "Step 2",
+ ResponsibleRole: stepDef.ResponsibleRole,
+ }
+ require.NoError(t, db.Create(secondStepDef).Error)
+
+ secondStep := &workflows.StepExecution{
+ WorkflowExecutionID: execution.ID,
+ WorkflowStepDefinitionID: secondStepDef.ID,
+ Status: workflows.StepStatusInProgress.String(),
+ AssignedToType: workflows.AssignmentTypeUser.String(),
+ AssignedToID: "old-user",
+ }
+ require.NoError(t, db.Create(secondStep).Error)
+
+ newAssigneeID := uuid.New()
+ user := &relational.User{
+ UUIDModel: relational.UUIDModel{ID: &newAssigneeID},
+ Email: "bulk-notify@example.com",
+ FirstName: "Bulk",
+ LastName: "Notify",
+ IsActive: true,
+ AuthMethod: "local",
+ }
+ require.NoError(t, db.Create(user).Error)
+
+ notificationEnqueuer.On("EnqueueWorkflowTaskAssigned", mock.Anything, mock.MatchedBy(func(step *workflows.StepExecution) bool {
+ return step != nil && step.ID != nil && (*step.ID == *stepExec.ID || *step.ID == *secondStep.ID)
+ })).Return(nil).Twice()
+
+ result, err := service.BulkReassignByRole(
+ context.Background(),
+ *execution.ID,
+ stepDef.ResponsibleRole,
+ Assignee{Type: workflows.AssignmentTypeUser.String(), ID: newAssigneeID.String()},
+ "bulk handoff",
+ nil,
+ "actor@example.com",
+ )
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ assert.Equal(t, 2, result.ReassignedCount)
+
+ notificationEnqueuer.AssertExpectations(t)
+}
diff --git a/internal/workflow/constants.go b/internal/workflow/constants.go
index 0bc4c421..b8ec19af 100644
--- a/internal/workflow/constants.go
+++ b/internal/workflow/constants.go
@@ -1,19 +1,6 @@
package workflow
-import "time"
-
-// Execution timing constants
-const (
- DefaultExecutionTimeout = 30 * time.Second
- StepPollInterval = 100 * time.Millisecond
- StepSimulationTime = 100 * time.Millisecond
-)
-
-// Orphan step reasons
-const (
- OrphanReasonNoDependencies = "no_dependencies"
- OrphanReasonNoDependents = "no_dependents"
-)
+import "github.com/compliance-framework/api/internal/service/relational/workflows"
// StepStatus represents the status of a workflow step execution
// Note: This type mirrors workflows.StepExecutionStatus for use in the workflow orchestration layer.
@@ -33,16 +20,53 @@ const (
StatusCancelled StepStatus = "cancelled"
)
-// IsValid checks if the step status is valid
-func (s StepStatus) IsValid() bool {
- switch s {
- case StatusPending, StatusBlocked, StatusInProgress, StatusOverdue, StatusCompleted, StatusFailed, StatusSkipped, StatusCancelled:
- return true
- }
- return false
-}
-
// String returns the string representation of the status
func (s StepStatus) String() string {
return string(s)
}
+
+// StepStatusCounts holds a count of step executions per status.
+type StepStatusCounts struct {
+ Pending int
+ Blocked int
+ InProgress int
+ Overdue int
+ Completed int
+ Failed int
+ Skipped int
+ Cancelled int
+}
+
+// AllTerminal returns true when no step can make further progress, matching the semantics
+// of checkWorkflowCompletion: all steps must be in {completed, failed, skipped} — cancelled
+// steps are intentionally not treated as terminal here so that a partially-cancelled execution
+// is not incorrectly marked complete.
+func (c StepStatusCounts) AllTerminal() bool {
+ return c.Pending == 0 && c.Blocked == 0 && c.InProgress == 0 && c.Overdue == 0 && c.Cancelled == 0
+}
+
+// CountStepStatuses counts step executions by status.
+func CountStepStatuses(steps []workflows.StepExecution) StepStatusCounts {
+ var c StepStatusCounts
+ for _, s := range steps {
+ switch s.Status {
+ case StatusPending.String():
+ c.Pending++
+ case StatusBlocked.String():
+ c.Blocked++
+ case StatusInProgress.String():
+ c.InProgress++
+ case StatusOverdue.String():
+ c.Overdue++
+ case StatusCompleted.String():
+ c.Completed++
+ case StatusFailed.String():
+ c.Failed++
+ case StatusSkipped.String():
+ c.Skipped++
+ case StatusCancelled.String():
+ c.Cancelled++
+ }
+ }
+ return c
+}
diff --git a/internal/workflow/dag.go b/internal/workflow/dag.go
deleted file mode 100644
index 372c1ee6..00000000
--- a/internal/workflow/dag.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package workflow
-
-import (
- "github.com/google/uuid"
-)
-
-// DAGValidationResult contains the results of DAG validation
-type DAGValidationResult struct {
- IsValid bool
- Errors []string
- Warnings []string
- Cycles []DAGCycle
- OrphanedSteps []OrphanedStep
- Dependencies []DependencyIssue
-}
-
-// DAGCycle represents a detected cycle in the DAG
-type DAGCycle struct {
- Steps []string // Step names/IDs in cycle order
- Path string // Human-readable cycle path
-}
-
-// OrphanedStep represents a step with no dependencies or dependents
-type OrphanedStep struct {
- StepID uuid.UUID
- StepName string
- Reason string // "no_dependencies" or "no_dependents"
-}
-
-// DependencyIssue represents a dependency validation problem
-type DependencyIssue struct {
- FromStepID uuid.UUID
- FromStepName string
- ToStepID uuid.UUID
- ToStepName string
- Issue string
-}
diff --git a/internal/workflow/dag_helpers.go b/internal/workflow/dag_helpers.go
deleted file mode 100644
index a009e67c..00000000
--- a/internal/workflow/dag_helpers.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package workflow
-
-import (
- "fmt"
-
- "github.com/compliance-framework/api/internal/service/relational/workflows"
- "github.com/google/uuid"
-)
-
-// stepIndex provides O(1) lookup for steps by ID
-type stepIndex map[string]*workflows.WorkflowStepDefinition
-
-// buildStepIndex creates an index for fast step lookups
-func buildStepIndex(steps []workflows.WorkflowStepDefinition) stepIndex {
- index := make(stepIndex, len(steps))
- for i := range steps {
- index[steps[i].ID.String()] = &steps[i]
- }
- return index
-}
-
-// findStep retrieves a step by ID from the index
-func (idx stepIndex) findStep(id *uuid.UUID) (*workflows.WorkflowStepDefinition, error) {
- step, exists := idx[id.String()]
- if !exists {
- return nil, fmt.Errorf("step with id %s not found", id.String())
- }
- return step, nil
-}
-
-// validateStepsNotEmpty checks if steps slice is empty
-func validateStepsNotEmpty(steps []workflows.WorkflowStepDefinition) error {
- if len(steps) == 0 {
- return fmt.Errorf("workflow has no steps defined")
- }
- return nil
-}
diff --git a/internal/workflow/dag_test.go b/internal/workflow/dag_test.go
deleted file mode 100644
index 487dedc5..00000000
--- a/internal/workflow/dag_test.go
+++ /dev/null
@@ -1,167 +0,0 @@
-package workflow
-
-import (
- "testing"
-
- "github.com/compliance-framework/api/internal/service/relational/workflows"
- "github.com/google/uuid"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestValidateStepsNotEmpty(t *testing.T) {
- // Test empty steps
- err := validateStepsNotEmpty([]workflows.WorkflowStepDefinition{})
- require.Error(t, err)
- assert.Contains(t, err.Error(), "no steps defined")
-
- // Test non-empty steps
- stepA := workflows.WorkflowStepDefinition{Name: "Step A"}
- err = validateStepsNotEmpty([]workflows.WorkflowStepDefinition{stepA})
- require.NoError(t, err)
-}
-
-func TestBuildStepIndex(t *testing.T) {
- stepA := workflows.WorkflowStepDefinition{Name: "Step A"}
- stepB := workflows.WorkflowStepDefinition{Name: "Step B"}
-
- // Set the UUID IDs
- idA := uuid.New()
- idB := uuid.New()
- stepA.ID = &idA
- stepB.ID = &idB
-
- steps := []workflows.WorkflowStepDefinition{stepA, stepB}
-
- index := buildStepIndex(steps)
-
- // Test finding existing steps
- foundStep, err := index.findStep(stepA.ID)
- require.NoError(t, err)
- assert.Equal(t, stepA.Name, foundStep.Name)
-
- foundStep, err = index.findStep(stepB.ID)
- require.NoError(t, err)
- assert.Equal(t, stepB.Name, foundStep.Name)
-
- // Test finding non-existent step
- nonExistentID := uuid.New()
- _, err = index.findStep(&nonExistentID)
- require.Error(t, err)
- assert.Contains(t, err.Error(), "not found")
-}
-
-func TestCountDependents(t *testing.T) {
- stepA := workflows.WorkflowStepDefinition{Name: "Step A"}
- stepB := workflows.WorkflowStepDefinition{Name: "Step B"}
- stepC := workflows.WorkflowStepDefinition{Name: "Step C"}
-
- // Set the UUID IDs
- idA := uuid.New()
- idB := uuid.New()
- idC := uuid.New()
- stepA.ID = &idA
- stepB.ID = &idB
- stepC.ID = &idC
-
- steps := []workflows.WorkflowStepDefinition{stepA, stepB, stepC}
-
- // Create dependency map: A depends on B, B depends on C
- depMap := map[string][]workflows.WorkflowStepDefinition{
- stepA.ID.String(): {stepB},
- stepB.ID.String(): {stepC},
- stepC.ID.String(): {},
- }
-
- dependents := countDependents(steps, depMap)
-
- // C has 1 dependent (B)
- assert.Equal(t, 1, dependents[stepC.ID.String()])
- // B has 1 dependent (A)
- assert.Equal(t, 1, dependents[stepB.ID.String()])
- // A has 0 dependents
- assert.Equal(t, 0, dependents[stepA.ID.String()])
-}
-
-func TestDAGValidationResult_Structure(t *testing.T) {
- result := &DAGValidationResult{
- IsValid: false,
- Errors: []string{"test error"},
- Warnings: []string{"test warning"},
- Cycles: []DAGCycle{{Path: "A -> B -> A"}},
- OrphanedSteps: []OrphanedStep{{StepName: "Orphaned Step"}},
- Dependencies: []DependencyIssue{{Issue: "dependency issue"}},
- }
-
- assert.False(t, result.IsValid)
- assert.Len(t, result.Errors, 1)
- assert.Len(t, result.Warnings, 1)
- assert.Len(t, result.Cycles, 1)
- assert.Len(t, result.OrphanedSteps, 1)
- assert.Len(t, result.Dependencies, 1)
-}
-
-func TestDAGCycle_Structure(t *testing.T) {
- cycle := DAGCycle{
- Steps: []string{"A", "B", "C"},
- Path: "A -> B -> C -> A",
- }
-
- assert.Len(t, cycle.Steps, 3)
- assert.Equal(t, "A -> B -> C -> A", cycle.Path)
-}
-
-func TestOrphanedStep_Structure(t *testing.T) {
- stepID := uuid.New()
- orphaned := OrphanedStep{
- StepID: stepID,
- StepName: "Orphaned Step",
- Reason: "no_dependencies",
- }
-
- assert.Equal(t, stepID, orphaned.StepID)
- assert.Equal(t, "Orphaned Step", orphaned.StepName)
- assert.Equal(t, "no_dependencies", orphaned.Reason)
-}
-
-func TestDependencyIssue_Structure(t *testing.T) {
- fromID := uuid.New()
- toID := uuid.New()
-
- issue := DependencyIssue{
- FromStepID: fromID,
- FromStepName: "Step A",
- ToStepID: toID,
- ToStepName: "Step B",
- Issue: "test issue",
- }
-
- assert.Equal(t, fromID, issue.FromStepID)
- assert.Equal(t, "Step A", issue.FromStepName)
- assert.Equal(t, toID, issue.ToStepID)
- assert.Equal(t, "Step B", issue.ToStepName)
- assert.Equal(t, "test issue", issue.Issue)
-}
-
-// Benchmark test for performance validation
-func BenchmarkDAGValidation_ComplexWorkflow(b *testing.B) {
- // Create a complex workflow with many steps and dependencies
- steps := make([]workflows.WorkflowStepDefinition, 100)
- for i := 0; i < 100; i++ {
- steps[i] = workflows.WorkflowStepDefinition{Name: "Step " + string(rune(i))}
- }
-
- // Create dependency map
- depMap := make(map[string][]workflows.WorkflowStepDefinition)
- for i := 1; i < 100; i++ {
- depMap[steps[i].ID.String()] = []workflows.WorkflowStepDefinition{steps[i-1]}
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- dependents := countDependents(steps, depMap)
- if len(dependents) != 100 {
- b.Fatal("Unexpected dependents count")
- }
- }
-}
diff --git a/internal/workflow/evidence.go b/internal/workflow/evidence.go
index edd8c1ec..e24b05f9 100644
--- a/internal/workflow/evidence.go
+++ b/internal/workflow/evidence.go
@@ -49,11 +49,6 @@ func NewEvidenceIntegration(
}
}
-// SetWorkflowExecutionService sets the workflow execution service (to avoid circular dependency)
-func (e *EvidenceIntegration) SetWorkflowExecutionService(svc *workflows.WorkflowExecutionService) {
- e.workflowExecutionSvc = svc
-}
-
// GetOrCreateExecutionStream gets or creates the evidence stream for a workflow execution
// This stream accumulates all step completion evidence for this execution
func (e *EvidenceIntegration) GetOrCreateExecutionStream(ctx context.Context, workflowExecutionID *uuid.UUID) (*relational.Evidence, error) {
@@ -338,100 +333,6 @@ func (e *EvidenceIntegration) AddStepStartedEvidence(ctx context.Context, stepEx
return nil
}
-// AddStepCompletionEvidence adds a step completion evidence record to the execution stream
-// and links any user-submitted StepEvidence records to it
-func (e *EvidenceIntegration) AddStepCompletionEvidence(ctx context.Context, stepExecutionID *uuid.UUID) error {
- // Get step execution
- stepExecution, err := e.stepExecutionSvc.GetByID(stepExecutionID)
- if err != nil {
- return fmt.Errorf("failed to get step execution: %w", err)
- }
-
- // Get or create execution stream
- stream, err := e.GetOrCreateExecutionStream(ctx, stepExecution.WorkflowExecutionID)
- if err != nil {
- return fmt.Errorf("failed to get execution stream: %w", err)
- }
-
- // Get step definition
- stepDef, err := e.stepDefinitionSvc.GetByID(stepExecution.WorkflowStepDefinitionID)
- if err != nil {
- return fmt.Errorf("failed to get step definition: %w", err)
- }
-
- // Get user-submitted step evidence
- var stepEvidences []workflows.StepEvidence
- if err := e.db.Where("step_execution_id = ?", stepExecutionID).Find(&stepEvidences).Error; err != nil {
- e.logger.Warnw("Failed to get step evidence", "error", err)
- }
-
- // Create individual evidence record for this step completion
- description := fmt.Sprintf("Step '%s' completed\nStatus: %s\nStarted: %s\nCompleted: %s",
- stepDef.Name,
- stepExecution.Status,
- stepExecution.StartedAt.Format(time.RFC3339),
- stepExecution.CompletedAt.Format(time.RFC3339),
- )
-
- if len(stepEvidences) > 0 {
- description += fmt.Sprintf("\nEvidence Submitted: %d items", len(stepEvidences))
- }
-
- // Build links to user-submitted evidence
- var links []relational.Link
- for _, stepEvidence := range stepEvidences {
- links = append(links, relational.Link{
- Href: fmt.Sprintf("#/evidence/%s", stepEvidence.ID.String()),
- Rel: "related",
- Text: stepEvidence.Name,
- })
- }
-
- evidence := &relational.Evidence{
- UUID: stream.UUID, // Same stream UUID
- Title: fmt.Sprintf("Step Completion: %s", stepDef.Name),
- Description: description,
- Start: *stepExecution.StartedAt,
- End: *stepExecution.CompletedAt,
- }
-
- // Add links if we have any
- if len(links) > 0 {
- evidence.Links = links
- }
-
- // Generate unique ID for this evidence record
- id := uuid.New()
- evidence.ID = &id
-
- if err := e.db.Create(evidence).Error; err != nil {
- return fmt.Errorf("failed to create step evidence: %w", err)
- }
-
- // Add labels
- labels := []relational.Labels{
- {Name: "step.execution.id", Value: stepExecution.ID.String()},
- {Name: "step.definition.id", Value: stepDef.ID.String()},
- {Name: "step.name", Value: stepDef.Name},
- {Name: "step.status", Value: stepExecution.Status},
- {Name: "evidence.type", Value: "step_completion"},
- {Name: "evidence.submitted_count", Value: fmt.Sprintf("%d", len(stepEvidences))},
- }
-
- if err := e.db.Model(evidence).Association("Labels").Append(labels); err != nil {
- return fmt.Errorf("failed to add labels: %w", err)
- }
-
- e.logger.Infow("Step completion evidence added to stream",
- "stream_uuid", stream.UUID,
- "evidence_id", evidence.ID,
- "step_execution_id", stepExecutionID,
- "linked_evidence_count", len(stepEvidences),
- )
-
- return nil
-}
-
// AddExecutionCompletionEvidence adds an execution completion evidence record to the instance stream
func (e *EvidenceIntegration) AddExecutionCompletionEvidence(ctx context.Context, workflowExecutionID *uuid.UUID) error {
// Get workflow execution
diff --git a/internal/workflow/evidence_test.go b/internal/workflow/evidence_test.go
index 85b4492d..368eb2b9 100644
--- a/internal/workflow/evidence_test.go
+++ b/internal/workflow/evidence_test.go
@@ -279,118 +279,6 @@ func TestGetOrCreateInstanceStream(t *testing.T) {
})
}
-func TestAddStepCompletionEvidence(t *testing.T) {
- db := setupEvidenceTestDB(t)
- defer func() {
- sqlDB, _ := db.DB()
- err := sqlDB.Close()
- require.NoError(t, err)
- }()
-
- logger := zap.NewNop().Sugar()
- integration := NewEvidenceIntegration(db, logger)
- ctx := context.Background()
-
- definition, _, execution, _ := createTestWorkflowContext(t, db)
-
- // Create step definition
- stepDef := &workflows.WorkflowStepDefinition{
- WorkflowDefinitionID: definition.ID,
- Name: "Test Step",
- ResponsibleRole: "engineer",
- }
- require.NoError(t, db.Create(stepDef).Error)
-
- // Create step execution
- startTime := time.Now()
- completedTime := time.Now().Add(5 * time.Minute)
- stepExecution := &workflows.StepExecution{
- WorkflowExecutionID: execution.ID,
- WorkflowStepDefinitionID: stepDef.ID,
- Status: "completed",
- StartedAt: &startTime,
- CompletedAt: &completedTime,
- }
- require.NoError(t, db.Create(stepExecution).Error)
-
- t.Run("AddStepEvidence", func(t *testing.T) {
- err := integration.AddStepCompletionEvidence(ctx, stepExecution.ID)
- require.NoError(t, err)
-
- // Verify execution stream was created
- stream, err := integration.GetOrCreateExecutionStream(ctx, execution.ID)
- require.NoError(t, err)
-
- // Verify evidence record was created with same stream UUID
- var evidenceRecords []relational.Evidence
- err = db.Where("uuid = ?", stream.UUID).Find(&evidenceRecords).Error
- require.NoError(t, err)
- assert.GreaterOrEqual(t, len(evidenceRecords), 2) // Stream + step evidence
-
- // Find the step evidence record
- var stepEvidence *relational.Evidence
- for _, record := range evidenceRecords {
- if record.Title == "Step Completion: Test Step" {
- stepEvidence = &record
- break
- }
- }
- require.NotNil(t, stepEvidence, "Step evidence not found")
- assert.Contains(t, stepEvidence.Description, "Test Step")
- assert.Contains(t, stepEvidence.Description, "completed")
- assert.Equal(t, startTime.Unix(), stepEvidence.Start.Unix())
- assert.Equal(t, completedTime.Unix(), stepEvidence.End.Unix())
-
- // Verify labels
- var labels []relational.Labels
- err = db.Model(stepEvidence).Association("Labels").Find(&labels)
- require.NoError(t, err)
- assert.Greater(t, len(labels), 0)
-
- labelMap := make(map[string]string)
- for _, label := range labels {
- labelMap[label.Name] = label.Value
- }
- assert.Equal(t, "step_completion", labelMap["evidence.type"])
- assert.Equal(t, stepExecution.ID.String(), labelMap["step.execution.id"])
- assert.Equal(t, stepDef.ID.String(), labelMap["step.definition.id"])
- assert.Equal(t, "Test Step", labelMap["step.name"])
- })
-
- t.Run("MultipleStepsInSameStream", func(t *testing.T) {
- // Create another step execution
- stepDef2 := &workflows.WorkflowStepDefinition{
- WorkflowDefinitionID: definition.ID,
- Name: "Test Step 2",
- ResponsibleRole: "engineer",
- }
- require.NoError(t, db.Create(stepDef2).Error)
-
- startTime2 := time.Now()
- completedTime2 := time.Now().Add(3 * time.Minute)
- stepExecution2 := &workflows.StepExecution{
- WorkflowExecutionID: execution.ID,
- WorkflowStepDefinitionID: stepDef2.ID,
- Status: "completed",
- StartedAt: &startTime2,
- CompletedAt: &completedTime2,
- }
- require.NoError(t, db.Create(stepExecution2).Error)
-
- err := integration.AddStepCompletionEvidence(ctx, stepExecution2.ID)
- require.NoError(t, err)
-
- // Verify both evidence records share the same stream UUID
- stream, err := integration.GetOrCreateExecutionStream(ctx, execution.ID)
- require.NoError(t, err)
-
- var evidenceRecords []relational.Evidence
- err = db.Where("uuid = ?", stream.UUID).Find(&evidenceRecords).Error
- require.NoError(t, err)
- assert.GreaterOrEqual(t, len(evidenceRecords), 3) // Stream + 2 step evidences
- })
-}
-
func TestAddExecutionCompletionEvidence(t *testing.T) {
db := setupEvidenceTestDB(t)
defer func() {
diff --git a/internal/workflow/executor.go b/internal/workflow/executor.go
index 1428a3f3..77c8fbd9 100644
--- a/internal/workflow/executor.go
+++ b/internal/workflow/executor.go
@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"log"
- "sync"
"time"
"github.com/compliance-framework/api/internal/config"
@@ -12,6 +11,13 @@ import (
"github.com/google/uuid"
)
+// NotificationEnqueuer is the minimal interface for enqueuing workflow notification jobs.
+// Implemented by the worker service to avoid a direct River dependency in this package.
+type NotificationEnqueuer interface {
+ EnqueueWorkflowTaskAssigned(ctx context.Context, stepExecution *workflows.StepExecution) error
+ EnqueueWorkflowExecutionFailed(ctx context.Context, execution *workflows.WorkflowExecution) error
+}
+
// DAGExecutor handles the execution of workflow DAGs with dependency resolution
// and parallel step execution capabilities
type DAGExecutor struct {
@@ -20,6 +26,7 @@ type DAGExecutor struct {
stepDefinitionService WorkflowStepDefinitionServiceInterface
assignmentService AssignmentServiceInterface
evidenceIntegration *EvidenceIntegration // Optional: for evidence stream integration
+ notificationEnqueuer NotificationEnqueuer // Optional: for workflow notification emails
logger *log.Logger
}
@@ -41,6 +48,7 @@ type WorkflowExecutionServiceInterface interface {
UpdateStatus(ctx context.Context, id *uuid.UUID, status string) error
Cancel(id *uuid.UUID) error
Fail(id *uuid.UUID, reason string) error
+ FailIfNotTerminal(ctx context.Context, id *uuid.UUID, reason string) (bool, error)
}
type WorkflowStepDefinitionServiceInterface interface {
@@ -68,6 +76,7 @@ func NewDAGExecutor(
stepDefinitionService WorkflowStepDefinitionServiceInterface,
assignmentService AssignmentServiceInterface,
logger *log.Logger,
+ notificationEnqueuer NotificationEnqueuer,
) *DAGExecutor {
if logger == nil {
logger = log.Default()
@@ -79,6 +88,7 @@ func NewDAGExecutor(
stepDefinitionService: stepDefinitionService,
assignmentService: assignmentService,
logger: logger,
+ notificationEnqueuer: notificationEnqueuer,
}
}
@@ -87,49 +97,6 @@ func (e *DAGExecutor) SetEvidenceIntegration(evidenceIntegration *EvidenceIntegr
e.evidenceIntegration = evidenceIntegration
}
-// ExecutionState tracks the current state of a workflow execution
-type ExecutionState struct {
- WorkflowExecutionID uuid.UUID
- StepStates map[uuid.UUID]*StepState
- CompletedSteps map[uuid.UUID]bool
- FailedSteps map[uuid.UUID]bool
- RunningSteps map[uuid.UUID]bool
- BlockedSteps map[uuid.UUID]bool
- mutex sync.RWMutex
-}
-
-// StepState represents the execution state of an individual step
-type StepState struct {
- StepDefinitionID uuid.UUID
- Status string
- StartedAt *time.Time
- CompletedAt *time.Time
- FailureReason string
- Dependencies []uuid.UUID
- Dependents []uuid.UUID
-}
-
-// ExecutionResult contains the results of a workflow execution
-type ExecutionResult struct {
- Success bool
- CompletedSteps int
- FailedSteps int
- TotalSteps int
- ExecutionTime time.Duration
- Errors []string
- StepResults map[uuid.UUID]*StepExecutionResult
-}
-
-// StepExecutionResult contains the result of an individual step execution
-type StepExecutionResult struct {
- StepDefinitionID uuid.UUID
- Success bool
- Status string
- StartedAt time.Time
- CompletedAt time.Time
- FailureReason string
-}
-
// InitializeWorkflow initializes a workflow execution by creating step execution records
// and setting up the initial DAG state (blocked/pending based on dependencies)
func (e *DAGExecutor) InitializeWorkflow(ctx context.Context, workflowExecutionID *uuid.UUID) error {
@@ -229,10 +196,20 @@ func (e *DAGExecutor) InitializeWorkflow(ctx context.Context, workflowExecutionI
if err := e.stepExecutionService.Create(stepExecution); err != nil {
return fmt.Errorf("failed to create step execution for step %s: %w", stepDef.ID.String(), err)
}
+
+ // Enqueue task-assigned notification for user or email-assigned, actionable steps
+ if e.notificationEnqueuer != nil &&
+ (stepExecution.AssignedToType == workflows.AssignmentTypeUser.String() ||
+ stepExecution.AssignedToType == workflows.AssignmentTypeEmail.String()) &&
+ initialStatus == StatusPending.String() {
+ if err := e.notificationEnqueuer.EnqueueWorkflowTaskAssigned(ctx, stepExecution); err != nil {
+ e.logger.Printf("Warning: failed to enqueue task-assigned notification for step %s: %v", stepDef.ID.String(), err)
+ }
+ }
}
if len(retryCompletedSteps) > 0 {
- if err := e.unblockReadySteps(workflowExecutionID); err != nil {
+ if err := e.unblockReadySteps(ctx, workflowExecutionID); err != nil {
e.logger.Printf("Warning: failed to unblock ready steps: %v", err)
}
}
@@ -288,56 +265,17 @@ func (e *DAGExecutor) resolveRetryCompletedStepDefinitions(workflowExecution *wo
return completed
}
-func (e *DAGExecutor) unblockReadySteps(workflowExecutionID *uuid.UUID) error {
+func (e *DAGExecutor) unblockReadySteps(ctx context.Context, workflowExecutionID *uuid.UUID) error {
stepExecutions, err := e.stepExecutionService.GetByWorkflowExecutionID(workflowExecutionID)
if err != nil {
return err
}
for i := range stepExecutions {
- _ = e.tryUnblockStep(&stepExecutions[i])
+ _ = e.tryUnblockStep(ctx, &stepExecutions[i])
}
return nil
}
-// initializeExecutionState creates and initializes the execution state for a workflow
-func (e *DAGExecutor) initializeExecutionState(workflowExecutionID *uuid.UUID, stepDefinitions []workflows.WorkflowStepDefinition) *ExecutionState {
- state := &ExecutionState{
- WorkflowExecutionID: *workflowExecutionID,
- StepStates: make(map[uuid.UUID]*StepState),
- CompletedSteps: make(map[uuid.UUID]bool),
- FailedSteps: make(map[uuid.UUID]bool),
- RunningSteps: make(map[uuid.UUID]bool),
- BlockedSteps: make(map[uuid.UUID]bool),
- }
-
- // Initialize step states
- for _, stepDef := range stepDefinitions {
- // Get dependencies for this step
- dependencies, _ := e.stepDefinitionService.GetDependencies(stepDef.ID)
-
- // Convert dependencies to UUID slice
- depIDs := make([]uuid.UUID, len(dependencies))
- for i, dep := range dependencies {
- depIDs[i] = *dep.ID
- }
-
- stepState := &StepState{
- StepDefinitionID: *stepDef.ID,
- Status: StatusPending.String(),
- Dependencies: depIDs,
- }
-
- state.StepStates[*stepDef.ID] = stepState
-
- // Initially block steps that have dependencies
- if len(depIDs) > 0 {
- state.BlockedSteps[*stepDef.ID] = true
- }
- }
-
- return state
-}
-
// ProcessStepCompletion processes a step completion and unblocks dependent steps
// This is called after a user manually completes a step
func (e *DAGExecutor) ProcessStepCompletion(ctx context.Context, stepExecutionID *uuid.UUID) error {
@@ -356,7 +294,7 @@ func (e *DAGExecutor) ProcessStepCompletion(ctx context.Context, stepExecutionID
}
// Unblock dependent steps that are now ready
- unblockedCount := e.unblockDependentSteps(stepExecution, dependentSteps)
+ unblockedCount := e.unblockDependentSteps(ctx, stepExecution, dependentSteps)
e.logger.Printf("Unblocked %d dependent steps", unblockedCount)
// Check if workflow is complete
@@ -414,7 +352,7 @@ func (e *DAGExecutor) ProcessStepFailure(ctx context.Context, stepExecutionID *u
}
// unblockDependentSteps processes dependent steps and unblocks those that are ready
-func (e *DAGExecutor) unblockDependentSteps(stepExecution *workflows.StepExecution, dependentSteps []workflows.WorkflowStepDefinition) int {
+func (e *DAGExecutor) unblockDependentSteps(ctx context.Context, stepExecution *workflows.StepExecution, dependentSteps []workflows.WorkflowStepDefinition) int {
unblockedCount := 0
for _, dependentStepDef := range dependentSteps {
@@ -423,7 +361,7 @@ func (e *DAGExecutor) unblockDependentSteps(stepExecution *workflows.StepExecuti
continue
}
- if e.tryUnblockStep(dependentStepExec) {
+ if e.tryUnblockStep(ctx, dependentStepExec) {
unblockedCount++
}
}
@@ -450,7 +388,7 @@ func (e *DAGExecutor) findDependentStepExecution(workflowExecutionID, stepDefini
}
// tryUnblockStep attempts to unblock a step if it's blocked and all dependencies are satisfied
-func (e *DAGExecutor) tryUnblockStep(stepExec *workflows.StepExecution) bool {
+func (e *DAGExecutor) tryUnblockStep(ctx context.Context, stepExec *workflows.StepExecution) bool {
if stepExec.Status != StatusBlocked.String() {
return false
}
@@ -471,56 +409,19 @@ func (e *DAGExecutor) tryUnblockStep(stepExec *workflows.StepExecution) bool {
}
e.logger.Printf("Unblocked step: %s", stepExec.ID.String())
- // TODO: Hook for notification - step is now ready for user action
- return true
-}
-
-// getReadySteps returns steps that are ready to be executed (dependencies satisfied)
-// NOTE: This method uses in-memory ExecutionState and is NOT used in the runtime execution path.
-// Runtime execution uses database-backed CanUnblock() in ProcessStepCompletion.
-// This method is kept for testing and benchmarking purposes only.
-func (e *DAGExecutor) getReadySteps(state *ExecutionState) []uuid.UUID {
- state.mutex.RLock()
- defer state.mutex.RUnlock()
-
- var readySteps []uuid.UUID
- for stepID, stepState := range state.StepStates {
- // Skip steps that are already completed, failed, or running
- if state.CompletedSteps[stepID] || state.FailedSteps[stepID] || state.RunningSteps[stepID] {
- continue
- }
-
- // Check if all dependencies are completed
- if e.areDependenciesCompleted(state, stepState.Dependencies) {
- readySteps = append(readySteps, stepID)
+ if e.notificationEnqueuer != nil {
+ reloaded, err := e.stepExecutionService.GetByID(stepExec.ID)
+ if err == nil && reloaded != nil {
+ if err := e.notificationEnqueuer.EnqueueWorkflowTaskAssigned(ctx, reloaded); err != nil {
+ e.logger.Printf("Warning: failed to enqueue task-assigned notification for step %s: %v", stepExec.ID.String(), err)
+ }
}
}
- return readySteps
-}
-
-// areDependenciesCompleted checks if all dependencies for a step are completed
-func (e *DAGExecutor) areDependenciesCompleted(state *ExecutionState, dependencies []uuid.UUID) bool {
- for _, depID := range dependencies {
- if !state.CompletedSteps[depID] {
- return false
- }
- }
return true
}
-// isExecutionComplete checks if all steps are completed or failed
-func (e *DAGExecutor) isExecutionComplete(state *ExecutionState) bool {
- state.mutex.RLock()
- defer state.mutex.RUnlock()
-
- totalSteps := len(state.StepStates)
- completedOrFailed := len(state.CompletedSteps) + len(state.FailedSteps)
-
- return completedOrFailed >= totalSteps
-}
-
// checkWorkflowCompletion checks if all steps are complete and updates workflow status
func (e *DAGExecutor) checkWorkflowCompletion(ctx context.Context, workflowExecutionID *uuid.UUID) error {
// Get all step executions
@@ -533,29 +434,31 @@ func (e *DAGExecutor) checkWorkflowCompletion(ctx context.Context, workflowExecu
return nil
}
- completedCount := 0
- skippedCount := 0
- failedCount := 0
- for _, stepExec := range stepExecutions {
- switch stepExec.Status {
- case StatusCompleted.String():
- completedCount++
- case StatusSkipped.String():
- skippedCount++
- case StatusFailed.String():
- failedCount++
- }
- }
+ counts := CountStepStatuses(stepExecutions)
// Check if all steps are in terminal states
- if completedCount+failedCount+skippedCount == len(stepExecutions) {
- if failedCount > 0 {
- // Workflow failed
- reason := fmt.Sprintf("%d of %d steps failed", failedCount, len(stepExecutions))
- if err := e.workflowExecutionService.Fail(workflowExecutionID, reason); err != nil {
+ if counts.AllTerminal() {
+ if counts.Failed > 0 {
+ // Workflow failed — use FailIfNotTerminal so that concurrent completions
+ // from two step-failure HTTP requests only enqueue the failure email once.
+ reason := fmt.Sprintf("%d of %d steps failed", counts.Failed, len(stepExecutions))
+ updated, err := e.workflowExecutionService.FailIfNotTerminal(ctx, workflowExecutionID, reason)
+ if err != nil {
return fmt.Errorf("failed to mark workflow as failed: %w", err)
}
e.logger.Printf("Workflow execution failed: %s", reason)
+
+ // Only enqueue the notification if this goroutine was the one that performed the update.
+ if updated && e.notificationEnqueuer != nil {
+ execution, execErr := e.workflowExecutionService.GetByID(workflowExecutionID)
+ if execErr == nil {
+ if notifyErr := e.notificationEnqueuer.EnqueueWorkflowExecutionFailed(ctx, execution); notifyErr != nil {
+ e.logger.Printf("Failed to enqueue workflow-execution-failed notification: %v", notifyErr)
+ }
+ } else {
+ e.logger.Printf("Failed to reload execution for failure notification: %v", execErr)
+ }
+ }
} else {
// All steps reached successful terminal states (completed/skipped)
if err := e.workflowExecutionService.UpdateStatus(ctx, workflowExecutionID, StatusCompleted.String()); err != nil {
@@ -580,16 +483,9 @@ func resolveStepGraceDays(workflowExecution *workflows.WorkflowExecution, stepDe
if stepDef.GracePeriodDays != nil {
return *stepDef.GracePeriodDays
}
- if workflowExecution.WorkflowInstance != nil && workflowExecution.WorkflowInstance.GracePeriodDays != nil {
- return *workflowExecution.WorkflowInstance.GracePeriodDays
- }
- if workflowExecution.WorkflowInstance != nil && workflowExecution.WorkflowInstance.WorkflowDefinition != nil &&
- workflowExecution.WorkflowInstance.WorkflowDefinition.GracePeriodDays != nil {
- return *workflowExecution.WorkflowInstance.WorkflowDefinition.GracePeriodDays
- }
// Step due-date initialization uses global workflow defaults from config.
// Execution failure grace in OverdueService can use an injected default to support worker-level overrides.
- return config.DefaultWorkflowConfig().GracePeriodDays
+ return ResolveGraceDays(workflowExecution.WorkflowInstance, config.DefaultWorkflowConfig().GracePeriodDays)
}
// CheckAutomaticTriggers checks if a step has automatic triggers configured
@@ -601,74 +497,3 @@ func (e *DAGExecutor) CheckAutomaticTriggers(ctx context.Context, stepExecutionI
e.logger.Printf("Checking automatic triggers for step: %s (not yet implemented)", stepExecutionID.String())
return nil
}
-
-// GetExecutionStatus returns the current status of a workflow execution
-func (e *DAGExecutor) GetExecutionStatus(workflowExecutionID *uuid.UUID) (*ExecutionState, error) {
- // Get all step executions for this workflow
- stepExecutions, err := e.stepExecutionService.GetByWorkflowExecutionID(workflowExecutionID)
- if err != nil {
- return nil, fmt.Errorf("failed to get step executions: %w", err)
- }
-
- // Build execution state from step executions
- state := &ExecutionState{
- WorkflowExecutionID: *workflowExecutionID,
- StepStates: make(map[uuid.UUID]*StepState),
- CompletedSteps: make(map[uuid.UUID]bool),
- FailedSteps: make(map[uuid.UUID]bool),
- RunningSteps: make(map[uuid.UUID]bool),
- BlockedSteps: make(map[uuid.UUID]bool),
- }
-
- for _, stepExec := range stepExecutions {
- stepState := &StepState{
- StepDefinitionID: *stepExec.WorkflowStepDefinitionID,
- Status: stepExec.Status,
- StartedAt: stepExec.StartedAt,
- CompletedAt: stepExec.CompletedAt,
- FailureReason: stepExec.FailureReason,
- }
-
- state.StepStates[*stepExec.WorkflowStepDefinitionID] = stepState
-
- switch stepExec.Status {
- case StatusCompleted.String():
- state.CompletedSteps[*stepExec.WorkflowStepDefinitionID] = true
- case StatusFailed.String():
- state.FailedSteps[*stepExec.WorkflowStepDefinitionID] = true
- case StatusInProgress.String():
- state.RunningSteps[*stepExec.WorkflowStepDefinitionID] = true
- case StatusBlocked.String():
- state.BlockedSteps[*stepExec.WorkflowStepDefinitionID] = true
- }
- }
-
- return state, nil
-}
-
-// CancelExecution cancels a running workflow execution
-func (e *DAGExecutor) CancelExecution(ctx context.Context, workflowExecutionID *uuid.UUID) error {
- e.logger.Printf("Cancelling workflow execution: %s", workflowExecutionID.String())
-
- // Update workflow execution status to cancelled
- if err := e.workflowExecutionService.UpdateStatus(ctx, workflowExecutionID, StatusCancelled.String()); err != nil {
- return fmt.Errorf("failed to update workflow execution status: %w", err)
- }
-
- // Cancel all running step executions
- stepExecutions, err := e.stepExecutionService.GetByWorkflowExecutionID(workflowExecutionID)
- if err != nil {
- return fmt.Errorf("failed to get step executions: %w", err)
- }
-
- for _, stepExec := range stepExecutions {
- if stepExec.Status == StatusInProgress.String() {
- if err := e.stepExecutionService.UpdateStatus(ctx, stepExec.ID, StatusCancelled.String()); err != nil {
- e.logger.Printf("Failed to cancel step execution %s: %v", stepExec.ID.String(), err)
- }
- }
- }
-
- e.logger.Printf("Workflow execution cancelled: %s", workflowExecutionID.String())
- return nil
-}
diff --git a/internal/workflow/executor_integration_test.go b/internal/workflow/executor_integration_test.go
index b5d72675..d7421920 100644
--- a/internal/workflow/executor_integration_test.go
+++ b/internal/workflow/executor_integration_test.go
@@ -13,6 +13,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -174,11 +175,11 @@ func TestDAGExecutor_Integration_InitializeWorkflow(t *testing.T) {
workflowExecService := workflows.NewWorkflowExecutionService(db)
stepDefService := workflows.NewWorkflowStepDefinitionService(db)
roleAssignmentService := workflows.NewRoleAssignmentService(db)
- assignmentService := NewAssignmentService(roleAssignmentService, db)
+ assignmentService := NewAssignmentService(roleAssignmentService, stepExecService, db, zap.NewNop().Sugar(), nil)
// Create executor
logger := log.New(os.Stdout, "[TEST] ", log.LstdFlags)
- executor := NewDAGExecutor(stepExecService, workflowExecService, stepDefService, assignmentService, logger)
+ executor := NewDAGExecutor(stepExecService, workflowExecService, stepDefService, assignmentService, logger, nil)
// Create test workflow
workflowDef, stepDefs := createTestWorkflow(t, db)
@@ -236,11 +237,11 @@ func TestDAGExecutor_Integration_ProcessStepCompletion(t *testing.T) {
workflowExecService := workflows.NewWorkflowExecutionService(db)
stepDefService := workflows.NewWorkflowStepDefinitionService(db)
roleAssignmentService := workflows.NewRoleAssignmentService(db)
- assignmentService := NewAssignmentService(roleAssignmentService, db)
+ assignmentService := NewAssignmentService(roleAssignmentService, stepExecService, db, zap.NewNop().Sugar(), nil)
// Create executor
logger := log.New(os.Stdout, "[TEST] ", log.LstdFlags)
- executor := NewDAGExecutor(stepExecService, workflowExecService, stepDefService, assignmentService, logger)
+ executor := NewDAGExecutor(stepExecService, workflowExecService, stepDefService, assignmentService, logger, nil)
// Create test workflow
workflowDef, stepDefs := createTestWorkflow(t, db)
@@ -282,65 +283,6 @@ func TestDAGExecutor_Integration_ProcessStepCompletion(t *testing.T) {
assert.Equal(t, "blocked", step3Exec.Status)
}
-func TestDAGExecutor_Integration_GetExecutionStatus(t *testing.T) {
- if testing.Short() {
- t.Skip("Skipping integration test in short mode")
- }
-
- // Setup test database
- db := setupTestDB(t)
- defer func() {
- sqlDB, _ := db.DB()
- sqlDB.Close()
- }()
-
- // Create services
- stepExecService := workflows.NewStepExecutionService(db, nil)
- workflowExecService := workflows.NewWorkflowExecutionService(db)
- stepDefService := workflows.NewWorkflowStepDefinitionService(db)
- roleAssignmentService := workflows.NewRoleAssignmentService(db)
- assignmentService := NewAssignmentService(roleAssignmentService, db)
-
- // Create executor
- logger := log.New(os.Stdout, "[TEST] ", log.LstdFlags)
- executor := NewDAGExecutor(stepExecService, workflowExecService, stepDefService, assignmentService, logger)
-
- // Create test workflow
- workflowDef, stepDefs := createTestWorkflow(t, db)
- instance := createTestWorkflowInstance(t, db, workflowDef)
- execution := createTestWorkflowExecution(t, db, instance)
-
- // Get execution status before initialization
- state, err := executor.GetExecutionStatus(execution.ID)
- require.NoError(t, err)
- assert.Equal(t, *execution.ID, state.WorkflowExecutionID)
- assert.Len(t, state.StepStates, 0) // No step executions yet
-
- // Initialize workflow
- ctx := context.Background()
- err = executor.InitializeWorkflow(ctx, execution.ID)
- require.NoError(t, err)
-
- // Get execution status after initialization
- state, err = executor.GetExecutionStatus(execution.ID)
- require.NoError(t, err)
- assert.Len(t, state.StepStates, 3)
- assert.Len(t, state.CompletedSteps, 0)
- assert.Len(t, state.FailedSteps, 0)
- assert.Len(t, state.BlockedSteps, 2) // Steps 2 and 3 are blocked
-
- // Verify step statuses
- for i, stepDef := range stepDefs {
- stepState, exists := state.StepStates[*stepDef.ID]
- require.True(t, exists, "Step state not found for step %s", stepDef.Name)
- if i == 0 {
- assert.Equal(t, "pending", stepState.Status)
- } else {
- assert.Equal(t, "blocked", stepState.Status)
- }
- }
-}
-
func TestDAGExecutor_Integration_ParallelSteps(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
@@ -358,11 +300,11 @@ func TestDAGExecutor_Integration_ParallelSteps(t *testing.T) {
workflowExecService := workflows.NewWorkflowExecutionService(db)
stepDefService := workflows.NewWorkflowStepDefinitionService(db)
roleAssignmentService := workflows.NewRoleAssignmentService(db)
- assignmentService := NewAssignmentService(roleAssignmentService, db)
+ assignmentService := NewAssignmentService(roleAssignmentService, stepExecService, db, zap.NewNop().Sugar(), nil)
// Create executor
logger := log.New(os.Stdout, "[TEST] ", log.LstdFlags)
- executor := NewDAGExecutor(stepExecService, workflowExecService, stepDefService, assignmentService, logger)
+ executor := NewDAGExecutor(stepExecService, workflowExecService, stepDefService, assignmentService, logger, nil)
// Create workflow with parallel steps (no dependencies)
workflowDefID := uuid.New()
diff --git a/internal/workflow/executor_test.go b/internal/workflow/executor_test.go
index 020d68b4..e2454722 100644
--- a/internal/workflow/executor_test.go
+++ b/internal/workflow/executor_test.go
@@ -6,7 +6,6 @@ import (
"fmt"
"log"
"testing"
- "time"
"github.com/compliance-framework/api/internal/service/relational"
"github.com/compliance-framework/api/internal/service/relational/workflows"
@@ -94,6 +93,11 @@ func (m *MockWorkflowExecutionService) Fail(id *uuid.UUID, reason string) error
return args.Error(0)
}
+func (m *MockWorkflowExecutionService) FailIfNotTerminal(ctx context.Context, id *uuid.UUID, reason string) (bool, error) {
+ args := m.Called(ctx, id, reason)
+ return args.Bool(0), args.Error(1)
+}
+
// MockWorkflowStepDefinitionService is a mock for workflows.WorkflowStepDefinitionService
type MockWorkflowStepDefinitionService struct {
mock.Mock
@@ -138,6 +142,20 @@ func (m *MockAssignmentService) ResolveStepAssignees(ctx context.Context, instan
return args.Get(0).(map[uuid.UUID]Assignee), args.Error(1)
}
+type MockNotificationEnqueuer struct {
+ mock.Mock
+}
+
+func (m *MockNotificationEnqueuer) EnqueueWorkflowTaskAssigned(ctx context.Context, stepExecution *workflows.StepExecution) error {
+ args := m.Called(ctx, stepExecution)
+ return args.Error(0)
+}
+
+func (m *MockNotificationEnqueuer) EnqueueWorkflowExecutionFailed(ctx context.Context, execution *workflows.WorkflowExecution) error {
+ args := m.Called(ctx, execution)
+ return args.Error(0)
+}
+
func TestNewDAGExecutor(t *testing.T) {
mockStepExecService := &MockStepExecutionService{}
mockWorkflowExecService := &MockWorkflowExecutionService{}
@@ -152,6 +170,7 @@ func TestNewDAGExecutor(t *testing.T) {
mockStepDefService,
mockAssignmentService,
logger,
+ nil,
)
assert.NotNil(t, executor)
@@ -162,166 +181,6 @@ func TestNewDAGExecutor(t *testing.T) {
assert.Equal(t, logger, executor.logger)
}
-func TestInitializeExecutionState(t *testing.T) {
- mockStepExecService := &MockStepExecutionService{}
- mockWorkflowExecService := &MockWorkflowExecutionService{}
- mockStepDefService := &MockWorkflowStepDefinitionService{}
- mockAssignmentService := &MockAssignmentService{}
- logger := log.New(bytes.NewBufferString(""), "", log.LstdFlags)
-
- executor := NewDAGExecutor(mockStepExecService, mockWorkflowExecService, mockStepDefService, mockAssignmentService, logger)
-
- // Create test step definitions
- stepDefID1 := uuid.New()
- stepDefID2 := uuid.New()
- stepDefID3 := uuid.New()
-
- stepDefinitions := []workflows.WorkflowStepDefinition{
- {UUIDModel: relational.UUIDModel{ID: &stepDefID1}, Name: "Step 1"},
- {UUIDModel: relational.UUIDModel{ID: &stepDefID2}, Name: "Step 2"},
- {UUIDModel: relational.UUIDModel{ID: &stepDefID3}, Name: "Step 3"},
- }
-
- // Mock dependencies: Step 2 depends on Step 1, Step 3 depends on Step 2
- mockStepDefService.On("GetDependencies", &stepDefID1).Return([]workflows.WorkflowStepDefinition{}, nil)
- mockStepDefService.On("GetDependencies", &stepDefID2).Return([]workflows.WorkflowStepDefinition{workflows.WorkflowStepDefinition{UUIDModel: relational.UUIDModel{ID: &stepDefID1}}}, nil)
- mockStepDefService.On("GetDependencies", &stepDefID3).Return([]workflows.WorkflowStepDefinition{workflows.WorkflowStepDefinition{UUIDModel: relational.UUIDModel{ID: &stepDefID2}}}, nil)
-
- workflowExecutionID := uuid.New()
- state := executor.initializeExecutionState(&workflowExecutionID, stepDefinitions)
-
- // Verify state initialization
- assert.Equal(t, workflowExecutionID, state.WorkflowExecutionID)
- assert.Len(t, state.StepStates, 3)
- assert.Len(t, state.BlockedSteps, 2) // Steps 2 and 3 should be blocked
- assert.Contains(t, state.BlockedSteps, stepDefID2)
- assert.Contains(t, state.BlockedSteps, stepDefID3)
- assert.NotContains(t, state.BlockedSteps, stepDefID1) // Step 1 has no dependencies
-
- mockStepDefService.AssertExpectations(t)
-}
-
-func TestGetReadySteps(t *testing.T) {
- executor := createTestExecutor(t)
-
- // Create test state
- workflowExecutionID := uuid.New()
- state := &ExecutionState{
- WorkflowExecutionID: workflowExecutionID,
- StepStates: make(map[uuid.UUID]*StepState),
- CompletedSteps: make(map[uuid.UUID]bool),
- FailedSteps: make(map[uuid.UUID]bool),
- RunningSteps: make(map[uuid.UUID]bool),
- BlockedSteps: make(map[uuid.UUID]bool),
- }
-
- // Create test steps
- stepID1 := uuid.New()
- stepID2 := uuid.New()
- stepID3 := uuid.New()
-
- // Step 1: no dependencies, should be ready
- state.StepStates[stepID1] = &StepState{
- StepDefinitionID: stepID1,
- Status: "pending",
- Dependencies: []uuid.UUID{},
- }
-
- // Step 2: depends on step 1, not ready yet
- state.StepStates[stepID2] = &StepState{
- StepDefinitionID: stepID2,
- Status: "pending",
- Dependencies: []uuid.UUID{stepID1},
- }
-
- // Step 3: depends on step 2, not ready yet
- state.StepStates[stepID3] = &StepState{
- StepDefinitionID: stepID3,
- Status: "pending",
- Dependencies: []uuid.UUID{stepID2},
- }
-
- // Initially, only step 1 should be ready
- readySteps := executor.getReadySteps(state)
- assert.Len(t, readySteps, 1)
- assert.Contains(t, readySteps, stepID1)
-
- // Mark step 1 as completed
- state.CompletedSteps[stepID1] = true
-
- // Now step 2 should be ready
- readySteps = executor.getReadySteps(state)
- assert.Len(t, readySteps, 1)
- assert.Contains(t, readySteps, stepID2)
-
- // Mark step 2 as completed
- state.CompletedSteps[stepID2] = true
-
- // Now step 3 should be ready
- readySteps = executor.getReadySteps(state)
- assert.Len(t, readySteps, 1)
- assert.Contains(t, readySteps, stepID3)
-}
-
-func TestAreDependenciesCompleted(t *testing.T) {
- executor := createTestExecutor(t)
-
- state := &ExecutionState{
- CompletedSteps: make(map[uuid.UUID]bool),
- }
-
- stepID1 := uuid.New()
- stepID2 := uuid.New()
- stepID3 := uuid.New()
-
- // Test with no dependencies
- dependencies := []uuid.UUID{}
- assert.True(t, executor.areDependenciesCompleted(state, dependencies))
-
- // Test with completed dependencies
- dependencies = []uuid.UUID{stepID1, stepID2}
- state.CompletedSteps[stepID1] = true
- state.CompletedSteps[stepID2] = true
- assert.True(t, executor.areDependenciesCompleted(state, dependencies))
-
- // Test with incomplete dependencies
- dependencies = []uuid.UUID{stepID1, stepID3}
- assert.False(t, executor.areDependenciesCompleted(state, dependencies))
-}
-
-func TestIsExecutionComplete(t *testing.T) {
- executor := createTestExecutor(t)
-
- state := &ExecutionState{
- StepStates: make(map[uuid.UUID]*StepState),
- CompletedSteps: make(map[uuid.UUID]bool),
- FailedSteps: make(map[uuid.UUID]bool),
- }
-
- stepID1 := uuid.New()
- stepID2 := uuid.New()
-
- // Add steps to state
- state.StepStates[stepID1] = &StepState{StepDefinitionID: stepID1}
- state.StepStates[stepID2] = &StepState{StepDefinitionID: stepID2}
-
- // Initially not complete
- assert.False(t, executor.isExecutionComplete(state))
-
- // Mark one step as completed
- state.CompletedSteps[stepID1] = true
- assert.False(t, executor.isExecutionComplete(state))
-
- // Mark second step as completed
- state.CompletedSteps[stepID2] = true
- assert.True(t, executor.isExecutionComplete(state))
-
- // Test with failure
- state.CompletedSteps[stepID2] = false
- state.FailedSteps[stepID2] = true
- assert.True(t, executor.isExecutionComplete(state))
-}
-
func TestResolveStepGraceDays_Preference(t *testing.T) {
defGrace := 9
instanceGrace := 5
@@ -353,7 +212,7 @@ func TestInitializeWorkflow_Success(t *testing.T) {
mockAssignmentService := &MockAssignmentService{}
logger := log.New(bytes.NewBufferString(""), "", log.LstdFlags)
- executor := NewDAGExecutor(mockStepExecService, mockWorkflowExecService, mockStepDefService, mockAssignmentService, logger)
+ executor := NewDAGExecutor(mockStepExecService, mockWorkflowExecService, mockStepDefService, mockAssignmentService, logger, nil)
// Setup test data
workflowExecutionID := uuid.New()
@@ -420,7 +279,7 @@ func TestInitializeWorkflow_Failure(t *testing.T) {
mockAssignmentService := &MockAssignmentService{}
logger := log.New(bytes.NewBufferString(""), "", log.LstdFlags)
- executor := NewDAGExecutor(mockStepExecService, mockWorkflowExecService, mockStepDefService, mockAssignmentService, logger)
+ executor := NewDAGExecutor(mockStepExecService, mockWorkflowExecService, mockStepDefService, mockAssignmentService, logger, nil)
// Setup test data
workflowExecutionID := uuid.New()
@@ -440,163 +299,74 @@ func TestInitializeWorkflow_Failure(t *testing.T) {
mockWorkflowExecService.AssertExpectations(t)
}
-func TestGetExecutionStatus(t *testing.T) {
+func TestProcessStepCompletion_PassesRequestContextToNotificationEnqueuer(t *testing.T) {
mockStepExecService := &MockStepExecutionService{}
mockWorkflowExecService := &MockWorkflowExecutionService{}
mockStepDefService := &MockWorkflowStepDefinitionService{}
mockAssignmentService := &MockAssignmentService{}
+ mockNotificationEnqueuer := &MockNotificationEnqueuer{}
logger := log.New(bytes.NewBufferString(""), "", log.LstdFlags)
- executor := NewDAGExecutor(mockStepExecService, mockWorkflowExecService, mockStepDefService, mockAssignmentService, logger)
+ executor := NewDAGExecutor(
+ mockStepExecService,
+ mockWorkflowExecService,
+ mockStepDefService,
+ mockAssignmentService,
+ logger,
+ mockNotificationEnqueuer,
+ )
- // Setup test data
+ stepExecutionID := uuid.New()
workflowExecutionID := uuid.New()
- stepDefID1 := uuid.New()
- stepDefID2 := uuid.New()
- stepExecID1 := uuid.New()
- stepExecID2 := uuid.New()
-
- stepExecutions := []workflows.StepExecution{
- {
- UUIDModel: relational.UUIDModel{ID: &stepExecID1},
- WorkflowExecutionID: &workflowExecutionID,
- WorkflowStepDefinitionID: &stepDefID1,
- Status: "completed",
- StartedAt: &time.Time{},
- CompletedAt: &time.Time{},
- },
- {
- UUIDModel: relational.UUIDModel{ID: &stepExecID2},
- WorkflowExecutionID: &workflowExecutionID,
- WorkflowStepDefinitionID: &stepDefID2,
- Status: "in_progress",
- StartedAt: &time.Time{},
- },
+ completedStepDefID := uuid.New()
+ dependentStepDefID := uuid.New()
+ dependentStepExecID := uuid.New()
+
+ completedStep := &workflows.StepExecution{
+ UUIDModel: relational.UUIDModel{ID: &stepExecutionID},
+ WorkflowExecutionID: &workflowExecutionID,
+ WorkflowStepDefinitionID: &completedStepDefID,
+ Status: StatusCompleted.String(),
}
-
- // Setup mocks
- mockStepExecService.On("GetByWorkflowExecutionID", &workflowExecutionID).Return(stepExecutions, nil)
-
- // Get execution status
- state, err := executor.GetExecutionStatus(&workflowExecutionID)
-
- // Verify results
- require.NoError(t, err)
- assert.NotNil(t, state)
- assert.Equal(t, workflowExecutionID, state.WorkflowExecutionID)
- assert.Len(t, state.StepStates, 2)
- assert.Len(t, state.CompletedSteps, 1)
- assert.Len(t, state.RunningSteps, 1)
- assert.Contains(t, state.CompletedSteps, stepDefID1)
- assert.Contains(t, state.RunningSteps, stepDefID2)
-
- // Verify mocks were called
- mockStepExecService.AssertExpectations(t)
-}
-
-func TestCancelExecution(t *testing.T) {
- mockStepExecService := &MockStepExecutionService{}
- mockWorkflowExecService := &MockWorkflowExecutionService{}
- mockStepDefService := &MockWorkflowStepDefinitionService{}
- mockAssignmentService := &MockAssignmentService{}
- logger := log.New(bytes.NewBufferString(""), "", log.LstdFlags)
-
- executor := NewDAGExecutor(mockStepExecService, mockWorkflowExecService, mockStepDefService, mockAssignmentService, logger)
-
- // Setup test data
- workflowExecutionID := uuid.New()
- stepDefID1 := uuid.New()
- stepExecID1 := uuid.New()
-
- stepExecutions := []workflows.StepExecution{
- {
- UUIDModel: relational.UUIDModel{ID: &stepExecID1},
- WorkflowExecutionID: &workflowExecutionID,
- WorkflowStepDefinitionID: &stepDefID1,
- Status: "in_progress",
- },
+ dependentExec := workflows.StepExecution{
+ UUIDModel: relational.UUIDModel{ID: &dependentStepExecID},
+ WorkflowExecutionID: &workflowExecutionID,
+ WorkflowStepDefinitionID: &dependentStepDefID,
+ Status: StatusBlocked.String(),
+ }
+ reloadedDependent := &workflows.StepExecution{
+ UUIDModel: relational.UUIDModel{ID: &dependentStepExecID},
+ WorkflowExecutionID: &workflowExecutionID,
+ WorkflowStepDefinitionID: &dependentStepDefID,
+ Status: StatusPending.String(),
}
- // Setup mocks
- mockWorkflowExecService.On("UpdateStatus", mock.Anything, &workflowExecutionID, "cancelled").Return(nil)
- mockStepExecService.On("GetByWorkflowExecutionID", &workflowExecutionID).Return(stepExecutions, nil)
- mockStepExecService.On("UpdateStatus", mock.Anything, &stepExecID1, "cancelled").Return(nil)
-
- // Cancel execution
- ctx := context.Background()
- err := executor.CancelExecution(ctx, &workflowExecutionID)
-
- // Verify results
+ mockStepExecService.On("GetByID", &stepExecutionID).Return(completedStep, nil).Once()
+ mockStepDefService.On("GetDependentSteps", &completedStepDefID).Return([]workflows.WorkflowStepDefinition{
+ {UUIDModel: relational.UUIDModel{ID: &dependentStepDefID}},
+ }, nil).Once()
+ mockStepExecService.On("GetByWorkflowExecutionID", &workflowExecutionID).Return([]workflows.StepExecution{
+ dependentExec,
+ }, nil).Twice()
+ mockStepExecService.On("CanUnblock", &dependentStepExecID).Return(true, nil).Once()
+ mockStepExecService.On("Unblock", &dependentStepExecID).Return(nil).Once()
+ mockStepExecService.On("GetByID", &dependentStepExecID).Return(reloadedDependent, nil).Once()
+
+ type ctxKey string
+ const traceKey ctxKey = "trace_id"
+ ctx := context.WithValue(context.Background(), traceKey, "trace-123")
+ mockNotificationEnqueuer.On(
+ "EnqueueWorkflowTaskAssigned",
+ mock.MatchedBy(func(c context.Context) bool {
+ return c != nil && c.Value(traceKey) == "trace-123"
+ }),
+ reloadedDependent,
+ ).Return(nil).Once()
+
+ err := executor.ProcessStepCompletion(ctx, &stepExecutionID)
require.NoError(t, err)
- // Verify mocks were called
- mockWorkflowExecService.AssertExpectations(t)
mockStepExecService.AssertExpectations(t)
-}
-
-// Helper function to create a test executor
-func createTestExecutor(t *testing.T) *DAGExecutor {
- mockStepExecService := &MockStepExecutionService{}
- mockWorkflowExecService := &MockWorkflowExecutionService{}
- mockStepDefService := &MockWorkflowStepDefinitionService{}
- mockAssignmentService := &MockAssignmentService{}
- logger := log.New(bytes.NewBufferString(""), "", log.LstdFlags)
-
- return NewDAGExecutor(mockStepExecService, mockWorkflowExecService, mockStepDefService, mockAssignmentService, logger)
-}
-
-// Benchmark tests
-func BenchmarkGetReadySteps(b *testing.B) {
- executor := createTestExecutor(&testing.T{})
-
- // Create a large state with many steps
- state := &ExecutionState{
- StepStates: make(map[uuid.UUID]*StepState),
- CompletedSteps: make(map[uuid.UUID]bool),
- FailedSteps: make(map[uuid.UUID]bool),
- RunningSteps: make(map[uuid.UUID]bool),
- BlockedSteps: make(map[uuid.UUID]bool),
- }
-
- // Create 1000 steps
- for i := 0; i < 1000; i++ {
- stepID := uuid.New()
- state.StepStates[stepID] = &StepState{
- StepDefinitionID: stepID,
- Status: "pending",
- Dependencies: []uuid.UUID{},
- }
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- readySteps := executor.getReadySteps(state)
- if len(readySteps) != 1000 {
- b.Fatal("Unexpected number of ready steps")
- }
- }
-}
-
-func BenchmarkAreDependenciesCompleted(b *testing.B) {
- executor := createTestExecutor(&testing.T{})
-
- state := &ExecutionState{
- CompletedSteps: make(map[uuid.UUID]bool),
- }
-
- // Create many dependencies
- dependencies := make([]uuid.UUID, 100)
- for i := 0; i < 100; i++ {
- stepID := uuid.New()
- dependencies[i] = stepID
- state.CompletedSteps[stepID] = true
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- completed := executor.areDependenciesCompleted(state, dependencies)
- if !completed {
- b.Fatal("Dependencies should be completed")
- }
- }
+ mockStepDefService.AssertExpectations(t)
+ mockNotificationEnqueuer.AssertExpectations(t)
}
diff --git a/internal/workflow/grace.go b/internal/workflow/grace.go
new file mode 100644
index 00000000..7d7ed93e
--- /dev/null
+++ b/internal/workflow/grace.go
@@ -0,0 +1,18 @@
+package workflow
+
+import "github.com/compliance-framework/api/internal/service/relational/workflows"
+
+// ResolveGraceDays returns the effective grace period days for a workflow instance,
+// falling back through: instance → definition → provided default.
+func ResolveGraceDays(instance *workflows.WorkflowInstance, defaultDays int) int {
+ if instance == nil {
+ return defaultDays
+ }
+ if instance.GracePeriodDays != nil {
+ return *instance.GracePeriodDays
+ }
+ if instance.WorkflowDefinition != nil && instance.WorkflowDefinition.GracePeriodDays != nil {
+ return *instance.WorkflowDefinition.GracePeriodDays
+ }
+ return defaultDays
+}
diff --git a/internal/workflow/grace_test.go b/internal/workflow/grace_test.go
new file mode 100644
index 00000000..26ad7440
--- /dev/null
+++ b/internal/workflow/grace_test.go
@@ -0,0 +1,39 @@
+package workflow
+
+import (
+ "testing"
+
+ "github.com/compliance-framework/api/internal/service/relational/workflows"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestResolveGraceDays_NilInstance(t *testing.T) {
+ assert.Equal(t, 7, ResolveGraceDays(nil, 7))
+}
+
+func TestResolveGraceDays_InstanceOverride(t *testing.T) {
+ days := 3
+ instance := &workflows.WorkflowInstance{
+ GracePeriodDays: &days,
+ WorkflowDefinition: &workflows.WorkflowDefinition{
+ GracePeriodDays: intPtr(14),
+ },
+ }
+ assert.Equal(t, 3, ResolveGraceDays(instance, 7))
+}
+
+func TestResolveGraceDays_DefinitionFallback(t *testing.T) {
+ instance := &workflows.WorkflowInstance{
+ WorkflowDefinition: &workflows.WorkflowDefinition{
+ GracePeriodDays: intPtr(14),
+ },
+ }
+ assert.Equal(t, 14, ResolveGraceDays(instance, 7))
+}
+
+func TestResolveGraceDays_DefaultFallback(t *testing.T) {
+ instance := &workflows.WorkflowInstance{}
+ assert.Equal(t, 7, ResolveGraceDays(instance, 7))
+}
+
+func intPtr(i int) *int { return &i }
diff --git a/internal/workflow/jobs.go b/internal/workflow/jobs.go
index 15c04c1f..a61c54e3 100644
--- a/internal/workflow/jobs.go
+++ b/internal/workflow/jobs.go
@@ -7,20 +7,12 @@ import (
"github.com/google/uuid"
"github.com/riverqueue/river"
+ "go.uber.org/zap"
)
-// Logger interface for logging
-type Logger interface {
- Infow(msg string, keysAndValues ...interface{})
- Errorw(msg string, keysAndValues ...interface{})
- Warnw(msg string, keysAndValues ...interface{})
- Debugw(msg string, keysAndValues ...interface{})
-}
-
// Job types for workflow processing
const (
JobTypeExecuteWorkflow = "execute_workflow"
- JobTypeExecuteStep = "execute_step"
JobTypeScheduleWorkflows = "schedule_workflows"
)
@@ -36,41 +28,26 @@ type ScheduleWorkflowsArgs struct {
// No arguments needed for the periodic scheduler job
}
-// ExecuteStepArgs represents the arguments for executing a single workflow step
-type ExecuteStepArgs struct {
- WorkflowExecutionID uuid.UUID `json:"workflow_execution_id"`
- WorkflowStepDefinitionID uuid.UUID `json:"workflow_step_definition_id"`
- StepExecutionID uuid.UUID `json:"step_execution_id"`
-}
-
// Kind returns the job kind for River
func (ExecuteWorkflowArgs) Kind() string { return JobTypeExecuteWorkflow }
// Kind returns the job kind for River
func (ScheduleWorkflowsArgs) Kind() string { return JobTypeScheduleWorkflows }
-// Kind returns the job kind for River
-func (ExecuteStepArgs) Kind() string { return JobTypeExecuteStep }
-
// Timeout returns the timeout for workflow execution jobs
func (ExecuteWorkflowArgs) Timeout() time.Duration {
return 30 * time.Minute // Workflows can take longer
}
-// Timeout returns the timeout for step execution jobs
-func (ExecuteStepArgs) Timeout() time.Duration {
- return 5 * time.Minute // Individual steps should be faster
-}
-
// WorkflowExecutionWorker handles workflow execution jobs
type WorkflowExecutionWorker struct {
executor *DAGExecutor
evidenceIntegration *EvidenceIntegration
- logger Logger
+ logger *zap.SugaredLogger
}
// NewWorkflowExecutionWorker creates a new WorkflowExecutionWorker
-func NewWorkflowExecutionWorker(executor *DAGExecutor, evidenceIntegration *EvidenceIntegration, logger Logger) *WorkflowExecutionWorker {
+func NewWorkflowExecutionWorker(executor *DAGExecutor, evidenceIntegration *EvidenceIntegration, logger *zap.SugaredLogger) *WorkflowExecutionWorker {
return &WorkflowExecutionWorker{
executor: executor,
evidenceIntegration: evidenceIntegration,
@@ -125,58 +102,6 @@ func (w *WorkflowExecutionWorker) Work(ctx context.Context, job *river.Job[Execu
return nil
}
-// StepExecutionWorker handles individual step execution jobs
-type StepExecutionWorker struct {
- stepExecutionService StepExecutionServiceInterface
- logger Logger
-}
-
-// NewStepExecutionWorker creates a new StepExecutionWorker
-func NewStepExecutionWorker(stepExecutionService StepExecutionServiceInterface, logger Logger) *StepExecutionWorker {
- return &StepExecutionWorker{
- stepExecutionService: stepExecutionService,
- logger: logger,
- }
-}
-
-// Work is the River work function for executing individual steps
-// NOTE: In Phase 1, steps are manually executed by users via the StepTransitionService.
-// This worker is reserved for future automatic step execution (Phase 5).
-// For now, it only logs that a step execution was requested but does not auto-complete it.
-func (w *StepExecutionWorker) Work(ctx context.Context, job *river.Job[ExecuteStepArgs]) error {
- args := job.Args
-
- w.logger.Infow("Step execution job received (manual execution mode - no auto-completion)",
- "job_id", job.ID,
- "workflow_execution_id", args.WorkflowExecutionID,
- "step_definition_id", args.WorkflowStepDefinitionID,
- "step_execution_id", args.StepExecutionID,
- )
-
- // Get the step execution to verify it exists
- stepExec, err := w.stepExecutionService.GetByID(&args.StepExecutionID)
- if err != nil {
- w.logger.Errorw("Failed to get step execution",
- "job_id", job.ID,
- "step_execution_id", args.StepExecutionID,
- "error", err,
- )
- return fmt.Errorf("failed to get step execution: %w", err)
- }
-
- w.logger.Infow("Step execution verified - awaiting manual user action",
- "job_id", job.ID,
- "step_execution_id", args.StepExecutionID,
- "current_status", stepExec.Status,
- )
-
- // Phase 1: Manual execution only - users must transition steps via StepTransitionService
- // Phase 5: This worker will handle automatic step execution based on triggers
- // TODO: Implement automatic step execution logic for Phase 5
-
- return nil
-}
-
// JobInsertOptionsForWorkflow returns insert options for workflow execution jobs
func JobInsertOptionsForWorkflow() *river.InsertOpts {
return &river.InsertOpts{
@@ -186,15 +111,6 @@ func JobInsertOptionsForWorkflow() *river.InsertOpts {
}
}
-// JobInsertOptionsForStep returns insert options for step execution jobs
-func JobInsertOptionsForStep() *river.InsertOpts {
- return &river.InsertOpts{
- Queue: "steps",
- MaxAttempts: 5, // More retries for individual steps
- Priority: 2, // Lower priority than workflow jobs
- }
-}
-
// JobInsertOptionsForScheduler returns insert options for the scheduler job
func JobInsertOptionsForScheduler() *river.InsertOpts {
return &river.InsertOpts{
diff --git a/internal/workflow/manager.go b/internal/workflow/manager.go
index 6ef95bbc..b6a550ea 100644
--- a/internal/workflow/manager.go
+++ b/internal/workflow/manager.go
@@ -8,20 +8,26 @@ import (
"github.com/compliance-framework/api/internal/service/relational/workflows"
"github.com/google/uuid"
- "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/riverqueue/river"
+ "github.com/riverqueue/river/rivertype"
"go.uber.org/zap"
)
var ErrWorkflowExecutionAlreadyExists = errors.New("workflow execution already exists for instance and period")
+// RiverClient interface for job enqueueing (enables testing)
+type RiverClient interface {
+ InsertMany(ctx context.Context, params []river.InsertManyParams) ([]*rivertype.JobInsertResult, error)
+}
+
// Manager orchestrates workflow execution lifecycle using River for async operations
type Manager struct {
riverClient RiverClient
workflowExecutionService WorkflowExecutionServiceInterface
workflowInstanceService WorkflowInstanceServiceInterface
stepExecutionService StepExecutionServiceInterface
+ notificationEnqueuer NotificationEnqueuer // Optional: for workflow notification emails
logger *zap.SugaredLogger
}
@@ -32,6 +38,7 @@ func NewManager(
workflowInstanceService WorkflowInstanceServiceInterface,
stepExecutionService StepExecutionServiceInterface,
logger *zap.SugaredLogger,
+ notificationEnqueuer NotificationEnqueuer,
) *Manager {
return &Manager{
riverClient: riverClient,
@@ -39,26 +46,10 @@ func NewManager(
workflowInstanceService: workflowInstanceService,
stepExecutionService: stepExecutionService,
logger: logger,
+ notificationEnqueuer: notificationEnqueuer,
}
}
-// NewManagerWithRiver creates a manager with a River client and concrete services
-func NewManagerWithRiver(
- riverClient *river.Client[pgx.Tx],
- workflowExecutionService *workflows.WorkflowExecutionService,
- workflowInstanceService *workflows.WorkflowInstanceService,
- stepExecutionService *workflows.StepExecutionService,
- logger *zap.SugaredLogger,
-) *Manager {
- return NewManager(
- riverClient,
- workflowExecutionService,
- workflowInstanceService,
- stepExecutionService,
- logger,
- )
-}
-
// StartWorkflowOptions contains options for starting a workflow execution
type StartWorkflowOptions struct {
TriggeredBy string
@@ -68,7 +59,7 @@ type StartWorkflowOptions struct {
}
// StartWorkflowExecution creates and starts a workflow execution via River
-func (m *Manager) StartWorkflowExecution(ctx context.Context, workflowInstanceID *uuid.UUID, opts StartWorkflowOptions) (*uuid.UUID, error) {
+func (m *Manager) StartWorkflowExecution(ctx context.Context, workflowInstanceID *uuid.UUID, opts StartWorkflowOptions) (*workflows.WorkflowExecution, error) {
m.logger.Infow("Starting workflow execution",
"workflow_instance_id", workflowInstanceID,
"triggered_by", opts.TriggeredBy,
@@ -126,6 +117,12 @@ func (m *Manager) StartWorkflowExecution(ctx context.Context, workflowInstanceID
// Mark execution as failed
if failErr := m.workflowExecutionService.Fail(execution.ID, fmt.Sprintf("Failed to enqueue job: %v", err)); failErr != nil {
m.logger.Errorw("Failed to mark execution as failed", "error", failErr)
+ } else if m.notificationEnqueuer != nil {
+ if reloaded, reloadErr := m.workflowExecutionService.GetByID(execution.ID); reloadErr == nil {
+ if notifyErr := m.notificationEnqueuer.EnqueueWorkflowExecutionFailed(ctx, reloaded); notifyErr != nil {
+ m.logger.Errorw("Failed to enqueue workflow-execution-failed notification", "error", notifyErr)
+ }
+ }
}
return nil, fmt.Errorf("failed to enqueue workflow execution job: %w", err)
}
@@ -135,7 +132,7 @@ func (m *Manager) StartWorkflowExecution(ctx context.Context, workflowInstanceID
"job_kind", JobTypeExecuteWorkflow,
)
- return execution.ID, nil
+ return execution, nil
}
// GetExecutionStatus returns the current status of a workflow execution
@@ -152,38 +149,19 @@ func (m *Manager) GetExecutionStatus(ctx context.Context, executionID *uuid.UUID
return nil, fmt.Errorf("failed to get step executions: %w", err)
}
- // Count steps by status
- var pending, blocked, inProgress, overdue, completed, failed, cancelled int
- for _, step := range stepExecutions {
- switch step.Status {
- case "pending":
- pending++
- case "blocked":
- blocked++
- case "in_progress":
- inProgress++
- case "overdue":
- overdue++
- case "completed":
- completed++
- case "failed":
- failed++
- case "cancelled":
- cancelled++
- }
- }
+ counts := CountStepStatuses(stepExecutions)
status := &ExecutionStatus{
ExecutionID: *executionID,
Status: execution.Status,
TotalSteps: len(stepExecutions),
- PendingSteps: pending,
- BlockedSteps: blocked,
- InProgressSteps: inProgress,
- OverdueSteps: overdue,
- CompletedSteps: completed,
- FailedSteps: failed,
- CancelledSteps: cancelled,
+ PendingSteps: counts.Pending,
+ BlockedSteps: counts.Blocked,
+ InProgressSteps: counts.InProgress,
+ OverdueSteps: counts.Overdue,
+ CompletedSteps: counts.Completed,
+ FailedSteps: counts.Failed,
+ CancelledSteps: counts.Cancelled,
StartedAt: execution.StartedAt,
CompletedAt: execution.CompletedAt,
FailedAt: execution.FailedAt,
@@ -194,7 +172,7 @@ func (m *Manager) GetExecutionStatus(ctx context.Context, executionID *uuid.UUID
}
// CancelExecution cancels a running workflow execution
-func (m *Manager) CancelExecution(ctx context.Context, executionID *uuid.UUID, reason string) error {
+func (m *Manager) CancelExecution(ctx context.Context, executionID *uuid.UUID, reason string) (*workflows.WorkflowExecution, error) {
m.logger.Infow("Cancelling workflow execution",
"execution_id", executionID,
"reason", reason,
@@ -203,29 +181,32 @@ func (m *Manager) CancelExecution(ctx context.Context, executionID *uuid.UUID, r
// Get workflow execution
execution, err := m.workflowExecutionService.GetByID(executionID)
if err != nil {
- return fmt.Errorf("failed to get workflow execution: %w", err)
+ return nil, fmt.Errorf("failed to get workflow execution: %w", err)
}
// Check if execution can be cancelled
if execution.Status == "completed" || execution.Status == "failed" || execution.Status == "cancelled" {
- return fmt.Errorf("cannot cancel execution in status: %s", execution.Status)
+ return nil, fmt.Errorf("cannot cancel execution in status: %s", execution.Status)
}
// Update execution status
if err := m.workflowExecutionService.Cancel(executionID); err != nil {
- return fmt.Errorf("failed to cancel workflow execution: %w", err)
+ return nil, fmt.Errorf("failed to cancel workflow execution: %w", err)
}
- // Cancel all in-progress and pending steps
+ // Cancel all non-terminal actionable steps
stepExecutions, err := m.stepExecutionService.GetByWorkflowExecutionID(executionID)
if err != nil {
- return fmt.Errorf("failed to get step executions: %w", err)
+ return nil, fmt.Errorf("failed to get step executions: %w", err)
}
for _, step := range stepExecutions {
- if step.Status == "in_progress" || step.Status == "pending" || step.Status == "blocked" {
+ if step.Status == workflows.StepStatusInProgress.String() ||
+ step.Status == workflows.StepStatusPending.String() ||
+ step.Status == workflows.StepStatusBlocked.String() ||
+ step.Status == workflows.StepStatusOverdue.String() {
// Update step status to cancelled
- if err := m.stepExecutionService.UpdateStatus(ctx, step.ID, "cancelled"); err != nil {
+ if err := m.stepExecutionService.UpdateStatus(ctx, step.ID, StatusCancelled.String()); err != nil {
m.logger.Warnw("Failed to cancel step execution",
"step_execution_id", step.ID,
"error", err,
@@ -238,11 +219,13 @@ func (m *Manager) CancelExecution(ctx context.Context, executionID *uuid.UUID, r
"execution_id", executionID,
)
- return nil
+ // Return the updated execution for immediate API responses without extra handler read.
+ execution.Status = workflows.WorkflowStatusCancelled.String()
+ return execution, nil
}
// RetryExecution creates a new execution for a failed workflow
-func (m *Manager) RetryExecution(ctx context.Context, executionID *uuid.UUID) (*uuid.UUID, error) {
+func (m *Manager) RetryExecution(ctx context.Context, executionID *uuid.UUID) (*workflows.WorkflowExecution, error) {
m.logger.Infow("Retrying workflow execution",
"original_execution_id", executionID,
)
@@ -267,7 +250,7 @@ func (m *Manager) RetryExecution(ctx context.Context, executionID *uuid.UUID) (*
TriggeredByID: executionID.String(),
}
- newExecutionID, err := m.StartWorkflowExecution(
+ newExecution, err := m.StartWorkflowExecution(
ctx,
execution.WorkflowInstanceID,
opts,
@@ -278,10 +261,10 @@ func (m *Manager) RetryExecution(ctx context.Context, executionID *uuid.UUID) (*
m.logger.Infow("Workflow execution retry started",
"original_execution_id", executionID,
- "new_execution_id", newExecutionID,
+ "new_execution_id", newExecution.ID,
)
- return newExecutionID, nil
+ return newExecution, nil
}
// ListExecutions returns workflow executions for a workflow instance
diff --git a/internal/workflow/manager_test.go b/internal/workflow/manager_test.go
index a0fac2dd..9c20a3a0 100644
--- a/internal/workflow/manager_test.go
+++ b/internal/workflow/manager_test.go
@@ -6,15 +6,30 @@ import (
"testing"
"time"
+ "github.com/compliance-framework/api/internal/service/relational"
"github.com/compliance-framework/api/internal/service/relational/workflows"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgconn"
+ "github.com/riverqueue/river"
+ "github.com/riverqueue/river/rivertype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
+type MockRiverClient struct {
+ mock.Mock
+}
+
+func (m *MockRiverClient) InsertMany(ctx context.Context, params []river.InsertManyParams) ([]*rivertype.JobInsertResult, error) {
+ args := m.Called(ctx, params)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).([]*rivertype.JobInsertResult), args.Error(1)
+}
+
type MockWorkflowInstanceService struct {
mock.Mock
}
@@ -61,6 +76,7 @@ func TestManager_StartWorkflowExecution_UniqueViolationScheduledReturnsAlreadyEx
mockWorkflowInstService,
mockStepExecService,
logger,
+ nil,
)
mockWorkflowInstService.On("GetByID", &instanceID).Return(&workflows.WorkflowInstance{IsActive: true}, nil).Once()
@@ -71,9 +87,9 @@ func TestManager_StartWorkflowExecution_UniqueViolationScheduledReturnsAlreadyEx
PeriodLabel: "2026-01",
}
- executionID, err := manager.StartWorkflowExecution(ctx, &instanceID, opts)
+ execution, err := manager.StartWorkflowExecution(ctx, &instanceID, opts)
require.Error(t, err)
- assert.Nil(t, executionID)
+ assert.Nil(t, execution)
assert.True(t, errors.Is(err, ErrWorkflowExecutionAlreadyExists))
mockClient.AssertNotCalled(t, "InsertMany", mock.Anything, mock.Anything)
@@ -97,6 +113,7 @@ func TestManager_StartWorkflowExecution_UniqueViolationManualDoesNotReturnAlread
mockWorkflowInstService,
mockStepExecService,
logger,
+ nil,
)
mockWorkflowInstService.On("GetByID", &instanceID).Return(&workflows.WorkflowInstance{IsActive: true}, nil).Once()
@@ -108,9 +125,57 @@ func TestManager_StartWorkflowExecution_UniqueViolationManualDoesNotReturnAlread
PeriodLabel: "2026-01",
}
- executionID, err := manager.StartWorkflowExecution(ctx, &instanceID, opts)
+ execution, err := manager.StartWorkflowExecution(ctx, &instanceID, opts)
require.Error(t, err)
- assert.Nil(t, executionID)
+ assert.Nil(t, execution)
assert.False(t, errors.Is(err, ErrWorkflowExecutionAlreadyExists))
assert.Contains(t, err.Error(), "failed to create workflow execution")
}
+
+func TestManager_CancelExecution_CancelsOverdueSteps(t *testing.T) {
+ ctx := context.Background()
+ logger := zap.NewNop().Sugar()
+
+ executionID := uuid.New()
+ overdueStepID := uuid.New()
+ completedStepID := uuid.New()
+
+ mockClient := &MockRiverClient{}
+ mockWorkflowExecService := &MockWorkflowExecutionService{}
+ mockWorkflowInstService := &MockWorkflowInstanceService{}
+ mockStepExecService := &MockStepExecutionService{}
+
+ manager := NewManager(
+ mockClient,
+ mockWorkflowExecService,
+ mockWorkflowInstService,
+ mockStepExecService,
+ logger,
+ nil,
+ )
+
+ mockWorkflowExecService.On("GetByID", &executionID).Return(&workflows.WorkflowExecution{
+ Status: workflows.WorkflowStatusInProgress.String(),
+ }, nil).Once()
+ mockWorkflowExecService.On("Cancel", &executionID).Return(nil).Once()
+ mockStepExecService.On("GetByWorkflowExecutionID", &executionID).Return([]workflows.StepExecution{
+ {
+ UUIDModel: relational.UUIDModel{ID: &overdueStepID},
+ Status: workflows.StepStatusOverdue.String(),
+ },
+ {
+ UUIDModel: relational.UUIDModel{ID: &completedStepID},
+ Status: workflows.StepStatusCompleted.String(),
+ },
+ }, nil).Once()
+ mockStepExecService.On("UpdateStatus", ctx, &overdueStepID, StatusCancelled.String()).Return(nil).Once()
+
+ execution, err := manager.CancelExecution(ctx, &executionID, "user requested cancellation")
+ require.NoError(t, err)
+ require.NotNil(t, execution)
+ assert.Equal(t, workflows.WorkflowStatusCancelled.String(), execution.Status)
+
+ mockStepExecService.AssertNotCalled(t, "UpdateStatus", ctx, &completedStepID, StatusCancelled.String())
+ mockWorkflowExecService.AssertExpectations(t)
+ mockStepExecService.AssertExpectations(t)
+}
diff --git a/internal/workflow/overdue.go b/internal/workflow/overdue.go
index 0e68f0bd..f42905ac 100644
--- a/internal/workflow/overdue.go
+++ b/internal/workflow/overdue.go
@@ -6,7 +6,7 @@ import (
"time"
"github.com/compliance-framework/api/internal/service/relational/workflows"
- "github.com/google/uuid"
+ "go.uber.org/zap"
"gorm.io/gorm"
)
@@ -16,7 +16,8 @@ type OverdueService struct {
workflowExecutionService *workflows.WorkflowExecutionService
stepExecutionService *workflows.StepExecutionService
evidenceIntegration *EvidenceIntegration
- logger Logger
+ notificationEnqueuer NotificationEnqueuer // Optional: for workflow notification emails
+ logger *zap.SugaredLogger
defaultGracePeriodDays int
}
@@ -25,8 +26,9 @@ func NewOverdueService(
workflowExecutionService *workflows.WorkflowExecutionService,
stepExecutionService *workflows.StepExecutionService,
evidenceIntegration *EvidenceIntegration,
- logger Logger,
+ logger *zap.SugaredLogger,
defaultGracePeriodDays int,
+ notificationEnqueuer NotificationEnqueuer,
) *OverdueService {
return &OverdueService{
db: db,
@@ -35,6 +37,7 @@ func NewOverdueService(
evidenceIntegration: evidenceIntegration,
logger: logger,
defaultGracePeriodDays: defaultGracePeriodDays,
+ notificationEnqueuer: notificationEnqueuer,
}
}
@@ -43,10 +46,23 @@ func (s *OverdueService) CheckOverdueExecutions(ctx context.Context) (int, error
var executions []workflows.WorkflowExecution
now := time.Now()
if err := s.db.WithContext(ctx).
- Where("status IN ? AND due_date IS NOT NULL AND due_date < ?", []string{
- workflows.WorkflowStatusPending.String(),
- workflows.WorkflowStatusInProgress.String(),
- }, now).
+ Where(
+ `status IN ? AND (
+ (due_date IS NOT NULL AND due_date < ?) OR
+ EXISTS (
+ SELECT 1
+ FROM step_executions se
+ WHERE se.workflow_execution_id = workflow_executions.id
+ AND se.status = ?
+ )
+ )`,
+ []string{
+ workflows.WorkflowStatusPending.String(),
+ workflows.WorkflowStatusInProgress.String(),
+ },
+ now,
+ workflows.StepStatusOverdue.String(),
+ ).
Find(&executions).Error; err != nil {
return 0, fmt.Errorf("failed to query overdue executions: %w", err)
}
@@ -86,12 +102,7 @@ func (s *OverdueService) CheckOverdueSteps(ctx context.Context) (int, error) {
updated++
}
- executionUpdates, err := s.markExecutionsOverdueFromStepOverdue(ctx)
- if err != nil {
- return updated, err
- }
-
- s.logger.Infow("Checked overdue workflow steps", "found", len(steps), "updated", updated, "execution_status_updated", executionUpdates)
+ s.logger.Infow("Checked overdue workflow steps", "found", len(steps), "updated", updated)
return updated, nil
}
@@ -110,7 +121,7 @@ func (s *OverdueService) CheckFailedExecutions(ctx context.Context) (int, error)
failed := 0
for i := range overdueExecutions {
exec := overdueExecutions[i]
- graceDays := s.resolveExecutionGraceDays(&exec)
+ graceDays := ResolveGraceDays(exec.WorkflowInstance, s.defaultGracePeriodDays)
if exec.OverdueAt == nil || exec.OverdueAt.AddDate(0, 0, graceDays).After(now) {
continue
}
@@ -120,23 +131,17 @@ func (s *OverdueService) CheckFailedExecutions(ctx context.Context) (int, error)
continue
}
failed++
+ if s.notificationEnqueuer != nil {
+ if notifyErr := s.notificationEnqueuer.EnqueueWorkflowExecutionFailed(ctx, &exec); notifyErr != nil {
+ s.logger.Errorw("Failed to enqueue workflow-execution-failed notification", "workflow_execution_id", exec.ID, "error", notifyErr)
+ }
+ }
}
s.logger.Infow("Checked failed workflow executions", "checked", len(overdueExecutions), "failed", failed)
return failed, nil
}
-func (s *OverdueService) resolveExecutionGraceDays(execution *workflows.WorkflowExecution) int {
- if execution.WorkflowInstance != nil && execution.WorkflowInstance.GracePeriodDays != nil {
- return *execution.WorkflowInstance.GracePeriodDays
- }
- if execution.WorkflowInstance != nil && execution.WorkflowInstance.WorkflowDefinition != nil &&
- execution.WorkflowInstance.WorkflowDefinition.GracePeriodDays != nil {
- return *execution.WorkflowInstance.WorkflowDefinition.GracePeriodDays
- }
- return s.defaultGracePeriodDays
-}
-
func (s *OverdueService) failExecutionAndSteps(ctx context.Context, execution *workflows.WorkflowExecution) error {
now := time.Now()
failureReason := "overdue - grace period expired"
@@ -158,18 +163,7 @@ func (s *OverdueService) failExecutionAndSteps(ctx context.Context, execution *w
}
executionFailed = true
- if err := tx.Model(&workflows.StepExecution{}).
- Where("workflow_execution_id = ? AND status IN ?", execution.ID, []string{
- workflows.StepStatusPending.String(),
- workflows.StepStatusBlocked.String(),
- workflows.StepStatusInProgress.String(),
- workflows.StepStatusOverdue.String(),
- }).
- Updates(map[string]interface{}{
- "status": workflows.StepStatusFailed.String(),
- "failed_at": now,
- "failure_reason": failureReason,
- }).Error; err != nil {
+ if err := s.stepExecutionService.BulkFailWithTx(tx, execution.ID, failureReason, now); err != nil {
return err
}
@@ -189,30 +183,3 @@ func (s *OverdueService) failExecutionAndSteps(ctx context.Context, execution *w
return nil
}
-
-func (s *OverdueService) markExecutionsOverdueFromStepOverdue(ctx context.Context) (int, error) {
- var executionIDs []uuid.UUID
- if err := s.db.WithContext(ctx).
- Table("step_executions se").
- Select("DISTINCT se.workflow_execution_id").
- Joins("JOIN workflow_executions we ON we.id = se.workflow_execution_id").
- Where("se.status = ? AND we.status IN ?", workflows.StepStatusOverdue.String(), []string{
- workflows.WorkflowStatusPending.String(),
- workflows.WorkflowStatusInProgress.String(),
- }).
- Scan(&executionIDs).Error; err != nil {
- return 0, fmt.Errorf("failed to query executions with overdue steps: %w", err)
- }
-
- updated := 0
- for i := range executionIDs {
- executionID := executionIDs[i]
- if err := s.workflowExecutionService.UpdateStatus(ctx, &executionID, workflows.WorkflowStatusOverdue.String()); err != nil {
- s.logger.Errorw("Failed to mark workflow execution overdue from step overdue", "workflow_execution_id", executionID, "error", err)
- continue
- }
- updated++
- }
-
- return updated, nil
-}
diff --git a/internal/workflow/overdue_test.go b/internal/workflow/overdue_test.go
index 2379067d..7cbb691b 100644
--- a/internal/workflow/overdue_test.go
+++ b/internal/workflow/overdue_test.go
@@ -108,7 +108,7 @@ func TestOverdueService_CheckOverdueTransitions(t *testing.T) {
workflowExecSvc := workflows.NewWorkflowExecutionService(db)
stepExecSvc := workflows.NewStepExecutionService(db, nil)
- svc := NewOverdueService(db, workflowExecSvc, stepExecSvc, nil, zap.NewNop().Sugar(), 7)
+ svc := NewOverdueService(db, workflowExecSvc, stepExecSvc, nil, zap.NewNop().Sugar(), 7, nil)
updatedSteps, err := svc.CheckOverdueSteps(context.Background())
require.NoError(t, err)
@@ -116,7 +116,7 @@ func TestOverdueService_CheckOverdueTransitions(t *testing.T) {
updatedExecutions, err := svc.CheckOverdueExecutions(context.Background())
require.NoError(t, err)
- assert.Equal(t, 0, updatedExecutions)
+ assert.Equal(t, 1, updatedExecutions)
var stepAfter workflows.StepExecution
require.NoError(t, db.First(&stepAfter, step.ID).Error)
@@ -146,12 +146,16 @@ func TestOverdueService_CheckFailedExecutions_StepOverduePromotesExecutionAndFai
workflowExecSvc := workflows.NewWorkflowExecutionService(db)
stepExecSvc := workflows.NewStepExecutionService(db, nil)
- svc := NewOverdueService(db, workflowExecSvc, stepExecSvc, nil, zap.NewNop().Sugar(), 0)
+ svc := NewOverdueService(db, workflowExecSvc, stepExecSvc, nil, zap.NewNop().Sugar(), 0, nil)
updatedSteps, err := svc.CheckOverdueSteps(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, updatedSteps)
+ updatedExecutions, err := svc.CheckOverdueExecutions(context.Background())
+ require.NoError(t, err)
+ assert.Equal(t, 1, updatedExecutions)
+
failed, err := svc.CheckFailedExecutions(context.Background())
require.NoError(t, err)
assert.Equal(t, 1, failed)
@@ -188,7 +192,7 @@ func TestOverdueService_CheckFailedExecutions(t *testing.T) {
workflowExecSvc := workflows.NewWorkflowExecutionService(db)
stepExecSvc := workflows.NewStepExecutionService(db, nil)
- svc := NewOverdueService(db, workflowExecSvc, stepExecSvc, nil, zap.NewNop().Sugar(), 1)
+ svc := NewOverdueService(db, workflowExecSvc, stepExecSvc, nil, zap.NewNop().Sugar(), 1, nil)
failed, err := svc.CheckFailedExecutions(context.Background())
require.NoError(t, err)
@@ -221,7 +225,7 @@ func TestOverdueService_CheckOverdueSteps_DoesNotMarkBlockedSteps(t *testing.T)
workflowExecSvc := workflows.NewWorkflowExecutionService(db)
stepExecSvc := workflows.NewStepExecutionService(db, nil)
- svc := NewOverdueService(db, workflowExecSvc, stepExecSvc, nil, zap.NewNop().Sugar(), 7)
+ svc := NewOverdueService(db, workflowExecSvc, stepExecSvc, nil, zap.NewNop().Sugar(), 7, nil)
updatedSteps, err := svc.CheckOverdueSteps(context.Background())
require.NoError(t, err)
diff --git a/internal/workflow/scheduler.go b/internal/workflow/scheduler.go
index 95e3b28d..a74bfedb 100644
--- a/internal/workflow/scheduler.go
+++ b/internal/workflow/scheduler.go
@@ -8,6 +8,7 @@ import (
"github.com/compliance-framework/api/internal/service/relational/workflows"
"github.com/riverqueue/river"
+ "go.uber.org/zap"
)
// WorkflowSchedulerWorker handles the periodic scheduling of workflows
@@ -16,7 +17,7 @@ type WorkflowSchedulerWorker struct {
workflowInstanceService WorkflowInstanceServiceInterface
overdueService *OverdueService
overdueCheckEnabled bool
- logger Logger
+ logger *zap.SugaredLogger
defaultGracePeriod int
}
@@ -26,7 +27,7 @@ func NewWorkflowSchedulerWorker(
workflowInstanceService WorkflowInstanceServiceInterface,
overdueService *OverdueService,
overdueCheckEnabled bool,
- logger Logger,
+ logger *zap.SugaredLogger,
defaultGracePeriod int,
) *WorkflowSchedulerWorker {
return &WorkflowSchedulerWorker{
@@ -108,12 +109,7 @@ func (w *WorkflowSchedulerWorker) Work(ctx context.Context, job *river.Job[Sched
periodLabel := GeneratePeriodLabel(instance.Cadence, refTime)
// Determine grace period
- gracePeriod := w.defaultGracePeriod
- if instance.GracePeriodDays != nil {
- gracePeriod = *instance.GracePeriodDays
- } else if instance.WorkflowDefinition != nil && instance.WorkflowDefinition.GracePeriodDays != nil {
- gracePeriod = *instance.WorkflowDefinition.GracePeriodDays
- }
+ gracePeriod := ResolveGraceDays(&instance, w.defaultGracePeriod)
// Calculate due date
// Due date is based on the scheduled time (when it should have run), not necessarily now
@@ -128,7 +124,7 @@ func (w *WorkflowSchedulerWorker) Work(ctx context.Context, job *river.Job[Sched
DueDate: &dueDate,
}
- executionID, err := w.manager.StartWorkflowExecution(ctx, instance.ID, options)
+ execution, err := w.manager.StartWorkflowExecution(ctx, instance.ID, options)
if err != nil {
if errors.Is(err, ErrWorkflowExecutionAlreadyExists) {
w.logger.Infow("Skipping already executed workflow instance for this period",
@@ -157,7 +153,7 @@ func (w *WorkflowSchedulerWorker) Work(ctx context.Context, job *river.Job[Sched
if err := w.workflowInstanceService.AdvanceSchedule(ctx, instance.ID); err != nil {
w.logger.Errorw("Failed to update next schedule",
"instance_id", instance.ID,
- "execution_id", executionID,
+ "execution_id", execution.ID,
"error", err,
)
// Don't fail the whole job, just log error
@@ -172,7 +168,7 @@ func (w *WorkflowSchedulerWorker) Work(ctx context.Context, job *river.Job[Sched
w.logger.Infow("Scheduled workflow execution",
"instance_id", instance.ID,
- "execution_id", executionID,
+ "execution_id", execution.ID,
"period_label", periodLabel,
)
}
diff --git a/internal/workflow/scheduler_test.go b/internal/workflow/scheduler_test.go
index 4d6daac0..3314353e 100644
--- a/internal/workflow/scheduler_test.go
+++ b/internal/workflow/scheduler_test.go
@@ -9,7 +9,6 @@ import (
"github.com/compliance-framework/api/internal/service/relational/workflows"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/mock"
)
func TestGeneratePeriodLabel(t *testing.T) {
@@ -190,27 +189,6 @@ func TestWorkflowSchedulerWorker_BasicFunctionality(t *testing.T) {
mockInstanceService.AssertExpectations(t)
}
-// Mock logger for testing
-type MockLoggerForScheduler struct {
- mock.Mock
-}
-
-func (m *MockLoggerForScheduler) Infow(msg string, keysAndValues ...interface{}) {
- m.Called(msg, keysAndValues)
-}
-
-func (m *MockLoggerForScheduler) Errorw(msg string, keysAndValues ...interface{}) {
- m.Called(msg, keysAndValues)
-}
-
-func (m *MockLoggerForScheduler) Warnw(msg string, keysAndValues ...interface{}) {
- m.Called(msg, keysAndValues)
-}
-
-func (m *MockLoggerForScheduler) Debugw(msg string, keysAndValues ...interface{}) {
- m.Called(msg, keysAndValues)
-}
-
// Helper function to create UUID pointer
func uuidPtr(u uuid.UUID) *uuid.UUID {
return &u
diff --git a/internal/workflow/service.go b/internal/workflow/service.go
deleted file mode 100644
index 48b51b04..00000000
--- a/internal/workflow/service.go
+++ /dev/null
@@ -1,62 +0,0 @@
-package workflow
-
-import (
- "context"
- "fmt"
-
- "github.com/google/uuid"
- "github.com/jackc/pgx/v5"
- "github.com/riverqueue/river"
- "github.com/riverqueue/river/rivertype"
-)
-
-// RiverClient interface for job enqueueing (enables testing)
-type RiverClient interface {
- InsertMany(ctx context.Context, params []river.InsertManyParams) ([]*rivertype.JobInsertResult, error)
-}
-
-// WorkflowService provides workflow execution capabilities via River job queue
-type WorkflowService struct {
- executor *DAGExecutor
- riverClient RiverClient
-}
-
-// NewWorkflowService creates a new workflow service
-func NewWorkflowService(executor *DAGExecutor, riverClient RiverClient) (*WorkflowService, error) {
- if riverClient == nil {
- return nil, fmt.Errorf("river client is required for workflow service")
- }
- return &WorkflowService{
- executor: executor,
- riverClient: riverClient,
- }, nil
-}
-
-// NewWorkflowServiceWithRiver creates a workflow service with a River client
-func NewWorkflowServiceWithRiver(executor *DAGExecutor, riverClient *river.Client[pgx.Tx]) (*WorkflowService, error) {
- return NewWorkflowService(executor, riverClient)
-}
-
-// EnqueueWorkflowExecution enqueues a workflow execution job
-func (s *WorkflowService) EnqueueWorkflowExecution(ctx context.Context, workflowExecutionID *uuid.UUID, triggeredBy, triggeredByID string) error {
- args := &ExecuteWorkflowArgs{
- WorkflowExecutionID: *workflowExecutionID,
- TriggeredBy: triggeredBy,
- TriggeredByID: triggeredByID,
- }
-
- // Insert the job
- _, err := s.riverClient.InsertMany(ctx, []river.InsertManyParams{
- {Args: args, InsertOpts: JobInsertOptionsForWorkflow()},
- })
- if err != nil {
- return fmt.Errorf("failed to enqueue workflow execution job: %w", err)
- }
-
- return nil
-}
-
-// GetExecutor returns the underlying DAG executor (for testing)
-func (s *WorkflowService) GetExecutor() *DAGExecutor {
- return s.executor
-}
diff --git a/internal/workflow/service_test.go b/internal/workflow/service_test.go
deleted file mode 100644
index e6b3c538..00000000
--- a/internal/workflow/service_test.go
+++ /dev/null
@@ -1,106 +0,0 @@
-package workflow
-
-import (
- "context"
- "testing"
-
- "github.com/google/uuid"
- "github.com/riverqueue/river"
- "github.com/riverqueue/river/rivertype"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/mock"
- "github.com/stretchr/testify/require"
-)
-
-// MockRiverClient mocks the River client for testing
-type MockRiverClient struct {
- mock.Mock
-}
-
-func (m *MockRiverClient) InsertMany(ctx context.Context, params []river.InsertManyParams) ([]*rivertype.JobInsertResult, error) {
- args := m.Called(ctx, params)
- if args.Get(0) == nil {
- return nil, args.Error(1)
- }
- return args.Get(0).([]*rivertype.JobInsertResult), args.Error(1)
-}
-
-func TestNewWorkflowService(t *testing.T) {
- executor := &DAGExecutor{}
-
- t.Run("Success", func(t *testing.T) {
- mockClient := &MockRiverClient{}
-
- service, err := NewWorkflowService(executor, mockClient)
-
- require.NoError(t, err)
- assert.NotNil(t, service)
- assert.Equal(t, executor, service.executor)
- assert.Equal(t, mockClient, service.riverClient)
- })
-
- t.Run("ErrorWhenRiverClientIsNil", func(t *testing.T) {
- service, err := NewWorkflowService(executor, nil)
-
- require.Error(t, err)
- assert.Nil(t, service)
- assert.Contains(t, err.Error(), "river client is required")
- })
-}
-
-func TestWorkflowService_EnqueueWorkflowExecution(t *testing.T) {
- executor := &DAGExecutor{}
- mockClient := &MockRiverClient{}
-
- service, err := NewWorkflowService(executor, mockClient)
- require.NoError(t, err)
-
- t.Run("Success", func(t *testing.T) {
- ctx := context.Background()
- executionID := uuid.New()
- triggeredBy := "manual"
- triggeredByID := "user-123"
-
- // Mock successful job insertion
- mockClient.On("InsertMany", ctx, mock.AnythingOfType("[]river.InsertManyParams")).Return(
- []*rivertype.JobInsertResult{},
- nil,
- ).Once()
-
- err := service.EnqueueWorkflowExecution(ctx, &executionID, triggeredBy, triggeredByID)
-
- assert.NoError(t, err)
- mockClient.AssertExpectations(t)
- })
-
- t.Run("ErrorWhenInsertFails", func(t *testing.T) {
- ctx := context.Background()
- executionID := uuid.New()
- triggeredBy := "manual"
- triggeredByID := "user-123"
-
- // Mock failed job insertion
- mockClient.On("InsertMany", ctx, mock.AnythingOfType("[]river.InsertManyParams")).Return(
- nil,
- assert.AnError,
- ).Once()
-
- err := service.EnqueueWorkflowExecution(ctx, &executionID, triggeredBy, triggeredByID)
-
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "failed to enqueue workflow execution job")
- mockClient.AssertExpectations(t)
- })
-}
-
-func TestWorkflowService_GetExecutor(t *testing.T) {
- executor := &DAGExecutor{}
- mockClient := &MockRiverClient{}
-
- service, err := NewWorkflowService(executor, mockClient)
- require.NoError(t, err)
-
- retrievedExecutor := service.GetExecutor()
-
- assert.Equal(t, executor, retrievedExecutor)
-}
diff --git a/internal/workflow/step_transition.go b/internal/workflow/step_transition.go
index b216aa90..7773c66f 100644
--- a/internal/workflow/step_transition.go
+++ b/internal/workflow/step_transition.go
@@ -138,10 +138,14 @@ func (s *StepTransitionService) TransitionStepStatus(ctx context.Context, stepEx
return fmt.Errorf("failed to update step status: %w", err)
}
+ // Keep the in-memory copy aligned with the persisted status so downstream
+ // evidence labeling reflects the new transition state without another read.
+ stepExecution.Status = request.Status
+
// If transitioning to completed, process the completion
if request.Status == StatusCompleted.String() {
// Store submitted evidence
- if err := s.storeStepEvidence(stepExecutionID, request.Evidence, request.UserID); err != nil {
+ if err := s.storeStepEvidence(ctx, stepExecution, stepDef, workflowExecution, request.Evidence, request.UserID); err != nil {
return fmt.Errorf("failed to store evidence: %w", err)
}
@@ -250,19 +254,26 @@ func (s *StepTransitionService) validateEvidenceRequirements(stepDef *workflows.
// storeStepEvidence stores the submitted evidence for a step execution as relational.Evidence
// with BackMatter resources for uploaded files and proper labels for the workflow execution stream
-func (s *StepTransitionService) storeStepEvidence(stepExecutionID *uuid.UUID, evidenceSubmissions []EvidenceSubmission, completedBy string) error {
+func (s *StepTransitionService) storeStepEvidence(
+ ctx context.Context,
+ stepExecution *workflows.StepExecution,
+ stepDef *workflows.WorkflowStepDefinition,
+ workflowExecution *workflows.WorkflowExecution,
+ evidenceSubmissions []EvidenceSubmission,
+ completedBy string,
+) error {
if len(evidenceSubmissions) == 0 {
return nil
}
// Gather all workflow context needed for evidence creation
- ctx, err := s.gatherWorkflowContext(stepExecutionID)
+ workflowCtx, err := s.gatherWorkflowContext(stepExecution, stepDef, workflowExecution)
if err != nil {
return err
}
// Get or create the execution evidence stream
- stream, err := s.evidenceIntegration.GetOrCreateExecutionStream(context.Background(), ctx.workflowExecution.ID)
+ stream, err := s.evidenceIntegration.GetOrCreateExecutionStream(ctx, workflowCtx.workflowExecution.ID)
if err != nil {
return fmt.Errorf("failed to get or create execution stream: %w", err)
}
@@ -271,7 +282,7 @@ func (s *StepTransitionService) storeStepEvidence(stepExecutionID *uuid.UUID, ev
backMatter, evidenceLinks := s.buildBackMatterFromSubmissions(evidenceSubmissions)
// Create the evidence record
- evidence := s.createEvidenceRecord(ctx, stream, backMatter, evidenceLinks, len(evidenceSubmissions))
+ evidence := s.createEvidenceRecord(workflowCtx, stream, backMatter, evidenceLinks, len(evidenceSubmissions))
// Save evidence to database
if err := s.db.Create(evidence).Error; err != nil {
@@ -279,7 +290,7 @@ func (s *StepTransitionService) storeStepEvidence(stepExecutionID *uuid.UUID, ev
}
// Build and attach labels
- labels := s.buildEvidenceLabels(ctx, completedBy, len(evidenceSubmissions))
+ labels := s.buildEvidenceLabels(workflowCtx, completedBy, len(evidenceSubmissions))
if err := s.db.Model(evidence).Association("Labels").Append(labels); err != nil {
return fmt.Errorf("failed to add labels to evidence: %w", err)
}
@@ -297,20 +308,19 @@ type workflowContext struct {
}
// gatherWorkflowContext retrieves all workflow entities needed for evidence creation
-func (s *StepTransitionService) gatherWorkflowContext(stepExecutionID *uuid.UUID) (*workflowContext, error) {
- stepExecution, err := s.stepExecutionService.GetByID(stepExecutionID)
- if err != nil {
- return nil, fmt.Errorf("failed to get step execution: %w", err)
+func (s *StepTransitionService) gatherWorkflowContext(
+ stepExecution *workflows.StepExecution,
+ stepDef *workflows.WorkflowStepDefinition,
+ workflowExecution *workflows.WorkflowExecution,
+) (*workflowContext, error) {
+ if stepExecution == nil {
+ return nil, fmt.Errorf("failed to get step execution: step execution is nil")
}
-
- stepDef, err := s.stepDefinitionService.GetByID(stepExecution.WorkflowStepDefinitionID)
- if err != nil {
- return nil, fmt.Errorf("failed to get step definition: %w", err)
+ if stepDef == nil {
+ return nil, fmt.Errorf("failed to get step definition: step definition is nil")
}
-
- workflowExecution, err := s.workflowExecutionService.GetByID(stepExecution.WorkflowExecutionID)
- if err != nil {
- return nil, fmt.Errorf("failed to get workflow execution: %w", err)
+ if workflowExecution == nil {
+ return nil, fmt.Errorf("failed to get workflow execution: workflow execution is nil")
}
instance, err := s.workflowInstanceService.GetByID(workflowExecution.WorkflowInstanceID)
@@ -499,30 +509,59 @@ func (s *StepTransitionService) GetStepExecutionService() *workflows.StepExecuti
// CanUserTransitionStep checks if a user can transition a specific step
func (s *StepTransitionService) CanUserTransitionStep(stepExecutionID *uuid.UUID, userID, userType string) (bool, error) {
- // Get the step execution
- stepExecution, err := s.stepExecutionService.GetByID(stepExecutionID)
- if err != nil {
- return false, fmt.Errorf("failed to get step execution: %w", err)
- }
+ if s.db == nil {
+ // Fallback for contexts that construct the service without a DB handle.
+ // This preserves legacy behavior while primary code path uses a single query.
+ stepExecution, err := s.stepExecutionService.GetByID(stepExecutionID)
+ if err != nil {
+ return false, fmt.Errorf("failed to get step execution: %w", err)
+ }
- // Get the step definition
- stepDef, err := s.getStepDefinition(stepExecution.WorkflowStepDefinitionID)
- if err != nil {
- return false, fmt.Errorf("failed to get step definition: %w", err)
- }
+ stepDef, err := s.getStepDefinition(stepExecution.WorkflowStepDefinitionID)
+ if err != nil {
+ return false, fmt.Errorf("failed to get step definition: %w", err)
+ }
- // Get the workflow execution
- workflowExecution, err := s.workflowExecutionService.GetByID(stepExecution.WorkflowExecutionID)
- if err != nil {
- return false, fmt.Errorf("failed to get workflow execution: %w", err)
- }
+ workflowExecution, err := s.workflowExecutionService.GetByID(stepExecution.WorkflowExecutionID)
+ if err != nil {
+ return false, fmt.Errorf("failed to get workflow execution: %w", err)
+ }
- // Verify user permission
- if err := s.verifyUserPermission(workflowExecution.WorkflowInstanceID, stepDef.ResponsibleRole, userID, userType); err != nil {
- return false, nil // No error, just not permitted
+ if err := s.verifyUserPermission(workflowExecution.WorkflowInstanceID, stepDef.ResponsibleRole, userID, userType); err != nil {
+ return false, nil
+ }
+ return true, nil
+ }
+
+ type canTransitionRow struct {
+ StepExecutionID uuid.UUID
+ MatchCount int64
+ }
+ var row canTransitionRow
+
+ err := s.db.Table("step_executions se").
+ Select("se.id AS step_execution_id, COUNT(ra.id) AS match_count").
+ Joins("JOIN workflow_step_definitions wsd ON wsd.id = se.workflow_step_definition_id").
+ Joins("JOIN workflow_executions we ON we.id = se.workflow_execution_id").
+ Joins(
+ `LEFT JOIN role_assignments ra
+ ON ra.workflow_instance_id = we.workflow_instance_id
+ AND ra.role_name = wsd.responsible_role
+ AND ra.assigned_to_type = ?
+ AND ra.assigned_to_id = ?
+ AND ra.is_active = ?`,
+ userType,
+ userID,
+ true,
+ ).
+ Where("se.id = ?", stepExecutionID).
+ Group("se.id").
+ Take(&row).Error
+ if err != nil {
+ return false, fmt.Errorf("failed to check transition permission: %w", err)
}
- return true, nil
+ return row.MatchCount > 0, nil
}
// GetEvidenceRequirements returns the evidence requirements for a step
diff --git a/internal/workflow/step_transition_test.go b/internal/workflow/step_transition_test.go
new file mode 100644
index 00000000..87497718
--- /dev/null
+++ b/internal/workflow/step_transition_test.go
@@ -0,0 +1,182 @@
+package workflow
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/compliance-framework/api/internal/service/relational"
+ "github.com/compliance-framework/api/internal/service/relational/workflows"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.uber.org/zap"
+ "gorm.io/driver/sqlite"
+ "gorm.io/gorm"
+)
+
+func setupStepTransitionTestDB(t *testing.T) *gorm.DB {
+ t.Helper()
+
+ db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
+ require.NoError(t, err)
+
+ require.NoError(t, db.AutoMigrate(&relational.User{}))
+ for _, entity := range workflows.GetWorkflowEntities() {
+ require.NoError(t, db.AutoMigrate(entity))
+ }
+
+ return db
+}
+
+func TestCanUserTransitionStep_QueryPath(t *testing.T) {
+ db := setupStepTransitionTestDB(t)
+
+ stepExecService := workflows.NewStepExecutionService(db, nil)
+ stepDefService := workflows.NewWorkflowStepDefinitionService(db)
+ workflowExecService := workflows.NewWorkflowExecutionService(db)
+ roleAssignmentService := workflows.NewRoleAssignmentService(db)
+ workflowInstanceService := workflows.NewWorkflowInstanceService(db)
+ workflowDefinitionService := workflows.NewWorkflowDefinitionService(db)
+
+ svc := NewStepTransitionService(
+ stepExecService,
+ stepDefService,
+ workflowExecService,
+ roleAssignmentService,
+ workflowInstanceService,
+ workflowDefinitionService,
+ nil,
+ db,
+ nil,
+ )
+
+ workflowDef := &workflows.WorkflowDefinition{Name: "WF", Version: "1.0"}
+ require.NoError(t, db.Create(workflowDef).Error)
+ sysID := uuid.New()
+ instance := &workflows.WorkflowInstance{
+ WorkflowDefinitionID: workflowDef.ID,
+ SystemSecurityPlanID: &sysID,
+ Name: "instance",
+ }
+ require.NoError(t, db.Create(instance).Error)
+ execution := &workflows.WorkflowExecution{
+ WorkflowInstanceID: instance.ID,
+ Status: workflows.WorkflowStatusInProgress.String(),
+ TriggeredBy: workflows.TriggerManual.String(),
+ }
+ require.NoError(t, db.Create(execution).Error)
+ stepDef := &workflows.WorkflowStepDefinition{
+ WorkflowDefinitionID: workflowDef.ID,
+ Name: "Step 1",
+ ResponsibleRole: "engineer",
+ }
+ require.NoError(t, db.Create(stepDef).Error)
+ stepExec := &workflows.StepExecution{
+ WorkflowExecutionID: execution.ID,
+ WorkflowStepDefinitionID: stepDef.ID,
+ Status: workflows.StepStatusPending.String(),
+ }
+ require.NoError(t, db.Create(stepExec).Error)
+
+ require.NoError(t, db.Create(&workflows.RoleAssignment{
+ WorkflowInstanceID: instance.ID,
+ RoleName: "engineer",
+ AssignedToType: workflows.AssignmentTypeUser.String(),
+ AssignedToID: "user-1",
+ IsActive: true,
+ }).Error)
+
+ can, err := svc.CanUserTransitionStep(stepExec.ID, "user-1", workflows.AssignmentTypeUser.String())
+ require.NoError(t, err)
+ assert.True(t, can)
+
+ can, err = svc.CanUserTransitionStep(stepExec.ID, "user-2", workflows.AssignmentTypeUser.String())
+ require.NoError(t, err)
+ assert.False(t, can)
+
+ missingID := uuid.New()
+ _, err = svc.CanUserTransitionStep(&missingID, "user-1", workflows.AssignmentTypeUser.String())
+ require.Error(t, err)
+}
+
+func TestCanUserTransitionStep_FallbackWhenDBNil(t *testing.T) {
+ stepExecID := uuid.New()
+ stepDefID := uuid.New()
+ execID := uuid.New()
+ instanceID := uuid.New()
+
+ mockStepExec := &MockStepExecutionService{}
+ mockStepDef := &MockWorkflowStepDefinitionService{}
+ mockWorkflowExec := &MockWorkflowExecutionService{}
+ mockRole := &MockRoleAssignmentService{}
+
+ svc := NewStepTransitionService(
+ mockStepExec,
+ mockStepDef,
+ mockWorkflowExec,
+ mockRole,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ mockStepExec.On("GetByID", &stepExecID).Return(&workflows.StepExecution{
+ UUIDModel: relational.UUIDModel{ID: &stepExecID},
+ WorkflowStepDefinitionID: &stepDefID,
+ WorkflowExecutionID: &execID,
+ }, nil).Once()
+ mockStepDef.On("GetByID", &stepDefID).Return(&workflows.WorkflowStepDefinition{
+ UUIDModel: relational.UUIDModel{ID: &stepDefID},
+ ResponsibleRole: "engineer",
+ }, nil).Once()
+ mockWorkflowExec.On("GetByID", &execID).Return(&workflows.WorkflowExecution{
+ UUIDModel: relational.UUIDModel{ID: &execID},
+ WorkflowInstanceID: &instanceID,
+ }, nil).Once()
+ mockRole.On("FindAssigneeForRole", &instanceID, "engineer").Return(&workflows.RoleAssignment{
+ AssignedToType: workflows.AssignmentTypeUser.String(),
+ AssignedToID: "user-1",
+ IsActive: true,
+ }, nil).Once()
+
+ can, err := svc.CanUserTransitionStep(&stepExecID, "user-1", workflows.AssignmentTypeUser.String())
+ require.NoError(t, err)
+ assert.True(t, can)
+
+ mockStepExec.AssertExpectations(t)
+ mockStepDef.AssertExpectations(t)
+ mockWorkflowExec.AssertExpectations(t)
+ mockRole.AssertExpectations(t)
+}
+
+type mockStepAssignmentService struct {
+ called bool
+}
+
+func (m *mockStepAssignmentService) ReassignWithTx(tx *gorm.DB, id *uuid.UUID, assignedToType, assignedToID string, assignedAt time.Time) error {
+ m.called = true
+ return nil
+}
+
+func TestAssignmentService_ReassignStep_UsesStepExecutionService(t *testing.T) {
+ db := setupAssignmentServiceTestDB(t)
+ roleService := new(MockRoleAssignmentService)
+ mockStepService := &mockStepAssignmentService{}
+ service := NewAssignmentService(roleService, mockStepService, db, zap.NewNop().Sugar(), nil)
+
+ _, _, stepExec := createAssignmentServiceGraph(t, db)
+
+ err := service.ReassignStep(
+ context.Background(),
+ *stepExec.ID,
+ Assignee{Type: "group", ID: "new-group"},
+ "route through service",
+ nil,
+ "actor@example.com",
+ )
+ require.NoError(t, err)
+ assert.True(t, mockStepService.called)
+}
diff --git a/internal/workflow/validator.go b/internal/workflow/validator.go
deleted file mode 100644
index d755af0a..00000000
--- a/internal/workflow/validator.go
+++ /dev/null
@@ -1,382 +0,0 @@
-package workflow
-
-import (
- "errors"
- "fmt"
-
- "github.com/compliance-framework/api/internal/service/relational/workflows"
- "github.com/google/uuid"
-)
-
-// DAGValidator provides comprehensive DAG validation functionality
-type DAGValidator struct {
- stepService *workflows.WorkflowStepDefinitionService
-}
-
-// NewDAGValidator creates a new DAG validator
-func NewDAGValidator(stepService *workflows.WorkflowStepDefinitionService) *DAGValidator {
- return &DAGValidator{
- stepService: stepService,
- }
-}
-
-// ValidateDAG performs comprehensive validation of a workflow DAG
-func (dv *DAGValidator) ValidateDAG(workflowDefID *uuid.UUID) (*DAGValidationResult, error) {
- result := &DAGValidationResult{
- IsValid: true,
- Errors: []string{},
- Warnings: []string{},
- }
-
- // Get all steps in the workflow
- steps, err := dv.stepService.GetByWorkflowDefinitionID(workflowDefID)
- if err != nil {
- return nil, fmt.Errorf("failed to retrieve workflow steps: %w", err)
- }
-
- if err := validateStepsNotEmpty(steps); err != nil {
- result.IsValid = false
- result.Errors = append(result.Errors, err.Error())
- return result, nil
- }
-
- // Build dependency map once for efficiency
- depMap, err := dv.getDependencyMap(steps)
- if err != nil {
- return nil, err
- }
-
- // 1. Check for cycles
- if err := dv.checkCycles(steps, depMap, result); err != nil {
- return nil, err
- }
-
- // 2. Check for orphaned steps
- if err := dv.checkOrphanedSteps(steps, depMap, result); err != nil {
- return nil, err
- }
-
- // 3. Validate dependencies
- if err := dv.checkDependencyIssues(steps, depMap, result); err != nil {
- return nil, err
- }
-
- return result, nil
-}
-
-// getDependencyMap builds a dependency map for efficient lookups
-func (dv *DAGValidator) getDependencyMap(steps []workflows.WorkflowStepDefinition) (map[string][]workflows.WorkflowStepDefinition, error) {
- depMap := make(map[string][]workflows.WorkflowStepDefinition)
-
- for _, step := range steps {
- dependencies, err := dv.stepService.GetDependencies(step.ID)
- if err != nil {
- return nil, fmt.Errorf("failed to get dependencies for step %s: %w", step.ID.String(), err)
- }
-
- depMap[step.ID.String()] = dependencies
- }
-
- return depMap, nil
-}
-
-// checkCycles detects cycles and updates the validation result
-func (dv *DAGValidator) checkCycles(steps []workflows.WorkflowStepDefinition, depMap map[string][]workflows.WorkflowStepDefinition, result *DAGValidationResult) error {
- cycles, err := dv.detectAllCycles(steps, depMap)
- if err != nil {
- return fmt.Errorf("cycle detection failed: %w", err)
- }
-
- if len(cycles) > 0 {
- result.IsValid = false
- result.Cycles = cycles
- for _, cycle := range cycles {
- result.Errors = append(result.Errors, fmt.Sprintf("circular dependency detected: %s", cycle.Path))
- }
- }
-
- return nil
-}
-
-// detectAllCycles finds all cycles in the workflow DAG
-func (dv *DAGValidator) detectAllCycles(steps []workflows.WorkflowStepDefinition, depMap map[string][]workflows.WorkflowStepDefinition) ([]DAGCycle, error) {
- var cycles []DAGCycle
- visited := make(map[string]bool)
- recursionStack := make(map[string]bool)
- path := make(map[string]string)
-
- for _, step := range steps {
- if !visited[step.ID.String()] {
- cycle, err := dv.dfsCycleDetection(step.ID, depMap, visited, recursionStack, path)
- if err != nil {
- return nil, err
- }
- if cycle != nil {
- cycles = append(cycles, *cycle)
- }
- }
- }
-
- return cycles, nil
-}
-
-// dfsCycleDetection performs DFS to detect cycles using pre-built dependency map
-func (dv *DAGValidator) dfsCycleDetection(
- currentStepID *uuid.UUID,
- depMap map[string][]workflows.WorkflowStepDefinition,
- visited, recursionStack map[string]bool,
- path map[string]string,
-) (*DAGCycle, error) {
- currentID := currentStepID.String()
- visited[currentID] = true
- recursionStack[currentID] = true
-
- // Get dependencies from pre-built map (O(1) lookup)
- dependencies := depMap[currentID]
-
- for _, dep := range dependencies {
- depID := dep.ID.String()
-
- // Cycle detected
- if recursionStack[depID] {
- return dv.createCycleResult(currentID, depID, path), nil
- }
-
- // Recurse if not visited
- if !visited[depID] {
- path[currentID] = depID
- cycle, err := dv.dfsCycleDetection(dep.ID, depMap, visited, recursionStack, path)
- if err != nil {
- return nil, err
- }
- if cycle != nil {
- return cycle, nil
- }
- delete(path, currentID)
- }
- }
-
- recursionStack[currentID] = false
- return nil, nil
-}
-
-// createCycleResult builds a DAGCycle result from detected cycle
-func (dv *DAGValidator) createCycleResult(fromID, toID string, path map[string]string) *DAGCycle {
- cyclePath := dv.buildCyclePath(fromID, toID, path)
- cycle := DAGCycle{
- Steps: []string{toID, fromID},
- Path: cyclePath,
- }
-
- // Add all steps in the cycle path
- for stepID := range path {
- if path[stepID] != "" {
- cycle.Steps = append(cycle.Steps, stepID)
- }
- }
-
- return &cycle
-}
-
-// buildCyclePath creates a human-readable cycle path
-func (dv *DAGValidator) buildCyclePath(fromID, toID string, path map[string]string) string {
- cyclePath := toID
- current := fromID
-
- for current != "" && current != toID {
- cyclePath += " -> " + current
- current = path[current]
- if current == toID {
- cyclePath += " -> " + toID
- break
- }
- }
-
- return cyclePath
-}
-
-// checkOrphanedSteps detects orphaned steps and updates the validation result
-func (dv *DAGValidator) checkOrphanedSteps(steps []workflows.WorkflowStepDefinition, depMap map[string][]workflows.WorkflowStepDefinition, result *DAGValidationResult) error {
- orphaned, err := dv.detectOrphanedSteps(steps, depMap)
- if err != nil {
- return fmt.Errorf("orphaned step detection failed: %w", err)
- }
-
- if len(orphaned) > 0 {
- result.OrphanedSteps = orphaned
- for _, orphan := range orphaned {
- result.Warnings = append(result.Warnings, fmt.Sprintf("orphaned step '%s': %s", orphan.StepName, orphan.Reason))
- }
- }
-
- return nil
-}
-
-// detectOrphanedSteps finds steps with no dependencies or no dependents
-func (dv *DAGValidator) detectOrphanedSteps(steps []workflows.WorkflowStepDefinition, depMap map[string][]workflows.WorkflowStepDefinition) ([]OrphanedStep, error) {
- var orphaned []OrphanedStep
-
- // Use helper to count dependents efficiently
- dependents := countDependents(steps, depMap)
-
- // Find orphaned steps
- for _, step := range steps {
- stepID := step.ID.String()
- dependencies := depMap[stepID]
-
- // Check if step has no dependencies (isolated starting point)
- if len(dependencies) == 0 && len(steps) > 1 {
- orphaned = append(orphaned, OrphanedStep{
- StepID: *step.ID,
- StepName: step.Name,
- Reason: OrphanReasonNoDependencies,
- })
- }
-
- // Check if step has no dependents (dead-end step)
- if dependents[stepID] == 0 && len(steps) > 1 && len(dependencies) > 0 {
- orphaned = append(orphaned, OrphanedStep{
- StepID: *step.ID,
- StepName: step.Name,
- Reason: OrphanReasonNoDependents,
- })
- }
- }
-
- return orphaned, nil
-}
-
-// countDependents counts the number of dependents for each step
-func countDependents(steps []workflows.WorkflowStepDefinition, depMap map[string][]workflows.WorkflowStepDefinition) map[string]int {
- dependents := make(map[string]int)
-
- for _, step := range steps {
- stepID := step.ID.String()
- dependencies := depMap[stepID]
-
- for _, dep := range dependencies {
- dependents[dep.ID.String()]++
- }
- }
-
- return dependents
-}
-
-// checkDependencyIssues validates dependencies and updates the validation result
-func (dv *DAGValidator) checkDependencyIssues(steps []workflows.WorkflowStepDefinition, depMap map[string][]workflows.WorkflowStepDefinition, result *DAGValidationResult) error {
- issues, err := dv.validateDependencies(steps, depMap)
- if err != nil {
- return fmt.Errorf("dependency validation failed: %w", err)
- }
-
- if len(issues) > 0 {
- result.Dependencies = issues
- for _, issue := range issues {
- result.Errors = append(result.Errors, fmt.Sprintf("dependency issue: %s", issue.Issue))
- }
- result.IsValid = false
- }
-
- return nil
-}
-
-// validateDependencies performs comprehensive dependency validation
-func (dv *DAGValidator) validateDependencies(steps []workflows.WorkflowStepDefinition, depMap map[string][]workflows.WorkflowStepDefinition) ([]DependencyIssue, error) {
- var issues []DependencyIssue
- stepIndex := buildStepIndex(steps)
-
- for _, step := range steps {
- stepID := step.ID.String()
- dependencies := depMap[stepID]
-
- for _, dep := range dependencies {
- depID := dep.ID.String()
-
- // Check for self-dependency
- if stepID == depID {
- issues = append(issues, DependencyIssue{
- FromStepID: *step.ID,
- FromStepName: step.Name,
- ToStepID: *dep.ID,
- ToStepName: step.Name,
- Issue: fmt.Sprintf("step '%s' depends on itself", step.Name),
- })
- continue
- }
-
- // Check if dependency step exists in the same workflow (O(1) lookup)
- if _, err := stepIndex.findStep(dep.ID); err != nil {
- issues = append(issues, DependencyIssue{
- FromStepID: *step.ID,
- FromStepName: step.Name,
- ToStepID: *dep.ID,
- ToStepName: "unknown",
- Issue: fmt.Sprintf("step '%s' depends on non-existent step '%s'", step.Name, depID),
- })
- }
- }
- }
-
- return issues, nil
-}
-
-// ValidateDependencyChange validates adding/removing a specific dependency
-func (dv *DAGValidator) ValidateDependencyChange(workflowDefID, stepID, dependsOnStepID *uuid.UUID, isAdd bool) error {
- if isAdd {
- return dv.validateDependencyAddition(stepID, dependsOnStepID)
- }
- return dv.validateDependencyRemoval(stepID, dependsOnStepID)
-}
-
-// validateDependencyAddition validates adding a new dependency
-func (dv *DAGValidator) validateDependencyAddition(stepID, dependsOnStepID *uuid.UUID) error {
- // Check if it would create a cycle
- hasCycle, err := dv.stepService.HasCircularDependency(stepID, dependsOnStepID)
- if err != nil {
- return fmt.Errorf("cycle detection failed: %w", err)
- }
-
- if hasCycle {
- return errors.New("adding this dependency would create a circular reference")
- }
-
- // Check if dependency already exists
- return dv.checkDependencyNotExists(stepID, dependsOnStepID)
-}
-
-// validateDependencyRemoval validates removing an existing dependency
-func (dv *DAGValidator) validateDependencyRemoval(stepID, dependsOnStepID *uuid.UUID) error {
- return dv.checkDependencyExists(stepID, dependsOnStepID)
-}
-
-// checkDependencyExists verifies that a dependency exists
-func (dv *DAGValidator) checkDependencyExists(stepID, dependsOnStepID *uuid.UUID) error {
- dependencies, err := dv.stepService.GetDependencies(stepID)
- if err != nil {
- return fmt.Errorf("failed to get current dependencies: %w", err)
- }
-
- for _, dep := range dependencies {
- if dep.ID.String() == dependsOnStepID.String() {
- return nil
- }
- }
-
- return errors.New("dependency does not exist")
-}
-
-// checkDependencyNotExists verifies that a dependency does not already exist
-func (dv *DAGValidator) checkDependencyNotExists(stepID, dependsOnStepID *uuid.UUID) error {
- dependencies, err := dv.stepService.GetDependencies(stepID)
- if err != nil {
- return fmt.Errorf("failed to get current dependencies: %w", err)
- }
-
- for _, dep := range dependencies {
- if dep.ID.String() == dependsOnStepID.String() {
- return errors.New("dependency already exists")
- }
- }
-
- return nil
-}
diff --git a/sdk/integration_base_test.go b/sdk/integration_base_test.go
index 65aca568..424364ef 100644
--- a/sdk/integration_base_test.go
+++ b/sdk/integration_base_test.go
@@ -98,7 +98,7 @@ func (suite *IntegrationBaseTestSuite) SetupSuite() {
logger, _ := zap.NewDevelopment()
metrics := api.NewMetricsHandler(context.Background(), logger.Sugar())
server := api.NewServer(context.Background(), logger.Sugar(), cfg, metrics)
- handler.RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil)
+ handler.RegisterHandlers(server, logger.Sugar(), suite.DB, suite.Config, nil, nil, nil, nil)
suite.Server = server
diff --git a/workflow.yaml b/workflow.yaml
index cd91a32a..2e478131 100644
--- a/workflow.yaml
+++ b/workflow.yaml
@@ -1 +1,5 @@
-scheduler_enabled: true
\ No newline at end of file
+scheduler_enabled: true
+due_soon_enabled: true
+due_soon_schedule: "0 * * * * *"
+task_digest_enabled: true
+task_digest_schedule: "0 * * * * *"
\ No newline at end of file