diff --git a/internal/runtime/system_tool.go b/internal/runtime/system_tool.go index 33e4cd7b..fa70de49 100644 --- a/internal/runtime/system_tool.go +++ b/internal/runtime/system_tool.go @@ -43,12 +43,12 @@ func (s *Service) ExecuteSystemTool(ctx context.Context, input SystemToolInput) ) if sessionID != "" { sessionMu, releaseLockRef := s.acquireSessionLock(sessionID) - defer releaseLockRef() sessionMu.Lock() - defer sessionMu.Unlock() session, err := s.sessionStore.LoadSession(ctx, sessionID) if err != nil { + sessionMu.Unlock() + releaseLockRef() return tools.ToolResult{}, err } loaded = session @@ -57,6 +57,8 @@ func (s *Service) ExecuteSystemTool(ctx context.Context, input SystemToolInput) } runStateValue := newRunState(runID, session) state = &runStateValue + sessionMu.Unlock() + releaseLockRef() } call := providertypes.ToolCall{ @@ -114,10 +116,15 @@ func normalizeToolName(name string) string { // newSystemToolRunID 为系统工具调用生成稳定前缀的运行标识,便于事件与日志定位。 func newSystemToolRunID(toolName string) string { - return fmt.Sprintf("system-tool-%s-%d", normalizeToolName(toolName), time.Now().UnixNano()) + return formatSystemToolID("system-tool", toolName) } // newSystemToolCallID 为系统工具调用生成单次执行唯一的 tool call id。 func newSystemToolCallID(toolName string) string { - return fmt.Sprintf("call-%s-%d", normalizeToolName(toolName), time.Now().UnixNano()) + return formatSystemToolID("call", toolName) +} + +// formatSystemToolID 统一构造系统工具相关 ID,避免不同类型 ID 生成逻辑分散重复。 +func formatSystemToolID(prefix, toolName string) string { + return fmt.Sprintf("%s-%s-%d", prefix, normalizeToolName(toolName), time.Now().UnixNano()) } diff --git a/internal/runtime/system_tool_test.go b/internal/runtime/system_tool_test.go index dd6591d9..b6af7330 100644 --- a/internal/runtime/system_tool_test.go +++ b/internal/runtime/system_tool_test.go @@ -5,6 +5,7 @@ import ( "errors" "strings" "testing" + "time" agentsession "neo-code/internal/session" "neo-code/internal/tools" @@ -153,6 +154,78 @@ func TestExecuteSystemToolWithSession(t *testing.T) { assertEventContains(t, events, EventToolResult) } +// TestExecuteSystemToolReleasesSessionLockBeforeToolExecution 验证工具执行期间不会继续持有会话锁。 +func TestExecuteSystemToolReleasesSessionLockBeforeToolExecution(t *testing.T) { + t.Parallel() + + store := newMemoryStore() + session, err := store.CreateSession(context.Background(), agentsession.CreateSessionInput{ + Title: "test-session", + }) + if err != nil { + t.Fatalf("create session: %v", err) + } + + started := make(chan struct{}) + releaseTool := make(chan struct{}) + tm := &stubToolManager{ + executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { + close(started) + <-releaseTool + return tools.ToolResult{Content: "done"}, nil + }, + } + service := NewWithFactory( + newRuntimeConfigManager(t), + tm, + store, + &scriptedProviderFactory{provider: &scriptedProvider{}}, + nil, + ) + + runErr := make(chan error, 1) + go func() { + _, runErrValue := service.ExecuteSystemTool(context.Background(), SystemToolInput{ + SessionID: session.ID, + ToolName: "bash", + Arguments: []byte(`{"command":"ls"}`), + }) + runErr <- runErrValue + }() + + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting tool execution to start") + } + + sessionMu, releaseLockRef := service.acquireSessionLock(session.ID) + lockAcquired := make(chan struct{}) + go func() { + sessionMu.Lock() + close(lockAcquired) + sessionMu.Unlock() + releaseLockRef() + }() + + select { + case <-lockAcquired: + case <-time.After(2 * time.Second): + close(releaseTool) + t.Fatal("session lock is still held during tool execution") + } + + close(releaseTool) + select { + case err := <-runErr: + if err != nil { + t.Fatalf("execute system tool: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting ExecuteSystemTool to return") + } +} + // TestExecuteSystemToolWithSessionLoadError 验证会话加载失败时返回错误。 func TestExecuteSystemToolWithSessionLoadError(t *testing.T) { t.Parallel() @@ -184,10 +257,8 @@ func TestExecuteSystemToolWithSessionLoadError(t *testing.T) { func TestExecuteSystemToolCustomRunID(t *testing.T) { t.Parallel() - var capturedInput tools.ToolCallInput tm := &stubToolManager{ executeFn: func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { - capturedInput = input return tools.ToolResult{Content: "ok"}, nil }, } @@ -222,7 +293,6 @@ func TestExecuteSystemToolCustomRunID(t *testing.T) { if !found { t.Fatalf("expected event with RunID 'my-custom-run-id' in %d events", len(events)) } - _ = capturedInput } // TestExecuteSystemToolDefaultWorkdir 验证 workdir 为空时使用配置默认值。 @@ -337,19 +407,7 @@ func TestNewSystemToolRunID(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { got := newSystemToolRunID(tc.input) - if !strings.HasPrefix(got, tc.prefix) { - t.Fatalf("expected prefix %q, got %q", tc.prefix, got) - } - // 验证前缀之后是数字(时间戳) - suffix := strings.TrimPrefix(got, tc.prefix) - if suffix == "" { - t.Fatal("expected numeric suffix after prefix") - } - for _, ch := range suffix { - if ch < '0' || ch > '9' { - t.Fatalf("expected numeric suffix, got %q in %q", string(ch), got) - } - } + assertGeneratedIDWithPrefix(t, got, tc.prefix) }) } } @@ -393,18 +451,25 @@ func TestNewSystemToolCallID(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { got := newSystemToolCallID(tc.input) - if !strings.HasPrefix(got, tc.prefix) { - t.Fatalf("expected prefix %q, got %q", tc.prefix, got) - } - suffix := strings.TrimPrefix(got, tc.prefix) - if suffix == "" { - t.Fatal("expected numeric suffix after prefix") - } - for _, ch := range suffix { - if ch < '0' || ch > '9' { - t.Fatalf("expected numeric suffix, got %q in %q", string(ch), got) - } - } + assertGeneratedIDWithPrefix(t, got, tc.prefix) }) } } + +func assertGeneratedIDWithPrefix(t *testing.T, got, prefix string) { + t.Helper() + + if !strings.HasPrefix(got, prefix) { + t.Fatalf("expected prefix %q, got %q", prefix, got) + } + + suffix := strings.TrimPrefix(got, prefix) + if suffix == "" { + t.Fatal("expected numeric suffix after prefix") + } + for _, ch := range suffix { + if ch < '0' || ch > '9' { + t.Fatalf("expected numeric suffix, got %q in %q", string(ch), got) + } + } +}