diff --git a/docs/swagger/docs.go b/docs/swagger/docs.go index 49a8d4108..5f4e4ed2f 100644 --- a/docs/swagger/docs.go +++ b/docs/swagger/docs.go @@ -8680,12 +8680,38 @@ const docTemplate = `{ "domain.GatewayRoute": { "type": "object", "properties": { + "allowed_cidrs": { + "description": "IPs allowed to access (empty = all)", + "type": "array", + "items": { + "type": "string" + } + }, + "blocked_cidrs": { + "description": "IPs blocked from access", + "type": "array", + "items": { + "type": "string" + } + }, "created_at": { "type": "string" }, + "dial_timeout": { + "description": "TCP dial timeout in milliseconds", + "type": "integer" + }, "id": { "type": "string" }, + "idle_conn_timeout": { + "description": "Idle connection timeout in milliseconds", + "type": "integer" + }, + "max_body_size": { + "description": "Max request body size in bytes", + "type": "integer" + }, "methods": { "description": "New: HTTP methods to match (empty = all)", "type": "array", @@ -8723,6 +8749,14 @@ const docTemplate = `{ "description": "Maximum allowed requests per second per IP", "type": "integer" }, + "require_tls": { + "description": "Force HTTPS for backend", + "type": "boolean" + }, + "response_header_timeout": { + "description": "Time to receive headers in milliseconds", + "type": "integer" + }, "strip_prefix": { "description": "If true, removes path_prefix from request before forwarding", "type": "boolean" @@ -8734,6 +8768,10 @@ const docTemplate = `{ "tenant_id": { "type": "string" }, + "tls_skip_verify": { + "description": "Skip TLS verification for backend", + "type": "boolean" + }, "updated_at": { "type": "string" }, @@ -11036,6 +11074,27 @@ const docTemplate = `{ "target_url" ], "properties": { + "allowed_cidrs": { + "type": "array", + "items": { + "type": "string" + } + }, + "blocked_cidrs": { + "type": "array", + "items": { + "type": "string" + } + }, + "dial_timeout": { + "type": "integer" + }, + "idle_conn_timeout": { + "type": "integer" + }, + "max_body_size": { + "type": "integer" + }, "methods": { "type": "array", "items": { @@ -11054,11 +11113,20 @@ const docTemplate = `{ "rate_limit": { "type": "integer" }, + "require_tls": { + "type": "boolean" + }, + "response_header_timeout": { + "type": "integer" + }, "strip_prefix": { "type": "boolean" }, "target_url": { "type": "string" + }, + "tls_skip_verify": { + "type": "boolean" } } }, diff --git a/docs/swagger/swagger.json b/docs/swagger/swagger.json index a4e653353..10a8e6945 100644 --- a/docs/swagger/swagger.json +++ b/docs/swagger/swagger.json @@ -8672,12 +8672,38 @@ "domain.GatewayRoute": { "type": "object", "properties": { + "allowed_cidrs": { + "description": "IPs allowed to access (empty = all)", + "type": "array", + "items": { + "type": "string" + } + }, + "blocked_cidrs": { + "description": "IPs blocked from access", + "type": "array", + "items": { + "type": "string" + } + }, "created_at": { "type": "string" }, + "dial_timeout": { + "description": "TCP dial timeout in milliseconds", + "type": "integer" + }, "id": { "type": "string" }, + "idle_conn_timeout": { + "description": "Idle connection timeout in milliseconds", + "type": "integer" + }, + "max_body_size": { + "description": "Max request body size in bytes", + "type": "integer" + }, "methods": { "description": "New: HTTP methods to match (empty = all)", "type": "array", @@ -8715,6 +8741,14 @@ "description": "Maximum allowed requests per second per IP", "type": "integer" }, + "require_tls": { + "description": "Force HTTPS for backend", + "type": "boolean" + }, + "response_header_timeout": { + "description": "Time to receive headers in milliseconds", + "type": "integer" + }, "strip_prefix": { "description": "If true, removes path_prefix from request before forwarding", "type": "boolean" @@ -8726,6 +8760,10 @@ "tenant_id": { "type": "string" }, + "tls_skip_verify": { + "description": "Skip TLS verification for backend", + "type": "boolean" + }, "updated_at": { "type": "string" }, @@ -11028,6 +11066,27 @@ "target_url" ], "properties": { + "allowed_cidrs": { + "type": "array", + "items": { + "type": "string" + } + }, + "blocked_cidrs": { + "type": "array", + "items": { + "type": "string" + } + }, + "dial_timeout": { + "type": "integer" + }, + "idle_conn_timeout": { + "type": "integer" + }, + "max_body_size": { + "type": "integer" + }, "methods": { "type": "array", "items": { @@ -11046,11 +11105,20 @@ "rate_limit": { "type": "integer" }, + "require_tls": { + "type": "boolean" + }, + "response_header_timeout": { + "type": "integer" + }, "strip_prefix": { "type": "boolean" }, "target_url": { "type": "string" + }, + "tls_skip_verify": { + "type": "boolean" } } }, diff --git a/docs/swagger/swagger.yaml b/docs/swagger/swagger.yaml index 5dba3664b..4b903592e 100644 --- a/docs/swagger/swagger.yaml +++ b/docs/swagger/swagger.yaml @@ -401,10 +401,29 @@ definitions: type: object domain.GatewayRoute: properties: + allowed_cidrs: + description: IPs allowed to access (empty = all) + items: + type: string + type: array + blocked_cidrs: + description: IPs blocked from access + items: + type: string + type: array created_at: type: string + dial_timeout: + description: TCP dial timeout in milliseconds + type: integer id: type: string + idle_conn_timeout: + description: Idle connection timeout in milliseconds + type: integer + max_body_size: + description: Max request body size in bytes + type: integer methods: description: 'New: HTTP methods to match (empty = all)' items: @@ -432,6 +451,12 @@ definitions: rate_limit: description: Maximum allowed requests per second per IP type: integer + require_tls: + description: Force HTTPS for backend + type: boolean + response_header_timeout: + description: Time to receive headers in milliseconds + type: integer strip_prefix: description: If true, removes path_prefix from request before forwarding type: boolean @@ -440,6 +465,9 @@ definitions: type: string tenant_id: type: string + tls_skip_verify: + description: Skip TLS verification for backend + type: boolean updated_at: type: string user_id: @@ -2099,6 +2127,20 @@ definitions: type: object httphandlers.CreateRouteRequest: properties: + allowed_cidrs: + items: + type: string + type: array + blocked_cidrs: + items: + type: string + type: array + dial_timeout: + type: integer + idle_conn_timeout: + type: integer + max_body_size: + type: integer methods: items: type: string @@ -2111,10 +2153,16 @@ definitions: type: integer rate_limit: type: integer + require_tls: + type: boolean + response_header_timeout: + type: integer strip_prefix: type: boolean target_url: type: string + tls_skip_verify: + type: boolean required: - name - path_prefix diff --git a/internal/api/setup/router.go b/internal/api/setup/router.go index 86552be61..423650e43 100644 --- a/internal/api/setup/router.go +++ b/internal/api/setup/router.go @@ -106,7 +106,7 @@ func InitHandlers(svcs *Services, cfg *platform.Config, logger *slog.Logger) *Ha Queue: httphandlers.NewQueueHandler(svcs.Queue), Notify: httphandlers.NewNotifyHandler(svcs.Notify), Cron: httphandlers.NewCronHandler(svcs.Cron), - Gateway: httphandlers.NewGatewayHandler(svcs.Gateway), + Gateway: httphandlers.NewGatewayHandler(svcs.Gateway, logger), Container: httphandlers.NewContainerHandler(svcs.Container), Pipeline: httphandlers.NewPipelineHandler(svcs.Pipeline), Health: httphandlers.NewHealthHandler(svcs.Health), diff --git a/internal/core/domain/gateway.go b/internal/core/domain/gateway.go index ec3fd3fbd..dc6b76467 100644 --- a/internal/core/domain/gateway.go +++ b/internal/core/domain/gateway.go @@ -9,21 +9,29 @@ import ( // GatewayRoute defines an ingress rule for mapping external HTTP traffic to internal resources. type GatewayRoute struct { - ID uuid.UUID `json:"id"` - UserID uuid.UUID `json:"user_id"` - TenantID uuid.UUID `json:"tenant_id"` - Name string `json:"name"` - PathPrefix string `json:"path_prefix"` // Legacy: Request path to match (e.g., "/api/v1") - PathPattern string `json:"path_pattern"` // New: Pattern with {params} - PatternType string `json:"pattern_type"` // "prefix" or "pattern" - ParamNames []string `json:"param_names"` // Extracted parameter names - TargetURL string `json:"target_url"` // Internal destination (e.g., "http://service-a:8080") - Methods []string `json:"methods"` // New: HTTP methods to match (empty = all) - StripPrefix bool `json:"strip_prefix"` // If true, removes path_prefix from request before forwarding - RateLimit int `json:"rate_limit"` // Maximum allowed requests per second per IP - Priority int `json:"priority"` // Manual priority for tie-breaking - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + TenantID uuid.UUID `json:"tenant_id"` + Name string `json:"name"` + PathPrefix string `json:"path_prefix"` // Legacy: Request path to match (e.g., "/api/v1") + PathPattern string `json:"path_pattern"` // New: Pattern with {params} + PatternType string `json:"pattern_type"` // "prefix" or "pattern" + ParamNames []string `json:"param_names"` // Extracted parameter names + TargetURL string `json:"target_url"` // Internal destination (e.g., "http://service-a:8080") + Methods []string `json:"methods"` // New: HTTP methods to match (empty = all) + StripPrefix bool `json:"strip_prefix"` // If true, removes path_prefix from request before forwarding + RateLimit int `json:"rate_limit"` // Maximum allowed requests per second per IP + DialTimeout int64 `json:"dial_timeout,omitempty"` // TCP dial timeout in milliseconds + ResponseHeaderTimeout int64 `json:"response_header_timeout,omitempty"` // Time to receive headers in milliseconds + IdleConnTimeout int64 `json:"idle_conn_timeout,omitempty"` // Idle connection timeout in milliseconds + TLSSkipVerify bool `json:"tls_skip_verify,omitempty"` // Skip TLS verification for backend + RequireTLS bool `json:"require_tls,omitempty"` // Force HTTPS for backend + AllowedCIDRs []string `json:"allowed_cidrs,omitempty"` // IPs allowed to access (empty = all) + BlockedCIDRs []string `json:"blocked_cidrs,omitempty"` // IPs blocked from access + MaxBodySize int64 `json:"max_body_size,omitempty"` // Max request body size in bytes + Priority int `json:"priority"` // Manual priority for tie-breaking + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // RouteMatch represents a successful route pattern match. diff --git a/internal/core/ports/gateway.go b/internal/core/ports/gateway.go index eafe375ea..43e1643dd 100644 --- a/internal/core/ports/gateway.go +++ b/internal/core/ports/gateway.go @@ -28,13 +28,21 @@ type GatewayRepository interface { // CreateRouteParams holds parameters for creating a new route. type CreateRouteParams struct { - Name string - Pattern string - Target string - Methods []string - StripPrefix bool - RateLimit int - Priority int + Name string + Pattern string + Target string + Methods []string + StripPrefix bool + RateLimit int + DialTimeout int64 + ResponseHeaderTimeout int64 + IdleConnTimeout int64 + TLSSkipVerify bool + RequireTLS bool + AllowedCIDRs []string + BlockedCIDRs []string + MaxBodySize int64 + Priority int } // GatewayService provides business logic for managing the API gateway and ingress traffic. @@ -48,5 +56,6 @@ type GatewayService interface { // RefreshRoutes reloads all routes and pre-compiles matchers. RefreshRoutes(ctx context.Context) error // GetProxy finds the appropriate backend for the given path and method. - GetProxy(method, path string) (*httputil.ReverseProxy, map[string]string, bool) + // Returns proxy, route, path params, and found flag. + GetProxy(method, path string) (*httputil.ReverseProxy, *domain.GatewayRoute, map[string]string, bool) } diff --git a/internal/core/services/gateway.go b/internal/core/services/gateway.go index dfdbf8598..fdad56b5f 100644 --- a/internal/core/services/gateway.go +++ b/internal/core/services/gateway.go @@ -3,8 +3,10 @@ package services import ( "context" + "crypto/tls" "fmt" "log/slog" + "net" "net/http" "net/http/httputil" "net/url" @@ -44,7 +46,9 @@ func NewGatewayService(repo ports.GatewayRepository, rbacSvc ports.RBACService, logger: logger, } // Initial load - _ = s.RefreshRoutes(context.Background()) + if err := s.RefreshRoutes(context.Background()); err != nil { + s.logger.Error("failed to refresh routes on startup", "error", err) + } return s } @@ -69,21 +73,29 @@ func (s *GatewayService) CreateRoute(ctx context.Context, params ports.CreateRou } route := &domain.GatewayRoute{ - ID: uuid.New(), - UserID: userID, - TenantID: tenantID, - Name: params.Name, - PathPrefix: params.Pattern, // Use pattern as prefix for backward compatibility where possible - PathPattern: params.Pattern, - PatternType: patternType, - ParamNames: paramNames, - TargetURL: params.Target, - Methods: params.Methods, - StripPrefix: params.StripPrefix, - RateLimit: params.RateLimit, - Priority: params.Priority, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), + ID: uuid.New(), + UserID: userID, + TenantID: tenantID, + Name: params.Name, + PathPrefix: params.Pattern, + PathPattern: params.Pattern, + PatternType: patternType, + ParamNames: paramNames, + TargetURL: params.Target, + Methods: params.Methods, + StripPrefix: params.StripPrefix, + RateLimit: params.RateLimit, + DialTimeout: params.DialTimeout, + ResponseHeaderTimeout: params.ResponseHeaderTimeout, + IdleConnTimeout: params.IdleConnTimeout, + TLSSkipVerify: params.TLSSkipVerify, + RequireTLS: params.RequireTLS, + AllowedCIDRs: params.AllowedCIDRs, + BlockedCIDRs: params.BlockedCIDRs, + MaxBodySize: params.MaxBodySize, + Priority: params.Priority, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), } if err := s.repo.CreateRoute(ctx, route); err != nil { @@ -98,7 +110,9 @@ func (s *GatewayService) CreateRoute(ctx context.Context, params ports.CreateRou s.logger.Warn("failed to log audit event", "action", "gateway.route_create", "route_id", route.ID, "error", err) } - _ = s.RefreshRoutes(ctx) + if err := s.RefreshRoutes(ctx); err != nil { + s.logger.Warn("failed to refresh routes after create", "route_id", route.ID, "error", err) + } return route, nil } @@ -150,6 +164,7 @@ func (s *GatewayService) RefreshRoutes(ctx context.Context) error { for _, r := range routes { proxy, err := s.createReverseProxy(r) if err != nil { + s.logger.Error("failed to create reverse proxy for route", "route_id", r.ID, "route_name", r.Name, "target_url", r.TargetURL, "error", err) continue } @@ -180,6 +195,32 @@ func (s *GatewayService) createReverseProxy(route *domain.GatewayRoute) (*httput } proxy := httputil.NewSingleHostReverseProxy(target) + + // Configure custom transport with timeouts and TLS + dialTimeout := time.Duration(route.DialTimeout) * time.Millisecond + if dialTimeout <= 0 { + dialTimeout = 5 * time.Second + } + responseHeaderTimeout := time.Duration(route.ResponseHeaderTimeout) * time.Millisecond + if responseHeaderTimeout <= 0 { + responseHeaderTimeout = 30 * time.Second + } + idleConnTimeout := time.Duration(route.IdleConnTimeout) * time.Millisecond + if idleConnTimeout <= 0 { + idleConnTimeout = 90 * time.Second + } + + proxy.Transport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: dialTimeout, + KeepAlive: 30 * time.Second, + }).DialContext, + ResponseHeaderTimeout: responseHeaderTimeout, + IdleConnTimeout: idleConnTimeout, + TLSClientConfig: s.buildTLSConfig(route), + TLSHandshakeTimeout: 10 * time.Second, + } + originalDirector := proxy.Director proxy.Director = func(req *http.Request) { if route.StripPrefix { @@ -199,6 +240,18 @@ func (s *GatewayService) createReverseProxy(route *domain.GatewayRoute) (*httput return proxy, nil } +func (s *GatewayService) buildTLSConfig(route *domain.GatewayRoute) *tls.Config { + cfg := &tls.Config{ + InsecureSkipVerify: route.TLSSkipVerify, //nolint:gosec // User-controlled option for development/testing + } + // Always set baseline TLS 1.2, raise to 1.3 if RequireTLS + cfg.MinVersion = tls.VersionTLS12 + if route.RequireTLS { + cfg.MinVersion = tls.VersionTLS13 + } + return cfg +} + func (s *GatewayService) sortRoutes(routes []*domain.GatewayRoute) { // Sort routes by specificity (longer literal prefixes and higher priority first) sort.Slice(routes, func(i, j int) bool { @@ -210,7 +263,7 @@ func (s *GatewayService) sortRoutes(routes []*domain.GatewayRoute) { // ProxyHandler is handled in the API layer for now -func (s *GatewayService) GetProxy(method, path string) (*httputil.ReverseProxy, map[string]string, bool) { +func (s *GatewayService) GetProxy(method, path string) (*httputil.ReverseProxy, *domain.GatewayRoute, map[string]string, bool) { s.proxyMu.RLock() defer s.proxyMu.RUnlock() @@ -226,10 +279,10 @@ func (s *GatewayService) GetProxy(method, path string) (*httputil.ReverseProxy, } if bestMatch != nil { - return s.proxies[bestMatch.Route.ID], bestMatch.Params, true + return s.proxies[bestMatch.Route.ID], bestMatch.Route, bestMatch.Params, true } - return nil, nil, false + return nil, nil, nil, false } func (s *GatewayService) checkRouteMatch(route *domain.GatewayRoute, method, path string) *domain.RouteMatch { diff --git a/internal/handlers/gateway_handler.go b/internal/handlers/gateway_handler.go index 1f4b45382..b92a42e06 100644 --- a/internal/handlers/gateway_handler.go +++ b/internal/handlers/gateway_handler.go @@ -2,11 +2,18 @@ package httphandlers import ( + "crypto/rand" + "encoding/hex" + "fmt" + "io" + "log/slog" + "net" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/poyrazk/thecloud/internal/core/domain" "github.com/poyrazk/thecloud/internal/core/ports" "github.com/poyrazk/thecloud/internal/errors" "github.com/poyrazk/thecloud/pkg/httputil" @@ -14,23 +21,32 @@ import ( // CreateRouteRequest define the payload for creating a route. type CreateRouteRequest struct { - Name string `json:"name" binding:"required"` - PathPrefix string `json:"path_prefix" binding:"required"` - TargetURL string `json:"target_url" binding:"required"` - Methods []string `json:"methods"` - StripPrefix bool `json:"strip_prefix"` - RateLimit int `json:"rate_limit"` - Priority int `json:"priority"` + Name string `json:"name" binding:"required"` + PathPrefix string `json:"path_prefix" binding:"required"` + TargetURL string `json:"target_url" binding:"required"` + Methods []string `json:"methods"` + StripPrefix bool `json:"strip_prefix"` + RateLimit int `json:"rate_limit"` + DialTimeout int64 `json:"dial_timeout"` + ResponseHeaderTimeout int64 `json:"response_header_timeout"` + IdleConnTimeout int64 `json:"idle_conn_timeout"` + TLSSkipVerify bool `json:"tls_skip_verify"` + RequireTLS bool `json:"require_tls"` + AllowedCIDRs []string `json:"allowed_cidrs"` + BlockedCIDRs []string `json:"blocked_cidrs"` + MaxBodySize int64 `json:"max_body_size"` + Priority int `json:"priority"` } // GatewayHandler handles API gateway HTTP endpoints. type GatewayHandler struct { - svc ports.GatewayService + svc ports.GatewayService + logger *slog.Logger } // NewGatewayHandler constructs a GatewayHandler. -func NewGatewayHandler(svc ports.GatewayService) *GatewayHandler { - return &GatewayHandler{svc: svc} +func NewGatewayHandler(svc ports.GatewayService, logger *slog.Logger) *GatewayHandler { + return &GatewayHandler{svc: svc, logger: logger} } // CreateRoute establishes a new ingress mapping @@ -56,14 +72,28 @@ func (h *GatewayHandler) CreateRoute(c *gin.Context) { req.RateLimit = 100 } + // Validate TLS settings + if req.RequireTLS && req.TLSSkipVerify { + httputil.Error(c, errors.New(errors.InvalidInput, "cannot set both require_tls and tls_skip_verify")) + return + } + params := ports.CreateRouteParams{ - Name: req.Name, - Pattern: req.PathPrefix, - Target: req.TargetURL, - Methods: req.Methods, - StripPrefix: req.StripPrefix, - RateLimit: req.RateLimit, - Priority: req.Priority, + Name: req.Name, + Pattern: req.PathPrefix, + Target: req.TargetURL, + Methods: req.Methods, + StripPrefix: req.StripPrefix, + RateLimit: req.RateLimit, + DialTimeout: req.DialTimeout, + ResponseHeaderTimeout: req.ResponseHeaderTimeout, + IdleConnTimeout: req.IdleConnTimeout, + TLSSkipVerify: req.TLSSkipVerify, + RequireTLS: req.RequireTLS, + AllowedCIDRs: req.AllowedCIDRs, + BlockedCIDRs: req.BlockedCIDRs, + MaxBodySize: req.MaxBodySize, + Priority: req.Priority, } route, err := h.svc.CreateRoute(c.Request.Context(), params) @@ -125,12 +155,32 @@ func (h *GatewayHandler) Proxy(c *gin.Context) { path = "/" + path } - proxy, params, ok := h.svc.GetProxy(c.Request.Method, path) + proxy, route, params, ok := h.svc.GetProxy(c.Request.Method, path) if !ok { - c.JSON(http.StatusNotFound, gin.H{"error": "No route found for " + path}) + httputil.Error(c, errors.New(errors.NotFound, "No route found for "+path)) + return + } + + // Apply IP allowlist/denylist (nil route means no route-specific rules apply) + if route != nil && !h.checkCIDR(c, route) { return } + // Apply request size limit - reject oversized requests before proxying + if route != nil && route.MaxBodySize > 0 { + if c.Request.ContentLength > route.MaxBodySize { + httputil.Error(c, errors.New(errors.InvalidInput, "request body too large")) + return + } + // For chunked bodies, pre-read and enforce limit + if c.Request.ContentLength < 0 { + c.Request.Body = &limitedReader{ + ReadCloser: c.Request.Body, + limit: route.MaxBodySize, + } + } + } + // Inject parameters into request context for downstream services if needed if len(params) > 0 { for k, v := range params { @@ -138,5 +188,133 @@ func (h *GatewayHandler) Proxy(c *gin.Context) { } } + // Inject trace headers + h.injectTraceHeaders(c) + proxy.ServeHTTP(c.Writer, c.Request) } + +func (h *GatewayHandler) injectTraceHeaders(c *gin.Context) { + requestID := c.GetHeader("X-Request-ID") + if requestID == "" { + requestID = uuid.New().String() + } + c.Request.Header.Set("X-Request-ID", requestID) + c.Header("X-Request-ID", requestID) + + // W3C TraceContext - preserve incoming trace headers if present + inboundTraceParent := c.GetHeader("traceparent") + if inboundTraceParent != "" { + c.Request.Header.Set("traceparent", inboundTraceParent) + c.Header("traceparent", inboundTraceParent) + inboundTraceState := c.GetHeader("tracestate") + if inboundTraceState != "" { + c.Request.Header.Set("tracestate", inboundTraceState) + c.Header("tracestate", inboundTraceState) + } + return + } + + // No inbound traceparent - generate new trace context + traceID := generateTraceID() + spanID := generateSpanID() + c.Request.Header.Set("traceparent", fmt.Sprintf("00-%s-%s-01", traceID, spanID)) + c.Request.Header.Set("tracestate", "") + c.Header("traceparent", fmt.Sprintf("00-%s-%s-01", traceID, spanID)) + c.Header("tracestate", "") +} + +func generateTraceID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + // crypto/rand.Read rarely fails, but handle it gracefully + return uuid.New().String() + } + return hex.EncodeToString(b) +} + +func generateSpanID() string { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + return uuid.New().String()[:16] + } + return hex.EncodeToString(b) +} + +func (h *GatewayHandler) checkCIDR(c *gin.Context, route *domain.GatewayRoute) bool { + clientIP := net.ParseIP(c.ClientIP()) + if clientIP == nil { + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "access denied: invalid client IP"}) + return false + } + + // Check blocked CIDRs first (takes precedence) + for _, cidrStr := range route.BlockedCIDRs { + _, ipNet, err := net.ParseCIDR(cidrStr) + if err != nil { + if h.logger != nil { + h.logger.Warn("invalid blocked CIDR", "cidr", cidrStr, "error", err) + } + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "access denied: misconfigured blocked CIDR"}) + return false + } + if ipNet.Contains(clientIP) { + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "access denied"}) + return false + } + } + + // If allowlist is non-empty, only allow matched IPs + if len(route.AllowedCIDRs) > 0 { + allowed := false + for _, cidrStr := range route.AllowedCIDRs { + _, ipNet, err := net.ParseCIDR(cidrStr) + if err != nil { + if h.logger != nil { + h.logger.Warn("invalid allowed CIDR", "cidr", cidrStr, "error", err) + } + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "access denied: misconfigured allowed CIDR"}) + return false + } + if ipNet.Contains(clientIP) { + allowed = true + break + } + } + if !allowed { + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "access denied"}) + return false + } + } + + return true +} + +// limitedReader wraps an io.ReadCloser and enforces a byte limit. +type limitedReader struct { + io.ReadCloser + limit int64 + read int64 +} + +// Read enforces the byte limit. When the limit is reached, io.EOF is returned +// even if the underlying reader returned an error (error shadowing for limit enforcement). +func (l *limitedReader) Read(p []byte) (n int, err error) { + if l.read >= l.limit { + return 0, io.EOF + } + toRead := l.limit - l.read + if int64(len(p)) > toRead { + p = p[:toRead] + } + n, err = l.ReadCloser.Read(p) + l.read += int64(n) + if l.read >= l.limit && err == nil { + err = io.EOF + } + return +} + +func (l *limitedReader) Close() error { + return l.ReadCloser.Close() +} diff --git a/internal/handlers/gateway_handler_cidr_test.go b/internal/handlers/gateway_handler_cidr_test.go new file mode 100644 index 000000000..c89025c26 --- /dev/null +++ b/internal/handlers/gateway_handler_cidr_test.go @@ -0,0 +1,66 @@ +package httphandlers + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/poyrazk/thecloud/internal/core/domain" + "github.com/stretchr/testify/assert" +) + +func TestCheckCIDR(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + remoteAddr string + blockedCIDRs []string + allowedCIDRs []string + expectedResult bool + }{ + { + name: "no restrictions should allow all", + remoteAddr: "10.0.0.1:12345", + blockedCIDRs: []string{}, + allowedCIDRs: []string{}, + expectedResult: true, + }, + { + name: "empty CIDR lists should allow all", + remoteAddr: "10.0.0.1:12345", + blockedCIDRs: nil, + allowedCIDRs: nil, + expectedResult: true, + }, + { + name: "invalid IP should be denied (fail closed)", + remoteAddr: "invalid-ip:12345", + blockedCIDRs: nil, + allowedCIDRs: nil, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + gin.SetMode(gin.TestMode) + handler := &GatewayHandler{logger: nil} + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + if tt.remoteAddr != "" { + c.Request.RemoteAddr = tt.remoteAddr + } + + route := &domain.GatewayRoute{ + BlockedCIDRs: tt.blockedCIDRs, + AllowedCIDRs: tt.allowedCIDRs, + } + + result := handler.checkCIDR(c, route) + assert.Equal(t, tt.expectedResult, result, tt.name) + }) + } +} \ No newline at end of file diff --git a/internal/handlers/gateway_handler_test.go b/internal/handlers/gateway_handler_test.go index 2f51b41db..f55c22fe8 100644 --- a/internal/handlers/gateway_handler_test.go +++ b/internal/handlers/gateway_handler_test.go @@ -49,21 +49,21 @@ func (m *mockGatewayService) CreateRoute(ctx context.Context, params ports.Creat return r0, args.Error(1) } -func (m *mockGatewayService) GetProxy(method, path string) (*httputil.ReverseProxy, map[string]string, bool) { +func (m *mockGatewayService) GetProxy(method, path string) (*httputil.ReverseProxy, *domain.GatewayRoute, map[string]string, bool) { args := m.Called(method, path) if args.Get(0) == nil { - return nil, nil, args.Bool(2) + return nil, nil, nil, args.Bool(3) + } + var route *domain.GatewayRoute + if r := args.Get(1); r != nil { + route = r.(*domain.GatewayRoute) } var params map[string]string - if p := args.Get(1); p != nil { - var ok bool - params, ok = p.(map[string]string) - if !ok { - params = nil - } + if p := args.Get(2); p != nil { + params = p.(map[string]string) } r0, _ := args.Get(0).(*httputil.ReverseProxy) - return r0, params, args.Bool(2) + return r0, route, params, args.Bool(3) } func (m *mockGatewayService) ListRoutes(ctx context.Context) ([]*domain.GatewayRoute, error) { @@ -88,7 +88,7 @@ func (m *mockGatewayService) RefreshRoutes(ctx context.Context) error { func setupGatewayHandlerTest(_ *testing.T) (*mockGatewayService, *GatewayHandler, *gin.Engine) { gin.SetMode(gin.TestMode) svc := new(mockGatewayService) - handler := NewGatewayHandler(svc) + handler := NewGatewayHandler(svc, nil) r := gin.New() return svc, handler, r } @@ -168,7 +168,7 @@ func TestGatewayHandlerProxyNotFound(t *testing.T) { r.Any(gwProxyPath, handler.Proxy) - svc.On("GetProxy", "GET", "/unknown").Return(nil, nil, false) + svc.On("GetProxy", "GET", "/unknown").Return(nil, nil, nil, false) req, err := http.NewRequest(http.MethodGet, "/gw/unknown", nil) require.NoError(t, err) @@ -200,7 +200,7 @@ func TestGatewayHandlerProxySuccess(t *testing.T) { // Gateway Handler implementation: c.Request.URL.Path = c.Param("proxy")? or just calls ServeHTTP. // If GatewayHandler calls `proxy.ServeHTTP(w, c.Request)`, the request path "/gw/api" is sent to target. // Test server expects any path. - svc.On("GetProxy", "GET", "/api").Return(proxy, map[string]string{}, true) + svc.On("GetProxy", "GET", "/api").Return(proxy, (*domain.GatewayRoute)(nil), map[string]string{}, true) req, err := http.NewRequest(http.MethodGet, gwAPITestPath, nil) require.NoError(t, err) @@ -223,7 +223,7 @@ func TestGatewayHandlerProxyWithoutSlash(t *testing.T) { defer ts.Close() targetURL, _ := url.Parse(ts.URL) - svc.On("GetProxy", "GET", "/api").Return(httputil.NewSingleHostReverseProxy(targetURL), map[string]string{}, true) + svc.On("GetProxy", "GET", "/api").Return(httputil.NewSingleHostReverseProxy(targetURL), (*domain.GatewayRoute)(nil), map[string]string{}, true) req, err := http.NewRequest(http.MethodGet, gwAPITestPath, nil) require.NoError(t, err) @@ -246,7 +246,7 @@ func TestGatewayHandlerProxyWithSlash(t *testing.T) { defer ts.Close() targetURL, _ := url.Parse(ts.URL) - svc.On("GetProxy", "GET", "//api").Return(httputil.NewSingleHostReverseProxy(targetURL), map[string]string{}, true) + svc.On("GetProxy", "GET", "//api").Return(httputil.NewSingleHostReverseProxy(targetURL), (*domain.GatewayRoute)(nil), map[string]string{}, true) req, err := http.NewRequest(http.MethodGet, "/gw//api", nil) require.NoError(t, err) @@ -296,7 +296,7 @@ func TestGatewayHandlerListError(t *testing.T) { func TestGatewayHandlerProxyParamWithoutSlash(t *testing.T) { t.Parallel() mockSvc := new(mockGatewayService) - handler := NewGatewayHandler(mockSvc) + handler := NewGatewayHandler(mockSvc, nil) gin.SetMode(gin.TestMode) // Manually create context to pass parameter without slash @@ -313,7 +313,7 @@ func TestGatewayHandlerProxyParamWithoutSlash(t *testing.T) { targetURL, _ := url.Parse(ts.URL) // Expect GetProxy to be called with "/api" (slash added) - mockSvc.On("GetProxy", "GET", "/api").Return(httputil.NewSingleHostReverseProxy(targetURL), map[string]string{}, true) + mockSvc.On("GetProxy", "GET", "/api").Return(httputil.NewSingleHostReverseProxy(targetURL), (*domain.GatewayRoute)(nil), map[string]string{}, true) handler.Proxy(c) diff --git a/internal/handlers/gateway_handler_trace_test.go b/internal/handlers/gateway_handler_trace_test.go new file mode 100644 index 000000000..d109fdd3a --- /dev/null +++ b/internal/handlers/gateway_handler_trace_test.go @@ -0,0 +1,43 @@ +package httphandlers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTraceIDGeneration(t *testing.T) { + tests := []struct { + name string + generator func() string + expectedLen int + checkUnique bool + }{ + { + name: "trace ID has correct length", + generator: generateTraceID, + expectedLen: 32, + checkUnique: true, + }, + { + name: "span ID has correct length", + generator: generateSpanID, + expectedLen: 16, + checkUnique: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + id := tt.generator() + assert.Len(t, id, tt.expectedLen, "ID should have expected length") + + if tt.checkUnique { + id2 := tt.generator() + assert.NotEqual(t, id, id2, "consecutive IDs should be unique") + } + }) + } +} \ No newline at end of file diff --git a/pkg/ratelimit/limiter.go b/pkg/ratelimit/limiter.go index 982a56043..83ebc34a4 100644 --- a/pkg/ratelimit/limiter.go +++ b/pkg/ratelimit/limiter.go @@ -8,12 +8,14 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/google/uuid" "golang.org/x/time/rate" ) // IPRateLimiter manages rate limiters for different IPs/clients type IPRateLimiter struct { ips map[string]*rate.Limiter + routes map[uuid.UUID]map[string]*rate.Limiter // per-route limiters mu sync.RWMutex rate rate.Limit burst int @@ -24,6 +26,7 @@ type IPRateLimiter struct { func NewIPRateLimiter(r rate.Limit, b int, logger *slog.Logger) *IPRateLimiter { i := &IPRateLimiter{ ips: make(map[string]*rate.Limiter), + routes: make(map[uuid.UUID]map[string]*rate.Limiter), rate: r, burst: b, logger: logger, @@ -49,6 +52,29 @@ func (i *IPRateLimiter) GetLimiter(key string) *rate.Limiter { return limiter } +// GetRouteLimiter returns a rate limiter for a specific route and client key. +// This enables per-route rate limiting while maintaining per-client tracking. +// The r and burst parameters specify the per-route rate limits. +func (i *IPRateLimiter) GetRouteLimiter(routeID uuid.UUID, key string, r rate.Limit, burst int) *rate.Limiter { + i.mu.Lock() + defer i.mu.Unlock() + + if i.routes[routeID] == nil { + i.routes[routeID] = make(map[string]*rate.Limiter) + } + + limiter, exists := i.routes[routeID][key] + if !exists { + limiter = rate.NewLimiter(r, burst) + i.routes[routeID][key] = limiter + return limiter + } + + // Update existing limiter with new rate/burst if different + limiter.SetLimit(r) + return limiter +} + // cleanupLoop removes old entries (rudimentary GC) func (i *IPRateLimiter) cleanupLoop() { for { @@ -57,6 +83,7 @@ func (i *IPRateLimiter) cleanupLoop() { // Start fresh every cleanup cycle for simplicity // A production robust implementation would track last access time i.ips = make(map[string]*rate.Limiter) + i.routes = make(map[uuid.UUID]map[string]*rate.Limiter) i.mu.Unlock() } }