diff --git a/go.mod b/go.mod index 81afa7a..6547d5d 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.22 github.com/robfig/cron/v3 v3.0.1 + golang.org/x/crypto v0.32.0 golang.org/x/oauth2 v0.34.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -60,7 +61,6 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/crypto v0.32.0 // indirect golang.org/x/net v0.34.0 // indirect golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.21.0 // indirect diff --git a/internal/api/apikeys.go b/internal/api/apikeys.go new file mode 100644 index 0000000..ed28d10 --- /dev/null +++ b/internal/api/apikeys.go @@ -0,0 +1,198 @@ +package api + +import ( + "net/http" + "strconv" + "time" + + "github.com/flatrun/agent/internal/auth" + "github.com/gin-gonic/gin" +) + +func (s *Server) getAPIKeyWithAuth(c *gin.Context) (*auth.APIKey, bool) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid API key ID"}) + return nil, false + } + + key, err := s.authManager.GetAPIKey(id) + if err == auth.ErrAPIKeyNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "API key not found"}) + return nil, false + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get API key"}) + return nil, false + } + + actor := auth.GetActorFromContext(c) + if actor.Role != auth.RoleAdmin && (actor.User == nil || actor.User.ID != key.UserID) { + c.JSON(http.StatusForbidden, gin.H{"error": "Access denied"}) + return nil, false + } + + return key, true +} + +func (s *Server) listAPIKeys(c *gin.Context) { + actor := auth.GetActorFromContext(c) + if actor == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Not authenticated"}) + return + } + + var keys []auth.APIKey + var err error + + if actor.Role == auth.RoleAdmin { + keys, err = s.authManager.GetAllAPIKeys() + } else if actor.User != nil { + keys, err = s.authManager.GetAPIKeysByUser(actor.User.ID) + } else { + c.JSON(http.StatusForbidden, gin.H{"error": "Cannot list API keys"}) + return + } + + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list API keys"}) + return + } + + response := make([]gin.H, 0, len(keys)) + for _, k := range keys { + response = append(response, apiKeyToResponse(&k)) + } + + c.JSON(http.StatusOK, gin.H{"api_keys": response}) +} + +func (s *Server) getAPIKey(c *gin.Context) { + key, ok := s.getAPIKeyWithAuth(c) + if !ok { + return + } + + c.JSON(http.StatusOK, gin.H{"api_key": apiKeyToResponse(key)}) +} + +func (s *Server) createAPIKey(c *gin.Context) { + actor := auth.GetActorFromContext(c) + if actor == nil || actor.User == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Not authenticated"}) + return + } + + var req struct { + Name string `json:"name" binding:"required"` + Description string `json:"description"` + Role auth.Role `json:"role"` + Permissions []string `json:"permissions"` + Deployments []string `json:"deployments"` + ExpiresIn int `json:"expires_in"` + UserID int64 `json:"user_id"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"}) + return + } + + userID := actor.User.ID + if req.UserID > 0 && actor.Role == auth.RoleAdmin { + userID = req.UserID + } + + if req.Role != "" && !req.Role.IsValid() { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid role"}) + return + } + + if actor.Role != auth.RoleAdmin { + if req.Role == auth.RoleAdmin { + c.JSON(http.StatusForbidden, gin.H{"error": "Cannot create admin API key"}) + return + } + for _, p := range req.Permissions { + if !actor.HasPermission(auth.Permission(p)) { + c.JSON(http.StatusForbidden, gin.H{"error": "Cannot grant permission you don't have: " + p}) + return + } + } + } + + var expiresAt time.Time + if req.ExpiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(req.ExpiresIn) * time.Second) + } + + key, plainKey, err := s.authManager.CreateAPIKey( + userID, + req.Name, + req.Description, + req.Role, + req.Permissions, + req.Deployments, + expiresAt, + ) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create API key"}) + return + } + + response := apiKeyToResponse(key) + response["key"] = plainKey + + c.JSON(http.StatusCreated, gin.H{ + "api_key": response, + "message": "Save this key securely. It will not be shown again.", + }) +} + +func (s *Server) deleteAPIKey(c *gin.Context) { + key, ok := s.getAPIKeyWithAuth(c) + if !ok { + return + } + + if err := s.authManager.DeleteAPIKey(key.ID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete API key"}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "API key deleted"}) +} + +func (s *Server) revokeAPIKey(c *gin.Context) { + key, ok := s.getAPIKeyWithAuth(c) + if !ok { + return + } + + if err := s.authManager.DeactivateAPIKey(key.ID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to revoke API key"}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "API key revoked"}) +} + +func apiKeyToResponse(k *auth.APIKey) gin.H { + return gin.H{ + "id": k.ID, + "key_id": k.KeyID, + "user_id": k.UserID, + "name": k.Name, + "description": k.Description, + "key_prefix": k.KeyPrefix, + "role": k.Role, + "permissions": k.Permissions, + "deployments": k.Deployments, + "expires_at": k.ExpiresAt, + "last_used_at": k.LastUsedAt, + "last_used_ip": k.LastUsedIP, + "is_active": k.IsActive, + "created_at": k.CreatedAt, + } +} diff --git a/internal/api/apikeys_test.go b/internal/api/apikeys_test.go new file mode 100644 index 0000000..55456ac --- /dev/null +++ b/internal/api/apikeys_test.go @@ -0,0 +1,423 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/pkg/config" + "github.com/gin-gonic/gin" +) + +func setupAPIKeyTestServer(t *testing.T) (*Server, *gin.Engine, func()) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "apikey_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + cfg := &config.Config{ + DeploymentsPath: tmpDir, + Auth: config.AuthConfig{ + Enabled: true, + JWTSecret: "test-jwt-secret-key-for-testing", + APIKeys: []string{"legacy-test-key"}, + }, + } + + os.Setenv("FLATRUN_ADMIN_PASSWORD", "testadminpass") + + authManager, err := auth.NewManager(tmpDir, &cfg.Auth) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create auth manager: %v", err) + } + + server := &Server{ + config: cfg, + authManager: authManager, + } + + router := gin.New() + authMiddleware := auth.NewMiddlewareWithManager(&cfg.Auth, authManager) + + api := router.Group("/api") + api.POST("/auth/login", authMiddleware.Login) + + protected := api.Group("") + protected.Use(authMiddleware.RequireAuth()) + { + apikeys := protected.Group("/apikeys") + apikeys.Use(authMiddleware.RequirePermission(auth.PermAPIKeysRead)) + { + apikeys.GET("", server.listAPIKeys) + apikeys.GET("/:id", server.getAPIKey) + } + + apikeysWrite := protected.Group("/apikeys") + apikeysWrite.Use(authMiddleware.RequirePermission(auth.PermAPIKeysWrite)) + { + apikeysWrite.POST("", server.createAPIKey) + } + + apikeysDelete := protected.Group("/apikeys") + apikeysDelete.Use(authMiddleware.RequirePermission(auth.PermAPIKeysDelete)) + { + apikeysDelete.DELETE("/:id", server.deleteAPIKey) + apikeysDelete.POST("/:id/revoke", server.revokeAPIKey) + } + } + + cleanup := func() { + authManager.Close() + os.RemoveAll(tmpDir) + os.Unsetenv("FLATRUN_ADMIN_PASSWORD") + } + + return server, router, cleanup +} + +func apiKeyLogin(t *testing.T, router *gin.Engine, username, password string) string { + body := map[string]string{ + "username": username, + "password": password, + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Login failed: %d - %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + return resp["token"].(string) +} + +func TestCreateAPIKey(t *testing.T) { + _, router, cleanup := setupAPIKeyTestServer(t) + defer cleanup() + + token := apiKeyLogin(t, router, "admin", "testadminpass") + + body := map[string]interface{}{ + "name": "Test API Key", + "description": "For testing purposes", + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/api/apikeys", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected status 201, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + apiKey, ok := resp["api_key"].(map[string]interface{}) + if !ok { + t.Fatal("Expected api_key object in response") + } + + if apiKey["name"] != "Test API Key" { + t.Errorf("Expected name 'Test API Key', got %v", apiKey["name"]) + } + + apiKeyResp := resp["api_key"].(map[string]interface{}) + plainKey, ok := apiKeyResp["key"].(string) + if !ok || plainKey == "" { + t.Error("Expected key to be returned on creation") + } + + if len(plainKey) < 10 { + t.Error("Plain key should be a significant length") + } +} + +func TestCreateAPIKeyWithExpiration(t *testing.T) { + _, router, cleanup := setupAPIKeyTestServer(t) + defer cleanup() + + token := apiKeyLogin(t, router, "admin", "testadminpass") + + body := map[string]interface{}{ + "name": "Expiring Key", + "expires_in": 86400, + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/api/apikeys", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected status 201, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + apiKey, ok := resp["api_key"].(map[string]interface{}) + if !ok { + t.Fatal("Response missing api_key") + } + if apiKey["expires_at"] == nil || apiKey["expires_at"] == "" { + t.Error("Expected expires_at to be set") + } +} + +func TestCreateAPIKeyWithRoleOverride(t *testing.T) { + _, router, cleanup := setupAPIKeyTestServer(t) + defer cleanup() + + token := apiKeyLogin(t, router, "admin", "testadminpass") + + body := map[string]interface{}{ + "name": "Viewer Key", + "role": "viewer", + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/api/apikeys", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected status 201, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + apiKey := resp["api_key"].(map[string]interface{}) + if apiKey["role"] != "viewer" { + t.Errorf("Expected role 'viewer', got %v", apiKey["role"]) + } +} + +func TestListAPIKeys(t *testing.T) { + server, router, cleanup := setupAPIKeyTestServer(t) + defer cleanup() + + admin, _ := server.authManager.GetUserByUsername("admin") + _, _, _ = server.authManager.CreateAPIKey(admin.ID, "Key 1", "", "", nil, nil, time.Time{}) + _, _, _ = server.authManager.CreateAPIKey(admin.ID, "Key 2", "", "", nil, nil, time.Time{}) + + token := apiKeyLogin(t, router, "admin", "testadminpass") + + req := httptest.NewRequest(http.MethodGet, "/api/apikeys", nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + keys, ok := resp["api_keys"].([]interface{}) + if !ok { + t.Fatal("Expected api_keys array in response") + } + + if len(keys) < 2 { + t.Errorf("Expected at least 2 keys, got %d", len(keys)) + } +} + +func TestGetAPIKey(t *testing.T) { + server, router, cleanup := setupAPIKeyTestServer(t) + defer cleanup() + + admin, _ := server.authManager.GetUserByUsername("admin") + key, _, _ := server.authManager.CreateAPIKey(admin.ID, "Specific Key", "Description", "", nil, nil, time.Time{}) + + token := apiKeyLogin(t, router, "admin", "testadminpass") + + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/apikeys/%d", key.ID), nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + apiKey := resp["api_key"].(map[string]interface{}) + if apiKey["name"] != "Specific Key" { + t.Errorf("Expected name 'Specific Key', got %v", apiKey["name"]) + } +} + +func TestDeleteAPIKey(t *testing.T) { + server, router, cleanup := setupAPIKeyTestServer(t) + defer cleanup() + + admin, _ := server.authManager.GetUserByUsername("admin") + key, _, _ := server.authManager.CreateAPIKey(admin.ID, "To Delete", "", "", nil, nil, time.Time{}) + + token := apiKeyLogin(t, router, "admin", "testadminpass") + + req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/api/apikeys/%d", key.ID), nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + _, err := server.authManager.GetAPIKey(key.ID) + if err != auth.ErrAPIKeyNotFound { + t.Error("API key should be deleted") + } +} + +func TestRevokeAPIKey(t *testing.T) { + server, router, cleanup := setupAPIKeyTestServer(t) + defer cleanup() + + admin, _ := server.authManager.GetUserByUsername("admin") + key, plainKey, _ := server.authManager.CreateAPIKey(admin.ID, "To Revoke", "", "", nil, nil, time.Time{}) + + _, _, err := server.authManager.ValidateAPIKey(plainKey) + if err != nil { + t.Fatalf("Key should be valid before revoke: %v", err) + } + + token := apiKeyLogin(t, router, "admin", "testadminpass") + + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/api/apikeys/%d/revoke", key.ID), nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + _, _, err = server.authManager.ValidateAPIKey(plainKey) + if err != auth.ErrAPIKeyInactive { + t.Errorf("Expected ErrAPIKeyInactive after revoke, got: %v", err) + } +} + +func TestOperatorCanAccessOwnAPIKeys(t *testing.T) { + server, router, cleanup := setupAPIKeyTestServer(t) + defer cleanup() + + operator, _ := server.authManager.CreateUser("operator", "", "operatorpass", auth.RoleOperator, nil) + + _, _, _ = server.authManager.CreateAPIKey(operator.ID, "Operator's Key", "", "", nil, nil, time.Time{}) + + token := apiKeyLogin(t, router, "operator", "operatorpass") + + req := httptest.NewRequest(http.MethodGet, "/api/apikeys", nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + keys := resp["api_keys"].([]interface{}) + if len(keys) != 1 { + t.Errorf("Operator should see their own 1 key, got %d", len(keys)) + } +} + +func TestViewerCannotCreateAPIKey(t *testing.T) { + server, router, cleanup := setupAPIKeyTestServer(t) + defer cleanup() + + _, _ = server.authManager.CreateUser("viewer", "", "viewerpass", auth.RoleViewer, nil) + + token := apiKeyLogin(t, router, "viewer", "viewerpass") + + body := map[string]string{"name": "Viewer Key"} + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/api/apikeys", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected status 403, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestCreateAPIKeyWithDeploymentScope(t *testing.T) { + _, router, cleanup := setupAPIKeyTestServer(t) + defer cleanup() + + token := apiKeyLogin(t, router, "admin", "testadminpass") + + body := map[string]interface{}{ + "name": "Scoped Key", + "deployments": []string{"app-a", "app-b"}, + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/api/apikeys", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected status 201, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + apiKey := resp["api_key"].(map[string]interface{}) + deployments := apiKey["deployments"].([]interface{}) + if len(deployments) != 2 { + t.Errorf("Expected 2 deployments, got %d", len(deployments)) + } +} diff --git a/internal/api/server.go b/internal/api/server.go index c40faaa..78d945a 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -53,6 +53,7 @@ type Server struct { networksManager *networks.Manager pluginRegistry *plugins.Registry authMiddleware *auth.Middleware + authManager *auth.Manager proxyOrchestrator *proxy.Orchestrator filesManager *files.Manager servicesManager *system.ServicesManager @@ -88,6 +89,15 @@ func New(cfg *config.Config, configPath string) *Server { pluginRegistry := plugins.NewRegistry(pluginsDir) _ = pluginRegistry.LoadFromDisk() authMiddleware := auth.NewMiddleware(&cfg.Auth) + + var authManager *auth.Manager + if authMgr, authErr := auth.NewManager(cfg.DeploymentsPath, &cfg.Auth); authErr != nil { + log.Printf("Warning: Failed to initialize auth manager: %v", authErr) + } else { + authManager = authMgr + authMiddleware.SetManager(authManager) + } + proxyOrchestrator := proxy.NewOrchestrator(cfg) filesManager := files.NewManager(cfg.DeploymentsPath) servicesManager := system.NewServicesManager() @@ -168,6 +178,7 @@ func New(cfg *config.Config, configPath string) *Server { networksManager: networksManager, pluginRegistry: pluginRegistry, authMiddleware: authMiddleware, + authManager: authManager, proxyOrchestrator: proxyOrchestrator, filesManager: filesManager, servicesManager: servicesManager, @@ -215,187 +226,253 @@ func (s *Server) setupRoutes() { protected.Use(s.auditMiddleware.Capture()) } { - protected.GET("/deployments", s.listDeployments) - protected.GET("/deployments/:name", s.getDeployment) - protected.POST("/deployments", s.createDeployment) - protected.PUT("/deployments/:name", s.updateDeployment) - protected.PUT("/deployments/:name/metadata", s.updateDeploymentMetadata) - protected.DELETE("/deployments/:name", s.deleteDeployment) - protected.POST("/deployments/:name/start", s.startDeployment) - protected.POST("/deployments/:name/stop", s.stopDeployment) - protected.POST("/deployments/:name/restart", s.restartDeployment) - protected.POST("/deployments/:name/rebuild", s.rebuildDeployment) - protected.POST("/deployments/:name/pull", s.pullDeploymentImage) - protected.GET("/deployments/:name/images", s.getDeploymentImages) - protected.POST("/deployments/:name/actions/:actionId", s.executeQuickAction) - protected.GET("/deployments/:name/logs", s.getDeploymentLogs) - protected.GET("/deployments/:name/compose", s.getDeploymentCompose) - protected.GET("/networks", s.listNetworks) - protected.POST("/networks", s.createNetwork) - protected.DELETE("/networks/:name", s.deleteNetwork) - protected.POST("/networks/:name/connect", s.connectContainer) - protected.POST("/networks/:name/disconnect", s.disconnectContainer) - protected.GET("/certificates", s.listCertificates) - protected.POST("/certificates", s.requestCertificate) - protected.POST("/certificates/renew", s.renewCertificates) - protected.DELETE("/certificates/:domain", s.deleteCertificate) - - protected.GET("/proxy/status/:name", s.getProxyStatus) - protected.POST("/proxy/setup/:name", s.setupProxy) - protected.DELETE("/proxy/:name", s.teardownProxy) - protected.GET("/proxy/vhosts", s.listVirtualHosts) - protected.POST("/proxy/sync", s.syncAllProxies) - protected.POST("/deployments/:name/ssl/disable", s.disableSSL) - - protected.GET("/settings", s.getSettings) - protected.PUT("/settings", s.updateSettings) - protected.PUT("/settings/security", s.updateSecuritySettings) - protected.GET("/subdomain/generate", s.generateSubdomain) - protected.GET("/plugins", s.listPlugins) - protected.GET("/plugins/:name", s.getPlugin) - protected.POST("/plugins/:name/deployments", s.createPluginDeployment) - protected.GET("/templates", s.listTemplates) - protected.GET("/templates/categories", s.getTemplateCategories) - protected.POST("/templates/refresh", s.refreshTemplates) - protected.GET("/templates/:id/compose", s.getTemplateCompose) - protected.POST("/templates/:id/generate", s.generateTemplateCompose) - protected.POST("/compose/update", s.updateCompose) - protected.GET("/stats", s.getSystemStats) - protected.GET("/containers", s.listContainers) - protected.POST("/containers/:id/start", s.startContainer) - protected.POST("/containers/:id/stop", s.stopContainer) - protected.POST("/containers/:id/restart", s.restartContainer) - protected.DELETE("/containers/:id", s.removeContainer) - protected.GET("/containers/:id/logs", s.getContainerLogs) - protected.GET("/containers/:id/stats", s.getContainerStats) - protected.GET("/containers/stats", s.getAllContainerStats) - protected.POST("/containers/:id/exec", s.containerExecHTTP) - protected.GET("/deployments/:name/stats", s.getDeploymentContainerStats) - protected.GET("/images", s.listImages) - protected.DELETE("/images/:id", s.removeImage) - protected.POST("/images/pull", s.pullImage) - protected.GET("/volumes", s.listVolumes) - protected.POST("/volumes", s.createVolume) - protected.DELETE("/volumes/:name", s.removeVolume) - protected.POST("/volumes/prune", s.pruneVolumes) - protected.GET("/ports", s.listPorts) - protected.POST("/ports/:pid/kill", s.killProcess) - - protected.GET("/system/services", s.listSystemServices) - protected.POST("/system/services/:name/start", s.startSystemService) - protected.POST("/system/services/:name/stop", s.stopSystemService) - protected.POST("/system/services/:name/restart", s.restartSystemService) - - protected.GET("/deployments/:name/files", s.listDeploymentFiles) - protected.GET("/deployments/:name/files/*path", s.getDeploymentFile) - protected.POST("/deployments/:name/files/*path", s.uploadDeploymentFile) - protected.DELETE("/deployments/:name/files/*path", s.deleteDeploymentFile) - protected.POST("/deployments/:name/mkdir/*path", s.createDeploymentDir) - protected.GET("/deployments/:name/files-info", s.getDeploymentFilesInfo) - - protected.GET("/deployments/:name/env", s.getDeploymentEnv) - protected.PUT("/deployments/:name/env", s.updateDeploymentEnv) - - protected.POST("/databases/test", s.testDatabaseConnection) - protected.POST("/databases/list", s.listDatabasesInServer) - protected.POST("/databases/tables", s.listDatabaseTables) - protected.POST("/databases/tables/data", s.queryTableData) - protected.POST("/databases/tables/schema", s.describeTable) - protected.POST("/databases/query", s.executeDatabaseQuery) - protected.POST("/databases/users", s.listDatabaseUsers) - protected.POST("/databases/users/by-database", s.listUsersByDatabase) - protected.POST("/databases/create", s.createDatabaseInServer) - protected.POST("/databases/delete", s.deleteDatabaseInServer) - protected.POST("/databases/users/create", s.createDatabaseUser) - protected.POST("/databases/users/delete", s.deleteDatabaseUser) - protected.POST("/databases/privileges/grant", s.grantDatabasePrivileges) - - protected.GET("/infrastructure", s.listInfrastructure) - protected.GET("/infrastructure/stats", s.getInfraStats) - protected.GET("/infrastructure/:name", s.getInfraService) - protected.POST("/infrastructure/:name/start", s.startInfraService) - protected.POST("/infrastructure/:name/stop", s.stopInfraService) - protected.POST("/infrastructure/:name/restart", s.restartInfraService) - protected.GET("/infrastructure/:name/logs", s.getInfraServiceLogs) - protected.POST("/infrastructure/migrate/:name", s.migrateToInfrastructure) - - protected.GET("/registries", s.listRegistryTypes) - protected.GET("/registries/:slug", s.getRegistryType) - protected.POST("/registries", s.createRegistryType) - protected.PUT("/registries/:slug", s.updateRegistryType) - protected.DELETE("/registries/:slug", s.deleteRegistryType) - - protected.GET("/credentials", s.listCredentials) - protected.GET("/credentials/:id", s.getCredential) - protected.POST("/credentials", s.createCredential) - protected.PUT("/credentials/:id", s.updateCredential) - protected.DELETE("/credentials/:id", s.deleteCredential) - protected.POST("/credentials/:id/test", s.testCredential) + // Deployment endpoints + protected.GET("/deployments", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.listDeployments) + protected.GET("/deployments/:name", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeployment) + protected.POST("/deployments", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.createDeployment) + protected.PUT("/deployments/:name", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.updateDeployment) + protected.PUT("/deployments/:name/metadata", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.updateDeploymentMetadata) + protected.DELETE("/deployments/:name", s.authMiddleware.RequirePermission(auth.PermDeploymentsDelete), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelAdmin), s.deleteDeployment) + protected.POST("/deployments/:name/start", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.startDeployment) + protected.POST("/deployments/:name/stop", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.stopDeployment) + protected.POST("/deployments/:name/restart", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.restartDeployment) + protected.POST("/deployments/:name/rebuild", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.rebuildDeployment) + protected.POST("/deployments/:name/pull", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.pullDeploymentImage) + protected.GET("/deployments/:name/images", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeploymentImages) + protected.POST("/deployments/:name/actions/:actionId", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.executeQuickAction) + protected.GET("/deployments/:name/logs", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeploymentLogs) + protected.GET("/deployments/:name/compose", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeploymentCompose) + + // Network endpoints + protected.GET("/networks", s.authMiddleware.RequirePermission(auth.PermNetworksRead), s.listNetworks) + protected.POST("/networks", s.authMiddleware.RequirePermission(auth.PermNetworksWrite), s.createNetwork) + protected.DELETE("/networks/:name", s.authMiddleware.RequirePermission(auth.PermNetworksDelete), s.deleteNetwork) + protected.POST("/networks/:name/connect", s.authMiddleware.RequirePermission(auth.PermNetworksWrite), s.connectContainer) + protected.POST("/networks/:name/disconnect", s.authMiddleware.RequirePermission(auth.PermNetworksWrite), s.disconnectContainer) + + // Certificate endpoints + protected.GET("/certificates", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.listCertificates) + protected.POST("/certificates", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.requestCertificate) + protected.POST("/certificates/renew", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.renewCertificates) + protected.DELETE("/certificates/:domain", s.authMiddleware.RequirePermission(auth.PermCertificatesDelete), s.deleteCertificate) + + // Proxy endpoints + protected.GET("/proxy/status/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.getProxyStatus) + protected.POST("/proxy/setup/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.setupProxy) + protected.DELETE("/proxy/:name", s.authMiddleware.RequirePermission(auth.PermCertificatesDelete), s.teardownProxy) + protected.GET("/proxy/vhosts", s.authMiddleware.RequirePermission(auth.PermCertificatesRead), s.listVirtualHosts) + protected.POST("/proxy/sync", s.authMiddleware.RequirePermission(auth.PermCertificatesWrite), s.syncAllProxies) + protected.POST("/deployments/:name/ssl/disable", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.disableSSL) + + // Settings endpoints + protected.GET("/settings", s.authMiddleware.RequirePermission(auth.PermSettingsRead), s.getSettings) + protected.PUT("/settings", s.authMiddleware.RequirePermission(auth.PermSettingsWrite), s.updateSettings) + protected.PUT("/settings/security", s.authMiddleware.RequirePermission(auth.PermSettingsWrite), s.updateSecuritySettings) + + // Compose, stats, subdomain (deployment-scoped) + protected.GET("/subdomain/generate", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.generateSubdomain) + protected.POST("/compose/update", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.updateCompose) + protected.GET("/stats", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.getSystemStats) + + // Template and plugin endpoints + protected.GET("/plugins", s.authMiddleware.RequirePermission(auth.PermTemplatesRead), s.listPlugins) + protected.GET("/plugins/:name", s.authMiddleware.RequirePermission(auth.PermTemplatesRead), s.getPlugin) + protected.POST("/plugins/:name/deployments", s.authMiddleware.RequirePermission(auth.PermTemplatesWrite), s.createPluginDeployment) + protected.GET("/templates", s.authMiddleware.RequirePermission(auth.PermTemplatesRead), s.listTemplates) + protected.GET("/templates/categories", s.authMiddleware.RequirePermission(auth.PermTemplatesRead), s.getTemplateCategories) + protected.POST("/templates/refresh", s.authMiddleware.RequirePermission(auth.PermTemplatesWrite), s.refreshTemplates) + protected.GET("/templates/:id/compose", s.authMiddleware.RequirePermission(auth.PermTemplatesRead), s.getTemplateCompose) + protected.POST("/templates/:id/generate", s.authMiddleware.RequirePermission(auth.PermTemplatesWrite), s.generateTemplateCompose) + + // Container endpoints + protected.GET("/containers", s.authMiddleware.RequirePermission(auth.PermContainersRead), s.listContainers) + protected.POST("/containers/:id/start", s.authMiddleware.RequirePermission(auth.PermContainersWrite), s.startContainer) + protected.POST("/containers/:id/stop", s.authMiddleware.RequirePermission(auth.PermContainersWrite), s.stopContainer) + protected.POST("/containers/:id/restart", s.authMiddleware.RequirePermission(auth.PermContainersWrite), s.restartContainer) + protected.DELETE("/containers/:id", s.authMiddleware.RequirePermission(auth.PermContainersDelete), s.removeContainer) + protected.GET("/containers/:id/logs", s.authMiddleware.RequirePermission(auth.PermContainersRead), s.getContainerLogs) + protected.GET("/containers/:id/stats", s.authMiddleware.RequirePermission(auth.PermContainersRead), s.getContainerStats) + protected.GET("/containers/stats", s.authMiddleware.RequirePermission(auth.PermContainersRead), s.getAllContainerStats) + protected.POST("/containers/:id/exec", s.authMiddleware.RequirePermission(auth.PermContainersWrite), s.containerExecHTTP) + protected.GET("/deployments/:name/stats", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeploymentContainerStats) + + // Image endpoints + protected.GET("/images", s.authMiddleware.RequirePermission(auth.PermImagesRead), s.listImages) + protected.DELETE("/images/:id", s.authMiddleware.RequirePermission(auth.PermImagesDelete), s.removeImage) + protected.POST("/images/pull", s.authMiddleware.RequirePermission(auth.PermImagesWrite), s.pullImage) + + // Volume endpoints + protected.GET("/volumes", s.authMiddleware.RequirePermission(auth.PermVolumesRead), s.listVolumes) + protected.POST("/volumes", s.authMiddleware.RequirePermission(auth.PermVolumesWrite), s.createVolume) + protected.DELETE("/volumes/:name", s.authMiddleware.RequirePermission(auth.PermVolumesDelete), s.removeVolume) + protected.POST("/volumes/prune", s.authMiddleware.RequirePermission(auth.PermVolumesWrite), s.pruneVolumes) + + // Port endpoints + protected.GET("/ports", s.authMiddleware.RequirePermission(auth.PermSystemRead), s.listPorts) + protected.POST("/ports/:pid/kill", s.authMiddleware.RequirePermission(auth.PermSystemWrite), s.killProcess) + + // System service endpoints + protected.GET("/system/services", s.authMiddleware.RequirePermission(auth.PermSystemRead), s.listSystemServices) + protected.POST("/system/services/:name/start", s.authMiddleware.RequirePermission(auth.PermSystemWrite), s.startSystemService) + protected.POST("/system/services/:name/stop", s.authMiddleware.RequirePermission(auth.PermSystemWrite), s.stopSystemService) + protected.POST("/system/services/:name/restart", s.authMiddleware.RequirePermission(auth.PermSystemWrite), s.restartSystemService) + + // Deployment file endpoints + protected.GET("/deployments/:name/files", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.listDeploymentFiles) + protected.GET("/deployments/:name/files/*path", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeploymentFile) + protected.POST("/deployments/:name/files/*path", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.uploadDeploymentFile) + protected.DELETE("/deployments/:name/files/*path", s.authMiddleware.RequirePermission(auth.PermDeploymentsDelete), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelAdmin), s.deleteDeploymentFile) + protected.POST("/deployments/:name/mkdir/*path", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.createDeploymentDir) + protected.GET("/deployments/:name/files-info", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeploymentFilesInfo) + + // Deployment environment endpoints + protected.GET("/deployments/:name/env", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeploymentEnv) + protected.PUT("/deployments/:name/env", s.authMiddleware.RequirePermission(auth.PermDeploymentsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.updateDeploymentEnv) + + // Database endpoints + protected.POST("/databases/test", s.authMiddleware.RequirePermission(auth.PermDatabasesRead), s.testDatabaseConnection) + protected.POST("/databases/list", s.authMiddleware.RequirePermission(auth.PermDatabasesRead), s.listDatabasesInServer) + protected.POST("/databases/tables", s.authMiddleware.RequirePermission(auth.PermDatabasesRead), s.listDatabaseTables) + protected.POST("/databases/tables/data", s.authMiddleware.RequirePermission(auth.PermDatabasesRead), s.queryTableData) + protected.POST("/databases/tables/schema", s.authMiddleware.RequirePermission(auth.PermDatabasesRead), s.describeTable) + protected.POST("/databases/query", s.authMiddleware.RequirePermission(auth.PermDatabasesWrite), s.executeDatabaseQuery) + protected.POST("/databases/users", s.authMiddleware.RequirePermission(auth.PermDatabasesRead), s.listDatabaseUsers) + protected.POST("/databases/users/by-database", s.authMiddleware.RequirePermission(auth.PermDatabasesRead), s.listUsersByDatabase) + protected.POST("/databases/create", s.authMiddleware.RequirePermission(auth.PermDatabasesWrite), s.createDatabaseInServer) + protected.POST("/databases/delete", s.authMiddleware.RequirePermission(auth.PermDatabasesWrite), s.deleteDatabaseInServer) + protected.POST("/databases/users/create", s.authMiddleware.RequirePermission(auth.PermDatabasesWrite), s.createDatabaseUser) + protected.POST("/databases/users/delete", s.authMiddleware.RequirePermission(auth.PermDatabasesWrite), s.deleteDatabaseUser) + protected.POST("/databases/privileges/grant", s.authMiddleware.RequirePermission(auth.PermDatabasesWrite), s.grantDatabasePrivileges) + + // Infrastructure endpoints + protected.GET("/infrastructure", s.authMiddleware.RequirePermission(auth.PermInfrastructureRead), s.listInfrastructure) + protected.GET("/infrastructure/stats", s.authMiddleware.RequirePermission(auth.PermInfrastructureRead), s.getInfraStats) + protected.GET("/infrastructure/:name", s.authMiddleware.RequirePermission(auth.PermInfrastructureRead), s.getInfraService) + protected.POST("/infrastructure/:name/start", s.authMiddleware.RequirePermission(auth.PermInfrastructureWrite), s.startInfraService) + protected.POST("/infrastructure/:name/stop", s.authMiddleware.RequirePermission(auth.PermInfrastructureWrite), s.stopInfraService) + protected.POST("/infrastructure/:name/restart", s.authMiddleware.RequirePermission(auth.PermInfrastructureWrite), s.restartInfraService) + protected.GET("/infrastructure/:name/logs", s.authMiddleware.RequirePermission(auth.PermInfrastructureRead), s.getInfraServiceLogs) + protected.POST("/infrastructure/migrate/:name", s.authMiddleware.RequirePermission(auth.PermInfrastructureWrite), s.migrateToInfrastructure) + + // Registry endpoints + protected.GET("/registries", s.authMiddleware.RequirePermission(auth.PermRegistriesRead), s.listRegistryTypes) + protected.GET("/registries/:slug", s.authMiddleware.RequirePermission(auth.PermRegistriesRead), s.getRegistryType) + protected.POST("/registries", s.authMiddleware.RequirePermission(auth.PermRegistriesWrite), s.createRegistryType) + protected.PUT("/registries/:slug", s.authMiddleware.RequirePermission(auth.PermRegistriesWrite), s.updateRegistryType) + protected.DELETE("/registries/:slug", s.authMiddleware.RequirePermission(auth.PermRegistriesDelete), s.deleteRegistryType) + + // Credential endpoints + protected.GET("/credentials", s.authMiddleware.RequirePermission(auth.PermRegistriesRead), s.listCredentials) + protected.GET("/credentials/:id", s.authMiddleware.RequirePermission(auth.PermRegistriesRead), s.getCredential) + protected.POST("/credentials", s.authMiddleware.RequirePermission(auth.PermRegistriesWrite), s.createCredential) + protected.PUT("/credentials/:id", s.authMiddleware.RequirePermission(auth.PermRegistriesWrite), s.updateCredential) + protected.DELETE("/credentials/:id", s.authMiddleware.RequirePermission(auth.PermRegistriesDelete), s.deleteCredential) + protected.POST("/credentials/:id/test", s.authMiddleware.RequirePermission(auth.PermRegistriesRead), s.testCredential) // Security endpoints - protected.GET("/security/stats", s.getSecurityStats) - protected.GET("/security/events", s.listSecurityEvents) - protected.GET("/security/events/:id", s.getSecurityEvent) - protected.POST("/security/cleanup", s.cleanupSecurityEvents) - protected.GET("/security/blocked-ips", s.listBlockedIPs) - protected.POST("/security/blocked-ips", s.blockIP) - protected.DELETE("/security/blocked-ips/:ip", s.unblockIP) - protected.GET("/security/ips/:ip/events", s.getEventsByIP) - protected.GET("/security/protected-routes", s.listProtectedRoutes) - protected.POST("/security/protected-routes", s.addProtectedRoute) - protected.PUT("/security/protected-routes/:id", s.updateProtectedRoute) - protected.DELETE("/security/protected-routes/:id", s.deleteProtectedRoute) - protected.GET("/security/whitelist", s.listWhitelist) - protected.POST("/security/whitelist", s.addWhitelistEntry) - protected.DELETE("/security/whitelist/:id", s.removeWhitelistEntry) - protected.GET("/security/realtime-capture", s.getRealtimeCaptureStatus) - protected.PUT("/security/realtime-capture", s.setRealtimeCaptureStatus) - protected.GET("/security/health", s.getSecurityHealth) - protected.POST("/security/refresh", s.refreshSecurityScripts) - protected.GET("/deployments/:name/security", s.getDeploymentSecurity) - protected.PUT("/deployments/:name/security", s.updateDeploymentSecurity) - protected.GET("/deployments/:name/security/events", s.getDeploymentSecurityEvents) + protected.GET("/security/stats", s.authMiddleware.RequirePermission(auth.PermSecurityRead), s.getSecurityStats) + protected.GET("/security/events", s.authMiddleware.RequirePermission(auth.PermSecurityRead), s.listSecurityEvents) + protected.GET("/security/events/:id", s.authMiddleware.RequirePermission(auth.PermSecurityRead), s.getSecurityEvent) + protected.POST("/security/cleanup", s.authMiddleware.RequirePermission(auth.PermSecurityWrite), s.cleanupSecurityEvents) + protected.GET("/security/blocked-ips", s.authMiddleware.RequirePermission(auth.PermSecurityRead), s.listBlockedIPs) + protected.POST("/security/blocked-ips", s.authMiddleware.RequirePermission(auth.PermSecurityWrite), s.blockIP) + protected.DELETE("/security/blocked-ips/:ip", s.authMiddleware.RequirePermission(auth.PermSecurityWrite), s.unblockIP) + protected.GET("/security/ips/:ip/events", s.authMiddleware.RequirePermission(auth.PermSecurityRead), s.getEventsByIP) + protected.GET("/security/protected-routes", s.authMiddleware.RequirePermission(auth.PermSecurityRead), s.listProtectedRoutes) + protected.POST("/security/protected-routes", s.authMiddleware.RequirePermission(auth.PermSecurityWrite), s.addProtectedRoute) + protected.PUT("/security/protected-routes/:id", s.authMiddleware.RequirePermission(auth.PermSecurityWrite), s.updateProtectedRoute) + protected.DELETE("/security/protected-routes/:id", s.authMiddleware.RequirePermission(auth.PermSecurityWrite), s.deleteProtectedRoute) + protected.GET("/security/whitelist", s.authMiddleware.RequirePermission(auth.PermSecurityRead), s.listWhitelist) + protected.POST("/security/whitelist", s.authMiddleware.RequirePermission(auth.PermSecurityWrite), s.addWhitelistEntry) + protected.DELETE("/security/whitelist/:id", s.authMiddleware.RequirePermission(auth.PermSecurityWrite), s.removeWhitelistEntry) + protected.GET("/security/realtime-capture", s.authMiddleware.RequirePermission(auth.PermSecurityRead), s.getRealtimeCaptureStatus) + protected.PUT("/security/realtime-capture", s.authMiddleware.RequirePermission(auth.PermSecurityWrite), s.setRealtimeCaptureStatus) + protected.GET("/security/health", s.authMiddleware.RequirePermission(auth.PermSecurityRead), s.getSecurityHealth) + protected.POST("/security/refresh", s.authMiddleware.RequirePermission(auth.PermSecurityWrite), s.refreshSecurityScripts) + protected.GET("/deployments/:name/security", s.authMiddleware.RequirePermission(auth.PermSecurityRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeploymentSecurity) + protected.PUT("/deployments/:name/security", s.authMiddleware.RequirePermission(auth.PermSecurityWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.updateDeploymentSecurity) + protected.GET("/deployments/:name/security/events", s.authMiddleware.RequirePermission(auth.PermSecurityRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeploymentSecurityEvents) // Traffic endpoints - protected.GET("/traffic/logs", s.getTrafficLogs) - protected.GET("/traffic/stats", s.getTrafficStats) - protected.GET("/traffic/unknown-domains", s.getUnknownDomainStats) - protected.POST("/traffic/cleanup", s.cleanupTrafficLogs) - protected.GET("/deployments/:name/traffic", s.getDeploymentTrafficStats) + protected.GET("/traffic/logs", s.authMiddleware.RequirePermission(auth.PermTrafficRead), s.getTrafficLogs) + protected.GET("/traffic/stats", s.authMiddleware.RequirePermission(auth.PermTrafficRead), s.getTrafficStats) + protected.GET("/traffic/unknown-domains", s.authMiddleware.RequirePermission(auth.PermTrafficRead), s.getUnknownDomainStats) + protected.POST("/traffic/cleanup", s.authMiddleware.RequirePermission(auth.PermTrafficWrite), s.cleanupTrafficLogs) + protected.GET("/deployments/:name/traffic", s.authMiddleware.RequirePermission(auth.PermDeploymentsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeploymentTrafficStats) // Backup endpoints - protected.GET("/backups", s.listBackups) - protected.GET("/backups/:id", s.getBackup) - protected.POST("/backups", s.createBackup) - protected.DELETE("/backups/:id", s.deleteBackup) - protected.GET("/backups/:id/download", s.downloadBackup) - protected.GET("/deployments/:name/backups", s.listDeploymentBackups) - protected.POST("/deployments/:name/backups", s.createDeploymentBackup) - protected.GET("/deployments/:name/backup-config", s.getDeploymentBackupConfig) - protected.PUT("/deployments/:name/backup-config", s.updateDeploymentBackupConfig) - protected.POST("/backups/:id/restore", s.restoreBackup) - protected.GET("/backups/jobs", s.listBackupJobs) - protected.GET("/backups/jobs/:id", s.getBackupJob) + protected.GET("/backups", s.authMiddleware.RequirePermission(auth.PermBackupsRead), s.listBackups) + protected.GET("/backups/:id", s.authMiddleware.RequirePermission(auth.PermBackupsRead), s.getBackup) + protected.POST("/backups", s.authMiddleware.RequirePermission(auth.PermBackupsWrite), s.createBackup) + protected.DELETE("/backups/:id", s.authMiddleware.RequirePermission(auth.PermBackupsDelete), s.deleteBackup) + protected.GET("/backups/:id/download", s.authMiddleware.RequirePermission(auth.PermBackupsRead), s.downloadBackup) + protected.GET("/deployments/:name/backups", s.authMiddleware.RequirePermission(auth.PermBackupsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.listDeploymentBackups) + protected.POST("/deployments/:name/backups", s.authMiddleware.RequirePermission(auth.PermBackupsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.createDeploymentBackup) + protected.GET("/deployments/:name/backup-config", s.authMiddleware.RequirePermission(auth.PermBackupsRead), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelRead), s.getDeploymentBackupConfig) + protected.PUT("/deployments/:name/backup-config", s.authMiddleware.RequirePermission(auth.PermBackupsWrite), s.authMiddleware.RequireDeploymentAccess(auth.AccessLevelWrite), s.updateDeploymentBackupConfig) + protected.POST("/backups/:id/restore", s.authMiddleware.RequirePermission(auth.PermBackupsWrite), s.restoreBackup) + protected.GET("/backups/jobs", s.authMiddleware.RequirePermission(auth.PermBackupsRead), s.listBackupJobs) + protected.GET("/backups/jobs/:id", s.authMiddleware.RequirePermission(auth.PermBackupsRead), s.getBackupJob) // Scheduler endpoints - protected.GET("/scheduler/tasks", s.listScheduledTasks) - protected.GET("/scheduler/tasks/:id", s.getScheduledTask) - protected.POST("/scheduler/tasks", s.createScheduledTask) - protected.PUT("/scheduler/tasks/:id", s.updateScheduledTask) - protected.DELETE("/scheduler/tasks/:id", s.deleteScheduledTask) - protected.POST("/scheduler/tasks/:id/run", s.runTaskNow) - protected.GET("/scheduler/tasks/:id/executions", s.getTaskExecutions) - protected.GET("/scheduler/executions", s.getRecentExecutions) + protected.GET("/scheduler/tasks", s.authMiddleware.RequirePermission(auth.PermSchedulerRead), s.listScheduledTasks) + protected.GET("/scheduler/tasks/:id", s.authMiddleware.RequirePermission(auth.PermSchedulerRead), s.getScheduledTask) + protected.POST("/scheduler/tasks", s.authMiddleware.RequirePermission(auth.PermSchedulerWrite), s.createScheduledTask) + protected.PUT("/scheduler/tasks/:id", s.authMiddleware.RequirePermission(auth.PermSchedulerWrite), s.updateScheduledTask) + protected.DELETE("/scheduler/tasks/:id", s.authMiddleware.RequirePermission(auth.PermSchedulerDelete), s.deleteScheduledTask) + protected.POST("/scheduler/tasks/:id/run", s.authMiddleware.RequirePermission(auth.PermSchedulerWrite), s.runTaskNow) + protected.GET("/scheduler/tasks/:id/executions", s.authMiddleware.RequirePermission(auth.PermSchedulerRead), s.getTaskExecutions) + protected.GET("/scheduler/executions", s.authMiddleware.RequirePermission(auth.PermSchedulerRead), s.getRecentExecutions) // Audit endpoints - protected.GET("/audit/events", s.listAuditEvents) - protected.GET("/audit/events/:id", s.getAuditEvent) - protected.GET("/audit/stats", s.getAuditStats) - protected.POST("/audit/export", s.exportAuditEvents) - protected.DELETE("/audit/cleanup", s.cleanupAuditEvents) + protected.GET("/audit/events", s.authMiddleware.RequirePermission(auth.PermAuditRead), s.listAuditEvents) + protected.GET("/audit/events/:id", s.authMiddleware.RequirePermission(auth.PermAuditRead), s.getAuditEvent) + protected.GET("/audit/stats", s.authMiddleware.RequirePermission(auth.PermAuditRead), s.getAuditStats) + protected.POST("/audit/export", s.authMiddleware.RequirePermission(auth.PermAuditRead), s.exportAuditEvents) + protected.DELETE("/audit/cleanup", s.authMiddleware.RequirePermission(auth.PermSettingsWrite), s.cleanupAuditEvents) + + // User management endpoints (require auth manager) + if s.authManager != nil { + // Current user endpoints (any authenticated user) + protected.GET("/users/me", s.getCurrentUser) + protected.PUT("/users/me", s.updateCurrentUser) + protected.PUT("/users/me/password", s.updateCurrentUserPassword) + + // User management (admin only) + usersGroup := protected.Group("/users") + usersGroup.Use(s.authMiddleware.RequirePermission(auth.PermUsersRead)) + { + usersGroup.GET("", s.listUsers) + usersGroup.GET("/:id", s.getUser) + usersGroup.POST("", s.authMiddleware.RequirePermission(auth.PermUsersWrite), s.createUser) + usersGroup.PUT("/:id", s.authMiddleware.RequirePermission(auth.PermUsersWrite), s.updateUser) + usersGroup.DELETE("/:id", s.authMiddleware.RequirePermission(auth.PermUsersDelete), s.deleteUser) + + // User deployment access + usersGroup.GET("/:id/deployments", s.getUserDeployments) + usersGroup.POST("/:id/deployments", s.authMiddleware.RequirePermission(auth.PermUsersWrite), s.assignUserDeployment) + usersGroup.PUT("/:id/deployments/:name", s.authMiddleware.RequirePermission(auth.PermUsersWrite), s.updateUserDeployment) + usersGroup.DELETE("/:id/deployments/:name", s.authMiddleware.RequirePermission(auth.PermUsersWrite), s.removeUserDeployment) + } + + // API key management + apiKeysGroup := protected.Group("/apikeys") + apiKeysGroup.Use(s.authMiddleware.RequirePermission(auth.PermAPIKeysRead)) + { + apiKeysGroup.GET("", s.listAPIKeys) + apiKeysGroup.GET("/:id", s.getAPIKey) + apiKeysGroup.POST("", s.authMiddleware.RequirePermission(auth.PermAPIKeysWrite), s.createAPIKey) + apiKeysGroup.DELETE("/:id", s.authMiddleware.RequirePermission(auth.PermAPIKeysDelete), s.deleteAPIKey) + apiKeysGroup.POST("/:id/revoke", s.authMiddleware.RequirePermission(auth.PermAPIKeysDelete), s.revokeAPIKey) + } + + // Get users with access to a deployment + protected.GET("/deployments/:name/users", s.authMiddleware.RequirePermission(auth.PermUsersRead), s.getDeploymentUsers) + } // DNS plugin routes dnsGroup := protected.Group("/dns") + dnsGroup.Use(s.authMiddleware.RequirePermission(auth.PermDNSRead)) { dnsGroup.GET("/providers", s.listDNSProviders) @@ -438,14 +515,10 @@ func (s *Server) Stop() error { } func (s *Server) healthCheck(c *gin.Context) { - stats, _ := s.manager.GetStats() - c.JSON(http.StatusOK, gin.H{ - "status": "healthy", - "agent": "flatrun", - "version": version.Get(), - "deployments_path": s.config.DeploymentsPath, - "stats": stats, + "status": "healthy", + "agent": "flatrun", + "version": version.Get(), }) } @@ -458,6 +531,17 @@ func (s *Server) listDeployments(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + var filtered []models.Deployment + for _, d := range deployments { + if actor.CanAccessDeployment(d.Name, auth.AccessLevelRead) { + filtered = append(filtered, d) + } + } + deployments = filtered + } + c.JSON(http.StatusOK, gin.H{ "deployments": deployments, "path": s.manager.BasePath(), @@ -577,6 +661,13 @@ func (s *Server) createDeployment(c *gin.Context) { return } + if s.authManager != nil { + actor := auth.GetActorFromContext(c) + if actor != nil && actor.User != nil && actor.Role != auth.RoleAdmin { + _ = s.authManager.AssignDeployment(actor.User.ID, req.Name, auth.AccessLevelAdmin, actor.User.ID) + } + } + var dbEnvVars []EnvVar if req.UseSharedDatabase && s.config.Infrastructure.Database.Enabled { dbResult, err := s.createDatabaseForDeployment(req.Name) @@ -1673,6 +1764,13 @@ func (s *Server) createPluginDeployment(c *gin.Context) { return } + if s.authManager != nil { + actor := auth.GetActorFromContext(c) + if actor != nil && actor.User != nil && actor.Role != auth.RoleAdmin { + _ = s.authManager.AssignDeployment(actor.User.ID, req.Name, auth.AccessLevelAdmin, actor.User.ID) + } + } + c.JSON(http.StatusCreated, gin.H{ "message": "Deployment created", "deployment": result, @@ -2904,7 +3002,7 @@ func (s *Server) setupProxyWithRetry(deployment *models.Deployment, maxRetries i } func (s *Server) getSystemStats(c *gin.Context) { - stats, err := s.manager.GetStats() + deployments, err := s.manager.ListDeployments() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": err.Error(), @@ -2912,6 +3010,37 @@ func (s *Server) getSystemStats(c *gin.Context) { return } + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + var filtered []models.Deployment + for _, d := range deployments { + if actor.CanAccessDeployment(d.Name, auth.AccessLevelRead) { + filtered = append(filtered, d) + } + } + deployments = filtered + } + + depStats := gin.H{ + "total_deployments": len(deployments), + "running": 0, + "stopped": 0, + "error": 0, + "unknown": 0, + } + for _, d := range deployments { + switch d.Status { + case "running": + depStats["running"] = depStats["running"].(int) + 1 + case "stopped": + depStats["stopped"] = depStats["stopped"].(int) + 1 + case "error": + depStats["error"] = depStats["error"].(int) + 1 + default: + depStats["unknown"] = depStats["unknown"].(int) + 1 + } + } + containerStats, _ := s.networksManager.GetContainerStats() imageStats, _ := s.networksManager.GetImageStats() volumeStats, _ := s.networksManager.GetVolumeStats() @@ -2929,7 +3058,7 @@ func (s *Server) getSystemStats(c *gin.Context) { systemStats, _ := system.GetSystemStats() c.JSON(http.StatusOK, gin.H{ - "deployments": stats, + "deployments": depStats, "containers": containerStats, "images": imageStats, "volumes": volumeStats, diff --git a/internal/api/user_deployments.go b/internal/api/user_deployments.go new file mode 100644 index 0000000..e35db5c --- /dev/null +++ b/internal/api/user_deployments.go @@ -0,0 +1,180 @@ +package api + +import ( + "net/http" + "strconv" + + "github.com/flatrun/agent/internal/auth" + "github.com/gin-gonic/gin" +) + +func (s *Server) getUserDeployments(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"}) + return + } + + deployments, err := s.authManager.GetUserDeployments(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user deployments"}) + return + } + + c.JSON(http.StatusOK, gin.H{"deployments": deploymentsToResponse(deployments)}) +} + +func (s *Server) assignUserDeployment(c *gin.Context) { + idStr := c.Param("id") + userID, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"}) + return + } + + var req struct { + DeploymentName string `json:"deployment_name" binding:"required"` + AccessLevel string `json:"access_level" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"}) + return + } + + if !auth.ValidAccessLevel(req.AccessLevel) { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid access level. Must be read, write, or admin"}) + return + } + + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + if !actor.CanAccessDeployment(req.DeploymentName, auth.AccessLevelAdmin) { + c.JSON(http.StatusForbidden, gin.H{"error": "No admin access to this deployment"}) + return + } + } + + grantedBy := int64(0) + if actor != nil && actor.User != nil { + grantedBy = actor.User.ID + } + + if err := s.authManager.AssignDeployment(userID, req.DeploymentName, req.AccessLevel, grantedBy); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to assign deployment"}) + return + } + + c.JSON(http.StatusCreated, gin.H{ + "message": "Deployment access granted", + "deployment_name": req.DeploymentName, + "access_level": req.AccessLevel, + }) +} + +func (s *Server) updateUserDeployment(c *gin.Context) { + idStr := c.Param("id") + userID, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"}) + return + } + + deploymentName := c.Param("name") + if deploymentName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Deployment name required"}) + return + } + + var req struct { + AccessLevel string `json:"access_level" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"}) + return + } + + if !auth.ValidAccessLevel(req.AccessLevel) { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid access level"}) + return + } + + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin { + if !actor.CanAccessDeployment(deploymentName, auth.AccessLevelAdmin) { + c.JSON(http.StatusForbidden, gin.H{"error": "No admin access to this deployment"}) + return + } + } + + if err := s.authManager.UpdateUserDeployment(userID, deploymentName, req.AccessLevel); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update deployment access"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "message": "Deployment access updated", + "deployment_name": deploymentName, + "access_level": req.AccessLevel, + }) +} + +func (s *Server) removeUserDeployment(c *gin.Context) { + idStr := c.Param("id") + userID, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"}) + return + } + + deploymentName := c.Param("name") + if deploymentName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Deployment name required"}) + return + } + + if err := s.authManager.RemoveDeploymentAccess(userID, deploymentName); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to remove deployment access"}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "Deployment access removed"}) +} + +func (s *Server) getDeploymentUsers(c *gin.Context) { + deploymentName := c.Param("name") + if deploymentName == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Deployment name required"}) + return + } + + deploymentUsers, err := s.authManager.GetDeploymentUsers(deploymentName) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get deployment users"}) + return + } + + users := make([]gin.H, 0, len(deploymentUsers)) + for _, du := range deploymentUsers { + user, err := s.authManager.GetUser(du.UserID) + if err != nil { + continue + } + users = append(users, gin.H{ + "user_id": du.UserID, + "username": user.Username, + "email": user.Email, + "role": user.Role, + "access_level": du.AccessLevel, + "granted_by": du.GrantedBy, + "created_at": du.CreatedAt, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "deployment_name": deploymentName, + "users": users, + }) +} diff --git a/internal/api/user_deployments_test.go b/internal/api/user_deployments_test.go new file mode 100644 index 0000000..f87df54 --- /dev/null +++ b/internal/api/user_deployments_test.go @@ -0,0 +1,396 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/pkg/config" + "github.com/gin-gonic/gin" +) + +func setupUserDeploymentsTestServer(t *testing.T) (*Server, *gin.Engine, func()) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "user_deployments_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + cfg := &config.Config{ + DeploymentsPath: tmpDir, + Auth: config.AuthConfig{ + Enabled: true, + JWTSecret: "test-jwt-secret-key-for-testing", + APIKeys: []string{"legacy-test-key"}, + }, + } + + os.Setenv("FLATRUN_ADMIN_PASSWORD", "testadminpass") + + authManager, err := auth.NewManager(tmpDir, &cfg.Auth) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create auth manager: %v", err) + } + + server := &Server{ + config: cfg, + authManager: authManager, + } + + router := gin.New() + authMiddleware := auth.NewMiddlewareWithManager(&cfg.Auth, authManager) + + api := router.Group("/api") + api.POST("/auth/login", authMiddleware.Login) + + protected := api.Group("") + protected.Use(authMiddleware.RequireAuth()) + { + usersRead := protected.Group("") + usersRead.Use(authMiddleware.RequirePermission(auth.PermUsersRead)) + { + usersRead.GET("/users/:id/deployments", server.getUserDeployments) + usersRead.GET("/deployments/:name/users", server.getDeploymentUsers) + } + + usersWrite := protected.Group("") + usersWrite.Use(authMiddleware.RequirePermission(auth.PermUsersWrite)) + { + usersWrite.POST("/users/:id/deployments", server.assignUserDeployment) + usersWrite.PUT("/users/:id/deployments/:name", server.updateUserDeployment) + usersWrite.DELETE("/users/:id/deployments/:name", server.removeUserDeployment) + } + } + + cleanup := func() { + authManager.Close() + os.RemoveAll(tmpDir) + os.Unsetenv("FLATRUN_ADMIN_PASSWORD") + } + + return server, router, cleanup +} + +func depLogin(t *testing.T, router *gin.Engine, username, password string) string { + body := map[string]string{ + "username": username, + "password": password, + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Login failed: %d - %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + return resp["token"].(string) +} + +func TestAssignUserDeployment(t *testing.T) { + server, router, cleanup := setupUserDeploymentsTestServer(t) + defer cleanup() + + user, _ := server.authManager.CreateUser("deployuser", "", "password", auth.RoleOperator, nil) + + token := depLogin(t, router, "admin", "testadminpass") + + body := map[string]string{ + "deployment_name": "my-app", + "access_level": "write", + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/api/users/%d/deployments", user.ID), bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected status 201, got %d: %s", w.Code, w.Body.String()) + } + + deployments, _ := server.authManager.GetUserDeployments(user.ID) + if len(deployments) != 1 { + t.Fatalf("Expected 1 deployment, got %d", len(deployments)) + } + + if deployments[0].DeploymentName != "my-app" { + t.Errorf("Expected deployment 'my-app', got %s", deployments[0].DeploymentName) + } + + if deployments[0].AccessLevel != "write" { + t.Errorf("Expected access level 'write', got %s", deployments[0].AccessLevel) + } +} + +func TestGetUserDeployments(t *testing.T) { + server, router, cleanup := setupUserDeploymentsTestServer(t) + defer cleanup() + + user, _ := server.authManager.CreateUser("listuser", "", "password", auth.RoleOperator, nil) + admin, _ := server.authManager.GetUserByUsername("admin") + + _ = server.authManager.AssignDeployment(user.ID, "app-a", "read", admin.ID) + _ = server.authManager.AssignDeployment(user.ID, "app-b", "write", admin.ID) + + token := depLogin(t, router, "admin", "testadminpass") + + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/users/%d/deployments", user.ID), nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + deployments, ok := resp["deployments"].([]interface{}) + if !ok { + t.Fatal("Expected deployments array in response") + } + + if len(deployments) != 2 { + t.Errorf("Expected 2 deployments, got %d", len(deployments)) + } +} + +func TestUpdateUserDeployment(t *testing.T) { + server, router, cleanup := setupUserDeploymentsTestServer(t) + defer cleanup() + + user, _ := server.authManager.CreateUser("updateuser", "", "password", auth.RoleOperator, nil) + admin, _ := server.authManager.GetUserByUsername("admin") + + _ = server.authManager.AssignDeployment(user.ID, "my-app", "read", admin.ID) + + token := depLogin(t, router, "admin", "testadminpass") + + body := map[string]string{ + "access_level": "admin", + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/api/users/%d/deployments/my-app", user.ID), bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + depMap, _ := server.authManager.GetUserDeploymentsMap(user.ID) + if depMap["my-app"] != "admin" { + t.Errorf("Expected access level 'admin', got %s", depMap["my-app"]) + } +} + +func TestRemoveUserDeployment(t *testing.T) { + server, router, cleanup := setupUserDeploymentsTestServer(t) + defer cleanup() + + user, _ := server.authManager.CreateUser("removeuser", "", "password", auth.RoleOperator, nil) + admin, _ := server.authManager.GetUserByUsername("admin") + + _ = server.authManager.AssignDeployment(user.ID, "to-remove", "write", admin.ID) + + token := depLogin(t, router, "admin", "testadminpass") + + req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/api/users/%d/deployments/to-remove", user.ID), nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + deployments, _ := server.authManager.GetUserDeployments(user.ID) + if len(deployments) != 0 { + t.Errorf("Expected 0 deployments after removal, got %d", len(deployments)) + } +} + +func TestGetDeploymentUsers(t *testing.T) { + server, router, cleanup := setupUserDeploymentsTestServer(t) + defer cleanup() + + user1, _ := server.authManager.CreateUser("depuser1", "", "password", auth.RoleOperator, nil) + user2, _ := server.authManager.CreateUser("depuser2", "", "password", auth.RoleViewer, nil) + admin, _ := server.authManager.GetUserByUsername("admin") + + _ = server.authManager.AssignDeployment(user1.ID, "shared-app", "write", admin.ID) + _ = server.authManager.AssignDeployment(user2.ID, "shared-app", "read", admin.ID) + + token := depLogin(t, router, "admin", "testadminpass") + + req := httptest.NewRequest(http.MethodGet, "/api/deployments/shared-app/users", nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + users, ok := resp["users"].([]interface{}) + if !ok { + t.Fatal("Expected users array in response") + } + + if len(users) != 2 { + t.Errorf("Expected 2 users, got %d", len(users)) + } +} + +func TestOperatorCannotAssignDeployment(t *testing.T) { + server, router, cleanup := setupUserDeploymentsTestServer(t) + defer cleanup() + + _, _ = server.authManager.CreateUser("operator", "", "operatorpass", auth.RoleOperator, nil) + + token := depLogin(t, router, "operator", "operatorpass") + + body := map[string]string{ + "deployment_name": "my-app", + "access_level": "read", + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/api/users/1/deployments", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected status 403, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestViewerCannotViewUserDeployments(t *testing.T) { + server, router, cleanup := setupUserDeploymentsTestServer(t) + defer cleanup() + + _, _ = server.authManager.CreateUser("viewer", "", "viewerpass", auth.RoleViewer, nil) + + token := depLogin(t, router, "viewer", "viewerpass") + + req := httptest.NewRequest(http.MethodGet, "/api/users/1/deployments", nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected status 403, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestAssignDeploymentInvalidAccessLevel(t *testing.T) { + server, router, cleanup := setupUserDeploymentsTestServer(t) + defer cleanup() + + user, _ := server.authManager.CreateUser("badaccess", "", "password", auth.RoleOperator, nil) + + token := depLogin(t, router, "admin", "testadminpass") + + body := map[string]string{ + "deployment_name": "my-app", + "access_level": "superadmin", + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/api/users/%d/deployments", user.ID), bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestAssignMultipleDeployments(t *testing.T) { + server, router, cleanup := setupUserDeploymentsTestServer(t) + defer cleanup() + + user, _ := server.authManager.CreateUser("multiuser", "", "password", auth.RoleOperator, nil) + + token := depLogin(t, router, "admin", "testadminpass") + + assignments := []struct { + name string + level string + }{ + {"app-a", "read"}, + {"app-b", "write"}, + {"app-c", "admin"}, + } + + for _, a := range assignments { + body := map[string]string{ + "deployment_name": a.name, + "access_level": a.level, + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/api/users/%d/deployments", user.ID), bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Failed to assign %s: %d - %s", a.name, w.Code, w.Body.String()) + } + } + + depMap, _ := server.authManager.GetUserDeploymentsMap(user.ID) + + if len(depMap) != 3 { + t.Errorf("Expected 3 deployments, got %d", len(depMap)) + } + + if depMap["app-a"] != "read" { + t.Errorf("Expected app-a access 'read', got %s", depMap["app-a"]) + } + if depMap["app-b"] != "write" { + t.Errorf("Expected app-b access 'write', got %s", depMap["app-b"]) + } + if depMap["app-c"] != "admin" { + t.Errorf("Expected app-c access 'admin', got %s", depMap["app-c"]) + } +} diff --git a/internal/api/users.go b/internal/api/users.go new file mode 100644 index 0000000..0b46bff --- /dev/null +++ b/internal/api/users.go @@ -0,0 +1,284 @@ +package api + +import ( + "net/http" + "strconv" + + "github.com/flatrun/agent/internal/auth" + "github.com/gin-gonic/gin" +) + +func (s *Server) listUsers(c *gin.Context) { + users, err := s.authManager.GetUsers() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list users"}) + return + } + + response := make([]gin.H, 0, len(users)) + for _, u := range users { + resp := userToResponse(&u) + deps, _ := s.authManager.GetUserDeployments(u.ID) + resp["deployment_count"] = len(deps) + response = append(response, resp) + } + + c.JSON(http.StatusOK, gin.H{"users": response}) +} + +func (s *Server) getUser(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"}) + return + } + + user, err := s.authManager.GetUser(id) + if err == auth.ErrUserNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + return + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"}) + return + } + + deployments, _ := s.authManager.GetUserDeployments(user.ID) + + c.JSON(http.StatusOK, gin.H{ + "user": userToResponse(user), + "deployments": deploymentsToResponse(deployments), + }) +} + +func (s *Server) createUser(c *gin.Context) { + var req struct { + Username string `json:"username" binding:"required"` + Email string `json:"email"` + Password string `json:"password" binding:"required"` + Role auth.Role `json:"role" binding:"required"` + Permissions []string `json:"permissions"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"}) + return + } + + if !req.Role.IsValid() { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid role. Must be admin, operator, or viewer"}) + return + } + + actor := auth.GetActorFromContext(c) + if actor != nil && actor.Role != auth.RoleAdmin && req.Role == auth.RoleAdmin { + c.JSON(http.StatusForbidden, gin.H{"error": "Only admins can create admin users"}) + return + } + + user, err := s.authManager.CreateUser(req.Username, req.Email, req.Password, req.Role, req.Permissions) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user: " + err.Error()}) + return + } + + c.JSON(http.StatusCreated, gin.H{"user": userToResponse(user)}) +} + +func (s *Server) updateUser(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"}) + return + } + + user, err := s.authManager.GetUser(id) + if err == auth.ErrUserNotFound { + c.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + return + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user"}) + return + } + + var req struct { + Username string `json:"username"` + Email string `json:"email"` + Role auth.Role `json:"role"` + Permissions *[]string `json:"permissions"` + IsActive *bool `json:"is_active"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"}) + return + } + + if req.Username != "" { + user.Username = req.Username + } + if req.Email != "" { + user.Email = req.Email + } + if req.Role != "" { + if !req.Role.IsValid() { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid role"}) + return + } + actor := auth.GetActorFromContext(c) + if req.Role == auth.RoleAdmin && actor != nil && actor.Role != auth.RoleAdmin { + c.JSON(http.StatusForbidden, gin.H{"error": "Only admins can assign the admin role"}) + return + } + user.Role = req.Role + } + if req.Permissions != nil { + user.Permissions = *req.Permissions + } + if req.IsActive != nil { + user.IsActive = *req.IsActive + } + + if err := s.authManager.UpdateUser(user); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update user"}) + return + } + + c.JSON(http.StatusOK, gin.H{"user": userToResponse(user)}) +} + +func (s *Server) deleteUser(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"}) + return + } + + actor := auth.GetActorFromContext(c) + if actor != nil && actor.UserID == id { + c.JSON(http.StatusBadRequest, gin.H{"error": "Cannot delete your own account"}) + return + } + + if err := s.authManager.DeleteUser(id, actor.UserID); err != nil { + if err == auth.ErrCannotDeleteSelf { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete user"}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "User deleted"}) +} + +func (s *Server) getCurrentUser(c *gin.Context) { + actor := auth.GetActorFromContext(c) + if actor == nil || actor.User == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Not authenticated"}) + return + } + + deployments, _ := s.authManager.GetUserDeployments(actor.User.ID) + permissions := auth.EffectivePermissions(actor.User, actor.Role) + + c.JSON(http.StatusOK, gin.H{ + "user": userToResponse(actor.User), + "permissions": permissions, + "deployments": deploymentsToResponse(deployments), + }) +} + +func (s *Server) updateCurrentUser(c *gin.Context) { + actor := auth.GetActorFromContext(c) + if actor == nil || actor.User == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Not authenticated"}) + return + } + + var req struct { + Email string `json:"email"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"}) + return + } + + user := actor.User + if req.Email != "" { + user.Email = req.Email + } + + if err := s.authManager.UpdateUser(user); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update profile"}) + return + } + + c.JSON(http.StatusOK, gin.H{"user": userToResponse(user)}) +} + +func (s *Server) updateCurrentUserPassword(c *gin.Context) { + actor := auth.GetActorFromContext(c) + if actor == nil || actor.User == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Not authenticated"}) + return + } + + var req struct { + CurrentPassword string `json:"current_password" binding:"required"` + NewPassword string `json:"new_password" binding:"required"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"}) + return + } + + if !auth.VerifyPassword(req.CurrentPassword, actor.User.PasswordHash) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Current password is incorrect"}) + return + } + + if err := s.authManager.UpdatePassword(actor.User.ID, req.NewPassword); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update password"}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "Password updated"}) +} + +func userToResponse(u *auth.User) gin.H { + resp := gin.H{ + "id": u.ID, + "uid": u.UID, + "username": u.Username, + "email": u.Email, + "role": u.Role, + "is_active": u.IsActive, + "created_at": u.CreatedAt, + "updated_at": u.UpdatedAt, + "last_login_at": u.LastLoginAt, + } + if len(u.Permissions) > 0 { + resp["permissions"] = u.Permissions + } + return resp +} + +func deploymentsToResponse(deployments []auth.UserDeployment) []gin.H { + result := make([]gin.H, 0, len(deployments)) + for _, d := range deployments { + result = append(result, gin.H{ + "deployment_name": d.DeploymentName, + "access_level": d.AccessLevel, + "granted_by": d.GrantedBy, + "created_at": d.CreatedAt, + }) + } + return result +} diff --git a/internal/api/users_test.go b/internal/api/users_test.go new file mode 100644 index 0000000..bb8cd93 --- /dev/null +++ b/internal/api/users_test.go @@ -0,0 +1,464 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/flatrun/agent/internal/auth" + "github.com/flatrun/agent/pkg/config" + "github.com/gin-gonic/gin" +) + +func setupTestServer(t *testing.T) (*Server, *gin.Engine, func()) { + gin.SetMode(gin.TestMode) + + tmpDir, err := os.MkdirTemp("", "api_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + cfg := &config.Config{ + DeploymentsPath: tmpDir, + Auth: config.AuthConfig{ + Enabled: true, + JWTSecret: "test-jwt-secret-key-for-testing", + APIKeys: []string{"legacy-test-key"}, + }, + } + + os.Setenv("FLATRUN_ADMIN_PASSWORD", "testadminpass") + + authManager, err := auth.NewManager(tmpDir, &cfg.Auth) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create auth manager: %v", err) + } + + server := &Server{ + config: cfg, + authManager: authManager, + } + + router := gin.New() + authMiddleware := auth.NewMiddlewareWithManager(&cfg.Auth, authManager) + + api := router.Group("/api") + api.POST("/auth/login", authMiddleware.Login) + + protected := api.Group("") + protected.Use(authMiddleware.RequireAuth()) + { + protected.GET("/users/me", server.getCurrentUser) + protected.PUT("/users/me", server.updateCurrentUser) + protected.PUT("/users/me/password", server.updateCurrentUserPassword) + + admin := protected.Group("") + admin.Use(authMiddleware.RequirePermission(auth.PermUsersRead)) + { + admin.GET("/users", server.listUsers) + admin.GET("/users/:id", server.getUser) + } + + adminWrite := protected.Group("") + adminWrite.Use(authMiddleware.RequirePermission(auth.PermUsersWrite)) + { + adminWrite.POST("/users", server.createUser) + adminWrite.PUT("/users/:id", server.updateUser) + } + + adminDelete := protected.Group("") + adminDelete.Use(authMiddleware.RequirePermission(auth.PermUsersDelete)) + { + adminDelete.DELETE("/users/:id", server.deleteUser) + } + } + + cleanup := func() { + authManager.Close() + os.RemoveAll(tmpDir) + os.Unsetenv("FLATRUN_ADMIN_PASSWORD") + } + + return server, router, cleanup +} + +func loginAndGetToken(t *testing.T, router *gin.Engine, username, password string) string { + body := map[string]string{ + "username": username, + "password": password, + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Login failed: %d - %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + token, ok := resp["token"].(string) + if !ok { + t.Fatal("Token not returned in login response") + } + + return token +} + +func TestListUsersAsAdmin(t *testing.T) { + _, router, cleanup := setupTestServer(t) + defer cleanup() + + token := loginAndGetToken(t, router, "admin", "testadminpass") + + req := httptest.NewRequest(http.MethodGet, "/api/users", nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + users, ok := resp["users"].([]interface{}) + if !ok { + t.Fatal("Expected users array in response") + } + + if len(users) < 1 { + t.Error("Expected at least 1 user (admin)") + } +} + +func TestCreateUserAsAdmin(t *testing.T) { + _, router, cleanup := setupTestServer(t) + defer cleanup() + + token := loginAndGetToken(t, router, "admin", "testadminpass") + + body := map[string]string{ + "username": "newuser", + "email": "new@example.com", + "password": "newpassword123", + "role": "operator", + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Errorf("Expected status 201, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + user, ok := resp["user"].(map[string]interface{}) + if !ok { + t.Fatal("Expected user object in response") + } + + if user["username"] != "newuser" { + t.Errorf("Expected username 'newuser', got %v", user["username"]) + } + + if user["role"] != "operator" { + t.Errorf("Expected role 'operator', got %v", user["role"]) + } +} + +func TestCreateUserInvalidRole(t *testing.T) { + _, router, cleanup := setupTestServer(t) + defer cleanup() + + token := loginAndGetToken(t, router, "admin", "testadminpass") + + body := map[string]string{ + "username": "baduser", + "password": "password", + "role": "superadmin", + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestGetCurrentUser(t *testing.T) { + _, router, cleanup := setupTestServer(t) + defer cleanup() + + token := loginAndGetToken(t, router, "admin", "testadminpass") + + req := httptest.NewRequest(http.MethodGet, "/api/users/me", nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + user, ok := resp["user"].(map[string]interface{}) + if !ok { + t.Fatal("Expected user object in response") + } + + if user["username"] != "admin" { + t.Errorf("Expected username 'admin', got %v", user["username"]) + } + + if user["role"] != "admin" { + t.Errorf("Expected role 'admin', got %v", user["role"]) + } + + perms, ok := resp["permissions"].([]interface{}) + if !ok { + t.Fatal("Expected permissions array in response") + } + + if len(perms) == 0 { + t.Error("Admin should have permissions") + } +} + +func TestUpdateCurrentUserPassword(t *testing.T) { + server, router, cleanup := setupTestServer(t) + defer cleanup() + + _, _ = server.authManager.CreateUser("passuser", "", "oldpassword", auth.RoleViewer, nil) + + token := loginAndGetToken(t, router, "passuser", "oldpassword") + + body := map[string]string{ + "current_password": "oldpassword", + "new_password": "newpassword123", + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPut, "/api/users/me/password", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + newToken := loginAndGetToken(t, router, "passuser", "newpassword123") + if newToken == "" { + t.Error("Should be able to login with new password") + } +} + +func TestUpdateCurrentUserPasswordWrongCurrent(t *testing.T) { + server, router, cleanup := setupTestServer(t) + defer cleanup() + + _, _ = server.authManager.CreateUser("wrongpass", "", "correctpassword", auth.RoleViewer, nil) + + token := loginAndGetToken(t, router, "wrongpass", "correctpassword") + + body := map[string]string{ + "current_password": "wrongpassword", + "new_password": "newpassword", + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPut, "/api/users/me/password", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestViewerCannotAccessUsersList(t *testing.T) { + server, router, cleanup := setupTestServer(t) + defer cleanup() + + _, _ = server.authManager.CreateUser("viewer", "", "viewerpass", auth.RoleViewer, nil) + + token := loginAndGetToken(t, router, "viewer", "viewerpass") + + req := httptest.NewRequest(http.MethodGet, "/api/users", nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected status 403, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestOperatorCannotAccessUsersList(t *testing.T) { + server, router, cleanup := setupTestServer(t) + defer cleanup() + + _, _ = server.authManager.CreateUser("operator", "", "operatorpass", auth.RoleOperator, nil) + + token := loginAndGetToken(t, router, "operator", "operatorpass") + + req := httptest.NewRequest(http.MethodGet, "/api/users", nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("Expected status 403, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUpdateUser(t *testing.T) { + server, router, cleanup := setupTestServer(t) + defer cleanup() + + user, _ := server.authManager.CreateUser("updateme", "", "password", auth.RoleViewer, nil) + + token := loginAndGetToken(t, router, "admin", "testadminpass") + + body := map[string]interface{}{ + "role": "operator", + "email": "updated@example.com", + "is_active": false, + } + jsonBody, _ := json.Marshal(body) + + req := httptest.NewRequest(http.MethodPut, "/api/users/"+itoa(user.ID), bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + _ = json.Unmarshal(w.Body.Bytes(), &resp) + + respUser := resp["user"].(map[string]interface{}) + if respUser["role"] != "operator" { + t.Errorf("Expected role 'operator', got %v", respUser["role"]) + } + if respUser["email"] != "updated@example.com" { + t.Errorf("Expected email 'updated@example.com', got %v", respUser["email"]) + } + if respUser["is_active"] != false { + t.Errorf("Expected is_active false, got %v", respUser["is_active"]) + } +} + +func TestDeleteUser(t *testing.T) { + server, router, cleanup := setupTestServer(t) + defer cleanup() + + user, _ := server.authManager.CreateUser("deleteme", "", "password", auth.RoleViewer, nil) + + token := loginAndGetToken(t, router, "admin", "testadminpass") + + req := httptest.NewRequest(http.MethodDelete, "/api/users/"+itoa(user.ID), nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + _, err := server.authManager.GetUser(user.ID) + if err != auth.ErrUserNotFound { + t.Error("User should be deleted") + } +} + +func TestDeleteUserCannotDeleteSelf(t *testing.T) { + _, router, cleanup := setupTestServer(t) + defer cleanup() + + token := loginAndGetToken(t, router, "admin", "testadminpass") + + req := httptest.NewRequest(http.MethodDelete, "/api/users/1", nil) + req.Header.Set("Authorization", "Bearer "+token) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestUnauthorizedAccess(t *testing.T) { + _, router, cleanup := setupTestServer(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/users", nil) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestInvalidToken(t *testing.T) { + _, router, cleanup := setupTestServer(t) + defer cleanup() + + req := httptest.NewRequest(http.MethodGet, "/api/users/me", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401, got %d: %s", w.Code, w.Body.String()) + } +} + +func itoa(n int64) string { + return fmt.Sprintf("%d", n) +} diff --git a/internal/auth/apikey.go b/internal/auth/apikey.go new file mode 100644 index 0000000..dddd68a --- /dev/null +++ b/internal/auth/apikey.go @@ -0,0 +1,62 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" +) + +const ( + apiKeyLength = 32 + keyIDLength = 12 + keyPrefix = "fr_" +) + +func GenerateAPIKey() (plainKey string, keyHash string, keyID string, prefix string, err error) { + keyBytes := make([]byte, apiKeyLength) + if _, err := rand.Read(keyBytes); err != nil { + return "", "", "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + + plainKey = keyPrefix + base64.RawURLEncoding.EncodeToString(keyBytes) + + hash := sha256.Sum256([]byte(plainKey)) + keyHash = hex.EncodeToString(hash[:]) + + idBytes := make([]byte, keyIDLength/2) + if _, err := rand.Read(idBytes); err != nil { + return "", "", "", "", fmt.Errorf("failed to generate key ID: %w", err) + } + keyID = hex.EncodeToString(idBytes) + + if len(plainKey) >= 12 { + prefix = plainKey[:12] + "..." + } else { + prefix = plainKey + "..." + } + + return plainKey, keyHash, keyID, prefix, nil +} + +func HashAPIKey(key string) string { + hash := sha256.Sum256([]byte(key)) + return hex.EncodeToString(hash[:]) +} + +func GenerateUID() (string, error) { + bytes := make([]byte, 16) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil +} + +func GenerateSessionID() (string, error) { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(bytes), nil +} diff --git a/internal/auth/apikey_test.go b/internal/auth/apikey_test.go new file mode 100644 index 0000000..4d2db69 --- /dev/null +++ b/internal/auth/apikey_test.go @@ -0,0 +1,111 @@ +package auth + +import ( + "strings" + "testing" +) + +func TestGenerateAPIKey(t *testing.T) { + plainKey, keyHash, keyID, prefix, err := GenerateAPIKey() + if err != nil { + t.Fatalf("GenerateAPIKey failed: %v", err) + } + + if plainKey == "" { + t.Error("GenerateAPIKey returned empty plainKey") + } + + if !strings.HasPrefix(plainKey, "fr_") { + t.Error("API key should start with 'fr_' prefix") + } + + if keyHash == "" { + t.Error("GenerateAPIKey returned empty keyHash") + } + + if keyID == "" { + t.Error("GenerateAPIKey returned empty keyID") + } + + if prefix == "" { + t.Error("GenerateAPIKey returned empty prefix") + } + + if !strings.HasSuffix(prefix, "...") { + t.Error("Prefix should end with '...'") + } +} + +func TestGenerateAPIKeyUniqueness(t *testing.T) { + key1, hash1, id1, _, _ := GenerateAPIKey() + key2, hash2, id2, _, _ := GenerateAPIKey() + + if key1 == key2 { + t.Error("GenerateAPIKey should generate unique keys") + } + + if hash1 == hash2 { + t.Error("GenerateAPIKey should generate unique hashes") + } + + if id1 == id2 { + t.Error("GenerateAPIKey should generate unique IDs") + } +} + +func TestHashAPIKey(t *testing.T) { + key := "fr_test_api_key_12345" + + hash1 := HashAPIKey(key) + hash2 := HashAPIKey(key) + + if hash1 != hash2 { + t.Error("HashAPIKey should return same hash for same key") + } + + if hash1 == key { + t.Error("HashAPIKey should not return plaintext key") + } + + if len(hash1) != 64 { + t.Errorf("SHA-256 hash should be 64 hex characters, got %d", len(hash1)) + } +} + +func TestGenerateUID(t *testing.T) { + uid1, err := GenerateUID() + if err != nil { + t.Fatalf("GenerateUID failed: %v", err) + } + + uid2, _ := GenerateUID() + + if uid1 == "" { + t.Error("GenerateUID returned empty string") + } + + if uid1 == uid2 { + t.Error("GenerateUID should generate unique values") + } + + if len(uid1) != 32 { + t.Errorf("UID should be 32 hex characters, got %d", len(uid1)) + } +} + +func TestGenerateSessionID(t *testing.T) { + sid1, err := GenerateSessionID() + if err != nil { + t.Fatalf("GenerateSessionID failed: %v", err) + } + + sid2, _ := GenerateSessionID() + + if sid1 == "" { + t.Error("GenerateSessionID returned empty string") + } + + if sid1 == sid2 { + t.Error("GenerateSessionID should generate unique values") + } +} diff --git a/internal/auth/db.go b/internal/auth/db.go new file mode 100644 index 0000000..f35000c --- /dev/null +++ b/internal/auth/db.go @@ -0,0 +1,730 @@ +package auth + +import ( + "database/sql" + "os" + "path/filepath" + "sync" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +type DB struct { + conn *sql.DB + path string + mu sync.RWMutex +} + +func NewAuthDB(deploymentsPath string) (*DB, error) { + dbDir := filepath.Join(deploymentsPath, ".flatrun") + if err := os.MkdirAll(dbDir, 0755); err != nil { + return nil, err + } + + dbPath := filepath.Join(dbDir, "auth.db") + conn, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_busy_timeout=5000") + if err != nil { + return nil, err + } + + conn.SetMaxOpenConns(10) + conn.SetMaxIdleConns(5) + conn.SetConnMaxLifetime(time.Hour) + + db := &DB{conn: conn, path: dbPath} + if err := db.migrate(); err != nil { + conn.Close() + return nil, err + } + + return db, nil +} + +func (db *DB) Close() error { + db.mu.Lock() + defer db.mu.Unlock() + return db.conn.Close() +} + +func (db *DB) migrate() error { + schema := ` + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + uid TEXT UNIQUE NOT NULL, + username TEXT UNIQUE NOT NULL, + email TEXT, + password_hash TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'viewer', + is_active BOOLEAN DEFAULT TRUE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + last_login_at DATETIME + ); + + CREATE TABLE IF NOT EXISTS api_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + key_id TEXT UNIQUE NOT NULL, + user_id INTEGER NOT NULL, + name TEXT NOT NULL, + description TEXT, + key_hash TEXT NOT NULL, + key_prefix TEXT NOT NULL, + role TEXT, + permissions TEXT, + deployments TEXT, + expires_at DATETIME, + last_used_at DATETIME, + last_used_ip TEXT, + is_active BOOLEAN DEFAULT TRUE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT UNIQUE NOT NULL, + user_id INTEGER NOT NULL, + api_key_id INTEGER, + token_hash TEXT NOT NULL, + expires_at DATETIME NOT NULL, + revoked_at DATETIME, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + client_ip TEXT, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS user_deployments ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + deployment_name TEXT NOT NULL, + access_level TEXT NOT NULL DEFAULT 'read', + granted_by INTEGER, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE, + UNIQUE(user_id, deployment_name) + ); + + CREATE INDEX IF NOT EXISTS idx_users_username ON users(username); + CREATE INDEX IF NOT EXISTS idx_users_email ON users(email); + CREATE INDEX IF NOT EXISTS idx_users_uid ON users(uid); + CREATE INDEX IF NOT EXISTS idx_api_keys_user ON api_keys(user_id); + CREATE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash); + CREATE INDEX IF NOT EXISTS idx_sessions_user ON sessions(user_id); + CREATE INDEX IF NOT EXISTS idx_sessions_token ON sessions(token_hash); + CREATE INDEX IF NOT EXISTS idx_user_deployments_user ON user_deployments(user_id); + CREATE INDEX IF NOT EXISTS idx_user_deployments_deployment ON user_deployments(deployment_name); + ` + + _, err := db.conn.Exec(schema) + if err != nil { + return err + } + + // Add permissions column to users table if missing (ignore error if column exists) + _, _ = db.conn.Exec(`ALTER TABLE users ADD COLUMN permissions TEXT`) + + return nil +} + +func (db *DB) CreateUser(user *User) (int64, error) { + db.mu.Lock() + defer db.mu.Unlock() + + result, err := db.conn.Exec(` + INSERT INTO users (uid, username, email, password_hash, role, permissions, is_active, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + user.UID, user.Username, user.Email, user.PasswordHash, user.Role, + user.GetPermissionsJSON(), user.IsActive, user.CreatedAt, user.UpdatedAt, + ) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +func (db *DB) GetUserByID(id int64) (*User, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var u User + var email, perms sql.NullString + var lastLogin sql.NullTime + + err := db.conn.QueryRow(` + SELECT id, uid, username, email, password_hash, role, permissions, is_active, created_at, updated_at, last_login_at + FROM users WHERE id = ?`, id).Scan( + &u.ID, &u.UID, &u.Username, &email, &u.PasswordHash, &u.Role, &perms, + &u.IsActive, &u.CreatedAt, &u.UpdatedAt, &lastLogin, + ) + if err != nil { + return nil, err + } + + u.Email = email.String + u.Permissions = ParsePermissionsJSON(perms.String) + if lastLogin.Valid { + u.LastLoginAt = lastLogin.Time + } + return &u, nil +} + +func (db *DB) GetUserByUID(uid string) (*User, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var u User + var email, perms sql.NullString + var lastLogin sql.NullTime + + err := db.conn.QueryRow(` + SELECT id, uid, username, email, password_hash, role, permissions, is_active, created_at, updated_at, last_login_at + FROM users WHERE uid = ?`, uid).Scan( + &u.ID, &u.UID, &u.Username, &email, &u.PasswordHash, &u.Role, &perms, + &u.IsActive, &u.CreatedAt, &u.UpdatedAt, &lastLogin, + ) + if err != nil { + return nil, err + } + + u.Email = email.String + u.Permissions = ParsePermissionsJSON(perms.String) + if lastLogin.Valid { + u.LastLoginAt = lastLogin.Time + } + return &u, nil +} + +func (db *DB) GetUserByUsername(username string) (*User, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var u User + var email, perms sql.NullString + var lastLogin sql.NullTime + + err := db.conn.QueryRow(` + SELECT id, uid, username, email, password_hash, role, permissions, is_active, created_at, updated_at, last_login_at + FROM users WHERE username = ?`, username).Scan( + &u.ID, &u.UID, &u.Username, &email, &u.PasswordHash, &u.Role, &perms, + &u.IsActive, &u.CreatedAt, &u.UpdatedAt, &lastLogin, + ) + if err != nil { + return nil, err + } + + u.Email = email.String + u.Permissions = ParsePermissionsJSON(perms.String) + if lastLogin.Valid { + u.LastLoginAt = lastLogin.Time + } + return &u, nil +} + +func (db *DB) GetUsers() ([]User, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + rows, err := db.conn.Query(` + SELECT id, uid, username, email, password_hash, role, permissions, is_active, created_at, updated_at, last_login_at + FROM users ORDER BY created_at DESC`) + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var u User + var email, perms sql.NullString + var lastLogin sql.NullTime + + if err := rows.Scan( + &u.ID, &u.UID, &u.Username, &email, &u.PasswordHash, &u.Role, &perms, + &u.IsActive, &u.CreatedAt, &u.UpdatedAt, &lastLogin, + ); err != nil { + return nil, err + } + + u.Email = email.String + u.Permissions = ParsePermissionsJSON(perms.String) + if lastLogin.Valid { + u.LastLoginAt = lastLogin.Time + } + users = append(users, u) + } + return users, nil +} + +func (db *DB) UpdateUser(user *User) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(` + UPDATE users SET username = ?, email = ?, role = ?, permissions = ?, is_active = ?, updated_at = ? + WHERE id = ?`, + user.Username, user.Email, user.Role, user.GetPermissionsJSON(), user.IsActive, time.Now(), user.ID, + ) + return err +} + +func (db *DB) UpdateUserPassword(id int64, passwordHash string) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`UPDATE users SET password_hash = ?, updated_at = ? WHERE id = ?`, + passwordHash, time.Now(), id) + return err +} + +func (db *DB) UpdateUserLastLogin(id int64) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`UPDATE users SET last_login_at = ? WHERE id = ?`, time.Now(), id) + return err +} + +func (db *DB) DeleteUser(id int64) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`DELETE FROM users WHERE id = ?`, id) + return err +} + +func (db *DB) CountUsers() (int, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var count int + err := db.conn.QueryRow(`SELECT COUNT(*) FROM users`).Scan(&count) + return count, err +} + +func (db *DB) CreateAPIKey(key *APIKey) (int64, error) { + db.mu.Lock() + defer db.mu.Unlock() + + var roleVal sql.NullString + if key.Role != "" { + roleVal = sql.NullString{String: string(key.Role), Valid: true} + } + + var expiresAt sql.NullTime + if !key.ExpiresAt.IsZero() { + expiresAt = sql.NullTime{Time: key.ExpiresAt, Valid: true} + } + + result, err := db.conn.Exec(` + INSERT INTO api_keys (key_id, user_id, name, description, key_hash, key_prefix, role, permissions, deployments, expires_at, is_active, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + key.KeyID, key.UserID, key.Name, key.Description, key.KeyHash, key.KeyPrefix, + roleVal, key.GetPermissionsJSON(), key.GetDeploymentsJSON(), expiresAt, key.IsActive, key.CreatedAt, + ) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +func (db *DB) GetAPIKeyByID(id int64) (*APIKey, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var k APIKey + var desc, role, perms, deps, lastIP sql.NullString + var expiresAt, lastUsed sql.NullTime + + err := db.conn.QueryRow(` + SELECT id, key_id, user_id, name, description, key_hash, key_prefix, role, permissions, deployments, + expires_at, last_used_at, last_used_ip, is_active, created_at + FROM api_keys WHERE id = ?`, id).Scan( + &k.ID, &k.KeyID, &k.UserID, &k.Name, &desc, &k.KeyHash, &k.KeyPrefix, &role, &perms, &deps, + &expiresAt, &lastUsed, &lastIP, &k.IsActive, &k.CreatedAt, + ) + if err != nil { + return nil, err + } + + k.Description = desc.String + k.Role = Role(role.String) + k.Permissions = ParsePermissionsJSON(perms.String) + k.Deployments = ParseDeploymentsJSON(deps.String) + k.LastUsedIP = lastIP.String + if expiresAt.Valid { + k.ExpiresAt = expiresAt.Time + } + if lastUsed.Valid { + k.LastUsedAt = lastUsed.Time + } + return &k, nil +} + +func (db *DB) GetAPIKeyByHash(hash string) (*APIKey, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var k APIKey + var desc, role, perms, deps, lastIP sql.NullString + var expiresAt, lastUsed sql.NullTime + + err := db.conn.QueryRow(` + SELECT id, key_id, user_id, name, description, key_hash, key_prefix, role, permissions, deployments, + expires_at, last_used_at, last_used_ip, is_active, created_at + FROM api_keys WHERE key_hash = ?`, hash).Scan( + &k.ID, &k.KeyID, &k.UserID, &k.Name, &desc, &k.KeyHash, &k.KeyPrefix, &role, &perms, &deps, + &expiresAt, &lastUsed, &lastIP, &k.IsActive, &k.CreatedAt, + ) + if err != nil { + return nil, err + } + + k.Description = desc.String + k.Role = Role(role.String) + k.Permissions = ParsePermissionsJSON(perms.String) + k.Deployments = ParseDeploymentsJSON(deps.String) + k.LastUsedIP = lastIP.String + if expiresAt.Valid { + k.ExpiresAt = expiresAt.Time + } + if lastUsed.Valid { + k.LastUsedAt = lastUsed.Time + } + return &k, nil +} + +func (db *DB) GetAPIKeysByUserID(userID int64) ([]APIKey, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + rows, err := db.conn.Query(` + SELECT id, key_id, user_id, name, description, key_hash, key_prefix, role, permissions, deployments, + expires_at, last_used_at, last_used_ip, is_active, created_at + FROM api_keys WHERE user_id = ? ORDER BY created_at DESC`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var keys []APIKey + for rows.Next() { + var k APIKey + var desc, role, perms, deps, lastIP sql.NullString + var expiresAt, lastUsed sql.NullTime + + if err := rows.Scan( + &k.ID, &k.KeyID, &k.UserID, &k.Name, &desc, &k.KeyHash, &k.KeyPrefix, &role, &perms, &deps, + &expiresAt, &lastUsed, &lastIP, &k.IsActive, &k.CreatedAt, + ); err != nil { + return nil, err + } + + k.Description = desc.String + k.Role = Role(role.String) + k.Permissions = ParsePermissionsJSON(perms.String) + k.Deployments = ParseDeploymentsJSON(deps.String) + k.LastUsedIP = lastIP.String + if expiresAt.Valid { + k.ExpiresAt = expiresAt.Time + } + if lastUsed.Valid { + k.LastUsedAt = lastUsed.Time + } + keys = append(keys, k) + } + return keys, nil +} + +func (db *DB) GetAllAPIKeys() ([]APIKey, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + rows, err := db.conn.Query(` + SELECT id, key_id, user_id, name, description, key_hash, key_prefix, role, permissions, deployments, + expires_at, last_used_at, last_used_ip, is_active, created_at + FROM api_keys ORDER BY created_at DESC`) + if err != nil { + return nil, err + } + defer rows.Close() + + var keys []APIKey + for rows.Next() { + var k APIKey + var desc, role, perms, deps, lastIP sql.NullString + var expiresAt, lastUsed sql.NullTime + + if err := rows.Scan( + &k.ID, &k.KeyID, &k.UserID, &k.Name, &desc, &k.KeyHash, &k.KeyPrefix, &role, &perms, &deps, + &expiresAt, &lastUsed, &lastIP, &k.IsActive, &k.CreatedAt, + ); err != nil { + return nil, err + } + + k.Description = desc.String + k.Role = Role(role.String) + k.Permissions = ParsePermissionsJSON(perms.String) + k.Deployments = ParseDeploymentsJSON(deps.String) + k.LastUsedIP = lastIP.String + if expiresAt.Valid { + k.ExpiresAt = expiresAt.Time + } + if lastUsed.Valid { + k.LastUsedAt = lastUsed.Time + } + keys = append(keys, k) + } + return keys, nil +} + +func (db *DB) UpdateAPIKeyLastUsed(id int64, ip string) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`UPDATE api_keys SET last_used_at = ?, last_used_ip = ? WHERE id = ?`, + time.Now(), ip, id) + return err +} + +func (db *DB) DeleteAPIKey(id int64) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`DELETE FROM api_keys WHERE id = ?`, id) + return err +} + +func (db *DB) DeactivateAPIKey(id int64) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`UPDATE api_keys SET is_active = FALSE WHERE id = ?`, id) + return err +} + +func (db *DB) CreateSession(session *Session) (int64, error) { + db.mu.Lock() + defer db.mu.Unlock() + + var apiKeyID sql.NullInt64 + if session.APIKeyID > 0 { + apiKeyID = sql.NullInt64{Int64: session.APIKeyID, Valid: true} + } + + result, err := db.conn.Exec(` + INSERT INTO sessions (session_id, user_id, api_key_id, token_hash, expires_at, created_at, client_ip) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + session.SessionID, session.UserID, apiKeyID, session.TokenHash, + session.ExpiresAt, session.CreatedAt, session.ClientIP, + ) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +func (db *DB) GetSessionByID(sessionID string) (*Session, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var s Session + var apiKeyID sql.NullInt64 + var revokedAt sql.NullTime + var clientIP sql.NullString + + err := db.conn.QueryRow(` + SELECT id, session_id, user_id, api_key_id, token_hash, expires_at, revoked_at, created_at, client_ip + FROM sessions WHERE session_id = ?`, sessionID).Scan( + &s.ID, &s.SessionID, &s.UserID, &apiKeyID, &s.TokenHash, + &s.ExpiresAt, &revokedAt, &s.CreatedAt, &clientIP, + ) + if err != nil { + return nil, err + } + + if apiKeyID.Valid { + s.APIKeyID = apiKeyID.Int64 + } + if revokedAt.Valid { + s.RevokedAt = revokedAt.Time + } + s.ClientIP = clientIP.String + return &s, nil +} + +func (db *DB) GetSessionByTokenHash(hash string) (*Session, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + var s Session + var apiKeyID sql.NullInt64 + var revokedAt sql.NullTime + var clientIP sql.NullString + + err := db.conn.QueryRow(` + SELECT id, session_id, user_id, api_key_id, token_hash, expires_at, revoked_at, created_at, client_ip + FROM sessions WHERE token_hash = ? AND revoked_at IS NULL AND expires_at > ?`, + hash, time.Now()).Scan( + &s.ID, &s.SessionID, &s.UserID, &apiKeyID, &s.TokenHash, + &s.ExpiresAt, &revokedAt, &s.CreatedAt, &clientIP, + ) + if err != nil { + return nil, err + } + + if apiKeyID.Valid { + s.APIKeyID = apiKeyID.Int64 + } + s.ClientIP = clientIP.String + return &s, nil +} + +func (db *DB) RevokeSession(sessionID string) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`UPDATE sessions SET revoked_at = ? WHERE session_id = ?`, + time.Now(), sessionID) + return err +} + +func (db *DB) RevokeUserSessions(userID int64) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`UPDATE sessions SET revoked_at = ? WHERE user_id = ? AND revoked_at IS NULL`, + time.Now(), userID) + return err +} + +func (db *DB) CleanupExpiredSessions() (int64, error) { + db.mu.Lock() + defer db.mu.Unlock() + + result, err := db.conn.Exec(`DELETE FROM sessions WHERE expires_at < ?`, time.Now()) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +func (db *DB) CreateUserDeployment(ud *UserDeployment) (int64, error) { + db.mu.Lock() + defer db.mu.Unlock() + + var grantedBy sql.NullInt64 + if ud.GrantedBy > 0 { + grantedBy = sql.NullInt64{Int64: ud.GrantedBy, Valid: true} + } + + result, err := db.conn.Exec(` + INSERT INTO user_deployments (user_id, deployment_name, access_level, granted_by, created_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(user_id, deployment_name) DO UPDATE SET access_level = ?, granted_by = ?`, + ud.UserID, ud.DeploymentName, ud.AccessLevel, grantedBy, ud.CreatedAt, + ud.AccessLevel, grantedBy, + ) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +func (db *DB) GetUserDeployments(userID int64) ([]UserDeployment, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + rows, err := db.conn.Query(` + SELECT id, user_id, deployment_name, access_level, granted_by, created_at + FROM user_deployments WHERE user_id = ? ORDER BY deployment_name`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var deployments []UserDeployment + for rows.Next() { + var ud UserDeployment + var grantedBy sql.NullInt64 + + if err := rows.Scan(&ud.ID, &ud.UserID, &ud.DeploymentName, &ud.AccessLevel, &grantedBy, &ud.CreatedAt); err != nil { + return nil, err + } + + if grantedBy.Valid { + ud.GrantedBy = grantedBy.Int64 + } + deployments = append(deployments, ud) + } + return deployments, nil +} + +func (db *DB) GetUserDeploymentsMap(userID int64) (map[string]string, error) { + deployments, err := db.GetUserDeployments(userID) + if err != nil { + return nil, err + } + + m := make(map[string]string) + for _, d := range deployments { + m[d.DeploymentName] = d.AccessLevel + } + return m, nil +} + +func (db *DB) GetDeploymentUsers(deploymentName string) ([]UserDeployment, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + rows, err := db.conn.Query(` + SELECT id, user_id, deployment_name, access_level, granted_by, created_at + FROM user_deployments WHERE deployment_name = ?`, deploymentName) + if err != nil { + return nil, err + } + defer rows.Close() + + var deployments []UserDeployment + for rows.Next() { + var ud UserDeployment + var grantedBy sql.NullInt64 + + if err := rows.Scan(&ud.ID, &ud.UserID, &ud.DeploymentName, &ud.AccessLevel, &grantedBy, &ud.CreatedAt); err != nil { + return nil, err + } + + if grantedBy.Valid { + ud.GrantedBy = grantedBy.Int64 + } + deployments = append(deployments, ud) + } + return deployments, nil +} + +func (db *DB) UpdateUserDeployment(userID int64, deploymentName, accessLevel string) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(` + UPDATE user_deployments SET access_level = ? WHERE user_id = ? AND deployment_name = ?`, + accessLevel, userID, deploymentName) + return err +} + +func (db *DB) DeleteUserDeployment(userID int64, deploymentName string) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`DELETE FROM user_deployments WHERE user_id = ? AND deployment_name = ?`, + userID, deploymentName) + return err +} + +func (db *DB) DeleteAllUserDeployments(userID int64) error { + db.mu.Lock() + defer db.mu.Unlock() + + _, err := db.conn.Exec(`DELETE FROM user_deployments WHERE user_id = ?`, userID) + return err +} diff --git a/internal/auth/db_test.go b/internal/auth/db_test.go new file mode 100644 index 0000000..8a2526b --- /dev/null +++ b/internal/auth/db_test.go @@ -0,0 +1,499 @@ +package auth + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func setupTestDB(t *testing.T) (*DB, func()) { + tmpDir, err := os.MkdirTemp("", "auth_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + db, err := NewAuthDB(tmpDir) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create DB: %v", err) + } + + cleanup := func() { + db.Close() + os.RemoveAll(tmpDir) + } + + return db, cleanup +} + +func TestNewAuthDB(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + if db == nil { + t.Fatal("NewAuthDB returned nil") + } + + if db.conn == nil { + t.Fatal("DB connection is nil") + } +} + +func TestAuthDBPath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "auth_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + db, err := NewAuthDB(tmpDir) + if err != nil { + t.Fatalf("Failed to create DB: %v", err) + } + defer db.Close() + + expectedPath := filepath.Join(tmpDir, ".flatrun", "auth.db") + if db.path != expectedPath { + t.Errorf("DB path = %s, want %s", db.path, expectedPath) + } + + if _, err := os.Stat(expectedPath); os.IsNotExist(err) { + t.Error("Database file was not created") + } +} + +func TestCreateAndGetUser(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + user := &User{ + UID: "test-uid-123", + Username: "testuser", + Email: "test@example.com", + PasswordHash: "hashedpassword", + Role: RoleOperator, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + id, err := db.CreateUser(user) + if err != nil { + t.Fatalf("CreateUser failed: %v", err) + } + + if id <= 0 { + t.Error("CreateUser should return positive ID") + } + + retrieved, err := db.GetUserByID(id) + if err != nil { + t.Fatalf("GetUserByID failed: %v", err) + } + + if retrieved.Username != user.Username { + t.Errorf("Username = %s, want %s", retrieved.Username, user.Username) + } + + if retrieved.Role != user.Role { + t.Errorf("Role = %s, want %s", retrieved.Role, user.Role) + } +} + +func TestGetUserByUsername(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + user := &User{ + UID: "test-uid-456", + Username: "findme", + PasswordHash: "hash", + Role: RoleViewer, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + _, _ = db.CreateUser(user) + + found, err := db.GetUserByUsername("findme") + if err != nil { + t.Fatalf("GetUserByUsername failed: %v", err) + } + + if found.Username != "findme" { + t.Errorf("Username = %s, want findme", found.Username) + } +} + +func TestGetUserByUID(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + user := &User{ + UID: "unique-uid-789", + Username: "uiduser", + PasswordHash: "hash", + Role: RoleAdmin, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + _, _ = db.CreateUser(user) + + found, err := db.GetUserByUID("unique-uid-789") + if err != nil { + t.Fatalf("GetUserByUID failed: %v", err) + } + + if found.UID != "unique-uid-789" { + t.Errorf("UID = %s, want unique-uid-789", found.UID) + } +} + +func TestUpdateUser(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + user := &User{ + UID: "update-uid", + Username: "updateme", + PasswordHash: "hash", + Role: RoleViewer, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + id, _ := db.CreateUser(user) + user.ID = id + user.Role = RoleOperator + user.Email = "updated@example.com" + + err := db.UpdateUser(user) + if err != nil { + t.Fatalf("UpdateUser failed: %v", err) + } + + updated, _ := db.GetUserByID(id) + if updated.Role != RoleOperator { + t.Errorf("Role = %s, want operator", updated.Role) + } + if updated.Email != "updated@example.com" { + t.Errorf("Email = %s, want updated@example.com", updated.Email) + } +} + +func TestDeleteUser(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + user := &User{ + UID: "delete-uid", + Username: "deleteme", + PasswordHash: "hash", + Role: RoleViewer, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + id, _ := db.CreateUser(user) + + err := db.DeleteUser(id) + if err != nil { + t.Fatalf("DeleteUser failed: %v", err) + } + + _, err = db.GetUserByID(id) + if err == nil { + t.Error("GetUserByID should fail after deletion") + } +} + +func TestCountUsers(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + count, _ := db.CountUsers() + if count != 0 { + t.Errorf("Initial count = %d, want 0", count) + } + + _, _ = db.CreateUser(&User{ + UID: "count-uid-1", + Username: "user1", + PasswordHash: "hash", + Role: RoleViewer, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }) + + _, _ = db.CreateUser(&User{ + UID: "count-uid-2", + Username: "user2", + PasswordHash: "hash", + Role: RoleViewer, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }) + + count, _ = db.CountUsers() + if count != 2 { + t.Errorf("Count = %d, want 2", count) + } +} + +func TestCreateAndGetAPIKey(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + user := &User{ + UID: "apikey-user-uid", + Username: "apikeyuser", + PasswordHash: "hash", + Role: RoleOperator, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + userID, _ := db.CreateUser(user) + + apiKey := &APIKey{ + KeyID: "key-id-123", + UserID: userID, + Name: "Test Key", + KeyHash: "hashed-key-value", + KeyPrefix: "fr_test...", + IsActive: true, + CreatedAt: time.Now(), + } + + id, err := db.CreateAPIKey(apiKey) + if err != nil { + t.Fatalf("CreateAPIKey failed: %v", err) + } + + retrieved, err := db.GetAPIKeyByID(id) + if err != nil { + t.Fatalf("GetAPIKeyByID failed: %v", err) + } + + if retrieved.Name != "Test Key" { + t.Errorf("Name = %s, want Test Key", retrieved.Name) + } + + if retrieved.UserID != userID { + t.Errorf("UserID = %d, want %d", retrieved.UserID, userID) + } +} + +func TestGetAPIKeyByHash(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + user := &User{ + UID: "hash-user-uid", + Username: "hashuser", + PasswordHash: "hash", + Role: RoleOperator, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + userID, _ := db.CreateUser(user) + + apiKey := &APIKey{ + KeyID: "hash-key-id", + UserID: userID, + Name: "Hash Key", + KeyHash: "unique-hash-value", + KeyPrefix: "fr_hash...", + IsActive: true, + CreatedAt: time.Now(), + } + _, _ = db.CreateAPIKey(apiKey) + + found, err := db.GetAPIKeyByHash("unique-hash-value") + if err != nil { + t.Fatalf("GetAPIKeyByHash failed: %v", err) + } + + if found.KeyHash != "unique-hash-value" { + t.Errorf("KeyHash = %s, want unique-hash-value", found.KeyHash) + } +} + +func TestUserDeployments(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + user := &User{ + UID: "deploy-user-uid", + Username: "deployuser", + PasswordHash: "hash", + Role: RoleOperator, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + userID, _ := db.CreateUser(user) + + ud := &UserDeployment{ + UserID: userID, + DeploymentName: "my-app", + AccessLevel: "write", + CreatedAt: time.Now(), + } + + _, err := db.CreateUserDeployment(ud) + if err != nil { + t.Fatalf("CreateUserDeployment failed: %v", err) + } + + deployments, err := db.GetUserDeployments(userID) + if err != nil { + t.Fatalf("GetUserDeployments failed: %v", err) + } + + if len(deployments) != 1 { + t.Fatalf("Expected 1 deployment, got %d", len(deployments)) + } + + if deployments[0].DeploymentName != "my-app" { + t.Errorf("DeploymentName = %s, want my-app", deployments[0].DeploymentName) + } + + if deployments[0].AccessLevel != "write" { + t.Errorf("AccessLevel = %s, want write", deployments[0].AccessLevel) + } +} + +func TestUserDeploymentsMap(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + user := &User{ + UID: "map-user-uid", + Username: "mapuser", + PasswordHash: "hash", + Role: RoleOperator, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + userID, _ := db.CreateUser(user) + + _, _ = db.CreateUserDeployment(&UserDeployment{ + UserID: userID, + DeploymentName: "app-a", + AccessLevel: "read", + CreatedAt: time.Now(), + }) + + _, _ = db.CreateUserDeployment(&UserDeployment{ + UserID: userID, + DeploymentName: "app-b", + AccessLevel: "admin", + CreatedAt: time.Now(), + }) + + depMap, err := db.GetUserDeploymentsMap(userID) + if err != nil { + t.Fatalf("GetUserDeploymentsMap failed: %v", err) + } + + if depMap["app-a"] != "read" { + t.Errorf("app-a access = %s, want read", depMap["app-a"]) + } + + if depMap["app-b"] != "admin" { + t.Errorf("app-b access = %s, want admin", depMap["app-b"]) + } +} + +func TestDeleteUserDeployment(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + user := &User{ + UID: "del-deploy-uid", + Username: "deldeployuser", + PasswordHash: "hash", + Role: RoleOperator, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + userID, _ := db.CreateUser(user) + + _, _ = db.CreateUserDeployment(&UserDeployment{ + UserID: userID, + DeploymentName: "to-remove", + AccessLevel: "read", + CreatedAt: time.Now(), + }) + + err := db.DeleteUserDeployment(userID, "to-remove") + if err != nil { + t.Fatalf("DeleteUserDeployment failed: %v", err) + } + + deployments, _ := db.GetUserDeployments(userID) + if len(deployments) != 0 { + t.Errorf("Expected 0 deployments after deletion, got %d", len(deployments)) + } +} + +func TestSessionOperations(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + user := &User{ + UID: "session-user-uid", + Username: "sessionuser", + PasswordHash: "hash", + Role: RoleOperator, + IsActive: true, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + userID, _ := db.CreateUser(user) + + session := &Session{ + SessionID: "session-123", + UserID: userID, + TokenHash: "token-hash-value", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + ClientIP: "127.0.0.1", + } + + _, err := db.CreateSession(session) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + + found, err := db.GetSessionByTokenHash("token-hash-value") + if err != nil { + t.Fatalf("GetSessionByTokenHash failed: %v", err) + } + + if found.SessionID != "session-123" { + t.Errorf("SessionID = %s, want session-123", found.SessionID) + } + + err = db.RevokeSession("session-123") + if err != nil { + t.Fatalf("RevokeSession failed: %v", err) + } + + _, err = db.GetSessionByTokenHash("token-hash-value") + if err == nil { + t.Error("Revoked session should not be returned") + } +} diff --git a/internal/auth/manager.go b/internal/auth/manager.go new file mode 100644 index 0000000..9f4bed6 --- /dev/null +++ b/internal/auth/manager.go @@ -0,0 +1,419 @@ +package auth + +import ( + "database/sql" + "errors" + "fmt" + "log" + "os" + "time" + + "github.com/flatrun/agent/pkg/config" +) + +var ( + ErrUserNotFound = errors.New("user not found") + ErrUserExists = errors.New("user already exists") + ErrInvalidPassword = errors.New("invalid password") + ErrAPIKeyNotFound = errors.New("api key not found") + ErrAPIKeyExpired = errors.New("api key has expired") + ErrAPIKeyInactive = errors.New("api key is inactive") + ErrSessionNotFound = errors.New("session not found") + ErrSessionExpired = errors.New("session has expired") + ErrSessionRevoked = errors.New("session has been revoked") + ErrInvalidRole = errors.New("invalid role") + ErrUserInactive = errors.New("user account is inactive") + ErrCannotDeleteSelf = errors.New("cannot delete your own account") +) + +type Manager struct { + db *DB + config *config.AuthConfig +} + +func NewManager(deploymentsPath string, cfg *config.AuthConfig) (*Manager, error) { + db, err := NewAuthDB(deploymentsPath) + if err != nil { + return nil, fmt.Errorf("failed to initialize auth database: %w", err) + } + + m := &Manager{ + db: db, + config: cfg, + } + + if err := m.ensureAdminUser(); err != nil { + log.Printf("Warning: failed to ensure admin user: %v", err) + } + + return m, nil +} + +func (m *Manager) Close() error { + return m.db.Close() +} + +func (m *Manager) ensureAdminUser() error { + count, err := m.db.CountUsers() + if err != nil { + return err + } + + if count > 0 { + return nil + } + + adminPassword := os.Getenv("FLATRUN_ADMIN_PASSWORD") + if adminPassword == "" { + adminPassword = "admin" + log.Println("WARNING: No users exist and FLATRUN_ADMIN_PASSWORD not set. Creating admin user with default password 'admin'. Please change this immediately!") + } + + _, err = m.CreateUser("admin", "", adminPassword, RoleAdmin, nil) + if err != nil { + return fmt.Errorf("failed to create admin user: %w", err) + } + + log.Println("Created initial admin user") + return nil +} + +func (m *Manager) CreateUser(username, email, password string, role Role, permissions []string) (*User, error) { + if !role.IsValid() { + return nil, ErrInvalidRole + } + + passwordHash, err := HashPassword(password) + if err != nil { + return nil, fmt.Errorf("failed to hash password: %w", err) + } + + uid, err := GenerateUID() + if err != nil { + return nil, fmt.Errorf("failed to generate UID: %w", err) + } + + now := time.Now() + user := &User{ + UID: uid, + Username: username, + Email: email, + PasswordHash: passwordHash, + Role: role, + Permissions: permissions, + IsActive: true, + CreatedAt: now, + UpdatedAt: now, + } + + id, err := m.db.CreateUser(user) + if err != nil { + return nil, err + } + + user.ID = id + return user, nil +} + +func (m *Manager) GetUser(id int64) (*User, error) { + user, err := m.db.GetUserByID(id) + if err == sql.ErrNoRows { + return nil, ErrUserNotFound + } + return user, err +} + +func (m *Manager) GetUserByUID(uid string) (*User, error) { + user, err := m.db.GetUserByUID(uid) + if err == sql.ErrNoRows { + return nil, ErrUserNotFound + } + return user, err +} + +func (m *Manager) GetUserByUsername(username string) (*User, error) { + user, err := m.db.GetUserByUsername(username) + if err == sql.ErrNoRows { + return nil, ErrUserNotFound + } + return user, err +} + +func (m *Manager) GetUsers() ([]User, error) { + return m.db.GetUsers() +} + +func (m *Manager) UpdateUser(user *User) error { + if !user.Role.IsValid() { + return ErrInvalidRole + } + return m.db.UpdateUser(user) +} + +func (m *Manager) UpdatePassword(userID int64, newPassword string) error { + hash, err := HashPassword(newPassword) + if err != nil { + return err + } + return m.db.UpdateUserPassword(userID, hash) +} + +func (m *Manager) DeleteUser(id int64, actorID int64) error { + if id == actorID { + return ErrCannotDeleteSelf + } + return m.db.DeleteUser(id) +} + +func (m *Manager) ValidateCredentials(username, password string) (*User, error) { + user, err := m.db.GetUserByUsername(username) + if err == sql.ErrNoRows { + return nil, ErrUserNotFound + } + if err != nil { + return nil, err + } + + if !user.IsActive { + return nil, ErrUserInactive + } + + if !VerifyPassword(password, user.PasswordHash) { + return nil, ErrInvalidPassword + } + + _ = m.db.UpdateUserLastLogin(user.ID) + return user, nil +} + +func (m *Manager) CreateAPIKey(userID int64, name, description string, role Role, permissions, deployments []string, expiresAt time.Time) (*APIKey, string, error) { + plainKey, keyHash, keyID, prefix, err := GenerateAPIKey() + if err != nil { + return nil, "", err + } + + key := &APIKey{ + KeyID: keyID, + UserID: userID, + Name: name, + Description: description, + KeyHash: keyHash, + KeyPrefix: prefix, + Role: role, + Permissions: permissions, + Deployments: deployments, + ExpiresAt: expiresAt, + IsActive: true, + CreatedAt: time.Now(), + } + + id, err := m.db.CreateAPIKey(key) + if err != nil { + return nil, "", err + } + + key.ID = id + return key, plainKey, nil +} + +func (m *Manager) GetAPIKey(id int64) (*APIKey, error) { + key, err := m.db.GetAPIKeyByID(id) + if err == sql.ErrNoRows { + return nil, ErrAPIKeyNotFound + } + return key, err +} + +func (m *Manager) GetAPIKeysByUser(userID int64) ([]APIKey, error) { + return m.db.GetAPIKeysByUserID(userID) +} + +func (m *Manager) GetAllAPIKeys() ([]APIKey, error) { + return m.db.GetAllAPIKeys() +} + +func (m *Manager) ValidateAPIKey(plainKey string) (*APIKey, *User, error) { + hash := HashAPIKey(plainKey) + key, err := m.db.GetAPIKeyByHash(hash) + if err == sql.ErrNoRows { + return nil, nil, ErrAPIKeyNotFound + } + if err != nil { + return nil, nil, err + } + + if !key.IsActive { + return nil, nil, ErrAPIKeyInactive + } + + if !key.ExpiresAt.IsZero() && key.ExpiresAt.Before(time.Now()) { + return nil, nil, ErrAPIKeyExpired + } + + user, err := m.db.GetUserByID(key.UserID) + if err != nil { + return nil, nil, err + } + + if !user.IsActive { + return nil, nil, ErrUserInactive + } + + return key, user, nil +} + +func (m *Manager) UpdateAPIKeyLastUsed(keyID int64, ip string) error { + return m.db.UpdateAPIKeyLastUsed(keyID, ip) +} + +func (m *Manager) DeleteAPIKey(id int64) error { + return m.db.DeleteAPIKey(id) +} + +func (m *Manager) DeactivateAPIKey(id int64) error { + return m.db.DeactivateAPIKey(id) +} + +func (m *Manager) CreateSession(userID int64, apiKeyID int64, sessionID, tokenHash, clientIP string, expiresAt time.Time) (*Session, error) { + if sessionID == "" { + var err error + sessionID, err = GenerateSessionID() + if err != nil { + return nil, err + } + } + + session := &Session{ + SessionID: sessionID, + UserID: userID, + APIKeyID: apiKeyID, + TokenHash: tokenHash, + ExpiresAt: expiresAt, + CreatedAt: time.Now(), + ClientIP: clientIP, + } + + id, err := m.db.CreateSession(session) + if err != nil { + return nil, err + } + + session.ID = id + return session, nil +} + +func (m *Manager) GetSessionByToken(tokenHash string) (*Session, error) { + session, err := m.db.GetSessionByTokenHash(tokenHash) + if err == sql.ErrNoRows { + return nil, ErrSessionNotFound + } + return session, err +} + +func (m *Manager) GetSessionByID(sessionID string) (*Session, error) { + session, err := m.db.GetSessionByID(sessionID) + if err == sql.ErrNoRows { + return nil, ErrSessionNotFound + } + return session, err +} + +func (m *Manager) RevokeSession(sessionID string) error { + return m.db.RevokeSession(sessionID) +} + +func (m *Manager) RevokeUserSessions(userID int64) error { + return m.db.RevokeUserSessions(userID) +} + +func (m *Manager) CleanupExpiredSessions() (int64, error) { + return m.db.CleanupExpiredSessions() +} + +func (m *Manager) AssignDeployment(userID int64, deploymentName, accessLevel string, grantedBy int64) error { + ud := &UserDeployment{ + UserID: userID, + DeploymentName: deploymentName, + AccessLevel: accessLevel, + GrantedBy: grantedBy, + CreatedAt: time.Now(), + } + _, err := m.db.CreateUserDeployment(ud) + return err +} + +func (m *Manager) GetUserDeployments(userID int64) ([]UserDeployment, error) { + return m.db.GetUserDeployments(userID) +} + +func (m *Manager) GetUserDeploymentsMap(userID int64) (map[string]string, error) { + return m.db.GetUserDeploymentsMap(userID) +} + +func (m *Manager) GetDeploymentUsers(deploymentName string) ([]UserDeployment, error) { + return m.db.GetDeploymentUsers(deploymentName) +} + +func (m *Manager) UpdateUserDeployment(userID int64, deploymentName, accessLevel string) error { + return m.db.UpdateUserDeployment(userID, deploymentName, accessLevel) +} + +func (m *Manager) RemoveDeploymentAccess(userID int64, deploymentName string) error { + return m.db.DeleteUserDeployment(userID, deploymentName) +} + +func (m *Manager) BuildActorContext(user *User, apiKey *APIKey) (*ActorContext, error) { + actor := &ActorContext{ + User: user, + APIKey: apiKey, + } + + if user != nil { + actor.UserID = user.ID + actor.Role = user.Role + + deployments, err := m.GetUserDeploymentsMap(user.ID) + if err != nil { + return nil, err + } + actor.Deployments = deployments + } + + if user != nil && len(user.Permissions) > 0 { + actor.Permissions = user.Permissions + } + + if apiKey != nil { + actor.Type = "api_key" + + if apiKey.Role != "" { + actor.Role = apiKey.Role + } + + if len(apiKey.Permissions) > 0 { + actor.Permissions = apiKey.Permissions + } + } else if user != nil { + actor.Type = "user" + } + + return actor, nil +} + +func (m *Manager) ValidateLegacyAPIKey(key string) bool { + for _, validKey := range m.config.APIKeys { + if key == validKey { + return true + } + } + return false +} + +func (m *Manager) GetLegacyKeyIndex(key string) int { + for i, validKey := range m.config.APIKeys { + if key == validKey { + return i + } + } + return -1 +} diff --git a/internal/auth/manager_test.go b/internal/auth/manager_test.go new file mode 100644 index 0000000..3fe135e --- /dev/null +++ b/internal/auth/manager_test.go @@ -0,0 +1,555 @@ +package auth + +import ( + "os" + "testing" + "time" + + "github.com/flatrun/agent/pkg/config" +) + +func setupTestManager(t *testing.T) (*Manager, func()) { + tmpDir, err := os.MkdirTemp("", "manager_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + cfg := &config.AuthConfig{ + JWTSecret: "test-secret", + APIKeys: []string{"legacy-key-1", "legacy-key-2"}, + } + + os.Setenv("FLATRUN_ADMIN_PASSWORD", "testadminpass") + defer os.Unsetenv("FLATRUN_ADMIN_PASSWORD") + + manager, err := NewManager(tmpDir, cfg) + if err != nil { + os.RemoveAll(tmpDir) + t.Fatalf("Failed to create manager: %v", err) + } + + cleanup := func() { + manager.Close() + os.RemoveAll(tmpDir) + } + + return manager, cleanup +} + +func TestNewManager(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + if manager == nil { + t.Fatal("NewManager returned nil") + } +} + +func TestManagerCreatesAdminUser(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "manager_admin_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + os.Setenv("FLATRUN_ADMIN_PASSWORD", "secureadminpass") + defer os.Unsetenv("FLATRUN_ADMIN_PASSWORD") + + cfg := &config.AuthConfig{JWTSecret: "test"} + manager, err := NewManager(tmpDir, cfg) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer manager.Close() + + admin, err := manager.GetUserByUsername("admin") + if err != nil { + t.Fatalf("Admin user should exist: %v", err) + } + + if admin.Role != RoleAdmin { + t.Errorf("Admin user role = %s, want admin", admin.Role) + } + + if !admin.IsActive { + t.Error("Admin user should be active") + } +} + +func TestManagerCreateUser(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, err := manager.CreateUser("testuser", "test@example.com", "password123", RoleOperator, nil) + if err != nil { + t.Fatalf("CreateUser failed: %v", err) + } + + if user.Username != "testuser" { + t.Errorf("Username = %s, want testuser", user.Username) + } + + if user.Role != RoleOperator { + t.Errorf("Role = %s, want operator", user.Role) + } + + if user.UID == "" { + t.Error("UID should be generated") + } +} + +func TestManagerCreateUserInvalidRole(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + _, err := manager.CreateUser("baduser", "", "pass", Role("invalid"), nil) + if err != ErrInvalidRole { + t.Errorf("Expected ErrInvalidRole, got %v", err) + } +} + +func TestManagerGetUser(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + created, _ := manager.CreateUser("findme", "", "pass", RoleViewer, nil) + + found, err := manager.GetUser(created.ID) + if err != nil { + t.Fatalf("GetUser failed: %v", err) + } + + if found.Username != "findme" { + t.Errorf("Username = %s, want findme", found.Username) + } +} + +func TestManagerGetUserNotFound(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + _, err := manager.GetUser(99999) + if err != ErrUserNotFound { + t.Errorf("Expected ErrUserNotFound, got %v", err) + } +} + +func TestManagerValidateCredentials(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + _, _ = manager.CreateUser("authuser", "", "correctpassword", RoleOperator, nil) + + user, err := manager.ValidateCredentials("authuser", "correctpassword") + if err != nil { + t.Fatalf("ValidateCredentials failed: %v", err) + } + + if user.Username != "authuser" { + t.Errorf("Username = %s, want authuser", user.Username) + } +} + +func TestManagerValidateCredentialsWrongPassword(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + _, _ = manager.CreateUser("authuser", "", "correctpassword", RoleOperator, nil) + + _, err := manager.ValidateCredentials("authuser", "wrongpassword") + if err != ErrInvalidPassword { + t.Errorf("Expected ErrInvalidPassword, got %v", err) + } +} + +func TestManagerValidateCredentialsUserNotFound(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + _, err := manager.ValidateCredentials("nonexistent", "password") + if err != ErrUserNotFound { + t.Errorf("Expected ErrUserNotFound, got %v", err) + } +} + +func TestManagerValidateCredentialsInactiveUser(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("inactive", "", "pass", RoleViewer, nil) + user.IsActive = false + _ = manager.UpdateUser(user) + + _, err := manager.ValidateCredentials("inactive", "pass") + if err != ErrUserInactive { + t.Errorf("Expected ErrUserInactive, got %v", err) + } +} + +func TestManagerDeleteUser(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("todelete", "", "pass", RoleViewer, nil) + actorID := int64(99999) + + err := manager.DeleteUser(user.ID, actorID) + if err != nil { + t.Fatalf("DeleteUser failed: %v", err) + } + + _, err = manager.GetUser(user.ID) + if err != ErrUserNotFound { + t.Error("User should be deleted") + } +} + +func TestManagerDeleteUserCannotDeleteSelf(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("selfdelete", "", "pass", RoleAdmin, nil) + + err := manager.DeleteUser(user.ID, user.ID) + if err != ErrCannotDeleteSelf { + t.Errorf("Expected ErrCannotDeleteSelf, got %v", err) + } +} + +func TestManagerUpdatePassword(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("passchange", "", "oldpass", RoleViewer, nil) + + err := manager.UpdatePassword(user.ID, "newpassword") + if err != nil { + t.Fatalf("UpdatePassword failed: %v", err) + } + + _, err = manager.ValidateCredentials("passchange", "newpassword") + if err != nil { + t.Error("Should be able to login with new password") + } + + _, err = manager.ValidateCredentials("passchange", "oldpass") + if err != ErrInvalidPassword { + t.Error("Old password should not work") + } +} + +func TestManagerCreateAPIKey(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("keyowner", "", "pass", RoleOperator, nil) + + key, plainKey, err := manager.CreateAPIKey(user.ID, "Test Key", "Testing", "", nil, nil, time.Time{}) + if err != nil { + t.Fatalf("CreateAPIKey failed: %v", err) + } + + if plainKey == "" { + t.Error("Plain key should be returned") + } + + if key.Name != "Test Key" { + t.Errorf("Name = %s, want Test Key", key.Name) + } + + if key.UserID != user.ID { + t.Errorf("UserID = %d, want %d", key.UserID, user.ID) + } +} + +func TestManagerValidateAPIKey(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("keyuser", "", "pass", RoleOperator, nil) + _, plainKey, _ := manager.CreateAPIKey(user.ID, "Valid Key", "", "", nil, nil, time.Time{}) + + key, foundUser, err := manager.ValidateAPIKey(plainKey) + if err != nil { + t.Fatalf("ValidateAPIKey failed: %v", err) + } + + if key.Name != "Valid Key" { + t.Errorf("Key name = %s, want Valid Key", key.Name) + } + + if foundUser.ID != user.ID { + t.Errorf("User ID = %d, want %d", foundUser.ID, user.ID) + } +} + +func TestManagerValidateAPIKeyExpired(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("expiredkey", "", "pass", RoleOperator, nil) + _, plainKey, _ := manager.CreateAPIKey(user.ID, "Expired Key", "", "", nil, nil, time.Now().Add(-1*time.Hour)) + + _, _, err := manager.ValidateAPIKey(plainKey) + if err != ErrAPIKeyExpired { + t.Errorf("Expected ErrAPIKeyExpired, got %v", err) + } +} + +func TestManagerValidateAPIKeyInactive(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("inactivekey", "", "pass", RoleOperator, nil) + key, plainKey, _ := manager.CreateAPIKey(user.ID, "Inactive Key", "", "", nil, nil, time.Time{}) + + _ = manager.DeactivateAPIKey(key.ID) + + _, _, err := manager.ValidateAPIKey(plainKey) + if err != ErrAPIKeyInactive { + t.Errorf("Expected ErrAPIKeyInactive, got %v", err) + } +} + +func TestManagerValidateAPIKeyNotFound(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + _, _, err := manager.ValidateAPIKey("fr_nonexistent_key") + if err != ErrAPIKeyNotFound { + t.Errorf("Expected ErrAPIKeyNotFound, got %v", err) + } +} + +func TestManagerSessionLifecycle(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("sessionuser", "", "pass", RoleViewer, nil) + + tokenHash := HashAPIKey("test-token") + session, err := manager.CreateSession(user.ID, 0, "", tokenHash, "127.0.0.1", time.Now().Add(24*time.Hour)) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + + found, err := manager.GetSessionByToken(tokenHash) + if err != nil { + t.Fatalf("GetSessionByToken failed: %v", err) + } + + if found.SessionID != session.SessionID { + t.Error("Session ID mismatch") + } + + err = manager.RevokeSession(session.SessionID) + if err != nil { + t.Fatalf("RevokeSession failed: %v", err) + } + + _, err = manager.GetSessionByToken(tokenHash) + if err == nil { + t.Error("Revoked session should not be found") + } +} + +func TestManagerDeploymentAccess(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("deployuser", "", "pass", RoleOperator, nil) + admin, _ := manager.GetUserByUsername("admin") + + err := manager.AssignDeployment(user.ID, "my-app", "write", admin.ID) + if err != nil { + t.Fatalf("AssignDeployment failed: %v", err) + } + + deployments, err := manager.GetUserDeployments(user.ID) + if err != nil { + t.Fatalf("GetUserDeployments failed: %v", err) + } + + if len(deployments) != 1 { + t.Fatalf("Expected 1 deployment, got %d", len(deployments)) + } + + if deployments[0].DeploymentName != "my-app" { + t.Errorf("DeploymentName = %s, want my-app", deployments[0].DeploymentName) + } + + if deployments[0].AccessLevel != "write" { + t.Errorf("AccessLevel = %s, want write", deployments[0].AccessLevel) + } +} + +func TestManagerDeploymentAccessMap(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("mapuser", "", "pass", RoleOperator, nil) + admin, _ := manager.GetUserByUsername("admin") + + _ = manager.AssignDeployment(user.ID, "app-a", "read", admin.ID) + _ = manager.AssignDeployment(user.ID, "app-b", "write", admin.ID) + + depMap, err := manager.GetUserDeploymentsMap(user.ID) + if err != nil { + t.Fatalf("GetUserDeploymentsMap failed: %v", err) + } + + if depMap["app-a"] != "read" { + t.Errorf("app-a access = %s, want read", depMap["app-a"]) + } + + if depMap["app-b"] != "write" { + t.Errorf("app-b access = %s, want write", depMap["app-b"]) + } +} + +func TestManagerUpdateDeploymentAccess(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("updateaccess", "", "pass", RoleOperator, nil) + admin, _ := manager.GetUserByUsername("admin") + + _ = manager.AssignDeployment(user.ID, "my-app", "read", admin.ID) + + err := manager.UpdateUserDeployment(user.ID, "my-app", "admin") + if err != nil { + t.Fatalf("UpdateUserDeployment failed: %v", err) + } + + depMap, _ := manager.GetUserDeploymentsMap(user.ID) + if depMap["my-app"] != "admin" { + t.Errorf("Access level = %s, want admin", depMap["my-app"]) + } +} + +func TestManagerRemoveDeploymentAccess(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("removeaccess", "", "pass", RoleOperator, nil) + admin, _ := manager.GetUserByUsername("admin") + + _ = manager.AssignDeployment(user.ID, "to-remove", "write", admin.ID) + + err := manager.RemoveDeploymentAccess(user.ID, "to-remove") + if err != nil { + t.Fatalf("RemoveDeploymentAccess failed: %v", err) + } + + deployments, _ := manager.GetUserDeployments(user.ID) + if len(deployments) != 0 { + t.Errorf("Expected 0 deployments, got %d", len(deployments)) + } +} + +func TestManagerBuildActorContext(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("actoruser", "", "pass", RoleOperator, nil) + admin, _ := manager.GetUserByUsername("admin") + _ = manager.AssignDeployment(user.ID, "my-app", "write", admin.ID) + + actor, err := manager.BuildActorContext(user, nil) + if err != nil { + t.Fatalf("BuildActorContext failed: %v", err) + } + + if actor.Type != "user" { + t.Errorf("Type = %s, want user", actor.Type) + } + + if actor.Role != RoleOperator { + t.Errorf("Role = %s, want operator", actor.Role) + } + + if actor.Deployments["my-app"] != "write" { + t.Error("Deployments should include my-app with write access") + } +} + +func TestManagerBuildActorContextWithAPIKey(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user, _ := manager.CreateUser("apikeyactor", "", "pass", RoleOperator, nil) + key, _, _ := manager.CreateAPIKey(user.ID, "Test Key", "", RoleViewer, []string{"deployments:read"}, nil, time.Time{}) + + fetchedKey, _ := manager.GetAPIKey(key.ID) + + actor, err := manager.BuildActorContext(user, fetchedKey) + if err != nil { + t.Fatalf("BuildActorContext failed: %v", err) + } + + if actor.Type != "api_key" { + t.Errorf("Type = %s, want api_key", actor.Type) + } + + if actor.Role != RoleViewer { + t.Errorf("API key role override should make Role = viewer, got %s", actor.Role) + } +} + +func TestManagerLegacyAPIKeys(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + if !manager.ValidateLegacyAPIKey("legacy-key-1") { + t.Error("Should validate legacy-key-1") + } + + if !manager.ValidateLegacyAPIKey("legacy-key-2") { + t.Error("Should validate legacy-key-2") + } + + if manager.ValidateLegacyAPIKey("invalid-key") { + t.Error("Should not validate invalid key") + } +} + +func TestManagerGetLegacyKeyIndex(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + idx := manager.GetLegacyKeyIndex("legacy-key-1") + if idx != 0 { + t.Errorf("Index = %d, want 0", idx) + } + + idx = manager.GetLegacyKeyIndex("legacy-key-2") + if idx != 1 { + t.Errorf("Index = %d, want 1", idx) + } + + idx = manager.GetLegacyKeyIndex("nonexistent") + if idx != -1 { + t.Errorf("Index = %d, want -1", idx) + } +} + +func TestManagerGetDeploymentUsers(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + user1, _ := manager.CreateUser("depuser1", "", "pass", RoleOperator, nil) + user2, _ := manager.CreateUser("depuser2", "", "pass", RoleViewer, nil) + admin, _ := manager.GetUserByUsername("admin") + + _ = manager.AssignDeployment(user1.ID, "shared-app", "write", admin.ID) + _ = manager.AssignDeployment(user2.ID, "shared-app", "read", admin.ID) + + users, err := manager.GetDeploymentUsers("shared-app") + if err != nil { + t.Fatalf("GetDeploymentUsers failed: %v", err) + } + + if len(users) != 2 { + t.Errorf("Expected 2 users, got %d", len(users)) + } +} diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index c6af4e2..3e5b59b 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -1,8 +1,11 @@ package auth import ( + "crypto/sha256" "crypto/subtle" + "encoding/hex" "fmt" + "log" "net/http" "strings" "time" @@ -14,22 +17,37 @@ import ( ) type Claims struct { - Username string `json:"username"` + Username string `json:"username"` + UserID int64 `json:"user_id,omitempty"` + SessionID string `json:"session_id,omitempty"` jwt.RegisteredClaims } type Middleware struct { - config *config.AuthConfig + config *config.AuthConfig + manager *Manager } func NewMiddleware(cfg *config.AuthConfig) *Middleware { return &Middleware{config: cfg} } +func NewMiddlewareWithManager(cfg *config.AuthConfig, manager *Manager) *Middleware { + return &Middleware{config: cfg, manager: manager} +} + +func (m *Middleware) SetManager(manager *Manager) { + m.manager = manager +} + func (m *Middleware) RequireAuth() gin.HandlerFunc { return func(c *gin.Context) { if !m.config.Enabled { c.Set(contextkeys.ActorType, "anonymous") + c.Set(contextkeys.Actor, &ActorContext{ + Type: "anonymous", + Role: RoleAdmin, + }) c.Next() return } @@ -58,20 +76,17 @@ func (m *Middleware) RequireAuth() gin.HandlerFunc { switch scheme { case "bearer": if claims := m.validateJWTWithClaims(token); claims != nil { - c.Set(contextkeys.ActorType, "jwt") - c.Set(contextkeys.ActorID, claims.Username) - c.Set(contextkeys.ActorName, claims.Username) - c.Next() - return + if err := m.setJWTContext(c, claims, token); err == nil { + c.Next() + return + } } - if keyIndex := m.validateAPIKeyWithIndex(token); keyIndex >= 0 { - m.setAPIKeyContext(c, token, keyIndex) + if m.handleAPIKey(c, token) { c.Next() return } case "apikey": - if keyIndex := m.validateAPIKeyWithIndex(token); keyIndex >= 0 { - m.setAPIKeyContext(c, token, keyIndex) + if m.handleAPIKey(c, token) { c.Next() return } @@ -90,7 +105,71 @@ func (m *Middleware) RequireAuth() gin.HandlerFunc { } } -func (m *Middleware) setAPIKeyContext(c *gin.Context, token string, keyIndex int) { +func (m *Middleware) handleAPIKey(c *gin.Context, token string) bool { + if m.manager != nil { + apiKey, user, err := m.manager.ValidateAPIKey(token) + if err == nil { + actor, err := m.manager.BuildActorContext(user, apiKey) + if err == nil { + c.Set(contextkeys.ActorType, "api_key") + c.Set(contextkeys.ActorID, fmt.Sprintf("key_%s", apiKey.KeyID)) + c.Set(contextkeys.ActorName, user.Username) + c.Set(contextkeys.APIKeyPrefix, apiKey.KeyPrefix) + c.Set(contextkeys.Actor, actor) + + go func() { _ = m.manager.UpdateAPIKeyLastUsed(apiKey.ID, c.ClientIP()) }() + return true + } + } + } + + if keyIndex := m.validateAPIKeyWithIndex(token); keyIndex >= 0 { + m.setLegacyAPIKeyContext(c, token, keyIndex) + log.Printf("Warning: Legacy API key used. Consider migrating to user-based API keys.") + return true + } + + return false +} + +func (m *Middleware) setJWTContext(c *gin.Context, claims *Claims, token string) error { + c.Set(contextkeys.ActorType, "jwt") + c.Set(contextkeys.ActorID, claims.Username) + c.Set(contextkeys.ActorName, claims.Username) + + if m.manager != nil && claims.SessionID != "" { + session, err := m.manager.GetSessionByID(claims.SessionID) + if err != nil || !session.RevokedAt.IsZero() { + return fmt.Errorf("session revoked or invalid") + } + } + + if m.manager != nil && claims.UserID > 0 { + user, err := m.manager.GetUser(claims.UserID) + if err != nil { + return err + } + + if !user.IsActive { + return ErrUserInactive + } + + actor, err := m.manager.BuildActorContext(user, nil) + if err != nil { + return err + } + c.Set(contextkeys.Actor, actor) + } else { + c.Set(contextkeys.Actor, &ActorContext{ + Type: "jwt", + Role: RoleAdmin, + }) + } + + return nil +} + +func (m *Middleware) setLegacyAPIKeyContext(c *gin.Context, token string, keyIndex int) { c.Set(contextkeys.ActorType, "api_key") c.Set(contextkeys.ActorID, fmt.Sprintf("key_%d", keyIndex)) if len(token) >= 8 { @@ -98,6 +177,11 @@ func (m *Middleware) setAPIKeyContext(c *gin.Context, token string, keyIndex int } else { c.Set(contextkeys.APIKeyPrefix, token+"...") } + + c.Set(contextkeys.Actor, &ActorContext{ + Type: "legacy_key", + Role: RoleAdmin, + }) } func (m *Middleware) validateJWTWithClaims(tokenString string) *Claims { @@ -171,6 +255,24 @@ func (m *Middleware) GenerateJWT(username string) (string, error) { return token.SignedString([]byte(m.config.JWTSecret)) } +func (m *Middleware) GenerateJWTForUser(user *User, sessionID string) (string, error) { + claims := Claims{ + Username: user.Username, + UserID: user.ID, + SessionID: sessionID, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Issuer: "flatrun-agent", + Subject: user.UID, + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(m.config.JWTSecret)) +} + func (m *Middleware) ValidateToken(c *gin.Context) { if !m.config.Enabled { c.JSON(http.StatusOK, gin.H{ @@ -207,6 +309,16 @@ func (m *Middleware) ValidateToken(c *gin.Context) { return } + if m.manager != nil { + if _, _, err := m.manager.ValidateAPIKey(token); err == nil { + c.JSON(http.StatusOK, gin.H{ + "valid": true, + "message": "Token is valid", + }) + return + } + } + c.JSON(http.StatusUnauthorized, gin.H{ "valid": false, "error": "Invalid token", @@ -223,7 +335,9 @@ func (m *Middleware) Login(c *gin.Context) { } var req struct { - APIKey string `json:"api_key"` + APIKey string `json:"api_key"` + Username string `json:"username"` + Password string `json:"password"` } if err := c.ShouldBindJSON(&req); err != nil { @@ -233,6 +347,118 @@ func (m *Middleware) Login(c *gin.Context) { return } + if req.Username != "" && req.Password != "" && m.manager != nil { + user, err := m.manager.ValidateCredentials(req.Username, req.Password) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Invalid username or password", + }) + return + } + + sessionID, err := GenerateSessionID() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate session", + }) + return + } + + token, err := m.GenerateJWTForUser(user, sessionID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate token", + }) + return + } + + tokenHash := sha256.Sum256([]byte(token)) + _, _ = m.manager.CreateSession(user.ID, 0, sessionID, hex.EncodeToString(tokenHash[:]), c.ClientIP(), time.Now().Add(24*time.Hour)) + + deployments, _ := m.manager.GetUserDeployments(user.ID) + depAccess := make([]gin.H, 0, len(deployments)) + for _, d := range deployments { + depAccess = append(depAccess, gin.H{ + "deployment_name": d.DeploymentName, + "access_level": d.AccessLevel, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "token": token, + "expires_in": 86400, + "token_type": "Bearer", + "user": userResponse(user), + "permissions": EffectivePermissions(user, user.Role), + "deployments": depAccess, + }) + return + } + + if req.APIKey == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "error": "API key or username/password required", + }) + return + } + + if m.manager != nil { + apiKey, user, err := m.manager.ValidateAPIKey(req.APIKey) + if err == nil { + sessionID, err := GenerateSessionID() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate session", + }) + return + } + + token, err := m.GenerateJWTForUser(user, sessionID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Failed to generate token", + }) + return + } + + tokenHash := sha256.Sum256([]byte(token)) + _, _ = m.manager.CreateSession(user.ID, apiKey.ID, sessionID, hex.EncodeToString(tokenHash[:]), c.ClientIP(), time.Now().Add(24*time.Hour)) + + role := user.Role + if apiKey.Role != "" { + role = apiKey.Role + } + + perms := GetRolePermissions(role) + if len(apiKey.Permissions) > 0 { + customPerms := make([]Permission, 0, len(apiKey.Permissions)) + for _, p := range apiKey.Permissions { + customPerms = append(customPerms, Permission(p)) + } + perms = customPerms + } + + deployments, _ := m.manager.GetUserDeployments(user.ID) + depAccess := make([]gin.H, 0, len(deployments)) + for _, d := range deployments { + depAccess = append(depAccess, gin.H{ + "deployment_name": d.DeploymentName, + "access_level": d.AccessLevel, + }) + } + + c.JSON(http.StatusOK, gin.H{ + "token": token, + "expires_in": 86400, + "token_type": "Bearer", + "user": userResponse(user), + "permissions": perms, + "deployments": depAccess, + }) + return + } + } + if !m.validateAPIKey(req.APIKey) { c.JSON(http.StatusUnauthorized, gin.H{ "error": "Invalid API key", @@ -249,9 +475,10 @@ func (m *Middleware) Login(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{ - "token": token, - "expires_in": 86400, - "token_type": "Bearer", + "token": token, + "expires_in": 86400, + "token_type": "Bearer", + "permissions": GetRolePermissions(RoleAdmin), }) } @@ -271,3 +498,108 @@ func (m *Middleware) ValidateTokenString(token string) bool { func (m *Middleware) IsAuthEnabled() bool { return m.config.Enabled } + +func (m *Middleware) RequirePermission(perms ...Permission) gin.HandlerFunc { + return func(c *gin.Context) { + actorVal, exists := c.Get(contextkeys.Actor) + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Not authenticated", + }) + c.Abort() + return + } + + actor, ok := actorVal.(*ActorContext) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Invalid actor context", + }) + c.Abort() + return + } + + for _, perm := range perms { + if !actor.HasPermission(perm) { + c.JSON(http.StatusForbidden, gin.H{ + "error": fmt.Sprintf("Permission denied: %s required", perm), + }) + c.Abort() + return + } + } + + c.Next() + } +} + +func (m *Middleware) RequireDeploymentAccess(level string) gin.HandlerFunc { + return func(c *gin.Context) { + actorVal, exists := c.Get(contextkeys.Actor) + if !exists { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": "Not authenticated", + }) + c.Abort() + return + } + + actor, ok := actorVal.(*ActorContext) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "Invalid actor context", + }) + c.Abort() + return + } + + deploymentName := c.Param("name") + if deploymentName == "" { + deploymentName = c.Param("deployment") + } + + if deploymentName == "" { + c.Next() + return + } + + if !actor.CanAccessDeployment(deploymentName, level) { + c.JSON(http.StatusForbidden, gin.H{ + "error": "No access to this deployment", + }) + c.Abort() + return + } + + c.Next() + } +} + +func GetActorFromContext(c *gin.Context) *ActorContext { + actorVal, exists := c.Get(contextkeys.Actor) + if !exists { + return nil + } + actor, ok := actorVal.(*ActorContext) + if !ok { + return nil + } + return actor +} + +func userResponse(u *User) gin.H { + resp := gin.H{ + "id": u.ID, + "uid": u.UID, + "username": u.Username, + "email": u.Email, + "role": u.Role, + "is_active": u.IsActive, + "created_at": u.CreatedAt, + "last_login_at": u.LastLoginAt, + } + if len(u.Permissions) > 0 { + resp["permissions"] = u.Permissions + } + return resp +} diff --git a/internal/auth/models.go b/internal/auth/models.go new file mode 100644 index 0000000..fdcdccb --- /dev/null +++ b/internal/auth/models.go @@ -0,0 +1,193 @@ +package auth + +import ( + "encoding/json" + "time" +) + +type Role string + +const ( + RoleAdmin Role = "admin" + RoleOperator Role = "operator" + RoleViewer Role = "viewer" +) + +const ( + AccessLevelRead = "read" + AccessLevelWrite = "write" + AccessLevelAdmin = "admin" +) + +func ValidAccessLevel(level string) bool { + return level == AccessLevelRead || level == AccessLevelWrite || level == AccessLevelAdmin +} + +func (r Role) IsValid() bool { + switch r { + case RoleAdmin, RoleOperator, RoleViewer: + return true + } + return false +} + +type User struct { + ID int64 `json:"id"` + UID string `json:"uid"` + Username string `json:"username"` + Email string `json:"email,omitempty"` + PasswordHash string `json:"-"` + Role Role `json:"role"` + Permissions []string `json:"permissions,omitempty"` + IsActive bool `json:"is_active"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + LastLoginAt time.Time `json:"last_login_at,omitempty"` +} + +func (u *User) GetPermissionsJSON() string { + if len(u.Permissions) == 0 { + return "" + } + b, _ := json.Marshal(u.Permissions) + return string(b) +} + +type APIKey struct { + ID int64 `json:"id"` + KeyID string `json:"key_id"` + UserID int64 `json:"user_id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + KeyHash string `json:"-"` + KeyPrefix string `json:"key_prefix"` + Role Role `json:"role,omitempty"` + Permissions []string `json:"permissions,omitempty"` + Deployments []string `json:"deployments,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + LastUsedAt time.Time `json:"last_used_at,omitempty"` + LastUsedIP string `json:"last_used_ip,omitempty"` + IsActive bool `json:"is_active"` + CreatedAt time.Time `json:"created_at"` +} + +type Session struct { + ID int64 `json:"id"` + SessionID string `json:"session_id"` + UserID int64 `json:"user_id"` + APIKeyID int64 `json:"api_key_id,omitempty"` + TokenHash string `json:"-"` + ExpiresAt time.Time `json:"expires_at"` + RevokedAt time.Time `json:"revoked_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + ClientIP string `json:"client_ip,omitempty"` +} + +type UserDeployment struct { + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + DeploymentName string `json:"deployment_name"` + AccessLevel string `json:"access_level"` + GrantedBy int64 `json:"granted_by,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +type ActorContext struct { + Type string `json:"type"` + UserID int64 `json:"user_id,omitempty"` + User *User `json:"user,omitempty"` + APIKey *APIKey `json:"api_key,omitempty"` + Role Role `json:"role"` + Permissions []string `json:"permissions,omitempty"` + Deployments map[string]string `json:"deployments,omitempty"` +} + +func (a *ActorContext) HasPermission(p Permission) bool { + if a.Role == RoleAdmin { + return true + } + + rolePerms := GetRolePermissions(a.Role) + for _, rp := range rolePerms { + if rp == p { + return true + } + } + + for _, ep := range a.Permissions { + if Permission(ep) == p { + return true + } + } + + return false +} + +func (a *ActorContext) CanAccessDeployment(name string, requiredLevel string) bool { + if a.Role == RoleAdmin { + return true + } + + if a.APIKey != nil && len(a.APIKey.Deployments) > 0 { + found := false + for _, d := range a.APIKey.Deployments { + if d == name { + found = true + break + } + } + if !found { + return false + } + } + + level, ok := a.Deployments[name] + if !ok { + return false + } + + return accessLevelSufficient(level, requiredLevel) +} + +func accessLevelSufficient(has, required string) bool { + levels := map[string]int{ + AccessLevelRead: 1, + AccessLevelWrite: 2, + AccessLevelAdmin: 3, + } + return levels[has] >= levels[required] +} + +func (a *APIKey) GetPermissionsJSON() string { + if len(a.Permissions) == 0 { + return "" + } + b, _ := json.Marshal(a.Permissions) + return string(b) +} + +func (a *APIKey) GetDeploymentsJSON() string { + if len(a.Deployments) == 0 { + return "" + } + b, _ := json.Marshal(a.Deployments) + return string(b) +} + +func ParsePermissionsJSON(s string) []string { + if s == "" { + return nil + } + var perms []string + _ = json.Unmarshal([]byte(s), &perms) + return perms +} + +func ParseDeploymentsJSON(s string) []string { + if s == "" { + return nil + } + var deps []string + _ = json.Unmarshal([]byte(s), &deps) + return deps +} diff --git a/internal/auth/models_test.go b/internal/auth/models_test.go new file mode 100644 index 0000000..d8dc1f5 --- /dev/null +++ b/internal/auth/models_test.go @@ -0,0 +1,234 @@ +package auth + +import ( + "testing" +) + +func TestActorContextHasPermission(t *testing.T) { + tests := []struct { + name string + actor *ActorContext + permission Permission + want bool + }{ + { + name: "admin has all permissions", + actor: &ActorContext{Role: RoleAdmin}, + permission: PermUsersDelete, + want: true, + }, + { + name: "viewer has read permission", + actor: &ActorContext{Role: RoleViewer}, + permission: PermDeploymentsRead, + want: true, + }, + { + name: "viewer cannot write", + actor: &ActorContext{Role: RoleViewer}, + permission: PermDeploymentsWrite, + want: false, + }, + { + name: "explicit permission overrides role", + actor: &ActorContext{ + Role: RoleViewer, + Permissions: []string{string(PermUsersWrite)}, + }, + permission: PermUsersWrite, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.actor.HasPermission(tt.permission); got != tt.want { + t.Errorf("ActorContext.HasPermission() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestActorContextCanAccessDeployment(t *testing.T) { + tests := []struct { + name string + actor *ActorContext + deploymentName string + requiredLevel string + want bool + }{ + { + name: "admin can access any deployment", + actor: &ActorContext{Role: RoleAdmin}, + deploymentName: "any-deployment", + requiredLevel: "admin", + want: true, + }, + { + name: "user with read access can read", + actor: &ActorContext{ + Role: RoleOperator, + Deployments: map[string]string{"my-app": "read"}, + }, + deploymentName: "my-app", + requiredLevel: "read", + want: true, + }, + { + name: "user with read access cannot write", + actor: &ActorContext{ + Role: RoleOperator, + Deployments: map[string]string{"my-app": "read"}, + }, + deploymentName: "my-app", + requiredLevel: "write", + want: false, + }, + { + name: "user with write access can read", + actor: &ActorContext{ + Role: RoleOperator, + Deployments: map[string]string{"my-app": "write"}, + }, + deploymentName: "my-app", + requiredLevel: "read", + want: true, + }, + { + name: "user with write access can write", + actor: &ActorContext{ + Role: RoleOperator, + Deployments: map[string]string{"my-app": "write"}, + }, + deploymentName: "my-app", + requiredLevel: "write", + want: true, + }, + { + name: "user cannot access unassigned deployment", + actor: &ActorContext{ + Role: RoleOperator, + Deployments: map[string]string{"my-app": "write"}, + }, + deploymentName: "other-app", + requiredLevel: "read", + want: false, + }, + { + name: "api key scoped to deployment restricts access", + actor: &ActorContext{ + Role: RoleOperator, + Deployments: map[string]string{"app-a": "write", "app-b": "write"}, + APIKey: &APIKey{Deployments: []string{"app-a"}}, + }, + deploymentName: "app-b", + requiredLevel: "read", + want: false, + }, + { + name: "api key scoped to deployment allows access", + actor: &ActorContext{ + Role: RoleOperator, + Deployments: map[string]string{"app-a": "write"}, + APIKey: &APIKey{Deployments: []string{"app-a"}}, + }, + deploymentName: "app-a", + requiredLevel: "write", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.actor.CanAccessDeployment(tt.deploymentName, tt.requiredLevel); got != tt.want { + t.Errorf("ActorContext.CanAccessDeployment() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAccessLevelSufficient(t *testing.T) { + tests := []struct { + has string + required string + want bool + }{ + {"read", "read", true}, + {"write", "read", true}, + {"write", "write", true}, + {"admin", "read", true}, + {"admin", "write", true}, + {"admin", "admin", true}, + {"read", "write", false}, + {"read", "admin", false}, + {"write", "admin", false}, + } + + for _, tt := range tests { + t.Run(tt.has+"_vs_"+tt.required, func(t *testing.T) { + if got := accessLevelSufficient(tt.has, tt.required); got != tt.want { + t.Errorf("accessLevelSufficient(%s, %s) = %v, want %v", tt.has, tt.required, got, tt.want) + } + }) + } +} + +func TestAPIKeyGetPermissionsJSON(t *testing.T) { + key := &APIKey{Permissions: []string{"deployments:read", "deployments:write"}} + + json := key.GetPermissionsJSON() + if json == "" { + t.Error("GetPermissionsJSON returned empty for non-empty permissions") + } + + emptyKey := &APIKey{} + if emptyKey.GetPermissionsJSON() != "" { + t.Error("GetPermissionsJSON should return empty for nil permissions") + } +} + +func TestAPIKeyGetDeploymentsJSON(t *testing.T) { + key := &APIKey{Deployments: []string{"app-a", "app-b"}} + + json := key.GetDeploymentsJSON() + if json == "" { + t.Error("GetDeploymentsJSON returned empty for non-empty deployments") + } + + emptyKey := &APIKey{} + if emptyKey.GetDeploymentsJSON() != "" { + t.Error("GetDeploymentsJSON should return empty for nil deployments") + } +} + +func TestParsePermissionsJSON(t *testing.T) { + json := `["deployments:read","deployments:write"]` + perms := ParsePermissionsJSON(json) + + if len(perms) != 2 { + t.Errorf("ParsePermissionsJSON returned %d items, want 2", len(perms)) + } + + if perms[0] != "deployments:read" { + t.Errorf("First permission = %s, want deployments:read", perms[0]) + } + + empty := ParsePermissionsJSON("") + if empty != nil { + t.Error("ParsePermissionsJSON should return nil for empty string") + } +} + +func TestParseDeploymentsJSON(t *testing.T) { + json := `["app-a","app-b"]` + deps := ParseDeploymentsJSON(json) + + if len(deps) != 2 { + t.Errorf("ParseDeploymentsJSON returned %d items, want 2", len(deps)) + } + + empty := ParseDeploymentsJSON("") + if empty != nil { + t.Error("ParseDeploymentsJSON should return nil for empty string") + } +} diff --git a/internal/auth/password.go b/internal/auth/password.go new file mode 100644 index 0000000..5964914 --- /dev/null +++ b/internal/auth/password.go @@ -0,0 +1,20 @@ +package auth + +import ( + "golang.org/x/crypto/bcrypt" +) + +const bcryptCost = 12 + +func HashPassword(password string) (string, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost) + if err != nil { + return "", err + } + return string(bytes), nil +} + +func VerifyPassword(password, hash string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err == nil +} diff --git a/internal/auth/password_test.go b/internal/auth/password_test.go new file mode 100644 index 0000000..f7f00cf --- /dev/null +++ b/internal/auth/password_test.go @@ -0,0 +1,65 @@ +package auth + +import ( + "testing" +) + +func TestHashPassword(t *testing.T) { + password := "testpassword123" + + hash, err := HashPassword(password) + if err != nil { + t.Fatalf("HashPassword failed: %v", err) + } + + if hash == "" { + t.Error("HashPassword returned empty hash") + } + + if hash == password { + t.Error("HashPassword returned plaintext password") + } +} + +func TestHashPasswordDifferentHashes(t *testing.T) { + password := "testpassword123" + + hash1, _ := HashPassword(password) + hash2, _ := HashPassword(password) + + if hash1 == hash2 { + t.Error("HashPassword should generate different hashes for same password (due to salt)") + } +} + +func TestVerifyPassword(t *testing.T) { + password := "testpassword123" + + hash, err := HashPassword(password) + if err != nil { + t.Fatalf("HashPassword failed: %v", err) + } + + if !VerifyPassword(password, hash) { + t.Error("VerifyPassword should return true for correct password") + } +} + +func TestVerifyPasswordWrong(t *testing.T) { + password := "testpassword123" + wrongPassword := "wrongpassword" + + hash, _ := HashPassword(password) + + if VerifyPassword(wrongPassword, hash) { + t.Error("VerifyPassword should return false for wrong password") + } +} + +func TestVerifyPasswordEmpty(t *testing.T) { + hash, _ := HashPassword("somepassword") + + if VerifyPassword("", hash) { + t.Error("VerifyPassword should return false for empty password") + } +} diff --git a/internal/auth/permissions.go b/internal/auth/permissions.go new file mode 100644 index 0000000..a8d1780 --- /dev/null +++ b/internal/auth/permissions.go @@ -0,0 +1,194 @@ +package auth + +type Permission string + +const ( + PermDeploymentsRead Permission = "deployments:read" + PermDeploymentsWrite Permission = "deployments:write" + PermDeploymentsDelete Permission = "deployments:delete" + + PermCertificatesRead Permission = "certificates:read" + PermCertificatesWrite Permission = "certificates:write" + PermCertificatesDelete Permission = "certificates:delete" + + PermNetworksRead Permission = "networks:read" + PermNetworksWrite Permission = "networks:write" + PermNetworksDelete Permission = "networks:delete" + + PermSecurityRead Permission = "security:read" + PermSecurityWrite Permission = "security:write" + + PermBackupsRead Permission = "backups:read" + PermBackupsWrite Permission = "backups:write" + PermBackupsDelete Permission = "backups:delete" + + PermUsersRead Permission = "users:read" + PermUsersWrite Permission = "users:write" + PermUsersDelete Permission = "users:delete" + + PermAPIKeysRead Permission = "apikeys:read" + PermAPIKeysWrite Permission = "apikeys:write" + PermAPIKeysDelete Permission = "apikeys:delete" + + PermSettingsRead Permission = "settings:read" + PermSettingsWrite Permission = "settings:write" + + PermAuditRead Permission = "audit:read" + + PermContainersRead Permission = "containers:read" + PermContainersWrite Permission = "containers:write" + PermContainersDelete Permission = "containers:delete" + + PermImagesRead Permission = "images:read" + PermImagesWrite Permission = "images:write" + PermImagesDelete Permission = "images:delete" + + PermVolumesRead Permission = "volumes:read" + PermVolumesWrite Permission = "volumes:write" + PermVolumesDelete Permission = "volumes:delete" + + PermDatabasesRead Permission = "databases:read" + PermDatabasesWrite Permission = "databases:write" + PermDatabasesDelete Permission = "databases:delete" + + PermInfrastructureRead Permission = "infrastructure:read" + PermInfrastructureWrite Permission = "infrastructure:write" + + PermSchedulerRead Permission = "scheduler:read" + PermSchedulerWrite Permission = "scheduler:write" + PermSchedulerDelete Permission = "scheduler:delete" + + PermSystemRead Permission = "system:read" + PermSystemWrite Permission = "system:write" + + PermDNSRead Permission = "dns:read" + PermDNSWrite Permission = "dns:write" + + PermRegistriesRead Permission = "registries:read" + PermRegistriesWrite Permission = "registries:write" + PermRegistriesDelete Permission = "registries:delete" + + PermTemplatesRead Permission = "templates:read" + PermTemplatesWrite Permission = "templates:write" + + PermTrafficRead Permission = "traffic:read" + PermTrafficWrite Permission = "traffic:write" +) + +var adminPermissions = []Permission{ + PermDeploymentsRead, PermDeploymentsWrite, PermDeploymentsDelete, + PermCertificatesRead, PermCertificatesWrite, PermCertificatesDelete, + PermNetworksRead, PermNetworksWrite, PermNetworksDelete, + PermSecurityRead, PermSecurityWrite, + PermBackupsRead, PermBackupsWrite, PermBackupsDelete, + PermUsersRead, PermUsersWrite, PermUsersDelete, + PermAPIKeysRead, PermAPIKeysWrite, PermAPIKeysDelete, + PermSettingsRead, PermSettingsWrite, + PermAuditRead, + PermContainersRead, PermContainersWrite, PermContainersDelete, + PermImagesRead, PermImagesWrite, PermImagesDelete, + PermVolumesRead, PermVolumesWrite, PermVolumesDelete, + PermDatabasesRead, PermDatabasesWrite, PermDatabasesDelete, + PermInfrastructureRead, PermInfrastructureWrite, + PermSchedulerRead, PermSchedulerWrite, PermSchedulerDelete, + PermSystemRead, PermSystemWrite, + PermDNSRead, PermDNSWrite, + PermRegistriesRead, PermRegistriesWrite, PermRegistriesDelete, + PermTemplatesRead, PermTemplatesWrite, + PermTrafficRead, PermTrafficWrite, +} + +var operatorPermissions = []Permission{ + PermDeploymentsRead, PermDeploymentsWrite, + PermCertificatesRead, PermCertificatesWrite, + PermNetworksRead, + PermSecurityRead, + PermBackupsRead, PermBackupsWrite, + PermAPIKeysRead, PermAPIKeysWrite, PermAPIKeysDelete, + PermSettingsRead, + PermContainersRead, PermContainersWrite, + PermImagesRead, PermImagesWrite, + PermVolumesRead, PermVolumesWrite, + PermDatabasesRead, PermDatabasesWrite, + PermInfrastructureRead, PermInfrastructureWrite, + PermSchedulerRead, PermSchedulerWrite, + PermSystemRead, PermSystemWrite, + PermDNSRead, PermDNSWrite, + PermRegistriesRead, PermRegistriesWrite, + PermTemplatesRead, + PermTrafficRead, +} + +var viewerPermissions = []Permission{ + PermDeploymentsRead, + PermCertificatesRead, + PermNetworksRead, + PermSecurityRead, + PermBackupsRead, + PermAPIKeysRead, + PermSettingsRead, + PermContainersRead, + PermImagesRead, + PermVolumesRead, + PermDatabasesRead, + PermInfrastructureRead, + PermSchedulerRead, + PermSystemRead, + PermDNSRead, + PermRegistriesRead, + PermTemplatesRead, + PermTrafficRead, +} + +func GetRolePermissions(role Role) []Permission { + switch role { + case RoleAdmin: + return adminPermissions + case RoleOperator: + return operatorPermissions + case RoleViewer: + return viewerPermissions + default: + return nil + } +} + +func HasPermission(role Role, explicitPerms []string, required Permission) bool { + if role == RoleAdmin { + return true + } + + rolePerms := GetRolePermissions(role) + for _, p := range rolePerms { + if p == required { + return true + } + } + + for _, p := range explicitPerms { + if Permission(p) == required { + return true + } + } + + return false +} + +func GetAllPermissions() []Permission { + return adminPermissions +} + +func (p Permission) String() string { + return string(p) +} + +func EffectivePermissions(user *User, role Role) []Permission { + if user != nil && len(user.Permissions) > 0 { + perms := make([]Permission, 0, len(user.Permissions)) + for _, p := range user.Permissions { + perms = append(perms, Permission(p)) + } + return perms + } + return GetRolePermissions(role) +} diff --git a/internal/auth/permissions_test.go b/internal/auth/permissions_test.go new file mode 100644 index 0000000..0c16ffb --- /dev/null +++ b/internal/auth/permissions_test.go @@ -0,0 +1,188 @@ +package auth + +import ( + "testing" +) + +func TestRoleIsValid(t *testing.T) { + tests := []struct { + role Role + valid bool + }{ + {RoleAdmin, true}, + {RoleOperator, true}, + {RoleViewer, true}, + {Role("invalid"), false}, + {Role(""), false}, + } + + for _, tt := range tests { + if got := tt.role.IsValid(); got != tt.valid { + t.Errorf("Role(%q).IsValid() = %v, want %v", tt.role, got, tt.valid) + } + } +} + +func TestGetRolePermissions(t *testing.T) { + adminPerms := GetRolePermissions(RoleAdmin) + if len(adminPerms) == 0 { + t.Error("Admin should have permissions") + } + + operatorPerms := GetRolePermissions(RoleOperator) + if len(operatorPerms) == 0 { + t.Error("Operator should have permissions") + } + + viewerPerms := GetRolePermissions(RoleViewer) + if len(viewerPerms) == 0 { + t.Error("Viewer should have permissions") + } + + if len(adminPerms) <= len(operatorPerms) { + t.Error("Admin should have more permissions than operator") + } + + if len(operatorPerms) <= len(viewerPerms) { + t.Error("Operator should have more permissions than viewer") + } + + invalidPerms := GetRolePermissions(Role("invalid")) + if invalidPerms != nil { + t.Error("Invalid role should return nil permissions") + } +} + +func TestAdminHasAllPermissions(t *testing.T) { + allPerms := GetAllPermissions() + + for _, perm := range allPerms { + if !HasPermission(RoleAdmin, nil, perm) { + t.Errorf("Admin should have permission %s", perm) + } + } +} + +func TestViewerCannotWrite(t *testing.T) { + writePerms := []Permission{ + PermDeploymentsWrite, + PermDeploymentsDelete, + PermUsersWrite, + PermUsersDelete, + PermContainersWrite, + PermContainersDelete, + PermImagesWrite, + PermImagesDelete, + PermVolumesWrite, + PermVolumesDelete, + PermDatabasesWrite, + PermDatabasesDelete, + PermInfrastructureWrite, + PermSchedulerWrite, + PermSchedulerDelete, + PermSystemWrite, + PermDNSWrite, + PermRegistriesWrite, + PermRegistriesDelete, + PermTemplatesWrite, + PermTrafficWrite, + } + + for _, perm := range writePerms { + if HasPermission(RoleViewer, nil, perm) { + t.Errorf("Viewer should not have permission %s", perm) + } + } +} + +func TestViewerCanRead(t *testing.T) { + readPerms := []Permission{ + PermDeploymentsRead, + PermCertificatesRead, + PermNetworksRead, + PermContainersRead, + PermImagesRead, + PermVolumesRead, + PermDatabasesRead, + PermInfrastructureRead, + PermSchedulerRead, + PermSystemRead, + PermDNSRead, + PermRegistriesRead, + PermTemplatesRead, + PermTrafficRead, + } + + for _, perm := range readPerms { + if !HasPermission(RoleViewer, nil, perm) { + t.Errorf("Viewer should have permission %s", perm) + } + } +} + +func TestOperatorPermissions(t *testing.T) { + if !HasPermission(RoleOperator, nil, PermDeploymentsWrite) { + t.Error("Operator should be able to write deployments") + } + + if HasPermission(RoleOperator, nil, PermUsersWrite) { + t.Error("Operator should not be able to write users") + } + + if HasPermission(RoleOperator, nil, PermDeploymentsDelete) { + t.Error("Operator should not be able to delete deployments") + } + + // Operator can write new resource groups + operatorWritePerms := []Permission{ + PermContainersWrite, + PermImagesWrite, + PermVolumesWrite, + PermDatabasesWrite, + PermInfrastructureWrite, + PermSchedulerWrite, + PermSystemWrite, + PermDNSWrite, + PermRegistriesWrite, + } + for _, perm := range operatorWritePerms { + if !HasPermission(RoleOperator, nil, perm) { + t.Errorf("Operator should have permission %s", perm) + } + } + + // Operator cannot delete new resource groups + operatorNoDeletePerms := []Permission{ + PermContainersDelete, + PermImagesDelete, + PermVolumesDelete, + PermDatabasesDelete, + PermSchedulerDelete, + PermRegistriesDelete, + } + for _, perm := range operatorNoDeletePerms { + if HasPermission(RoleOperator, nil, perm) { + t.Errorf("Operator should not have permission %s", perm) + } + } +} + +func TestExplicitPermissionsOverride(t *testing.T) { + explicitPerms := []string{string(PermUsersWrite)} + + if !HasPermission(RoleViewer, explicitPerms, PermUsersWrite) { + t.Error("Explicit permission should grant access") + } + + if HasPermission(RoleViewer, explicitPerms, PermUsersDelete) { + t.Error("Should not have permissions not explicitly granted") + } +} + +func TestPermissionString(t *testing.T) { + perm := PermDeploymentsRead + + if perm.String() != "deployments:read" { + t.Errorf("Permission.String() = %s, want deployments:read", perm.String()) + } +} diff --git a/internal/contextkeys/keys.go b/internal/contextkeys/keys.go index f721739..eeff867 100644 --- a/internal/contextkeys/keys.go +++ b/internal/contextkeys/keys.go @@ -6,4 +6,5 @@ const ( ActorName = "audit_actor_name" APIKeyPrefix = "audit_api_key_prefix" RequestID = "audit_request_id" + Actor = "auth_actor_context" )