Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ flowchart TD
9. 若模型返回文本答案,则保存 assistant 消息并结束本轮。
10. 若模型返回工具调用,则逐个执行 Tool:
- 生成 `ExecutionContext`
- 执行 Tool
- 通过 `Runner` 调用 `Executor` 执行统一管线(参数校验、权限策略、超时、错误归一),当前实现在 `internal/agent/runner.go` + `internal/tools/executor.go`
- 工具定义与 mode 过滤由 `internal/tools/registry.go` 提供
- 记录结果
- 将 tool result 重新写回 Session
11. 进入下一轮模型调用,直到:
Expand Down
87 changes: 87 additions & 0 deletions internal/tools/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package tools

import (
"context"
"errors"
"testing"
)

func TestNormalizeToolErrorContractTimeout(t *testing.T) {
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good contract coverage here: these tests assert Code, Retryable, and wrapped cause behavior, which protects error mapping stability.

cause := context.DeadlineExceeded
err := normalizeToolError(cause)
execErr, ok := AsToolExecError(err)
if !ok {
t.Fatalf("expected ToolExecError, got %T", err)
}
if execErr.Code != ToolErrorTimeout {
t.Fatalf("unexpected code: %s", execErr.Code)
}
if !execErr.Retryable {
t.Fatal("timeout should be retryable")
}
if !errors.Is(execErr, cause) {
t.Fatal("timeout cause must be preserved")
}
}

func TestNormalizeToolErrorContractPermissionDenied(t *testing.T) {
cause := errors.New("approval denied")
err := normalizeToolError(cause)
execErr, ok := AsToolExecError(err)
if !ok {
t.Fatalf("expected ToolExecError, got %T", err)
}
if execErr.Code != ToolErrorPermissionDenied {
t.Fatalf("unexpected code: %s", execErr.Code)
}
if execErr.Retryable {
t.Fatal("permission denied should not be retryable")
}
if !errors.Is(execErr, cause) {
t.Fatal("permission-denied cause must be preserved")
}
}

func TestNormalizeToolErrorContractInvalidArgs(t *testing.T) {
cause := errors.New("unknown argument \"extra\"")
err := normalizeToolError(cause)
execErr, ok := AsToolExecError(err)
if !ok {
t.Fatalf("expected ToolExecError, got %T", err)
}
if execErr.Code != ToolErrorInvalidArgs {
t.Fatalf("unexpected code: %s", execErr.Code)
}
if execErr.Retryable {
t.Fatal("invalid args should not be retryable")
}
if !errors.Is(execErr, cause) {
t.Fatal("invalid-args cause must be preserved")
}
}

func TestNormalizeToolErrorContractToolFailed(t *testing.T) {
cause := errors.New("tool crashed")
err := normalizeToolError(cause)
execErr, ok := AsToolExecError(err)
if !ok {
t.Fatalf("expected ToolExecError, got %T", err)
}
if execErr.Code != ToolErrorToolFailed {
t.Fatalf("unexpected code: %s", execErr.Code)
}
if !execErr.Retryable {
t.Fatal("tool failed should be retryable")
}
if !errors.Is(execErr, cause) {
t.Fatal("tool-failed cause must be preserved")
}
}

func TestNormalizeToolErrorKeepsExistingToolExecError(t *testing.T) {
original := NewToolExecError(ToolErrorInvalidArgs, "bad input", false, errors.New("root"))
err := normalizeToolError(original)
if err != original {
t.Fatal("expected existing ToolExecError to be returned unchanged")
}
}
84 changes: 77 additions & 7 deletions internal/tools/executor.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package tools

import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"time"
"unicode/utf8"
Expand Down Expand Up @@ -37,9 +39,14 @@ type OutputNormalizer interface {
Normalize(string, ResolvedTool) string
}

type WriteApprovalEngine interface {
Check(context.Context, ResolvedTool, *ExecutionContext) error
}

type Executor struct {
registry *Registry
permissionEngine PermissionEngine
writeApproval WriteApprovalEngine
argumentDecoder ArgumentDecoder
outputNormalizer OutputNormalizer
}
Expand All @@ -48,6 +55,7 @@ func NewExecutor(registry *Registry) *Executor {
return &Executor{
registry: registry,
permissionEngine: defaultPermissionEngine{},
writeApproval: defaultWriteApprovalEngine{},
argumentDecoder: strictJSONArgumentDecoder{},
outputNormalizer: maxCharsOutputNormalizer{},
}
Expand Down Expand Up @@ -90,6 +98,11 @@ func (e *Executor) ExecuteRequest(ctx context.Context, req ExecuteRequest) (Exec
if err := e.permissionEngine.Check(ctx, resolved, req.Context); err != nil {
return ExecuteResult{}, err
}
if e.writeApproval != nil {
if err := e.writeApproval.Check(ctx, resolved, req.Context); err != nil {
return ExecuteResult{}, err
}
}

execCtx := req.Context
if execCtx == nil {
Expand Down Expand Up @@ -121,6 +134,60 @@ func (defaultPermissionEngine) Check(_ context.Context, resolved ResolvedTool, e
return nil
}

type defaultWriteApprovalEngine struct{}

func (defaultWriteApprovalEngine) Check(_ context.Context, resolved ResolvedTool, execCtx *ExecutionContext) error {
if !resolved.Spec.Destructive {
return nil
}
if execCtx == nil {
return NewToolExecError(ToolErrorPermissionDenied, fmt.Sprintf("tool %q requires approval context", resolved.Definition.Function.Name), false, nil)
}

switch strings.TrimSpace(execCtx.ApprovalPolicy) {
case "never":
return nil
case "always", "on-request", "":
return promptForWriteApproval(resolved.Definition.Function.Name, execCtx)
default:
return promptForWriteApproval(resolved.Definition.Function.Name, execCtx)
}
}

func promptForWriteApproval(toolName string, execCtx *ExecutionContext) error {
reason := "writes files in the workspace"
if execCtx.Approval != nil {
approved, err := execCtx.Approval(ApprovalRequest{
Command: toolName,
Reason: reason,
})
if err != nil {
return NewToolExecError(ToolErrorPermissionDenied, err.Error(), false, err)
}
if !approved {
return NewToolExecError(ToolErrorPermissionDenied, fmt.Sprintf("tool %q was not run because approval was denied", toolName), false, nil)
}
return nil
}

if execCtx.Stdin == nil {
return NewToolExecError(ToolErrorPermissionDenied, fmt.Sprintf("tool %q requires approval but no stdin is available", toolName), false, nil)
}
if execCtx.Stdout != nil {
fmt.Fprintf(execCtx.Stdout, "Approve tool %q (%s)? [y/N]: ", toolName, reason)
}
reader := bufio.NewReader(execCtx.Stdin)
line, err := reader.ReadString('\n')
if err != nil && !errors.Is(err, io.EOF) {
return NewToolExecError(ToolErrorPermissionDenied, err.Error(), false, err)
}
answer := strings.ToLower(strings.TrimSpace(line))
if answer != "y" && answer != "yes" {
return NewToolExecError(ToolErrorPermissionDenied, fmt.Sprintf("tool %q was not run because approval was denied", toolName), false, nil)
}
return nil
}

type strictJSONArgumentDecoder struct{}

func (strictJSONArgumentDecoder) Decode(rawArgs string, resolved ResolvedTool) (json.RawMessage, error) {
Expand All @@ -143,14 +210,11 @@ func (strictJSONArgumentDecoder) Decode(rawArgs string, resolved ResolvedTool) (
return nil, NewToolExecError(ToolErrorInvalidArgs, "tool arguments must be a JSON object", false, nil)
}

if !schemaRejectsUnknownFields(resolved.Definition.Function.Parameters) {
if schemaAllowsUnknownFields(resolved.Definition.Function.Parameters) {
return json.RawMessage(rawArgs), nil
}

allowedFields := schemaPropertyNames(resolved.Definition.Function.Parameters)
if len(allowedFields) == 0 {
return json.RawMessage(rawArgs), nil
}
for key := range objectPayload {
if _, ok := allowedFields[key]; ok {
continue
Expand Down Expand Up @@ -227,13 +291,19 @@ func schemaPropertyNames(parameters map[string]any) map[string]struct{} {
return names
}

func schemaRejectsUnknownFields(parameters map[string]any) bool {
func schemaAllowsUnknownFields(parameters map[string]any) bool {
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

schemaAllowsUnknownFields only treats additionalProperties: true as allowing unknown keys. In JSON Schema, additionalProperties may also be an object schema; that should still allow unknown keys (with validation constraints). This implementation now rejects all unknown keys in that valid schema form, which is a behavior regression for tools that rely on schema-object additionalProperties.

value, ok := parameters["additionalProperties"]
if !ok {
return false
}
allowed, ok := value.(bool)
return ok && !allowed
switch typed := value.(type) {
case bool:
return typed
case map[string]any:
return true
default:
return false
}
}

func executionTimeout(raw json.RawMessage, spec ToolSpec) time.Duration {
Expand Down
Loading
Loading