Skip to content
Merged

Dev #168

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,9 @@ VITE_API_BASE_URL=http://localhost:8000
ADMIN_USERNAME=
ADMIN_EMAIL=
ADMIN_PASSWORD=

# ==================== AI Service User (optional) ====================
# Internal service account for AI-generated content. If not set,
# a random password is generated on each startup (the AI user never logs in).
AI_USER_USERNAME=xiaoshiguang
AI_USER_PASSWORD=
46 changes: 38 additions & 8 deletions backend/cmd/server/main.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package main

import (
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"os"

"github.com/gin-gonic/gin"
"github.com/momshell/backend/internal/config"
Expand Down Expand Up @@ -74,7 +77,7 @@ func main() {
firecrawlClient = firecrawl.NewClient(cfg.FirecrawlAPIKey)
}

chatService := service.NewChatService(chatClient, chatRepo, userRepo, firecrawlClient)
chatService := service.NewChatService(chatClient, chatRepo, userRepo, firecrawlClient, cfg.JWTSecretKey)
echoService := service.NewEchoService(chatClient, echoRepo, userRepo)
photoService := service.NewPhotoService(photoRepo, userRepo, chatClient, cfg.ImageModel)
whisperService := service.NewWhisperService(whisperRepo, userRepo, chatClient)
Expand Down Expand Up @@ -136,10 +139,21 @@ func main() {
router.Setup(
r, cfg,
adminHandler.IsAdmin,
authHandler, questionHandler, answerHandler,
commentHandler, interactionHandler, tagHandler,
chatHandler, echoHandler, userHandler, adminHandler,
photoHandler, whisperHandler, taskHandler,
&router.Handlers{
Auth: authHandler,
Question: questionHandler,
Answer: answerHandler,
Comment: commentHandler,
Interaction: interactionHandler,
Tag: tagHandler,
Chat: chatHandler,
Echo: echoHandler,
User: userHandler,
Admin: adminHandler,
Photo: photoHandler,
Whisper: whisperHandler,
Task: taskHandler,
},
)

// Start server
Expand Down Expand Up @@ -186,19 +200,35 @@ func createInitialAdmin(cfg *config.Config, userRepo *repository.UserRepo) {
}

func ensureAIUser(userRepo *repository.UserRepo) string {
user, err := userRepo.FindByUsernameOrEmail("xiaoshiguang")
aiUsername := os.Getenv("AI_USER_USERNAME")
if aiUsername == "" {
aiUsername = "xiaoshiguang"
log.Println("[WARN] AI_USER_USERNAME not set, using default: xiaoshiguang")
}
aiPasswd := os.Getenv("AI_USER_PASSWORD")
if aiPasswd == "" {
// Generate a random password; AI user never logs in interactively.
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
log.Fatalf("failed to generate random AI user password: %v", err)
}
aiPasswd = hex.EncodeToString(b)
log.Println("[WARN] AI_USER_PASSWORD not set, using random password")
}

user, err := userRepo.FindByUsernameOrEmail(aiUsername)
if err == nil {
return user.ID
}

hash, err := password.Hash("ai-user-no-login")
hash, err := password.Hash(aiPasswd)
if err != nil {
log.Printf("Failed to hash AI user password: %v", err)
return ""
}

aiUser := &model.User{
Username: "xiaoshiguang",
Username: aiUsername,
Email: "ai@momshell.com",
PasswordHash: hash,
Nickname: "小石光",
Expand Down
4 changes: 4 additions & 0 deletions backend/internal/handler/admin.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package handler

import (
"log"
"net/http"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -142,6 +143,9 @@ func (h *AdminHandler) UpdateConfig(c *gin.Context) {
return
}

adminID := middleware.GetUserID(c)
log.Printf("[SECURITY] admin_config_change | ip=%s | user=%s | detail=config update requested", c.ClientIP(), adminID)

if err := h.adminService.UpdateConfig(req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
Expand Down
7 changes: 4 additions & 3 deletions backend/internal/handler/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}

resp, err := h.authService.Login(req)
resp, err := h.authService.Login(req, c.ClientIP())
if err != nil {
status := http.StatusUnauthorized
if err.Error() == "账号已禁用" || err.Error() == "账号已被封禁" {
Expand Down Expand Up @@ -184,8 +184,9 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) {
return
}

// TODO: In production, send email with reset link instead of logging
_ = token // token would be sent via email
// In a production deployment, the token would be sent via email with a reset link.
// For now, it is intentionally unused to avoid leaking it in the response.
_ = token
c.JSON(http.StatusOK, gin.H{
"message": "如果该邮箱已注册,将收到重置密码邮件",
})
Expand Down
4 changes: 2 additions & 2 deletions backend/internal/handler/interaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func (h *InteractionHandler) CreateLike(c *gin.Context) {
}

userID := middleware.GetUserID(c)
isLiked, newCount, err := h.communityService.ToggleLike(userID, req.TargetType, req.TargetID)
isLiked, newCount, err := h.communityService.AddLike(userID, req.TargetType, req.TargetID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
Expand All @@ -44,7 +44,7 @@ func (h *InteractionHandler) DeleteLike(c *gin.Context) {
}

userID := middleware.GetUserID(c)
isLiked, newCount, err := h.communityService.ToggleLike(userID, req.TargetType, req.TargetID)
isLiked, newCount, err := h.communityService.RemoveLike(userID, req.TargetType, req.TargetID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
Expand Down
4 changes: 4 additions & 0 deletions backend/internal/middleware/auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"log"
"net/http"
"strings"

Expand Down Expand Up @@ -82,15 +83,18 @@ func extractUserID(c *gin.Context, cfg *config.Config) (string, error) {

// Check token blacklist (revoked tokens)
if TokenBlacklist.IsBlacklisted(tokenStr) {
log.Printf("[SECURITY] jwt_validation_failed | ip=%s | user=unknown | detail=token is blacklisted", c.ClientIP())
return "", pkgjwt.ErrInvalidToken
}

claims, err := pkgjwt.ParseToken(tokenStr, cfg.JWTSecretKey)
if err != nil {
log.Printf("[SECURITY] jwt_validation_failed | ip=%s | user=unknown | detail=%v", c.ClientIP(), err)
return "", err
}

if claims.Type != "access" {
log.Printf("[SECURITY] jwt_validation_failed | ip=%s | user=%s | detail=invalid token type: %s", c.ClientIP(), claims.Subject, claims.Type)
return "", pkgjwt.ErrInvalidToken
}

Expand Down
13 changes: 7 additions & 6 deletions backend/internal/middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,24 @@ import (

func CORS(cfg *config.Config) gin.HandlerFunc {
corsConfig := cors.Config{
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Access-Token"},
ExposeHeaders: []string{"Content-Length"},
AllowCredentials: true,
MaxAge: 12 * time.Hour,
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Access-Token"},
ExposeHeaders: []string{"Content-Length"},
MaxAge: 12 * time.Hour,
}

trimmed := strings.TrimSpace(cfg.CORSOrigins)
if trimmed == "" || trimmed == "*" {
// Allow all origins (compatible with AllowCredentials)
// Allow all origins but disable credentials to prevent CSRF-like attacks
corsConfig.AllowOriginFunc = func(origin string) bool { return true }
corsConfig.AllowCredentials = false
} else {
origins := strings.Split(trimmed, ",")
for i := range origins {
origins[i] = strings.TrimSpace(origins[i])
}
corsConfig.AllowOrigins = origins
corsConfig.AllowCredentials = true
}

return cors.New(corsConfig)
Expand Down
14 changes: 13 additions & 1 deletion backend/internal/middleware/security.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package middleware

import "github.com/gin-gonic/gin"
import (
"strings"

"github.com/gin-gonic/gin"
)

// SecurityHeaders adds standard security headers to all responses.
func SecurityHeaders() gin.HandlerFunc {
Expand All @@ -10,6 +14,14 @@ func SecurityHeaders() gin.HandlerFunc {
c.Header("X-XSS-Protection", "1; mode=block")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
c.Header("Permissions-Policy", "camera=(self), microphone=(), geolocation=()")
c.Header("Strict-Transport-Security", "max-age=63072000; includeSubDomains")

// Add cache-control headers for API responses
if strings.HasPrefix(c.Request.URL.Path, "/api/") {
c.Header("Cache-Control", "no-store")
c.Header("Pragma", "no-cache")
}

c.Next()
}
}
20 changes: 17 additions & 3 deletions backend/internal/middleware/tokenblacklist.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"log"
"sync"
"time"
)
Expand All @@ -22,17 +23,30 @@ func (b *tokenBlacklistStore) Add(token string, expiry time.Time) {
b.mu.Lock()
defer b.mu.Unlock()
if len(b.tokens) >= maxBlacklistSize {
// Emergency cleanup: remove expired tokens immediately
// Emergency cleanup: remove expired tokens
now := time.Now()
for t, exp := range b.tokens {
if now.After(exp) {
delete(b.tokens, t)
}
}
}
// If still at capacity after cleanup, skip adding (token will simply not be blacklisted)
// If still at capacity after cleanup, evict the oldest entry (FIFO)
if len(b.tokens) >= maxBlacklistSize {
return
log.Printf("[SECURITY] token_blacklist_full | size=%d | evicting oldest entry to make room", len(b.tokens))
var oldestToken string
var oldestExpiry time.Time
first := true
for t, exp := range b.tokens {
if first || exp.Before(oldestExpiry) {
oldestToken = t
oldestExpiry = exp
first = false
}
}
if oldestToken != "" {
delete(b.tokens, oldestToken)
}
}
b.tokens[token] = expiry
}
Expand Down
4 changes: 2 additions & 2 deletions backend/internal/repository/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (r *ChatRepo) Create(m *model.ChatMemory) error {
}

// UpdateSummaryAndTurns updates only the summary and turns fields.
func (r *ChatRepo) UpdateSummaryAndTurns(userID string, summary string, turns string) error {
func (r *ChatRepo) UpdateSummaryAndTurns(userID, summary, turns string) error {
return r.db.Model(&model.ChatMemory{}).
Where(whereUserID, userID).
Updates(map[string]any{
Expand Down Expand Up @@ -139,7 +139,7 @@ func (r *ChatRepo) FactExistsByContentFamily(familyIDs []string, content string)
return count > 0, err
}

func (r *ChatRepo) DeleteFactsByContentLikeFamily(familyIDs []string, phrases []string) error {
func (r *ChatRepo) DeleteFactsByContentLikeFamily(familyIDs, phrases []string) error {
for _, phrase := range phrases {
phrase = strings.TrimSpace(phrase)
if phrase == "" {
Expand Down
Loading
Loading