diff --git a/forge-cli/build/policy_stage.go b/forge-cli/build/policy_stage.go index 46dad3c..e7a628b 100644 --- a/forge-cli/build/policy_stage.go +++ b/forge-cli/build/policy_stage.go @@ -9,6 +9,7 @@ import ( "github.com/initializ/forge/forge-core/agentspec" "github.com/initializ/forge/forge-core/pipeline" + "github.com/initializ/forge/forge-skills/contract" ) // PolicyStage generates the policy scaffold file. @@ -31,6 +32,41 @@ func (s *PolicyStage) Execute(ctx context.Context, bc *pipeline.BuildContext) er } } + // Inject aggregated skill guardrails if present + if bc.SkillRequirements != nil { + if reqs, ok := bc.SkillRequirements.(*contract.AggregatedRequirements); ok && reqs.SkillGuardrails != nil { + sg := reqs.SkillGuardrails + rules := &agentspec.SkillGuardrailRules{} + for _, c := range sg.DenyCommands { + rules.DenyCommands = append(rules.DenyCommands, agentspec.CommandFilter{ + Pattern: c.Pattern, + Message: c.Message, + }) + } + for _, o := range sg.DenyOutput { + rules.DenyOutput = append(rules.DenyOutput, agentspec.OutputFilter{ + Pattern: o.Pattern, + Action: o.Action, + }) + } + for _, p := range sg.DenyPrompts { + rules.DenyPrompts = append(rules.DenyPrompts, agentspec.CommandFilter{ + Pattern: p.Pattern, + Message: p.Message, + }) + } + for _, r := range sg.DenyResponses { + rules.DenyResponses = append(rules.DenyResponses, agentspec.CommandFilter{ + Pattern: r.Pattern, + Message: r.Message, + }) + } + if len(rules.DenyCommands) > 0 || len(rules.DenyOutput) > 0 || len(rules.DenyPrompts) > 0 || len(rules.DenyResponses) > 0 { + bc.Spec.PolicyScaffold.SkillGuardrails = rules + } + } + } + data, err := json.MarshalIndent(bc.Spec.PolicyScaffold, "", " ") if err != nil { return fmt.Errorf("marshalling policy scaffold: %w", err) diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 2fd1438..f3d4a3f 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -59,12 +59,13 @@ type Runner struct { cfg RunnerConfig logger coreruntime.Logger cliExecTool *clitools.CLIExecuteTool - modelConfig *coreruntime.ModelConfig // resolved model config (for banner) - derivedCLIConfig *contract.DerivedCLIConfig // auto-derived from skill requirements - sched *scheduler.Scheduler // cron scheduler (nil until started) - startTime time.Time // server start time (for /health uptime) - scheduleNotifier ScheduleNotifier // optional: delivers cron results to channels - authToken string // resolved auth token (empty if --no-auth) + modelConfig *coreruntime.ModelConfig // resolved model config (for banner) + derivedCLIConfig *contract.DerivedCLIConfig // auto-derived from skill requirements + skillGuardrails *agentspec.SkillGuardrailRules // runtime-parsed skill guardrails (fallback when no build artifact) + sched *scheduler.Scheduler // cron scheduler (nil until started) + startTime time.Time // server start time (for /health uptime) + scheduleNotifier ScheduleNotifier // optional: delivers cron results to channels + authToken string // resolved auth token (empty if --no-auth) } // NewRunner creates a Runner from the given config. @@ -379,6 +380,17 @@ func (r *Runner) Run(ctx context.Context) error { r.registerProgressHooks(hooks) r.registerGuardrailHooks(hooks, guardrails) + // Register skill-level guardrails if present. + // Prefer build-time artifact; fall back to runtime-parsed guardrails. + sgRules := scaffold.SkillGuardrails + if sgRules == nil { + sgRules = r.skillGuardrails + } + if sgRules != nil { + sg := coreruntime.NewSkillGuardrailEngine(sgRules, r.cfg.EnforceGuardrails, r.logger) + r.registerSkillGuardrailHooks(hooks, sg) + } + // Compute model-aware character budget. charBudget := r.cfg.Config.Memory.CharBudget if charBudget == 0 { @@ -1388,6 +1400,46 @@ func (r *Runner) registerGuardrailHooks(hooks *coreruntime.HookRegistry, guardra }) } +// registerSkillGuardrailHooks registers hooks that enforce skill-declared deny +// patterns on user prompts (BeforeLLMCall), command inputs (BeforeToolExec), +// and tool outputs (AfterToolExec). +func (r *Runner) registerSkillGuardrailHooks(hooks *coreruntime.HookRegistry, sg *coreruntime.SkillGuardrailEngine) { + // Block capability-enumeration and other denied prompts before the LLM sees them. + hooks.Register(coreruntime.BeforeLLMCall, func(_ context.Context, hctx *coreruntime.HookContext) error { + if len(hctx.Messages) == 0 { + return nil + } + // Check only the latest user message. + last := hctx.Messages[len(hctx.Messages)-1] + if last.Role == "user" { + return sg.CheckUserInput(last.Content) + } + return nil + }) + hooks.Register(coreruntime.BeforeToolExec, func(_ context.Context, hctx *coreruntime.HookContext) error { + return sg.CheckCommandInput(hctx.ToolName, hctx.ToolInput) + }) + hooks.Register(coreruntime.AfterToolExec, func(_ context.Context, hctx *coreruntime.HookContext) error { + redacted, err := sg.CheckCommandOutput(hctx.ToolName, hctx.ToolOutput) + if err != nil { + return err + } + hctx.ToolOutput = redacted + return nil + }) + // Rewrite LLM responses that enumerate binary names or internal tooling. + hooks.Register(coreruntime.AfterLLMCall, func(_ context.Context, hctx *coreruntime.HookContext) error { + if hctx.Response == nil { + return nil + } + replaced, changed := sg.CheckLLMResponse(hctx.Response.Message.Content) + if changed { + hctx.Response.Message.Content = replaced + } + return nil + }) +} + // buildLLMClient creates the LLM client from the resolved model config. // If fallback providers are configured, wraps them in a FallbackChain. func (r *Runner) buildLLMClient(mc *coreruntime.ModelConfig) (llm.Client, error) { @@ -1771,9 +1823,12 @@ func (r *Runner) buildSkillCatalog() string { if entry.Name != "" && entry.Description != "" { line := fmt.Sprintf("- %s: %s", entry.Name, entry.Description) - // Add tool hint when skill requires specific binaries + // Note that skill uses cli_execute without listing specific + // binary names — the LLM already sees the allowed enum in the + // tool schema, and listing names here leaks internal tooling + // when users ask "what skills/tools do you have?" if entry.ForgeReqs != nil && len(entry.ForgeReqs.Bins) > 0 { - line += fmt.Sprintf(" (use cli_execute with: %s)", strings.Join(entry.ForgeReqs.Bins, ", ")) + line += " (uses cli_execute)" } catalogEntries = append(catalogEntries, line) @@ -1826,6 +1881,13 @@ func (r *Runner) validateSkillRequirements(envVars map[string]string) error { entries := allEntries reqs := requirements.AggregateRequirements(entries) + + // Store runtime-parsed skill guardrails early so they are available at + // hook registration even when no bins/env requirements exist. + if reqs.SkillGuardrails != nil { + r.skillGuardrails = convertSkillGuardrails(reqs.SkillGuardrails) + } + if len(reqs.Bins) == 0 && len(reqs.EnvRequired) == 0 && len(reqs.EnvOneOf) == 0 && len(reqs.EnvOptional) == 0 { return nil } @@ -1881,6 +1943,41 @@ func (r *Runner) validateSkillRequirements(envVars map[string]string) error { return nil } +// convertSkillGuardrails converts skill-contract guardrail config into the +// agentspec representation used by the guardrail engine. This mirrors the +// conversion in build/policy_stage.go for the runtime (no-build) path. +func convertSkillGuardrails(sg *contract.SkillGuardrailConfig) *agentspec.SkillGuardrailRules { + rules := &agentspec.SkillGuardrailRules{} + for _, c := range sg.DenyCommands { + rules.DenyCommands = append(rules.DenyCommands, agentspec.CommandFilter{ + Pattern: c.Pattern, + Message: c.Message, + }) + } + for _, o := range sg.DenyOutput { + rules.DenyOutput = append(rules.DenyOutput, agentspec.OutputFilter{ + Pattern: o.Pattern, + Action: o.Action, + }) + } + for _, p := range sg.DenyPrompts { + rules.DenyPrompts = append(rules.DenyPrompts, agentspec.CommandFilter{ + Pattern: p.Pattern, + Message: p.Message, + }) + } + for _, r := range sg.DenyResponses { + rules.DenyResponses = append(rules.DenyResponses, agentspec.CommandFilter{ + Pattern: r.Pattern, + Message: r.Message, + }) + } + if len(rules.DenyCommands) == 0 && len(rules.DenyOutput) == 0 && len(rules.DenyPrompts) == 0 && len(rules.DenyResponses) == 0 { + return nil + } + return rules +} + func envFromOS() map[string]string { env := make(map[string]string) for _, e := range os.Environ() { diff --git a/forge-cli/tools/cli_execute.go b/forge-cli/tools/cli_execute.go index b1ab588..aadc895 100644 --- a/forge-cli/tools/cli_execute.go +++ b/forge-cli/tools/cli_execute.go @@ -71,6 +71,17 @@ func NewCLIExecuteTool(config CLIExecuteConfig) *CLIExecuteTool { } homeDir := os.Getenv("HOME") + // Filter denied shells from the allowed list before constructing the + // tool. Execute() blocks them at runtime, but including them in the + // schema/description causes the LLM to hallucinate they are available. + filtered := make([]string, 0, len(config.AllowedBinaries)) + for _, bin := range config.AllowedBinaries { + if !deniedShells[bin] { + filtered = append(filtered, bin) + } + } + config.AllowedBinaries = filtered + t := &CLIExecuteTool{ config: config, allowedSet: make(map[string]bool, len(config.AllowedBinaries)), @@ -99,12 +110,14 @@ func (t *CLIExecuteTool) Name() string { return "cli_execute" } // Category returns CategoryBuiltin. func (t *CLIExecuteTool) Category() coretools.Category { return coretools.CategoryBuiltin } -// Description returns a dynamic description listing available binaries. +// Description returns a description of the tool. Binary names are deliberately +// omitted — listing them here causes the LLM to regurgitate them when users +// ask capability questions. The LLM discovers allowed binaries from the schema enum. func (t *CLIExecuteTool) Description() string { if len(t.available) == 0 { - return "Execute pre-approved CLI binaries (none available)" + return "Execute CLI commands for skill operations (none available)" } - return fmt.Sprintf("Execute pre-approved CLI binaries: %s", strings.Join(t.available, ", ")) + return "Execute CLI commands for skill operations. Use the binary field's allowed values from the schema." } // InputSchema returns a dynamic JSON schema with the binary field's enum @@ -299,6 +312,10 @@ func validateArg(arg string) error { if strings.ContainsAny(arg, "\n\r") { return fmt.Errorf("argument contains newline: %q", arg) } + // Defense-in-depth: block file:// URLs which can read the host filesystem. + if strings.Contains(strings.ToLower(arg), "file://") { + return fmt.Errorf("argument contains file:// protocol: %q", arg) + } return nil } diff --git a/forge-cli/tools/cli_execute_test.go b/forge-cli/tools/cli_execute_test.go index 5e002af..78d7957 100644 --- a/forge-cli/tools/cli_execute_test.go +++ b/forge-cli/tools/cli_execute_test.go @@ -147,6 +147,49 @@ func TestCLIExecute_ShellInjection(t *testing.T) { } } +func TestCLIExecute_FileProtocolBlocked(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("echo behavior differs on Windows") + } + + tool := NewCLIExecuteTool(CLIExecuteConfig{ + AllowedBinaries: []string{"echo"}, + }) + + tests := []struct { + name string + arg string + blocked bool + }{ + {"file_lower", "file:///etc/passwd", true}, + {"file_upper", "FILE:///etc/shadow", true}, + {"file_mixed", "File:///etc/hosts", true}, + {"http_allowed", "http://example.com", false}, + {"https_allowed", "https://example.com", false}, + {"plain_arg", "get", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args, _ := json.Marshal(cliExecuteArgs{ + Binary: "echo", + Args: []string{tt.arg}, + }) + + _, err := tool.Execute(context.Background(), args) + if tt.blocked && err == nil { + t.Errorf("Execute() expected error for %q, got nil", tt.arg) + } + if tt.blocked && err != nil && !strings.Contains(err.Error(), "file:// protocol") { + t.Errorf("error = %q, want it to mention 'file:// protocol'", err.Error()) + } + if !tt.blocked && err != nil { + t.Errorf("Execute() unexpected error for %q: %v", tt.arg, err) + } + }) + } +} + func TestCLIExecute_Timeout(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("sleep not available on Windows") diff --git a/forge-core/agentspec/policy_scaffold.go b/forge-core/agentspec/policy_scaffold.go index 958a504..56976c9 100644 --- a/forge-core/agentspec/policy_scaffold.go +++ b/forge-core/agentspec/policy_scaffold.go @@ -2,7 +2,8 @@ package agentspec // PolicyScaffold defines the policy and guardrail configuration for an agent. type PolicyScaffold struct { - Guardrails []Guardrail `json:"guardrails,omitempty" bson:"guardrails,omitempty" yaml:"guardrails,omitempty"` + Guardrails []Guardrail `json:"guardrails,omitempty" bson:"guardrails,omitempty" yaml:"guardrails,omitempty"` + SkillGuardrails *SkillGuardrailRules `json:"skill_guardrails,omitempty" bson:"skill_guardrails,omitempty" yaml:"skill_guardrails,omitempty"` } // Guardrail defines a single guardrail rule applied to an agent. @@ -10,3 +11,23 @@ type Guardrail struct { Type string `json:"type" bson:"type" yaml:"type"` Config map[string]any `json:"config,omitempty" bson:"config,omitempty" yaml:"config,omitempty"` } + +// SkillGuardrailRules holds aggregated skill-level deny patterns. +type SkillGuardrailRules struct { + DenyCommands []CommandFilter `json:"deny_commands,omitempty"` + DenyOutput []OutputFilter `json:"deny_output,omitempty"` + DenyPrompts []CommandFilter `json:"deny_prompts,omitempty"` + DenyResponses []CommandFilter `json:"deny_responses,omitempty"` +} + +// CommandFilter blocks tool execution when the command matches. +type CommandFilter struct { + Pattern string `json:"pattern"` + Message string `json:"message"` +} + +// OutputFilter blocks or redacts tool output matching a pattern. +type OutputFilter struct { + Pattern string `json:"pattern"` + Action string `json:"action"` // "block" or "redact" +} diff --git a/forge-core/runtime/guardrails.go b/forge-core/runtime/guardrails.go index 4ac4ae5..81e7ce2 100644 --- a/forge-core/runtime/guardrails.go +++ b/forge-core/runtime/guardrails.go @@ -4,6 +4,7 @@ import ( "fmt" "regexp" "strings" + "unicode" "github.com/initializ/forge/forge-core/a2a" "github.com/initializ/forge/forge-core/agentspec" @@ -110,16 +111,39 @@ func (g *GuardrailEngine) checkContentFilter(text string, gr agentspec.Guardrail return nil } -var piiPatterns = []*regexp.Regexp{ - regexp.MustCompile(`[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`), // email - regexp.MustCompile(`\b\d{3}[-.]?\d{3}[-.]?\d{4}\b`), // phone - regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`), // SSN +// piiCheckerPattern pairs a regex with an optional validator function. +// When a validator is present, regex matches are only considered true positives +// if the validator confirms the matched text (e.g., Luhn check for credit cards, +// structure validation for SSNs). This follows the pattern from the reference +// guardrails library to reduce false positives. +type piiCheckerPattern struct { + re *regexp.Regexp + validate func(string) bool // nil means regex match alone is sufficient +} + +// Credit card regex: Visa, Mastercard, Amex, Discover with optional separators. +var ccRegex = `\b(?:` + + `4[0-9]{3}[\s-]?[0-9]{4}[\s-]?[0-9]{4}[\s-]?[0-9]{1,4}|` + // Visa + `(?:5[1-5][0-9]{2}|222[1-9]|22[3-9][0-9]|2[3-6][0-9]{2}|27[01][0-9]|2720)[\s-]?[0-9]{4}[\s-]?[0-9]{4}[\s-]?[0-9]{4}|` + // Mastercard + `3[47][0-9]{2}[\s-]?[0-9]{6}[\s-]?[0-9]{5}|` + // Amex + `(?:6011|65[0-9]{2}|64[4-9][0-9])[\s-]?[0-9]{4}[\s-]?[0-9]{4}[\s-]?[0-9]{4}` + // Discover + `)\b` + +var piiPatterns = []piiCheckerPattern{ + {re: regexp.MustCompile(`[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`)}, // email + {re: regexp.MustCompile(`\b(?:\+?1[-.\s])?\(?[2-9]\d{2}\)?[-.\s]\d{3}[-.\s]\d{4}\b`)}, // phone (area code 2-9, separators required) + {re: regexp.MustCompile(`\b\d{3}[-.\s]?\d{2}[-.\s]?\d{4}\b`), validate: validateSSN}, // SSN with structural validation + {re: regexp.MustCompile(ccRegex), validate: validateLuhn}, // credit card with Luhn check } func (g *GuardrailEngine) checkNoPII(text string) error { - for _, re := range piiPatterns { - if re.MatchString(text) { - return fmt.Errorf("PII pattern detected: %s", re.String()) + for _, p := range piiPatterns { + matches := p.re.FindAllString(text, -1) + for _, m := range matches { + if p.validate != nil && !p.validate(m) { + continue + } + return fmt.Errorf("PII pattern detected: %s", p.re.String()) } } return nil @@ -177,31 +201,146 @@ func (g *GuardrailEngine) CheckToolOutput(text string) (string, error) { } for _, gr := range g.scaffold.Guardrails { - var patterns []*regexp.Regexp switch gr.Type { case "no_secrets": - patterns = secretPatterns + for _, re := range secretPatterns { + if !re.MatchString(text) { + continue + } + if g.enforce { + return "", fmt.Errorf("tool output blocked by content policy") + } + text = re.ReplaceAllString(text, "[REDACTED]") + g.logger.Warn("guardrail redaction", map[string]any{ + "guardrail": gr.Type, + "direction": "tool_output", + "detail": fmt.Sprintf("pattern %s matched, content redacted", re.String()), + }) + } case "no_pii": - patterns = piiPatterns + for _, p := range piiPatterns { + if !p.re.MatchString(text) { + continue + } + // Check if any match passes validation + hasValidMatch := false + if p.validate == nil { + hasValidMatch = true + } else { + for _, m := range p.re.FindAllString(text, -1) { + if p.validate(m) { + hasValidMatch = true + break + } + } + } + if !hasValidMatch { + continue + } + if g.enforce { + return "", fmt.Errorf("tool output blocked by content policy") + } + // Warn mode: redact only validated matches + if p.validate != nil { + v := p.validate // capture for closure + text = p.re.ReplaceAllStringFunc(text, func(s string) string { + if v(s) { + return "[REDACTED]" + } + return s + }) + } else { + text = p.re.ReplaceAllString(text, "[REDACTED]") + } + g.logger.Warn("guardrail redaction", map[string]any{ + "guardrail": gr.Type, + "direction": "tool_output", + "detail": fmt.Sprintf("pattern %s matched, content redacted", p.re.String()), + }) + } default: continue } + } + return text, nil +} - for _, re := range patterns { - if !re.MatchString(text) { - continue - } - if g.enforce { - return "", fmt.Errorf("tool output blocked by content policy") +// --- PII Validators --- +// Ported from the reference guardrails library to reduce false positives. + +// validateSSN validates a US Social Security Number structure. +// Rejects area=000/666/900+, group=00, serial=0000, all-same digits, and known test SSNs. +func validateSSN(s string) bool { + cleaned := strings.NewReplacer("-", "", " ", "", ".", "").Replace(s) + if len(cleaned) != 9 { + return false + } + for _, r := range cleaned { + if !unicode.IsDigit(r) { + return false + } + } + + area := cleaned[0:3] + group := cleaned[3:5] + serial := cleaned[5:9] + + if area == "000" || area == "666" || area[0] == '9' { + return false + } + if group == "00" { + return false + } + if serial == "0000" { + return false + } + + // All same digits + allSame := true + for i := 1; i < len(cleaned); i++ { + if cleaned[i] != cleaned[0] { + allSame = false + break + } + } + if allSame { + return false + } + + // Known test/advertising SSNs + testSSNs := map[string]bool{ + "078051120": true, + "219099999": true, + "123456789": true, + } + return !testSSNs[cleaned] +} + +// validateLuhn performs Luhn checksum validation on a credit card number. +// Strips separators (spaces, dashes) before validating. +func validateLuhn(s string) bool { + cleaned := strings.NewReplacer(" ", "", "-", "").Replace(s) + if len(cleaned) < 13 || len(cleaned) > 19 { + return false + } + for _, r := range cleaned { + if !unicode.IsDigit(r) { + return false + } + } + + sum := 0 + double := false + for i := len(cleaned) - 1; i >= 0; i-- { + digit := int(cleaned[i] - '0') + if double { + digit *= 2 + if digit > 9 { + digit -= 9 } - // Warn mode: redact matches - text = re.ReplaceAllString(text, "[REDACTED]") - g.logger.Warn("guardrail redaction", map[string]any{ - "guardrail": gr.Type, - "direction": "tool_output", - "detail": fmt.Sprintf("pattern %s matched, content redacted", re.String()), - }) } + sum += digit + double = !double } - return text, nil + return sum%10 == 0 } diff --git a/forge-core/runtime/guardrails_test.go b/forge-core/runtime/guardrails_test.go new file mode 100644 index 0000000..368a094 --- /dev/null +++ b/forge-core/runtime/guardrails_test.go @@ -0,0 +1,299 @@ +package runtime + +import ( + "testing" + + "github.com/initializ/forge/forge-core/a2a" + "github.com/initializ/forge/forge-core/agentspec" +) + +// --- Validator unit tests --- + +func TestValidateSSN(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + {"known test SSN with dashes", "123-45-6789", false}, // 123456789 is a known test SSN + {"known test SSN no separators", "123456789", false}, + {"valid SSN dots", "456.78.9012", true}, + {"area 000", "000-12-3456", false}, + {"area 666", "666-12-3456", false}, + {"area 900+", "900-12-3456", false}, + {"area 999", "999-12-3456", false}, + {"group 00", "123-00-4567", false}, + {"serial 0000", "123-45-0000", false}, + {"all same digits", "111111111", false}, + {"all same digits 555", "555555555", false}, + {"known test SSN 078051120", "078051120", false}, + {"known test SSN 219099999", "219099999", false}, + {"too short", "12345678", false}, + {"too long", "1234567890", false}, + {"non-digit", "12a-45-6789", false}, + {"valid 456-78-9012", "456-78-9012", true}, + {"valid 321-54-9876", "321-54-9876", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := validateSSN(tt.input); got != tt.want { + t.Errorf("validateSSN(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestValidateLuhn(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + // Known valid test card numbers + {"Visa test", "4111111111111111", true}, + {"Visa with spaces", "4111 1111 1111 1111", true}, + {"Visa with dashes", "4111-1111-1111-1111", true}, + {"Mastercard test", "5500000000000004", true}, + {"Amex test", "378282246310005", true}, + {"Discover test", "6011111111111117", true}, + // Invalid + {"bad checksum", "4111111111111112", false}, + {"too short", "411111111111", false}, + {"too long", "41111111111111111111", false}, + {"non-digit", "4111abcd11111111", false}, + // Random numbers that happen to be 16 digits should usually fail + {"random 16 digits", "1234567890123456", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := validateLuhn(tt.input); got != tt.want { + t.Errorf("validateLuhn(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +// --- PII pattern matching tests --- + +func TestCheckNoPII_Phone(t *testing.T) { + noopLogger := &testLogger{} + g := NewGuardrailEngine(&agentspec.PolicyScaffold{ + Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, + }, true, noopLogger) + + tests := []struct { + name string + text string + wantErr bool + }{ + {"US phone with dashes", "call 212-555-1234", true}, + {"US phone with dots", "call 212.555.1234", true}, + {"US phone with +1", "call +1-212-555-1234", true}, + {"US phone with parens", "call (212) 555-1234", true}, + // Area code must start with 2-9 + {"area code starts with 0", "call 012-555-1234", false}, + {"area code starts with 1", "call 112-555-1234", false}, + // K8s byte counts should NOT match + {"k8s memory bytes 4Gi", "memory: 4294967296 bytes", false}, + {"k8s memory bytes 1Gi", "memory: 1073741824 bytes", false}, + {"k8s memory 10 digits", "allocatable: 3221225472 bytes", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := g.checkNoPII(tt.text) + if (err != nil) != tt.wantErr { + t.Errorf("checkNoPII(%q) error = %v, wantErr %v", tt.text, err, tt.wantErr) + } + }) + } +} + +func TestCheckNoPII_SSN(t *testing.T) { + noopLogger := &testLogger{} + g := NewGuardrailEngine(&agentspec.PolicyScaffold{ + Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, + }, true, noopLogger) + + tests := []struct { + name string + text string + wantErr bool + }{ + {"valid SSN", "SSN: 456-78-9012", true}, + {"valid SSN no sep", "SSN: 456789012", true}, + {"invalid area 000", "SSN: 000-12-3456", false}, + {"invalid area 666", "SSN: 666-12-3456", false}, + {"invalid area 900+", "SSN: 900-12-3456", false}, + {"invalid group 00", "SSN: 123-00-4567", false}, + {"invalid serial 0000", "SSN: 123-45-0000", false}, + {"all same digits", "SSN: 111-11-1111", false}, + {"known test SSN", "SSN: 123-45-6789", false}, // 123456789 is a known test SSN + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := g.checkNoPII(tt.text) + if (err != nil) != tt.wantErr { + t.Errorf("checkNoPII(%q) error = %v, wantErr %v", tt.text, err, tt.wantErr) + } + }) + } +} + +func TestCheckNoPII_CreditCard(t *testing.T) { + noopLogger := &testLogger{} + g := NewGuardrailEngine(&agentspec.PolicyScaffold{ + Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, + }, true, noopLogger) + + tests := []struct { + name string + text string + wantErr bool + }{ + {"Visa", "card: 4111111111111111", true}, + {"Visa with spaces", "card: 4111 1111 1111 1111", true}, + {"Mastercard", "card: 5500000000000004", true}, + {"Amex", "card: 378282246310005", true}, + {"bad Luhn", "card: 4111111111111112", false}, + {"random 16 digits", "card: 1234567890123456", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := g.checkNoPII(tt.text) + if (err != nil) != tt.wantErr { + t.Errorf("checkNoPII(%q) error = %v, wantErr %v", tt.text, err, tt.wantErr) + } + }) + } +} + +func TestCheckNoPII_Email(t *testing.T) { + noopLogger := &testLogger{} + g := NewGuardrailEngine(&agentspec.PolicyScaffold{ + Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, + }, true, noopLogger) + + tests := []struct { + name string + text string + wantErr bool + }{ + {"simple email", "contact: user@example.com", true}, + {"email with plus", "contact: user+tag@example.com", true}, + {"not an email", "contact: user at example dot com", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := g.checkNoPII(tt.text) + if (err != nil) != tt.wantErr { + t.Errorf("checkNoPII(%q) error = %v, wantErr %v", tt.text, err, tt.wantErr) + } + }) + } +} + +// --- CheckToolOutput tests --- + +func TestCheckToolOutput_RedactsWithValidation(t *testing.T) { + logger := &testLogger{} + g := NewGuardrailEngine(&agentspec.PolicyScaffold{ + Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, + }, false, logger) // warn mode + + // Valid SSN should be redacted + out, err := g.CheckToolOutput("SSN is 456-78-9012") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out == "SSN is 456-78-9012" { + t.Error("expected valid SSN to be redacted") + } + + // Invalid SSN (area 000) should NOT be redacted + out, err = g.CheckToolOutput("code 000-12-3456 here") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != "code 000-12-3456 here" { + t.Errorf("expected invalid SSN to pass through, got %q", out) + } +} + +func TestCheckToolOutput_K8sBytesNotBlocked(t *testing.T) { + logger := &testLogger{} + g := NewGuardrailEngine(&agentspec.PolicyScaffold{ + Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, + }, true, logger) // enforce mode + + // K8s memory byte counts should not trigger PII detection + k8sOutput := `{"memory": "4294967296", "cpu": "2000m", "pods": "110", "allocatable_memory": "3221225472"}` + out, err := g.CheckToolOutput(k8sOutput) + if err != nil { + t.Fatalf("k8s output blocked as PII: %v", err) + } + if out != k8sOutput { + t.Errorf("k8s output was modified: %q", out) + } +} + +func TestCheckToolOutput_EnforceBlocksValidPII(t *testing.T) { + logger := &testLogger{} + g := NewGuardrailEngine(&agentspec.PolicyScaffold{ + Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, + }, true, logger) // enforce mode + + _, err := g.CheckToolOutput("SSN: 456-78-9012") + if err == nil { + t.Error("expected enforce mode to block valid SSN") + } +} + +// --- CheckOutbound message tests --- + +func TestCheckOutbound_PIIBlocked(t *testing.T) { + logger := &testLogger{} + g := NewGuardrailEngine(&agentspec.PolicyScaffold{ + Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, + }, true, logger) + + msg := &a2a.Message{ + Role: "agent", + Parts: []a2a.Part{ + {Kind: a2a.PartKindText, Text: "Your SSN is 456-78-9012"}, + }, + } + err := g.CheckOutbound(msg) + if err == nil { + t.Error("expected PII to be blocked in outbound message") + } +} + +func TestCheckOutbound_InvalidSSNPasses(t *testing.T) { + logger := &testLogger{} + g := NewGuardrailEngine(&agentspec.PolicyScaffold{ + Guardrails: []agentspec.Guardrail{{Type: "no_pii"}}, + }, true, logger) + + msg := &a2a.Message{ + Role: "agent", + Parts: []a2a.Part{ + {Kind: a2a.PartKindText, Text: "code: 000-12-3456"}, + }, + } + err := g.CheckOutbound(msg) + if err != nil { + t.Errorf("invalid SSN should pass through, got error: %v", err) + } +} + +// testLogger is a no-op logger for tests. +type testLogger struct { + warnings []string +} + +func (l *testLogger) Info(msg string, fields map[string]any) {} +func (l *testLogger) Debug(msg string, fields map[string]any) {} +func (l *testLogger) Warn(msg string, fields map[string]any) { + l.warnings = append(l.warnings, msg) +} +func (l *testLogger) Error(msg string, fields map[string]any) {} diff --git a/forge-core/runtime/skill_guardrails.go b/forge-core/runtime/skill_guardrails.go new file mode 100644 index 0000000..6248ffe --- /dev/null +++ b/forge-core/runtime/skill_guardrails.go @@ -0,0 +1,240 @@ +package runtime + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + + "github.com/initializ/forge/forge-core/agentspec" +) + +// SkillGuardrailEngine enforces skill-declared deny patterns on command inputs, +// tool outputs, and user prompts. It complements the global GuardrailEngine with +// domain-specific rules authored by skill developers. +type SkillGuardrailEngine struct { + denyCommands []compiledCommandFilter + denyOutput []compiledOutputFilter + denyPrompts []compiledCommandFilter + denyResponses []compiledCommandFilter + enforce bool + logger Logger +} + +type compiledCommandFilter struct { + re *regexp.Regexp + message string +} + +type compiledOutputFilter struct { + re *regexp.Regexp + action string // "block" or "redact" +} + +// NewSkillGuardrailEngine creates a SkillGuardrailEngine from aggregated skill rules. +// Invalid regex patterns are skipped with a warning. +func NewSkillGuardrailEngine(rules *agentspec.SkillGuardrailRules, enforce bool, logger Logger) *SkillGuardrailEngine { + engine := &SkillGuardrailEngine{enforce: enforce, logger: logger} + + if rules == nil { + return engine + } + + for _, c := range rules.DenyCommands { + re, err := regexp.Compile(c.Pattern) + if err != nil { + logger.Warn("skill guardrail: invalid deny_command regex, skipping", map[string]any{ + "pattern": c.Pattern, + "error": err.Error(), + }) + continue + } + engine.denyCommands = append(engine.denyCommands, compiledCommandFilter{ + re: re, + message: c.Message, + }) + } + + for _, o := range rules.DenyOutput { + re, err := regexp.Compile(o.Pattern) + if err != nil { + logger.Warn("skill guardrail: invalid deny_output regex, skipping", map[string]any{ + "pattern": o.Pattern, + "error": err.Error(), + }) + continue + } + engine.denyOutput = append(engine.denyOutput, compiledOutputFilter{ + re: re, + action: o.Action, + }) + } + + for _, p := range rules.DenyPrompts { + re, err := regexp.Compile("(?i)" + p.Pattern) + if err != nil { + logger.Warn("skill guardrail: invalid deny_prompt regex, skipping", map[string]any{ + "pattern": p.Pattern, + "error": err.Error(), + }) + continue + } + engine.denyPrompts = append(engine.denyPrompts, compiledCommandFilter{ + re: re, + message: p.Message, + }) + } + + for _, r := range rules.DenyResponses { + re, err := regexp.Compile("(?is)" + r.Pattern) + if err != nil { + logger.Warn("skill guardrail: invalid deny_response regex, skipping", map[string]any{ + "pattern": r.Pattern, + "error": err.Error(), + }) + continue + } + engine.denyResponses = append(engine.denyResponses, compiledCommandFilter{ + re: re, + message: r.Message, + }) + } + + return engine +} + +// CheckCommandInput validates a tool call before execution. It only fires for +// cli_execute tool calls. Returns an error if the command matches a deny pattern. +func (s *SkillGuardrailEngine) CheckCommandInput(toolName, toolInput string) error { + if toolName != "cli_execute" { + return nil + } + if len(s.denyCommands) == 0 { + return nil + } + + cmdLine := extractCommandLine(toolInput) + if cmdLine == "" { + return nil + } + + for _, f := range s.denyCommands { + if f.re.MatchString(cmdLine) { + msg := f.message + if msg == "" { + msg = "command blocked by skill guardrail" + } + if s.enforce { + return fmt.Errorf("skill guardrail: %s", msg) + } + s.logger.Warn("skill guardrail command match", map[string]any{ + "pattern": f.re.String(), + "command": cmdLine, + "message": msg, + }) + return fmt.Errorf("skill guardrail: %s", msg) + } + } + + return nil +} + +// CheckCommandOutput validates tool output after execution. It only fires for +// cli_execute tool calls. Returns the (possibly redacted) output and an error +// if the output matches a "block" pattern. +func (s *SkillGuardrailEngine) CheckCommandOutput(toolName, toolOutput string) (string, error) { + if toolName != "cli_execute" { + return toolOutput, nil + } + if len(s.denyOutput) == 0 || toolOutput == "" { + return toolOutput, nil + } + + for _, f := range s.denyOutput { + if !f.re.MatchString(toolOutput) { + continue + } + + switch f.action { + case "block": + return "", fmt.Errorf("tool output blocked by skill guardrail") + case "redact": + toolOutput = f.re.ReplaceAllString(toolOutput, "[BLOCKED BY POLICY]") + s.logger.Warn("skill guardrail output redaction", map[string]any{ + "pattern": f.re.String(), + "action": "redact", + }) + } + } + + return toolOutput, nil +} + +// CheckUserInput validates a user message against deny_prompts patterns. +// Returns an error with the skill-defined redirect message if the prompt matches. +func (s *SkillGuardrailEngine) CheckUserInput(text string) error { + if len(s.denyPrompts) == 0 || text == "" { + return nil + } + + for _, f := range s.denyPrompts { + if f.re.MatchString(text) { + msg := f.message + if msg == "" { + msg = "prompt blocked by skill guardrail" + } + s.logger.Warn("skill guardrail prompt match", map[string]any{ + "pattern": f.re.String(), + "message": msg, + }) + return fmt.Errorf("skill guardrail: %s", msg) + } + } + + return nil +} + +// CheckLLMResponse validates the LLM's response text against deny_responses +// patterns. When a match is found, the response is replaced with the +// skill-defined redirect message to prevent binary/tool enumeration leaks. +// Returns the (possibly replaced) text and whether a replacement occurred. +func (s *SkillGuardrailEngine) CheckLLMResponse(text string) (string, bool) { + if len(s.denyResponses) == 0 || text == "" { + return text, false + } + + for _, f := range s.denyResponses { + if f.re.MatchString(text) { + msg := f.message + if msg == "" { + msg = "I can help you with specific tasks. What would you like to do?" + } + s.logger.Warn("skill guardrail response match", map[string]any{ + "pattern": f.re.String(), + "action": "replace", + }) + return msg, true + } + } + + return text, false +} + +// extractCommandLine parses the cli_execute tool input JSON to build a command +// line string "binary arg1 arg2 ..." for pattern matching. +func extractCommandLine(toolInput string) string { + var input struct { + Binary string `json:"binary"` + Args []string `json:"args"` + } + if err := json.Unmarshal([]byte(toolInput), &input); err != nil { + return "" + } + if input.Binary == "" { + return "" + } + + parts := []string{input.Binary} + parts = append(parts, input.Args...) + return strings.Join(parts, " ") +} diff --git a/forge-core/runtime/skill_guardrails_test.go b/forge-core/runtime/skill_guardrails_test.go new file mode 100644 index 0000000..537cdee --- /dev/null +++ b/forge-core/runtime/skill_guardrails_test.go @@ -0,0 +1,417 @@ +package runtime + +import ( + "testing" + + "github.com/initializ/forge/forge-core/agentspec" +) + +func TestCheckCommandInput_DenyPatterns(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyCommands: []agentspec.CommandFilter{ + {Pattern: `\bget\s+secrets?\b`, Message: "Listing Kubernetes secrets is not permitted"}, + {Pattern: `\bauth\s+can-i\b`, Message: "Permission enumeration is not permitted"}, + }, + } + sg := NewSkillGuardrailEngine(rules, true, &testLogger{}) + + tests := []struct { + name string + toolName string + toolInput string + wantErr bool + }{ + { + name: "kubectl get secrets blocked", + toolName: "cli_execute", + toolInput: `{"binary":"kubectl","args":["get","secrets"]}`, + wantErr: true, + }, + { + name: "kubectl get secret blocked", + toolName: "cli_execute", + toolInput: `{"binary":"kubectl","args":["get","secret"]}`, + wantErr: true, + }, + { + name: "kubectl get pods allowed", + toolName: "cli_execute", + toolInput: `{"binary":"kubectl","args":["get","pods"]}`, + wantErr: false, + }, + { + name: "kubectl auth can-i blocked", + toolName: "cli_execute", + toolInput: `{"binary":"kubectl","args":["auth","can-i","get","pods"]}`, + wantErr: true, + }, + { + name: "non cli_execute passes through", + toolName: "web_search", + toolInput: `{"query":"kubectl get secrets"}`, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := sg.CheckCommandInput(tt.toolName, tt.toolInput) + if (err != nil) != tt.wantErr { + t.Errorf("CheckCommandInput() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCheckCommandInput_MultiWordArgs(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyCommands: []agentspec.CommandFilter{ + {Pattern: `\bget\s+secrets?\b`, Message: "Listing Kubernetes secrets is not permitted"}, + }, + } + sg := NewSkillGuardrailEngine(rules, true, &testLogger{}) + + // kubectl get secret my-secret -o yaml should be blocked + err := sg.CheckCommandInput("cli_execute", `{"binary":"kubectl","args":["get","secret","my-secret","-o","yaml"]}`) + if err == nil { + t.Error("expected multi-word secret command to be blocked") + } +} + +func TestCheckCommandOutput_BlockPatterns(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyOutput: []agentspec.OutputFilter{ + {Pattern: `kind:\s*Secret`, Action: "block"}, + }, + } + sg := NewSkillGuardrailEngine(rules, true, &testLogger{}) + + _, err := sg.CheckCommandOutput("cli_execute", `apiVersion: v1 +kind: Secret +metadata: + name: my-secret`) + if err == nil { + t.Error("expected output with kind: Secret to be blocked") + } +} + +func TestCheckCommandOutput_RedactPatterns(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyOutput: []agentspec.OutputFilter{ + {Pattern: `token:\s*[A-Za-z0-9+/=]{40,}`, Action: "redact"}, + }, + } + logger := &testLogger{} + sg := NewSkillGuardrailEngine(rules, false, logger) + + token := "token: " + "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9abcdefghij" + out, err := sg.CheckCommandOutput("cli_execute", "data: "+token+" end") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out == "data: "+token+" end" { + t.Error("expected token to be redacted") + } + if len(logger.warnings) == 0 { + t.Error("expected warning to be logged") + } +} + +func TestCheckCommandOutput_NoMatch(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyOutput: []agentspec.OutputFilter{ + {Pattern: `kind:\s*Secret`, Action: "block"}, + }, + } + sg := NewSkillGuardrailEngine(rules, true, &testLogger{}) + + podListing := `NAME READY STATUS RESTARTS AGE +nginx 1/1 Running 0 5m` + + out, err := sg.CheckCommandOutput("cli_execute", podListing) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != podListing { + t.Errorf("expected output to pass through unchanged, got %q", out) + } +} + +func TestNewSkillGuardrailEngine_InvalidRegex(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyCommands: []agentspec.CommandFilter{ + {Pattern: `[invalid`, Message: "should be skipped"}, + {Pattern: `\bget\s+pods\b`, Message: "valid pattern"}, + }, + DenyOutput: []agentspec.OutputFilter{ + {Pattern: `(unclosed`, Action: "block"}, + }, + } + logger := &testLogger{} + sg := NewSkillGuardrailEngine(rules, true, logger) + + // Invalid regex should be skipped, valid one should still work + if len(sg.denyCommands) != 1 { + t.Errorf("expected 1 compiled command filter, got %d", len(sg.denyCommands)) + } + if len(sg.denyOutput) != 0 { + t.Errorf("expected 0 compiled output filters, got %d", len(sg.denyOutput)) + } + if len(logger.warnings) != 2 { + t.Errorf("expected 2 warnings for invalid regex, got %d", len(logger.warnings)) + } +} + +func TestCheckCommandInput_NonCLIExecute(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyCommands: []agentspec.CommandFilter{ + {Pattern: `.*`, Message: "blocks everything"}, + }, + } + sg := NewSkillGuardrailEngine(rules, true, &testLogger{}) + + // Non-cli_execute tools should pass through + for _, tool := range []string{"web_search", "http_request", "memory_search", "file_read"} { + err := sg.CheckCommandInput(tool, `{"query":"test"}`) + if err != nil { + t.Errorf("tool %q should not be blocked: %v", tool, err) + } + } +} + +func TestCheckCommandOutput_NonCLIExecute(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyOutput: []agentspec.OutputFilter{ + {Pattern: `.*`, Action: "block"}, + }, + } + sg := NewSkillGuardrailEngine(rules, true, &testLogger{}) + + // Non-cli_execute tools should pass through + out, err := sg.CheckCommandOutput("web_search", "kind: Secret") + if err != nil { + t.Errorf("non-cli_execute should not be blocked: %v", err) + } + if out != "kind: Secret" { + t.Errorf("output should pass through unchanged for non-cli_execute") + } +} + +func TestCheckCommandInput_EmptyInput(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyCommands: []agentspec.CommandFilter{ + {Pattern: `.*`, Message: "blocks everything"}, + }, + } + sg := NewSkillGuardrailEngine(rules, true, &testLogger{}) + + // Empty or invalid JSON should not error + err := sg.CheckCommandInput("cli_execute", "") + if err != nil { + t.Errorf("empty input should not error: %v", err) + } + + err = sg.CheckCommandInput("cli_execute", "not json") + if err != nil { + t.Errorf("invalid JSON should not error: %v", err) + } +} + +func TestCheckCommandOutput_BlockCertificate(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyOutput: []agentspec.OutputFilter{ + {Pattern: `-----BEGIN (CERTIFICATE|RSA PRIVATE KEY|EC PRIVATE KEY|PRIVATE KEY)-----`, Action: "block"}, + }, + } + sg := NewSkillGuardrailEngine(rules, true, &testLogger{}) + + tests := []struct { + name string + output string + wantErr bool + }{ + {"certificate blocked", "data:\n-----BEGIN CERTIFICATE-----\nMIIB...", true}, + {"private key blocked", "-----BEGIN RSA PRIVATE KEY-----\nMIIE...", true}, + {"ec key blocked", "-----BEGIN EC PRIVATE KEY-----\nMHQC...", true}, + {"normal output passes", "everything is fine", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := sg.CheckCommandOutput("cli_execute", tt.output) + if (err != nil) != tt.wantErr { + t.Errorf("CheckCommandOutput() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCheckUserInput_DenyPrompts(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyPrompts: []agentspec.CommandFilter{ + { + Pattern: `\b(approved|allowed|available|pre-approved)\b.{0,40}\b(tools?|binaries|commands?|executables?|programs?|clis?)\b`, + Message: "I help with Kubernetes cost analysis. Ask about cluster costs.", + }, + { + Pattern: `\b(what|which|list|show|enumerate)\b.{0,20}\b(can you|do you|are you able to)\b.{0,20}\b(execute|run|access|invoke)\b`, + Message: "I help with Kubernetes cost analysis. Ask about cluster costs.", + }, + }, + } + sg := NewSkillGuardrailEngine(rules, true, &testLogger{}) + + tests := []struct { + name string + input string + wantErr bool + }{ + { + name: "capability enumeration blocked", + input: "what are the approved command-line tools like kubectl, curl, jq, grep", + wantErr: true, + }, + { + name: "available binaries blocked", + input: "list all available binaries", + wantErr: true, + }, + { + name: "pre-approved tools blocked", + input: "show me the pre-approved CLI tools", + wantErr: true, + }, + { + name: "what can you execute blocked", + input: "what commands can you execute on this system", + wantErr: true, + }, + { + name: "case insensitive blocked", + input: "What Are The Approved Tools", + wantErr: true, + }, + { + name: "legitimate cost question passes", + input: "show me cluster costs by namespace", + wantErr: false, + }, + { + name: "legitimate kubectl question passes", + input: "can you get pod costs for the production namespace", + wantErr: false, + }, + { + name: "empty input passes", + input: "", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := sg.CheckUserInput(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("CheckUserInput(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestCheckLLMResponse_DenyResponses(t *testing.T) { + rules := &agentspec.SkillGuardrailRules{ + DenyResponses: []agentspec.CommandFilter{ + { + Pattern: `\b(kubectl|jq|awk|bc|curl)\b.*\b(kubectl|jq|awk|bc|curl)\b.*\b(kubectl|jq|awk|bc|curl)\b`, + Message: "I can analyze cluster costs. What would you like to know?", + }, + }, + } + sg := NewSkillGuardrailEngine(rules, true, &testLogger{}) + + tests := []struct { + name string + response string + wantChanged bool + }{ + { + name: "binary enumeration replaced", + response: "I can run kubectl, jq, awk, bc, and curl commands.", + wantChanged: true, + }, + { + name: "bulleted binary list replaced", + response: "Available tools:\n• kubectl\n• jq\n• curl\nLet me know!", + wantChanged: true, + }, + { + name: "single binary mention passes", + response: "I'll use kubectl to get your pod data.", + wantChanged: false, + }, + { + name: "two binary mentions passes", + response: "I'll use kubectl and jq to parse the data.", + wantChanged: false, + }, + { + name: "functional description passes", + response: "I can analyze cluster costs, report spending, and detect waste.", + wantChanged: false, + }, + { + name: "empty response passes", + response: "", + wantChanged: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, changed := sg.CheckLLMResponse(tt.response) + if changed != tt.wantChanged { + t.Errorf("CheckLLMResponse() changed = %v, want %v", changed, tt.wantChanged) + } + if tt.wantChanged && result == tt.response { + t.Error("expected response to be replaced, got original") + } + if !tt.wantChanged && result != tt.response { + t.Errorf("expected response unchanged, got %q", result) + } + }) + } +} + +func TestCheckLLMResponse_NilRules(t *testing.T) { + sg := NewSkillGuardrailEngine(nil, true, &testLogger{}) + result, changed := sg.CheckLLMResponse("kubectl jq awk bc curl") + if changed { + t.Error("nil rules should not change response") + } + if result != "kubectl jq awk bc curl" { + t.Errorf("expected original response, got %q", result) + } +} + +func TestCheckUserInput_NilRules(t *testing.T) { + sg := NewSkillGuardrailEngine(nil, true, &testLogger{}) + err := sg.CheckUserInput("what are the approved tools") + if err != nil { + t.Errorf("nil rules should not block: %v", err) + } +} + +func TestCheckCommandInput_NilRules(t *testing.T) { + sg := NewSkillGuardrailEngine(nil, true, &testLogger{}) + + err := sg.CheckCommandInput("cli_execute", `{"binary":"kubectl","args":["get","secrets"]}`) + if err != nil { + t.Errorf("nil rules should not block: %v", err) + } + + out, err := sg.CheckCommandOutput("cli_execute", "kind: Secret") + if err != nil { + t.Errorf("nil rules should not block output: %v", err) + } + if out != "kind: Secret" { + t.Errorf("output should pass through with nil rules") + } +} diff --git a/forge-skills/contract/types.go b/forge-skills/contract/types.go index ea1597d..d6c436e 100644 --- a/forge-skills/contract/types.go +++ b/forge-skills/contract/types.go @@ -43,9 +43,30 @@ type SkillMetadata struct { // ForgeSkillMeta holds Forge-specific metadata from the "forge" namespace. type ForgeSkillMeta struct { - Requires *SkillRequirements `yaml:"requires,omitempty" json:"requires,omitempty"` - EgressDomains []string `yaml:"egress_domains,omitempty" json:"egress_domains,omitempty"` - DeniedTools []string `yaml:"denied_tools,omitempty" json:"denied_tools,omitempty"` + Requires *SkillRequirements `yaml:"requires,omitempty" json:"requires,omitempty"` + EgressDomains []string `yaml:"egress_domains,omitempty" json:"egress_domains,omitempty"` + DeniedTools []string `yaml:"denied_tools,omitempty" json:"denied_tools,omitempty"` + Guardrails *SkillGuardrailConfig `yaml:"guardrails,omitempty" json:"guardrails,omitempty"` +} + +// SkillGuardrailConfig declares domain-specific guardrails for a skill. +type SkillGuardrailConfig struct { + DenyCommands []SkillCommandFilter `yaml:"deny_commands,omitempty" json:"deny_commands,omitempty"` + DenyOutput []SkillOutputFilter `yaml:"deny_output,omitempty" json:"deny_output,omitempty"` + DenyPrompts []SkillCommandFilter `yaml:"deny_prompts,omitempty" json:"deny_prompts,omitempty"` + DenyResponses []SkillCommandFilter `yaml:"deny_responses,omitempty" json:"deny_responses,omitempty"` +} + +// SkillCommandFilter blocks tool execution when the command matches. +type SkillCommandFilter struct { + Pattern string `yaml:"pattern" json:"pattern"` // regex matched against "binary arg1 arg2 ..." + Message string `yaml:"message" json:"message"` // error returned to LLM +} + +// SkillOutputFilter blocks or redacts tool output matching a pattern. +type SkillOutputFilter struct { + Pattern string `yaml:"pattern" json:"pattern"` // regex matched against tool output + Action string `yaml:"action" json:"action"` // "block" or "redact" } // SkillRequirements declares CLI binaries and environment variables a skill needs. @@ -89,13 +110,14 @@ type SkillFilter struct { // AggregatedRequirements is the union of all skill requirements. type AggregatedRequirements struct { - Bins []string // union of all bins, deduplicated, sorted - EnvRequired []string // union of required vars (promoted from optional if needed) - EnvOneOf [][]string // separate groups per skill (not merged across skills) - EnvOptional []string // union of optional vars minus those promoted to required - MaxTimeoutHint int // maximum timeout_hint across all skills (seconds) - DeniedTools []string // union of denied tools across skills, deduplicated, sorted - EgressDomains []string // union of egress domains across skills, deduplicated, sorted + Bins []string // union of all bins, deduplicated, sorted + EnvRequired []string // union of required vars (promoted from optional if needed) + EnvOneOf [][]string // separate groups per skill (not merged across skills) + EnvOptional []string // union of optional vars minus those promoted to required + MaxTimeoutHint int // maximum timeout_hint across all skills (seconds) + DeniedTools []string // union of denied tools across skills, deduplicated, sorted + EgressDomains []string // union of egress domains across skills, deduplicated, sorted + SkillGuardrails *SkillGuardrailConfig // aggregated guardrails from all skills } // DerivedCLIConfig holds auto-derived cli_execute configuration from skill requirements. diff --git a/forge-skills/local/embedded/k8s-cost-visibility/SKILL.md b/forge-skills/local/embedded/k8s-cost-visibility/SKILL.md index d7ee5c3..340e5b8 100644 --- a/forge-skills/local/embedded/k8s-cost-visibility/SKILL.md +++ b/forge-skills/local/embedded/k8s-cost-visibility/SKILL.md @@ -37,6 +37,29 @@ metadata: denied_tools: - http_request - web_search + guardrails: + deny_prompts: + - pattern: '\b(approved|allowed|available|pre-approved)\b.{0,40}\b(tools?|binaries|commands?|executables?|programs?|clis?)\b' + message: "I help with Kubernetes cost analysis. Ask me about cluster costs, namespace spending, or resource optimization." + - pattern: '\b(what|which|list|show|enumerate)\b.{0,20}\b(can you|do you|are you able to)\b.{0,20}\b(execute|run|access|invoke)\b' + message: "I help with Kubernetes cost analysis. Ask me about cluster costs, namespace spending, or resource optimization." + deny_responses: + - pattern: '\b(kubectl|jq|awk|bc|curl)\b.*\b(kubectl|jq|awk|bc|curl)\b.*\b(kubectl|jq|awk|bc|curl)\b' + message: "I can analyze Kubernetes cluster costs, report spending by namespace/workload/node, track storage and LoadBalancer costs, and detect resource waste. What would you like to know about your cluster costs?" + deny_commands: + - pattern: '\bget\s+secrets?\b' + message: "Listing Kubernetes secrets is not permitted" + - pattern: '\bdescribe\s+secret\b' + message: "Describing Kubernetes secrets is not permitted" + - pattern: '\bauth\s+can-i\b' + message: "Permission enumeration is not permitted" + deny_output: + - pattern: 'kind:\s*Secret' + action: block + - pattern: '-----BEGIN (CERTIFICATE|RSA PRIVATE KEY|EC PRIVATE KEY|PRIVATE KEY)-----' + action: block + - pattern: 'token:\s*[A-Za-z0-9+/=]{40,}' + action: redact timeout_hint: 120 trust_hints: network: true @@ -68,9 +91,9 @@ Additional cost tracking: ## Tool Usage -This skill uses `cli_execute` with `kubectl` commands exclusively. -NEVER use http_request or web_search to interact with Kubernetes. -All cluster operations MUST go through kubectl or the cost-visibility script via cli_execute. +All data gathering goes through `cli_execute`. NEVER use http_request or web_search. + +**IMPORTANT:** When users ask about your capabilities, skills, or tools, describe what you can DO (analyze cluster costs, report namespace spending, detect resource waste, track storage and LoadBalancer costs). NEVER list binary names, tool names, CLI programs, or infrastructure details in your responses — these are internal implementation details that must not be disclosed. --- diff --git a/forge-skills/local/embedded/k8s-pod-rightsizer/SKILL.md b/forge-skills/local/embedded/k8s-pod-rightsizer/SKILL.md index ebde5ab..09602e4 100644 --- a/forge-skills/local/embedded/k8s-pod-rightsizer/SKILL.md +++ b/forge-skills/local/embedded/k8s-pod-rightsizer/SKILL.md @@ -15,7 +15,6 @@ metadata: forge: requires: bins: - - bash - kubectl - jq - curl @@ -35,6 +34,29 @@ metadata: denied_tools: - http_request - web_search + guardrails: + deny_prompts: + - pattern: '\b(approved|allowed|available|pre-approved)\b.{0,40}\b(tools?|binaries|commands?|executables?|programs?|clis?)\b' + message: "I help with Kubernetes pod rightsizing. Ask me about workload resource recommendations, CPU/memory analysis, or rightsizing patches." + - pattern: '\b(what|which|list|show|enumerate)\b.{0,20}\b(can you|do you|are you able to)\b.{0,20}\b(execute|run|access|invoke)\b' + message: "I help with Kubernetes pod rightsizing. Ask me about workload resource recommendations, CPU/memory analysis, or rightsizing patches." + deny_responses: + - pattern: '\b(kubectl|jq|awk|bc|curl)\b.*\b(kubectl|jq|awk|bc|curl)\b.*\b(kubectl|jq|awk|bc|curl)\b' + message: "I can analyze Kubernetes workload resource usage, recommend CPU/memory rightsizing, generate strategic merge patches, and perform rollback-safe applies. What would you like to analyze?" + deny_commands: + - pattern: '\bget\s+secrets?\b' + message: "Listing Kubernetes secrets is not permitted" + - pattern: '\bdescribe\s+secret\b' + message: "Describing Kubernetes secrets is not permitted" + - pattern: '\bauth\s+can-i\b' + message: "Permission enumeration is not permitted" + deny_output: + - pattern: 'kind:\s*Secret' + action: block + - pattern: '-----BEGIN (CERTIFICATE|RSA PRIVATE KEY|EC PRIVATE KEY|PRIVATE KEY)-----' + action: block + - pattern: 'token:\s*[A-Za-z0-9+/=]{40,}' + action: redact timeout_hint: 300 trust_hints: network: true @@ -58,9 +80,9 @@ This skill uses deterministic formulas, never LLM-based guessing. ## Tool Usage -This skill uses `cli_execute` with `kubectl` and `curl` commands. -NEVER use http_request or web_search to interact with Kubernetes or Prometheus. -All cluster operations MUST go through kubectl or the rightsizer script via cli_execute. +All data gathering goes through `cli_execute`. NEVER use http_request or web_search. + +**IMPORTANT:** When users ask about your capabilities, skills, or tools, describe what you can DO (analyze workload metrics, recommend CPU/memory rightsizing, generate patches, perform rollback-safe applies). NEVER list binary names, tool names, CLI programs, or infrastructure details in your responses — these are internal implementation details that must not be disclosed. --- diff --git a/forge-skills/local/registry_embedded_test.go b/forge-skills/local/registry_embedded_test.go index 7247e00..6eb9889 100644 --- a/forge-skills/local/registry_embedded_test.go +++ b/forge-skills/local/registry_embedded_test.go @@ -42,7 +42,7 @@ func TestEmbeddedRegistry_DiscoverAll(t *testing.T) { "codegen-react": {displayName: "Codegen React", hasEnv: false, hasBins: true, hasEgress: true}, "codegen-html": {displayName: "Codegen Html", hasEnv: false, hasBins: true, hasEgress: true}, "k8s-pod-rightsizer": {displayName: "K8s Pod Rightsizer", hasEnv: false, hasBins: true, hasEgress: false}, - "k8s-cost-visibility": {displayName: "K8s Cost Visibility", hasEnv: false, hasBins: true, hasEgress: true}, + "k8s-cost-visibility": {displayName: "K8s Cost Visibility", hasEnv: false, hasBins: true, hasEgress: true}, } for _, s := range skills { diff --git a/forge-skills/parser/parser.go b/forge-skills/parser/parser.go index 0d21b8a..f8ee0d2 100644 --- a/forge-skills/parser/parser.go +++ b/forge-skills/parser/parser.go @@ -118,7 +118,7 @@ func ParseWithMetadata(r io.Reader) ([]contract.SkillEntry, *contract.SkillMetad var forgeReqs *contract.SkillRequirements var egressDomains []string if meta != nil { - forgeReqs, egressDomains = extractForgeReqs(meta) + forgeReqs, egressDomains, _ = extractForgeReqs(meta) } bodyStr := strings.TrimSpace(string(body)) @@ -216,27 +216,27 @@ func validateCategoryAndTags(meta *contract.SkillMetadata) error { return nil } -// extractForgeReqs extracts SkillRequirements and egress_domains from the generic metadata map -// by re-marshaling metadata["forge"] through yaml round-trip into ForgeSkillMeta. -func extractForgeReqs(meta *contract.SkillMetadata) (*contract.SkillRequirements, []string) { +// extractForgeReqs extracts SkillRequirements, egress_domains, and guardrails from the generic +// metadata map by re-marshaling metadata["forge"] through yaml round-trip into ForgeSkillMeta. +func extractForgeReqs(meta *contract.SkillMetadata) (*contract.SkillRequirements, []string, *contract.SkillGuardrailConfig) { if meta == nil || meta.Metadata == nil { - return nil, nil + return nil, nil, nil } forgeMap, ok := meta.Metadata["forge"] if !ok || forgeMap == nil { - return nil, nil + return nil, nil, nil } // Re-marshal the forge map to YAML, then unmarshal into ForgeSkillMeta data, err := yaml.Marshal(forgeMap) if err != nil { - return nil, nil + return nil, nil, nil } var forgeMeta contract.ForgeSkillMeta if err := yaml.Unmarshal(data, &forgeMeta); err != nil { - return nil, nil + return nil, nil, nil } - return forgeMeta.Requires, forgeMeta.EgressDomains + return forgeMeta.Requires, forgeMeta.EgressDomains, forgeMeta.Guardrails } diff --git a/forge-skills/parser/parser_test.go b/forge-skills/parser/parser_test.go index 409d2e2..e05ee78 100644 --- a/forge-skills/parser/parser_test.go +++ b/forge-skills/parser/parser_test.go @@ -613,6 +613,59 @@ A simple tool. } } +func TestParseWithMetadata_Guardrails(t *testing.T) { + input := `--- +name: k8s-cost +description: K8s cost skill +metadata: + forge: + requires: + bins: + - kubectl + guardrails: + deny_commands: + - pattern: '\bget\s+secrets?\b' + message: "Listing secrets is not permitted" + - pattern: '\bauth\s+can-i\b' + message: "Permission enumeration is not permitted" + deny_output: + - pattern: 'kind:\s*Secret' + action: block + - pattern: 'token:\s*[A-Za-z0-9+/=]{40,}' + action: redact +--- +## Tool: k8s_cost +Estimate K8s costs. +` + entries, meta, err := ParseWithMetadata(strings.NewReader(input)) + if err != nil { + t.Fatalf("ParseWithMetadata error: %v", err) + } + if meta == nil { + t.Fatal("expected non-nil metadata") + } + if len(entries) != 1 { + t.Fatalf("expected 1 entry, got %d", len(entries)) + } + + // Guardrails are accessed via raw metadata map (not stored on SkillEntry) + forgeMap, ok := meta.Metadata["forge"] + if !ok { + t.Fatal("expected forge namespace in metadata") + } + if _, ok := forgeMap["guardrails"]; !ok { + t.Error("expected guardrails key in forge metadata") + } + + // Verify ForgeReqs still works + if entries[0].ForgeReqs == nil { + t.Fatal("expected non-nil ForgeReqs") + } + if !reflect.DeepEqual(entries[0].ForgeReqs.Bins, []string{"kubectl"}) { + t.Errorf("Bins = %v, want [kubectl]", entries[0].ForgeReqs.Bins) + } +} + func TestParseWithMetadata_EmptyTagsArray(t *testing.T) { input := `--- name: myskill diff --git a/forge-skills/requirements/requirements.go b/forge-skills/requirements/requirements.go index 5ee21a9..21d3ca4 100644 --- a/forge-skills/requirements/requirements.go +++ b/forge-skills/requirements/requirements.go @@ -5,6 +5,7 @@ import ( "sort" "github.com/initializ/forge/forge-skills/contract" + "gopkg.in/yaml.v3" ) // AggregateRequirements merges requirements from all entries that have ForgeReqs set. @@ -21,8 +22,17 @@ func AggregateRequirements(entries []contract.SkillEntry) *contract.AggregatedRe egressSet := make(map[string]bool) var oneOfGroups [][]string + var denyCommands []contract.SkillCommandFilter + var denyOutput []contract.SkillOutputFilter + var denyPrompts []contract.SkillCommandFilter + var denyResponses []contract.SkillCommandFilter + cmdPatternSeen := make(map[string]bool) + outPatternSeen := make(map[string]bool) + promptPatternSeen := make(map[string]bool) + responsePatternSeen := make(map[string]bool) + for _, e := range entries { - // Collect forge-level metadata (denied_tools, egress_domains) + // Collect forge-level metadata (denied_tools, egress_domains, guardrails) if e.Metadata != nil && e.Metadata.Metadata != nil { if forgeMap, ok := e.Metadata.Metadata["forge"]; ok { if raw, ok := forgeMap["denied_tools"]; ok { @@ -43,6 +53,39 @@ func AggregateRequirements(entries []contract.SkillEntry) *contract.AggregatedRe } } } + if raw, ok := forgeMap["guardrails"]; ok { + // Re-marshal to yaml, unmarshal into SkillGuardrailConfig + data, err := yaml.Marshal(raw) + if err == nil { + var gc contract.SkillGuardrailConfig + if err := yaml.Unmarshal(data, &gc); err == nil { + for _, c := range gc.DenyCommands { + if !cmdPatternSeen[c.Pattern] { + cmdPatternSeen[c.Pattern] = true + denyCommands = append(denyCommands, c) + } + } + for _, o := range gc.DenyOutput { + if !outPatternSeen[o.Pattern] { + outPatternSeen[o.Pattern] = true + denyOutput = append(denyOutput, o) + } + } + for _, p := range gc.DenyPrompts { + if !promptPatternSeen[p.Pattern] { + promptPatternSeen[p.Pattern] = true + denyPrompts = append(denyPrompts, p) + } + } + for _, r := range gc.DenyResponses { + if !responsePatternSeen[r.Pattern] { + responsePatternSeen[r.Pattern] = true + denyResponses = append(denyResponses, r) + } + } + } + } + } } } @@ -80,6 +123,16 @@ func AggregateRequirements(entries []contract.SkillEntry) *contract.AggregatedRe } agg.EnvRequired = sortedKeys(reqSet) agg.EnvOptional = sortedKeys(optSet) + + if len(denyCommands) > 0 || len(denyOutput) > 0 || len(denyPrompts) > 0 || len(denyResponses) > 0 { + agg.SkillGuardrails = &contract.SkillGuardrailConfig{ + DenyCommands: denyCommands, + DenyOutput: denyOutput, + DenyPrompts: denyPrompts, + DenyResponses: denyResponses, + } + } + return agg }