diff --git a/README.md b/README.md index 7b58fc96..43c5a21b 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,12 @@ cd neo-code go run ./cmd/neocode ``` +Gateway 独立进程(Step 1 骨架): + +```bash +go run ./cmd/neocode-gateway +``` + 设置 API Key 示例(按你使用的 provider 选择): ```bash diff --git a/cmd/neocode-gateway/main.go b/cmd/neocode-gateway/main.go new file mode 100644 index 00000000..cb231449 --- /dev/null +++ b/cmd/neocode-gateway/main.go @@ -0,0 +1,87 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "os" + "os/signal" + "strings" + "syscall" + + "neo-code/internal/gateway" +) + +const ( + defaultLogLevel = "info" +) + +var errHelpRequested = errors.New("help requested") + +// main 负责启动 Gateway 独立进程,并在收到系统信号时优雅退出。 +func main() { + if err := run(); err != nil { + fmt.Fprintf(os.Stderr, "neocode-gateway: %v\n", err) + os.Exit(1) + } +} + +// run 解析启动参数并驱动网关服务生命周期。 +func run() error { + listenAddress, logLevel, err := parseFlags() + if err != nil { + if errors.Is(err, errHelpRequested) { + return nil + } + return err + } + + logger := log.New(os.Stderr, "neocode-gateway: ", log.LstdFlags) + logger.Printf("starting gateway (log-level=%s)", logLevel) + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + server, err := gateway.NewServer(gateway.ServerOptions{ + ListenAddress: listenAddress, + Logger: logger, + }) + if err != nil { + return err + } + defer func() { + _ = server.Close(context.Background()) + }() + + logger.Printf("gateway listen address: %s", server.ListenAddress()) + return server.Serve(ctx, nil) +} + +// parseFlags 解析命令行参数并执行基础校验。 +func parseFlags() (listenAddress string, logLevel string, err error) { + fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + fs.SetOutput(os.Stdout) + + var listen string + var level string + fs.StringVar(&listen, "listen", "", "gateway listen address (optional override)") + fs.StringVar(&level, "log-level", defaultLogLevel, "gateway log level: debug|info|warn|error") + + if parseErr := fs.Parse(os.Args[1:]); parseErr != nil { + if errors.Is(parseErr, flag.ErrHelp) { + return "", "", errHelpRequested + } + return "", "", parseErr + } + + normalizedLevel := strings.ToLower(strings.TrimSpace(level)) + switch normalizedLevel { + case "debug", "info", "warn", "error": + default: + return "", "", fmt.Errorf("invalid --log-level %q: must be debug|info|warn|error", level) + } + + return strings.TrimSpace(listen), normalizedLevel, nil +} diff --git a/cmd/neocode-gateway/main_test.go b/cmd/neocode-gateway/main_test.go new file mode 100644 index 00000000..6f706e3d --- /dev/null +++ b/cmd/neocode-gateway/main_test.go @@ -0,0 +1,89 @@ +package main + +import ( + "errors" + "flag" + "os" + "strings" + "testing" +) + +func TestParseFlagsValid(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--listen", " /tmp/gateway.sock ", "--log-level", " WARN "}, func() { + listen, level, err := parseFlags() + if err != nil { + t.Fatalf("parse flags: %v", err) + } + if listen != "/tmp/gateway.sock" { + t.Fatalf("listen = %q, want %q", listen, "/tmp/gateway.sock") + } + if level != "warn" { + t.Fatalf("log level = %q, want %q", level, "warn") + } + }) +} + +func TestParseFlagsHelp(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--help"}, func() { + _, _, err := parseFlags() + if !errors.Is(err, errHelpRequested) { + t.Fatalf("parse flags error = %v, want %v", err, errHelpRequested) + } + }) +} + +func TestParseFlagsInvalidLogLevel(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--log-level", "trace"}, func() { + _, _, err := parseFlags() + if err == nil { + t.Fatal("expected invalid log level error") + } + if !strings.Contains(err.Error(), "invalid --log-level") { + t.Fatalf("error = %v, want contains %q", err, "invalid --log-level") + } + }) +} + +func TestParseFlagsUnknownFlag(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--unknown"}, func() { + _, _, err := parseFlags() + if err == nil { + t.Fatal("expected parse error") + } + if errors.Is(err, flag.ErrHelp) { + t.Fatalf("error = %v, should not be help error", err) + } + }) +} + +func TestRunHelp(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--help"}, func() { + if err := run(); err != nil { + t.Fatalf("run help: %v", err) + } + }) +} + +func TestRunInvalidLogLevel(t *testing.T) { + withArgs(t, []string{"neocode-gateway", "--log-level", "trace"}, func() { + err := run() + if err == nil { + t.Fatal("expected run error") + } + if !strings.Contains(err.Error(), "invalid --log-level") { + t.Fatalf("error = %v, want contains %q", err, "invalid --log-level") + } + }) +} + +func withArgs(t *testing.T, args []string, fn func()) { + t.Helper() + + originalArgs := os.Args + os.Args = args + defer func() { + os.Args = originalArgs + }() + + fn() +} diff --git a/go.mod b/go.mod index 7d1726d8..b5b9bda2 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( ) require ( + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/alecthomas/chroma/v2 v2.20.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect diff --git a/go.sum b/go.sum index 56abd244..59439be3 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= github.com/alecthomas/chroma/v2 v2.20.0 h1:sfIHpxPyR07/Oylvmcai3X/exDlE8+FA820NTz+9sGw= diff --git a/internal/gateway/server.go b/internal/gateway/server.go new file mode 100644 index 00000000..9078d082 --- /dev/null +++ b/internal/gateway/server.go @@ -0,0 +1,430 @@ +package gateway + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "strings" + "sync" + "time" + + "neo-code/internal/gateway/transport" +) + +const ( + // MaxFrameSize 定义单条 JSON 帧允许的最大字节数,避免异常输入导致内存放大。 + MaxFrameSize int64 = 1 << 20 // 1 MiB + + // DefaultMaxConnections 定义服务允许的最大并发连接数,超过上限的连接会被快速拒绝。 + DefaultMaxConnections = 128 + // DefaultReadTimeout 定义单次读帧的最大等待时间,避免慢连接长期占用资源。 + DefaultReadTimeout = 30 * time.Second + // DefaultWriteTimeout 定义单次写帧的最大等待时间,避免写阻塞占用处理协程。 + DefaultWriteTimeout = 30 * time.Second +) + +var ( + errFrameTooLarge = errors.New("frame exceeds max size") + errFrameEmpty = errors.New("empty frame") + + defaultListenAddressFn = transport.DefaultListenAddress +) + +// ServerOptions 描述网关服务启动所需的可选配置。 +type ServerOptions struct { + ListenAddress string + Logger *log.Logger + MaxConnections int + ReadTimeout time.Duration + WriteTimeout time.Duration + listenFn func(address string) (net.Listener, error) +} + +// Server 提供基于本地 IPC 的网关服务骨架实现。 +type Server struct { + listenAddress string + logger *log.Logger + listenFn func(address string) (net.Listener, error) + maxConnections int + readTimeout time.Duration + writeTimeout time.Duration + + mu sync.Mutex + listener net.Listener + conns map[net.Conn]struct{} + wg sync.WaitGroup +} + +type registerConnectionResult int + +const ( + registerConnectionAccepted registerConnectionResult = iota + registerConnectionServerClosed + registerConnectionLimitExceeded +) + +// NewServer 创建网关服务实例,并解析默认监听地址。 +func NewServer(options ServerOptions) (*Server, error) { + listenAddress := strings.TrimSpace(options.ListenAddress) + if listenAddress == "" { + resolved, err := defaultListenAddressFn() + if err != nil { + return nil, err + } + listenAddress = resolved + } + + logger := options.Logger + if logger == nil { + logger = log.New(os.Stderr, "gateway: ", log.LstdFlags) + } + + listenFn := options.listenFn + if listenFn == nil { + listenFn = transport.Listen + } + + maxConnections := options.MaxConnections + if maxConnections <= 0 { + maxConnections = DefaultMaxConnections + } + + readTimeout := options.ReadTimeout + if readTimeout <= 0 { + readTimeout = DefaultReadTimeout + } + + writeTimeout := options.WriteTimeout + if writeTimeout <= 0 { + writeTimeout = DefaultWriteTimeout + } + + return &Server{ + listenAddress: listenAddress, + logger: logger, + listenFn: listenFn, + maxConnections: maxConnections, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + conns: make(map[net.Conn]struct{}), + }, nil +} + +// ListenAddress 返回当前服务绑定的监听地址。 +func (s *Server) ListenAddress() string { + return s.listenAddress +} + +// Serve 启动 IPC 监听并处理客户端请求。 +func (s *Server) Serve(ctx context.Context, runtimePort RuntimePort) error { + listener, err := s.listenFn(s.listenAddress) + if err != nil { + return err + } + + s.mu.Lock() + if s.listener != nil { + s.mu.Unlock() + _ = listener.Close() + return fmt.Errorf("gateway: server is already serving") + } + s.listener = listener + s.mu.Unlock() + + s.logger.Printf("listening on %s", s.listenAddress) + + go func() { + <-ctx.Done() + _ = s.Close(context.Background()) + }() + + for { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + if errors.Is(acceptErr, net.ErrClosed) || ctx.Err() != nil || s.isClosed() { + return nil + } + return fmt.Errorf("gateway: accept connection: %w", acceptErr) + } + + switch s.registerConnection(conn) { + case registerConnectionAccepted: + case registerConnectionServerClosed: + _ = conn.Close() + continue + case registerConnectionLimitExceeded: + s.logger.Printf("reject connection: max connections %d reached", s.maxConnections) + _ = conn.Close() + continue + } + + go func() { + defer s.wg.Done() + defer s.untrackConnection(conn) + s.handleConnection(ctx, conn, runtimePort) + }() + } +} + +// Close 关闭监听器并等待所有连接处理协程退出。 +func (s *Server) Close(ctx context.Context) error { + s.mu.Lock() + listener := s.listener + s.listener = nil + s.mu.Unlock() + + var closeErr error + if listener != nil { + closeErr = listener.Close() + } + + for conn := range s.snapshotConnections() { + closeErr = errors.Join(closeErr, conn.Close()) + } + + waitDone := make(chan struct{}) + go func() { + s.wg.Wait() + close(waitDone) + }() + + select { + case <-ctx.Done(): + closeErr = errors.Join(closeErr, ctx.Err()) + case <-waitDone: + } + + return closeErr +} + +// isClosed 判断监听器是否已经关闭。 +func (s *Server) isClosed() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.listener == nil +} + +// snapshotConnections 返回当前连接集合的拷贝,用于关闭流程安全遍历。 +func (s *Server) snapshotConnections() map[net.Conn]struct{} { + s.mu.Lock() + defer s.mu.Unlock() + + copied := make(map[net.Conn]struct{}, len(s.conns)) + for conn := range s.conns { + copied[conn] = struct{}{} + } + return copied +} + +// registerConnection 在服务可用且未超限时登记连接,并原子增加连接处理 WaitGroup 计数。 +func (s *Server) registerConnection(conn net.Conn) registerConnectionResult { + s.mu.Lock() + defer s.mu.Unlock() + if s.listener == nil { + return registerConnectionServerClosed + } + if len(s.conns) >= s.maxConnections { + return registerConnectionLimitExceeded + } + s.conns[conn] = struct{}{} + s.wg.Add(1) + return registerConnectionAccepted +} + +// untrackConnection 移除已结束连接,避免连接集合持续增长。 +func (s *Server) untrackConnection(conn net.Conn) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.conns, conn) +} + +// handleConnection 在单连接上循环处理消息帧并返回响应帧。 +func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePort RuntimePort) { + defer func() { + _ = conn.Close() + }() + + reader := bufio.NewReader(conn) + encoder := json.NewEncoder(conn) + + for { + select { + case <-ctx.Done(): + return + default: + } + + if err := s.applyReadDeadline(conn); err != nil { + s.logger.Printf("set read deadline failed: %v", err) + return + } + + frame, err := decodeFrame(reader) + if err != nil { + if errors.Is(err, io.EOF) { + return + } + if errors.Is(err, errFrameEmpty) { + continue + } + if isTimeoutError(err) { + s.logger.Printf("read frame timeout: %v", err) + return + } + if errors.Is(err, errFrameTooLarge) { + s.logger.Printf("decode frame failed: %v", err) + _ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError( + ErrorCodeInvalidFrame, + fmt.Sprintf("frame exceeds max size %d bytes", MaxFrameSize), + ))) + return + } + + s.logger.Printf("decode frame failed: %v", err) + _ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError(ErrorCodeInvalidFrame, "invalid json frame"))) + return + } + + response := s.dispatchFrame(ctx, frame, runtimePort) + if !s.writeFrame(conn, encoder, response) { + return + } + } +} + +// applyReadDeadline 为当前连接设置下一次读操作超时,避免慢读连接长期占用协程。 +func (s *Server) applyReadDeadline(conn net.Conn) error { + if s.readTimeout <= 0 { + return nil + } + return conn.SetReadDeadline(time.Now().Add(s.readTimeout)) +} + +// applyWriteDeadline 为当前连接设置下一次写操作超时,避免写阻塞导致协程泄漏。 +func (s *Server) applyWriteDeadline(conn net.Conn) error { + if s.writeTimeout <= 0 { + return nil + } + return conn.SetWriteDeadline(time.Now().Add(s.writeTimeout)) +} + +// writeFrame 统一处理响应写回及写超时设置,失败时返回 false 供上层快速终止连接循环。 +func (s *Server) writeFrame(conn net.Conn, encoder *json.Encoder, frame MessageFrame) bool { + if err := s.applyWriteDeadline(conn); err != nil { + s.logger.Printf("set write deadline failed: %v", err) + return false + } + if err := encoder.Encode(frame); err != nil { + s.logger.Printf("write frame failed: %v", err) + return false + } + return true +} + +// isTimeoutError 判断错误是否为网络超时,用于区分慢连接超时与协议错误。 +func isTimeoutError(err error) bool { + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} + +// decodeFrame 从连接读取一条 JSON 帧并执行长度与格式校验。 +func decodeFrame(reader *bufio.Reader) (MessageFrame, error) { + payload, err := readFramePayload(reader, MaxFrameSize) + if err != nil { + return MessageFrame{}, err + } + + limitedReader := &io.LimitedReader{R: bytes.NewReader(payload), N: MaxFrameSize} + decoder := json.NewDecoder(limitedReader) + + var frame MessageFrame + if err := decoder.Decode(&frame); err != nil { + return MessageFrame{}, err + } + + var trailing any + if err := decoder.Decode(&trailing); !errors.Is(err, io.EOF) { + return MessageFrame{}, fmt.Errorf("frame contains trailing json values") + } + + return frame, nil +} + +// readFramePayload 按换行边界读取单条帧,并限制单帧最大字节数。 +func readFramePayload(reader *bufio.Reader, maxSize int64) ([]byte, error) { + var payload []byte + + for { + chunk, err := reader.ReadSlice('\n') + if int64(len(payload)+len(chunk)) > maxSize { + return nil, errFrameTooLarge + } + payload = append(payload, chunk...) + + if err == nil { + break + } + if errors.Is(err, bufio.ErrBufferFull) { + continue + } + if errors.Is(err, io.EOF) { + if len(payload) == 0 { + return nil, io.EOF + } + break + } + return nil, err + } + + payload = bytes.TrimSpace(payload) + if len(payload) == 0 { + return nil, errFrameEmpty + } + return payload, nil +} + +// dispatchFrame 根据请求动作生成响应帧。 +func (s *Server) dispatchFrame(_ context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { + _ = runtimePort + + if validationErr := ValidateFrame(frame); validationErr != nil { + return errorFrame(frame, validationErr) + } + + if frame.Type != FrameTypeRequest { + return errorFrame(frame, NewFrameError(ErrorCodeInvalidFrame, "only request frames are supported")) + } + + switch frame.Action { + case FrameActionPing: + return MessageFrame{ + Type: FrameTypeAck, + Action: FrameActionPing, + RequestID: frame.RequestID, + Payload: map[string]string{ + "message": "pong", + }, + } + default: + return errorFrame(frame, NewFrameError(ErrorCodeUnsupportedAction, "action is not implemented in gateway step 1")) + } +} + +// errorFrame 构建统一错误响应帧。 +func errorFrame(frame MessageFrame, frameErr *FrameError) MessageFrame { + return MessageFrame{ + Type: FrameTypeError, + Action: frame.Action, + RequestID: frame.RequestID, + Error: frameErr, + } +} + +var _ Gateway = (*Server)(nil) diff --git a/internal/gateway/server_additional_test.go b/internal/gateway/server_additional_test.go new file mode 100644 index 00000000..3fcc7118 --- /dev/null +++ b/internal/gateway/server_additional_test.go @@ -0,0 +1,490 @@ +package gateway + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "io" + "log" + "net" + "strings" + "sync" + "testing" + "time" +) + +func TestNewServerUsesDefaultsAndOverrides(t *testing.T) { + originalDefaultListenAddress := defaultListenAddressFn + defaultListenAddressFn = func() (string, error) { + return "default-address", nil + } + t.Cleanup(func() { + defaultListenAddressFn = originalDefaultListenAddress + }) + + server, err := NewServer(ServerOptions{}) + if err != nil { + t.Fatalf("new server with defaults: %v", err) + } + if server.ListenAddress() != "default-address" { + t.Fatalf("default listen address = %q, want %q", server.ListenAddress(), "default-address") + } + if server.logger == nil { + t.Fatal("default logger should not be nil") + } + if server.listenFn == nil { + t.Fatal("default listen function should not be nil") + } + if server.maxConnections != DefaultMaxConnections { + t.Fatalf("default max connections = %d, want %d", server.maxConnections, DefaultMaxConnections) + } + if server.readTimeout != DefaultReadTimeout { + t.Fatalf("default read timeout = %v, want %v", server.readTimeout, DefaultReadTimeout) + } + if server.writeTimeout != DefaultWriteTimeout { + t.Fatalf("default write timeout = %v, want %v", server.writeTimeout, DefaultWriteTimeout) + } + + customLogger := log.New(io.Discard, "custom", 0) + customServer, err := NewServer(ServerOptions{ + ListenAddress: " custom-address ", + Logger: customLogger, + MaxConnections: 7, + ReadTimeout: 150 * time.Millisecond, + WriteTimeout: 250 * time.Millisecond, + listenFn: func(string) (net.Listener, error) { + return nil, nil + }, + }) + if err != nil { + t.Fatalf("new server with custom options: %v", err) + } + if customServer.ListenAddress() != "custom-address" { + t.Fatalf("custom listen address = %q, want %q", customServer.ListenAddress(), "custom-address") + } + if customServer.logger != customLogger { + t.Fatal("custom logger was not used") + } + if customServer.maxConnections != 7 { + t.Fatalf("custom max connections = %d, want %d", customServer.maxConnections, 7) + } + if customServer.readTimeout != 150*time.Millisecond { + t.Fatalf("custom read timeout = %v, want %v", customServer.readTimeout, 150*time.Millisecond) + } + if customServer.writeTimeout != 250*time.Millisecond { + t.Fatalf("custom write timeout = %v, want %v", customServer.writeTimeout, 250*time.Millisecond) + } +} + +func TestNewServerReturnsDefaultAddressError(t *testing.T) { + originalDefaultListenAddress := defaultListenAddressFn + defaultListenAddressFn = func() (string, error) { + return "", errors.New("default address failed") + } + t.Cleanup(func() { + defaultListenAddressFn = originalDefaultListenAddress + }) + + _, err := NewServer(ServerOptions{}) + if err == nil { + t.Fatal("expected error when default listen address fails") + } + if !strings.Contains(err.Error(), "default address failed") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestServerIsClosedState(t *testing.T) { + server := &Server{} + if !server.isClosed() { + t.Fatal("expected server to be closed when listener is nil") + } + + server.listener = &simpleListener{} + if server.isClosed() { + t.Fatal("expected server to be open when listener exists") + } +} + +func TestServeReturnsListenError(t *testing.T) { + server, err := NewServer(ServerOptions{ + ListenAddress: "listen-error", + Logger: log.New(io.Discard, "", 0), + listenFn: func(string) (net.Listener, error) { + return nil, errors.New("listen failed") + }, + }) + if err != nil { + t.Fatalf("new server: %v", err) + } + + serveErr := server.Serve(context.Background(), nil) + if serveErr == nil || !strings.Contains(serveErr.Error(), "listen failed") { + t.Fatalf("expected listen failure, got %v", serveErr) + } +} + +func TestServeRejectsAlreadyServing(t *testing.T) { + created := &simpleListener{} + server, err := NewServer(ServerOptions{ + ListenAddress: "already-serving", + Logger: log.New(io.Discard, "", 0), + listenFn: func(string) (net.Listener, error) { + return created, nil + }, + }) + if err != nil { + t.Fatalf("new server: %v", err) + } + server.listener = &simpleListener{} + + serveErr := server.Serve(context.Background(), nil) + if serveErr == nil || !strings.Contains(serveErr.Error(), "already serving") { + t.Fatalf("expected already serving error, got %v", serveErr) + } + if !created.closed { + t.Fatal("newly created listener should be closed when server is already serving") + } +} + +func TestServeReturnsAcceptError(t *testing.T) { + listener := &scriptedListener{results: []acceptResult{{err: errors.New("accept failed")}}} + server, err := NewServer(ServerOptions{ + ListenAddress: "accept-error", + Logger: log.New(io.Discard, "", 0), + listenFn: func(string) (net.Listener, error) { + return listener, nil + }, + }) + if err != nil { + t.Fatalf("new server: %v", err) + } + + serveErr := server.Serve(context.Background(), nil) + if serveErr == nil || !strings.Contains(serveErr.Error(), "accept connection") { + t.Fatalf("expected accept error, got %v", serveErr) + } +} + +func TestServeSkipsConnectionWhenRegisterRejected(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + listener := &scriptedListener{ + results: []acceptResult{ + { + conn: serverConn, + }, + {err: net.ErrClosed}, + }, + } + + server, err := NewServer(ServerOptions{ + ListenAddress: "register-reject", + Logger: log.New(io.Discard, "", 0), + listenFn: func(string) (net.Listener, error) { + return listener, nil + }, + }) + if err != nil { + t.Fatalf("new server: %v", err) + } + + listener.results[0].beforeReturn = func() { + server.mu.Lock() + server.listener = nil + server.mu.Unlock() + } + + serveErr := server.Serve(context.Background(), nil) + if serveErr != nil { + t.Fatalf("serve should exit cleanly when listener closed, got %v", serveErr) + } + + readDone := make(chan error, 1) + go func() { + var buf [1]byte + _, err := clientConn.Read(buf[:]) + readDone <- err + }() + + select { + case err := <-readDone: + if !errors.Is(err, io.EOF) && (err == nil || !strings.Contains(err.Error(), "closed pipe")) { + t.Fatalf("expected rejected connection to be closed, got %v", err) + } + case <-time.After(500 * time.Millisecond): + t.Fatal("rejected connection was not closed") + } +} + +func TestCloseReturnsContextErrorWhenWaitCanceled(t *testing.T) { + server := &Server{conns: make(map[net.Conn]struct{})} + server.wg.Add(1) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := server.Close(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("close error = %v, want context canceled", err) + } + + server.wg.Done() +} + +func TestDecodeFrameTrailingJSON(t *testing.T) { + reader := bufio.NewReader(strings.NewReader(`{"type":"request","action":"ping"} {"extra":1}` + "\n")) + _, err := decodeFrame(reader) + if err == nil || !strings.Contains(err.Error(), "trailing") { + t.Fatalf("expected trailing json error, got %v", err) + } +} + +func TestReadFramePayloadBranches(t *testing.T) { + if _, err := readFramePayload(bufio.NewReader(strings.NewReader("")), MaxFrameSize); !errors.Is(err, io.EOF) { + t.Fatalf("empty payload error = %v, want io.EOF", err) + } + + payload, err := readFramePayload(bufio.NewReader(strings.NewReader("{\"type\":\"request\"}")), MaxFrameSize) + if err != nil { + t.Fatalf("payload without newline should decode at EOF: %v", err) + } + if string(payload) != `{"type":"request"}` { + t.Fatalf("payload mismatch: %q", string(payload)) + } + + tooLarge := strings.Repeat("a", 5000) + if _, err := readFramePayload(bufio.NewReaderSize(strings.NewReader(tooLarge), 64), 1024); !errors.Is(err, errFrameTooLarge) { + t.Fatalf("oversized payload error = %v, want errFrameTooLarge", err) + } + + if _, err := readFramePayload(bufio.NewReader(&failingReader{}), MaxFrameSize); err == nil || err.Error() != "read failed" { + t.Fatalf("expected read failure, got %v", err) + } +} + +func TestDispatchFrameNonRequest(t *testing.T) { + server := &Server{} + response := server.dispatchFrame(context.Background(), MessageFrame{Type: FrameTypeEvent, Action: FrameActionPing}, nil) + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeInvalidFrame.String() { + t.Fatalf("response error = %#v, want invalid frame", response.Error) + } +} + +func TestDispatchFrameValidationError(t *testing.T) { + server := &Server{} + response := server.dispatchFrame(context.Background(), MessageFrame{Type: FrameType("invalid")}, nil) + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeInvalidFrame.String() { + t.Fatalf("response error = %#v, want invalid frame", response.Error) + } +} + +func TestServerHandleConnectionSkipsEmptyFrame(t *testing.T) { + server := &Server{logger: log.New(io.Discard, "", 0)} + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + _, _ = io.WriteString(clientConn, "\n") + _, _ = io.WriteString(clientConn, `{"type":"request","action":"ping","request_id":"empty-then-ping"}`+"\n") + + decoder := json.NewDecoder(clientConn) + var response MessageFrame + if err := decoder.Decode(&response); err != nil { + t.Fatalf("decode response: %v", err) + } + if response.Type != FrameTypeAck || response.Action != FrameActionPing { + t.Fatalf("unexpected response after empty frame: %#v", response) + } + + _ = clientConn.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleConnection did not exit") + } +} + +func TestServerHandleConnectionInvalidJSONFrame(t *testing.T) { + server := &Server{logger: log.New(io.Discard, "", 0)} + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + _, _ = io.WriteString(clientConn, "{invalid-json}\n") + decoder := json.NewDecoder(clientConn) + var response MessageFrame + if err := decoder.Decode(&response); err != nil { + t.Fatalf("decode response: %v", err) + } + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeInvalidFrame.String() { + t.Fatalf("response error = %#v, want invalid frame", response.Error) + } + + _ = clientConn.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleConnection did not exit") + } +} + +func TestRegisterConnectionRejectsWhenLimitExceeded(t *testing.T) { + server := &Server{ + listener: &simpleListener{}, + maxConnections: 1, + conns: make(map[net.Conn]struct{}), + } + + conn1Server, conn1Client := net.Pipe() + defer conn1Client.Close() + defer conn1Server.Close() + if got := server.registerConnection(conn1Server); got != registerConnectionAccepted { + t.Fatalf("first register result = %v, want accepted", got) + } + + conn2Server, conn2Client := net.Pipe() + defer conn2Client.Close() + defer conn2Server.Close() + if got := server.registerConnection(conn2Server); got != registerConnectionLimitExceeded { + t.Fatalf("second register result = %v, want limit exceeded", got) + } + + server.untrackConnection(conn1Server) + server.wg.Done() +} + +func TestServerHandleConnectionReadTimeoutClosesConnection(t *testing.T) { + server := &Server{ + logger: log.New(io.Discard, "", 0), + readTimeout: 20 * time.Millisecond, + } + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("handleConnection should exit after read timeout") + } + + var buf [1]byte + _, err := clientConn.Read(buf[:]) + if !errors.Is(err, io.EOF) && (err == nil || !strings.Contains(err.Error(), "closed pipe")) { + t.Fatalf("expected closed connection after timeout, got %v", err) + } + _ = clientConn.Close() +} + +func TestServerHandleConnectionWriteTimeoutClosesConnection(t *testing.T) { + server := &Server{ + logger: log.New(io.Discard, "", 0), + readTimeout: time.Second, + writeTimeout: 20 * time.Millisecond, + } + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + _, err := io.WriteString(clientConn, `{"type":"request","action":"ping","request_id":"write-timeout"}`+"\n") + if err != nil { + t.Fatalf("write request: %v", err) + } + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("handleConnection should exit after write timeout") + } + + _ = clientConn.Close() +} + +type failingReader struct{} + +func (r *failingReader) Read(_ []byte) (int, error) { + return 0, errors.New("read failed") +} + +type simpleListener struct { + closed bool +} + +func (l *simpleListener) Accept() (net.Conn, error) { + return nil, net.ErrClosed +} + +func (l *simpleListener) Close() error { + l.closed = true + return nil +} + +func (l *simpleListener) Addr() net.Addr { + return stubAddr("simple") +} + +type acceptResult struct { + conn net.Conn + err error + beforeReturn func() +} + +type scriptedListener struct { + mu sync.Mutex + results []acceptResult + closed bool +} + +func (l *scriptedListener) Accept() (net.Conn, error) { + l.mu.Lock() + defer l.mu.Unlock() + if len(l.results) == 0 { + return nil, net.ErrClosed + } + result := l.results[0] + l.results = l.results[1:] + if result.beforeReturn != nil { + result.beforeReturn() + } + if result.err != nil { + return nil, result.err + } + if result.conn == nil { + return nil, net.ErrClosed + } + return result.conn, nil +} + +func (l *scriptedListener) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + l.closed = true + return nil +} + +func (l *scriptedListener) Addr() net.Addr { + return stubAddr("scripted") +} diff --git a/internal/gateway/server_race_test.go b/internal/gateway/server_race_test.go new file mode 100644 index 00000000..5798b14f --- /dev/null +++ b/internal/gateway/server_race_test.go @@ -0,0 +1,136 @@ +package gateway + +import ( + "context" + "errors" + "io" + "log" + "net" + "os" + "strings" + "sync" + "testing" + "time" +) + +func TestServeCloseDuringAcceptDoesNotLeakConnection(t *testing.T) { + t.Parallel() + + listener := newStubListener() + server, err := NewServer(ServerOptions{ + ListenAddress: "stub://gateway", + Logger: log.New(io.Discard, "", 0), + listenFn: func(string) (net.Listener, error) { + return listener, nil + }, + }) + if err != nil { + t.Fatalf("new server: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serveDone := make(chan error, 1) + go func() { + serveDone <- server.Serve(ctx, nil) + }() + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + closeDone := make(chan error, 1) + listener.onAccept = func() { + go func() { + closeDone <- server.Close(context.Background()) + }() + } + + listener.acceptCh <- serverConn + + select { + case closeErr := <-closeDone: + if closeErr != nil { + t.Fatalf("close server: %v", closeErr) + } + case <-time.After(2 * time.Second): + t.Fatal("close timed out") + } + + select { + case serveErr := <-serveDone: + if serveErr != nil { + t.Fatalf("serve returned error: %v", serveErr) + } + case <-time.After(2 * time.Second): + t.Fatal("serve did not exit") + } + + readDone := make(chan error, 1) + go func() { + var buf [1]byte + _, readErr := clientConn.Read(buf[:]) + readDone <- readErr + }() + + select { + case readErr := <-readDone: + if errors.Is(readErr, io.EOF) || errors.Is(readErr, net.ErrClosed) || errors.Is(readErr, os.ErrClosed) { + return + } + if readErr != nil && strings.Contains(readErr.Error(), "closed pipe") { + return + } + t.Fatalf("expected closed connection after server close, got %v", readErr) + case <-time.After(300 * time.Millisecond): + t.Fatal("connection was not closed by server") + } +} + +type stubListener struct { + acceptCh chan net.Conn + closeCh chan struct{} + + onAccept func() + closeOnce sync.Once +} + +func newStubListener() *stubListener { + return &stubListener{ + acceptCh: make(chan net.Conn, 1), + closeCh: make(chan struct{}), + } +} + +func (l *stubListener) Accept() (net.Conn, error) { + select { + case <-l.closeCh: + return nil, net.ErrClosed + case conn := <-l.acceptCh: + if l.onAccept != nil { + l.onAccept() + } + return conn, nil + } +} + +func (l *stubListener) Close() error { + l.closeOnce.Do(func() { + close(l.closeCh) + }) + return nil +} + +func (l *stubListener) Addr() net.Addr { + return stubAddr("stub://gateway") +} + +type stubAddr string + +func (a stubAddr) Network() string { + return "stub" +} + +func (a stubAddr) String() string { + return string(a) +} diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go new file mode 100644 index 00000000..aa7cae26 --- /dev/null +++ b/internal/gateway/server_test.go @@ -0,0 +1,191 @@ +package gateway + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "strings" + "testing" + "time" +) + +func TestServerHandleConnectionPing(t *testing.T) { + t.Parallel() + + server := &Server{logger: log.New(io.Discard, "", 0)} + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + encoder := json.NewEncoder(clientConn) + decoder := json.NewDecoder(clientConn) + + if err := encoder.Encode(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionPing, + RequestID: "req-1", + }); err != nil { + t.Fatalf("encode request: %v", err) + } + + var response MessageFrame + if err := decoder.Decode(&response); err != nil { + t.Fatalf("decode response: %v", err) + } + + if response.Type != FrameTypeAck { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeAck) + } + if response.Action != FrameActionPing { + t.Fatalf("response action = %q, want %q", response.Action, FrameActionPing) + } + if response.RequestID != "req-1" { + t.Fatalf("response request_id = %q, want %q", response.RequestID, "req-1") + } + + payloadMap, ok := response.Payload.(map[string]any) + if !ok { + t.Fatalf("response payload type = %T, want map[string]any", response.Payload) + } + if got, _ := payloadMap["message"].(string); got != "pong" { + t.Fatalf("response payload message = %q, want %q", got, "pong") + } + + _ = clientConn.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleConnection did not exit") + } +} + +func TestServerHandleConnectionUnsupportedAction(t *testing.T) { + t.Parallel() + + server := &Server{logger: log.New(io.Discard, "", 0)} + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + encoder := json.NewEncoder(clientConn) + decoder := json.NewDecoder(clientConn) + + if err := encoder.Encode(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionRun, + RequestID: "req-2", + InputText: "hello", + }); err != nil { + t.Fatalf("encode request: %v", err) + } + + var response MessageFrame + if err := decoder.Decode(&response); err != nil { + t.Fatalf("decode response: %v", err) + } + + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil { + t.Fatal("response error is nil") + } + if response.Error.Code != ErrorCodeUnsupportedAction.String() { + t.Fatalf("error code = %q, want %q", response.Error.Code, ErrorCodeUnsupportedAction.String()) + } + + _ = clientConn.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleConnection did not exit") + } +} + +func TestServerHandleConnectionRejectsOversizedFrame(t *testing.T) { + t.Parallel() + + server := &Server{logger: log.New(io.Discard, "", 0)} + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + decoder := json.NewDecoder(clientConn) + oversizedPayload := strings.Repeat("a", int(MaxFrameSize)+128) + requestFrame := fmt.Sprintf( + `{"type":"request","action":"ping","request_id":"req-oversize","input_text":"%s"}`+"\n", + oversizedPayload, + ) + + writeDone := make(chan error, 1) + go func() { + _, err := io.WriteString(clientConn, requestFrame) + writeDone <- err + }() + + var response MessageFrame + if err := decoder.Decode(&response); err != nil { + t.Fatalf("decode oversized response: %v", err) + } + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil { + t.Fatal("response error is nil") + } + if response.Error.Code != ErrorCodeInvalidFrame.String() { + t.Fatalf("error code = %q, want %q", response.Error.Code, ErrorCodeInvalidFrame.String()) + } + if !strings.Contains(response.Error.Message, "frame exceeds max size") { + t.Fatalf("error message = %q, want contains %q", response.Error.Message, "frame exceeds max size") + } + + select { + case <-writeDone: + case <-time.After(2 * time.Second): + t.Fatal("write oversized frame timed out") + } + + readDone := make(chan error, 1) + go func() { + var buf [1]byte + _, err := clientConn.Read(buf[:]) + readDone <- err + }() + + select { + case err := <-readDone: + if errors.Is(err, io.EOF) { + break + } + if err != nil && strings.Contains(err.Error(), "closed pipe") { + break + } + t.Fatalf("expected connection close after oversized frame, got %v", err) + case <-time.After(500 * time.Millisecond): + t.Fatal("connection was not closed after oversized frame") + } + + _ = clientConn.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleConnection did not exit") + } +} diff --git a/internal/gateway/transport/address_unix.go b/internal/gateway/transport/address_unix.go new file mode 100644 index 00000000..ed6b0c89 --- /dev/null +++ b/internal/gateway/transport/address_unix.go @@ -0,0 +1,20 @@ +//go:build !windows + +package transport + +import ( + "fmt" + "os" + "path/filepath" +) + +const defaultUnixSocketRelativePath = ".neocode/run/gateway.sock" + +// DefaultListenAddress 返回 Unix 系统默认监听地址。 +func DefaultListenAddress() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("gateway: resolve user home dir: %w", err) + } + return filepath.Join(homeDir, defaultUnixSocketRelativePath), nil +} diff --git a/internal/gateway/transport/address_unix_test.go b/internal/gateway/transport/address_unix_test.go new file mode 100644 index 00000000..568d3f43 --- /dev/null +++ b/internal/gateway/transport/address_unix_test.go @@ -0,0 +1,23 @@ +//go:build !windows + +package transport + +import ( + "path/filepath" + "testing" +) + +func TestDefaultListenAddress(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + + address, err := DefaultListenAddress() + if err != nil { + t.Fatalf("default listen address: %v", err) + } + + want := filepath.Join(home, defaultUnixSocketRelativePath) + if address != want { + t.Fatalf("default listen address = %q, want %q", address, want) + } +} diff --git a/internal/gateway/transport/address_windows.go b/internal/gateway/transport/address_windows.go new file mode 100644 index 00000000..f6b00306 --- /dev/null +++ b/internal/gateway/transport/address_windows.go @@ -0,0 +1,10 @@ +//go:build windows + +package transport + +const defaultWindowsNamedPipePath = `\\.\pipe\neocode-gateway` + +// DefaultListenAddress 返回 Windows 系统默认监听地址。 +func DefaultListenAddress() (string, error) { + return defaultWindowsNamedPipePath, nil +} diff --git a/internal/gateway/transport/listen.go b/internal/gateway/transport/listen.go new file mode 100644 index 00000000..91514564 --- /dev/null +++ b/internal/gateway/transport/listen.go @@ -0,0 +1,28 @@ +package transport + +import ( + "errors" + "net" +) + +// cleanupListener 在关闭底层监听器后执行额外清理逻辑。 +type cleanupListener struct { + net.Listener + cleanup func() error +} + +// newCleanupListener 包装监听器并注入清理钩子。 +func newCleanupListener(listener net.Listener, cleanup func() error) net.Listener { + if cleanup == nil { + return listener + } + return &cleanupListener{ + Listener: listener, + cleanup: cleanup, + } +} + +// Close 关闭监听器并执行额外清理。 +func (l *cleanupListener) Close() error { + return errors.Join(l.Listener.Close(), l.cleanup()) +} diff --git a/internal/gateway/transport/listen_unix.go b/internal/gateway/transport/listen_unix.go new file mode 100644 index 00000000..de2c3430 --- /dev/null +++ b/internal/gateway/transport/listen_unix.go @@ -0,0 +1,91 @@ +//go:build !windows + +package transport + +import ( + "fmt" + "net" + "os" + "path/filepath" +) + +const ( + // unixSocketDirPerm 定义 Unix socket 父目录权限(仅当前用户可访问)。 + unixSocketDirPerm os.FileMode = 0o700 + // unixSocketFilePerm 定义 Unix socket 文件权限(仅当前用户可读写)。 + unixSocketFilePerm os.FileMode = 0o600 +) + +// Listen 在 Unix 系统上启动 UDS 监听并在关闭时清理 socket 文件。 +func Listen(address string) (net.Listener, error) { + socketDir := filepath.Dir(address) + created, err := ensureSocketDir(socketDir) + if err != nil { + return nil, err + } + if created { + if err := os.Chmod(socketDir, unixSocketDirPerm); err != nil { + return nil, fmt.Errorf("gateway: set socket dir permission: %w", err) + } + } + + if err := removeStaleUnixSocket(address); err != nil { + return nil, err + } + + listener, err := net.Listen("unix", address) + if err != nil { + return nil, fmt.Errorf("gateway: listen unix socket: %w", err) + } + if err := os.Chmod(address, unixSocketFilePerm); err != nil { + _ = listener.Close() + return nil, fmt.Errorf("gateway: set socket file permission: %w", err) + } + + return newCleanupListener(listener, func() error { + if err := os.Remove(address); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("gateway: remove unix socket: %w", err) + } + return nil + }), nil +} + +// ensureSocketDir 确保 socket 父目录可用,并返回该目录是否由当前流程创建。 +func ensureSocketDir(socketDir string) (bool, error) { + info, err := os.Stat(socketDir) + if err == nil { + if !info.IsDir() { + return false, fmt.Errorf("gateway: socket dir path exists and is not directory: %s", socketDir) + } + return false, nil + } + if !os.IsNotExist(err) { + return false, fmt.Errorf("gateway: stat socket dir: %w", err) + } + + if err := os.MkdirAll(socketDir, unixSocketDirPerm); err != nil { + return false, fmt.Errorf("gateway: create socket dir: %w", err) + } + return true, nil +} + +// removeStaleUnixSocket 清理历史残留的 socket 文件,避免监听失败。 +func removeStaleUnixSocket(address string) error { + info, err := os.Lstat(address) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("gateway: stat unix socket path: %w", err) + } + + if info.Mode()&os.ModeSocket == 0 { + return fmt.Errorf("gateway: unix socket path exists and is not socket: %s", address) + } + + if err := os.Remove(address); err != nil { + return fmt.Errorf("gateway: remove stale unix socket: %w", err) + } + + return nil +} diff --git a/internal/gateway/transport/listen_unix_test.go b/internal/gateway/transport/listen_unix_test.go new file mode 100644 index 00000000..4e103a5d --- /dev/null +++ b/internal/gateway/transport/listen_unix_test.go @@ -0,0 +1,205 @@ +//go:build !windows + +package transport + +import ( + "errors" + "net" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestListenUnixAcceptsConnectionAndCleansSocket(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "run", "gateway.sock") + socketDir := filepath.Dir(socketPath) + listener, err := Listen(socketPath) + if err != nil { + t.Fatalf("listen unix socket: %v", err) + } + defer func() { + _ = listener.Close() + }() + + acceptDone := make(chan error, 1) + go func() { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + acceptDone <- acceptErr + return + } + _ = conn.Close() + acceptDone <- nil + }() + + conn, err := net.Dial("unix", socketPath) + if err != nil { + t.Fatalf("dial unix socket: %v", err) + } + _ = conn.Close() + + socketInfo, err := os.Stat(socketPath) + if err != nil { + t.Fatalf("stat socket file: %v", err) + } + if got := socketInfo.Mode() & os.ModePerm; got != unixSocketFilePerm { + t.Fatalf("socket file perm = %#o, want %#o", got, unixSocketFilePerm) + } + + dirInfo, err := os.Stat(socketDir) + if err != nil { + t.Fatalf("stat socket dir: %v", err) + } + if got := dirInfo.Mode() & os.ModePerm; got != unixSocketDirPerm { + t.Fatalf("socket dir perm = %#o, want %#o", got, unixSocketDirPerm) + } + + select { + case acceptErr := <-acceptDone: + if acceptErr != nil { + t.Fatalf("accept connection: %v", acceptErr) + } + case <-time.After(2 * time.Second): + t.Fatal("accept timed out") + } + + if err := listener.Close(); err != nil { + t.Fatalf("close listener: %v", err) + } + + if _, err := os.Stat(socketPath); !os.IsNotExist(err) { + t.Fatalf("socket file should be removed on close, stat err: %v", err) + } +} + +func TestListenUnixDoesNotChmodExistingDir(t *testing.T) { + t.Parallel() + + parentDir := filepath.Join(t.TempDir(), "existing") + if err := os.MkdirAll(parentDir, 0o755); err != nil { + t.Fatalf("create parent dir: %v", err) + } + + socketPath := filepath.Join(parentDir, "gateway.sock") + listener, err := Listen(socketPath) + if err != nil { + t.Fatalf("listen unix socket: %v", err) + } + defer func() { + _ = listener.Close() + }() + + dirInfo, err := os.Stat(parentDir) + if err != nil { + t.Fatalf("stat parent dir: %v", err) + } + if got := dirInfo.Mode() & os.ModePerm; got != 0o755 { + t.Fatalf("existing dir perm = %#o, want %#o", got, 0o755) + } +} + +func TestListenUnixSocketDirPathIsFile(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + filePath := filepath.Join(baseDir, "not-dir") + if err := os.WriteFile(filePath, []byte("x"), 0o600); err != nil { + t.Fatalf("write marker file: %v", err) + } + + socketPath := filepath.Join(filePath, "gateway.sock") + _, err := Listen(socketPath) + if err == nil { + t.Fatal("expected error when socket dir path is file") + } + if !strings.Contains(err.Error(), "is not directory") { + t.Fatalf("error = %v, want contains %q", err, "is not directory") + } +} + +func TestRemoveStaleUnixSocket(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "gateway.sock") + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("listen unix socket: %v", err) + } + _ = listener.Close() + + if err := removeStaleUnixSocket(socketPath); err != nil { + t.Fatalf("remove stale socket: %v", err) + } + if _, err := os.Stat(socketPath); !os.IsNotExist(err) { + t.Fatalf("socket should be removed, stat err: %v", err) + } +} + +func TestRemoveStaleUnixSocketNonSocketPath(t *testing.T) { + t.Parallel() + + filePath := filepath.Join(t.TempDir(), "plain-file") + if err := os.WriteFile(filePath, []byte("x"), 0o600); err != nil { + t.Fatalf("write marker file: %v", err) + } + + err := removeStaleUnixSocket(filePath) + if err == nil { + t.Fatal("expected error when stale path is non-socket") + } + if !strings.Contains(err.Error(), "is not socket") { + t.Fatalf("error = %v, want contains %q", err, "is not socket") + } +} + +func TestRemoveStaleUnixSocketNotExist(t *testing.T) { + t.Parallel() + + err := removeStaleUnixSocket(filepath.Join(t.TempDir(), "missing.sock")) + if err != nil { + t.Fatalf("remove missing stale socket: %v", err) + } +} + +func TestNewCleanupListenerWithoutCleanup(t *testing.T) { + t.Parallel() + + baseListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen tcp: %v", err) + } + defer func() { + _ = baseListener.Close() + }() + + wrapped := newCleanupListener(baseListener, nil) + if wrapped != baseListener { + t.Fatal("expected original listener when cleanup is nil") + } +} + +func TestCleanupListenerCloseReturnsJoinedError(t *testing.T) { + t.Parallel() + + baseListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen tcp: %v", err) + } + + cleanupErr := errors.New("cleanup failed") + wrapped := newCleanupListener(baseListener, func() error { + return cleanupErr + }) + + closeErr := wrapped.Close() + if closeErr == nil { + t.Fatal("expected close error") + } + if !errors.Is(closeErr, cleanupErr) { + t.Fatalf("close error = %v, want contains cleanup err %v", closeErr, cleanupErr) + } +} diff --git a/internal/gateway/transport/listen_windows.go b/internal/gateway/transport/listen_windows.go new file mode 100644 index 00000000..4ed05073 --- /dev/null +++ b/internal/gateway/transport/listen_windows.go @@ -0,0 +1,98 @@ +//go:build windows + +package transport + +import ( + "fmt" + "net" + + "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" +) + +const ( + pipeSDDLDiscretionaryACL = "D:P" +) + +var ( + listenPipeFn = winio.ListenPipe + currentProcessUserSIDFn = currentProcessUserSID + wellKnownSIDStringFn = wellKnownSIDString + getCurrentProcessTokenFn = windows.GetCurrentProcessToken + createWellKnownSIDFn = windows.CreateWellKnownSid +) + +// Listen 在 Windows 系统上启动 Named Pipe 监听,并显式收敛访问控制。 +func Listen(address string) (net.Listener, error) { + config, err := newRestrictedPipeConfig() + if err != nil { + return nil, err + } + + listener, err := listenPipeFn(address, config) + if err != nil { + return nil, fmt.Errorf("gateway: listen named pipe: %w", err) + } + return newCleanupListener(listener, nil), nil +} + +// newRestrictedPipeConfig 构建最小权限 PipeConfig,仅允许 SYSTEM、管理员组与当前用户访问。 +func newRestrictedPipeConfig() (*winio.PipeConfig, error) { + securityDescriptor, err := buildRestrictedPipeSecurityDescriptor() + if err != nil { + return nil, err + } + return &winio.PipeConfig{SecurityDescriptor: securityDescriptor}, nil +} + +// buildRestrictedPipeSecurityDescriptor 生成管道 ACL 的 SDDL 表达式。 +func buildRestrictedPipeSecurityDescriptor() (string, error) { + currentUserSID, err := currentProcessUserSIDFn() + if err != nil { + return "", err + } + + systemSID, err := wellKnownSIDStringFn(windows.WinLocalSystemSid) + if err != nil { + return "", fmt.Errorf("gateway: resolve local-system sid: %w", err) + } + + administratorsSID, err := wellKnownSIDStringFn(windows.WinBuiltinAdministratorsSid) + if err != nil { + return "", fmt.Errorf("gateway: resolve administrators sid: %w", err) + } + + return fmt.Sprintf( + "%s(%s)(%s)(%s)", + pipeSDDLDiscretionaryACL, + allowGenericAllAce(systemSID), + allowGenericAllAce(administratorsSID), + allowGenericAllAce(currentUserSID), + ), nil +} + +// currentProcessUserSID 返回当前进程用户的 SID 字符串。 +func currentProcessUserSID() (string, error) { + tokenUser, err := getCurrentProcessTokenFn().GetTokenUser() + if err != nil { + return "", fmt.Errorf("gateway: query current token user: %w", err) + } + if tokenUser == nil || tokenUser.User.Sid == nil { + return "", fmt.Errorf("gateway: current token user sid is empty") + } + return tokenUser.User.Sid.String(), nil +} + +// wellKnownSIDString 将系统内置 SID 类型转换为 SID 字符串。 +func wellKnownSIDString(sidType windows.WELL_KNOWN_SID_TYPE) (string, error) { + sid, err := createWellKnownSIDFn(sidType) + if err != nil { + return "", err + } + return sid.String(), nil +} + +// allowGenericAllAce 为指定 SID 生成“完全控制”ACE。 +func allowGenericAllAce(sid string) string { + return fmt.Sprintf("A;;GA;;;%s", sid) +} diff --git a/internal/gateway/transport/listen_windows_acl_test.go b/internal/gateway/transport/listen_windows_acl_test.go new file mode 100644 index 00000000..6d10e7ab --- /dev/null +++ b/internal/gateway/transport/listen_windows_acl_test.go @@ -0,0 +1,209 @@ +//go:build windows + +package transport + +import ( + "errors" + "net" + "strings" + "testing" + + "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" +) + +func TestDefaultListenAddressWindows(t *testing.T) { + address, err := DefaultListenAddress() + if err != nil { + t.Fatalf("default listen address: %v", err) + } + if address != defaultWindowsNamedPipePath { + t.Fatalf("default address = %q, want %q", address, defaultWindowsNamedPipePath) + } +} + +func TestNewCleanupListenerBranches(t *testing.T) { + base := &stubNetListener{} + if got := newCleanupListener(base, nil); got != base { + t.Fatal("expected original listener when cleanup is nil") + } + + closeErr := errors.New("close failed") + cleanupErr := errors.New("cleanup failed") + wrapped := newCleanupListener(&stubNetListener{closeErr: closeErr}, func() error { return cleanupErr }) + if err := wrapped.Close(); err == nil { + t.Fatal("expected joined error") + } else { + if !errors.Is(err, closeErr) { + t.Fatalf("joined error should include close error, got %v", err) + } + if !errors.Is(err, cleanupErr) { + t.Fatalf("joined error should include cleanup error, got %v", err) + } + } +} + +func TestBuildRestrictedPipeSecurityDescriptorContainsExpectedACEs(t *testing.T) { + sddl, err := buildRestrictedPipeSecurityDescriptor() + if err != nil { + t.Fatalf("build restricted descriptor: %v", err) + } + if !strings.HasPrefix(sddl, pipeSDDLDiscretionaryACL) { + t.Fatalf("sddl prefix = %q, want starts with %q", sddl, pipeSDDLDiscretionaryACL) + } + if strings.Count(sddl, "A;;GA;;;") != 3 { + t.Fatalf("sddl should contain 3 allow full-control ACE entries, got %q", sddl) + } +} + +func TestNewRestrictedPipeConfigErrorBranch(t *testing.T) { + originalCurrent := currentProcessUserSIDFn + currentProcessUserSIDFn = func() (string, error) { + return "", errors.New("current user failed") + } + t.Cleanup(func() { + currentProcessUserSIDFn = originalCurrent + }) + + _, err := newRestrictedPipeConfig() + if err == nil || !strings.Contains(err.Error(), "current user failed") { + t.Fatalf("expected current user error, got %v", err) + } +} + +func TestListenReturnsConfigError(t *testing.T) { + originalCurrent := currentProcessUserSIDFn + currentProcessUserSIDFn = func() (string, error) { + return "", errors.New("restricted config failed") + } + t.Cleanup(func() { + currentProcessUserSIDFn = originalCurrent + }) + + _, err := Listen(`\\.\pipe\neocode-gateway-config-error-test`) + if err == nil || !strings.Contains(err.Error(), "restricted config failed") { + t.Fatalf("expected config build failure, got %v", err) + } +} + +func TestBuildRestrictedPipeSecurityDescriptorSystemErrorBranch(t *testing.T) { + originalCurrent := currentProcessUserSIDFn + originalWellKnown := wellKnownSIDStringFn + currentProcessUserSIDFn = func() (string, error) { return "S-1-5-21-1", nil } + wellKnownSIDStringFn = func(sidType windows.WELL_KNOWN_SID_TYPE) (string, error) { + if sidType == windows.WinLocalSystemSid { + return "", errors.New("system sid failed") + } + return "S-1-5-32-544", nil + } + t.Cleanup(func() { + currentProcessUserSIDFn = originalCurrent + wellKnownSIDStringFn = originalWellKnown + }) + + _, err := buildRestrictedPipeSecurityDescriptor() + if err == nil || !strings.Contains(err.Error(), "system sid failed") { + t.Fatalf("expected system sid error, got %v", err) + } +} + +func TestBuildRestrictedPipeSecurityDescriptorAdminErrorBranch(t *testing.T) { + originalCurrent := currentProcessUserSIDFn + originalWellKnown := wellKnownSIDStringFn + currentProcessUserSIDFn = func() (string, error) { return "S-1-5-21-1", nil } + wellKnownSIDStringFn = func(sidType windows.WELL_KNOWN_SID_TYPE) (string, error) { + if sidType == windows.WinBuiltinAdministratorsSid { + return "", errors.New("admin sid failed") + } + return "S-1-5-18", nil + } + t.Cleanup(func() { + currentProcessUserSIDFn = originalCurrent + wellKnownSIDStringFn = originalWellKnown + }) + + _, err := buildRestrictedPipeSecurityDescriptor() + if err == nil || !strings.Contains(err.Error(), "admin sid failed") { + t.Fatalf("expected admin sid error, got %v", err) + } +} + +func TestListenReturnsListenPipeError(t *testing.T) { + originalCurrent := currentProcessUserSIDFn + originalWellKnown := wellKnownSIDStringFn + originalListenPipe := listenPipeFn + currentProcessUserSIDFn = func() (string, error) { return "S-1-5-21-1", nil } + wellKnownSIDStringFn = func(sidType windows.WELL_KNOWN_SID_TYPE) (string, error) { + if sidType == windows.WinLocalSystemSid { + return "S-1-5-18", nil + } + if sidType == windows.WinBuiltinAdministratorsSid { + return "S-1-5-32-544", nil + } + return "", errors.New("unexpected sid type") + } + listenPipeFn = func(_ string, _ *winio.PipeConfig) (net.Listener, error) { + return nil, errors.New("listen pipe failed") + } + t.Cleanup(func() { + currentProcessUserSIDFn = originalCurrent + wellKnownSIDStringFn = originalWellKnown + listenPipeFn = originalListenPipe + }) + + _, err := Listen(`\\.\pipe\neocode-gateway-error-test`) + if err == nil || !strings.Contains(err.Error(), "listen pipe failed") { + t.Fatalf("expected listen pipe failure, got %v", err) + } +} + +func TestCurrentProcessUserSIDErrorBranch(t *testing.T) { + originalTokenFn := getCurrentProcessTokenFn + getCurrentProcessTokenFn = func() windows.Token { + return windows.Token(0) + } + t.Cleanup(func() { + getCurrentProcessTokenFn = originalTokenFn + }) + + _, err := currentProcessUserSID() + if err == nil { + t.Fatal("expected current process token user error") + } +} + +func TestWellKnownSIDStringErrorBranch(t *testing.T) { + originalCreateWellKnownSID := createWellKnownSIDFn + createWellKnownSIDFn = func(_ windows.WELL_KNOWN_SID_TYPE) (*windows.SID, error) { + return nil, errors.New("create sid failed") + } + t.Cleanup(func() { + createWellKnownSIDFn = originalCreateWellKnownSID + }) + + _, err := wellKnownSIDString(windows.WinLocalSystemSid) + if err == nil || !strings.Contains(err.Error(), "create sid failed") { + t.Fatalf("expected create sid failure, got %v", err) + } +} + +type stubNetListener struct { + closeErr error +} + +func (l *stubNetListener) Accept() (net.Conn, error) { + return nil, net.ErrClosed +} + +func (l *stubNetListener) Close() error { + return l.closeErr +} + +func (l *stubNetListener) Addr() net.Addr { + return pipeAddr("stub") +} + +type pipeAddr string + +func (a pipeAddr) Network() string { return "pipe" } +func (a pipeAddr) String() string { return string(a) } diff --git a/internal/gateway/transport/listen_windows_test.go b/internal/gateway/transport/listen_windows_test.go new file mode 100644 index 00000000..826734ec --- /dev/null +++ b/internal/gateway/transport/listen_windows_test.go @@ -0,0 +1,88 @@ +//go:build windows + +package transport + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/Microsoft/go-winio" + "golang.org/x/sys/windows" +) + +func TestListenNamedPipeAcceptsConnection(t *testing.T) { + pipePath := fmt.Sprintf(`\\.\pipe\neocode-gateway-test-%d`, time.Now().UnixNano()) + listener, err := Listen(pipePath) + if err != nil { + t.Fatalf("listen named pipe: %v", err) + } + defer func() { + _ = listener.Close() + }() + + acceptDone := make(chan error, 1) + go func() { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + acceptDone <- acceptErr + return + } + _ = conn.Close() + acceptDone <- nil + }() + + timeout := 2 * time.Second + conn, err := winio.DialPipe(pipePath, &timeout) + if err != nil { + t.Fatalf("dial named pipe: %v", err) + } + _ = conn.Close() + + select { + case acceptErr := <-acceptDone: + if acceptErr != nil { + t.Fatalf("accept connection: %v", acceptErr) + } + case <-time.After(3 * time.Second): + t.Fatal("accept timed out") + } +} + +func TestNewRestrictedPipeConfigContainsExpectedSIDs(t *testing.T) { + config, err := newRestrictedPipeConfig() + if err != nil { + t.Fatalf("new restricted pipe config: %v", err) + } + if config == nil { + t.Fatal("pipe config is nil") + } + if config.SecurityDescriptor == "" { + t.Fatal("security descriptor is empty") + } + + currentUserSID, err := currentProcessUserSID() + if err != nil { + t.Fatalf("current user sid: %v", err) + } + if !strings.Contains(config.SecurityDescriptor, currentUserSID) { + t.Fatalf("security descriptor does not contain current user sid %q", currentUserSID) + } + + systemSID, err := wellKnownSIDString(windows.WinLocalSystemSid) + if err != nil { + t.Fatalf("system sid: %v", err) + } + if !strings.Contains(config.SecurityDescriptor, systemSID) { + t.Fatalf("security descriptor does not contain system sid %q", systemSID) + } + + adminSID, err := wellKnownSIDString(windows.WinBuiltinAdministratorsSid) + if err != nil { + t.Fatalf("administrators sid: %v", err) + } + if !strings.Contains(config.SecurityDescriptor, adminSID) { + t.Fatalf("security descriptor does not contain administrators sid %q", adminSID) + } +} diff --git a/internal/gateway/types.go b/internal/gateway/types.go index 34807ba8..851fffc4 100644 --- a/internal/gateway/types.go +++ b/internal/gateway/types.go @@ -18,6 +18,8 @@ const ( type FrameAction string const ( + // FrameActionPing 表示探活动作,用于验证网关可用性。 + FrameActionPing FrameAction = "ping" // FrameActionRun 表示发起一次运行。 FrameActionRun FrameAction = "run" // FrameActionCompact 表示触发一次手动压缩。 diff --git a/internal/gateway/validate.go b/internal/gateway/validate.go index d83d1a15..aae8283f 100644 --- a/internal/gateway/validate.go +++ b/internal/gateway/validate.go @@ -30,6 +30,8 @@ func validateRequestFrame(frame MessageFrame) *FrameError { } switch frame.Action { + case FrameActionPing: + return nil case FrameActionRun: return validateRunFrame(frame) case FrameActionCompact, FrameActionLoadSession: @@ -168,7 +170,8 @@ func isValidFrameType(frameType FrameType) bool { // isValidFrameAction 判断动作是否属于协议定义集合。 func isValidFrameAction(action FrameAction) bool { switch action { - case FrameActionRun, + case FrameActionPing, + FrameActionRun, FrameActionCompact, FrameActionCancel, FrameActionListSessions, diff --git a/internal/gateway/validate_additional_test.go b/internal/gateway/validate_additional_test.go new file mode 100644 index 00000000..e8611842 --- /dev/null +++ b/internal/gateway/validate_additional_test.go @@ -0,0 +1,89 @@ +package gateway + +import ( + "strings" + "testing" +) + +func TestDecodePermissionResolutionInputAdditionalBranches(t *testing.T) { + t.Parallel() + + t.Run("nil permission pointer", func(t *testing.T) { + var input *PermissionResolutionInput + _, err := decodePermissionResolutionInput(input) + if err == nil || !strings.Contains(err.Error(), "is nil") { + t.Fatalf("expected nil pointer error, got %v", err) + } + }) + + t.Run("marshal error", func(t *testing.T) { + payload := map[string]any{"bad": func() {}} + _, err := decodePermissionResolutionInput(payload) + if err == nil { + t.Fatal("expected marshal error") + } + }) + + t.Run("unmarshal error", func(t *testing.T) { + _, err := decodePermissionResolutionInput([]byte("not-json-object")) + if err == nil { + t.Fatal("expected unmarshal error") + } + }) +} + +func TestValidateRequestFrameRunsInputPartsValidationForCompact(t *testing.T) { + t.Parallel() + + err := ValidateFrame(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionCompact, + SessionID: "sess-1", + InputParts: []InputPart{{ + Type: InputPartTypeText, + Text: " ", + }}, + }) + if err == nil { + t.Fatal("expected input_parts validation error") + } + if err.Code != ErrorCodeInvalidMultimodalPayload.String() { + t.Fatalf("error code = %q, want %q", err.Code, ErrorCodeInvalidMultimodalPayload.String()) + } +} + +func TestValidateFrameCancelAndListSessions(t *testing.T) { + t.Parallel() + + cancelErr := ValidateFrame(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionCancel, + }) + if cancelErr != nil { + t.Fatalf("cancel request should be valid, got %v", cancelErr) + } + + listErr := ValidateFrame(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionListSessions, + }) + if listErr != nil { + t.Fatalf("list_sessions request should be valid, got %v", listErr) + } +} + +func TestValidateResolvePermissionInvalidPayloadType(t *testing.T) { + t.Parallel() + + err := ValidateFrame(MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionResolvePermission, + Payload: make(chan int), + }) + if err == nil { + t.Fatal("expected invalid resolve_permission payload error") + } + if err.Code != ErrorCodeInvalidAction.String() { + t.Fatalf("error code = %q, want %q", err.Code, ErrorCodeInvalidAction.String()) + } +} diff --git a/internal/gateway/validate_test.go b/internal/gateway/validate_test.go index 48267aa0..85f2dad1 100644 --- a/internal/gateway/validate_test.go +++ b/internal/gateway/validate_test.go @@ -13,6 +13,15 @@ func TestValidateFrame_BasicRules(t *testing.T) { wantCode string wantField string }{ + { + name: "valid ping request", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionPing, + RequestID: "req-ping", + }, + wantNil: true, + }, { name: "valid run with input_text", frame: MessageFrame{