diff --git a/cmd/asb-api/observability.go b/cmd/asb-api/observability.go index d93820e..8a8516c 100644 --- a/cmd/asb-api/observability.go +++ b/cmd/asb-api/observability.go @@ -11,10 +11,10 @@ import ( ) func newObservedHandler(logger *slog.Logger, metrics *observability.Metrics, next http.Handler) http.Handler { - observed := httpkit.WithRequestID(observability.RequestLoggingMiddleware(logger, metrics)(next)) if metrics == nil { - return observed + return httpkit.WithRequestID(next) } + observed := httpkit.WithRequestID(observability.RequestLoggingMiddleware(logger, metrics)(next)) metricsHandler := metrics.Handler() return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/cmd/asb-api/observability_test.go b/cmd/asb-api/observability_test.go index f9c85b5..6e0e3ae 100644 --- a/cmd/asb-api/observability_test.go +++ b/cmd/asb-api/observability_test.go @@ -91,6 +91,26 @@ func TestNewObservedHandlerRecordsRequestMetrics(t *testing.T) { } } +func TestNewObservedHandlerAllowsNilMetrics(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/v1/test", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNoContent) + }) + + handler := newObservedHandler(discardLogger(), nil, mux) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/v1/test", nil)) + + if recorder.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusNoContent) + } + if got := recorder.Header().Get("X-Request-Id"); got == "" { + t.Fatal("expected X-Request-Id response header") + } +} + func TestRegisterRuntimeMetricsRegistersDBStats(t *testing.T) { t.Parallel()