Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 112 additions & 20 deletions internal/gateway/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@ import (
"os"
"strings"
"sync"
"time"

"neo-code/internal/gateway/transport"
)

const (
// MaxFrameSize 定义单条 JSON 帧允许的最大字节数,避免异常输入导致内存放大。
MaxFrameSize int64 = 1 << 20 // 1 MiB

// DefaultMaxConnections 定义服务允许的最大并发连接数,超过上限的连接会被快速拒绝。
DefaultMaxConnections = 128
// DefaultReadTimeout 定义单次读帧的最大等待时间,避免慢连接长期占用资源。
DefaultReadTimeout = 30 * time.Second
// DefaultWriteTimeout 定义单次写帧的最大等待时间,避免写阻塞占用处理协程。
DefaultWriteTimeout = 30 * time.Second
)

var (
Expand All @@ -31,23 +39,37 @@ var (

// ServerOptions 描述网关服务启动所需的可选配置。
type ServerOptions struct {
ListenAddress string
Logger *log.Logger
listenFn func(address string) (net.Listener, error)
ListenAddress string
Logger *log.Logger
MaxConnections int
ReadTimeout time.Duration
WriteTimeout time.Duration
listenFn func(address string) (net.Listener, error)
}

// Server 提供基于本地 IPC 的网关服务骨架实现。
type Server struct {
listenAddress string
logger *log.Logger
listenFn func(address string) (net.Listener, error)
listenAddress string
logger *log.Logger
listenFn func(address string) (net.Listener, error)
maxConnections int
readTimeout time.Duration
writeTimeout time.Duration

mu sync.Mutex
listener net.Listener
conns map[net.Conn]struct{}
wg sync.WaitGroup
}

type registerConnectionResult int

const (
registerConnectionAccepted registerConnectionResult = iota
registerConnectionServerClosed
registerConnectionLimitExceeded
)

// NewServer 创建网关服务实例,并解析默认监听地址。
func NewServer(options ServerOptions) (*Server, error) {
listenAddress := strings.TrimSpace(options.ListenAddress)
Expand All @@ -69,11 +91,29 @@ func NewServer(options ServerOptions) (*Server, error) {
listenFn = transport.Listen
}

maxConnections := options.MaxConnections
if maxConnections <= 0 {
maxConnections = DefaultMaxConnections
}

readTimeout := options.ReadTimeout
if readTimeout <= 0 {
readTimeout = DefaultReadTimeout
}

writeTimeout := options.WriteTimeout
if writeTimeout <= 0 {
writeTimeout = DefaultWriteTimeout
}

return &Server{
listenAddress: listenAddress,
logger: logger,
listenFn: listenFn,
conns: make(map[net.Conn]struct{}),
listenAddress: listenAddress,
logger: logger,
listenFn: listenFn,
maxConnections: maxConnections,
readTimeout: readTimeout,
writeTimeout: writeTimeout,
conns: make(map[net.Conn]struct{}),
}, nil
}

Expand Down Expand Up @@ -114,12 +154,17 @@ func (s *Server) Serve(ctx context.Context, runtimePort RuntimePort) error {
return fmt.Errorf("gateway: accept connection: %w", acceptErr)
}

if !s.registerConnection(conn) {
switch s.registerConnection(conn) {
case registerConnectionAccepted:
case registerConnectionServerClosed:
_ = conn.Close()
continue
case registerConnectionLimitExceeded:
s.logger.Printf("reject connection: max connections %d reached", s.maxConnections)
_ = conn.Close()
continue
}

s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.untrackConnection(conn)
Expand Down Expand Up @@ -178,15 +223,19 @@ func (s *Server) snapshotConnections() map[net.Conn]struct{} {
return copied
}

// registerConnection 在服务可用时登记连接,若网关已关闭则拒绝登记
func (s *Server) registerConnection(conn net.Conn) bool {
// registerConnection 在服务可用且未超限时登记连接,并原子增加连接处理 WaitGroup 计数
func (s *Server) registerConnection(conn net.Conn) registerConnectionResult {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener == nil {
return false
return registerConnectionServerClosed
}
if len(s.conns) >= s.maxConnections {
return registerConnectionLimitExceeded
}
s.conns[conn] = struct{}{}
return true
s.wg.Add(1)
return registerConnectionAccepted
}

// untrackConnection 移除已结束连接,避免连接集合持续增长。
Expand All @@ -212,6 +261,11 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor
default:
}

if err := s.applyReadDeadline(conn); err != nil {
s.logger.Printf("set read deadline failed: %v", err)
return
}

frame, err := decodeFrame(reader)
if err != nil {
if errors.Is(err, io.EOF) {
Expand All @@ -220,28 +274,66 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn, runtimePor
if errors.Is(err, errFrameEmpty) {
continue
}
if isTimeoutError(err) {
s.logger.Printf("read frame timeout: %v", err)
return
}
if errors.Is(err, errFrameTooLarge) {
s.logger.Printf("decode frame failed: %v", err)
_ = encoder.Encode(errorFrame(MessageFrame{}, NewFrameError(
_ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError(
ErrorCodeInvalidFrame,
fmt.Sprintf("frame exceeds max size %d bytes", MaxFrameSize),
)))
return
}

s.logger.Printf("decode frame failed: %v", err)
_ = encoder.Encode(errorFrame(MessageFrame{}, NewFrameError(ErrorCodeInvalidFrame, "invalid json frame")))
_ = s.writeFrame(conn, encoder, errorFrame(MessageFrame{}, NewFrameError(ErrorCodeInvalidFrame, "invalid json frame")))
return
}

response := s.dispatchFrame(ctx, frame, runtimePort)
if err := encoder.Encode(response); err != nil {
s.logger.Printf("write frame failed: %v", err)
if !s.writeFrame(conn, encoder, response) {
return
}
}
}

// applyReadDeadline 为当前连接设置下一次读操作超时,避免慢读连接长期占用协程。
func (s *Server) applyReadDeadline(conn net.Conn) error {
if s.readTimeout <= 0 {
return nil
}
return conn.SetReadDeadline(time.Now().Add(s.readTimeout))
}

// applyWriteDeadline 为当前连接设置下一次写操作超时,避免写阻塞导致协程泄漏。
func (s *Server) applyWriteDeadline(conn net.Conn) error {
if s.writeTimeout <= 0 {
return nil
}
return conn.SetWriteDeadline(time.Now().Add(s.writeTimeout))
}

// writeFrame 统一处理响应写回及写超时设置,失败时返回 false 供上层快速终止连接循环。
func (s *Server) writeFrame(conn net.Conn, encoder *json.Encoder, frame MessageFrame) bool {
if err := s.applyWriteDeadline(conn); err != nil {
s.logger.Printf("set write deadline failed: %v", err)
return false
}
if err := encoder.Encode(frame); err != nil {
s.logger.Printf("write frame failed: %v", err)
return false
}
return true
}

// isTimeoutError 判断错误是否为网络超时,用于区分慢连接超时与协议错误。
func isTimeoutError(err error) bool {
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}

// decodeFrame 从连接读取一条 JSON 帧并执行长度与格式校验。
func decodeFrame(reader *bufio.Reader) (MessageFrame, error) {
payload, err := readFramePayload(reader, MaxFrameSize)
Expand Down
103 changes: 101 additions & 2 deletions internal/gateway/server_additional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,23 @@ func TestNewServerUsesDefaultsAndOverrides(t *testing.T) {
if server.listenFn == nil {
t.Fatal("default listen function should not be nil")
}
if server.maxConnections != DefaultMaxConnections {
t.Fatalf("default max connections = %d, want %d", server.maxConnections, DefaultMaxConnections)
}
if server.readTimeout != DefaultReadTimeout {
t.Fatalf("default read timeout = %v, want %v", server.readTimeout, DefaultReadTimeout)
}
if server.writeTimeout != DefaultWriteTimeout {
t.Fatalf("default write timeout = %v, want %v", server.writeTimeout, DefaultWriteTimeout)
}

customLogger := log.New(io.Discard, "custom", 0)
customServer, err := NewServer(ServerOptions{
ListenAddress: " custom-address ",
Logger: customLogger,
ListenAddress: " custom-address ",
Logger: customLogger,
MaxConnections: 7,
ReadTimeout: 150 * time.Millisecond,
WriteTimeout: 250 * time.Millisecond,
listenFn: func(string) (net.Listener, error) {
return nil, nil
},
Expand All @@ -54,6 +66,15 @@ func TestNewServerUsesDefaultsAndOverrides(t *testing.T) {
if customServer.logger != customLogger {
t.Fatal("custom logger was not used")
}
if customServer.maxConnections != 7 {
t.Fatalf("custom max connections = %d, want %d", customServer.maxConnections, 7)
}
if customServer.readTimeout != 150*time.Millisecond {
t.Fatalf("custom read timeout = %v, want %v", customServer.readTimeout, 150*time.Millisecond)
}
if customServer.writeTimeout != 250*time.Millisecond {
t.Fatalf("custom write timeout = %v, want %v", customServer.writeTimeout, 250*time.Millisecond)
}
}

func TestNewServerReturnsDefaultAddressError(t *testing.T) {
Expand Down Expand Up @@ -324,6 +345,84 @@ func TestServerHandleConnectionInvalidJSONFrame(t *testing.T) {
}
}

func TestRegisterConnectionRejectsWhenLimitExceeded(t *testing.T) {
server := &Server{
listener: &simpleListener{},
maxConnections: 1,
conns: make(map[net.Conn]struct{}),
}

conn1Server, conn1Client := net.Pipe()
defer conn1Client.Close()
defer conn1Server.Close()
if got := server.registerConnection(conn1Server); got != registerConnectionAccepted {
t.Fatalf("first register result = %v, want accepted", got)
}

conn2Server, conn2Client := net.Pipe()
defer conn2Client.Close()
defer conn2Server.Close()
if got := server.registerConnection(conn2Server); got != registerConnectionLimitExceeded {
t.Fatalf("second register result = %v, want limit exceeded", got)
}

server.untrackConnection(conn1Server)
server.wg.Done()
}

func TestServerHandleConnectionReadTimeoutClosesConnection(t *testing.T) {
server := &Server{
logger: log.New(io.Discard, "", 0),
readTimeout: 20 * time.Millisecond,
}
serverConn, clientConn := net.Pipe()
done := make(chan struct{})
go func() {
defer close(done)
server.handleConnection(context.Background(), serverConn, nil)
}()

select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("handleConnection should exit after read timeout")
}

var buf [1]byte
_, err := clientConn.Read(buf[:])
if !errors.Is(err, io.EOF) && (err == nil || !strings.Contains(err.Error(), "closed pipe")) {
t.Fatalf("expected closed connection after timeout, got %v", err)
}
_ = clientConn.Close()
}

func TestServerHandleConnectionWriteTimeoutClosesConnection(t *testing.T) {
server := &Server{
logger: log.New(io.Discard, "", 0),
readTimeout: time.Second,
writeTimeout: 20 * time.Millisecond,
}
serverConn, clientConn := net.Pipe()
done := make(chan struct{})
go func() {
defer close(done)
server.handleConnection(context.Background(), serverConn, nil)
}()

_, err := io.WriteString(clientConn, `{"type":"request","action":"ping","request_id":"write-timeout"}`+"\n")
if err != nil {
t.Fatalf("write request: %v", err)
}

select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("handleConnection should exit after write timeout")
}

_ = clientConn.Close()
}

type failingReader struct{}

func (r *failingReader) Read(_ []byte) (int, error) {
Expand Down
6 changes: 0 additions & 6 deletions internal/gateway/transport/listen_windows_acl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import (
)

func TestDefaultListenAddressWindows(t *testing.T) {
t.Parallel()

address, err := DefaultListenAddress()
if err != nil {
t.Fatalf("default listen address: %v", err)
Expand All @@ -25,8 +23,6 @@ func TestDefaultListenAddressWindows(t *testing.T) {
}

func TestNewCleanupListenerBranches(t *testing.T) {
t.Parallel()

base := &stubNetListener{}
if got := newCleanupListener(base, nil); got != base {
t.Fatal("expected original listener when cleanup is nil")
Expand All @@ -48,8 +44,6 @@ func TestNewCleanupListenerBranches(t *testing.T) {
}

func TestBuildRestrictedPipeSecurityDescriptorContainsExpectedACEs(t *testing.T) {
t.Parallel()

sddl, err := buildRestrictedPipeSecurityDescriptor()
if err != nil {
t.Fatalf("build restricted descriptor: %v", err)
Expand Down
4 changes: 0 additions & 4 deletions internal/gateway/transport/listen_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import (
)

func TestListenNamedPipeAcceptsConnection(t *testing.T) {
t.Parallel()

pipePath := fmt.Sprintf(`\\.\pipe\neocode-gateway-test-%d`, time.Now().UnixNano())
listener, err := Listen(pipePath)
if err != nil {
Expand Down Expand Up @@ -53,8 +51,6 @@ func TestListenNamedPipeAcceptsConnection(t *testing.T) {
}

func TestNewRestrictedPipeConfigContainsExpectedSIDs(t *testing.T) {
t.Parallel()

config, err := newRestrictedPipeConfig()
if err != nil {
t.Fatalf("new restricted pipe config: %v", err)
Expand Down
Loading