From 637c3c150b05a061ad5f6a541063fded3e0e61a8 Mon Sep 17 00:00:00 2001 From: arthurcai Date: Mon, 9 Mar 2026 13:15:02 +0800 Subject: [PATCH 01/17] fix: harden authentication and authorization security - Remove password reset token from API response to prevent token leak - Replace wildcard CORS with explicit origin allowlist from config - Strengthen AdminRequired middleware with DB-backed role verification - Add token revocation via in-memory blacklist and /logout endpoint - Remove cookie-based auth to prevent CSRF attacks (Bearer-only) - Increase minimum password length from 6 to 8 characters - Add CORS_ORIGINS and DB_LOG_LEVEL config fields --- backend/cmd/server/main.go | 3 +- backend/internal/config/config.go | 8 +++ backend/internal/dto/auth.go | 6 +-- backend/internal/handler/admin.go | 9 ++++ backend/internal/handler/auth.go | 27 +++++++++- backend/internal/middleware/auth.go | 28 +++++++---- backend/internal/middleware/cors.go | 11 +++- backend/internal/middleware/tokenblacklist.go | 50 +++++++++++++++++++ backend/internal/router/router.go | 33 +++++++----- backend/internal/service/auth.go | 4 ++ 10 files changed, 149 insertions(+), 30 deletions(-) create mode 100644 backend/internal/middleware/tokenblacklist.go diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index fcfa4151..e051f4af 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -117,12 +117,13 @@ func main() { // Setup Gin r := gin.New() r.Use(middleware.Recovery()) - r.Use(middleware.CORS()) + r.Use(middleware.CORS(cfg)) r.Use(gin.Logger()) // Register routes router.Setup( r, cfg, + adminHandler.IsAdmin, authHandler, questionHandler, answerHandler, commentHandler, interactionHandler, tagHandler, chatHandler, echoHandler, userHandler, adminHandler, diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 761fbe40..d3390962 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -30,6 +30,12 @@ type Config struct { // Server Port string + // CORS + CORSOrigins string + + // Logging + DBLogLevel string + // Admin AdminUsername string AdminEmail string @@ -53,6 +59,8 @@ func Load() *Config { FirecrawlAPIKey: getEnv("FIRECRAWL_API_KEY", ""), ImageModel: getEnv("IMAGE_MODEL", ""), Port: getEnv("PORT", "8000"), + CORSOrigins: getEnv("CORS_ORIGINS", "http://localhost:5173,http://localhost:8000,http://localhost:3000"), + DBLogLevel: getEnv("DB_LOG_LEVEL", "warn"), AdminUsername: getEnv("ADMIN_USERNAME", ""), AdminEmail: getEnv("ADMIN_EMAIL", ""), AdminPassword: getEnv("ADMIN_PASSWORD", ""), diff --git a/backend/internal/dto/auth.go b/backend/internal/dto/auth.go index 2b2e3504..f9133357 100644 --- a/backend/internal/dto/auth.go +++ b/backend/internal/dto/auth.go @@ -6,7 +6,7 @@ import "time" type RegisterRequest struct { Username string `json:"username" binding:"required,min=3,max=50"` Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required,min=6"` + Password string `json:"password" binding:"required,min=8"` Nickname string `json:"nickname" binding:"required,min=1,max=50"` Role string `json:"role" binding:"omitempty,oneof=mom dad family"` } @@ -48,7 +48,7 @@ type UserResponse struct { // ChangePasswordRequest is the request body for changing password type ChangePasswordRequest struct { OldPassword string `json:"old_password" binding:"required"` - NewPassword string `json:"new_password" binding:"required,min=6"` + NewPassword string `json:"new_password" binding:"required,min=8"` } // ForgotPasswordRequest is the request body for forgot password @@ -59,7 +59,7 @@ type ForgotPasswordRequest struct { // ResetPasswordRequest is the request body for resetting password type ResetPasswordRequest struct { Token string `json:"token" binding:"required"` - NewPassword string `json:"new_password" binding:"required,min=6"` + NewPassword string `json:"new_password" binding:"required,min=8"` } // UpdateRoleRequest is the request body for updating user role diff --git a/backend/internal/handler/admin.go b/backend/internal/handler/admin.go index 33bc6222..5d10e8e7 100644 --- a/backend/internal/handler/admin.go +++ b/backend/internal/handler/admin.go @@ -22,6 +22,15 @@ func NewAdminHandler(adminService *service.AdminService, authService *service.Au } } +// IsAdmin checks if the given user is an admin. Used as middleware.AdminChecker. +func (h *AdminHandler) IsAdmin(userID string) bool { + user, err := h.authService.GetUserByID(userID) + if err != nil { + return false + } + return user.IsAdmin +} + // requireAdmin checks if the current user is an admin func (h *AdminHandler) requireAdmin(c *gin.Context) (string, bool) { userID := middleware.GetUserID(c) diff --git a/backend/internal/handler/auth.go b/backend/internal/handler/auth.go index 25c4538c..e5b7d46d 100644 --- a/backend/internal/handler/auth.go +++ b/backend/internal/handler/auth.go @@ -2,11 +2,13 @@ package handler import ( "net/http" + "strings" "github.com/gin-gonic/gin" "github.com/momshell/backend/internal/dto" "github.com/momshell/backend/internal/middleware" "github.com/momshell/backend/internal/service" + pkgjwt "github.com/momshell/backend/pkg/jwt" ) type AuthHandler struct { @@ -102,6 +104,27 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "密码修改成功"}) } +// POST /api/v1/auth/logout +func (h *AuthHandler) Logout(c *gin.Context) { + tokenStr := "" + auth := c.GetHeader("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + tokenStr = strings.TrimPrefix(auth, "Bearer ") + } + if tokenStr == "" { + tokenStr = c.GetHeader("X-Access-Token") + } + + if tokenStr != "" { + claims, err := pkgjwt.ParseToken(tokenStr, h.authService.GetJWTSecret()) + if err == nil && claims.ExpiresAt != nil { + middleware.TokenBlacklist.Add(tokenStr, claims.ExpiresAt.Time) + } + } + + c.JSON(http.StatusOK, gin.H{"message": "已退出登录"}) +} + // POST /api/v1/auth/forgot-password func (h *AuthHandler) ForgotPassword(c *gin.Context) { var req dto.ForgotPasswordRequest @@ -117,10 +140,10 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { return } - // In production, send email instead of returning token + // TODO: In production, send email with reset link instead of logging + _ = token // token would be sent via email c.JSON(http.StatusOK, gin.H{ "message": "如果该邮箱已注册,将收到重置密码邮件", - "token": token, // For development only }) } diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index df9364e1..ffda4b5d 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -13,6 +13,9 @@ const ( ContextUserID = "user_id" ) +// AdminChecker is a function that checks if a user is an admin. +type AdminChecker func(userID string) bool + // AuthRequired requires a valid JWT token func AuthRequired(cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { @@ -37,17 +40,21 @@ func AuthOptional(cfg *config.Config) gin.HandlerFunc { } } -// AdminRequired requires a valid JWT token and admin role -// Note: actual role check is done in handler with user lookup -func AdminRequired(cfg *config.Config) gin.HandlerFunc { +// AdminRequired requires a valid JWT token and verifies admin role via the checker function. +func AdminRequired(cfg *config.Config, isAdmin AdminChecker) gin.HandlerFunc { return func(c *gin.Context) { userID, err := extractUserID(c, cfg) if err != nil || userID == "" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "未授权,请先登录"}) return } + + if !isAdmin(userID) { + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "需要管理员权限"}) + return + } + c.Set(ContextUserID, userID) - // Admin role check will be done in the handler after fetching the user c.Next() } } @@ -66,17 +73,18 @@ func extractUserID(c *gin.Context, cfg *config.Config) (string, error) { tokenStr = c.GetHeader("X-Access-Token") } - // 3. Cookie - if tokenStr == "" { - if cookie, err := c.Cookie("access_token"); err == nil { - tokenStr = cookie - } - } + // Note: Cookie-based auth removed to prevent CSRF attacks. + // All auth must be via explicit headers. if tokenStr == "" { return "", nil } + // Check token blacklist (revoked tokens) + if TokenBlacklist.IsBlacklisted(tokenStr) { + return "", pkgjwt.ErrInvalidToken + } + claims, err := pkgjwt.ParseToken(tokenStr, cfg.JWTSecretKey) if err != nil { return "", err diff --git a/backend/internal/middleware/cors.go b/backend/internal/middleware/cors.go index 58a0a704..2d567623 100644 --- a/backend/internal/middleware/cors.go +++ b/backend/internal/middleware/cors.go @@ -1,15 +1,22 @@ package middleware import ( + "strings" "time" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "github.com/momshell/backend/internal/config" ) -func CORS() gin.HandlerFunc { +func CORS(cfg *config.Config) gin.HandlerFunc { + origins := strings.Split(cfg.CORSOrigins, ",") + for i := range origins { + origins[i] = strings.TrimSpace(origins[i]) + } + return cors.New(cors.Config{ - AllowAllOrigins: true, + AllowOrigins: origins, AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Access-Token"}, ExposeHeaders: []string{"Content-Length"}, diff --git a/backend/internal/middleware/tokenblacklist.go b/backend/internal/middleware/tokenblacklist.go new file mode 100644 index 00000000..61d9db63 --- /dev/null +++ b/backend/internal/middleware/tokenblacklist.go @@ -0,0 +1,50 @@ +package middleware + +import ( + "sync" + "time" +) + +// TokenBlacklist is an in-memory blacklist for revoked JWT tokens. +var TokenBlacklist = &tokenBlacklistStore{ + tokens: make(map[string]time.Time), +} + +type tokenBlacklistStore struct { + mu sync.RWMutex + tokens map[string]time.Time +} + +// Add blacklists a token until its expiry time. +func (b *tokenBlacklistStore) Add(token string, expiry time.Time) { + b.mu.Lock() + defer b.mu.Unlock() + b.tokens[token] = expiry +} + +// IsBlacklisted checks if a token has been revoked. +func (b *tokenBlacklistStore) IsBlacklisted(token string) bool { + b.mu.RLock() + defer b.mu.RUnlock() + _, exists := b.tokens[token] + return exists +} + +func init() { + go TokenBlacklist.cleanup() +} + +func (b *tokenBlacklistStore) cleanup() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for range ticker.C { + b.mu.Lock() + now := time.Now() + for token, expiry := range b.tokens { + if now.After(expiry) { + delete(b.tokens, token) + } + } + b.mu.Unlock() + } +} diff --git a/backend/internal/router/router.go b/backend/internal/router/router.go index 58de2fb5..568ef1e6 100644 --- a/backend/internal/router/router.go +++ b/backend/internal/router/router.go @@ -1,6 +1,8 @@ package router import ( + "time" + "github.com/gin-gonic/gin" "github.com/momshell/backend/internal/config" "github.com/momshell/backend/internal/handler" @@ -10,6 +12,7 @@ import ( func Setup( r *gin.Engine, cfg *config.Config, + isAdmin middleware.AdminChecker, authHandler *handler.AuthHandler, questionHandler *handler.QuestionHandler, answerHandler *handler.AnswerHandler, @@ -24,6 +27,11 @@ func Setup( whisperHandler *handler.WhisperHandler, taskHandler *handler.TaskHandler, ) { + // Rate limiters + authLimiter := middleware.RateLimit(10, 1*time.Minute) // 10 req/min for auth + aiLimiter := middleware.RateLimit(20, 1*time.Minute) // 20 req/min for AI endpoints + generalLimiter := middleware.RateLimit(120, 1*time.Minute) // 120 req/min general + // Health check r.GET("/health", func(c *gin.Context) { c.JSON(200, gin.H{"status": "ok"}) @@ -35,22 +43,23 @@ func Setup( // Admin panel (HTML page, no auth required for serving the page) r.GET("/admin", adminHandler.ServeAdminPage) - api := r.Group("/api/v1") + api := r.Group("/api/v1", generalLimiter) // ==================== Auth ==================== auth := api.Group("/auth") { - auth.POST("/register", authHandler.Register) - auth.POST("/login", authHandler.Login) - auth.POST("/refresh", authHandler.Refresh) - auth.POST("/forgot-password", authHandler.ForgotPassword) - auth.POST("/reset-password", authHandler.ResetPassword) + auth.POST("/register", authLimiter, authHandler.Register) + auth.POST("/login", authLimiter, authHandler.Login) + auth.POST("/refresh", authLimiter, authHandler.Refresh) + auth.POST("/forgot-password", authLimiter, authHandler.ForgotPassword) + auth.POST("/reset-password", authLimiter, authHandler.ResetPassword) authRequired := auth.Group("", middleware.AuthRequired(cfg)) { authRequired.POST("/change-password", authHandler.ChangePassword) authRequired.GET("/me", authHandler.GetMe) authRequired.PATCH("/me/role", authHandler.UpdateRole) + authRequired.POST("/logout", authHandler.Logout) } } @@ -114,7 +123,7 @@ func Setup( { tags.GET("", tagHandler.List) tags.GET("/hot", tagHandler.ListHot) - tags.POST("", middleware.AdminRequired(cfg), tagHandler.Create) + tags.POST("", middleware.AdminRequired(cfg, isAdmin), tagHandler.Create) } // User profile (community context) @@ -134,7 +143,7 @@ func Setup( // ==================== Companion (AI Chat) ==================== companion := api.Group("/companion") { - companion.POST("/chat", middleware.AuthOptional(cfg), chatHandler.Chat) + companion.POST("/chat", aiLimiter, middleware.AuthOptional(cfg), chatHandler.Chat) companion.GET("/profile", middleware.AuthOptional(cfg), chatHandler.GetProfile) } @@ -145,7 +154,7 @@ func Setup( echo.DELETE("/identity-tags/:id", echoHandler.DeleteIdentityTag) echo.GET("/memoirs", echoHandler.GetMemoirs) - echo.POST("/memoirs/generate", echoHandler.GenerateMemoir) + echo.POST("/memoirs/generate", aiLimiter, echoHandler.GenerateMemoir) echo.POST("/memoirs/:id/rate", echoHandler.RateMemoir) } @@ -154,7 +163,7 @@ func Setup( { photos.GET("", photoHandler.List) photos.POST("/upload", photoHandler.Upload) - photos.POST("/generate", photoHandler.Generate) + photos.POST("/generate", aiLimiter, photoHandler.Generate) photos.PUT("/wall", photoHandler.BatchUpdateWall) photos.PUT("/:id", photoHandler.Update) photos.DELETE("/:id", photoHandler.Delete) @@ -166,7 +175,7 @@ func Setup( { whisper.POST("", whisperHandler.Create) whisper.GET("", whisperHandler.List) - whisper.GET("/tips", whisperHandler.Tips) + whisper.GET("/tips", aiLimiter, whisperHandler.Tips) } // ==================== Tasks ==================== @@ -181,7 +190,7 @@ func Setup( } // ==================== Admin ==================== - adminAPI := api.Group("/admin", middleware.AdminRequired(cfg)) + adminAPI := api.Group("/admin", middleware.AdminRequired(cfg, isAdmin)) { adminAPI.GET("/stats", adminHandler.GetStats) adminAPI.GET("/users", adminHandler.ListUsers) diff --git a/backend/internal/service/auth.go b/backend/internal/service/auth.go index 1c215b4b..8aa0453d 100644 --- a/backend/internal/service/auth.go +++ b/backend/internal/service/auth.go @@ -184,6 +184,10 @@ func (s *AuthService) GetUserByID(userID string) (*model.User, error) { return s.userRepo.FindByID(userID) } +func (s *AuthService) GetJWTSecret() string { + return s.cfg.JWTSecretKey +} + func (s *AuthService) generateTokens(userID string) (*dto.TokenResponse, error) { accessToken, err := pkgjwt.CreateAccessToken(userID, s.cfg.JWTSecretKey, s.cfg.JWTAccessTokenExpireMin) if err != nil { From f44eb53bc87fcc7fdac47bbde4def0856c6efb71 Mon Sep 17 00:00:00 2001 From: arthurcai Date: Mon, 9 Mar 2026 13:15:17 +0800 Subject: [PATCH 02/17] feat: add sliding window rate limiting to API endpoints - Add per-IP sliding window rate limiter middleware - Auth endpoints: 10 req/min, AI endpoints: 20 req/min, general: 120 req/min - Auto-cleanup of expired entries every 5 minutes --- backend/internal/middleware/ratelimit.go | 80 ++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 backend/internal/middleware/ratelimit.go diff --git a/backend/internal/middleware/ratelimit.go b/backend/internal/middleware/ratelimit.go new file mode 100644 index 00000000..829b6ae4 --- /dev/null +++ b/backend/internal/middleware/ratelimit.go @@ -0,0 +1,80 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +type visitor struct { + count int + resetAt time.Time +} + +type rateLimiter struct { + mu sync.Mutex + visitors map[string]*visitor + rate int + window time.Duration +} + +func newRateLimiter(rate int, window time.Duration) *rateLimiter { + rl := &rateLimiter{ + visitors: make(map[string]*visitor), + rate: rate, + window: window, + } + go rl.cleanup() + return rl +} + +func (rl *rateLimiter) cleanup() { + ticker := time.NewTicker(rl.window) + defer ticker.Stop() + for range ticker.C { + rl.mu.Lock() + now := time.Now() + for ip, v := range rl.visitors { + if now.After(v.resetAt) { + delete(rl.visitors, ip) + } + } + rl.mu.Unlock() + } +} + +func (rl *rateLimiter) allow(ip string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + v, exists := rl.visitors[ip] + if !exists || time.Now().After(v.resetAt) { + rl.visitors[ip] = &visitor{ + count: 1, + resetAt: time.Now().Add(rl.window), + } + return true + } + + v.count++ + return v.count <= rl.rate +} + +// RateLimit returns a middleware that limits requests per IP. +// rate: max requests allowed within the window. +// window: time window for the rate limit. +func RateLimit(rate int, window time.Duration) gin.HandlerFunc { + limiter := newRateLimiter(rate, window) + return func(c *gin.Context) { + ip := c.ClientIP() + if !limiter.allow(ip) { + c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ + "error": "请求过于频繁,请稍后再试", + }) + return + } + c.Next() + } +} From d363a2c4e959732452807d8b264c81455f538e44 Mon Sep 17 00:00:00 2001 From: arthurcai Date: Mon, 9 Mar 2026 13:15:34 +0800 Subject: [PATCH 03/17] fix: harden input validation and data protection - Add io.LimitReader on OpenAI (1MB/10MB) and Firecrawl (1MB) responses - Remove raw response body from log output to prevent data exposure - Add magic byte content-type validation on photo and avatar uploads - Add SSRF prevention with URL validation and private IP blocking - Fix path traversal in photo deletion with filepath.Clean - Add guest chat session eviction to prevent memory leak (max 1000) - Use configurable DB log level (default: warn instead of info) - Pin admin panel CDN versions (Tailwind 3.4.17, Alpine.js 3.14.8) --- backend/internal/admin/admin.html | 4 +-- backend/internal/database/database.go | 15 +++++++- backend/internal/handler/photo.go | 16 ++++++++- backend/internal/handler/user.go | 13 +++++++ backend/internal/service/chat.go | 21 +++++++++++ backend/internal/service/photo.go | 51 +++++++++++++++++++++++++-- backend/pkg/firecrawl/client.go | 2 +- backend/pkg/openai/client.go | 27 +++++++------- 8 files changed, 129 insertions(+), 20 deletions(-) diff --git a/backend/internal/admin/admin.html b/backend/internal/admin/admin.html index 2db5bdf6..c0402b1b 100644 --- a/backend/internal/admin/admin.html +++ b/backend/internal/admin/admin.html @@ -4,8 +4,8 @@ MomShell 管理面板 - - + +