diff --git a/main.go b/main.go index cfffede..39640f9 100644 --- a/main.go +++ b/main.go @@ -88,6 +88,7 @@ type ScenarioProgressMessage struct { Code string `json:"code"` OverallStatus string `json:"overall_status,omitempty"` FailedCount int64 `json:"failed_count,omitempty"` + CancelledCount int64 `json:"cancelled_count,omitempty"` FailedScenarios []string `json:"failed_scenarios,omitempty"` CommitSHA string `json:"commit_sha,omitempty"` Repository string `json:"repository,omitempty"` @@ -388,8 +389,93 @@ type appctx struct { rpub *lspubsub.PubsubPublisher // topic to publish reports mtx *sync.Mutex topicArn *string + + activeRunsMu sync.RWMutex + activeRuns map[string]map[string]context.CancelFunc + pendingCancelsMu sync.RWMutex + pendingCancels map[string]struct{} +} + +func (a *appctx) registerRun(commitSHA string, cancel context.CancelFunc) string { + if commitSHA == "" { + return "" + } + instanceID := uniuri.NewLen(12) + a.activeRunsMu.Lock() + defer a.activeRunsMu.Unlock() + if a.activeRuns == nil { + a.activeRuns = make(map[string]map[string]context.CancelFunc) + } + if a.activeRuns[commitSHA] == nil { + a.activeRuns[commitSHA] = make(map[string]context.CancelFunc) + } + a.activeRuns[commitSHA][instanceID] = cancel + return instanceID +} + +func (a *appctx) unregisterRun(commitSHA, instanceID string) { + if commitSHA == "" || instanceID == "" { + return + } + a.activeRunsMu.Lock() + defer a.activeRunsMu.Unlock() + delete(a.activeRuns[commitSHA], instanceID) + if len(a.activeRuns[commitSHA]) == 0 { + delete(a.activeRuns, commitSHA) + } } +// cancelRun cancels every in-flight scenario for commitSHA and returns true +// if at least one was found. +func (a *appctx) cancelRun(commitSHA string) bool { + if commitSHA == "" { + return false + } + a.activeRunsMu.RLock() + funcs := a.activeRuns[commitSHA] + a.activeRunsMu.RUnlock() + if len(funcs) == 0 { + return false + } + for _, cancel := range funcs { + cancel() + } + log.Printf("cancelRun: commit_sha=%s cancelled %d in-flight scenario(s)", commitSHA, len(funcs)) + return true +} + +func (a *appctx) markCancelled(commitSHA string) { + if commitSHA == "" { + return + } + a.pendingCancelsMu.Lock() + defer a.pendingCancelsMu.Unlock() + if a.pendingCancels == nil { + a.pendingCancels = make(map[string]struct{}) + } + a.pendingCancels[commitSHA] = struct{}{} + log.Printf("markCancelled: commit_sha=%s tombstoned", commitSHA) +} + +func (a *appctx) isCancelled(commitSHA string) bool { + if commitSHA == "" { + return false + } + a.pendingCancelsMu.RLock() + defer a.pendingCancelsMu.RUnlock() + _, ok := a.pendingCancels[commitSHA] + return ok +} + +func (a *appctx) unmarkCancelled(commitSHA string) { + if commitSHA == "" { + return + } + a.pendingCancelsMu.Lock() + defer a.pendingCancelsMu.Unlock() + delete(a.pendingCancels, commitSHA) + log.Printf("unmarkCancelled: commit_sha=%s tombstone cleared", commitSHA) +} // Our message processing callback. func process(ctx any, data []byte) error { app := ctx.(*appctx) @@ -484,7 +570,29 @@ func process(ctx any, data []byte) error { } case "process": log.Printf("process: %+v", c) - doScenario(&doScenarioInput{ + commitSHA, _ := c.Metadata["commit_sha"].(string) + if commitSHA != "" && app.isCancelled(commitSHA) { + log.Printf("process: commit_sha=%s is tombstoned, publishing cancelled result for %s", commitSHA, c.Scenario) + in := &doScenarioInput{ + app: app, + ScenarioFiles: []string{c.Scenario}, + ReportPubsub: reppubsub, + Metadata: c.Metadata, + RunID: c.ID, + } + publishCancelledResult(app, c.Scenario, in) + return nil + } + + runCtx, runCancel := context.WithCancel(context.Background()) + defer runCancel() + var instanceID string + if commitSHA != "" { + instanceID = app.registerRun(commitSHA, runCancel) + defer app.unregisterRun(commitSHA, instanceID) + } + + in := &doScenarioInput{ app: app, ScenarioFiles: []string{c.Scenario}, ReportSlack: repslack, @@ -492,7 +600,17 @@ func process(ctx any, data []byte) error { Verbose: verbose, Metadata: c.Metadata, RunID: c.ID, - }) + cancelCtx: runCtx, + } + select { + case <-runCtx.Done(): + log.Printf("process: commit_sha=%s cancelled just after register, publishing cancelled result for %s", commitSHA, c.Scenario) + publishCancelledResult(app, c.Scenario, in) + return nil + default: + } + + doScenario(in) } return nil @@ -506,8 +624,73 @@ func handleScenarioCompletion(ctx any, data []byte) error { } log.Printf("scenario progress: run_id=%s code=%s progress=%s", msg.RunID, msg.Code, msg.TotalScenarios) + var app *appctx + if ctx != nil { + app, _ = ctx.(*appctx) + } switch msg.Code { + case "closed": + log.Printf("received closed event: commit_sha=%s repo=%s pr=%s run_id=%s", + msg.CommitSHA, msg.Repository, msg.PRNumber, msg.RunID) + + if msg.CommitSHA == "" { + log.Printf("cancel: missing commit_sha, skipping") + return nil + } + if app != nil { + app.markCancelled(msg.CommitSHA) + } + + cancelled := false + if app != nil { + cancelled = app.cancelRun(msg.CommitSHA) + } + + if cancelled { + log.Printf("cancel: commit_sha=%s cancelled successfully", msg.CommitSHA) + } else { + log.Printf("cancel: commit_sha=%s not found in active runs (may have already finished or not yet registered)", msg.CommitSHA) + } + + if msg.Repository != "" { + if err := postCommitStatus(githubtoken, msg.CommitSHA, msg.Repository, msg.RunURL, "error", "PR closed — test run cancelled"); err != nil { + log.Printf("cancel: postCommitStatus failed: %v", err) + } + } + + if repslack != "" { + env := "dev" + if strings.Contains(pubsub, "prod") { + env = "prod" + } else if strings.Contains(pubsub, "next") { + env = "next" + } + + text := fmt.Sprintf("*Environment:* %s\n*Repository:* %s\n*PR:* #%s\n*Commit:* %s", + env, msg.Repository, msg.PRNumber, msg.CommitSHA) + if msg.RunURL != "" { + text += fmt.Sprintf("\n\n<%s|View run>", msg.RunURL) + } + + payload := SlackMessage{ + Attachments: []SlackAttachment{ + { + Color: "warning", + Title: "Test Run Cancelled (PR Closed)", + Text: text, + Footer: fmt.Sprintf("oops • sha: %s", msg.CommitSHA), + Timestamp: time.Now().Unix(), + MrkdwnIn: []string{"text"}, + }, + }, + } + + if err := payload.Notify(repslack); err != nil { + log.Printf("cancel: Notify (slack) failed: %v", err) + } + } + case "approve": log.Printf("received approve event: repo=%s sha=%s approvals=%d reviewers=%s", msg.Repository, msg.CommitSHA, msg.ApprovalCount, msg.Reviewers) @@ -559,8 +742,16 @@ func handleScenarioCompletion(ctx any, data []byte) error { } case "completed": - log.Printf("run completed: run_id=%s overall_status=%s failed=%d repo=%s sha=%s", - msg.RunID, msg.OverallStatus, msg.FailedCount, msg.Repository, msg.CommitSHA) + log.Printf("run completed: run_id=%s overall_status=%s failed=%d cancelled=%d repo=%s sha=%s", + msg.RunID, msg.OverallStatus, msg.FailedCount, msg.CancelledCount, msg.Repository, msg.CommitSHA) + if app != nil { + app.unmarkCancelled(msg.CommitSHA) + } + + if msg.CancelledCount > 0 { + log.Printf("completed: run_id=%s has %d cancelled scenario(s), skipping dispatch and notifications", msg.RunID, msg.CancelledCount) + return nil + } if err := sendRepositoryDispatch(githubtoken, &msg); err != nil { log.Printf("sendRepositoryDispatch failed: %v", err) @@ -694,8 +885,6 @@ func run(ctx context.Context, done chan error) { } go func() { - // Messages should be payer level. We will subdivide linked accts to separate messages for - // linked-acct-level processing. ls := lspubsub.NewLengthySubscriber(app, project, pubsub, process) err = ls.Start(ctx0, done0) if err != nil { @@ -754,7 +943,7 @@ func run(ctx context.Context, done chan error) { done1 := make(chan error, 1) go func() { - ls := lspubsub.NewLengthySubscriber(nil, project, scenariopubsub, handleScenarioCompletion) + ls := lspubsub.NewLengthySubscriber(app, project, scenariopubsub, handleScenarioCompletion) err := ls.Start(ctx0, done1) if err != nil { log.Fatalf("listener for scenario progress failed: %v", err) @@ -824,4 +1013,4 @@ func main() { log.SetPrefix("[oops] ") log.SetOutput(os.Stdout) rootcmd.Execute() -} +} \ No newline at end of file diff --git a/scenario.go b/scenario.go index 9a42682..f7e3c35 100644 --- a/scenario.go +++ b/scenario.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "log" @@ -72,7 +73,7 @@ func (s Scenario) getHead(file string) ([]byte, error) { } // RunScript runs file and returns the combined stdout+stderr result. -func (s *Scenario) RunScript(file string) ([]byte, error) { +func (s *Scenario) RunScript(ctx context.Context, file string) ([]byte, error) { l1, err := s.getHead(file) if err != nil { return nil, err @@ -89,10 +90,10 @@ func (s *Scenario) RunScript(file string) ([]byte, error) { var c *exec.Cmd switch { case strings.Contains(runner, "python"): - c = exec.Command(runner, file) + c = exec.CommandContext(ctx, runner, file) default: // Assume it's a shell interpreter. - c = exec.Command(runner, "-c", file) + c = exec.CommandContext(ctx, runner, "-c", file) } c.Env = os.Environ() @@ -109,7 +110,7 @@ func (s *Scenario) RunScript(file string) ([]byte, error) { // ParseValue tries to check if contents is in script form and if it is, writes it // to disk as an executable, runs it and returns the resulting stream output. // Otherwise, return the contents as is. -func (s *Scenario) ParseValue(contents string, file ...string) (string, error) { +func (s *Scenario) ParseValue(ctx context.Context, contents string, file ...string) (string, error) { if strings.HasPrefix(contents, "#!") { f := fmt.Sprintf("oops_%v", uuid.NewString()) f = filepath.Join(os.TempDir(), f) @@ -122,7 +123,7 @@ func (s *Scenario) ParseValue(contents string, file ...string) (string, error) { return contents, err } - b, err := s.RunScript(f) + b, err := s.RunScript(ctx, f) return string(b), err } @@ -169,6 +170,7 @@ type doScenarioInput struct { Verbose bool Metadata map[string]interface{} RunID string + cancelCtx context.Context OnScenarioDone func(scenario, status string) } @@ -195,7 +197,22 @@ func isAllowed(s *Scenario) bool { } func doScenario(in *doScenarioInput) error { + scriptCtx := context.Background() + if in.cancelCtx != nil { + scriptCtx = in.cancelCtx + } + for _, f := range in.ScenarioFiles { + if in.cancelCtx != nil { + select { + case <-in.cancelCtx.Done(): + log.Printf("doScenario: cancelled before starting %v, publishing cancelled result", filepath.Base(f)) + publishCancelledResult(in.app, f, in) + continue + default: + } + } + yml, err := os.ReadFile(f) if err != nil { continue @@ -220,7 +237,7 @@ func doScenario(in *doScenarioInput) error { basef := filepath.Base(f) fn := filepath.Join(os.TempDir(), fmt.Sprintf("%v_prepare", basef)) fn, _ = s.WriteScript(fn, s.Prepare) - b, err := s.RunScript(fn) + b, err := s.RunScript(scriptCtx, fn) if err != nil { s.errs = append(s.errs, errors.Wrapf(err, "prepare:\n%v: %v", s.Prepare, string(b))) @@ -237,7 +254,7 @@ func doScenario(in *doScenarioInput) error { // Parse url. fn := fmt.Sprintf("%v_url", prefix) - nv, err := s.ParseValue(run.HTTP.URL, fn) + nv, err := s.ParseValue(scriptCtx, run.HTTP.URL, fn) if err != nil { s.errs = append(s.errs, errors.Wrapf(err, "ParseValue[%v]: %v", i, run.HTTP.URL)) continue @@ -253,7 +270,7 @@ func doScenario(in *doScenarioInput) error { req := e.Request(run.HTTP.Method, u.Path) for k, v := range run.HTTP.Headers { fn := fmt.Sprintf("%v_hdr.%v", prefix, k) - nv, err := s.ParseValue(v, fn) + nv, err := s.ParseValue(scriptCtx, v, fn) if err != nil { s.errs = append(s.errs, errors.Wrapf(err, "ParseValue[%v]: %v", i, v)) continue @@ -265,7 +282,7 @@ func doScenario(in *doScenarioInput) error { for k, v := range run.HTTP.QueryParams { fn := fmt.Sprintf("%v_qparams.%v", prefix, k) - nv, _ := s.ParseValue(v, fn) + nv, _ := s.ParseValue(scriptCtx, v, fn) req = req.WithQuery(k, nv) } @@ -274,19 +291,19 @@ func doScenario(in *doScenarioInput) error { } for k, v := range run.HTTP.Files { fn := fmt.Sprintf("%v_files.%v", prefix, k) - nv, _ := s.ParseValue(v, fn) + nv, _ := s.ParseValue(scriptCtx, v, fn) req = req.WithFile(k, nv) } for k, v := range run.HTTP.Forms { fn := fmt.Sprintf("%v_forms.%v", prefix, k) - nv, _ := s.ParseValue(v, fn) + nv, _ := s.ParseValue(scriptCtx, v, fn) req = req.WithFormField(k, nv) } if run.HTTP.Payload != "" { fn := fmt.Sprintf("%v_payload", prefix) - nv, _ := s.ParseValue(run.HTTP.Payload, fn) + nv, _ := s.ParseValue(scriptCtx, run.HTTP.Payload, fn) req = req.WithBytes([]byte(nv)) } @@ -310,7 +327,7 @@ func doScenario(in *doScenarioInput) error { if run.HTTP.Asserts.Script != "" { fn := fmt.Sprintf("%v_assertscript", prefix) s.WriteScript(fn, run.HTTP.Asserts.Script) - b, err := s.RunScript(fn) + b, err := s.RunScript(scriptCtx, fn) if err != nil { s.errs = append(s.errs, errors.Wrapf(err, "assert.script[%v]:\n%v: %v", i, run.HTTP.Asserts.Script, string(b))) @@ -326,7 +343,7 @@ func doScenario(in *doScenarioInput) error { basef := filepath.Base(f) fn := filepath.Join(os.TempDir(), fmt.Sprintf("%v_check", basef)) fn, _ = s.WriteScript(fn, s.Check) - b, err := s.RunScript(fn) + b, err := s.RunScript(scriptCtx, fn) if err != nil { s.errs = append(s.errs, errors.Wrapf(err, "check:\n%v: %v", s.Check, string(b))) @@ -340,6 +357,15 @@ func doScenario(in *doScenarioInput) error { if len(s.errs) > 0 { log.Printf("errs: %v", s.errs) } + if in.cancelCtx != nil { + select { + case <-in.cancelCtx.Done(): + log.Printf("doScenario: cancelled mid-flight for %v, publishing cancelled result", filepath.Base(f)) + publishCancelledResult(in.app, f, in) + continue + default: + } + } if in.ReportSlack != "" { if len(s.errs) > 0 { @@ -447,3 +473,59 @@ func doScenario(in *doScenarioInput) error { return nil } + +// publishCancelledResult publishes a "cancelled" status to oopsdev-report for a +// scenario that was skipped due to a PR close cancellation. +func publishCancelledResult(app *appctx, scenarioFile string, in *doScenarioInput) { + if app == nil || app.rpub == nil { + log.Printf("publishCancelledResult: rpub not initialised, cannot publish cancelled result for %v (run_id=%v) — batch total may never complete", scenarioFile, in.RunID) + return + } + if in.ReportPubsub == "" { + log.Printf("publishCancelledResult: report-pubsub not set, cannot publish cancelled result for %v (run_id=%v) — batch total may never complete", scenarioFile, in.RunID) + return + } + + attr := make(map[string]string) + if pubsub != "" { + attr["pubsub"] = pubsub + } + if snssqs != "" { + attr["snssqs"] = snssqs + } + if in.Metadata != nil { + for _, key := range []string{ + "pr_number", "branch", "commit_sha", "actor", + "trigger_type", "run_url", "repository", "workflow", "total_scenarios", + } { + if v, ok := in.Metadata[key].(string); ok && v != "" { + attr[key] = v + } + } + if ta, ok := in.Metadata["test_analysis"].(map[string]interface{}); ok { + for _, key := range []string{"missing_tests_in_pr", "should_run_tests"} { + if v, ok := ta[key].(bool); ok { + attr[key] = fmt.Sprintf("%v", v) + } + } + } + if b, err := json.Marshal(in.Metadata); err == nil { + attr["metadata"] = string(b) + } + } + + r := ReportPubsub{ + Scenario: scenarioFile, + Attributes: attr, + Status: "cancelled", + Data: "skipped: PR was closed", + MessageID: uuid.NewString(), + RunID: in.RunID, + } + + if err := app.rpub.Publish(r.MessageID, r); err != nil { + log.Printf("publishCancelledResult: Publish failed for %v: %v", scenarioFile, err) + } else { + log.Printf("publishCancelledResult: scenario=%v run_id=%v", scenarioFile, in.RunID) + } +} \ No newline at end of file