diff --git a/main.go b/main.go index cfffede..dd836d7 100644 --- a/main.go +++ b/main.go @@ -383,11 +383,45 @@ func distributeSQS(app *appctx, runID string, tagFilters []string, metadata map[ return true } +type cancelledEntry struct { + expiresAt time.Time +} + type appctx struct { - pub *lspubsub.PubsubPublisher // starter publisher topic - rpub *lspubsub.PubsubPublisher // topic to publish reports - mtx *sync.Mutex - topicArn *string + pub *lspubsub.PubsubPublisher // starter publisher topic + rpub *lspubsub.PubsubPublisher // topic to publish reports + mtx *sync.Mutex + topicArn *string + cancelledMtx sync.RWMutex + cancelledRuns map[string]cancelledEntry // run_ids that have been cancelled, with expiry +} + +const cancelledRunTTL = 10 * time.Minute + +func (a *appctx) cancelRun(runID string) { + a.cancelledMtx.Lock() + defer a.cancelledMtx.Unlock() + a.cancelledRuns[runID] = cancelledEntry{expiresAt: time.Now().Add(cancelledRunTTL)} + log.Printf("cancelRun: run_id=%s will expire at %s", runID, a.cancelledRuns[runID].expiresAt.Format(time.RFC3339)) +} + +func (a *appctx) isRunCancelled(runID string) bool { + if runID == "" { + return false + } + a.cancelledMtx.RLock() + entry, ok := a.cancelledRuns[runID] + a.cancelledMtx.RUnlock() + if !ok { + return false + } + if time.Now().After(entry.expiresAt) { + a.cancelledMtx.Lock() + delete(a.cancelledRuns, runID) + a.cancelledMtx.Unlock() + return false + } + return true } // Our message processing callback. @@ -484,6 +518,10 @@ func process(ctx any, data []byte) error { } case "process": log.Printf("process: %+v", c) + if app.isRunCancelled(c.ID) { + log.Printf("process: run_id=%s is cancelled, skipping scenario %s", c.ID, c.Scenario) + return nil + } doScenario(&doScenarioInput{ app: app, ScenarioFiles: []string{c.Scenario}, @@ -558,7 +596,50 @@ func handleScenarioCompletion(ctx any, data []byte) error { } } - case "completed": + case "cancelled": + log.Printf("run cancelled: run_id=%s repo=%s sha=%s pr=%s", + msg.RunID, msg.Repository, msg.CommitSHA, msg.PRNumber) + if app, ok := ctx.(*appctx); ok && app != nil && msg.RunID != "" { + app.cancelRun(msg.RunID) + log.Printf("cancelled: run_id=%s marked as cancelled in-process, pending scenarios will be skipped", msg.RunID) + } + + if msg.CommitSHA == "" || msg.Repository == "" { + log.Printf("cancelled: missing commit_sha or repository, skipping github status update") + return nil + } + + if err := postCommitStatus( + githubtoken, + msg.CommitSHA, + msg.Repository, + msg.RunURL, + "failure", + fmt.Sprintf("Test run cancelled — PR #%s was closed", msg.PRNumber), + ); err != nil { + log.Printf("postCommitStatus (cancelled) failed: %v", err) + } + if repslack != "" { + payload := SlackMessage{ + Attachments: []SlackAttachment{ + { + Color: "warning", + Title: "Test Run Cancelled", + Text: fmt.Sprintf("*PR #%s* in `%s` was closed.\nIn-progress test run `%s` has been cancelled.\n<%s|View workflow>", + msg.PRNumber, msg.Repository, msg.RunID, msg.RunURL), + Footer: fmt.Sprintf("oops • pr: %s • sha: %.7s", msg.PRNumber, msg.CommitSHA), + Timestamp: time.Now().Unix(), + MrkdwnIn: []string{"text"}, + }, + }, + } + + if err := payload.Notify(repslack); err != nil { + log.Printf("Notify (slack) cancelled failed: %v", err) + } + } + + case "completed": log.Printf("run completed: run_id=%s overall_status=%s failed=%d repo=%s sha=%s", msg.RunID, msg.OverallStatus, msg.FailedCount, msg.Repository, msg.CommitSHA) @@ -657,7 +738,8 @@ func run(ctx context.Context, done chan error) { } app := &appctx{ - mtx: &sync.Mutex{}, + mtx: &sync.Mutex{}, + cancelledRuns: make(map[string]cancelledEntry), } ctx0, cancelCtx0 := context.WithCancel(ctx) defer cancelCtx0() @@ -754,7 +836,7 @@ func run(ctx context.Context, done chan error) { done1 := make(chan error, 1) go func() { - ls := lspubsub.NewLengthySubscriber(nil, project, scenariopubsub, handleScenarioCompletion) + ls := lspubsub.NewLengthySubscriber(app, project, scenariopubsub, handleScenarioCompletion) err := ls.Start(ctx0, done1) if err != nil { log.Fatalf("listener for scenario progress failed: %v", err)