From 9cf7099852698eb40daac8016f39c70d486817ee Mon Sep 17 00:00:00 2001 From: xgopilot Date: Fri, 24 Apr 2026 05:05:17 +0000 Subject: [PATCH] fix(gateway): resolve remaining review risks and tighten checks Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: pionxe <148670367+pionxe@users.noreply.github.com> --- Makefile | 7 +- .../gateway/adapters/urlscheme/dispatcher.go | 221 ++++++++++-------- .../adapters/urlscheme/dispatcher_test.go | 136 +++++------ internal/gateway/launcher/launcher.go | 101 ++++++-- internal/gateway/launcher/launcher_test.go | 117 +++++++--- scripts/check_gateway_docs/main.go | 112 +++++++++ scripts/check_gateway_docs/main_test.go | 112 +++++++++ 7 files changed, 593 insertions(+), 213 deletions(-) create mode 100644 scripts/check_gateway_docs/main.go create mode 100644 scripts/check_gateway_docs/main_test.go diff --git a/Makefile b/Makefile index 3b5260e6..6cea1f23 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,14 @@ .PHONY: install-skills docs-gateway docs-gateway-check +GATEWAY_DOCS_GENERATOR := go run -tags gatewaydocgen ./scripts/generate_gateway_rpc_examples.go + install-skills: @./scripts/install_skills.sh docs-gateway: - @go run -tags gatewaydocgen ./scripts/generate_gateway_rpc_examples.go + @$(GATEWAY_DOCS_GENERATOR) docs-gateway-check: - @go run -tags gatewaydocgen ./scripts/generate_gateway_rpc_examples.go + @$(GATEWAY_DOCS_GENERATOR) + @go run ./scripts/check_gateway_docs @git diff --exit-code -- docs/generated/gateway-rpc-examples.json diff --git a/internal/gateway/adapters/urlscheme/dispatcher.go b/internal/gateway/adapters/urlscheme/dispatcher.go index 95495a41..81444a88 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher.go +++ b/internal/gateway/adapters/urlscheme/dispatcher.go @@ -166,31 +166,17 @@ func (d *Dispatcher) Dispatch(ctx context.Context, request DispatchRequest) (Dis if err != nil { return DispatchResult{}, err } - if strings.TrimSpace(rpcResponse.JSONRPC) != protocol.JSONRPCVersion { - return DispatchResult{}, newDispatchError( - ErrorCodeUnexpectedResponse, - "unexpected response jsonrpc version", - ) - } - if !rawJSONMessageEqual(rpcResponse.ID, rpcRequest.ID) { - return DispatchResult{}, newDispatchError(ErrorCodeUnexpectedResponse, "rpc correlation failed: id mismatch") - } - if rpcResponse.Error != nil && rpcResponse.Result != nil { - return DispatchResult{}, newDispatchError( - ErrorCodeUnexpectedResponse, - "unexpected response payload: both result and error are present", - ) - } - if rpcResponse.Error != nil { - return DispatchResult{}, toDispatchErrorFromJSONRPC(rpcResponse.Error) - } - if rpcResponse.Result == nil { - return DispatchResult{}, newDispatchError(ErrorCodeUnexpectedResponse, "gateway response missing result payload") - } - - responseFrame, err := decodeResponseFrameResult(rpcResponse.Result) + responseFrame, err := validateRPCFrameResponse( + rpcResponse, + rpcRequest.ID, + "unexpected response jsonrpc version", + "rpc correlation failed: id mismatch", + "unexpected response payload: both result and error are present", + "gateway response missing result payload", + "decode response frame: %v", + ) if err != nil { - return DispatchResult{}, newDispatchError(ErrorCodeUnexpectedResponse, fmt.Sprintf("decode response frame: %v", err)) + return DispatchResult{}, err } if responseFrame.Action != requestFrame.Action || responseFrame.RequestID != requestFrame.RequestID { return DispatchResult{}, newDispatchError( @@ -238,21 +224,17 @@ func (d *Dispatcher) authenticate(ctx context.Context, conn net.Conn, token stri if err != nil { return err } - if strings.TrimSpace(authResponse.JSONRPC) != protocol.JSONRPCVersion { - return newDispatchError(ErrorCodeUnexpectedResponse, "unexpected auth response jsonrpc version") - } - if !rawJSONMessageEqual(authResponse.ID, authRequest.ID) { - return newDispatchError(ErrorCodeUnexpectedResponse, "rpc correlation failed: auth id mismatch") - } - if authResponse.Error != nil { - return toDispatchErrorFromJSONRPC(authResponse.Error) - } - if authResponse.Result == nil { - return newDispatchError(ErrorCodeUnexpectedResponse, "gateway auth response missing result payload") - } - frame, err := decodeResponseFrameResult(authResponse.Result) + frame, err := validateRPCFrameResponse( + authResponse, + authRequest.ID, + "unexpected auth response jsonrpc version", + "rpc correlation failed: auth id mismatch", + "unexpected response payload: both result and error are present", + "gateway auth response missing result payload", + "decode auth response frame: %v", + ) if err != nil { - return newDispatchError(ErrorCodeUnexpectedResponse, fmt.Sprintf("decode auth response frame: %v", err)) + return err } if frame.Type != gateway.FrameTypeAck || frame.Action != gateway.FrameActionAuthenticate || frame.RequestID != authRequestID { return newDispatchError(ErrorCodeUnexpectedResponse, "unexpected auth response frame") @@ -347,76 +329,79 @@ func (d *Dispatcher) launchGateway(ctx context.Context, listenAddress string, re spec, err := resolveLaunchSpecFn() if err != nil { - d.emitLaunchDecisionLog(launchDecisionLogEntry{ - RequestID: requestID, - Method: string(protocol.MethodWakeOpenURL), - Source: "url-dispatch", - Status: "launch_failed", - GatewayCode: ErrorCodeGatewayUnavailable, - ListenAddress: listenAddress, - AuthMode: resolveAuthMode(authToken), - Message: err.Error(), - }) + d.emitLaunchFailureLog(requestID, listenAddress, authToken, launcher.LaunchSpec{}, err) return err } - d.emitLaunchDecisionLog(launchDecisionLogEntry{ - RequestID: requestID, - Method: string(protocol.MethodWakeOpenURL), - Source: "url-dispatch", - Status: "launch_attempt", - ListenAddress: listenAddress, - AuthMode: resolveAuthMode(authToken), - LaunchMode: spec.LaunchMode, - ResolvedExec: spec.Executable, - }) + d.emitLaunchDecisionLog(newLaunchDecisionLogEntry( + requestID, + listenAddress, + authToken, + "launch_attempt", + "", + spec, + "", + )) launchSpec := spec launchSpec.Args = buildGatewayLaunchArgs(spec.Args, listenAddress) if err := startGatewayFn(launchSpec); err != nil { - d.emitLaunchDecisionLog(launchDecisionLogEntry{ - RequestID: requestID, - Method: string(protocol.MethodWakeOpenURL), - Source: "url-dispatch", - Status: "launch_failed", - GatewayCode: ErrorCodeGatewayUnavailable, - ListenAddress: listenAddress, - AuthMode: resolveAuthMode(authToken), - LaunchMode: spec.LaunchMode, - ResolvedExec: spec.Executable, - Message: err.Error(), - }) + d.emitLaunchFailureLog(requestID, listenAddress, authToken, spec, err) return err } if err := d.waitGatewayReady(ctx, listenAddress); err != nil { - d.emitLaunchDecisionLog(launchDecisionLogEntry{ - RequestID: requestID, - Method: string(protocol.MethodWakeOpenURL), - Source: "url-dispatch", - Status: "launch_failed", - GatewayCode: ErrorCodeGatewayUnavailable, - ListenAddress: listenAddress, - AuthMode: resolveAuthMode(authToken), - LaunchMode: spec.LaunchMode, - ResolvedExec: spec.Executable, - Message: err.Error(), - }) + d.emitLaunchFailureLog(requestID, listenAddress, authToken, spec, err) return err } - d.emitLaunchDecisionLog(launchDecisionLogEntry{ - RequestID: requestID, - Method: string(protocol.MethodWakeOpenURL), - Source: "url-dispatch", - Status: "launch_ready", - ListenAddress: listenAddress, - AuthMode: resolveAuthMode(authToken), - LaunchMode: spec.LaunchMode, - ResolvedExec: spec.Executable, - }) + d.emitLaunchDecisionLog(newLaunchDecisionLogEntry( + requestID, + listenAddress, + authToken, + "launch_ready", + "", + spec, + "", + )) return nil } +// validateRPCFrameResponse 统一校验 JSON-RPC 基础字段并解码结果帧,保持调度与鉴权分支一致。 +func validateRPCFrameResponse( + response protocol.JSONRPCResponse, + expectedID json.RawMessage, + versionMismatchMessage string, + idMismatchMessage string, + dualPayloadMessage string, + missingResultMessage string, + decodeFrameMessageFormat string, +) (gateway.MessageFrame, error) { + if strings.TrimSpace(response.JSONRPC) != protocol.JSONRPCVersion { + return gateway.MessageFrame{}, newDispatchError(ErrorCodeUnexpectedResponse, versionMismatchMessage) + } + if !rawJSONMessageEqual(response.ID, expectedID) { + return gateway.MessageFrame{}, newDispatchError(ErrorCodeUnexpectedResponse, idMismatchMessage) + } + if response.Error != nil && response.Result != nil { + return gateway.MessageFrame{}, newDispatchError(ErrorCodeUnexpectedResponse, dualPayloadMessage) + } + if response.Error != nil { + return gateway.MessageFrame{}, toDispatchErrorFromJSONRPC(response.Error) + } + if response.Result == nil { + return gateway.MessageFrame{}, newDispatchError(ErrorCodeUnexpectedResponse, missingResultMessage) + } + + frame, err := decodeResponseFrameResult(response.Result) + if err != nil { + return gateway.MessageFrame{}, newDispatchError( + ErrorCodeUnexpectedResponse, + fmt.Sprintf(decodeFrameMessageFormat, err), + ) + } + return frame, nil +} + // buildGatewayLaunchArgs 构造自动拉起参数,确保子进程监听地址与调度重拨地址一致。 func buildGatewayLaunchArgs(baseArgs []string, listenAddress string) []string { args := append([]string(nil), baseArgs...) @@ -438,12 +423,17 @@ func (d *Dispatcher) waitGatewayReady(ctx context.Context, listenAddress string) sleepFn = time.Sleep } - deadline := nowFn().Add(defaultGatewayLaunchTimeout) + startTime := nowFn() + deadline := startTime.Add(defaultGatewayLaunchTimeout) if ctx != nil { if ctxDeadline, ok := ctx.Deadline(); ok && ctxDeadline.Before(deadline) { deadline = ctxDeadline } } + effectiveTimeout := deadline.Sub(startTime) + if effectiveTimeout < 0 { + effectiveTimeout = 0 + } for { if err := ensureDispatchContextActive(ctx); err != nil { @@ -455,7 +445,7 @@ func (d *Dispatcher) waitGatewayReady(ctx context.Context, listenAddress string) return nil } if !nowFn().Before(deadline) { - return fmt.Errorf("gateway did not become reachable within %s", defaultGatewayLaunchTimeout) + return fmt.Errorf("gateway did not become reachable within %s", effectiveTimeout) } sleepFn(defaultGatewayLaunchRetryInterval) } @@ -474,6 +464,49 @@ func (d *Dispatcher) emitLaunchDecisionLog(entry launchDecisionLogEntry) { d.logger.Print(string(raw)) } +// newLaunchDecisionLogEntry 构造统一的网关拉起日志字段,避免各分支重复拼装。 +func newLaunchDecisionLogEntry( + requestID string, + listenAddress string, + authToken string, + status string, + gatewayCode string, + spec launcher.LaunchSpec, + message string, +) launchDecisionLogEntry { + return launchDecisionLogEntry{ + RequestID: requestID, + Method: string(protocol.MethodWakeOpenURL), + Source: "url-dispatch", + Status: status, + GatewayCode: gatewayCode, + ListenAddress: listenAddress, + AuthMode: resolveAuthMode(authToken), + LaunchMode: spec.LaunchMode, + ResolvedExec: spec.Executable, + Message: message, + } +} + +// emitLaunchFailureLog 输出统一的启动失败日志,保持失败分支字段稳定。 +func (d *Dispatcher) emitLaunchFailureLog( + requestID string, + listenAddress string, + authToken string, + spec launcher.LaunchSpec, + err error, +) { + d.emitLaunchDecisionLog(newLaunchDecisionLogEntry( + requestID, + listenAddress, + authToken, + "launch_failed", + ErrorCodeGatewayUnavailable, + spec, + err.Error(), + )) +} + // resolveAuthMode 归一化调度鉴权模式,便于日志与兼容性测试稳定断言。 func resolveAuthMode(authToken string) string { if strings.TrimSpace(authToken) == "" { diff --git a/internal/gateway/adapters/urlscheme/dispatcher_test.go b/internal/gateway/adapters/urlscheme/dispatcher_test.go index d193a867..9057e264 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher_test.go +++ b/internal/gateway/adapters/urlscheme/dispatcher_test.go @@ -20,6 +20,43 @@ import ( "neo-code/internal/gateway/transport" ) +// newStubDispatcher 创建测试用调度器,统一默认依赖并允许按需覆盖。 +func newStubDispatcher(overrides func(*Dispatcher)) *Dispatcher { + dispatcher := &Dispatcher{ + resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, + dialFn: func(string) (net.Conn, error) { return &stubDispatchConn{}, nil }, + requestIDFn: func() string { return "wake-test" }, + } + if overrides != nil { + overrides(dispatcher) + } + return dispatcher +} + +// assertDispatchErrorCode 校验错误会被映射为指定的 DispatchError 码。 +func assertDispatchErrorCode(t *testing.T, err error, wantCode string) *DispatchError { + t.Helper() + + var dispatchErr *DispatchError + if !errors.As(err, &dispatchErr) { + t.Fatalf("error type = %T, want *DispatchError", err) + } + if dispatchErr.Code != wantCode { + t.Fatalf("error code = %q, want %q", dispatchErr.Code, wantCode) + } + return dispatchErr +} + +// assertDispatchErrorMessageContains 校验结构化错误包含预期消息片段。 +func assertDispatchErrorMessageContains(t *testing.T, err error, wantCode string, wantMessage string) { + t.Helper() + + dispatchErr := assertDispatchErrorCode(t, err, wantCode) + if !strings.Contains(dispatchErr.Message, wantMessage) { + t.Fatalf("error message = %q, want contains %q", dispatchErr.Message, wantMessage) + } +} + func TestDispatcherDispatchSuccess(t *testing.T) { serverConn, clientConn := net.Pipe() t.Cleanup(func() { @@ -27,17 +64,14 @@ func TestDispatcherDispatchSuccess(t *testing.T) { _ = clientConn.Close() }) - dispatcher := &Dispatcher{ - resolveListenAddressFn: func(string) (string, error) { - return "stub://gateway", nil - }, - dialFn: func(string) (net.Conn, error) { + dispatcher := newStubDispatcher(func(dispatcher *Dispatcher) { + dispatcher.dialFn = func(string) (net.Conn, error) { return clientConn, nil - }, - requestIDFn: func() string { + } + dispatcher.requestIDFn = func() string { return "wake-1" - }, - } + } + }) done := make(chan struct{}) go func() { @@ -115,11 +149,10 @@ func TestDispatcherDispatchReturnsGatewayError(t *testing.T) { _ = clientConn.Close() }) - dispatcher := &Dispatcher{ - resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, - dialFn: func(string) (net.Conn, error) { return clientConn, nil }, - requestIDFn: func() string { return "wake-2" }, - } + dispatcher := newStubDispatcher(func(dispatcher *Dispatcher) { + dispatcher.dialFn = func(string) (net.Conn, error) { return clientConn, nil } + dispatcher.requestIDFn = func() string { return "wake-2" } + }) go func() { decoder := json.NewDecoder(serverConn) @@ -144,13 +177,7 @@ func TestDispatcherDispatchReturnsGatewayError(t *testing.T) { t.Fatal("expected gateway error") } - var dispatchErr *DispatchError - if !errors.As(err, &dispatchErr) { - t.Fatalf("error type = %T, want *DispatchError", err) - } - if dispatchErr.Code != gateway.ErrorCodeInvalidAction.String() { - t.Fatalf("error code = %q, want %q", dispatchErr.Code, gateway.ErrorCodeInvalidAction.String()) - } + assertDispatchErrorCode(t, err, gateway.ErrorCodeInvalidAction.String()) } func TestDispatcherDispatchReturnsUnexpectedResponseError(t *testing.T) { @@ -160,11 +187,10 @@ func TestDispatcherDispatchReturnsUnexpectedResponseError(t *testing.T) { _ = clientConn.Close() }) - dispatcher := &Dispatcher{ - resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, - dialFn: func(string) (net.Conn, error) { return clientConn, nil }, - requestIDFn: func() string { return "wake-3" }, - } + dispatcher := newStubDispatcher(func(dispatcher *Dispatcher) { + dispatcher.dialFn = func(string) (net.Conn, error) { return clientConn, nil } + dispatcher.requestIDFn = func() string { return "wake-3" } + }) go func() { decoder := json.NewDecoder(serverConn) @@ -188,13 +214,7 @@ func TestDispatcherDispatchReturnsUnexpectedResponseError(t *testing.T) { if err == nil { t.Fatal("expected unexpected response error") } - var dispatchErr *DispatchError - if !errors.As(err, &dispatchErr) { - t.Fatalf("error type = %T, want *DispatchError", err) - } - if dispatchErr.Code != ErrorCodeUnexpectedResponse { - t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) - } + assertDispatchErrorCode(t, err, ErrorCodeUnexpectedResponse) } func TestDispatcherDispatchReturnsCorrelationMismatchError(t *testing.T) { @@ -204,11 +224,10 @@ func TestDispatcherDispatchReturnsCorrelationMismatchError(t *testing.T) { _ = clientConn.Close() }) - dispatcher := &Dispatcher{ - resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, - dialFn: func(string) (net.Conn, error) { return clientConn, nil }, - requestIDFn: func() string { return "wake-9" }, - } + dispatcher := newStubDispatcher(func(dispatcher *Dispatcher) { + dispatcher.dialFn = func(string) (net.Conn, error) { return clientConn, nil } + dispatcher.requestIDFn = func() string { return "wake-9" } + }) go func() { decoder := json.NewDecoder(serverConn) @@ -232,26 +251,16 @@ func TestDispatcherDispatchReturnsCorrelationMismatchError(t *testing.T) { if err == nil { t.Fatal("expected correlation mismatch error") } - var dispatchErr *DispatchError - if !errors.As(err, &dispatchErr) { - t.Fatalf("error type = %T, want *DispatchError", err) - } - if dispatchErr.Code != ErrorCodeUnexpectedResponse { - t.Fatalf("error code = %q, want %q", dispatchErr.Code, ErrorCodeUnexpectedResponse) - } - if !strings.Contains(dispatchErr.Message, "frame correlation failed") { - t.Fatalf("error message = %q, want correlation failure", dispatchErr.Message) - } + assertDispatchErrorMessageContains(t, err, ErrorCodeUnexpectedResponse, "frame correlation failed") } func TestDispatcherDispatchInputAndDialErrors(t *testing.T) { - dispatcher := &Dispatcher{ - resolveListenAddressFn: func(string) (string, error) { return "stub://gateway", nil }, - dialFn: func(string) (net.Conn, error) { + dispatcher := newStubDispatcher(func(dispatcher *Dispatcher) { + dispatcher.dialFn = func(string) (net.Conn, error) { return nil, errors.New("dial failed") - }, - requestIDFn: func() string { return "wake-4" }, - } + } + dispatcher.requestIDFn = func() string { return "wake-4" } + }) _, parseErr := dispatcher.Dispatch(context.Background(), DispatchRequest{ RawURL: "http://review?path=README.md", @@ -259,13 +268,7 @@ func TestDispatcherDispatchInputAndDialErrors(t *testing.T) { if parseErr == nil { t.Fatal("expected parse error") } - var parseDispatchErr *DispatchError - if !errors.As(parseErr, &parseDispatchErr) { - t.Fatalf("parse error type = %T, want *DispatchError", parseErr) - } - if parseDispatchErr.Code != "invalid_scheme" { - t.Fatalf("parse error code = %q, want %q", parseDispatchErr.Code, "invalid_scheme") - } + assertDispatchErrorCode(t, parseErr, "invalid_scheme") _, dialErr := dispatcher.Dispatch(context.Background(), DispatchRequest{ RawURL: "neocode://review?path=README.md", @@ -273,13 +276,7 @@ func TestDispatcherDispatchInputAndDialErrors(t *testing.T) { if dialErr == nil { t.Fatal("expected dial error") } - var dialDispatchErr *DispatchError - if !errors.As(dialErr, &dialDispatchErr) { - t.Fatalf("dial error type = %T, want *DispatchError", dialErr) - } - if dialDispatchErr.Code != ErrorCodeGatewayUnavailable { - t.Fatalf("dial error code = %q, want %q", dialDispatchErr.Code, ErrorCodeGatewayUnavailable) - } + assertDispatchErrorCode(t, dialErr, ErrorCodeGatewayUnavailable) } func TestDispatcherDialGatewayWithSingleLaunchFallback(t *testing.T) { @@ -1343,6 +1340,9 @@ func TestDispatcherWaitGatewayReadyBranches(t *testing.T) { if !strings.Contains(err.Error(), "did not become reachable") && !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("expected timeout-related error, got %v", err) } + if !errors.Is(err, context.DeadlineExceeded) && !strings.Contains(err.Error(), "40ms") { + t.Fatalf("error = %v, want contains %q when timeout message is returned", err, "40ms") + } if sleepCalls != 0 { t.Fatalf("sleepCalls = %d, want %d", sleepCalls, 0) } diff --git a/internal/gateway/launcher/launcher.go b/internal/gateway/launcher/launcher.go index f29b7a89..b0b846d9 100644 --- a/internal/gateway/launcher/launcher.go +++ b/internal/gateway/launcher/launcher.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "strings" ) @@ -56,41 +57,97 @@ func resolveGatewayLaunchSpecWithDeps( options ResolveOptions, lookPathFn func(string) (string, error), ) (LaunchSpec, error) { - resolveByLookup := func(binary string) (string, error) { - resolved, err := lookPathFn(strings.TrimSpace(binary)) - if err != nil { - return "", fmt.Errorf("resolve executable %q: %w", strings.TrimSpace(binary), err) - } - return strings.TrimSpace(resolved), nil - } - explicitBinary := strings.TrimSpace(options.ExplicitBinary) if explicitBinary != "" { - resolved, err := resolveByLookup(explicitBinary) + if err := validateExplicitGatewayBinary(explicitBinary); err != nil { + return LaunchSpec{}, err + } + spec, err := resolveLaunchSpecCandidate( + lookPathFn, + explicitBinary, + LaunchModeExplicitPath, + nil, + "explicit gateway binary", + ) if err != nil { return LaunchSpec{}, err } - return LaunchSpec{ - LaunchMode: LaunchModeExplicitPath, - Executable: resolved, - }, nil + return spec, nil } - if resolved, err := resolveByLookup("neocode-gateway"); err == nil { - return LaunchSpec{ - LaunchMode: LaunchModePathBinary, - Executable: resolved, - }, nil + resolvedPathBinary, err := resolveExecutablePath(lookPathFn, "neocode-gateway") + if err == nil { + return resolveLaunchSpecFromResolvedPath( + resolvedPathBinary, + LaunchModePathBinary, + nil, + "PATH neocode-gateway", + ) } - resolvedFallbackExecutable, err := resolveByLookup("neocode") + return resolveLaunchSpecCandidate( + lookPathFn, + "neocode", + LaunchModeFallbackSubcommand, + []string{"gateway"}, + "PATH neocode", + ) +} + +// resolveLaunchSpecCandidate 统一处理可执行查找、绝对路径校验与 LaunchSpec 构造。 +func resolveLaunchSpecCandidate( + lookPathFn func(string) (string, error), + binary string, + launchMode string, + args []string, + source string, +) (LaunchSpec, error) { + resolvedPath, err := resolveExecutablePath(lookPathFn, binary) if err != nil { return LaunchSpec{}, err } + return resolveLaunchSpecFromResolvedPath(resolvedPath, launchMode, args, source) +} +// resolveLaunchSpecFromResolvedPath 基于已解析的路径构造启动规格,并保留绝对路径校验。 +func resolveLaunchSpecFromResolvedPath( + resolvedPath string, + launchMode string, + args []string, + source string, +) (LaunchSpec, error) { + if err := validateResolvedExecutablePath(resolvedPath, source); err != nil { + return LaunchSpec{}, err + } return LaunchSpec{ - LaunchMode: LaunchModeFallbackSubcommand, - Executable: resolvedFallbackExecutable, - Args: []string{"gateway"}, + LaunchMode: launchMode, + Executable: resolvedPath, + Args: append([]string(nil), args...), }, nil } + +// resolveExecutablePath 统一处理可执行路径查找与空白归一化。 +func resolveExecutablePath(lookPathFn func(string) (string, error), binary string) (string, error) { + trimmedBinary := strings.TrimSpace(binary) + resolvedPath, err := lookPathFn(trimmedBinary) + if err != nil { + return "", fmt.Errorf("resolve executable %q: %w", trimmedBinary, err) + } + return strings.TrimSpace(resolvedPath), nil +} + +// validateExplicitGatewayBinary 校验显式配置的网关二进制路径,禁止使用相对路径降低 PATH 劫持风险。 +func validateExplicitGatewayBinary(explicitBinary string) error { + if !filepath.IsAbs(explicitBinary) { + return fmt.Errorf("explicit gateway binary must be an absolute path: %q", explicitBinary) + } + return nil +} + +// validateResolvedExecutablePath 校验解析后的可执行路径必须为绝对路径,避免执行不受控相对路径目标。 +func validateResolvedExecutablePath(resolvedPath string, source string) error { + if !filepath.IsAbs(resolvedPath) { + return fmt.Errorf("resolved executable from %s is not an absolute path: %q", source, resolvedPath) + } + return nil +} diff --git a/internal/gateway/launcher/launcher_test.go b/internal/gateway/launcher/launcher_test.go index 9738cf5e..afa22c31 100644 --- a/internal/gateway/launcher/launcher_test.go +++ b/internal/gateway/launcher/launcher_test.go @@ -6,10 +6,26 @@ import ( "path/filepath" "reflect" "runtime" + "strings" "testing" "time" ) +// assertLaunchSpecEqual 校验解析出的启动规格,保持测试断言结构一致。 +func assertLaunchSpecEqual(t *testing.T, spec LaunchSpec, want LaunchSpec) { + t.Helper() + + if spec.LaunchMode != want.LaunchMode { + t.Fatalf("launch mode = %q, want %q", spec.LaunchMode, want.LaunchMode) + } + if spec.Executable != want.Executable { + t.Fatalf("executable = %q, want %q", spec.Executable, want.Executable) + } + if !reflect.DeepEqual(spec.Args, want.Args) { + t.Fatalf("args = %#v, want %#v", spec.Args, want.Args) + } +} + func TestResolveGatewayLaunchSpecWithDeps(t *testing.T) { t.Run("explicit binary has highest priority", func(t *testing.T) { spec, err := resolveGatewayLaunchSpecWithDeps( @@ -24,15 +40,10 @@ func TestResolveGatewayLaunchSpecWithDeps(t *testing.T) { if err != nil { t.Fatalf("resolveGatewayLaunchSpecWithDeps() error = %v", err) } - if spec.LaunchMode != LaunchModeExplicitPath { - t.Fatalf("launch mode = %q, want %q", spec.LaunchMode, LaunchModeExplicitPath) - } - if spec.Executable != "/opt/tools/neocode-gateway" { - t.Fatalf("executable = %q, want %q", spec.Executable, "/opt/tools/neocode-gateway") - } - if len(spec.Args) != 0 { - t.Fatalf("args = %#v, want empty", spec.Args) - } + assertLaunchSpecEqual(t, spec, LaunchSpec{ + LaunchMode: LaunchModeExplicitPath, + Executable: "/opt/tools/neocode-gateway", + }) }) t.Run("path binary preferred over fallback", func(t *testing.T) { @@ -48,15 +59,10 @@ func TestResolveGatewayLaunchSpecWithDeps(t *testing.T) { if err != nil { t.Fatalf("resolveGatewayLaunchSpecWithDeps() error = %v", err) } - if spec.LaunchMode != LaunchModePathBinary { - t.Fatalf("launch mode = %q, want %q", spec.LaunchMode, LaunchModePathBinary) - } - if spec.Executable != "/usr/local/bin/neocode-gateway" { - t.Fatalf("executable = %q, want %q", spec.Executable, "/usr/local/bin/neocode-gateway") - } - if len(spec.Args) != 0 { - t.Fatalf("args = %#v, want empty", spec.Args) - } + assertLaunchSpecEqual(t, spec, LaunchSpec{ + LaunchMode: LaunchModePathBinary, + Executable: "/usr/local/bin/neocode-gateway", + }) }) t.Run("fallback to neocode subcommand", func(t *testing.T) { @@ -76,15 +82,11 @@ func TestResolveGatewayLaunchSpecWithDeps(t *testing.T) { if err != nil { t.Fatalf("resolveGatewayLaunchSpecWithDeps() error = %v", err) } - if spec.LaunchMode != LaunchModeFallbackSubcommand { - t.Fatalf("launch mode = %q, want %q", spec.LaunchMode, LaunchModeFallbackSubcommand) - } - if spec.Executable != "/usr/local/bin/neocode" { - t.Fatalf("executable = %q, want %q", spec.Executable, "/usr/local/bin/neocode") - } - if !reflect.DeepEqual(spec.Args, []string{"gateway"}) { - t.Fatalf("args = %#v, want %#v", spec.Args, []string{"gateway"}) - } + assertLaunchSpecEqual(t, spec, LaunchSpec{ + LaunchMode: LaunchModeFallbackSubcommand, + Executable: "/usr/local/bin/neocode", + Args: []string{"gateway"}, + }) }) t.Run("explicit binary lookup failure returns error", func(t *testing.T) { @@ -99,6 +101,67 @@ func TestResolveGatewayLaunchSpecWithDeps(t *testing.T) { } }) + t.Run("explicit binary must be absolute path", func(t *testing.T) { + lookupCalled := false + _, err := resolveGatewayLaunchSpecWithDeps( + ResolveOptions{ExplicitBinary: "neocode-gateway"}, + func(string) (string, error) { + lookupCalled = true + return "", nil + }, + ) + if err == nil { + t.Fatal("expected explicit path validation error") + } + if lookupCalled { + t.Fatal("lookPath should not be called for invalid explicit path") + } + }) + + t.Run("path binary resolution rejects non-absolute path", func(t *testing.T) { + _, err := resolveGatewayLaunchSpecWithDeps( + ResolveOptions{}, + func(binary string) (string, error) { + switch binary { + case "neocode-gateway": + return "neocode-gateway", nil + case "neocode": + return "/usr/local/bin/neocode", nil + default: + return "", errors.New("unexpected lookup") + } + }, + ) + if err == nil { + t.Fatal("expected non-absolute path resolution error") + } + if !strings.Contains(err.Error(), "not an absolute path") { + t.Fatalf("error = %v, want contains %q", err, "not an absolute path") + } + }) + + t.Run("fallback binary resolution rejects non-absolute path", func(t *testing.T) { + _, err := resolveGatewayLaunchSpecWithDeps( + ResolveOptions{}, + func(binary string) (string, error) { + switch binary { + case "neocode-gateway": + return "", errors.New("not found") + case "neocode": + return "neocode", nil + default: + return "", errors.New("unexpected lookup") + } + }, + ) + if err == nil { + t.Fatal("expected non-absolute fallback path resolution error") + } + if !strings.Contains(err.Error(), "not an absolute path") { + t.Fatalf("error = %v, want contains %q", err, "not an absolute path") + } + }) + t.Run("fallback fails when neocode is unavailable", func(t *testing.T) { _, err := resolveGatewayLaunchSpecWithDeps( ResolveOptions{}, diff --git a/scripts/check_gateway_docs/main.go b/scripts/check_gateway_docs/main.go new file mode 100644 index 00000000..960d0705 --- /dev/null +++ b/scripts/check_gateway_docs/main.go @@ -0,0 +1,112 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" +) + +const ( + gatewayExamplesPath = "docs/generated/gateway-rpc-examples.json" + gatewayRPCDocPath = "docs/gateway-rpc-api.md" +) + +// main 执行 Gateway RPC 文档一致性校验,确保生成示例与主文档的关键方法声明不漂移。 +func main() { + if err := checkGatewayRPCDocConsistency(gatewayExamplesPath, gatewayRPCDocPath); err != nil { + fmt.Fprintf(os.Stderr, "gateway docs consistency check failed: %v\n", err) + os.Exit(1) + } + fmt.Printf("verified gateway docs consistency: %s <-> %s\n", gatewayExamplesPath, gatewayRPCDocPath) +} + +// checkGatewayRPCDocConsistency 校验示例 JSON 中的 gateway 方法在主文档中均有对应 Method 小节。 +func checkGatewayRPCDocConsistency(examplesPath, docPath string) error { + examples, err := readGatewayExamples(examplesPath) + if err != nil { + return err + } + + docContent, err := readGatewayRPCDoc(docPath) + if err != nil { + return err + } + if !containsAnyPathReference(docContent, pathReferenceCandidates(examplesPath)) { + return fmt.Errorf("rpc doc %q must reference generated examples file %q", docPath, examplesPath) + } + + missingSections := collectMissingMethodSections(docContent, collectGatewayMethods(examples)) + if len(missingSections) > 0 { + return fmt.Errorf("rpc doc %q is missing sections for generated methods: %s", docPath, strings.Join(missingSections, ", ")) + } + return nil +} + +// readGatewayExamples 读取并解析生成的示例文件,统一错误包装。 +func readGatewayExamples(examplesPath string) (map[string]json.RawMessage, error) { + rawExamples, err := os.ReadFile(examplesPath) + if err != nil { + return nil, fmt.Errorf("read examples file %q: %w", examplesPath, err) + } + + var examples map[string]json.RawMessage + if err := json.Unmarshal(rawExamples, &examples); err != nil { + return nil, fmt.Errorf("decode examples file %q: %w", examplesPath, err) + } + return examples, nil +} + +// readGatewayRPCDoc 读取 Gateway RPC 主文档内容。 +func readGatewayRPCDoc(docPath string) (string, error) { + rawDoc, err := os.ReadFile(docPath) + if err != nil { + return "", fmt.Errorf("read rpc doc %q: %w", docPath, err) + } + return string(rawDoc), nil +} + +// pathReferenceCandidates 返回示例文件可能出现的文档引用形式,兼容绝对路径与仓库相对路径。 +func pathReferenceCandidates(examplesPath string) []string { + normalizedInput := filepath.ToSlash(examplesPath) + return []string{ + normalizedInput, + filepath.ToSlash(filepath.Join("docs", "generated", filepath.Base(examplesPath))), + } +} + +// containsAnyPathReference 判断文档是否包含任意一个合法引用路径。 +func containsAnyPathReference(content string, candidates []string) bool { + for _, candidate := range candidates { + if strings.Contains(content, candidate) { + return true + } + } + return false +} + +// collectMissingMethodSections 收集文档中缺失的方法小节标题,便于稳定输出错误信息。 +func collectMissingMethodSections(docContent string, methods []string) []string { + missingSections := make([]string, 0) + for _, method := range methods { + heading := "## Method: " + method + if !strings.Contains(docContent, heading) { + missingSections = append(missingSections, heading) + } + } + return missingSections +} + +// collectGatewayMethods 从生成示例键中提取 gateway.* 方法名并排序,便于稳定校验与报错。 +func collectGatewayMethods(examples map[string]json.RawMessage) []string { + methods := make([]string, 0, len(examples)) + for key := range examples { + if strings.HasPrefix(key, "gateway.") { + methods = append(methods, key) + } + } + sort.Strings(methods) + return methods +} diff --git a/scripts/check_gateway_docs/main_test.go b/scripts/check_gateway_docs/main_test.go new file mode 100644 index 00000000..8632e015 --- /dev/null +++ b/scripts/check_gateway_docs/main_test.go @@ -0,0 +1,112 @@ +package main + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" +) + +// writeGatewayDocFixtures 写入文档校验测试所需的示例与文档文件。 +func writeGatewayDocFixtures(t *testing.T, examples string, doc string) (string, string) { + t.Helper() + + tempDir := t.TempDir() + examplesPath := filepath.Join(tempDir, "gateway-rpc-examples.json") + docPath := filepath.Join(tempDir, "gateway-rpc-api.md") + if err := os.WriteFile(examplesPath, []byte(examples), 0o644); err != nil { + t.Fatalf("write examples: %v", err) + } + if err := os.WriteFile(docPath, []byte(doc), 0o644); err != nil { + t.Fatalf("write doc: %v", err) + } + return examplesPath, docPath +} + +func TestCheckGatewayRPCDocConsistency(t *testing.T) { + t.Run("success when methods and generated path are in doc", func(t *testing.T) { + examples := `{ + "gateway.bindStream": {}, + "gateway.run": {}, + "common.error": {} +} +` + doc := strings.Join([]string{ + "# Gateway RPC API", + "", + "产物:docs/generated/gateway-rpc-examples.json", + "", + "## Method: gateway.bindStream", + "", + "## Method: gateway.run", + }, "\n") + examplesPath, docPath := writeGatewayDocFixtures(t, examples, doc) + + if err := checkGatewayRPCDocConsistency(examplesPath, docPath); err != nil { + t.Fatalf("checkGatewayRPCDocConsistency() error = %v", err) + } + }) + + t.Run("fails when doc misses generated path reference", func(t *testing.T) { + examples := `{"gateway.run":{}}` + doc := "## Method: gateway.run\n" + examplesPath, docPath := writeGatewayDocFixtures(t, examples, doc) + + err := checkGatewayRPCDocConsistency(examplesPath, docPath) + if err == nil { + t.Fatal("expected generated path reference error") + } + if !strings.Contains(err.Error(), "must reference generated examples file") { + t.Fatalf("error = %v, want contains %q", err, "must reference generated examples file") + } + }) + + t.Run("fails when doc misses method sections", func(t *testing.T) { + examples := `{"gateway.bindStream":{},"gateway.run":{}}` + doc := strings.Join([]string{ + "docs/generated/gateway-rpc-examples.json", + "## Method: gateway.run", + }, "\n") + examplesPath, docPath := writeGatewayDocFixtures(t, examples, doc) + + err := checkGatewayRPCDocConsistency(examplesPath, docPath) + if err == nil { + t.Fatal("expected missing method section error") + } + if !strings.Contains(err.Error(), "## Method: gateway.bindStream") { + t.Fatalf("error = %v, want contains %q", err, "## Method: gateway.bindStream") + } + }) +} + +func TestCollectGatewayMethods(t *testing.T) { + methods := collectGatewayMethods(map[string]json.RawMessage{ + "common.error": nil, + "gateway.run": nil, + "gateway.bindStream": nil, + }) + + want := []string{"gateway.bindStream", "gateway.run"} + if len(methods) != len(want) { + t.Fatalf("len(methods) = %d, want %d", len(methods), len(want)) + } + for index := range want { + if methods[index] != want[index] { + t.Fatalf("methods[%d] = %q, want %q", index, methods[index], want[index]) + } + } +} + +func TestCollectMissingMethodSections(t *testing.T) { + missing := collectMissingMethodSections("## Method: gateway.run", []string{"gateway.bindStream", "gateway.run"}) + want := []string{"## Method: gateway.bindStream"} + if len(missing) != len(want) { + t.Fatalf("len(missing) = %d, want %d", len(missing), len(want)) + } + for index := range want { + if missing[index] != want[index] { + t.Fatalf("missing[%d] = %q, want %q", index, missing[index], want[index]) + } + } +}