Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions internal/config/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,18 @@ func TestGatewayApplyDefaultsNilReceivers(t *testing.T) {
var observabilityCfg *GatewayObservabilityConfig
observabilityCfg.ApplyDefaults(GatewayObservabilityConfig{MetricsEnabled: boolPtr(false)})
}

func TestLoadGatewayConfigReadFileError(t *testing.T) {
t.Parallel()

baseDir := t.TempDir()
configPath := filepath.Join(baseDir, configName)
if err := os.Mkdir(configPath, 0o755); err != nil {
t.Fatalf("mkdir config path: %v", err)
}

_, err := LoadGatewayConfig(context.Background(), baseDir)
if err == nil || !strings.Contains(err.Error(), "read gateway config file") {
t.Fatalf("expected read gateway config file error, got %v", err)
}
}
26 changes: 26 additions & 0 deletions internal/gateway/adapters/urlscheme/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,32 @@ func TestDispatcherAuthenticateBranches(t *testing.T) {
t.Fatalf("expected unexpected auth frame error, got %v", err)
}
})

t.Run("auth version mismatch", func(t *testing.T) {
dispatcher := &Dispatcher{
requestIDFn: func() string { return "wake-auth-4" },
}
conn := &stubDispatchConn{
readBuffer: bytes.NewBufferString(`{"jsonrpc":"1.0","id":"wake-auth-4-auth","result":{}}` + "\n"),
}
err := dispatcher.authenticate(context.Background(), conn, "token-1")
if err == nil || !strings.Contains(err.Error(), "jsonrpc version") {
t.Fatalf("expected auth version mismatch error, got %v", err)
}
})

t.Run("auth id mismatch", func(t *testing.T) {
dispatcher := &Dispatcher{
requestIDFn: func() string { return "wake-auth-5" },
}
conn := &stubDispatchConn{
readBuffer: bytes.NewBufferString(`{"jsonrpc":"2.0","id":"other-auth","result":{}}` + "\n"),
}
err := dispatcher.authenticate(context.Background(), conn, "token-1")
if err == nil || !strings.Contains(err.Error(), "id mismatch") {
t.Fatalf("expected auth id mismatch error, got %v", err)
}
})
}

func TestDispatcherDispatchWithAuthHandshake(t *testing.T) {
Expand Down
10 changes: 10 additions & 0 deletions internal/gateway/auth/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,13 @@ func TestDefaultAuthPathAndLoadOrCreateNilManager(t *testing.T) {
t.Fatal("expected nil manager loadOrCreate error")
}
}

func TestManagerGenerateTokenAndWriteCredentialsErrorBranches(t *testing.T) {
t.Run("write credentials fails on directory path", func(t *testing.T) {
dir := t.TempDir()
err := writeCredentials(dir, Credentials{Version: 1, Token: "token"})
if err == nil || !strings.Contains(err.Error(), "write auth file") {
t.Fatalf("expected write auth file error, got %v", err)
}
})
}
89 changes: 89 additions & 0 deletions internal/gateway/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,92 @@ func TestHandleAuthenticateFrameBranches(t *testing.T) {
}
})
}

type invalidJSONMarshaler struct{}

func (invalidJSONMarshaler) MarshalJSON() ([]byte, error) {
return []byte("{"), nil
}

func TestBootstrapDecodeAdditionalBranches(t *testing.T) {
t.Run("decode authenticate pointer nil", func(t *testing.T) {
_, frameErr := decodeAuthenticatePayload((*protocol.AuthenticateParams)(nil))
if frameErr == nil || frameErr.Code != ErrorCodeMissingRequiredField.String() {
t.Fatalf("frameErr = %#v, want missing_required_field", frameErr)
}
})

t.Run("decode authenticate unmarshal error", func(t *testing.T) {
_, frameErr := decodeAuthenticatePayload(invalidJSONMarshaler{})
if frameErr == nil || frameErr.Code != ErrorCodeInvalidFrame.String() {
t.Fatalf("frameErr = %#v, want invalid_frame", frameErr)
}
})

t.Run("decode bind stream unmarshal error", func(t *testing.T) {
_, frameErr := decodeBindStreamParams(invalidJSONMarshaler{})
if frameErr == nil || frameErr.Code != ErrorCodeInvalidFrame.String() {
t.Fatalf("frameErr = %#v, want invalid_frame", frameErr)
}
})

t.Run("decode bind stream pointer nil", func(t *testing.T) {
_, frameErr := decodeBindStreamParams((*protocol.BindStreamParams)(nil))
if frameErr == nil || frameErr.Code != ErrorCodeInvalidFrame.String() {
t.Fatalf("frameErr = %#v, want invalid_frame", frameErr)
}
})

t.Run("decode bind stream map missing session", func(t *testing.T) {
_, frameErr := decodeBindStreamParams(map[string]any{"run_id": "r-1"})
if frameErr == nil || frameErr.Code != ErrorCodeMissingRequiredField.String() {
t.Fatalf("frameErr = %#v, want missing_required_field", frameErr)
}
})

t.Run("decode bind stream default struct", func(t *testing.T) {
payload := struct {
SessionID string `json:"session_id"`
RunID string `json:"run_id"`
Channel string `json:"channel"`
}{
SessionID: "s-1",
RunID: "r-1",
Channel: "ipc",
}
params, frameErr := decodeBindStreamParams(payload)
if frameErr != nil {
t.Fatalf("frameErr = %v", frameErr)
}
if params.SessionID != "s-1" || params.Channel != StreamChannelIPC {
t.Fatalf("params = %#v", params)
}
})

t.Run("decode wake pointer nil", func(t *testing.T) {
_, err := decodeWakeIntent((*protocol.WakeIntent)(nil))
if err == nil {
t.Fatal("expected nil pointer wake payload error")
}
})

t.Run("decode wake unmarshal error", func(t *testing.T) {
_, err := decodeWakeIntent(invalidJSONMarshaler{})
if err == nil {
t.Fatal("expected wake decode error")
}
})

t.Run("decode wake direct struct", func(t *testing.T) {
intent, err := decodeWakeIntent(protocol.WakeIntent{
Action: "review",
Params: map[string]string{"path": "README.md"},
})
if err != nil {
t.Fatalf("decode wake intent: %v", err)
}
if intent.Action != "review" {
t.Fatalf("action = %q, want review", intent.Action)
}
})
}
82 changes: 82 additions & 0 deletions internal/gateway/network_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,88 @@ func TestNetworkServerVersionAndObservabilityAuthHelpers(t *testing.T) {
})
}

func TestNetworkServerObservabilityHandlerBranches(t *testing.T) {
t.Run("prometheus handler branches", func(t *testing.T) {
server := &NetworkServer{}

recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/metrics", nil)
server.handlePrometheusMetrics(recorder, request)
if recorder.Code != http.StatusMethodNotAllowed {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusMethodNotAllowed)
}

recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodGet, "/metrics", nil)
server.handlePrometheusMetrics(recorder, request)
if recorder.Code != http.StatusServiceUnavailable {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusServiceUnavailable)
}

server.metrics = NewGatewayMetrics()
server.authenticator = staticTokenAuthenticator{token: "token-1"}
recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodGet, "/metrics", nil)
server.handlePrometheusMetrics(recorder, request)
if recorder.Code != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusUnauthorized)
}

recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodGet, "/metrics", nil)
request.Header.Set("Authorization", "Bearer token-1")
server.handlePrometheusMetrics(recorder, request)
if recorder.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusOK)
}
})

t.Run("json metrics unauthorized", func(t *testing.T) {
server := &NetworkServer{
metrics: NewGatewayMetrics(),
authenticator: staticTokenAuthenticator{token: "token-2"},
allowedOrigins: defaultControlPlaneOrigins(),
}
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/metrics.json", nil)
server.handleJSONMetrics(recorder, request)
if recorder.Code != http.StatusUnauthorized {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusUnauthorized)
}
})
}

func TestNetworkServerOriginAndTokenHelpersAdditional(t *testing.T) {
normalized := normalizeControlPlaneOrigins([]string{" HTTP://LOCALHOST ", "", " app://desktop "})
if len(normalized) != 2 {
t.Fatalf("normalized len = %d, want 2", len(normalized))
}
if normalized[0] != "http://localhost" || normalized[1] != "app://desktop" {
t.Fatalf("normalized = %#v", normalized)
}

if originMatchesAllowRule("http://localhost:3000", "") {
t.Fatal("blank allow rule should reject")
}
if !originMatchesAllowRule("http://[::1]:8080", "http://[::1]") {
t.Fatal("ipv6 localhost with port should match")
}

server := &NetworkServer{allowedOrigins: []string{"http://localhost"}}
if err := server.validateWebSocketOrigin(nil); err == nil {
t.Fatal("nil request should be rejected")
}
request := httptest.NewRequest(http.MethodGet, "/ws", nil)
request.Header.Set("Origin", "http://evil.example")
if err := server.validateWebSocketOrigin(request); err == nil {
t.Fatal("disallowed origin should be rejected")
}

if token := extractBearerToken("Basic abc"); token != "" {
t.Fatalf("token = %q, want empty", token)
}
}

func TestNetworkServerCloseInterruptsStreams(t *testing.T) {
server := newTestNetworkServer(t, NetworkServerOptions{})
testContext, cancel := context.WithCancel(context.Background())
Expand Down
26 changes: 26 additions & 0 deletions internal/gateway/protocol/jsonrpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,29 @@ func TestNewJSONRPCErrorResponseWithNilIDEncodesNull(t *testing.T) {
t.Fatalf("encoded response id = %#v, want nil", payload["id"])
}
}

func TestJSONRPCHelpersAdditionalBranches(t *testing.T) {
if got := MapGatewayCodeToJSONRPCCode(GatewayCodeInternalError); got != JSONRPCCodeInternalError {
t.Fatalf("internal mapping = %d, want %d", got, JSONRPCCodeInternalError)
}
if got := MapGatewayCodeToJSONRPCCode("unknown-code"); got != JSONRPCCodeInternalError {
t.Fatalf("default mapping = %d, want %d", got, JSONRPCCodeInternalError)
}

if _, err := normalizeJSONRPCID(json.RawMessage(`null`)); err == nil {
t.Fatal("expected null id error")
}
if _, err := normalizeJSONRPCID(json.RawMessage(`" "`)); err == nil {
t.Fatal("expected blank string id error")
}

if _, err := decodeAuthenticateParams(json.RawMessage(`null`)); err == nil {
t.Fatal("expected missing authenticate params error")
}
if _, err := decodeBindStreamParams(json.RawMessage(`null`)); err == nil {
t.Fatal("expected missing bind stream params error")
}
if _, err := decodeWakeIntentParams(json.RawMessage(`null`)); err == nil {
t.Fatal("expected missing wake params error")
}
}
19 changes: 19 additions & 0 deletions internal/gateway/request_logging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,22 @@ func TestRequestLatencyMS(t *testing.T) {
t.Fatal("requestStartTime should not return zero time")
}
}

func TestEmitRequestLogUsesContextSource(t *testing.T) {
t.Parallel()

buffer := &bytes.Buffer{}
logger := log.New(buffer, "", 0)
ctx := WithRequestSource(context.Background(), RequestSourceHTTP)

emitRequestLog(ctx, logger, RequestLogEntry{
RequestID: "req-http",
Method: "gateway.ping",
Status: "ok",
})

output := buffer.String()
if !strings.Contains(output, `"source":"http"`) {
t.Fatalf("output = %q, want http source", output)
}
}
34 changes: 34 additions & 0 deletions internal/gateway/rpc_dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,37 @@ func TestDispatchRPCRequestMetricsBranches(t *testing.T) {
t.Fatalf("expected ok request metric, snapshot=%#v", snapshot["gateway_requests_total"])
}
}

func TestDispatchRPCRequestFrameErrorWithoutPayload(t *testing.T) {
originalHandlers := requestFrameHandlers
requestFrameHandlers = map[FrameAction]requestFrameHandler{
FrameActionPing: func(_ context.Context, frame MessageFrame) MessageFrame {
return MessageFrame{Type: FrameTypeError, Action: frame.Action, RequestID: frame.RequestID}
},
}
t.Cleanup(func() { requestFrameHandlers = originalHandlers })

response := dispatchRPCRequest(context.Background(), protocol.JSONRPCRequest{
JSONRPC: protocol.JSONRPCVersion,
ID: json.RawMessage(`"rpc-noerr-1"`),
Method: protocol.MethodGatewayPing,
Params: json.RawMessage(`{}`),
}, nil)
if response.Error == nil {
t.Fatal("expected rpc error response")
}
if code := protocol.GatewayCodeFromJSONRPCError(response.Error); code != ErrorCodeInternalError.String() {
t.Fatalf("gateway_code = %q, want %q", code, ErrorCodeInternalError.String())
}
}

func TestHydrateFrameSessionFromPayloadBranch(t *testing.T) {
frame := hydrateFrameSessionFromConnection(context.Background(), MessageFrame{
Type: FrameTypeRequest,
Action: FrameActionPing,
Payload: map[string]any{"session_id": "s-from-payload"},
})
if frame.SessionID != "s-from-payload" {
t.Fatalf("session_id = %q, want %q", frame.SessionID, "s-from-payload")
}
}
46 changes: 46 additions & 0 deletions internal/gateway/stream_relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,49 @@ func registerConnectionForRelayTest(
t.Fatalf("bind connection: %v", bindErr)
}
}

func TestStreamRelayAdditionalCoverageBranches(t *testing.T) {
t.Run("snapshot counts skips nil connection entry", func(t *testing.T) {
relay := NewStreamRelay(StreamRelayOptions{})
relay.mu.Lock()
relay.connections[ConnectionID("nil-entry")] = nil
relay.mu.Unlock()

snapshot := relay.SnapshotConnectionCounts()
if snapshot[StreamChannelIPC] != 0 || snapshot[StreamChannelWS] != 0 || snapshot[StreamChannelSSE] != 0 {
t.Fatalf("unexpected snapshot: %#v", snapshot)
}
})

t.Run("send sync response branches", func(t *testing.T) {
var nilRelay *StreamRelay
if nilRelay.SendJSONRPCResponseSync("cid", protocol.JSONRPCResponse{}) {
t.Fatal("nil relay should return false")
}

relay := NewStreamRelay(StreamRelayOptions{})
if relay.SendJSONRPCResponseSync("", protocol.JSONRPCResponse{}) {
t.Fatal("empty connection id should return false")
}
if relay.SendJSONRPCResponseSync("missing", protocol.JSONRPCResponse{}) {
t.Fatal("missing connection should return false")
}
})

t.Run("update active connection metrics", func(t *testing.T) {
metrics := NewGatewayMetrics()
relay := NewStreamRelay(StreamRelayOptions{Metrics: metrics})
relay.mu.Lock()
relay.connections[ConnectionID("ipc")] = &relayConnection{channel: StreamChannelIPC}
relay.connections[ConnectionID("ws")] = &relayConnection{channel: StreamChannelWS}
relay.connections[ConnectionID("sse")] = &relayConnection{channel: StreamChannelSSE}
relay.connections[ConnectionID("nil")] = nil
relay.updateActiveConnectionMetricsLocked()
relay.mu.Unlock()

entries := metrics.Snapshot()["gateway_connections_active"]
if entries["ipc"] != 1 || entries["ws"] != 1 || entries["sse"] != 1 {
t.Fatalf("unexpected metrics snapshot: %#v", entries)
}
})
}
Loading
Loading