Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions apps/workspace-engine/pkg/db/plan_validation.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions apps/workspace-engine/pkg/db/queries/plan_validation.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ JOIN release rel ON rel.id = t.current_release_id
JOIN deployment_version dv ON dv.id = rel.version_id
WHERE t.id = $1;

-- name: ListPlanValidationResultsByTargetID :many
SELECT
v.result_id,
v.rule_id,
v.violations,
r.name AS rule_name
FROM deployment_plan_target_result_validation v
JOIN deployment_plan_target_result res ON res.id = v.result_id
JOIN policy_rule_plan_validation_opa r ON r.id = v.rule_id
WHERE res.target_id = $1
AND v.passed = false
ORDER BY v.evaluated_at DESC;

-- name: UpsertPlanValidationResult :exec
INSERT INTO deployment_plan_target_result_validation (
result_id, rule_id, passed, violations, evaluated_at
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ func (m *mockGetter) ListDeploymentPlanTargetResultsByTargetID(
return nil, nil
}

func (m *mockGetter) ListPlanValidationResultsByTargetID(
_ context.Context,
_ uuid.UUID,
) ([]db.ListPlanValidationResultsByTargetIDRow, error) {
return nil, nil
}

func (m *mockGetter) GetMatchingPlanValidationOpaRules(
_ context.Context,
_ uuid.UUID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ type Getter interface {
targetID uuid.UUID,
) ([]db.ListDeploymentPlanTargetResultsByTargetIDRow, error)

ListPlanValidationResultsByTargetID(
ctx context.Context,
targetID uuid.UUID,
) ([]db.ListPlanValidationResultsByTargetIDRow, error)

GetMatchingPlanValidationOpaRules(
ctx context.Context,
workspaceID uuid.UUID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ func (g *PostgresGetter) ListDeploymentPlanTargetResultsByTargetID(
return db.GetQueries(ctx).ListDeploymentPlanTargetResultsByTargetID(ctx, targetID)
}

func (g *PostgresGetter) ListPlanValidationResultsByTargetID(
ctx context.Context,
targetID uuid.UUID,
) ([]db.ListPlanValidationResultsByTargetIDRow, error) {
return db.GetQueries(ctx).ListPlanValidationResultsByTargetID(ctx, targetID)
}

func (g *PostgresGetter) GetCurrentVersionForPlanTarget(
ctx context.Context,
planTargetID uuid.UUID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,21 @@ func (t targetContext) hasGitHubMetadata() bool {
// agentResult is a denormalized, template-friendly view of a result row
// with its dispatch context parsed for the agent's name/type.
type agentResult struct {
ResultID uuid.UUID
AgentName string
AgentType string
Status db.DeploymentPlanTargetStatus
HasChanges *bool
Current string
Proposed string
Message string
Violations []ruleViolation
}

// ruleViolation is a failed plan-validation rule attached to a result.
type ruleViolation struct {
RuleName string
Messages []string
}

// agentResultFromRow builds an agentResult from a DB row. If the row's
Expand Down Expand Up @@ -131,6 +139,7 @@ func agentResultFromRow(
}

return agentResult{
ResultID: row.ID,
AgentName: agentName,
AgentType: agentType,
Status: row.Status,
Expand All @@ -144,14 +153,15 @@ func agentResultFromRow(
// aggregate describes the overall state of all agents for one target.
// Used to pick the check run's status, conclusion, and title.
type aggregate struct {
Total int
Completed int
Errored int
Unsupported int
Changed int
Unchanged int
Additions int
Deletions int
Total int
Completed int
Errored int
Unsupported int
Changed int
Unchanged int
Additions int
Deletions int
ValidationFailures int
}

func countDiffLines(current, proposed string) (int, int) {
Expand Down Expand Up @@ -195,6 +205,9 @@ func aggregateResults(results []agentResult) aggregate {
if r.HasChanges != nil && !*r.HasChanges {
a.Unchanged++
}
if len(r.Violations) > 0 {
a.ValidationFailures++
}
}
return a
}
Expand All @@ -206,11 +219,11 @@ func (a aggregate) allDone() bool {
}

// shouldFinalize reports whether the check run should be set to
// "completed" status. We finalize as soon as any agent errors so the
// failure is surfaced on the PR immediately, or when all agents have
// reached a terminal state.
// "completed" status. We finalize as soon as any agent errors or any
// plan-validation rule fails so the failure is surfaced on the PR
// immediately, or when all agents have reached a terminal state.
func (a aggregate) shouldFinalize() bool {
return a.Errored > 0 || a.allDone()
return a.Errored > 0 || a.ValidationFailures > 0 || a.allDone()
}

// checkStatus returns the GitHub "status" field for the check run.
Expand All @@ -224,7 +237,7 @@ func (a aggregate) checkStatus() string {
// checkConclusion returns the GitHub "conclusion" field. Only
// meaningful when shouldFinalize() is true.
func (a aggregate) checkConclusion() string {
if a.Errored > 0 {
if a.Errored > 0 || a.ValidationFailures > 0 {
return "failure"
}
if a.Total > 0 && a.Unsupported == a.Total {
Expand Down Expand Up @@ -255,10 +268,25 @@ func (a aggregate) checkTitle() string {
}

diffSummary := fmt.Sprintf("+%d -%d", a.Additions, a.Deletions)
suffix := ""
if a.Errored > 0 {
return fmt.Sprintf("%s (%d errored)", diffSummary, a.Errored)
suffix += fmt.Sprintf(" (%d errored)", a.Errored)
}
if a.ValidationFailures > 0 {
suffix += fmt.Sprintf(
" (%d policy violation%s)",
a.ValidationFailures,
plural(a.ValidationFailures),
)
}
return diffSummary
return diffSummary + suffix
}

func plural(n int) string {
if n == 1 {
return ""
}
return "s"
}

// formatAgentSection renders the markdown block for one agent in the
Expand Down Expand Up @@ -298,9 +326,23 @@ func formatAgentSection(r agentResult) string {
sb.WriteString("\n```diff\n")
sb.WriteString(diff)
sb.WriteString("```\n")
writeViolations(&sb, r.Violations)
return sb.String()
}

func writeViolations(sb *strings.Builder, violations []ruleViolation) {
if len(violations) == 0 {
return
}
sb.WriteString("\n**Policy violations:**\n")
for _, v := range violations {
fmt.Fprintf(sb, "\n- `%s`\n", v.RuleName)
for _, msg := range v.Messages {
fmt.Fprintf(sb, " - %s\n", msg)
}
}
}

// truncateText trims s to fit within maxBytes (accounting for a trailing
// truncation sentinel). It rolls back to the last valid UTF-8 rune
// boundary so multi-byte characters are never cut in half.
Expand Down Expand Up @@ -538,6 +580,11 @@ func loadTargetContext(
return targetContext{}, nil, fmt.Errorf("list target results: %w", err)
}

violationsByResult, err := loadViolationsByResult(ctx, getter, tc.TargetID)
if err != nil {
return targetContext{}, nil, err
}

span := trace.SpanFromContext(ctx)
results := make([]agentResult, len(rows))
for i, r := range rows {
Expand All @@ -548,7 +595,51 @@ func loadTargetContext(
r.ID, parseErr,
))
}
result.Violations = violationsByResult[result.ResultID]
results[i] = result
}
return tc, results, nil
}

func loadViolationsByResult(
ctx context.Context,
getter Getter,
targetID uuid.UUID,
) (map[uuid.UUID][]ruleViolation, error) {
rows, err := getter.ListPlanValidationResultsByTargetID(ctx, targetID)
if err != nil {
return nil, fmt.Errorf("list plan validation results: %w", err)
}

span := trace.SpanFromContext(ctx)
out := make(map[uuid.UUID][]ruleViolation, len(rows))
for _, r := range rows {
messages, parseErr := parseViolationMessages(r.Violations)
if parseErr != nil {
span.RecordError(fmt.Errorf(
"parse violations for rule %s: %w", r.RuleID, parseErr,
))
continue
}
out[r.ResultID] = append(out[r.ResultID], ruleViolation{
RuleName: r.RuleName,
Messages: messages,
})
}
return out, nil
}

func parseViolationMessages(raw []byte) ([]string, error) {
if len(raw) == 0 {
return nil, nil
}
var parsed []oapi.PlanValidationViolation
if err := json.Unmarshal(raw, &parsed); err != nil {
return nil, err
}
messages := make([]string, len(parsed))
for i, v := range parsed {
messages[i] = v.Message
}
return messages, nil
}
Loading
Loading