-
Notifications
You must be signed in to change notification settings - Fork 6
feat(tools): standardize executor approval and strict arg contract #191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| package tools | ||
|
|
||
| import ( | ||
| "context" | ||
| "errors" | ||
| "testing" | ||
| ) | ||
|
|
||
| func TestNormalizeToolErrorContractTimeout(t *testing.T) { | ||
| cause := context.DeadlineExceeded | ||
| err := normalizeToolError(cause) | ||
| execErr, ok := AsToolExecError(err) | ||
| if !ok { | ||
| t.Fatalf("expected ToolExecError, got %T", err) | ||
| } | ||
| if execErr.Code != ToolErrorTimeout { | ||
| t.Fatalf("unexpected code: %s", execErr.Code) | ||
| } | ||
| if !execErr.Retryable { | ||
| t.Fatal("timeout should be retryable") | ||
| } | ||
| if !errors.Is(execErr, cause) { | ||
| t.Fatal("timeout cause must be preserved") | ||
| } | ||
| } | ||
|
|
||
| func TestNormalizeToolErrorContractPermissionDenied(t *testing.T) { | ||
| cause := errors.New("approval denied") | ||
| err := normalizeToolError(cause) | ||
| execErr, ok := AsToolExecError(err) | ||
| if !ok { | ||
| t.Fatalf("expected ToolExecError, got %T", err) | ||
| } | ||
| if execErr.Code != ToolErrorPermissionDenied { | ||
| t.Fatalf("unexpected code: %s", execErr.Code) | ||
| } | ||
| if execErr.Retryable { | ||
| t.Fatal("permission denied should not be retryable") | ||
| } | ||
| if !errors.Is(execErr, cause) { | ||
| t.Fatal("permission-denied cause must be preserved") | ||
| } | ||
| } | ||
|
|
||
| func TestNormalizeToolErrorContractInvalidArgs(t *testing.T) { | ||
| cause := errors.New("unknown argument \"extra\"") | ||
| err := normalizeToolError(cause) | ||
| execErr, ok := AsToolExecError(err) | ||
| if !ok { | ||
| t.Fatalf("expected ToolExecError, got %T", err) | ||
| } | ||
| if execErr.Code != ToolErrorInvalidArgs { | ||
| t.Fatalf("unexpected code: %s", execErr.Code) | ||
| } | ||
| if execErr.Retryable { | ||
| t.Fatal("invalid args should not be retryable") | ||
| } | ||
| if !errors.Is(execErr, cause) { | ||
| t.Fatal("invalid-args cause must be preserved") | ||
| } | ||
| } | ||
|
|
||
| func TestNormalizeToolErrorContractToolFailed(t *testing.T) { | ||
| cause := errors.New("tool crashed") | ||
| err := normalizeToolError(cause) | ||
| execErr, ok := AsToolExecError(err) | ||
| if !ok { | ||
| t.Fatalf("expected ToolExecError, got %T", err) | ||
| } | ||
| if execErr.Code != ToolErrorToolFailed { | ||
| t.Fatalf("unexpected code: %s", execErr.Code) | ||
| } | ||
| if !execErr.Retryable { | ||
| t.Fatal("tool failed should be retryable") | ||
| } | ||
| if !errors.Is(execErr, cause) { | ||
| t.Fatal("tool-failed cause must be preserved") | ||
| } | ||
| } | ||
|
|
||
| func TestNormalizeToolErrorKeepsExistingToolExecError(t *testing.T) { | ||
| original := NewToolExecError(ToolErrorInvalidArgs, "bad input", false, errors.New("root")) | ||
| err := normalizeToolError(original) | ||
| if err != original { | ||
| t.Fatal("expected existing ToolExecError to be returned unchanged") | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| package tools | ||
|
|
||
| import ( | ||
| "bufio" | ||
| "context" | ||
| "encoding/json" | ||
| "errors" | ||
| "fmt" | ||
| "io" | ||
| "strings" | ||
| "time" | ||
| "unicode/utf8" | ||
|
|
@@ -37,9 +39,14 @@ type OutputNormalizer interface { | |
| Normalize(string, ResolvedTool) string | ||
| } | ||
|
|
||
| type WriteApprovalEngine interface { | ||
| Check(context.Context, ResolvedTool, *ExecutionContext) error | ||
| } | ||
|
|
||
| type Executor struct { | ||
| registry *Registry | ||
| permissionEngine PermissionEngine | ||
| writeApproval WriteApprovalEngine | ||
| argumentDecoder ArgumentDecoder | ||
| outputNormalizer OutputNormalizer | ||
| } | ||
|
|
@@ -48,6 +55,7 @@ func NewExecutor(registry *Registry) *Executor { | |
| return &Executor{ | ||
| registry: registry, | ||
| permissionEngine: defaultPermissionEngine{}, | ||
| writeApproval: defaultWriteApprovalEngine{}, | ||
| argumentDecoder: strictJSONArgumentDecoder{}, | ||
| outputNormalizer: maxCharsOutputNormalizer{}, | ||
| } | ||
|
|
@@ -90,6 +98,11 @@ func (e *Executor) ExecuteRequest(ctx context.Context, req ExecuteRequest) (Exec | |
| if err := e.permissionEngine.Check(ctx, resolved, req.Context); err != nil { | ||
| return ExecuteResult{}, err | ||
| } | ||
| if e.writeApproval != nil { | ||
| if err := e.writeApproval.Check(ctx, resolved, req.Context); err != nil { | ||
| return ExecuteResult{}, err | ||
| } | ||
| } | ||
|
|
||
| execCtx := req.Context | ||
| if execCtx == nil { | ||
|
|
@@ -121,6 +134,60 @@ func (defaultPermissionEngine) Check(_ context.Context, resolved ResolvedTool, e | |
| return nil | ||
| } | ||
|
|
||
| type defaultWriteApprovalEngine struct{} | ||
|
|
||
| func (defaultWriteApprovalEngine) Check(_ context.Context, resolved ResolvedTool, execCtx *ExecutionContext) error { | ||
| if !resolved.Spec.Destructive { | ||
| return nil | ||
| } | ||
| if execCtx == nil { | ||
| return NewToolExecError(ToolErrorPermissionDenied, fmt.Sprintf("tool %q requires approval context", resolved.Definition.Function.Name), false, nil) | ||
| } | ||
|
|
||
| switch strings.TrimSpace(execCtx.ApprovalPolicy) { | ||
| case "never": | ||
| return nil | ||
| case "always", "on-request", "": | ||
| return promptForWriteApproval(resolved.Definition.Function.Name, execCtx) | ||
| default: | ||
| return promptForWriteApproval(resolved.Definition.Function.Name, execCtx) | ||
| } | ||
| } | ||
|
|
||
| func promptForWriteApproval(toolName string, execCtx *ExecutionContext) error { | ||
| reason := "writes files in the workspace" | ||
| if execCtx.Approval != nil { | ||
| approved, err := execCtx.Approval(ApprovalRequest{ | ||
| Command: toolName, | ||
| Reason: reason, | ||
| }) | ||
| if err != nil { | ||
| return NewToolExecError(ToolErrorPermissionDenied, err.Error(), false, err) | ||
| } | ||
| if !approved { | ||
| return NewToolExecError(ToolErrorPermissionDenied, fmt.Sprintf("tool %q was not run because approval was denied", toolName), false, nil) | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| if execCtx.Stdin == nil { | ||
| return NewToolExecError(ToolErrorPermissionDenied, fmt.Sprintf("tool %q requires approval but no stdin is available", toolName), false, nil) | ||
| } | ||
| if execCtx.Stdout != nil { | ||
| fmt.Fprintf(execCtx.Stdout, "Approve tool %q (%s)? [y/N]: ", toolName, reason) | ||
| } | ||
| reader := bufio.NewReader(execCtx.Stdin) | ||
| line, err := reader.ReadString('\n') | ||
| if err != nil && !errors.Is(err, io.EOF) { | ||
| return NewToolExecError(ToolErrorPermissionDenied, err.Error(), false, err) | ||
| } | ||
| answer := strings.ToLower(strings.TrimSpace(line)) | ||
| if answer != "y" && answer != "yes" { | ||
| return NewToolExecError(ToolErrorPermissionDenied, fmt.Sprintf("tool %q was not run because approval was denied", toolName), false, nil) | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| type strictJSONArgumentDecoder struct{} | ||
|
|
||
| func (strictJSONArgumentDecoder) Decode(rawArgs string, resolved ResolvedTool) (json.RawMessage, error) { | ||
|
|
@@ -143,14 +210,11 @@ func (strictJSONArgumentDecoder) Decode(rawArgs string, resolved ResolvedTool) ( | |
| return nil, NewToolExecError(ToolErrorInvalidArgs, "tool arguments must be a JSON object", false, nil) | ||
| } | ||
|
|
||
| if !schemaRejectsUnknownFields(resolved.Definition.Function.Parameters) { | ||
| if schemaAllowsUnknownFields(resolved.Definition.Function.Parameters) { | ||
| return json.RawMessage(rawArgs), nil | ||
| } | ||
|
|
||
| allowedFields := schemaPropertyNames(resolved.Definition.Function.Parameters) | ||
| if len(allowedFields) == 0 { | ||
| return json.RawMessage(rawArgs), nil | ||
| } | ||
| for key := range objectPayload { | ||
| if _, ok := allowedFields[key]; ok { | ||
| continue | ||
|
|
@@ -227,13 +291,19 @@ func schemaPropertyNames(parameters map[string]any) map[string]struct{} { | |
| return names | ||
| } | ||
|
|
||
| func schemaRejectsUnknownFields(parameters map[string]any) bool { | ||
| func schemaAllowsUnknownFields(parameters map[string]any) bool { | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| value, ok := parameters["additionalProperties"] | ||
| if !ok { | ||
| return false | ||
| } | ||
| allowed, ok := value.(bool) | ||
| return ok && !allowed | ||
| switch typed := value.(type) { | ||
| case bool: | ||
| return typed | ||
| case map[string]any: | ||
| return true | ||
| default: | ||
| return false | ||
| } | ||
| } | ||
|
|
||
| func executionTimeout(raw json.RawMessage, spec ToolSpec) time.Duration { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good contract coverage here: these tests assert
Code,Retryable, and wrapped cause behavior, which protects error mapping stability.