diff --git a/internal/gateway/server.go b/internal/gateway/server.go index b253c26d..9078d082 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -13,6 +13,7 @@ import ( "os" "strings" "sync" + "time" "neo-code/internal/gateway/transport" ) @@ -20,6 +21,13 @@ import ( const ( // MaxFrameSize 定义单条 JSON 帧允许的最大字节数,避免异常输入导致内存放大。 MaxFrameSize int64 = 1 << 20 // 1 MiB + + // DefaultMaxConnections 定义服务允许的最大并发连接数,超过上限的连接会被快速拒绝。 + DefaultMaxConnections = 128 + // DefaultReadTimeout 定义单次读帧的最大等待时间,避免慢连接长期占用资源。 + DefaultReadTimeout = 30 * time.Second + // DefaultWriteTimeout 定义单次写帧的最大等待时间,避免写阻塞占用处理协程。 + DefaultWriteTimeout = 30 * time.Second ) var ( @@ -31,16 +39,22 @@ var ( // ServerOptions 描述网关服务启动所需的可选配置。 type ServerOptions struct { - ListenAddress string - Logger *log.Logger - listenFn func(address string) (net.Listener, error) + ListenAddress string + Logger *log.Logger + MaxConnections int + ReadTimeout time.Duration + WriteTimeout time.Duration + listenFn func(address string) (net.Listener, error) } // Server 提供基于本地 IPC 的网关服务骨架实现。 type Server struct { - listenAddress string - logger *log.Logger - listenFn func(address string) (net.Listener, error) + listenAddress string + logger *log.Logger + listenFn func(address string) (net.Listener, error) + maxConnections int + readTimeout time.Duration + writeTimeout time.Duration mu sync.Mutex listener net.Listener @@ -48,6 +62,14 @@ type Server struct { wg sync.WaitGroup } +type registerConnectionResult int + +const ( + registerConnectionAccepted registerConnectionResult = iota + registerConnectionServerClosed + registerConnectionLimitExceeded +) + // NewServer 创建网关服务实例,并解析默认监听地址。 func NewServer(options ServerOptions) (*Server, error) { listenAddress := strings.TrimSpace(options.ListenAddress) @@ -69,11 +91,29 @@ func NewServer(options ServerOptions) (*Server, error) { listenFn = transport.Listen } + maxConnections := options.MaxConnections + if maxConnections <= 0 { + maxConnections = DefaultMaxConnections + } + + readTimeout := options.ReadTimeout + if readTimeout <= 0 { + readTimeout = DefaultReadTimeout + } + + writeTimeout := options.WriteTimeout + if writeTimeout <= 0 { + writeTimeout = DefaultWriteTimeout + } + return &Server{ - listenAddress: listenAddress, - logger: logger, - listenFn: listenFn, - conns: make(map[net.Conn]struct{}), + listenAddress: listenAddress, + logger: logger, + listenFn: listenFn, + maxConnections: maxConnections, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + conns: make(map[net.Conn]struct{}), }, nil } @@ -114,12 +154,17 @@ func (s *Server) Serve(ctx context.Context, runtimePort RuntimePort) error { return fmt.Errorf("gateway: accept connection: %w", acceptErr) } - if !s.registerConnection(conn) { + switch s.registerConnection(conn) { + case registerConnectionAccepted: + case registerConnectionServerClosed: + _ = conn.Close() + continue + case registerConnectionLimitExceeded: + s.logger.Printf("reject connection: max connections %d reached", s.maxConnections) _ = conn.Close() continue } - s.wg.Add(1) go func() { defer s.wg.Done() defer s.untrackConnection(conn) @@ -178,15 +223,19 @@ func (s *Server) snapshotConnections() map[net.Conn]struct{} { return copied } -// registerConnection 在服务可用时登记连接,若网关已关闭则拒绝登记。 -func (s *Server) registerConnection(conn net.Conn) bool { +// registerConnection 在服务可用且未超限时登记连接,并原子增加连接处理 WaitGroup 计数。 +func (s *Server) registerConnection(conn net.Conn) registerConnectionResult { s.mu.Lock() defer s.mu.Unlock() if s.listener == nil { - return false + return registerConnectionServerClosed + } + if len(s.conns) >= s.maxConnections { + return registerConnectionLimitExceeded } s.conns[conn] = struct{}{} - return true + s.wg.Add(1) + return registerConnectionAccepted } // untrackConnection 移除已结束连接,避免连接集合持续增长。 @@ -212,6 +261,11 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor default: } + if err := s.applyReadDeadline(conn); err != nil { + s.logger.Printf("set read deadline failed: %v", err) + return + } + frame, err := decodeFrame(reader) if err != nil { if errors.Is(err, io.EOF) { @@ -220,9 +274,13 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor if errors.Is(err, errFrameEmpty) { continue } + if isTimeoutError(err) { + s.logger.Printf("read frame timeout: %v", err) + return + } if errors.Is(err, errFrameTooLarge) { s.logger.Printf("decode frame failed: %v", err) - _ = encoder.Encode(errorFrame(MessageFrame{}, NewFrameError( + _ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError( ErrorCodeInvalidFrame, fmt.Sprintf("frame exceeds max size %d bytes", MaxFrameSize), ))) @@ -230,18 +288,52 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor } s.logger.Printf("decode frame failed: %v", err) - _ = encoder.Encode(errorFrame(MessageFrame{}, NewFrameError(ErrorCodeInvalidFrame, "invalid json frame"))) + _ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError(ErrorCodeInvalidFrame, "invalid json frame"))) return } response := s.dispatchFrame(ctx, frame, runtimePort) - if err := encoder.Encode(response); err != nil { - s.logger.Printf("write frame failed: %v", err) + if !s.writeFrame(conn, encoder, response) { return } } } +// applyReadDeadline 为当前连接设置下一次读操作超时,避免慢读连接长期占用协程。 +func (s *Server) applyReadDeadline(conn net.Conn) error { + if s.readTimeout <= 0 { + return nil + } + return conn.SetReadDeadline(time.Now().Add(s.readTimeout)) +} + +// applyWriteDeadline 为当前连接设置下一次写操作超时,避免写阻塞导致协程泄漏。 +func (s *Server) applyWriteDeadline(conn net.Conn) error { + if s.writeTimeout <= 0 { + return nil + } + return conn.SetWriteDeadline(time.Now().Add(s.writeTimeout)) +} + +// writeFrame 统一处理响应写回及写超时设置,失败时返回 false 供上层快速终止连接循环。 +func (s *Server) writeFrame(conn net.Conn, encoder *json.Encoder, frame MessageFrame) bool { + if err := s.applyWriteDeadline(conn); err != nil { + s.logger.Printf("set write deadline failed: %v", err) + return false + } + if err := encoder.Encode(frame); err != nil { + s.logger.Printf("write frame failed: %v", err) + return false + } + return true +} + +// isTimeoutError 判断错误是否为网络超时,用于区分慢连接超时与协议错误。 +func isTimeoutError(err error) bool { + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} + // decodeFrame 从连接读取一条 JSON 帧并执行长度与格式校验。 func decodeFrame(reader *bufio.Reader) (MessageFrame, error) { payload, err := readFramePayload(reader, MaxFrameSize) diff --git a/internal/gateway/server_additional_test.go b/internal/gateway/server_additional_test.go index 5689ddec..3fcc7118 100644 --- a/internal/gateway/server_additional_test.go +++ b/internal/gateway/server_additional_test.go @@ -36,11 +36,23 @@ func TestNewServerUsesDefaultsAndOverrides(t *testing.T) { if server.listenFn == nil { t.Fatal("default listen function should not be nil") } + if server.maxConnections != DefaultMaxConnections { + t.Fatalf("default max connections = %d, want %d", server.maxConnections, DefaultMaxConnections) + } + if server.readTimeout != DefaultReadTimeout { + t.Fatalf("default read timeout = %v, want %v", server.readTimeout, DefaultReadTimeout) + } + if server.writeTimeout != DefaultWriteTimeout { + t.Fatalf("default write timeout = %v, want %v", server.writeTimeout, DefaultWriteTimeout) + } customLogger := log.New(io.Discard, "custom", 0) customServer, err := NewServer(ServerOptions{ - ListenAddress: " custom-address ", - Logger: customLogger, + ListenAddress: " custom-address ", + Logger: customLogger, + MaxConnections: 7, + ReadTimeout: 150 * time.Millisecond, + WriteTimeout: 250 * time.Millisecond, listenFn: func(string) (net.Listener, error) { return nil, nil }, @@ -54,6 +66,15 @@ func TestNewServerUsesDefaultsAndOverrides(t *testing.T) { if customServer.logger != customLogger { t.Fatal("custom logger was not used") } + if customServer.maxConnections != 7 { + t.Fatalf("custom max connections = %d, want %d", customServer.maxConnections, 7) + } + if customServer.readTimeout != 150*time.Millisecond { + t.Fatalf("custom read timeout = %v, want %v", customServer.readTimeout, 150*time.Millisecond) + } + if customServer.writeTimeout != 250*time.Millisecond { + t.Fatalf("custom write timeout = %v, want %v", customServer.writeTimeout, 250*time.Millisecond) + } } func TestNewServerReturnsDefaultAddressError(t *testing.T) { @@ -324,6 +345,84 @@ func TestServerHandleConnectionInvalidJSONFrame(t *testing.T) { } } +func TestRegisterConnectionRejectsWhenLimitExceeded(t *testing.T) { + server := &Server{ + listener: &simpleListener{}, + maxConnections: 1, + conns: make(map[net.Conn]struct{}), + } + + conn1Server, conn1Client := net.Pipe() + defer conn1Client.Close() + defer conn1Server.Close() + if got := server.registerConnection(conn1Server); got != registerConnectionAccepted { + t.Fatalf("first register result = %v, want accepted", got) + } + + conn2Server, conn2Client := net.Pipe() + defer conn2Client.Close() + defer conn2Server.Close() + if got := server.registerConnection(conn2Server); got != registerConnectionLimitExceeded { + t.Fatalf("second register result = %v, want limit exceeded", got) + } + + server.untrackConnection(conn1Server) + server.wg.Done() +} + +func TestServerHandleConnectionReadTimeoutClosesConnection(t *testing.T) { + server := &Server{ + logger: log.New(io.Discard, "", 0), + readTimeout: 20 * time.Millisecond, + } + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("handleConnection should exit after read timeout") + } + + var buf [1]byte + _, err := clientConn.Read(buf[:]) + if !errors.Is(err, io.EOF) && (err == nil || !strings.Contains(err.Error(), "closed pipe")) { + t.Fatalf("expected closed connection after timeout, got %v", err) + } + _ = clientConn.Close() +} + +func TestServerHandleConnectionWriteTimeoutClosesConnection(t *testing.T) { + server := &Server{ + logger: log.New(io.Discard, "", 0), + readTimeout: time.Second, + writeTimeout: 20 * time.Millisecond, + } + serverConn, clientConn := net.Pipe() + done := make(chan struct{}) + go func() { + defer close(done) + server.handleConnection(context.Background(), serverConn, nil) + }() + + _, err := io.WriteString(clientConn, `{"type":"request","action":"ping","request_id":"write-timeout"}`+"\n") + if err != nil { + t.Fatalf("write request: %v", err) + } + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("handleConnection should exit after write timeout") + } + + _ = clientConn.Close() +} + type failingReader struct{} func (r *failingReader) Read(_ []byte) (int, error) { diff --git a/internal/gateway/transport/listen_windows_acl_test.go b/internal/gateway/transport/listen_windows_acl_test.go index c92b8ae9..6d10e7ab 100644 --- a/internal/gateway/transport/listen_windows_acl_test.go +++ b/internal/gateway/transport/listen_windows_acl_test.go @@ -13,8 +13,6 @@ import ( ) func TestDefaultListenAddressWindows(t *testing.T) { - t.Parallel() - address, err := DefaultListenAddress() if err != nil { t.Fatalf("default listen address: %v", err) @@ -25,8 +23,6 @@ func TestDefaultListenAddressWindows(t *testing.T) { } func TestNewCleanupListenerBranches(t *testing.T) { - t.Parallel() - base := &stubNetListener{} if got := newCleanupListener(base, nil); got != base { t.Fatal("expected original listener when cleanup is nil") @@ -48,8 +44,6 @@ func TestNewCleanupListenerBranches(t *testing.T) { } func TestBuildRestrictedPipeSecurityDescriptorContainsExpectedACEs(t *testing.T) { - t.Parallel() - sddl, err := buildRestrictedPipeSecurityDescriptor() if err != nil { t.Fatalf("build restricted descriptor: %v", err) diff --git a/internal/gateway/transport/listen_windows_test.go b/internal/gateway/transport/listen_windows_test.go index 338f0b6b..826734ec 100644 --- a/internal/gateway/transport/listen_windows_test.go +++ b/internal/gateway/transport/listen_windows_test.go @@ -13,8 +13,6 @@ import ( ) func TestListenNamedPipeAcceptsConnection(t *testing.T) { - t.Parallel() - pipePath := fmt.Sprintf(`\\.\pipe\neocode-gateway-test-%d`, time.Now().UnixNano()) listener, err := Listen(pipePath) if err != nil { @@ -53,8 +51,6 @@ func TestListenNamedPipeAcceptsConnection(t *testing.T) { } func TestNewRestrictedPipeConfigContainsExpectedSIDs(t *testing.T) { - t.Parallel() - config, err := newRestrictedPipeConfig() if err != nil { t.Fatalf("new restricted pipe config: %v", err)