Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions internal/runtime/system_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ func (s *Service) ExecuteSystemTool(ctx context.Context, input SystemToolInput)
)
if sessionID != "" {
sessionMu, releaseLockRef := s.acquireSessionLock(sessionID)
defer releaseLockRef()
sessionMu.Lock()
defer sessionMu.Unlock()

session, err := s.sessionStore.LoadSession(ctx, sessionID)
if err != nil {
sessionMu.Unlock()
releaseLockRef()
return tools.ToolResult{}, err
}
loaded = session
Expand All @@ -57,6 +57,8 @@ func (s *Service) ExecuteSystemTool(ctx context.Context, input SystemToolInput)
}
runStateValue := newRunState(runID, session)
state = &runStateValue
sessionMu.Unlock()
releaseLockRef()
}

call := providertypes.ToolCall{
Expand Down Expand Up @@ -114,10 +116,15 @@ func normalizeToolName(name string) string {

// newSystemToolRunID 为系统工具调用生成稳定前缀的运行标识,便于事件与日志定位。
func newSystemToolRunID(toolName string) string {
return fmt.Sprintf("system-tool-%s-%d", normalizeToolName(toolName), time.Now().UnixNano())
return formatSystemToolID("system-tool", toolName)
}

// newSystemToolCallID 为系统工具调用生成单次执行唯一的 tool call id。
func newSystemToolCallID(toolName string) string {
return fmt.Sprintf("call-%s-%d", normalizeToolName(toolName), time.Now().UnixNano())
return formatSystemToolID("call", toolName)
}

// formatSystemToolID 统一构造系统工具相关 ID,避免不同类型 ID 生成逻辑分散重复。
func formatSystemToolID(prefix, toolName string) string {
return fmt.Sprintf("%s-%s-%d", prefix, normalizeToolName(toolName), time.Now().UnixNano())
}
121 changes: 93 additions & 28 deletions internal/runtime/system_tool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"strings"
"testing"
"time"

agentsession "neo-code/internal/session"
"neo-code/internal/tools"
Expand Down Expand Up @@ -153,6 +154,78 @@ func TestExecuteSystemToolWithSession(t *testing.T) {
assertEventContains(t, events, EventToolResult)
}

// TestExecuteSystemToolReleasesSessionLockBeforeToolExecution 验证工具执行期间不会继续持有会话锁。
func TestExecuteSystemToolReleasesSessionLockBeforeToolExecution(t *testing.T) {
t.Parallel()

store := newMemoryStore()
session, err := store.CreateSession(context.Background(), agentsession.CreateSessionInput{
Title: "test-session",
})
if err != nil {
t.Fatalf("create session: %v", err)
}

started := make(chan struct{})
releaseTool := make(chan struct{})
tm := &stubToolManager{
executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) {
close(started)
<-releaseTool
return tools.ToolResult{Content: "done"}, nil
},
}
service := NewWithFactory(
newRuntimeConfigManager(t),
tm,
store,
&scriptedProviderFactory{provider: &scriptedProvider{}},
nil,
)

runErr := make(chan error, 1)
go func() {
_, runErrValue := service.ExecuteSystemTool(context.Background(), SystemToolInput{
SessionID: session.ID,
ToolName: "bash",
Arguments: []byte(`{"command":"ls"}`),
})
runErr <- runErrValue
}()

select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting tool execution to start")
}

sessionMu, releaseLockRef := service.acquireSessionLock(session.ID)
lockAcquired := make(chan struct{})
go func() {
sessionMu.Lock()
close(lockAcquired)
sessionMu.Unlock()
releaseLockRef()
}()

select {
case <-lockAcquired:
case <-time.After(2 * time.Second):
close(releaseTool)
t.Fatal("session lock is still held during tool execution")
}

close(releaseTool)
select {
case err := <-runErr:
if err != nil {
t.Fatalf("execute system tool: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting ExecuteSystemTool to return")
}
}

// TestExecuteSystemToolWithSessionLoadError 验证会话加载失败时返回错误。
func TestExecuteSystemToolWithSessionLoadError(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -184,10 +257,8 @@ func TestExecuteSystemToolWithSessionLoadError(t *testing.T) {
func TestExecuteSystemToolCustomRunID(t *testing.T) {
t.Parallel()

var capturedInput tools.ToolCallInput
tm := &stubToolManager{
executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) {
capturedInput = input
return tools.ToolResult{Content: "ok"}, nil
},
}
Expand Down Expand Up @@ -222,7 +293,6 @@ func TestExecuteSystemToolCustomRunID(t *testing.T) {
if !found {
t.Fatalf("expected event with RunID 'my-custom-run-id' in %d events", len(events))
}
_ = capturedInput
}

// TestExecuteSystemToolDefaultWorkdir 验证 workdir 为空时使用配置默认值。
Expand Down Expand Up @@ -337,19 +407,7 @@ func TestNewSystemToolRunID(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := newSystemToolRunID(tc.input)
if !strings.HasPrefix(got, tc.prefix) {
t.Fatalf("expected prefix %q, got %q", tc.prefix, got)
}
// 验证前缀之后是数字(时间戳)
suffix := strings.TrimPrefix(got, tc.prefix)
if suffix == "" {
t.Fatal("expected numeric suffix after prefix")
}
for _, ch := range suffix {
if ch < '0' || ch > '9' {
t.Fatalf("expected numeric suffix, got %q in %q", string(ch), got)
}
}
assertGeneratedIDWithPrefix(t, got, tc.prefix)
})
}
}
Expand Down Expand Up @@ -393,18 +451,25 @@ func TestNewSystemToolCallID(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := newSystemToolCallID(tc.input)
if !strings.HasPrefix(got, tc.prefix) {
t.Fatalf("expected prefix %q, got %q", tc.prefix, got)
}
suffix := strings.TrimPrefix(got, tc.prefix)
if suffix == "" {
t.Fatal("expected numeric suffix after prefix")
}
for _, ch := range suffix {
if ch < '0' || ch > '9' {
t.Fatalf("expected numeric suffix, got %q in %q", string(ch), got)
}
}
assertGeneratedIDWithPrefix(t, got, tc.prefix)
})
}
}

func assertGeneratedIDWithPrefix(t *testing.T, got, prefix string) {
t.Helper()

if !strings.HasPrefix(got, prefix) {
t.Fatalf("expected prefix %q, got %q", prefix, got)
}

suffix := strings.TrimPrefix(got, prefix)
if suffix == "" {
t.Fatal("expected numeric suffix after prefix")
}
for _, ch := range suffix {
if ch < '0' || ch > '9' {
t.Fatalf("expected numeric suffix, got %q in %q", string(ch), got)
}
}
}
Loading