diff --git a/.env.example b/.env.example
index 78b0cfa6..b8800fae 100644
--- a/.env.example
+++ b/.env.example
@@ -37,3 +37,9 @@ VITE_API_BASE_URL=http://localhost:8000
ADMIN_USERNAME=
ADMIN_EMAIL=
ADMIN_PASSWORD=
+
+# ==================== AI Service User (optional) ====================
+# Internal service account for AI-generated content. If not set,
+# a random password is generated on each startup (the AI user never logs in).
+AI_USER_USERNAME=xiaoshiguang
+AI_USER_PASSWORD=
diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go
index 69e259f2..3ae29253 100644
--- a/backend/cmd/server/main.go
+++ b/backend/cmd/server/main.go
@@ -1,8 +1,11 @@
package main
import (
+ "crypto/rand"
+ "encoding/hex"
"fmt"
"log"
+ "os"
"github.com/gin-gonic/gin"
"github.com/momshell/backend/internal/config"
@@ -74,7 +77,7 @@ func main() {
firecrawlClient = firecrawl.NewClient(cfg.FirecrawlAPIKey)
}
- chatService := service.NewChatService(chatClient, chatRepo, userRepo, firecrawlClient)
+ chatService := service.NewChatService(chatClient, chatRepo, userRepo, firecrawlClient, cfg.JWTSecretKey)
echoService := service.NewEchoService(chatClient, echoRepo, userRepo)
photoService := service.NewPhotoService(photoRepo, userRepo, chatClient, cfg.ImageModel)
whisperService := service.NewWhisperService(whisperRepo, userRepo, chatClient)
@@ -136,10 +139,21 @@ func main() {
router.Setup(
r, cfg,
adminHandler.IsAdmin,
- authHandler, questionHandler, answerHandler,
- commentHandler, interactionHandler, tagHandler,
- chatHandler, echoHandler, userHandler, adminHandler,
- photoHandler, whisperHandler, taskHandler,
+ &router.Handlers{
+ Auth: authHandler,
+ Question: questionHandler,
+ Answer: answerHandler,
+ Comment: commentHandler,
+ Interaction: interactionHandler,
+ Tag: tagHandler,
+ Chat: chatHandler,
+ Echo: echoHandler,
+ User: userHandler,
+ Admin: adminHandler,
+ Photo: photoHandler,
+ Whisper: whisperHandler,
+ Task: taskHandler,
+ },
)
// Start server
@@ -186,19 +200,35 @@ func createInitialAdmin(cfg *config.Config, userRepo *repository.UserRepo) {
}
func ensureAIUser(userRepo *repository.UserRepo) string {
- user, err := userRepo.FindByUsernameOrEmail("xiaoshiguang")
+ aiUsername := os.Getenv("AI_USER_USERNAME")
+ if aiUsername == "" {
+ aiUsername = "xiaoshiguang"
+ log.Println("[WARN] AI_USER_USERNAME not set, using default: xiaoshiguang")
+ }
+ aiPasswd := os.Getenv("AI_USER_PASSWORD")
+ if aiPasswd == "" {
+ // Generate a random password; AI user never logs in interactively.
+ b := make([]byte, 16)
+ if _, err := rand.Read(b); err != nil {
+ log.Fatalf("failed to generate random AI user password: %v", err)
+ }
+ aiPasswd = hex.EncodeToString(b)
+ log.Println("[WARN] AI_USER_PASSWORD not set, using random password")
+ }
+
+ user, err := userRepo.FindByUsernameOrEmail(aiUsername)
if err == nil {
return user.ID
}
- hash, err := password.Hash("ai-user-no-login")
+ hash, err := password.Hash(aiPasswd)
if err != nil {
log.Printf("Failed to hash AI user password: %v", err)
return ""
}
aiUser := &model.User{
- Username: "xiaoshiguang",
+ Username: aiUsername,
Email: "ai@momshell.com",
PasswordHash: hash,
Nickname: "小石光",
diff --git a/backend/internal/handler/admin.go b/backend/internal/handler/admin.go
index a6bcb24d..b22fe11a 100644
--- a/backend/internal/handler/admin.go
+++ b/backend/internal/handler/admin.go
@@ -1,6 +1,7 @@
package handler
import (
+ "log"
"net/http"
"github.com/gin-gonic/gin"
@@ -142,6 +143,9 @@ func (h *AdminHandler) UpdateConfig(c *gin.Context) {
return
}
+ adminID := middleware.GetUserID(c)
+ log.Printf("[SECURITY] admin_config_change | ip=%s | user=%s | detail=config update requested", c.ClientIP(), adminID)
+
if err := h.adminService.UpdateConfig(req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
diff --git a/backend/internal/handler/auth.go b/backend/internal/handler/auth.go
index e4265636..0eed8b66 100644
--- a/backend/internal/handler/auth.go
+++ b/backend/internal/handler/auth.go
@@ -66,7 +66,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
- resp, err := h.authService.Login(req)
+ resp, err := h.authService.Login(req, c.ClientIP())
if err != nil {
status := http.StatusUnauthorized
if err.Error() == "账号已禁用" || err.Error() == "账号已被封禁" {
@@ -184,8 +184,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return
}
- // TODO: In production, send email with reset link instead of logging
- _ = token // token would be sent via email
+ // In a production deployment, the token would be sent via email with a reset link.
+ // For now, it is intentionally unused to avoid leaking it in the response.
+ _ = token
c.JSON(http.StatusOK, gin.H{
"message": "如果该邮箱已注册,将收到重置密码邮件",
})
diff --git a/backend/internal/handler/interaction.go b/backend/internal/handler/interaction.go
index 87eaea9c..4d609ec8 100644
--- a/backend/internal/handler/interaction.go
+++ b/backend/internal/handler/interaction.go
@@ -26,7 +26,7 @@ func (h *InteractionHandler) CreateLike(c *gin.Context) {
}
userID := middleware.GetUserID(c)
- isLiked, newCount, err := h.communityService.ToggleLike(userID, req.TargetType, req.TargetID)
+ isLiked, newCount, err := h.communityService.AddLike(userID, req.TargetType, req.TargetID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -44,7 +44,7 @@ func (h *InteractionHandler) DeleteLike(c *gin.Context) {
}
userID := middleware.GetUserID(c)
- isLiked, newCount, err := h.communityService.ToggleLike(userID, req.TargetType, req.TargetID)
+ isLiked, newCount, err := h.communityService.RemoveLike(userID, req.TargetType, req.TargetID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
diff --git a/backend/internal/middleware/auth.go b/backend/internal/middleware/auth.go
index ffda4b5d..55cd5553 100644
--- a/backend/internal/middleware/auth.go
+++ b/backend/internal/middleware/auth.go
@@ -1,6 +1,7 @@
package middleware
import (
+ "log"
"net/http"
"strings"
@@ -82,15 +83,18 @@ func extractUserID(c *gin.Context, cfg *config.Config) (string, error) {
// Check token blacklist (revoked tokens)
if TokenBlacklist.IsBlacklisted(tokenStr) {
+ log.Printf("[SECURITY] jwt_validation_failed | ip=%s | user=unknown | detail=token is blacklisted", c.ClientIP())
return "", pkgjwt.ErrInvalidToken
}
claims, err := pkgjwt.ParseToken(tokenStr, cfg.JWTSecretKey)
if err != nil {
+ log.Printf("[SECURITY] jwt_validation_failed | ip=%s | user=unknown | detail=%v", c.ClientIP(), err)
return "", err
}
if claims.Type != "access" {
+ log.Printf("[SECURITY] jwt_validation_failed | ip=%s | user=%s | detail=invalid token type: %s", c.ClientIP(), claims.Subject, claims.Type)
return "", pkgjwt.ErrInvalidToken
}
diff --git a/backend/internal/middleware/cors.go b/backend/internal/middleware/cors.go
index 4b295647..0edb4d98 100644
--- a/backend/internal/middleware/cors.go
+++ b/backend/internal/middleware/cors.go
@@ -11,23 +11,24 @@ import (
func CORS(cfg *config.Config) gin.HandlerFunc {
corsConfig := cors.Config{
- AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
- AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Access-Token"},
- ExposeHeaders: []string{"Content-Length"},
- AllowCredentials: true,
- MaxAge: 12 * time.Hour,
+ AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
+ AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Access-Token"},
+ ExposeHeaders: []string{"Content-Length"},
+ MaxAge: 12 * time.Hour,
}
trimmed := strings.TrimSpace(cfg.CORSOrigins)
if trimmed == "" || trimmed == "*" {
- // Allow all origins (compatible with AllowCredentials)
+ // Allow all origins but disable credentials to prevent CSRF-like attacks
corsConfig.AllowOriginFunc = func(origin string) bool { return true }
+ corsConfig.AllowCredentials = false
} else {
origins := strings.Split(trimmed, ",")
for i := range origins {
origins[i] = strings.TrimSpace(origins[i])
}
corsConfig.AllowOrigins = origins
+ corsConfig.AllowCredentials = true
}
return cors.New(corsConfig)
diff --git a/backend/internal/middleware/security.go b/backend/internal/middleware/security.go
index c98c4275..058b8861 100644
--- a/backend/internal/middleware/security.go
+++ b/backend/internal/middleware/security.go
@@ -1,6 +1,10 @@
package middleware
-import "github.com/gin-gonic/gin"
+import (
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
// SecurityHeaders adds standard security headers to all responses.
func SecurityHeaders() gin.HandlerFunc {
@@ -10,6 +14,14 @@ func SecurityHeaders() gin.HandlerFunc {
c.Header("X-XSS-Protection", "1; mode=block")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
c.Header("Permissions-Policy", "camera=(self), microphone=(), geolocation=()")
+ c.Header("Strict-Transport-Security", "max-age=63072000; includeSubDomains")
+
+ // Add cache-control headers for API responses
+ if strings.HasPrefix(c.Request.URL.Path, "/api/") {
+ c.Header("Cache-Control", "no-store")
+ c.Header("Pragma", "no-cache")
+ }
+
c.Next()
}
}
diff --git a/backend/internal/middleware/tokenblacklist.go b/backend/internal/middleware/tokenblacklist.go
index 1aeb2ec1..24689a43 100644
--- a/backend/internal/middleware/tokenblacklist.go
+++ b/backend/internal/middleware/tokenblacklist.go
@@ -1,6 +1,7 @@
package middleware
import (
+ "log"
"sync"
"time"
)
@@ -22,7 +23,7 @@ func (b *tokenBlacklistStore) Add(token string, expiry time.Time) {
b.mu.Lock()
defer b.mu.Unlock()
if len(b.tokens) >= maxBlacklistSize {
- // Emergency cleanup: remove expired tokens immediately
+ // Emergency cleanup: remove expired tokens
now := time.Now()
for t, exp := range b.tokens {
if now.After(exp) {
@@ -30,9 +31,22 @@ func (b *tokenBlacklistStore) Add(token string, expiry time.Time) {
}
}
}
- // If still at capacity after cleanup, skip adding (token will simply not be blacklisted)
+ // If still at capacity after cleanup, evict the oldest entry (FIFO)
if len(b.tokens) >= maxBlacklistSize {
- return
+ log.Printf("[SECURITY] token_blacklist_full | size=%d | evicting oldest entry to make room", len(b.tokens))
+ var oldestToken string
+ var oldestExpiry time.Time
+ first := true
+ for t, exp := range b.tokens {
+ if first || exp.Before(oldestExpiry) {
+ oldestToken = t
+ oldestExpiry = exp
+ first = false
+ }
+ }
+ if oldestToken != "" {
+ delete(b.tokens, oldestToken)
+ }
}
b.tokens[token] = expiry
}
diff --git a/backend/internal/repository/chat.go b/backend/internal/repository/chat.go
index 9b6dbf1e..b413f878 100644
--- a/backend/internal/repository/chat.go
+++ b/backend/internal/repository/chat.go
@@ -38,7 +38,7 @@ func (r *ChatRepo) Create(m *model.ChatMemory) error {
}
// UpdateSummaryAndTurns updates only the summary and turns fields.
-func (r *ChatRepo) UpdateSummaryAndTurns(userID string, summary string, turns string) error {
+func (r *ChatRepo) UpdateSummaryAndTurns(userID, summary, turns string) error {
return r.db.Model(&model.ChatMemory{}).
Where(whereUserID, userID).
Updates(map[string]any{
@@ -139,7 +139,7 @@ func (r *ChatRepo) FactExistsByContentFamily(familyIDs []string, content string)
return count > 0, err
}
-func (r *ChatRepo) DeleteFactsByContentLikeFamily(familyIDs []string, phrases []string) error {
+func (r *ChatRepo) DeleteFactsByContentLikeFamily(familyIDs, phrases []string) error {
for _, phrase := range phrases {
phrase = strings.TrimSpace(phrase)
if phrase == "" {
diff --git a/backend/internal/router/router.go b/backend/internal/router/router.go
index 83cf42d2..bbdb9a28 100644
--- a/backend/internal/router/router.go
+++ b/backend/internal/router/router.go
@@ -1,6 +1,9 @@
package router
import (
+ "net/http"
+ "path/filepath"
+ "strings"
"time"
"github.com/gin-gonic/gin"
@@ -9,23 +12,33 @@ import (
"github.com/momshell/backend/internal/middleware"
)
+const (
+ routeUsers = "/users"
+ routeUserID = "/users/:id"
+)
+
+// Handlers groups all handler dependencies for route setup.
+type Handlers struct {
+ Auth *handler.AuthHandler
+ Question *handler.QuestionHandler
+ Answer *handler.AnswerHandler
+ Comment *handler.CommentHandler
+ Interaction *handler.InteractionHandler
+ Tag *handler.TagHandler
+ Chat *handler.ChatHandler
+ Echo *handler.EchoHandler
+ User *handler.UserHandler
+ Admin *handler.AdminHandler
+ Photo *handler.PhotoHandler
+ Whisper *handler.WhisperHandler
+ Task *handler.TaskHandler
+}
+
func Setup(
r *gin.Engine,
cfg *config.Config,
isAdmin middleware.AdminChecker,
- authHandler *handler.AuthHandler,
- questionHandler *handler.QuestionHandler,
- answerHandler *handler.AnswerHandler,
- commentHandler *handler.CommentHandler,
- interactionHandler *handler.InteractionHandler,
- tagHandler *handler.TagHandler,
- chatHandler *handler.ChatHandler,
- echoHandler *handler.EchoHandler,
- userHandler *handler.UserHandler,
- adminHandler *handler.AdminHandler,
- photoHandler *handler.PhotoHandler,
- whisperHandler *handler.WhisperHandler,
- taskHandler *handler.TaskHandler,
+ h *Handlers,
) {
// Rate limiters
authLimiter := middleware.RateLimit(10, 1*time.Minute) // 10 req/min for auth
@@ -37,30 +50,30 @@ func Setup(
c.JSON(200, gin.H{"status": "ok"})
})
- // Serve uploaded files
- r.Static("/uploads", "./uploads")
+ // Serve uploaded files with security restrictions
+ r.GET("/uploads/*filepath", secureStaticHandler("./uploads"))
// Admin panel (HTML page, no auth required for serving the page)
- r.GET("/admin", adminHandler.ServeAdminPage)
+ r.GET("/admin", h.Admin.ServeAdminPage)
api := r.Group("/api/v1", generalLimiter)
// ==================== Auth ====================
auth := api.Group("/auth")
{
- auth.POST("/register", authLimiter, authHandler.Register)
- auth.POST("/login", authLimiter, authHandler.Login)
- auth.POST("/refresh", authLimiter, authHandler.Refresh)
- auth.POST("/forgot-password", authLimiter, authHandler.ForgotPassword)
- auth.POST("/reset-password", authLimiter, authHandler.ResetPassword)
+ auth.POST("/register", authLimiter, h.Auth.Register)
+ auth.POST("/login", authLimiter, h.Auth.Login)
+ auth.POST("/refresh", authLimiter, h.Auth.Refresh)
+ auth.POST("/forgot-password", authLimiter, h.Auth.ForgotPassword)
+ auth.POST("/reset-password", authLimiter, h.Auth.ResetPassword)
authRequired := auth.Group("", middleware.AuthRequired(cfg))
{
- authRequired.POST("/change-password", authHandler.ChangePassword)
- authRequired.GET("/me", authHandler.GetMe)
- authRequired.PATCH("/me/role", authHandler.UpdateRole)
- authRequired.PATCH("/me/tutorial", authHandler.CompleteTutorial)
- authRequired.POST("/logout", authHandler.Logout)
+ authRequired.POST("/change-password", h.Auth.ChangePassword)
+ authRequired.GET("/me", h.Auth.GetMe)
+ authRequired.PATCH("/me/role", h.Auth.UpdateRole)
+ authRequired.PATCH("/me/tutorial", h.Auth.CompleteTutorial)
+ authRequired.POST("/logout", h.Auth.Logout)
}
}
@@ -70,145 +83,199 @@ func Setup(
// Questions (optional auth for read, required for write)
questions := community.Group("/questions")
{
- questions.GET("", middleware.AuthOptional(cfg), questionHandler.List)
- questions.GET("/hot", middleware.AuthOptional(cfg), questionHandler.ListHot)
- questions.GET("/channel/:channel", middleware.AuthOptional(cfg), questionHandler.ListByChannel)
- questions.GET("/:id", middleware.AuthOptional(cfg), questionHandler.Get)
+ questions.GET("", middleware.AuthOptional(cfg), h.Question.List)
+ questions.GET("/hot", middleware.AuthOptional(cfg), h.Question.ListHot)
+ questions.GET("/channel/:channel", middleware.AuthOptional(cfg), h.Question.ListByChannel)
+ questions.GET("/:id", middleware.AuthOptional(cfg), h.Question.Get)
questionsAuth := questions.Group("", middleware.AuthRequired(cfg))
{
- questionsAuth.POST("", questionHandler.Create)
- questionsAuth.PUT("/:id", questionHandler.Update)
- questionsAuth.DELETE("/:id", questionHandler.Delete)
+ questionsAuth.POST("", h.Question.Create)
+ questionsAuth.PUT("/:id", h.Question.Update)
+ questionsAuth.DELETE("/:id", h.Question.Delete)
}
// Answers under questions
- questions.GET("/:id/answers", middleware.AuthOptional(cfg), answerHandler.List)
- questions.POST("/:id/answers", middleware.AuthRequired(cfg), answerHandler.Create)
+ questions.GET("/:id/answers", middleware.AuthOptional(cfg), h.Answer.List)
+ questions.POST("/:id/answers", middleware.AuthRequired(cfg), h.Answer.Create)
}
// Answers (update/delete by answer ID)
answers := community.Group("/answers")
{
- answers.PUT("/:id", middleware.AuthRequired(cfg), answerHandler.Update)
- answers.DELETE("/:id", middleware.AuthRequired(cfg), answerHandler.Delete)
+ answers.PUT("/:id", middleware.AuthRequired(cfg), h.Answer.Update)
+ answers.DELETE("/:id", middleware.AuthRequired(cfg), h.Answer.Delete)
// Comments under answers
- answers.GET("/:id/comments", middleware.AuthOptional(cfg), commentHandler.List)
- answers.POST("/:id/comments", middleware.AuthRequired(cfg), commentHandler.Create)
+ answers.GET("/:id/comments", middleware.AuthOptional(cfg), h.Comment.List)
+ answers.POST("/:id/comments", middleware.AuthRequired(cfg), h.Comment.Create)
}
// Comments (update/delete by comment ID)
comments := community.Group("/comments")
{
- comments.PUT("/:id", middleware.AuthRequired(cfg), commentHandler.Update)
- comments.DELETE("/:id", middleware.AuthRequired(cfg), commentHandler.Delete)
+ comments.PUT("/:id", middleware.AuthRequired(cfg), h.Comment.Update)
+ comments.DELETE("/:id", middleware.AuthRequired(cfg), h.Comment.Delete)
}
// Likes
likes := community.Group("/likes", middleware.AuthRequired(cfg))
{
- likes.POST("", interactionHandler.CreateLike)
- likes.DELETE("", interactionHandler.DeleteLike)
+ likes.POST("", h.Interaction.CreateLike)
+ likes.DELETE("", h.Interaction.DeleteLike)
}
// Collections
collections := community.Group("/collections", middleware.AuthRequired(cfg))
{
- collections.POST("", interactionHandler.CreateCollection)
- collections.DELETE("/:id", interactionHandler.DeleteCollection)
- collections.GET("/my", interactionHandler.GetMyCollections)
+ collections.POST("", h.Interaction.CreateCollection)
+ collections.DELETE("/:id", h.Interaction.DeleteCollection)
+ collections.GET("/my", h.Interaction.GetMyCollections)
}
// Tags
tags := community.Group("/tags")
{
- tags.GET("", tagHandler.List)
- tags.GET("/hot", tagHandler.ListHot)
- tags.POST("", middleware.AdminRequired(cfg, isAdmin), tagHandler.Create)
+ tags.GET("", h.Tag.List)
+ tags.GET("/hot", h.Tag.ListHot)
+ tags.POST("", middleware.AdminRequired(cfg, isAdmin), h.Tag.Create)
}
// User profile (community context)
- users := community.Group("/users", middleware.AuthRequired(cfg))
+ users := community.Group(routeUsers, middleware.AuthRequired(cfg))
{
- users.GET("/me", userHandler.GetMe)
- users.PUT("/me", userHandler.UpdateMe)
- users.POST("/me/avatar", userHandler.UploadAvatar)
- users.POST("/me/shell-code", userHandler.GenerateShellCode)
- users.POST("/me/bind", userHandler.BindPartner)
- users.DELETE("/me/bind", userHandler.UnbindPartner)
- users.GET("/me/questions", userHandler.GetMyQuestions)
- users.GET("/me/answers", userHandler.GetMyAnswers)
+ users.GET("/me", h.User.GetMe)
+ users.PUT("/me", h.User.UpdateMe)
+ users.POST("/me/avatar", h.User.UploadAvatar)
+ users.POST("/me/shell-code", h.User.GenerateShellCode)
+ users.POST("/me/bind", h.User.BindPartner)
+ users.DELETE("/me/bind", h.User.UnbindPartner)
+ users.GET("/me/questions", h.User.GetMyQuestions)
+ users.GET("/me/answers", h.User.GetMyAnswers)
}
}
// ==================== Companion (AI Chat) ====================
companion := api.Group("/companion")
{
- companion.POST("/chat", aiLimiter, middleware.AuthOptional(cfg), chatHandler.Chat)
- companion.GET("/profile", middleware.AuthOptional(cfg), chatHandler.GetProfile)
- companion.GET("/memories", middleware.AuthRequired(cfg), chatHandler.GetMemories)
- companion.DELETE("/memories/:id", middleware.AuthRequired(cfg), chatHandler.DeleteMemory)
- companion.GET("/history", middleware.AuthRequired(cfg), chatHandler.GetHistory)
- companion.DELETE("/history", middleware.AuthRequired(cfg), chatHandler.ClearHistory)
+ companion.POST("/chat", aiLimiter, middleware.AuthOptional(cfg), h.Chat.Chat)
+ companion.GET("/profile", middleware.AuthOptional(cfg), h.Chat.GetProfile)
+ companion.GET("/memories", middleware.AuthRequired(cfg), h.Chat.GetMemories)
+ companion.DELETE("/memories/:id", middleware.AuthRequired(cfg), h.Chat.DeleteMemory)
+ companion.GET("/history", middleware.AuthRequired(cfg), h.Chat.GetHistory)
+ companion.DELETE("/history", middleware.AuthRequired(cfg), h.Chat.ClearHistory)
}
echo := api.Group("/echo", middleware.AuthRequired(cfg))
{
- echo.GET("/identity-tags", echoHandler.GetIdentityTags)
- echo.POST("/identity-tags", echoHandler.CreateIdentityTag)
- echo.DELETE("/identity-tags/:id", echoHandler.DeleteIdentityTag)
+ echo.GET("/identity-tags", h.Echo.GetIdentityTags)
+ echo.POST("/identity-tags", h.Echo.CreateIdentityTag)
+ echo.DELETE("/identity-tags/:id", h.Echo.DeleteIdentityTag)
- echo.GET("/memoirs", echoHandler.GetMemoirs)
- echo.POST("/memoirs/generate", aiLimiter, echoHandler.GenerateMemoir)
- echo.POST("/memoirs/:id/rate", echoHandler.RateMemoir)
+ echo.GET("/memoirs", h.Echo.GetMemoirs)
+ echo.POST("/memoirs/generate", aiLimiter, h.Echo.GenerateMemoir)
+ echo.POST("/memoirs/:id/rate", h.Echo.RateMemoir)
}
// ==================== Photos ====================
photos := api.Group("/photos", middleware.AuthRequired(cfg))
{
- photos.GET("", photoHandler.List)
- photos.POST("/upload", photoHandler.Upload)
- photos.POST("/generate", aiLimiter, photoHandler.Generate)
- photos.PUT("/wall", photoHandler.BatchUpdateWall)
- photos.PUT("/:id", photoHandler.Update)
- photos.DELETE("/:id", photoHandler.Delete)
- photos.PUT("/:id/wall", photoHandler.ToggleWall)
+ photos.GET("", h.Photo.List)
+ photos.POST("/upload", h.Photo.Upload)
+ photos.POST("/generate", aiLimiter, h.Photo.Generate)
+ photos.PUT("/wall", h.Photo.BatchUpdateWall)
+ photos.PUT("/:id", h.Photo.Update)
+ photos.DELETE("/:id", h.Photo.Delete)
+ photos.PUT("/:id/wall", h.Photo.ToggleWall)
}
// ==================== Whisper (Heart Words) ====================
whisper := api.Group("/whisper", middleware.AuthRequired(cfg))
{
- whisper.POST("", whisperHandler.Create)
- whisper.GET("", whisperHandler.List)
- whisper.GET("/tips", aiLimiter, whisperHandler.Tips)
+ whisper.POST("", h.Whisper.Create)
+ whisper.GET("", h.Whisper.List)
+ whisper.GET("/tips", aiLimiter, h.Whisper.Tips)
}
// ==================== Tasks ====================
tasks := api.Group("/tasks", middleware.AuthRequired(cfg))
{
- tasks.GET("/daily", taskHandler.DailyTasks)
- tasks.POST("/:id/complete", taskHandler.Complete)
- tasks.GET("/partner", taskHandler.PartnerTasks)
- tasks.POST("/:id/score", taskHandler.Score)
- tasks.POST("/:id/reject", taskHandler.Reject)
- tasks.GET("/stats", taskHandler.Stats)
- tasks.GET("/baby-age", taskHandler.GetBabyAge)
- tasks.PUT("/baby-age", taskHandler.SetBabyAge)
+ tasks.GET("/daily", h.Task.DailyTasks)
+ tasks.POST("/:id/complete", h.Task.Complete)
+ tasks.GET("/partner", h.Task.PartnerTasks)
+ tasks.POST("/:id/score", h.Task.Score)
+ tasks.POST("/:id/reject", h.Task.Reject)
+ tasks.GET("/stats", h.Task.Stats)
+ tasks.GET("/baby-age", h.Task.GetBabyAge)
+ tasks.PUT("/baby-age", h.Task.SetBabyAge)
}
// ==================== Admin ====================
adminAPI := api.Group("/admin", middleware.AdminRequired(cfg, isAdmin))
{
- adminAPI.GET("/stats", adminHandler.GetStats)
- adminAPI.GET("/users", adminHandler.ListUsers)
- adminAPI.GET("/users/:id", adminHandler.GetUser)
- adminAPI.POST("/users", adminHandler.CreateUser)
- adminAPI.PATCH("/users/:id", adminHandler.UpdateUser)
- adminAPI.DELETE("/users/:id", adminHandler.DeleteUser)
- adminAPI.GET("/config", adminHandler.GetConfig)
- adminAPI.PATCH("/config", adminHandler.UpdateConfig)
- adminAPI.GET("/photos", adminHandler.ListPhotos)
- adminAPI.DELETE("/photos/:id", adminHandler.DeletePhoto)
+ adminAPI.GET("/stats", h.Admin.GetStats)
+ adminAPI.GET(routeUsers, h.Admin.ListUsers)
+ adminAPI.GET(routeUserID, h.Admin.GetUser)
+ adminAPI.POST(routeUsers, h.Admin.CreateUser)
+ adminAPI.PATCH(routeUserID, h.Admin.UpdateUser)
+ adminAPI.DELETE(routeUserID, h.Admin.DeleteUser)
+ adminAPI.GET("/config", h.Admin.GetConfig)
+ adminAPI.PATCH("/config", h.Admin.UpdateConfig)
+ adminAPI.GET("/photos", h.Admin.ListPhotos)
+ adminAPI.DELETE("/photos/:id", h.Admin.DeletePhoto)
+ }
+}
+
+// allowedImageExts defines image extensions that are safe to serve inline.
+var allowedImageExts = map[string]bool{
+ ".jpg": true,
+ ".jpeg": true,
+ ".png": true,
+ ".gif": true,
+ ".webp": true,
+}
+
+// blockedExts defines file extensions that must never be served.
+var blockedExts = map[string]bool{
+ ".svg": true,
+ ".html": true,
+ ".htm": true,
+ ".xml": true,
+ ".js": true,
+ ".css": true,
+}
+
+// secureStaticHandler returns a handler that serves files from the given root
+// directory with security headers and file type restrictions.
+func secureStaticHandler(root string) gin.HandlerFunc {
+ fs := http.Dir(root)
+ fileServer := http.StripPrefix("/uploads", http.FileServer(fs))
+
+ return func(c *gin.Context) {
+ // Clean the path to prevent traversal attacks
+ reqPath := filepath.Clean(c.Param("filepath"))
+
+ // Ensure the cleaned path does not escape the root
+ if strings.Contains(reqPath, "..") {
+ c.AbortWithStatus(http.StatusForbidden)
+ return
+ }
+
+ ext := strings.ToLower(filepath.Ext(reqPath))
+
+ // Block dangerous file types
+ if blockedExts[ext] {
+ c.AbortWithStatus(http.StatusForbidden)
+ return
+ }
+
+ // Set security headers
+ c.Header("X-Content-Type-Options", "nosniff")
+
+ if allowedImageExts[ext] {
+ c.Header("Content-Disposition", "inline")
+ }
+
+ fileServer.ServeHTTP(c.Writer, c.Request)
}
}
diff --git a/backend/internal/service/auth.go b/backend/internal/service/auth.go
index fd4870f2..a3a7ba2a 100644
--- a/backend/internal/service/auth.go
+++ b/backend/internal/service/auth.go
@@ -3,6 +3,7 @@ package service
import (
"errors"
"fmt"
+ "log"
"github.com/momshell/backend/internal/config"
"github.com/momshell/backend/internal/dto"
@@ -13,6 +14,11 @@ import (
"gorm.io/gorm"
)
+const (
+ errPasswordHashFailed = "密码加密失败: %w"
+ errInvalidRefreshToken = "无效的刷新令牌"
+)
+
type AuthService struct {
cfg *config.Config
userRepo *repository.UserRepo
@@ -35,7 +41,7 @@ func (s *AuthService) Register(req dto.RegisterRequest) (*dto.UserResponse, erro
// Hash password
hash, err := password.Hash(req.Password)
if err != nil {
- return nil, fmt.Errorf("密码加密失败: %w", err)
+ return nil, fmt.Errorf(errPasswordHashFailed, err)
}
user := &model.User{
@@ -56,42 +62,50 @@ func (s *AuthService) Register(req dto.RegisterRequest) (*dto.UserResponse, erro
return s.buildUserResponse(user), nil
}
-func (s *AuthService) Login(req dto.LoginRequest) (*dto.TokenResponse, error) {
+func (s *AuthService) Login(req dto.LoginRequest, clientIP string) (*dto.TokenResponse, error) {
user, err := s.userRepo.FindByUsernameOrEmail(req.Login)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
+ log.Printf("[SECURITY] login_failed | ip=%s | user=%s | detail=user not found", clientIP, req.Login)
return nil, errors.New("用户名或密码错误")
}
return nil, fmt.Errorf("查询用户失败: %w", err)
}
if !password.Verify(req.Password, user.PasswordHash) {
+ log.Printf("[SECURITY] login_failed | ip=%s | user=%s | detail=wrong password", clientIP, req.Login)
return nil, errors.New("用户名或密码错误")
}
if !user.IsActive {
+ log.Printf("[SECURITY] login_failed | ip=%s | user=%s | detail=account disabled", clientIP, req.Login)
return nil, errors.New("账号已禁用")
}
if user.IsBanned {
+ log.Printf("[SECURITY] login_failed | ip=%s | user=%s | detail=account banned", clientIP, req.Login)
return nil, errors.New("账号已被封禁")
}
+ if user.IsAdmin {
+ log.Printf("[SECURITY] admin_login | ip=%s | user=%s | detail=admin login successful", clientIP, req.Login)
+ }
+
return s.generateTokens(user.ID)
}
func (s *AuthService) RefreshToken(refreshToken string) (*dto.TokenResponse, error) {
claims, err := pkgjwt.ParseToken(refreshToken, s.cfg.JWTSecretKey)
if err != nil {
- return nil, errors.New("无效的刷新令牌")
+ return nil, errors.New(errInvalidRefreshToken)
}
if claims.Type != "refresh" {
- return nil, errors.New("无效的刷新令牌")
+ return nil, errors.New(errInvalidRefreshToken)
}
user, err := s.userRepo.FindByID(claims.Subject)
if err != nil || !user.IsActive || user.IsBanned {
- return nil, errors.New("无效的刷新令牌")
+ return nil, errors.New(errInvalidRefreshToken)
}
return s.generateTokens(user.ID)
@@ -117,7 +131,7 @@ func (s *AuthService) ChangePassword(userID, oldPassword, newPassword string) er
hash, err := password.Hash(newPassword)
if err != nil {
- return fmt.Errorf("密码加密失败: %w", err)
+ return fmt.Errorf(errPasswordHashFailed, err)
}
return s.userRepo.UpdatePassword(userID, hash)
@@ -147,7 +161,7 @@ func (s *AuthService) ResetPassword(token, newPassword string) error {
hash, err := password.Hash(newPassword)
if err != nil {
- return fmt.Errorf("密码加密失败: %w", err)
+ return fmt.Errorf(errPasswordHashFailed, err)
}
return s.userRepo.UpdatePassword(userID, hash)
diff --git a/backend/internal/service/chat.go b/backend/internal/service/chat.go
index 81496ae0..0b3608bb 100644
--- a/backend/internal/service/chat.go
+++ b/backend/internal/service/chat.go
@@ -2,6 +2,9 @@ package service
import (
"context"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
"encoding/json"
"fmt"
"log"
@@ -217,6 +220,7 @@ type ChatService struct {
chatRepo *repository.ChatRepo
userRepo *repository.UserRepo
firecrawl *firecrawl.Client
+ jwtSecret string
// In-memory storage for guest sessions
mu sync.RWMutex
guestMemory map[string][]map[string]interface{}
@@ -224,12 +228,13 @@ type ChatService struct {
guestLastAccess map[string]time.Time
}
-func NewChatService(client *openai.Client, chatRepo *repository.ChatRepo, userRepo *repository.UserRepo, fc *firecrawl.Client) *ChatService {
+func NewChatService(client *openai.Client, chatRepo *repository.ChatRepo, userRepo *repository.UserRepo, fc *firecrawl.Client, jwtSecret string) *ChatService {
return &ChatService{
client: client,
chatRepo: chatRepo,
userRepo: userRepo,
firecrawl: fc,
+ jwtSecret: jwtSecret,
guestMemory: make(map[string][]map[string]interface{}),
guestProfiles: make(map[string]map[string]interface{}),
guestLastAccess: make(map[string]time.Time),
@@ -265,49 +270,62 @@ func (s *ChatService) Chat(ctx context.Context, msg dto.UserMessage, userID stri
}
// appendWebSearchResults appends web search results to the system prompt if available.
-func appendWebSearchResults(systemPrompt string, webResults string) string {
+func appendWebSearchResults(systemPrompt, webResults string) string {
if webResults == "" {
return systemPrompt
}
return systemPrompt + "\n\n## 联网搜索参考\n" + webResults + "\n日常聊天不需要引用来源。仅在提供专业性建议时才引用,引用时直接写出具体来源名称(如「根据XX的一篇文章...」),不要使用[来源1]这样的标注。不确定的信息请标明。"
}
-func (s *ChatService) chatAuthenticated(ctx context.Context, msg dto.UserMessage, userID string) (*dto.VisualResponse, error) {
- // Look up user role and partner info
- role := model.RoleMom
- isAdmin := false
- var partnerID string
- var partnerRole model.UserRole
- if user, err := s.userRepo.FindByID(userID); err == nil {
- role = user.Role
- isAdmin = user.IsAdmin
- if user.PartnerID != nil && *user.PartnerID != "" {
- partnerID = *user.PartnerID
- if partner, err := s.userRepo.FindByID(partnerID); err == nil {
- partnerRole = partner.Role
- }
+// chatUserContext holds resolved user context for authenticated chat.
+type chatUserContext struct {
+ role model.UserRole
+ isAdmin bool
+ partnerID string
+ partnerRole model.UserRole
+}
+
+// resolveChatUserContext looks up the user and partner information for chat.
+func (s *ChatService) resolveChatUserContext(userID string) chatUserContext {
+ ctx := chatUserContext{role: model.RoleMom}
+ user, err := s.userRepo.FindByID(userID)
+ if err != nil {
+ return ctx
+ }
+ ctx.role = user.Role
+ ctx.isAdmin = user.IsAdmin
+ if user.PartnerID != nil && *user.PartnerID != "" {
+ ctx.partnerID = *user.PartnerID
+ if partner, pErr := s.userRepo.FindByID(ctx.partnerID); pErr == nil {
+ ctx.partnerRole = partner.Role
}
}
- pronoun := pronounFor(role)
+ return ctx
+}
+
+func (s *ChatService) chatAuthenticated(ctx context.Context, msg dto.UserMessage, userID string) (*dto.VisualResponse, error) {
+ // Look up user role and partner info
+ uc := s.resolveChatUserContext(userID)
+ pronoun := pronounFor(uc.role)
familyIDs := []string{userID}
- if partnerID != "" {
- familyIDs = append(familyIDs, partnerID)
+ if uc.partnerID != "" {
+ familyIDs = append(familyIDs, uc.partnerID)
}
// Load memory from DB (per-user, not shared)
profile, turns, summary := s.loadUserMemory(userID)
// Load structured facts for prompt (family-scoped)
- factsText, deletedFactsText := s.loadFactsForPrompt(userID, familyIDs, role, partnerRole)
+ factsText, deletedFactsText := s.loadFactsForPrompt(userID, familyIDs, uc.role, uc.partnerRole)
- systemPrompt := fmt.Sprintf(getCompanionPrompt(role, isAdmin),
+ systemPrompt := fmt.Sprintf(getCompanionPrompt(uc.role, uc.isAdmin),
formatProfile(profile, pronoun, factsText),
formatTurns(turns, summary, pronoun),
)
// Update memory section header for family mode
- if partnerID != "" {
+ if uc.partnerID != "" {
for _, old := range []string{"你记得关于她的重要信息", "你记得关于他的重要信息", "你记得关于对方的重要信息"} {
systemPrompt = strings.Replace(systemPrompt, old, "你记得关于这个家庭的重要信息", 1)
}
@@ -394,8 +412,16 @@ func (s *ChatService) chatGuest(ctx context.Context, msg dto.UserMessage) (*dto.
if msg.SessionID != nil {
sessionID = *msg.SessionID
}
+
+ // Validate HMAC signature if session_id is provided
+ if sessionID != "" {
+ if !s.verifySessionID(sessionID) {
+ // Invalid signature; generate a new signed session
+ sessionID = ""
+ }
+ }
if sessionID == "" {
- sessionID = uuid.New().String()
+ sessionID = s.signSessionID(uuid.New().String())
}
s.mu.Lock()
@@ -425,7 +451,7 @@ func (s *ChatService) chatGuest(ctx context.Context, msg dto.UserMessage) (*dto.
{Role: "user", Content: msg.Content},
}
- rawContent, err := s.client.Chat(context.Background(), messages)
+ rawContent, err := s.client.Chat(ctx, messages)
if err != nil {
return nil, fmt.Errorf("AI 服务调用失败: %w", err)
}
@@ -449,6 +475,31 @@ func (s *ChatService) chatGuest(ctx context.Context, msg dto.UserMessage) (*dto.
return buildVisualResponse(parsed, memoryUpdated), nil
}
+// --- Guest session HMAC signing ---
+
+// signSessionID creates a signed session ID in the format "uuid:hmac_hex".
+func (s *ChatService) signSessionID(id string) string {
+ mac := hmac.New(sha256.New, []byte(s.jwtSecret))
+ mac.Write([]byte(id))
+ sig := hex.EncodeToString(mac.Sum(nil))
+ return id + ":" + sig
+}
+
+// verifySessionID validates that a session ID has a valid HMAC signature.
+func (s *ChatService) verifySessionID(sessionID string) bool {
+ idx := strings.LastIndex(sessionID, ":")
+ if idx < 0 || idx == len(sessionID)-1 {
+ return false
+ }
+ id := sessionID[:idx]
+ sig := sessionID[idx+1:]
+
+ mac := hmac.New(sha256.New, []byte(s.jwtSecret))
+ mac.Write([]byte(id))
+ expected := hex.EncodeToString(mac.Sum(nil))
+ return hmac.Equal([]byte(sig), []byte(expected))
+}
+
// --- Phase 2: Conversation Summary ---
func (s *ChatService) generateAndSaveSummary(userID string, existingSummary string, oldTurns []map[string]interface{}) {
@@ -499,6 +550,45 @@ func (s *ChatService) generateAndSaveSummary(userID string, existingSummary stri
// --- Phase 3: Structured Memory Facts ---
+// parseFactItem extracts content and category from a single fact item.
+func parseFactItem(v interface{}) (content, aiCategory string) {
+ switch item := v.(type) {
+ case map[string]interface{}:
+ c, _ := item["content"].(string)
+ content = strings.TrimSpace(c)
+ cat, _ := item["category"].(string)
+ aiCategory = strings.TrimSpace(cat)
+ case string:
+ content = strings.TrimSpace(item)
+ }
+ return content, aiCategory
+}
+
+// saveSingleFact creates a fact if it does not already exist. Returns true if saved.
+func (s *ChatService) saveSingleFact(userID, content, aiCategory string) bool {
+ exists, err := s.chatRepo.FactExistsByContent(userID, content)
+ if err != nil {
+ log.Printf("[ChatService] failed to check fact existence: %v", err)
+ return false
+ }
+ if exists {
+ return false
+ }
+
+ category := resolveFactCategory(aiCategory, content)
+ fact := &model.ChatMemoryFact{
+ UserID: userID,
+ OwnerUserID: userID,
+ Content: content,
+ Category: category,
+ }
+ if err := s.chatRepo.CreateFact(fact); err != nil {
+ log.Printf("[ChatService] failed to save fact for user %s: %v", userID, err)
+ return false
+ }
+ return true
+}
+
func (s *ChatService) saveFactsFromExtract(userID string, extract interface{}) bool {
if extract == nil {
return false
@@ -514,47 +604,11 @@ func (s *ChatService) saveFactsFromExtract(userID string, extract interface{}) b
saved := false
for _, v := range facts {
- var content string
- var aiCategory string
-
- switch item := v.(type) {
- case map[string]interface{}:
- // New structured format: {"content": "...", "category": "..."}
- c, _ := item["content"].(string)
- content = strings.TrimSpace(c)
- cat, _ := item["category"].(string)
- aiCategory = strings.TrimSpace(cat)
- case string:
- // Legacy string format
- content = strings.TrimSpace(item)
- default:
- continue
- }
-
+ content, aiCategory := parseFactItem(v)
if content == "" {
continue
}
-
- // Skip if identical fact already exists for this user (including soft-deleted)
- exists, err := s.chatRepo.FactExistsByContent(userID, content)
- if err != nil {
- log.Printf("[ChatService] failed to check fact existence: %v", err)
- continue
- }
- if exists {
- continue
- }
-
- category := resolveFactCategory(aiCategory, content)
- fact := &model.ChatMemoryFact{
- UserID: userID,
- OwnerUserID: userID,
- Content: content,
- Category: category,
- }
- if err := s.chatRepo.CreateFact(fact); err != nil {
- log.Printf("[ChatService] failed to save fact for user %s: %v", userID, err)
- } else {
+ if s.saveSingleFact(userID, content, aiCategory) {
saved = true
}
}
@@ -626,7 +680,7 @@ func (s *ChatService) processMemoryCorrections(familyIDs []string, extract inter
}
// resolveFactCategory uses the AI-provided category if valid, otherwise falls back to keyword detection.
-func resolveFactCategory(aiCategory string, content string) model.FactCategory {
+func resolveFactCategory(aiCategory, content string) model.FactCategory {
switch model.FactCategory(aiCategory) {
case model.FactCategoryPersonalInfo, model.FactCategoryFamily,
model.FactCategoryInterest, model.FactCategoryConcern,
@@ -677,7 +731,18 @@ func categorizeFactContent(content string) model.FactCategory {
return model.FactCategoryOther
}
-func (s *ChatService) loadFactsForPrompt(userID string, familyIDs []string, userRole model.UserRole, partnerRole model.UserRole) (string, string) {
+// factLabel returns the display label for a fact in family mode.
+func factLabel(f model.ChatMemoryFact, userID string, userRole, partnerRole model.UserRole) string {
+ if f.Category == model.FactCategoryFamily {
+ return "家庭"
+ }
+ if f.OwnerUserID == userID {
+ return pronounFor(userRole)
+ }
+ return pronounFor(partnerRole)
+}
+
+func (s *ChatService) loadFactsForPrompt(userID string, familyIDs []string, userRole, partnerRole model.UserRole) (string, string) {
hasPartner := len(familyIDs) > 1
facts, err := s.chatRepo.FindFactsByFamilyIDs(familyIDs)
@@ -690,14 +755,7 @@ func (s *ChatService) loadFactsForPrompt(userID string, familyIDs []string, user
var sb strings.Builder
for _, f := range facts {
if hasPartner {
- var label string
- if f.Category == model.FactCategoryFamily {
- label = "家庭"
- } else if f.OwnerUserID == userID {
- label = pronounFor(userRole)
- } else {
- label = pronounFor(partnerRole)
- }
+ label := factLabel(f, userID, userRole, partnerRole)
fmt.Fprintf(&sb, " · [%s] %s\n", label, f.Content)
} else {
fmt.Fprintf(&sb, " · %s\n", f.Content)
@@ -822,6 +880,10 @@ func (s *ChatService) searchWebForChat(ctx context.Context, userMessage string)
}
var sb strings.Builder
for _, r := range results {
+ // Validate URL scheme before including in output
+ if !strings.HasPrefix(r.URL, "https://") && !strings.HasPrefix(r.URL, "http://") {
+ continue
+ }
content := r.Markdown
if content == "" {
content = r.Description
diff --git a/backend/internal/service/community.go b/backend/internal/service/community.go
index 4c329e14..c23d9ebb 100644
--- a/backend/internal/service/community.go
+++ b/backend/internal/service/community.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "html"
"time"
"github.com/momshell/backend/internal/dto"
@@ -12,6 +13,11 @@ import (
"gorm.io/gorm"
)
+const (
+ errExpertPostSourceRequired = "专家帖必须标注来源依据"
+ errContentModerationFailed = "内容审核未通过: %s"
+)
+
type CommunityService struct {
questionRepo *repository.QuestionRepo
answerRepo *repository.AnswerRepo
@@ -251,6 +257,10 @@ func (s *CommunityService) GetQuestion(questionID, currentUserID string) (*dto.Q
}
func (s *CommunityService) CreateQuestion(req dto.QuestionCreate, author *model.User) (*model.Question, error) {
+ // Sanitize user content
+ req.Title = sanitizeHTML(req.Title)
+ req.Content = sanitizeHTML(req.Content)
+
// Content moderation
titleDecision := s.moderation.ModerateText(req.Title)
if titleDecision.Result == model.ModerationRejected {
@@ -259,7 +269,7 @@ func (s *CommunityService) CreateQuestion(req dto.QuestionCreate, author *model.
contentDecision := s.moderation.ModerateText(req.Content)
if contentDecision.Result == model.ModerationRejected {
- return nil, fmt.Errorf("内容审核未通过: %s", derefStr(contentDecision.Reason))
+ return nil, fmt.Errorf(errContentModerationFailed, derefStr(contentDecision.Reason))
}
// Determine status
@@ -311,6 +321,8 @@ func (s *CommunityService) moderateAndUpdateTextField(q *model.Question, newText
if newText == nil {
return nil
}
+ sanitized := sanitizeHTML(*newText)
+ newText = &sanitized
decision := s.moderation.ModerateText(*newText)
if decision.Result == model.ModerationRejected {
return fmt.Errorf("%s审核未通过: %s", fieldName, derefStr(decision.Reason))
@@ -444,6 +456,9 @@ func (s *CommunityService) GetAnswers(questionID string, params dto.AnswerListPa
}
func (s *CommunityService) CreateAnswer(questionID string, req dto.AnswerCreate, author *model.User) (*model.Answer, error) {
+ // Sanitize user content
+ req.Content = sanitizeHTML(req.Content)
+
// Check question exists
_, err := s.questionRepo.FindByID(questionID)
if err != nil {
@@ -456,14 +471,14 @@ func (s *CommunityService) CreateAnswer(questionID string, req dto.AnswerCreate,
return nil, errors.New("仅认证专业人士可发布专家帖")
}
if req.Sources == "" {
- return nil, errors.New("专家帖必须标注来源依据")
+ return nil, errors.New(errExpertPostSourceRequired)
}
}
// Content moderation
decision := s.moderation.ModerateText(req.Content)
if decision.Result == model.ModerationRejected {
- return nil, fmt.Errorf("内容审核未通过: %s", derefStr(decision.Reason))
+ return nil, fmt.Errorf(errContentModerationFailed, derefStr(decision.Reason))
}
status := model.StatusPublished
@@ -504,44 +519,45 @@ func (s *CommunityService) CreateAnswer(questionID string, req dto.AnswerCreate,
return answer, nil
}
-func (s *CommunityService) UpdateAnswer(answerID string, req dto.AnswerUpdate, user *model.User) (*model.Answer, error) {
- a, err := s.answerRepo.FindByID(answerID)
- if err != nil {
- if errors.Is(err, gorm.ErrRecordNotFound) {
- return nil, errors.New("回答不存在")
- }
- return nil, err
+// updateAnswerContent moderates and applies content update to an answer.
+func (s *CommunityService) updateAnswerContent(a *model.Answer, content *string) error {
+ if content == nil {
+ return nil
}
-
- if a.AuthorID != user.ID && !user.IsAdmin {
- return nil, errors.New("无权修改此回答")
+ sanitized := sanitizeHTML(*content)
+ content = &sanitized
+ decision := s.moderation.ModerateText(*content)
+ if decision.Result == model.ModerationRejected {
+ return fmt.Errorf(errContentModerationFailed, derefStr(decision.Reason))
+ }
+ a.Content = *content
+ if decision.Result == model.ModerationNeedManualReview {
+ a.Status = model.StatusPendingReview
}
+ return nil
+}
- if req.Content != nil {
- decision := s.moderation.ModerateText(*req.Content)
- if decision.Result == model.ModerationRejected {
- return nil, fmt.Errorf("内容审核未通过: %s", derefStr(decision.Reason))
- }
- a.Content = *req.Content
- if decision.Result == model.ModerationNeedManualReview {
- a.Status = model.StatusPendingReview
- }
+// resolveExpertPostSources determines the sources value for an expert post update.
+func resolveExpertPostSources(reqSources *string, existingSources *string) string {
+ if reqSources != nil {
+ return *reqSources
+ }
+ if existingSources != nil {
+ return *existingSources
}
+ return ""
+}
- // Update expert post fields
+// updateExpertPostFields handles the expert post flag and sources fields on answer update.
+func (s *CommunityService) updateExpertPostFields(a *model.Answer, req dto.AnswerUpdate, user *model.User) error {
if req.IsExpertPost != nil {
if *req.IsExpertPost {
if !s.IsCertifiedProfessional(user) && !user.IsAdmin {
- return nil, errors.New("仅认证专业人士可发布专家帖")
- }
- sources := ""
- if req.Sources != nil {
- sources = *req.Sources
- } else if a.Sources != nil {
- sources = *a.Sources
+ return errors.New("仅认证专业人士可发布专家帖")
}
+ sources := resolveExpertPostSources(req.Sources, a.Sources)
if sources == "" {
- return nil, errors.New("专家帖必须标注来源依据")
+ return errors.New(errExpertPostSourceRequired)
}
a.IsExpertPost = true
a.Sources = &sources
@@ -551,10 +567,33 @@ func (s *CommunityService) UpdateAnswer(answerID string, req dto.AnswerUpdate, u
}
if req.Sources != nil && (req.IsExpertPost == nil || !*req.IsExpertPost) {
if a.IsExpertPost && *req.Sources == "" {
- return nil, errors.New("专家帖必须标注来源依据")
+ return errors.New(errExpertPostSourceRequired)
}
a.Sources = req.Sources
}
+ return nil
+}
+
+func (s *CommunityService) UpdateAnswer(answerID string, req dto.AnswerUpdate, user *model.User) (*model.Answer, error) {
+ a, err := s.answerRepo.FindByID(answerID)
+ if err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return nil, errors.New("回答不存在")
+ }
+ return nil, err
+ }
+
+ if a.AuthorID != user.ID && !user.IsAdmin {
+ return nil, errors.New("无权修改此回答")
+ }
+
+ if err := s.updateAnswerContent(a, req.Content); err != nil {
+ return nil, err
+ }
+
+ if err := s.updateExpertPostFields(a, req, user); err != nil {
+ return nil, err
+ }
a.UpdatedAt = time.Now()
if err := s.answerRepo.Update(a); err != nil {
@@ -641,6 +680,9 @@ func (s *CommunityService) GetComments(answerID, currentUserID string) ([]dto.Co
}
func (s *CommunityService) CreateComment(answerID string, req dto.CommentCreate, user *model.User) (*dto.CommentListItem, error) {
+ // Sanitize user content
+ req.Content = sanitizeHTML(req.Content)
+
// Check answer exists
_, err := s.answerRepo.FindByID(answerID)
if err != nil {
@@ -709,6 +751,8 @@ func (s *CommunityService) UpdateComment(commentID string, req dto.CommentUpdate
return nil, errors.New("无权修改此评论")
}
+ req.Content = sanitizeHTML(req.Content)
+
decision := s.moderation.ModerateText(req.Content)
if decision.Result == model.ModerationRejected {
return nil, fmt.Errorf("评论审核未通过: %s", derefStr(decision.Reason))
@@ -782,6 +826,48 @@ func (s *CommunityService) ToggleLike(userID, targetType, targetID string) (bool
return true, count, nil
}
+// AddLike creates a like only if one doesn't already exist. Returns the
+// current state and count without toggling.
+func (s *CommunityService) AddLike(userID, targetType, targetID string) (bool, int, error) {
+ _, err := s.interactionRepo.FindLike(userID, targetType, targetID)
+ if err == nil {
+ // Already liked — return current state (idempotent)
+ count := s.getTargetLikeCount(targetType, targetID)
+ return true, count, nil
+ }
+
+ like := &model.Like{
+ UserID: userID,
+ TargetType: targetType,
+ TargetID: targetID,
+ }
+ if err := s.interactionRepo.CreateLike(like); err != nil {
+ return false, 0, err
+ }
+ s.updateTargetLikeCount(targetType, targetID, 1)
+ count := s.getTargetLikeCount(targetType, targetID)
+ return true, count, nil
+}
+
+// RemoveLike removes a like only if one exists. Returns the current state
+// and count without toggling. Calling DELETE on an already-unliked resource
+// is a no-op (idempotent).
+func (s *CommunityService) RemoveLike(userID, targetType, targetID string) (bool, int, error) {
+ _, err := s.interactionRepo.FindLike(userID, targetType, targetID)
+ if err != nil {
+ // Not liked — return current state (idempotent)
+ count := s.getTargetLikeCount(targetType, targetID)
+ return false, count, nil
+ }
+
+ if err := s.interactionRepo.DeleteLike(userID, targetType, targetID); err != nil {
+ return false, 0, err
+ }
+ s.updateTargetLikeCount(targetType, targetID, -1)
+ count := s.getTargetLikeCount(targetType, targetID)
+ return false, count, nil
+}
+
func (s *CommunityService) updateTargetLikeCount(targetType, targetID string, delta int) {
switch targetType {
case "question":
@@ -964,3 +1050,8 @@ func derefStr(s *string) string {
}
return *s
}
+
+// sanitizeHTML escapes HTML entities in user-provided content to prevent XSS.
+func sanitizeHTML(s string) string {
+ return html.EscapeString(s)
+}
diff --git a/backend/internal/service/echo.go b/backend/internal/service/echo.go
index 58b7b9a5..2f7b0cff 100644
--- a/backend/internal/service/echo.go
+++ b/backend/internal/service/echo.go
@@ -15,9 +15,10 @@ import (
)
const (
- memoirDefaultTitle = "一段温柔的回响"
- memoirKeyTitle = "title"
- memoirKeyContent = "content"
+ memoirDefaultTitle = "一段温柔的回响"
+ memoirKeyTitle = "title"
+ memoirKeyContent = "content"
+ trimWhitespaceChars = " \t\n\r"
)
type EchoService struct {
@@ -319,13 +320,13 @@ func parseMemoirLLMResponse(content string) map[string]interface{} {
func cleanMemoirText(s string) string {
s = strings.TrimSpace(s)
// Remove trailing markdown code fences and JSON braces
- s = strings.TrimRight(s, " \t\n\r")
+ s = strings.TrimRight(s, trimWhitespaceChars)
for {
trimmed := s
trimmed = strings.TrimSuffix(trimmed, "```")
trimmed = strings.TrimSuffix(trimmed, "}")
trimmed = strings.TrimSuffix(trimmed, `"`)
- trimmed = strings.TrimRight(trimmed, " \t\n\r")
+ trimmed = strings.TrimRight(trimmed, trimWhitespaceChars)
if trimmed == s {
break
}
@@ -337,7 +338,7 @@ func cleanMemoirText(s string) string {
trimmed = strings.TrimPrefix(trimmed, "```json")
trimmed = strings.TrimPrefix(trimmed, "```")
trimmed = strings.TrimPrefix(trimmed, "{")
- trimmed = strings.TrimLeft(trimmed, " \t\n\r")
+ trimmed = strings.TrimLeft(trimmed, trimWhitespaceChars)
if trimmed == s {
break
}
diff --git a/backend/internal/service/task.go b/backend/internal/service/task.go
index 5dede137..aae0646b 100644
--- a/backend/internal/service/task.go
+++ b/backend/internal/service/task.go
@@ -334,7 +334,7 @@ func (s *TaskService) createTasksForUser(user *model.User, userID string, date t
}
// SetBabyAge sets the baby age stage for the user and immediately regenerates tasks.
-func (s *TaskService) SetBabyAge(userID string, ageStage string) error {
+func (s *TaskService) SetBabyAge(userID, ageStage string) error {
user, err := s.userRepo.FindByID(userID)
if err != nil {
return errors.New(errUserNotFound)
diff --git a/backend/internal/service/task_ai.go b/backend/internal/service/task_ai.go
index 11bce66c..fcdfdf88 100644
--- a/backend/internal/service/task_ai.go
+++ b/backend/internal/service/task_ai.go
@@ -41,6 +41,24 @@ func coupleKey(a, b string) string {
return b + "-" + a
}
+// resolveAgeStageFromChatMemory tries to infer the age stage from chat memory facts.
+func resolveAgeStageFromChatMemory(chatRepo *repository.ChatRepo, familyIDs []string) (string, bool) {
+ facts, err := chatRepo.FindFactsByFamilyIDs(familyIDs)
+ if err != nil {
+ return "", false
+ }
+ for _, f := range facts {
+ lower := strings.ToLower(f.Content)
+ if strings.Contains(lower, "宝宝") || strings.Contains(lower, "baby") || strings.Contains(lower, "孩子") {
+ stage := inferAgeStageFromMemory(f.Content)
+ if stage != "" {
+ return stage, true
+ }
+ }
+ }
+ return "", false
+}
+
// resolveAgeStage returns the baby age stage for the couple, checking:
// 1. The user's own BabyAgeStage
// 2. The partner's BabyAgeStage
@@ -69,17 +87,8 @@ func resolveAgeStage(
if user.PartnerID != nil {
familyIDs = append(familyIDs, *user.PartnerID)
}
- facts, err := chatRepo.FindFactsByFamilyIDs(familyIDs)
- if err == nil {
- for _, f := range facts {
- lower := strings.ToLower(f.Content)
- if strings.Contains(lower, "宝宝") || strings.Contains(lower, "baby") || strings.Contains(lower, "孩子") {
- stage := inferAgeStageFromMemory(f.Content)
- if stage != "" {
- return stage, "memory"
- }
- }
- }
+ if s, found := resolveAgeStageFromChatMemory(chatRepo, familyIDs); found {
+ return s, "memory"
}
}
diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go
index 6979e1f0..ceb2e561 100644
--- a/backend/internal/service/user.go
+++ b/backend/internal/service/user.go
@@ -11,7 +11,10 @@ import (
"gorm.io/gorm"
)
-const errUserServiceUserNotFound = "用户不存在"
+const (
+ errUserServiceUserNotFound = "用户不存在"
+ whereIDEquals = "id = ?"
+)
type UserService struct {
db *gorm.DB
@@ -86,6 +89,21 @@ func (s *UserService) GetProfile(userID string) (*dto.UserProfile, error) {
return profile, nil
}
+// validateRoleChange checks if a role update is allowed for the user.
+func validateRoleChange(user *model.User, newRoleStr string) error {
+ newRole := model.UserRole(newRoleStr)
+ if !model.FamilyRoles[newRole] {
+ return errors.New("角色只能是: mom, dad")
+ }
+ if model.ProfessionalRoles[user.Role] {
+ return errors.New("认证专业人员不能修改角色")
+ }
+ if user.PartnerID != nil {
+ return errors.New("已绑定伴侣,无法更改身份")
+ }
+ return nil
+}
+
// applyProfileFieldUpdates applies the individual field updates from the request to the user model.
// Returns an error if any validation fails (e.g. duplicate username/email, invalid role).
func (s *UserService) applyProfileFieldUpdates(user *model.User, req dto.UserProfileUpdate) error {
@@ -114,17 +132,10 @@ func (s *UserService) applyProfileFieldUpdates(user *model.User, req dto.UserPro
}
if req.Role != nil {
- newRole := model.UserRole(*req.Role)
- if !model.FamilyRoles[newRole] {
- return errors.New("角色只能是: mom, dad")
- }
- if model.ProfessionalRoles[user.Role] {
- return errors.New("认证专业人员不能修改角色")
- }
- if user.PartnerID != nil {
- return errors.New("已绑定伴侣,无法更改身份")
+ if err := validateRoleChange(user, *req.Role); err != nil {
+ return err
}
- user.Role = newRole
+ user.Role = model.UserRole(*req.Role)
}
return nil
@@ -178,7 +189,7 @@ func (s *UserService) GenerateShellCode(userID string) (*dto.UserProfile, error)
}
// BindPartner binds a 守护者 (dad) to a 溯源者 (mom) via shell code.
-func (s *UserService) BindPartner(userID string, shellCode string) (*dto.UserProfile, error) {
+func (s *UserService) BindPartner(userID, shellCode string) (*dto.UserProfile, error) {
user, err := s.userRepo.FindByID(userID)
if err != nil {
return nil, errors.New(errUserServiceUserNotFound)
@@ -211,10 +222,10 @@ func (s *UserService) BindPartner(userID string, shellCode string) (*dto.UserPro
// Bind both sides in a transaction
err = s.db.Transaction(func(tx *gorm.DB) error {
- if err := tx.Model(&model.User{}).Where("id = ?", user.ID).Update("partner_id", partner.ID).Error; err != nil {
+ if err := tx.Model(&model.User{}).Where(whereIDEquals, user.ID).Update("partner_id", partner.ID).Error; err != nil {
return err
}
- if err := tx.Model(&model.User{}).Where("id = ?", partner.ID).Update("partner_id", user.ID).Error; err != nil {
+ if err := tx.Model(&model.User{}).Where(whereIDEquals, partner.ID).Update("partner_id", user.ID).Error; err != nil {
return err
}
return nil
@@ -241,13 +252,13 @@ func (s *UserService) UnbindPartner(userID string) (*dto.UserProfile, error) {
// Unbind both sides, clear shell code
err = s.db.Transaction(func(tx *gorm.DB) error {
- if err := tx.Model(&model.User{}).Where("id = ?", userID).Updates(map[string]interface{}{
+ if err := tx.Model(&model.User{}).Where(whereIDEquals, userID).Updates(map[string]interface{}{
"partner_id": nil,
"shell_code": nil,
}).Error; err != nil {
return err
}
- if err := tx.Model(&model.User{}).Where("id = ?", partnerID).Updates(map[string]interface{}{
+ if err := tx.Model(&model.User{}).Where(whereIDEquals, partnerID).Updates(map[string]interface{}{
"partner_id": nil,
"shell_code": nil,
}).Error; err != nil {
diff --git a/frontend/index.html b/frontend/index.html
index 90a40b82..c4400867 100755
--- a/frontend/index.html
+++ b/frontend/index.html
@@ -6,6 +6,7 @@
+
diff --git a/frontend/src/App.vue b/frontend/src/App.vue
index 5e7d3317..18362a62 100644
--- a/frontend/src/App.vue
+++ b/frontend/src/App.vue
@@ -15,7 +15,7 @@