From bc989c14705a2850e7b7f69267d91abf71118e37 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Mon, 20 Apr 2026 14:54:22 +0000 Subject: [PATCH] fix(runtime,tui,session): address review reliability and error-classification issues - harden gateway rpc notification backpressure with fail-fast timeout - avoid close/readLoop channel panic by removing queue close path - replace runtime already-exists string matching with sentinel errors.Is checks - add ErrSessionAlreadyExists and map sqlite unique constraints to sentinel - add regression tests for close interleaving and backpressure fail-fast Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: pionxe <148670367+pionxe@users.noreply.github.com> --- internal/runtime/create_session_test.go | 13 +++- internal/runtime/runtime.go | 6 +- internal/session/sqlite_store.go | 23 ++++++ internal/session/store.go | 3 + internal/session/store_test.go | 22 ++++++ internal/tui/services/gateway_rpc_client.go | 66 ++++++++++------ .../gateway_rpc_client_additional_test.go | 77 ++++++++++++++++++- 7 files changed, 177 insertions(+), 33 deletions(-) diff --git a/internal/runtime/create_session_test.go b/internal/runtime/create_session_test.go index a4d6ef21..ed6a3d73 100644 --- a/internal/runtime/create_session_test.go +++ b/internal/runtime/create_session_test.go @@ -3,6 +3,7 @@ package runtime import ( "context" "fmt" + "os" "testing" agentsession "neo-code/internal/session" @@ -160,7 +161,7 @@ func TestServiceCreateSessionDuplicateCreateFallsBackToLoad(t *testing.T) { memoryStore: newMemoryStore(), missingErr: fmt.Errorf("load session row: %w", agentsession.ErrSessionNotFound), }, - createErr: fmt.Errorf("unique constraint failed"), + createErr: fmt.Errorf("sqlite: %w", agentsession.ErrSessionAlreadyExists), loaded: agentsession.Session{ID: "session-dup", Title: "loaded"}, } service := &Service{ @@ -190,9 +191,15 @@ func TestCreateSessionErrorPredicates(t *testing.T) { if isRuntimeSessionAlreadyExistsError(nil) { t.Fatalf("isRuntimeSessionAlreadyExistsError(nil) should be false") } + if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("wrapped: %w", agentsession.ErrSessionAlreadyExists)) { + t.Fatalf("wrapped ErrSessionAlreadyExists should be detected") + } + if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("wrapped: %w", os.ErrExist)) { + t.Fatalf("wrapped os.ErrExist should be detected") + } for _, text := range []string{"already exists", "UNIQUE CONSTRAINT", "duplicate key"} { - if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("%s", text)) { - t.Fatalf("expected %q to be treated as already exists", text) + if isRuntimeSessionAlreadyExistsError(fmt.Errorf("%s", text)) { + t.Fatalf("plain text %q should not be treated as already exists", text) } } } diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 45c16172..93f8db91 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -3,6 +3,7 @@ package runtime import ( "context" "errors" + "os" "strings" "sync" "time" @@ -310,10 +311,7 @@ func isRuntimeSessionAlreadyExistsError(err error) bool { if err == nil { return false } - normalized := strings.ToLower(strings.TrimSpace(err.Error())) - return strings.Contains(normalized, "already exists") || - strings.Contains(normalized, "unique constraint") || - strings.Contains(normalized, "duplicate") + return errors.Is(err, agentsession.ErrSessionAlreadyExists) || errors.Is(err, os.ErrExist) } // SetAutoCompactThresholdResolver 注入自动压缩阈值解析能力,避免 runtime 直接处理模型目录细节。 diff --git a/internal/session/sqlite_store.go b/internal/session/sqlite_store.go index 75e4f77d..036e4948 100644 --- a/internal/session/sqlite_store.go +++ b/internal/session/sqlite_store.go @@ -177,6 +177,9 @@ INSERT INTO sessions ( session.TokenOutputTotal, ) if err != nil { + if isSQLiteSessionUniqueConstraintError(err) { + return Session{}, wrapSessionAlreadyExists(err) + } return Session{}, fmt.Errorf("session: insert session %s: %w", session.ID, err) } if err := tx.Commit(); err != nil { @@ -1306,6 +1309,14 @@ func wrapSessionNotFound(cause error) error { return fmt.Errorf("%w: %w", ErrSessionNotFound, fmt.Errorf("%w: %w", os.ErrNotExist, cause)) } +// wrapSessionAlreadyExists 统一包装会话重复创建错误,确保上层可通过 ErrSessionAlreadyExists 做精确判断。 +func wrapSessionAlreadyExists(cause error) error { + if cause == nil { + cause = os.ErrExist + } + return fmt.Errorf("%w: %w", ErrSessionAlreadyExists, fmt.Errorf("%w: %w", os.ErrExist, cause)) +} + // cloneMessage 深拷贝消息,避免共享底层切片和映射。 // mapSessionAssetInsertError 统一收敛附件元数据插入阶段的缺失会话语义,避免向上泄漏底层 SQLite 错误。 func mapSessionAssetInsertError(assetID string, err error) error { @@ -1324,6 +1335,18 @@ func isSQLiteForeignKeyConstraintError(err error) bool { return false } +// isSQLiteSessionUniqueConstraintError 判断底层错误是否为 SQLite 主键/唯一约束失败。 +func isSQLiteSessionUniqueConstraintError(err error) bool { + var sqliteErr *sqlitedriver.Error + if !errors.As(err, &sqliteErr) { + return false + } + code := sqliteErr.Code() + return code == sqlite3.SQLITE_CONSTRAINT || + code == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY || + code == sqlite3.SQLITE_CONSTRAINT_UNIQUE +} + func cloneMessage(message providertypes.Message) providertypes.Message { next := message next.Parts = providertypes.CloneParts(message.Parts) diff --git a/internal/session/store.go b/internal/session/store.go index 32164721..87646a9c 100644 --- a/internal/session/store.go +++ b/internal/session/store.go @@ -28,6 +28,9 @@ var storageIDPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,127}$`) // ErrSessionNotFound 表示会话在存储层不存在,用于 runtime 做精确错误分流。 var ErrSessionNotFound = errors.New("session: session not found") +// ErrSessionAlreadyExists 表示会话在存储层已存在,用于 runtime 处理并发创建冲突。 +var ErrSessionAlreadyExists = errors.New("session: session already exists") + // Session 表示单个会话的运行态与持久化聚合模型。 type Session struct { ID string diff --git a/internal/session/store_test.go b/internal/session/store_test.go index 52bc42f4..e9a76862 100644 --- a/internal/session/store_test.go +++ b/internal/session/store_test.go @@ -122,6 +122,28 @@ func TestSQLiteStoreLifecycleRoundTrip(t *testing.T) { } } +func TestSQLiteStoreCreateSessionDuplicateReturnsSentinel(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := newTestStore(t) + input := CreateSessionInput{ID: "dup_session", Title: "dup"} + if _, err := store.CreateSession(ctx, input); err != nil { + t.Fatalf("first CreateSession() error = %v", err) + } + + _, err := store.CreateSession(ctx, input) + if err == nil { + t.Fatalf("expected duplicate CreateSession() to fail") + } + if !errors.Is(err, ErrSessionAlreadyExists) { + t.Fatalf("expected ErrSessionAlreadyExists, got %v", err) + } + if !errors.Is(err, os.ErrExist) { + t.Fatalf("expected os.ErrExist chain, got %v", err) + } +} + func TestSQLiteStoreListSummariesSortedAndLegacyJSONIgnored(t *testing.T) { ctx := context.Background() baseDir, err := os.MkdirTemp("", "session-base-") diff --git a/internal/tui/services/gateway_rpc_client.go b/internal/tui/services/gateway_rpc_client.go index 3b42219f..6236f075 100644 --- a/internal/tui/services/gateway_rpc_client.go +++ b/internal/tui/services/gateway_rpc_client.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "net" "strings" "sync" @@ -17,10 +18,11 @@ import ( ) const ( - defaultGatewayRPCRequestTimeout = 8 * time.Second - defaultGatewayRPCRetryCount = 1 - defaultGatewayNotificationBuffer = 64 - defaultGatewayNotificationQueue = 256 + defaultGatewayRPCRequestTimeout = 8 * time.Second + defaultGatewayRPCRetryCount = 1 + defaultGatewayNotificationBuffer = 64 + defaultGatewayNotificationQueue = 256 + defaultGatewayNotificationEnqueueTimeout = 3 * time.Second ) // GatewayRPCClientOptions 描述网关 JSON-RPC 客户端的初始化参数。 @@ -107,11 +109,12 @@ type GatewayRPCClient struct { conn net.Conn pending map[string]chan gatewayRPCResponse - notifications chan gatewayRPCNotification - notificationQueue chan gatewayRPCNotification - notificationWG sync.WaitGroup - notificationStart sync.Once - sequence uint64 + notifications chan gatewayRPCNotification + notificationQueue chan gatewayRPCNotification + notificationEnqueueTimeout time.Duration + notificationWG sync.WaitGroup + notificationStart sync.Once + sequence uint64 } // NewGatewayRPCClient 创建网关 RPC 客户端,并在启动时静默读取认证 Token。 @@ -146,15 +149,16 @@ func NewGatewayRPCClient(options GatewayRPCClientOptions) (*GatewayRPCClient, er } return &GatewayRPCClient{ - listenAddress: listenAddress, - token: token, - requestTimeout: requestTimeout, - retryCount: retryCount, - dialFn: dialFn, - closed: make(chan struct{}), - pending: make(map[string]chan gatewayRPCResponse), - notifications: make(chan gatewayRPCNotification, defaultGatewayNotificationBuffer), - notificationQueue: make(chan gatewayRPCNotification, defaultGatewayNotificationQueue), + listenAddress: listenAddress, + token: token, + requestTimeout: requestTimeout, + retryCount: retryCount, + dialFn: dialFn, + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification, defaultGatewayNotificationBuffer), + notificationQueue: make(chan gatewayRPCNotification, defaultGatewayNotificationQueue), + notificationEnqueueTimeout: defaultGatewayNotificationEnqueueTimeout, }, nil } @@ -232,7 +236,6 @@ func (c *GatewayRPCClient) Close() error { c.closeOnce.Do(func() { close(c.closed) firstErr = c.forceCloseWithError(errors.New("gateway rpc client closed")) - close(c.notificationQueue) c.notificationWG.Wait() close(c.notifications) }) @@ -372,7 +375,9 @@ func (c *GatewayRPCClient) readLoop(conn net.Conn) { if paramsRaw, hasParams := envelope["params"]; hasParams { notification.Params = cloneJSONRawMessage(paramsRaw) } - c.enqueueNotification(notification) + if !c.enqueueNotification(notification) { + return + } continue } @@ -387,7 +392,7 @@ func (c *GatewayRPCClient) readLoop(conn net.Conn) { } } -// startNotificationDispatcher 启动通知转发协程,确保 readLoop 不会被 UI 消费速度阻塞。 +// startNotificationDispatcher 启动通知转发协程,配合 enqueue 超时保护避免 readLoop 长时间背压阻塞。 func (c *GatewayRPCClient) startNotificationDispatcher() { c.notificationStart.Do(func() { c.notificationWG.Add(1) @@ -412,12 +417,25 @@ func (c *GatewayRPCClient) startNotificationDispatcher() { }) } -// enqueueNotification 以阻塞方式投递通知,确保 gateway.event 不会因队列满被静默丢弃。 -func (c *GatewayRPCClient) enqueueNotification(notification gatewayRPCNotification) { +// enqueueNotification 投递通知到内部队列;若背压持续超时则主动断开连接,避免 readLoop 无限阻塞。 +func (c *GatewayRPCClient) enqueueNotification(notification gatewayRPCNotification) bool { + enqueueTimeout := c.notificationEnqueueTimeout + if enqueueTimeout <= 0 { + enqueueTimeout = defaultGatewayNotificationEnqueueTimeout + } + timer := time.NewTimer(enqueueTimeout) + defer timer.Stop() + select { case <-c.closed: - return + return false case c.notificationQueue <- notification: + return true + case <-timer.C: + err := fmt.Errorf("gateway rpc client: notification queue blocked for %s", enqueueTimeout) + log.Printf("warning: gateway rpc client force close due to notification backpressure method=%s err=%v", notification.Method, err) + _ = c.forceCloseWithError(err) + return false } } diff --git a/internal/tui/services/gateway_rpc_client_additional_test.go b/internal/tui/services/gateway_rpc_client_additional_test.go index 7f1e5bd9..10041047 100644 --- a/internal/tui/services/gateway_rpc_client_additional_test.go +++ b/internal/tui/services/gateway_rpc_client_additional_test.go @@ -443,7 +443,7 @@ func TestGatewayRPCClientReadLoopAdditionalBranches(t *testing.T) { _ = client.Close() } -func TestGatewayRPCClientNotificationDispatcherStopsOnQueueClose(t *testing.T) { +func TestGatewayRPCClientNotificationDispatcherStopsOnCloseSignal(t *testing.T) { t.Parallel() client := &GatewayRPCClient{ @@ -453,7 +453,7 @@ func TestGatewayRPCClientNotificationDispatcherStopsOnQueueClose(t *testing.T) { notificationQueue: make(chan gatewayRPCNotification, 1), } client.startNotificationDispatcher() - close(client.notificationQueue) + close(client.closed) client.notificationWG.Wait() } @@ -510,6 +510,79 @@ func TestGatewayRPCClientEnqueueNotificationDoesNotDropUnderQueuePressure(t *tes } } +func TestGatewayRPCClientReadLoopFailsFastOnNotificationBackpressure(t *testing.T) { + t.Parallel() + + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { + _ = clientConn.Close() + _ = serverConn.Close() + }) + + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification), + notificationQueue: make(chan gatewayRPCNotification, 1), + notificationEnqueueTimeout: 50 * time.Millisecond, + } + client.startNotificationDispatcher() + t.Cleanup(func() { _ = client.Close() }) + + readDone := make(chan struct{}) + go func() { + defer close(readDone) + client.readLoop(clientConn) + }() + encoder := json.NewEncoder(serverConn) + if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 1}}); err != nil { + t.Fatalf("encode first notification: %v", err) + } + if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 2}}); err != nil { + t.Fatalf("encode second notification: %v", err) + } + if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 3}}); err != nil { + t.Fatalf("encode third notification: %v", err) + } + + select { + case <-readDone: + case <-time.After(time.Second): + t.Fatalf("expected readLoop to fail-fast on sustained notification backpressure") + } +} + +func TestGatewayRPCClientEnqueueNotificationUnblocksOnClose(t *testing.T) { + t.Parallel() + + client := &GatewayRPCClient{ + closed: make(chan struct{}), + pending: make(map[string]chan gatewayRPCResponse), + notifications: make(chan gatewayRPCNotification), + notificationQueue: make(chan gatewayRPCNotification, 1), + notificationEnqueueTimeout: time.Second, + } + client.startNotificationDispatcher() + + // 首条通知占满队列,第二条通知会阻塞在 enqueue,关闭客户端后应立即退出。 + client.notificationQueue <- gatewayRPCNotification{Method: protocol.MethodGatewayEvent} + + done := make(chan struct{}) + go func() { + defer close(done) + client.enqueueNotification(gatewayRPCNotification{Method: protocol.MethodGatewayEvent}) + }() + + time.Sleep(20 * time.Millisecond) + _ = client.Close() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("enqueueNotification should unblock when client closes") + } +} + func TestGatewayRPCClientWriteRequestFailure(t *testing.T) { t.Parallel()