Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions internal/runtime/create_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package runtime
import (
"context"
"fmt"
"os"
"testing"

agentsession "neo-code/internal/session"
Expand Down Expand Up @@ -160,7 +161,7 @@ func TestServiceCreateSessionDuplicateCreateFallsBackToLoad(t *testing.T) {
memoryStore: newMemoryStore(),
missingErr: fmt.Errorf("load session row: %w", agentsession.ErrSessionNotFound),
},
createErr: fmt.Errorf("unique constraint failed"),
createErr: fmt.Errorf("sqlite: %w", agentsession.ErrSessionAlreadyExists),
loaded: agentsession.Session{ID: "session-dup", Title: "loaded"},
}
service := &Service{
Expand Down Expand Up @@ -190,9 +191,15 @@ func TestCreateSessionErrorPredicates(t *testing.T) {
if isRuntimeSessionAlreadyExistsError(nil) {
t.Fatalf("isRuntimeSessionAlreadyExistsError(nil) should be false")
}
if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("wrapped: %w", agentsession.ErrSessionAlreadyExists)) {
t.Fatalf("wrapped ErrSessionAlreadyExists should be detected")
}
if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("wrapped: %w", os.ErrExist)) {
t.Fatalf("wrapped os.ErrExist should be detected")
}
for _, text := range []string{"already exists", "UNIQUE CONSTRAINT", "duplicate key"} {
if !isRuntimeSessionAlreadyExistsError(fmt.Errorf("%s", text)) {
t.Fatalf("expected %q to be treated as already exists", text)
if isRuntimeSessionAlreadyExistsError(fmt.Errorf("%s", text)) {
t.Fatalf("plain text %q should not be treated as already exists", text)
}
}
}
6 changes: 2 additions & 4 deletions internal/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package runtime
import (
"context"
"errors"
"os"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -310,10 +311,7 @@ func isRuntimeSessionAlreadyExistsError(err error) bool {
if err == nil {
return false
}
normalized := strings.ToLower(strings.TrimSpace(err.Error()))
return strings.Contains(normalized, "already exists") ||
strings.Contains(normalized, "unique constraint") ||
strings.Contains(normalized, "duplicate")
return errors.Is(err, agentsession.ErrSessionAlreadyExists) || errors.Is(err, os.ErrExist)
}

// SetAutoCompactThresholdResolver 注入自动压缩阈值解析能力,避免 runtime 直接处理模型目录细节。
Expand Down
23 changes: 23 additions & 0 deletions internal/session/sqlite_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ INSERT INTO sessions (
session.TokenOutputTotal,
)
if err != nil {
if isSQLiteSessionUniqueConstraintError(err) {
return Session{}, wrapSessionAlreadyExists(err)
}
return Session{}, fmt.Errorf("session: insert session %s: %w", session.ID, err)
}
if err := tx.Commit(); err != nil {
Expand Down Expand Up @@ -1306,6 +1309,14 @@ func wrapSessionNotFound(cause error) error {
return fmt.Errorf("%w: %w", ErrSessionNotFound, fmt.Errorf("%w: %w", os.ErrNotExist, cause))
}

// wrapSessionAlreadyExists 统一包装会话重复创建错误,确保上层可通过 ErrSessionAlreadyExists 做精确判断。
func wrapSessionAlreadyExists(cause error) error {
if cause == nil {
cause = os.ErrExist
}
return fmt.Errorf("%w: %w", ErrSessionAlreadyExists, fmt.Errorf("%w: %w", os.ErrExist, cause))
}

// cloneMessage 深拷贝消息,避免共享底层切片和映射。
// mapSessionAssetInsertError 统一收敛附件元数据插入阶段的缺失会话语义,避免向上泄漏底层 SQLite 错误。
func mapSessionAssetInsertError(assetID string, err error) error {
Expand All @@ -1324,6 +1335,18 @@ func isSQLiteForeignKeyConstraintError(err error) bool {
return false
}

// isSQLiteSessionUniqueConstraintError 判断底层错误是否为 SQLite 主键/唯一约束失败。
func isSQLiteSessionUniqueConstraintError(err error) bool {
var sqliteErr *sqlitedriver.Error
if !errors.As(err, &sqliteErr) {
return false
}
code := sqliteErr.Code()
return code == sqlite3.SQLITE_CONSTRAINT ||
code == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY ||
code == sqlite3.SQLITE_CONSTRAINT_UNIQUE
}

func cloneMessage(message providertypes.Message) providertypes.Message {
next := message
next.Parts = providertypes.CloneParts(message.Parts)
Expand Down
3 changes: 3 additions & 0 deletions internal/session/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ var storageIDPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,127}$`)
// ErrSessionNotFound 表示会话在存储层不存在,用于 runtime 做精确错误分流。
var ErrSessionNotFound = errors.New("session: session not found")

// ErrSessionAlreadyExists 表示会话在存储层已存在,用于 runtime 处理并发创建冲突。
var ErrSessionAlreadyExists = errors.New("session: session already exists")

// Session 表示单个会话的运行态与持久化聚合模型。
type Session struct {
ID string
Expand Down
22 changes: 22 additions & 0 deletions internal/session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,28 @@ func TestSQLiteStoreLifecycleRoundTrip(t *testing.T) {
}
}

func TestSQLiteStoreCreateSessionDuplicateReturnsSentinel(t *testing.T) {
t.Parallel()

ctx := context.Background()
store := newTestStore(t)
input := CreateSessionInput{ID: "dup_session", Title: "dup"}
if _, err := store.CreateSession(ctx, input); err != nil {
t.Fatalf("first CreateSession() error = %v", err)
}

_, err := store.CreateSession(ctx, input)
if err == nil {
t.Fatalf("expected duplicate CreateSession() to fail")
}
if !errors.Is(err, ErrSessionAlreadyExists) {
t.Fatalf("expected ErrSessionAlreadyExists, got %v", err)
}
if !errors.Is(err, os.ErrExist) {
t.Fatalf("expected os.ErrExist chain, got %v", err)
}
}

func TestSQLiteStoreListSummariesSortedAndLegacyJSONIgnored(t *testing.T) {
ctx := context.Background()
baseDir, err := os.MkdirTemp("", "session-base-")
Expand Down
66 changes: 42 additions & 24 deletions internal/tui/services/gateway_rpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net"
"strings"
"sync"
Expand All @@ -17,10 +18,11 @@ import (
)

const (
defaultGatewayRPCRequestTimeout = 8 * time.Second
defaultGatewayRPCRetryCount = 1
defaultGatewayNotificationBuffer = 64
defaultGatewayNotificationQueue = 256
defaultGatewayRPCRequestTimeout = 8 * time.Second
defaultGatewayRPCRetryCount = 1
defaultGatewayNotificationBuffer = 64
defaultGatewayNotificationQueue = 256
defaultGatewayNotificationEnqueueTimeout = 3 * time.Second
)

// GatewayRPCClientOptions 描述网关 JSON-RPC 客户端的初始化参数。
Expand Down Expand Up @@ -107,11 +109,12 @@ type GatewayRPCClient struct {
conn net.Conn
pending map[string]chan gatewayRPCResponse

notifications chan gatewayRPCNotification
notificationQueue chan gatewayRPCNotification
notificationWG sync.WaitGroup
notificationStart sync.Once
sequence uint64
notifications chan gatewayRPCNotification
notificationQueue chan gatewayRPCNotification
notificationEnqueueTimeout time.Duration
notificationWG sync.WaitGroup
notificationStart sync.Once
sequence uint64
}

// NewGatewayRPCClient 创建网关 RPC 客户端,并在启动时静默读取认证 Token。
Expand Down Expand Up @@ -146,15 +149,16 @@ func NewGatewayRPCClient(options GatewayRPCClientOptions) (*GatewayRPCClient, er
}

return &GatewayRPCClient{
listenAddress: listenAddress,
token: token,
requestTimeout: requestTimeout,
retryCount: retryCount,
dialFn: dialFn,
closed: make(chan struct{}),
pending: make(map[string]chan gatewayRPCResponse),
notifications: make(chan gatewayRPCNotification, defaultGatewayNotificationBuffer),
notificationQueue: make(chan gatewayRPCNotification, defaultGatewayNotificationQueue),
listenAddress: listenAddress,
token: token,
requestTimeout: requestTimeout,
retryCount: retryCount,
dialFn: dialFn,
closed: make(chan struct{}),
pending: make(map[string]chan gatewayRPCResponse),
notifications: make(chan gatewayRPCNotification, defaultGatewayNotificationBuffer),
notificationQueue: make(chan gatewayRPCNotification, defaultGatewayNotificationQueue),
notificationEnqueueTimeout: defaultGatewayNotificationEnqueueTimeout,
}, nil
}

Expand Down Expand Up @@ -232,7 +236,6 @@ func (c *GatewayRPCClient) Close() error {
c.closeOnce.Do(func() {
close(c.closed)
firstErr = c.forceCloseWithError(errors.New("gateway rpc client closed"))
close(c.notificationQueue)
c.notificationWG.Wait()
close(c.notifications)
})
Expand Down Expand Up @@ -372,7 +375,9 @@ func (c *GatewayRPCClient) readLoop(conn net.Conn) {
if paramsRaw, hasParams := envelope["params"]; hasParams {
notification.Params = cloneJSONRawMessage(paramsRaw)
}
c.enqueueNotification(notification)
if !c.enqueueNotification(notification) {
return
}
continue
}

Expand All @@ -387,7 +392,7 @@ func (c *GatewayRPCClient) readLoop(conn net.Conn) {
}
}

// startNotificationDispatcher 启动通知转发协程,确保 readLoop 不会被 UI 消费速度阻塞
// startNotificationDispatcher 启动通知转发协程,配合 enqueue 超时保护避免 readLoop 长时间背压阻塞
func (c *GatewayRPCClient) startNotificationDispatcher() {
c.notificationStart.Do(func() {
c.notificationWG.Add(1)
Expand All @@ -412,12 +417,25 @@ func (c *GatewayRPCClient) startNotificationDispatcher() {
})
}

// enqueueNotification 以阻塞方式投递通知,确保 gateway.event 不会因队列满被静默丢弃。
func (c *GatewayRPCClient) enqueueNotification(notification gatewayRPCNotification) {
// enqueueNotification 投递通知到内部队列;若背压持续超时则主动断开连接,避免 readLoop 无限阻塞。
func (c *GatewayRPCClient) enqueueNotification(notification gatewayRPCNotification) bool {
enqueueTimeout := c.notificationEnqueueTimeout
if enqueueTimeout <= 0 {
enqueueTimeout = defaultGatewayNotificationEnqueueTimeout
}
timer := time.NewTimer(enqueueTimeout)
defer timer.Stop()

select {
case <-c.closed:
return
return false
case c.notificationQueue <- notification:
return true
case <-timer.C:
err := fmt.Errorf("gateway rpc client: notification queue blocked for %s", enqueueTimeout)
log.Printf("warning: gateway rpc client force close due to notification backpressure method=%s err=%v", notification.Method, err)
_ = c.forceCloseWithError(err)
return false
}
}

Expand Down
77 changes: 75 additions & 2 deletions internal/tui/services/gateway_rpc_client_additional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ func TestGatewayRPCClientReadLoopAdditionalBranches(t *testing.T) {
_ = client.Close()
}

func TestGatewayRPCClientNotificationDispatcherStopsOnQueueClose(t *testing.T) {
func TestGatewayRPCClientNotificationDispatcherStopsOnCloseSignal(t *testing.T) {
t.Parallel()

client := &GatewayRPCClient{
Expand All @@ -453,7 +453,7 @@ func TestGatewayRPCClientNotificationDispatcherStopsOnQueueClose(t *testing.T) {
notificationQueue: make(chan gatewayRPCNotification, 1),
}
client.startNotificationDispatcher()
close(client.notificationQueue)
close(client.closed)
client.notificationWG.Wait()
}

Expand Down Expand Up @@ -510,6 +510,79 @@ func TestGatewayRPCClientEnqueueNotificationDoesNotDropUnderQueuePressure(t *tes
}
}

func TestGatewayRPCClientReadLoopFailsFastOnNotificationBackpressure(t *testing.T) {
t.Parallel()

clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
})

client := &GatewayRPCClient{
closed: make(chan struct{}),
pending: make(map[string]chan gatewayRPCResponse),
notifications: make(chan gatewayRPCNotification),
notificationQueue: make(chan gatewayRPCNotification, 1),
notificationEnqueueTimeout: 50 * time.Millisecond,
}
client.startNotificationDispatcher()
t.Cleanup(func() { _ = client.Close() })

readDone := make(chan struct{})
go func() {
defer close(readDone)
client.readLoop(clientConn)
}()
encoder := json.NewEncoder(serverConn)
if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 1}}); err != nil {
t.Fatalf("encode first notification: %v", err)
}
if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 2}}); err != nil {
t.Fatalf("encode second notification: %v", err)
}
if err := encoder.Encode(map[string]any{"method": protocol.MethodGatewayEvent, "params": map[string]any{"idx": 3}}); err != nil {
t.Fatalf("encode third notification: %v", err)
}

select {
case <-readDone:
case <-time.After(time.Second):
t.Fatalf("expected readLoop to fail-fast on sustained notification backpressure")
}
}

func TestGatewayRPCClientEnqueueNotificationUnblocksOnClose(t *testing.T) {
t.Parallel()

client := &GatewayRPCClient{
closed: make(chan struct{}),
pending: make(map[string]chan gatewayRPCResponse),
notifications: make(chan gatewayRPCNotification),
notificationQueue: make(chan gatewayRPCNotification, 1),
notificationEnqueueTimeout: time.Second,
}
client.startNotificationDispatcher()

// 首条通知占满队列,第二条通知会阻塞在 enqueue,关闭客户端后应立即退出。
client.notificationQueue <- gatewayRPCNotification{Method: protocol.MethodGatewayEvent}

done := make(chan struct{})
go func() {
defer close(done)
client.enqueueNotification(gatewayRPCNotification{Method: protocol.MethodGatewayEvent})
}()

time.Sleep(20 * time.Millisecond)
_ = client.Close()

select {
case <-done:
case <-time.After(time.Second):
t.Fatalf("enqueueNotification should unblock when client closes")
}
}

func TestGatewayRPCClientWriteRequestFailure(t *testing.T) {
t.Parallel()

Expand Down
Loading