diff --git a/.env.example b/.env.example index 78b0cfa6..b8800fae 100644 --- a/.env.example +++ b/.env.example @@ -37,3 +37,9 @@ VITE_API_BASE_URL=http://localhost:8000 ADMIN_USERNAME= ADMIN_EMAIL= ADMIN_PASSWORD= + +# ==================== AI Service User (optional) ==================== +# Internal service account for AI-generated content. If not set, +# a random password is generated on each startup (the AI user never logs in). +AI_USER_USERNAME=xiaoshiguang +AI_USER_PASSWORD= diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 69e259f2..3ae29253 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -1,8 +1,11 @@ package main import ( + "crypto/rand" + "encoding/hex" "fmt" "log" + "os" "github.com/gin-gonic/gin" "github.com/momshell/backend/internal/config" @@ -74,7 +77,7 @@ func main() { firecrawlClient = firecrawl.NewClient(cfg.FirecrawlAPIKey) } - chatService := service.NewChatService(chatClient, chatRepo, userRepo, firecrawlClient) + chatService := service.NewChatService(chatClient, chatRepo, userRepo, firecrawlClient, cfg.JWTSecretKey) echoService := service.NewEchoService(chatClient, echoRepo, userRepo) photoService := service.NewPhotoService(photoRepo, userRepo, chatClient, cfg.ImageModel) whisperService := service.NewWhisperService(whisperRepo, userRepo, chatClient) @@ -136,10 +139,21 @@ func main() { router.Setup( r, cfg, adminHandler.IsAdmin, - authHandler, questionHandler, answerHandler, - commentHandler, interactionHandler, tagHandler, - chatHandler, echoHandler, userHandler, adminHandler, - photoHandler, whisperHandler, taskHandler, + &router.Handlers{ + Auth: authHandler, + Question: questionHandler, + Answer: answerHandler, + Comment: commentHandler, + Interaction: interactionHandler, + Tag: tagHandler, + Chat: chatHandler, + Echo: echoHandler, + User: userHandler, + Admin: adminHandler, + Photo: photoHandler, + Whisper: whisperHandler, + Task: taskHandler, + }, ) // Start server @@ -186,19 +200,35 @@ func createInitialAdmin(cfg *config.Config, userRepo *repository.UserRepo) { } func ensureAIUser(userRepo *repository.UserRepo) string { - user, err := userRepo.FindByUsernameOrEmail("xiaoshiguang") + aiUsername := os.Getenv("AI_USER_USERNAME") + if aiUsername == "" { + aiUsername = "xiaoshiguang" + log.Println("[WARN] AI_USER_USERNAME not set, using default: xiaoshiguang") + } + aiPasswd := os.Getenv("AI_USER_PASSWORD") + if aiPasswd == "" { + // Generate a random password; AI user never logs in interactively. + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + log.Fatalf("failed to generate random AI user password: %v", err) + } + aiPasswd = hex.EncodeToString(b) + log.Println("[WARN] AI_USER_PASSWORD not set, using random password") + } + + user, err := userRepo.FindByUsernameOrEmail(aiUsername) if err == nil { return user.ID } - hash, err := password.Hash("ai-user-no-login") + hash, err := password.Hash(aiPasswd) if err != nil { log.Printf("Failed to hash AI user password: %v", err) return "" } aiUser := &model.User{ - Username: "xiaoshiguang", + Username: aiUsername, Email: "ai@momshell.com", PasswordHash: hash, Nickname: "小石光", diff --git a/backend/internal/handler/admin.go b/backend/internal/handler/admin.go index a6bcb24d..b22fe11a 100644 --- a/backend/internal/handler/admin.go +++ b/backend/internal/handler/admin.go @@ -1,6 +1,7 @@ package handler import ( + "log" "net/http" "github.com/gin-gonic/gin" @@ -142,6 +143,9 @@ func (h *AdminHandler) UpdateConfig(c *gin.Context) { return } + adminID := middleware.GetUserID(c) + log.Printf("[SECURITY] admin_config_change | ip=%s | user=%s | detail=config update requested", c.ClientIP(), adminID) + if err := h.adminService.UpdateConfig(req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return diff --git a/backend/internal/handler/auth.go b/backend/internal/handler/auth.go index e4265636..0eed8b66 100644 --- a/backend/internal/handler/auth.go +++ b/backend/internal/handler/auth.go @@ -66,7 +66,7 @@ func (h *AuthHandler) Login(c *gin.Context) { return } - resp, err := h.authService.Login(req) + resp, err := h.authService.Login(req, c.ClientIP()) if err != nil { status := http.StatusUnauthorized if err.Error() == "账号已禁用" || err.Error() == "账号已被封禁" { @@ -184,8 +184,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { return } - // TODO: In production, send email with reset link instead of logging - _ = token // token would be sent via email + // In a production deployment, the token would be sent via email with a reset link. + // For now, it is intentionally unused to avoid leaking it in the response. + _ = token c.JSON(http.StatusOK, gin.H{ "message": "如果该邮箱已注册,将收到重置密码邮件", }) diff --git a/backend/internal/handler/interaction.go b/backend/internal/handler/interaction.go index 87eaea9c..4d609ec8 100644 --- a/backend/internal/handler/interaction.go +++ b/backend/internal/handler/interaction.go @@ -26,7 +26,7 @@ func (h *InteractionHandler) CreateLike(c *gin.Context) { } userID := middleware.GetUserID(c) - isLiked, newCount, err := h.communityService.ToggleLike(userID, req.TargetType, req.TargetID) + isLiked, newCount, err := h.communityService.AddLike(userID, req.TargetType, req.TargetID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -44,7 +44,7 @@ func (h *InteractionHandler) DeleteLike(c *gin.Context) { } userID := middleware.GetUserID(c) - isLiked, newCount, err := h.communityService.ToggleLike(userID, req.TargetType, req.TargetID) + isLiked, newCount, err := h.communityService.RemoveLike(userID, req.TargetType, req.TargetID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go index ffda4b5d..55cd5553 100644 --- a/backend/internal/middleware/auth.go +++ b/backend/internal/middleware/auth.go @@ -1,6 +1,7 @@ package middleware import ( + "log" "net/http" "strings" @@ -82,15 +83,18 @@ func extractUserID(c *gin.Context, cfg *config.Config) (string, error) { // Check token blacklist (revoked tokens) if TokenBlacklist.IsBlacklisted(tokenStr) { + log.Printf("[SECURITY] jwt_validation_failed | ip=%s | user=unknown | detail=token is blacklisted", c.ClientIP()) return "", pkgjwt.ErrInvalidToken } claims, err := pkgjwt.ParseToken(tokenStr, cfg.JWTSecretKey) if err != nil { + log.Printf("[SECURITY] jwt_validation_failed | ip=%s | user=unknown | detail=%v", c.ClientIP(), err) return "", err } if claims.Type != "access" { + log.Printf("[SECURITY] jwt_validation_failed | ip=%s | user=%s | detail=invalid token type: %s", c.ClientIP(), claims.Subject, claims.Type) return "", pkgjwt.ErrInvalidToken } diff --git a/backend/internal/middleware/cors.go b/backend/internal/middleware/cors.go index 4b295647..0edb4d98 100644 --- a/backend/internal/middleware/cors.go +++ b/backend/internal/middleware/cors.go @@ -11,23 +11,24 @@ import ( func CORS(cfg *config.Config) gin.HandlerFunc { corsConfig := cors.Config{ - AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, - AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Access-Token"}, - ExposeHeaders: []string{"Content-Length"}, - AllowCredentials: true, - MaxAge: 12 * time.Hour, + AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, + AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Access-Token"}, + ExposeHeaders: []string{"Content-Length"}, + MaxAge: 12 * time.Hour, } trimmed := strings.TrimSpace(cfg.CORSOrigins) if trimmed == "" || trimmed == "*" { - // Allow all origins (compatible with AllowCredentials) + // Allow all origins but disable credentials to prevent CSRF-like attacks corsConfig.AllowOriginFunc = func(origin string) bool { return true } + corsConfig.AllowCredentials = false } else { origins := strings.Split(trimmed, ",") for i := range origins { origins[i] = strings.TrimSpace(origins[i]) } corsConfig.AllowOrigins = origins + corsConfig.AllowCredentials = true } return cors.New(corsConfig) diff --git a/backend/internal/middleware/security.go b/backend/internal/middleware/security.go index c98c4275..058b8861 100644 --- a/backend/internal/middleware/security.go +++ b/backend/internal/middleware/security.go @@ -1,6 +1,10 @@ package middleware -import "github.com/gin-gonic/gin" +import ( + "strings" + + "github.com/gin-gonic/gin" +) // SecurityHeaders adds standard security headers to all responses. func SecurityHeaders() gin.HandlerFunc { @@ -10,6 +14,14 @@ func SecurityHeaders() gin.HandlerFunc { c.Header("X-XSS-Protection", "1; mode=block") c.Header("Referrer-Policy", "strict-origin-when-cross-origin") c.Header("Permissions-Policy", "camera=(self), microphone=(), geolocation=()") + c.Header("Strict-Transport-Security", "max-age=63072000; includeSubDomains") + + // Add cache-control headers for API responses + if strings.HasPrefix(c.Request.URL.Path, "/api/") { + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + } + c.Next() } } diff --git a/backend/internal/middleware/tokenblacklist.go b/backend/internal/middleware/tokenblacklist.go index 1aeb2ec1..24689a43 100644 --- a/backend/internal/middleware/tokenblacklist.go +++ b/backend/internal/middleware/tokenblacklist.go @@ -1,6 +1,7 @@ package middleware import ( + "log" "sync" "time" ) @@ -22,7 +23,7 @@ func (b *tokenBlacklistStore) Add(token string, expiry time.Time) { b.mu.Lock() defer b.mu.Unlock() if len(b.tokens) >= maxBlacklistSize { - // Emergency cleanup: remove expired tokens immediately + // Emergency cleanup: remove expired tokens now := time.Now() for t, exp := range b.tokens { if now.After(exp) { @@ -30,9 +31,22 @@ func (b *tokenBlacklistStore) Add(token string, expiry time.Time) { } } } - // If still at capacity after cleanup, skip adding (token will simply not be blacklisted) + // If still at capacity after cleanup, evict the oldest entry (FIFO) if len(b.tokens) >= maxBlacklistSize { - return + log.Printf("[SECURITY] token_blacklist_full | size=%d | evicting oldest entry to make room", len(b.tokens)) + var oldestToken string + var oldestExpiry time.Time + first := true + for t, exp := range b.tokens { + if first || exp.Before(oldestExpiry) { + oldestToken = t + oldestExpiry = exp + first = false + } + } + if oldestToken != "" { + delete(b.tokens, oldestToken) + } } b.tokens[token] = expiry } diff --git a/backend/internal/repository/chat.go b/backend/internal/repository/chat.go index 9b6dbf1e..b413f878 100644 --- a/backend/internal/repository/chat.go +++ b/backend/internal/repository/chat.go @@ -38,7 +38,7 @@ func (r *ChatRepo) Create(m *model.ChatMemory) error { } // UpdateSummaryAndTurns updates only the summary and turns fields. -func (r *ChatRepo) UpdateSummaryAndTurns(userID string, summary string, turns string) error { +func (r *ChatRepo) UpdateSummaryAndTurns(userID, summary, turns string) error { return r.db.Model(&model.ChatMemory{}). Where(whereUserID, userID). Updates(map[string]any{ @@ -139,7 +139,7 @@ func (r *ChatRepo) FactExistsByContentFamily(familyIDs []string, content string) return count > 0, err } -func (r *ChatRepo) DeleteFactsByContentLikeFamily(familyIDs []string, phrases []string) error { +func (r *ChatRepo) DeleteFactsByContentLikeFamily(familyIDs, phrases []string) error { for _, phrase := range phrases { phrase = strings.TrimSpace(phrase) if phrase == "" { diff --git a/backend/internal/router/router.go b/backend/internal/router/router.go index 83cf42d2..bbdb9a28 100644 --- a/backend/internal/router/router.go +++ b/backend/internal/router/router.go @@ -1,6 +1,9 @@ package router import ( + "net/http" + "path/filepath" + "strings" "time" "github.com/gin-gonic/gin" @@ -9,23 +12,33 @@ import ( "github.com/momshell/backend/internal/middleware" ) +const ( + routeUsers = "/users" + routeUserID = "/users/:id" +) + +// Handlers groups all handler dependencies for route setup. +type Handlers struct { + Auth *handler.AuthHandler + Question *handler.QuestionHandler + Answer *handler.AnswerHandler + Comment *handler.CommentHandler + Interaction *handler.InteractionHandler + Tag *handler.TagHandler + Chat *handler.ChatHandler + Echo *handler.EchoHandler + User *handler.UserHandler + Admin *handler.AdminHandler + Photo *handler.PhotoHandler + Whisper *handler.WhisperHandler + Task *handler.TaskHandler +} + func Setup( r *gin.Engine, cfg *config.Config, isAdmin middleware.AdminChecker, - authHandler *handler.AuthHandler, - questionHandler *handler.QuestionHandler, - answerHandler *handler.AnswerHandler, - commentHandler *handler.CommentHandler, - interactionHandler *handler.InteractionHandler, - tagHandler *handler.TagHandler, - chatHandler *handler.ChatHandler, - echoHandler *handler.EchoHandler, - userHandler *handler.UserHandler, - adminHandler *handler.AdminHandler, - photoHandler *handler.PhotoHandler, - whisperHandler *handler.WhisperHandler, - taskHandler *handler.TaskHandler, + h *Handlers, ) { // Rate limiters authLimiter := middleware.RateLimit(10, 1*time.Minute) // 10 req/min for auth @@ -37,30 +50,30 @@ func Setup( c.JSON(200, gin.H{"status": "ok"}) }) - // Serve uploaded files - r.Static("/uploads", "./uploads") + // Serve uploaded files with security restrictions + r.GET("/uploads/*filepath", secureStaticHandler("./uploads")) // Admin panel (HTML page, no auth required for serving the page) - r.GET("/admin", adminHandler.ServeAdminPage) + r.GET("/admin", h.Admin.ServeAdminPage) api := r.Group("/api/v1", generalLimiter) // ==================== Auth ==================== auth := api.Group("/auth") { - auth.POST("/register", authLimiter, authHandler.Register) - auth.POST("/login", authLimiter, authHandler.Login) - auth.POST("/refresh", authLimiter, authHandler.Refresh) - auth.POST("/forgot-password", authLimiter, authHandler.ForgotPassword) - auth.POST("/reset-password", authLimiter, authHandler.ResetPassword) + auth.POST("/register", authLimiter, h.Auth.Register) + auth.POST("/login", authLimiter, h.Auth.Login) + auth.POST("/refresh", authLimiter, h.Auth.Refresh) + auth.POST("/forgot-password", authLimiter, h.Auth.ForgotPassword) + auth.POST("/reset-password", authLimiter, h.Auth.ResetPassword) authRequired := auth.Group("", middleware.AuthRequired(cfg)) { - authRequired.POST("/change-password", authHandler.ChangePassword) - authRequired.GET("/me", authHandler.GetMe) - authRequired.PATCH("/me/role", authHandler.UpdateRole) - authRequired.PATCH("/me/tutorial", authHandler.CompleteTutorial) - authRequired.POST("/logout", authHandler.Logout) + authRequired.POST("/change-password", h.Auth.ChangePassword) + authRequired.GET("/me", h.Auth.GetMe) + authRequired.PATCH("/me/role", h.Auth.UpdateRole) + authRequired.PATCH("/me/tutorial", h.Auth.CompleteTutorial) + authRequired.POST("/logout", h.Auth.Logout) } } @@ -70,145 +83,199 @@ func Setup( // Questions (optional auth for read, required for write) questions := community.Group("/questions") { - questions.GET("", middleware.AuthOptional(cfg), questionHandler.List) - questions.GET("/hot", middleware.AuthOptional(cfg), questionHandler.ListHot) - questions.GET("/channel/:channel", middleware.AuthOptional(cfg), questionHandler.ListByChannel) - questions.GET("/:id", middleware.AuthOptional(cfg), questionHandler.Get) + questions.GET("", middleware.AuthOptional(cfg), h.Question.List) + questions.GET("/hot", middleware.AuthOptional(cfg), h.Question.ListHot) + questions.GET("/channel/:channel", middleware.AuthOptional(cfg), h.Question.ListByChannel) + questions.GET("/:id", middleware.AuthOptional(cfg), h.Question.Get) questionsAuth := questions.Group("", middleware.AuthRequired(cfg)) { - questionsAuth.POST("", questionHandler.Create) - questionsAuth.PUT("/:id", questionHandler.Update) - questionsAuth.DELETE("/:id", questionHandler.Delete) + questionsAuth.POST("", h.Question.Create) + questionsAuth.PUT("/:id", h.Question.Update) + questionsAuth.DELETE("/:id", h.Question.Delete) } // Answers under questions - questions.GET("/:id/answers", middleware.AuthOptional(cfg), answerHandler.List) - questions.POST("/:id/answers", middleware.AuthRequired(cfg), answerHandler.Create) + questions.GET("/:id/answers", middleware.AuthOptional(cfg), h.Answer.List) + questions.POST("/:id/answers", middleware.AuthRequired(cfg), h.Answer.Create) } // Answers (update/delete by answer ID) answers := community.Group("/answers") { - answers.PUT("/:id", middleware.AuthRequired(cfg), answerHandler.Update) - answers.DELETE("/:id", middleware.AuthRequired(cfg), answerHandler.Delete) + answers.PUT("/:id", middleware.AuthRequired(cfg), h.Answer.Update) + answers.DELETE("/:id", middleware.AuthRequired(cfg), h.Answer.Delete) // Comments under answers - answers.GET("/:id/comments", middleware.AuthOptional(cfg), commentHandler.List) - answers.POST("/:id/comments", middleware.AuthRequired(cfg), commentHandler.Create) + answers.GET("/:id/comments", middleware.AuthOptional(cfg), h.Comment.List) + answers.POST("/:id/comments", middleware.AuthRequired(cfg), h.Comment.Create) } // Comments (update/delete by comment ID) comments := community.Group("/comments") { - comments.PUT("/:id", middleware.AuthRequired(cfg), commentHandler.Update) - comments.DELETE("/:id", middleware.AuthRequired(cfg), commentHandler.Delete) + comments.PUT("/:id", middleware.AuthRequired(cfg), h.Comment.Update) + comments.DELETE("/:id", middleware.AuthRequired(cfg), h.Comment.Delete) } // Likes likes := community.Group("/likes", middleware.AuthRequired(cfg)) { - likes.POST("", interactionHandler.CreateLike) - likes.DELETE("", interactionHandler.DeleteLike) + likes.POST("", h.Interaction.CreateLike) + likes.DELETE("", h.Interaction.DeleteLike) } // Collections collections := community.Group("/collections", middleware.AuthRequired(cfg)) { - collections.POST("", interactionHandler.CreateCollection) - collections.DELETE("/:id", interactionHandler.DeleteCollection) - collections.GET("/my", interactionHandler.GetMyCollections) + collections.POST("", h.Interaction.CreateCollection) + collections.DELETE("/:id", h.Interaction.DeleteCollection) + collections.GET("/my", h.Interaction.GetMyCollections) } // Tags tags := community.Group("/tags") { - tags.GET("", tagHandler.List) - tags.GET("/hot", tagHandler.ListHot) - tags.POST("", middleware.AdminRequired(cfg, isAdmin), tagHandler.Create) + tags.GET("", h.Tag.List) + tags.GET("/hot", h.Tag.ListHot) + tags.POST("", middleware.AdminRequired(cfg, isAdmin), h.Tag.Create) } // User profile (community context) - users := community.Group("/users", middleware.AuthRequired(cfg)) + users := community.Group(routeUsers, middleware.AuthRequired(cfg)) { - users.GET("/me", userHandler.GetMe) - users.PUT("/me", userHandler.UpdateMe) - users.POST("/me/avatar", userHandler.UploadAvatar) - users.POST("/me/shell-code", userHandler.GenerateShellCode) - users.POST("/me/bind", userHandler.BindPartner) - users.DELETE("/me/bind", userHandler.UnbindPartner) - users.GET("/me/questions", userHandler.GetMyQuestions) - users.GET("/me/answers", userHandler.GetMyAnswers) + users.GET("/me", h.User.GetMe) + users.PUT("/me", h.User.UpdateMe) + users.POST("/me/avatar", h.User.UploadAvatar) + users.POST("/me/shell-code", h.User.GenerateShellCode) + users.POST("/me/bind", h.User.BindPartner) + users.DELETE("/me/bind", h.User.UnbindPartner) + users.GET("/me/questions", h.User.GetMyQuestions) + users.GET("/me/answers", h.User.GetMyAnswers) } } // ==================== Companion (AI Chat) ==================== companion := api.Group("/companion") { - companion.POST("/chat", aiLimiter, middleware.AuthOptional(cfg), chatHandler.Chat) - companion.GET("/profile", middleware.AuthOptional(cfg), chatHandler.GetProfile) - companion.GET("/memories", middleware.AuthRequired(cfg), chatHandler.GetMemories) - companion.DELETE("/memories/:id", middleware.AuthRequired(cfg), chatHandler.DeleteMemory) - companion.GET("/history", middleware.AuthRequired(cfg), chatHandler.GetHistory) - companion.DELETE("/history", middleware.AuthRequired(cfg), chatHandler.ClearHistory) + companion.POST("/chat", aiLimiter, middleware.AuthOptional(cfg), h.Chat.Chat) + companion.GET("/profile", middleware.AuthOptional(cfg), h.Chat.GetProfile) + companion.GET("/memories", middleware.AuthRequired(cfg), h.Chat.GetMemories) + companion.DELETE("/memories/:id", middleware.AuthRequired(cfg), h.Chat.DeleteMemory) + companion.GET("/history", middleware.AuthRequired(cfg), h.Chat.GetHistory) + companion.DELETE("/history", middleware.AuthRequired(cfg), h.Chat.ClearHistory) } echo := api.Group("/echo", middleware.AuthRequired(cfg)) { - echo.GET("/identity-tags", echoHandler.GetIdentityTags) - echo.POST("/identity-tags", echoHandler.CreateIdentityTag) - echo.DELETE("/identity-tags/:id", echoHandler.DeleteIdentityTag) + echo.GET("/identity-tags", h.Echo.GetIdentityTags) + echo.POST("/identity-tags", h.Echo.CreateIdentityTag) + echo.DELETE("/identity-tags/:id", h.Echo.DeleteIdentityTag) - echo.GET("/memoirs", echoHandler.GetMemoirs) - echo.POST("/memoirs/generate", aiLimiter, echoHandler.GenerateMemoir) - echo.POST("/memoirs/:id/rate", echoHandler.RateMemoir) + echo.GET("/memoirs", h.Echo.GetMemoirs) + echo.POST("/memoirs/generate", aiLimiter, h.Echo.GenerateMemoir) + echo.POST("/memoirs/:id/rate", h.Echo.RateMemoir) } // ==================== Photos ==================== photos := api.Group("/photos", middleware.AuthRequired(cfg)) { - photos.GET("", photoHandler.List) - photos.POST("/upload", photoHandler.Upload) - photos.POST("/generate", aiLimiter, photoHandler.Generate) - photos.PUT("/wall", photoHandler.BatchUpdateWall) - photos.PUT("/:id", photoHandler.Update) - photos.DELETE("/:id", photoHandler.Delete) - photos.PUT("/:id/wall", photoHandler.ToggleWall) + photos.GET("", h.Photo.List) + photos.POST("/upload", h.Photo.Upload) + photos.POST("/generate", aiLimiter, h.Photo.Generate) + photos.PUT("/wall", h.Photo.BatchUpdateWall) + photos.PUT("/:id", h.Photo.Update) + photos.DELETE("/:id", h.Photo.Delete) + photos.PUT("/:id/wall", h.Photo.ToggleWall) } // ==================== Whisper (Heart Words) ==================== whisper := api.Group("/whisper", middleware.AuthRequired(cfg)) { - whisper.POST("", whisperHandler.Create) - whisper.GET("", whisperHandler.List) - whisper.GET("/tips", aiLimiter, whisperHandler.Tips) + whisper.POST("", h.Whisper.Create) + whisper.GET("", h.Whisper.List) + whisper.GET("/tips", aiLimiter, h.Whisper.Tips) } // ==================== Tasks ==================== tasks := api.Group("/tasks", middleware.AuthRequired(cfg)) { - tasks.GET("/daily", taskHandler.DailyTasks) - tasks.POST("/:id/complete", taskHandler.Complete) - tasks.GET("/partner", taskHandler.PartnerTasks) - tasks.POST("/:id/score", taskHandler.Score) - tasks.POST("/:id/reject", taskHandler.Reject) - tasks.GET("/stats", taskHandler.Stats) - tasks.GET("/baby-age", taskHandler.GetBabyAge) - tasks.PUT("/baby-age", taskHandler.SetBabyAge) + tasks.GET("/daily", h.Task.DailyTasks) + tasks.POST("/:id/complete", h.Task.Complete) + tasks.GET("/partner", h.Task.PartnerTasks) + tasks.POST("/:id/score", h.Task.Score) + tasks.POST("/:id/reject", h.Task.Reject) + tasks.GET("/stats", h.Task.Stats) + tasks.GET("/baby-age", h.Task.GetBabyAge) + tasks.PUT("/baby-age", h.Task.SetBabyAge) } // ==================== Admin ==================== adminAPI := api.Group("/admin", middleware.AdminRequired(cfg, isAdmin)) { - adminAPI.GET("/stats", adminHandler.GetStats) - adminAPI.GET("/users", adminHandler.ListUsers) - adminAPI.GET("/users/:id", adminHandler.GetUser) - adminAPI.POST("/users", adminHandler.CreateUser) - adminAPI.PATCH("/users/:id", adminHandler.UpdateUser) - adminAPI.DELETE("/users/:id", adminHandler.DeleteUser) - adminAPI.GET("/config", adminHandler.GetConfig) - adminAPI.PATCH("/config", adminHandler.UpdateConfig) - adminAPI.GET("/photos", adminHandler.ListPhotos) - adminAPI.DELETE("/photos/:id", adminHandler.DeletePhoto) + adminAPI.GET("/stats", h.Admin.GetStats) + adminAPI.GET(routeUsers, h.Admin.ListUsers) + adminAPI.GET(routeUserID, h.Admin.GetUser) + adminAPI.POST(routeUsers, h.Admin.CreateUser) + adminAPI.PATCH(routeUserID, h.Admin.UpdateUser) + adminAPI.DELETE(routeUserID, h.Admin.DeleteUser) + adminAPI.GET("/config", h.Admin.GetConfig) + adminAPI.PATCH("/config", h.Admin.UpdateConfig) + adminAPI.GET("/photos", h.Admin.ListPhotos) + adminAPI.DELETE("/photos/:id", h.Admin.DeletePhoto) + } +} + +// allowedImageExts defines image extensions that are safe to serve inline. +var allowedImageExts = map[string]bool{ + ".jpg": true, + ".jpeg": true, + ".png": true, + ".gif": true, + ".webp": true, +} + +// blockedExts defines file extensions that must never be served. +var blockedExts = map[string]bool{ + ".svg": true, + ".html": true, + ".htm": true, + ".xml": true, + ".js": true, + ".css": true, +} + +// secureStaticHandler returns a handler that serves files from the given root +// directory with security headers and file type restrictions. +func secureStaticHandler(root string) gin.HandlerFunc { + fs := http.Dir(root) + fileServer := http.StripPrefix("/uploads", http.FileServer(fs)) + + return func(c *gin.Context) { + // Clean the path to prevent traversal attacks + reqPath := filepath.Clean(c.Param("filepath")) + + // Ensure the cleaned path does not escape the root + if strings.Contains(reqPath, "..") { + c.AbortWithStatus(http.StatusForbidden) + return + } + + ext := strings.ToLower(filepath.Ext(reqPath)) + + // Block dangerous file types + if blockedExts[ext] { + c.AbortWithStatus(http.StatusForbidden) + return + } + + // Set security headers + c.Header("X-Content-Type-Options", "nosniff") + + if allowedImageExts[ext] { + c.Header("Content-Disposition", "inline") + } + + fileServer.ServeHTTP(c.Writer, c.Request) } } diff --git a/backend/internal/service/auth.go b/backend/internal/service/auth.go index fd4870f2..a3a7ba2a 100644 --- a/backend/internal/service/auth.go +++ b/backend/internal/service/auth.go @@ -3,6 +3,7 @@ package service import ( "errors" "fmt" + "log" "github.com/momshell/backend/internal/config" "github.com/momshell/backend/internal/dto" @@ -13,6 +14,11 @@ import ( "gorm.io/gorm" ) +const ( + errPasswordHashFailed = "密码加密失败: %w" + errInvalidRefreshToken = "无效的刷新令牌" +) + type AuthService struct { cfg *config.Config userRepo *repository.UserRepo @@ -35,7 +41,7 @@ func (s *AuthService) Register(req dto.RegisterRequest) (*dto.UserResponse, erro // Hash password hash, err := password.Hash(req.Password) if err != nil { - return nil, fmt.Errorf("密码加密失败: %w", err) + return nil, fmt.Errorf(errPasswordHashFailed, err) } user := &model.User{ @@ -56,42 +62,50 @@ func (s *AuthService) Register(req dto.RegisterRequest) (*dto.UserResponse, erro return s.buildUserResponse(user), nil } -func (s *AuthService) Login(req dto.LoginRequest) (*dto.TokenResponse, error) { +func (s *AuthService) Login(req dto.LoginRequest, clientIP string) (*dto.TokenResponse, error) { user, err := s.userRepo.FindByUsernameOrEmail(req.Login) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { + log.Printf("[SECURITY] login_failed | ip=%s | user=%s | detail=user not found", clientIP, req.Login) return nil, errors.New("用户名或密码错误") } return nil, fmt.Errorf("查询用户失败: %w", err) } if !password.Verify(req.Password, user.PasswordHash) { + log.Printf("[SECURITY] login_failed | ip=%s | user=%s | detail=wrong password", clientIP, req.Login) return nil, errors.New("用户名或密码错误") } if !user.IsActive { + log.Printf("[SECURITY] login_failed | ip=%s | user=%s | detail=account disabled", clientIP, req.Login) return nil, errors.New("账号已禁用") } if user.IsBanned { + log.Printf("[SECURITY] login_failed | ip=%s | user=%s | detail=account banned", clientIP, req.Login) return nil, errors.New("账号已被封禁") } + if user.IsAdmin { + log.Printf("[SECURITY] admin_login | ip=%s | user=%s | detail=admin login successful", clientIP, req.Login) + } + return s.generateTokens(user.ID) } func (s *AuthService) RefreshToken(refreshToken string) (*dto.TokenResponse, error) { claims, err := pkgjwt.ParseToken(refreshToken, s.cfg.JWTSecretKey) if err != nil { - return nil, errors.New("无效的刷新令牌") + return nil, errors.New(errInvalidRefreshToken) } if claims.Type != "refresh" { - return nil, errors.New("无效的刷新令牌") + return nil, errors.New(errInvalidRefreshToken) } user, err := s.userRepo.FindByID(claims.Subject) if err != nil || !user.IsActive || user.IsBanned { - return nil, errors.New("无效的刷新令牌") + return nil, errors.New(errInvalidRefreshToken) } return s.generateTokens(user.ID) @@ -117,7 +131,7 @@ func (s *AuthService) ChangePassword(userID, oldPassword, newPassword string) er hash, err := password.Hash(newPassword) if err != nil { - return fmt.Errorf("密码加密失败: %w", err) + return fmt.Errorf(errPasswordHashFailed, err) } return s.userRepo.UpdatePassword(userID, hash) @@ -147,7 +161,7 @@ func (s *AuthService) ResetPassword(token, newPassword string) error { hash, err := password.Hash(newPassword) if err != nil { - return fmt.Errorf("密码加密失败: %w", err) + return fmt.Errorf(errPasswordHashFailed, err) } return s.userRepo.UpdatePassword(userID, hash) diff --git a/backend/internal/service/chat.go b/backend/internal/service/chat.go index 81496ae0..0b3608bb 100644 --- a/backend/internal/service/chat.go +++ b/backend/internal/service/chat.go @@ -2,6 +2,9 @@ package service import ( "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "log" @@ -217,6 +220,7 @@ type ChatService struct { chatRepo *repository.ChatRepo userRepo *repository.UserRepo firecrawl *firecrawl.Client + jwtSecret string // In-memory storage for guest sessions mu sync.RWMutex guestMemory map[string][]map[string]interface{} @@ -224,12 +228,13 @@ type ChatService struct { guestLastAccess map[string]time.Time } -func NewChatService(client *openai.Client, chatRepo *repository.ChatRepo, userRepo *repository.UserRepo, fc *firecrawl.Client) *ChatService { +func NewChatService(client *openai.Client, chatRepo *repository.ChatRepo, userRepo *repository.UserRepo, fc *firecrawl.Client, jwtSecret string) *ChatService { return &ChatService{ client: client, chatRepo: chatRepo, userRepo: userRepo, firecrawl: fc, + jwtSecret: jwtSecret, guestMemory: make(map[string][]map[string]interface{}), guestProfiles: make(map[string]map[string]interface{}), guestLastAccess: make(map[string]time.Time), @@ -265,49 +270,62 @@ func (s *ChatService) Chat(ctx context.Context, msg dto.UserMessage, userID stri } // appendWebSearchResults appends web search results to the system prompt if available. -func appendWebSearchResults(systemPrompt string, webResults string) string { +func appendWebSearchResults(systemPrompt, webResults string) string { if webResults == "" { return systemPrompt } return systemPrompt + "\n\n## 联网搜索参考\n" + webResults + "\n日常聊天不需要引用来源。仅在提供专业性建议时才引用,引用时直接写出具体来源名称(如「根据XX的一篇文章...」),不要使用[来源1]这样的标注。不确定的信息请标明。" } -func (s *ChatService) chatAuthenticated(ctx context.Context, msg dto.UserMessage, userID string) (*dto.VisualResponse, error) { - // Look up user role and partner info - role := model.RoleMom - isAdmin := false - var partnerID string - var partnerRole model.UserRole - if user, err := s.userRepo.FindByID(userID); err == nil { - role = user.Role - isAdmin = user.IsAdmin - if user.PartnerID != nil && *user.PartnerID != "" { - partnerID = *user.PartnerID - if partner, err := s.userRepo.FindByID(partnerID); err == nil { - partnerRole = partner.Role - } +// chatUserContext holds resolved user context for authenticated chat. +type chatUserContext struct { + role model.UserRole + isAdmin bool + partnerID string + partnerRole model.UserRole +} + +// resolveChatUserContext looks up the user and partner information for chat. +func (s *ChatService) resolveChatUserContext(userID string) chatUserContext { + ctx := chatUserContext{role: model.RoleMom} + user, err := s.userRepo.FindByID(userID) + if err != nil { + return ctx + } + ctx.role = user.Role + ctx.isAdmin = user.IsAdmin + if user.PartnerID != nil && *user.PartnerID != "" { + ctx.partnerID = *user.PartnerID + if partner, pErr := s.userRepo.FindByID(ctx.partnerID); pErr == nil { + ctx.partnerRole = partner.Role } } - pronoun := pronounFor(role) + return ctx +} + +func (s *ChatService) chatAuthenticated(ctx context.Context, msg dto.UserMessage, userID string) (*dto.VisualResponse, error) { + // Look up user role and partner info + uc := s.resolveChatUserContext(userID) + pronoun := pronounFor(uc.role) familyIDs := []string{userID} - if partnerID != "" { - familyIDs = append(familyIDs, partnerID) + if uc.partnerID != "" { + familyIDs = append(familyIDs, uc.partnerID) } // Load memory from DB (per-user, not shared) profile, turns, summary := s.loadUserMemory(userID) // Load structured facts for prompt (family-scoped) - factsText, deletedFactsText := s.loadFactsForPrompt(userID, familyIDs, role, partnerRole) + factsText, deletedFactsText := s.loadFactsForPrompt(userID, familyIDs, uc.role, uc.partnerRole) - systemPrompt := fmt.Sprintf(getCompanionPrompt(role, isAdmin), + systemPrompt := fmt.Sprintf(getCompanionPrompt(uc.role, uc.isAdmin), formatProfile(profile, pronoun, factsText), formatTurns(turns, summary, pronoun), ) // Update memory section header for family mode - if partnerID != "" { + if uc.partnerID != "" { for _, old := range []string{"你记得关于她的重要信息", "你记得关于他的重要信息", "你记得关于对方的重要信息"} { systemPrompt = strings.Replace(systemPrompt, old, "你记得关于这个家庭的重要信息", 1) } @@ -394,8 +412,16 @@ func (s *ChatService) chatGuest(ctx context.Context, msg dto.UserMessage) (*dto. if msg.SessionID != nil { sessionID = *msg.SessionID } + + // Validate HMAC signature if session_id is provided + if sessionID != "" { + if !s.verifySessionID(sessionID) { + // Invalid signature; generate a new signed session + sessionID = "" + } + } if sessionID == "" { - sessionID = uuid.New().String() + sessionID = s.signSessionID(uuid.New().String()) } s.mu.Lock() @@ -425,7 +451,7 @@ func (s *ChatService) chatGuest(ctx context.Context, msg dto.UserMessage) (*dto. {Role: "user", Content: msg.Content}, } - rawContent, err := s.client.Chat(context.Background(), messages) + rawContent, err := s.client.Chat(ctx, messages) if err != nil { return nil, fmt.Errorf("AI 服务调用失败: %w", err) } @@ -449,6 +475,31 @@ func (s *ChatService) chatGuest(ctx context.Context, msg dto.UserMessage) (*dto. return buildVisualResponse(parsed, memoryUpdated), nil } +// --- Guest session HMAC signing --- + +// signSessionID creates a signed session ID in the format "uuid:hmac_hex". +func (s *ChatService) signSessionID(id string) string { + mac := hmac.New(sha256.New, []byte(s.jwtSecret)) + mac.Write([]byte(id)) + sig := hex.EncodeToString(mac.Sum(nil)) + return id + ":" + sig +} + +// verifySessionID validates that a session ID has a valid HMAC signature. +func (s *ChatService) verifySessionID(sessionID string) bool { + idx := strings.LastIndex(sessionID, ":") + if idx < 0 || idx == len(sessionID)-1 { + return false + } + id := sessionID[:idx] + sig := sessionID[idx+1:] + + mac := hmac.New(sha256.New, []byte(s.jwtSecret)) + mac.Write([]byte(id)) + expected := hex.EncodeToString(mac.Sum(nil)) + return hmac.Equal([]byte(sig), []byte(expected)) +} + // --- Phase 2: Conversation Summary --- func (s *ChatService) generateAndSaveSummary(userID string, existingSummary string, oldTurns []map[string]interface{}) { @@ -499,6 +550,45 @@ func (s *ChatService) generateAndSaveSummary(userID string, existingSummary stri // --- Phase 3: Structured Memory Facts --- +// parseFactItem extracts content and category from a single fact item. +func parseFactItem(v interface{}) (content, aiCategory string) { + switch item := v.(type) { + case map[string]interface{}: + c, _ := item["content"].(string) + content = strings.TrimSpace(c) + cat, _ := item["category"].(string) + aiCategory = strings.TrimSpace(cat) + case string: + content = strings.TrimSpace(item) + } + return content, aiCategory +} + +// saveSingleFact creates a fact if it does not already exist. Returns true if saved. +func (s *ChatService) saveSingleFact(userID, content, aiCategory string) bool { + exists, err := s.chatRepo.FactExistsByContent(userID, content) + if err != nil { + log.Printf("[ChatService] failed to check fact existence: %v", err) + return false + } + if exists { + return false + } + + category := resolveFactCategory(aiCategory, content) + fact := &model.ChatMemoryFact{ + UserID: userID, + OwnerUserID: userID, + Content: content, + Category: category, + } + if err := s.chatRepo.CreateFact(fact); err != nil { + log.Printf("[ChatService] failed to save fact for user %s: %v", userID, err) + return false + } + return true +} + func (s *ChatService) saveFactsFromExtract(userID string, extract interface{}) bool { if extract == nil { return false @@ -514,47 +604,11 @@ func (s *ChatService) saveFactsFromExtract(userID string, extract interface{}) b saved := false for _, v := range facts { - var content string - var aiCategory string - - switch item := v.(type) { - case map[string]interface{}: - // New structured format: {"content": "...", "category": "..."} - c, _ := item["content"].(string) - content = strings.TrimSpace(c) - cat, _ := item["category"].(string) - aiCategory = strings.TrimSpace(cat) - case string: - // Legacy string format - content = strings.TrimSpace(item) - default: - continue - } - + content, aiCategory := parseFactItem(v) if content == "" { continue } - - // Skip if identical fact already exists for this user (including soft-deleted) - exists, err := s.chatRepo.FactExistsByContent(userID, content) - if err != nil { - log.Printf("[ChatService] failed to check fact existence: %v", err) - continue - } - if exists { - continue - } - - category := resolveFactCategory(aiCategory, content) - fact := &model.ChatMemoryFact{ - UserID: userID, - OwnerUserID: userID, - Content: content, - Category: category, - } - if err := s.chatRepo.CreateFact(fact); err != nil { - log.Printf("[ChatService] failed to save fact for user %s: %v", userID, err) - } else { + if s.saveSingleFact(userID, content, aiCategory) { saved = true } } @@ -626,7 +680,7 @@ func (s *ChatService) processMemoryCorrections(familyIDs []string, extract inter } // resolveFactCategory uses the AI-provided category if valid, otherwise falls back to keyword detection. -func resolveFactCategory(aiCategory string, content string) model.FactCategory { +func resolveFactCategory(aiCategory, content string) model.FactCategory { switch model.FactCategory(aiCategory) { case model.FactCategoryPersonalInfo, model.FactCategoryFamily, model.FactCategoryInterest, model.FactCategoryConcern, @@ -677,7 +731,18 @@ func categorizeFactContent(content string) model.FactCategory { return model.FactCategoryOther } -func (s *ChatService) loadFactsForPrompt(userID string, familyIDs []string, userRole model.UserRole, partnerRole model.UserRole) (string, string) { +// factLabel returns the display label for a fact in family mode. +func factLabel(f model.ChatMemoryFact, userID string, userRole, partnerRole model.UserRole) string { + if f.Category == model.FactCategoryFamily { + return "家庭" + } + if f.OwnerUserID == userID { + return pronounFor(userRole) + } + return pronounFor(partnerRole) +} + +func (s *ChatService) loadFactsForPrompt(userID string, familyIDs []string, userRole, partnerRole model.UserRole) (string, string) { hasPartner := len(familyIDs) > 1 facts, err := s.chatRepo.FindFactsByFamilyIDs(familyIDs) @@ -690,14 +755,7 @@ func (s *ChatService) loadFactsForPrompt(userID string, familyIDs []string, user var sb strings.Builder for _, f := range facts { if hasPartner { - var label string - if f.Category == model.FactCategoryFamily { - label = "家庭" - } else if f.OwnerUserID == userID { - label = pronounFor(userRole) - } else { - label = pronounFor(partnerRole) - } + label := factLabel(f, userID, userRole, partnerRole) fmt.Fprintf(&sb, " · [%s] %s\n", label, f.Content) } else { fmt.Fprintf(&sb, " · %s\n", f.Content) @@ -822,6 +880,10 @@ func (s *ChatService) searchWebForChat(ctx context.Context, userMessage string) } var sb strings.Builder for _, r := range results { + // Validate URL scheme before including in output + if !strings.HasPrefix(r.URL, "https://") && !strings.HasPrefix(r.URL, "http://") { + continue + } content := r.Markdown if content == "" { content = r.Description diff --git a/backend/internal/service/community.go b/backend/internal/service/community.go index 4c329e14..c23d9ebb 100644 --- a/backend/internal/service/community.go +++ b/backend/internal/service/community.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "html" "time" "github.com/momshell/backend/internal/dto" @@ -12,6 +13,11 @@ import ( "gorm.io/gorm" ) +const ( + errExpertPostSourceRequired = "专家帖必须标注来源依据" + errContentModerationFailed = "内容审核未通过: %s" +) + type CommunityService struct { questionRepo *repository.QuestionRepo answerRepo *repository.AnswerRepo @@ -251,6 +257,10 @@ func (s *CommunityService) GetQuestion(questionID, currentUserID string) (*dto.Q } func (s *CommunityService) CreateQuestion(req dto.QuestionCreate, author *model.User) (*model.Question, error) { + // Sanitize user content + req.Title = sanitizeHTML(req.Title) + req.Content = sanitizeHTML(req.Content) + // Content moderation titleDecision := s.moderation.ModerateText(req.Title) if titleDecision.Result == model.ModerationRejected { @@ -259,7 +269,7 @@ func (s *CommunityService) CreateQuestion(req dto.QuestionCreate, author *model. contentDecision := s.moderation.ModerateText(req.Content) if contentDecision.Result == model.ModerationRejected { - return nil, fmt.Errorf("内容审核未通过: %s", derefStr(contentDecision.Reason)) + return nil, fmt.Errorf(errContentModerationFailed, derefStr(contentDecision.Reason)) } // Determine status @@ -311,6 +321,8 @@ func (s *CommunityService) moderateAndUpdateTextField(q *model.Question, newText if newText == nil { return nil } + sanitized := sanitizeHTML(*newText) + newText = &sanitized decision := s.moderation.ModerateText(*newText) if decision.Result == model.ModerationRejected { return fmt.Errorf("%s审核未通过: %s", fieldName, derefStr(decision.Reason)) @@ -444,6 +456,9 @@ func (s *CommunityService) GetAnswers(questionID string, params dto.AnswerListPa } func (s *CommunityService) CreateAnswer(questionID string, req dto.AnswerCreate, author *model.User) (*model.Answer, error) { + // Sanitize user content + req.Content = sanitizeHTML(req.Content) + // Check question exists _, err := s.questionRepo.FindByID(questionID) if err != nil { @@ -456,14 +471,14 @@ func (s *CommunityService) CreateAnswer(questionID string, req dto.AnswerCreate, return nil, errors.New("仅认证专业人士可发布专家帖") } if req.Sources == "" { - return nil, errors.New("专家帖必须标注来源依据") + return nil, errors.New(errExpertPostSourceRequired) } } // Content moderation decision := s.moderation.ModerateText(req.Content) if decision.Result == model.ModerationRejected { - return nil, fmt.Errorf("内容审核未通过: %s", derefStr(decision.Reason)) + return nil, fmt.Errorf(errContentModerationFailed, derefStr(decision.Reason)) } status := model.StatusPublished @@ -504,44 +519,45 @@ func (s *CommunityService) CreateAnswer(questionID string, req dto.AnswerCreate, return answer, nil } -func (s *CommunityService) UpdateAnswer(answerID string, req dto.AnswerUpdate, user *model.User) (*model.Answer, error) { - a, err := s.answerRepo.FindByID(answerID) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, errors.New("回答不存在") - } - return nil, err +// updateAnswerContent moderates and applies content update to an answer. +func (s *CommunityService) updateAnswerContent(a *model.Answer, content *string) error { + if content == nil { + return nil } - - if a.AuthorID != user.ID && !user.IsAdmin { - return nil, errors.New("无权修改此回答") + sanitized := sanitizeHTML(*content) + content = &sanitized + decision := s.moderation.ModerateText(*content) + if decision.Result == model.ModerationRejected { + return fmt.Errorf(errContentModerationFailed, derefStr(decision.Reason)) + } + a.Content = *content + if decision.Result == model.ModerationNeedManualReview { + a.Status = model.StatusPendingReview } + return nil +} - if req.Content != nil { - decision := s.moderation.ModerateText(*req.Content) - if decision.Result == model.ModerationRejected { - return nil, fmt.Errorf("内容审核未通过: %s", derefStr(decision.Reason)) - } - a.Content = *req.Content - if decision.Result == model.ModerationNeedManualReview { - a.Status = model.StatusPendingReview - } +// resolveExpertPostSources determines the sources value for an expert post update. +func resolveExpertPostSources(reqSources *string, existingSources *string) string { + if reqSources != nil { + return *reqSources + } + if existingSources != nil { + return *existingSources } + return "" +} - // Update expert post fields +// updateExpertPostFields handles the expert post flag and sources fields on answer update. +func (s *CommunityService) updateExpertPostFields(a *model.Answer, req dto.AnswerUpdate, user *model.User) error { if req.IsExpertPost != nil { if *req.IsExpertPost { if !s.IsCertifiedProfessional(user) && !user.IsAdmin { - return nil, errors.New("仅认证专业人士可发布专家帖") - } - sources := "" - if req.Sources != nil { - sources = *req.Sources - } else if a.Sources != nil { - sources = *a.Sources + return errors.New("仅认证专业人士可发布专家帖") } + sources := resolveExpertPostSources(req.Sources, a.Sources) if sources == "" { - return nil, errors.New("专家帖必须标注来源依据") + return errors.New(errExpertPostSourceRequired) } a.IsExpertPost = true a.Sources = &sources @@ -551,10 +567,33 @@ func (s *CommunityService) UpdateAnswer(answerID string, req dto.AnswerUpdate, u } if req.Sources != nil && (req.IsExpertPost == nil || !*req.IsExpertPost) { if a.IsExpertPost && *req.Sources == "" { - return nil, errors.New("专家帖必须标注来源依据") + return errors.New(errExpertPostSourceRequired) } a.Sources = req.Sources } + return nil +} + +func (s *CommunityService) UpdateAnswer(answerID string, req dto.AnswerUpdate, user *model.User) (*model.Answer, error) { + a, err := s.answerRepo.FindByID(answerID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("回答不存在") + } + return nil, err + } + + if a.AuthorID != user.ID && !user.IsAdmin { + return nil, errors.New("无权修改此回答") + } + + if err := s.updateAnswerContent(a, req.Content); err != nil { + return nil, err + } + + if err := s.updateExpertPostFields(a, req, user); err != nil { + return nil, err + } a.UpdatedAt = time.Now() if err := s.answerRepo.Update(a); err != nil { @@ -641,6 +680,9 @@ func (s *CommunityService) GetComments(answerID, currentUserID string) ([]dto.Co } func (s *CommunityService) CreateComment(answerID string, req dto.CommentCreate, user *model.User) (*dto.CommentListItem, error) { + // Sanitize user content + req.Content = sanitizeHTML(req.Content) + // Check answer exists _, err := s.answerRepo.FindByID(answerID) if err != nil { @@ -709,6 +751,8 @@ func (s *CommunityService) UpdateComment(commentID string, req dto.CommentUpdate return nil, errors.New("无权修改此评论") } + req.Content = sanitizeHTML(req.Content) + decision := s.moderation.ModerateText(req.Content) if decision.Result == model.ModerationRejected { return nil, fmt.Errorf("评论审核未通过: %s", derefStr(decision.Reason)) @@ -782,6 +826,48 @@ func (s *CommunityService) ToggleLike(userID, targetType, targetID string) (bool return true, count, nil } +// AddLike creates a like only if one doesn't already exist. Returns the +// current state and count without toggling. +func (s *CommunityService) AddLike(userID, targetType, targetID string) (bool, int, error) { + _, err := s.interactionRepo.FindLike(userID, targetType, targetID) + if err == nil { + // Already liked — return current state (idempotent) + count := s.getTargetLikeCount(targetType, targetID) + return true, count, nil + } + + like := &model.Like{ + UserID: userID, + TargetType: targetType, + TargetID: targetID, + } + if err := s.interactionRepo.CreateLike(like); err != nil { + return false, 0, err + } + s.updateTargetLikeCount(targetType, targetID, 1) + count := s.getTargetLikeCount(targetType, targetID) + return true, count, nil +} + +// RemoveLike removes a like only if one exists. Returns the current state +// and count without toggling. Calling DELETE on an already-unliked resource +// is a no-op (idempotent). +func (s *CommunityService) RemoveLike(userID, targetType, targetID string) (bool, int, error) { + _, err := s.interactionRepo.FindLike(userID, targetType, targetID) + if err != nil { + // Not liked — return current state (idempotent) + count := s.getTargetLikeCount(targetType, targetID) + return false, count, nil + } + + if err := s.interactionRepo.DeleteLike(userID, targetType, targetID); err != nil { + return false, 0, err + } + s.updateTargetLikeCount(targetType, targetID, -1) + count := s.getTargetLikeCount(targetType, targetID) + return false, count, nil +} + func (s *CommunityService) updateTargetLikeCount(targetType, targetID string, delta int) { switch targetType { case "question": @@ -964,3 +1050,8 @@ func derefStr(s *string) string { } return *s } + +// sanitizeHTML escapes HTML entities in user-provided content to prevent XSS. +func sanitizeHTML(s string) string { + return html.EscapeString(s) +} diff --git a/backend/internal/service/echo.go b/backend/internal/service/echo.go index 58b7b9a5..2f7b0cff 100644 --- a/backend/internal/service/echo.go +++ b/backend/internal/service/echo.go @@ -15,9 +15,10 @@ import ( ) const ( - memoirDefaultTitle = "一段温柔的回响" - memoirKeyTitle = "title" - memoirKeyContent = "content" + memoirDefaultTitle = "一段温柔的回响" + memoirKeyTitle = "title" + memoirKeyContent = "content" + trimWhitespaceChars = " \t\n\r" ) type EchoService struct { @@ -319,13 +320,13 @@ func parseMemoirLLMResponse(content string) map[string]interface{} { func cleanMemoirText(s string) string { s = strings.TrimSpace(s) // Remove trailing markdown code fences and JSON braces - s = strings.TrimRight(s, " \t\n\r") + s = strings.TrimRight(s, trimWhitespaceChars) for { trimmed := s trimmed = strings.TrimSuffix(trimmed, "```") trimmed = strings.TrimSuffix(trimmed, "}") trimmed = strings.TrimSuffix(trimmed, `"`) - trimmed = strings.TrimRight(trimmed, " \t\n\r") + trimmed = strings.TrimRight(trimmed, trimWhitespaceChars) if trimmed == s { break } @@ -337,7 +338,7 @@ func cleanMemoirText(s string) string { trimmed = strings.TrimPrefix(trimmed, "```json") trimmed = strings.TrimPrefix(trimmed, "```") trimmed = strings.TrimPrefix(trimmed, "{") - trimmed = strings.TrimLeft(trimmed, " \t\n\r") + trimmed = strings.TrimLeft(trimmed, trimWhitespaceChars) if trimmed == s { break } diff --git a/backend/internal/service/task.go b/backend/internal/service/task.go index 5dede137..aae0646b 100644 --- a/backend/internal/service/task.go +++ b/backend/internal/service/task.go @@ -334,7 +334,7 @@ func (s *TaskService) createTasksForUser(user *model.User, userID string, date t } // SetBabyAge sets the baby age stage for the user and immediately regenerates tasks. -func (s *TaskService) SetBabyAge(userID string, ageStage string) error { +func (s *TaskService) SetBabyAge(userID, ageStage string) error { user, err := s.userRepo.FindByID(userID) if err != nil { return errors.New(errUserNotFound) diff --git a/backend/internal/service/task_ai.go b/backend/internal/service/task_ai.go index 11bce66c..fcdfdf88 100644 --- a/backend/internal/service/task_ai.go +++ b/backend/internal/service/task_ai.go @@ -41,6 +41,24 @@ func coupleKey(a, b string) string { return b + "-" + a } +// resolveAgeStageFromChatMemory tries to infer the age stage from chat memory facts. +func resolveAgeStageFromChatMemory(chatRepo *repository.ChatRepo, familyIDs []string) (string, bool) { + facts, err := chatRepo.FindFactsByFamilyIDs(familyIDs) + if err != nil { + return "", false + } + for _, f := range facts { + lower := strings.ToLower(f.Content) + if strings.Contains(lower, "宝宝") || strings.Contains(lower, "baby") || strings.Contains(lower, "孩子") { + stage := inferAgeStageFromMemory(f.Content) + if stage != "" { + return stage, true + } + } + } + return "", false +} + // resolveAgeStage returns the baby age stage for the couple, checking: // 1. The user's own BabyAgeStage // 2. The partner's BabyAgeStage @@ -69,17 +87,8 @@ func resolveAgeStage( if user.PartnerID != nil { familyIDs = append(familyIDs, *user.PartnerID) } - facts, err := chatRepo.FindFactsByFamilyIDs(familyIDs) - if err == nil { - for _, f := range facts { - lower := strings.ToLower(f.Content) - if strings.Contains(lower, "宝宝") || strings.Contains(lower, "baby") || strings.Contains(lower, "孩子") { - stage := inferAgeStageFromMemory(f.Content) - if stage != "" { - return stage, "memory" - } - } - } + if s, found := resolveAgeStageFromChatMemory(chatRepo, familyIDs); found { + return s, "memory" } } diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index 6979e1f0..ceb2e561 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -11,7 +11,10 @@ import ( "gorm.io/gorm" ) -const errUserServiceUserNotFound = "用户不存在" +const ( + errUserServiceUserNotFound = "用户不存在" + whereIDEquals = "id = ?" +) type UserService struct { db *gorm.DB @@ -86,6 +89,21 @@ func (s *UserService) GetProfile(userID string) (*dto.UserProfile, error) { return profile, nil } +// validateRoleChange checks if a role update is allowed for the user. +func validateRoleChange(user *model.User, newRoleStr string) error { + newRole := model.UserRole(newRoleStr) + if !model.FamilyRoles[newRole] { + return errors.New("角色只能是: mom, dad") + } + if model.ProfessionalRoles[user.Role] { + return errors.New("认证专业人员不能修改角色") + } + if user.PartnerID != nil { + return errors.New("已绑定伴侣,无法更改身份") + } + return nil +} + // applyProfileFieldUpdates applies the individual field updates from the request to the user model. // Returns an error if any validation fails (e.g. duplicate username/email, invalid role). func (s *UserService) applyProfileFieldUpdates(user *model.User, req dto.UserProfileUpdate) error { @@ -114,17 +132,10 @@ func (s *UserService) applyProfileFieldUpdates(user *model.User, req dto.UserPro } if req.Role != nil { - newRole := model.UserRole(*req.Role) - if !model.FamilyRoles[newRole] { - return errors.New("角色只能是: mom, dad") - } - if model.ProfessionalRoles[user.Role] { - return errors.New("认证专业人员不能修改角色") - } - if user.PartnerID != nil { - return errors.New("已绑定伴侣,无法更改身份") + if err := validateRoleChange(user, *req.Role); err != nil { + return err } - user.Role = newRole + user.Role = model.UserRole(*req.Role) } return nil @@ -178,7 +189,7 @@ func (s *UserService) GenerateShellCode(userID string) (*dto.UserProfile, error) } // BindPartner binds a 守护者 (dad) to a 溯源者 (mom) via shell code. -func (s *UserService) BindPartner(userID string, shellCode string) (*dto.UserProfile, error) { +func (s *UserService) BindPartner(userID, shellCode string) (*dto.UserProfile, error) { user, err := s.userRepo.FindByID(userID) if err != nil { return nil, errors.New(errUserServiceUserNotFound) @@ -211,10 +222,10 @@ func (s *UserService) BindPartner(userID string, shellCode string) (*dto.UserPro // Bind both sides in a transaction err = s.db.Transaction(func(tx *gorm.DB) error { - if err := tx.Model(&model.User{}).Where("id = ?", user.ID).Update("partner_id", partner.ID).Error; err != nil { + if err := tx.Model(&model.User{}).Where(whereIDEquals, user.ID).Update("partner_id", partner.ID).Error; err != nil { return err } - if err := tx.Model(&model.User{}).Where("id = ?", partner.ID).Update("partner_id", user.ID).Error; err != nil { + if err := tx.Model(&model.User{}).Where(whereIDEquals, partner.ID).Update("partner_id", user.ID).Error; err != nil { return err } return nil @@ -241,13 +252,13 @@ func (s *UserService) UnbindPartner(userID string) (*dto.UserProfile, error) { // Unbind both sides, clear shell code err = s.db.Transaction(func(tx *gorm.DB) error { - if err := tx.Model(&model.User{}).Where("id = ?", userID).Updates(map[string]interface{}{ + if err := tx.Model(&model.User{}).Where(whereIDEquals, userID).Updates(map[string]interface{}{ "partner_id": nil, "shell_code": nil, }).Error; err != nil { return err } - if err := tx.Model(&model.User{}).Where("id = ?", partnerID).Updates(map[string]interface{}{ + if err := tx.Model(&model.User{}).Where(whereIDEquals, partnerID).Updates(map[string]interface{}{ "partner_id": nil, "shell_code": nil, }).Error; err != nil { diff --git a/frontend/index.html b/frontend/index.html index 90a40b82..c4400867 100755 --- a/frontend/index.html +++ b/frontend/index.html @@ -6,6 +6,7 @@ + diff --git a/frontend/src/App.vue b/frontend/src/App.vue index 5e7d3317..18362a62 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -15,7 +15,7 @@