diff --git a/engine/constants.go b/engine/constants.go new file mode 100644 index 00000000..7b4404d3 --- /dev/null +++ b/engine/constants.go @@ -0,0 +1,9 @@ +package engine + +const ( + RedisEventReset = "reset" + RedisEventRoundFinish = "round_finish" + RedisQueueTasks = "tasks" + RedisQueueResults = "results" + RedisChannelEvents = "events" +) diff --git a/engine/db/services.go b/engine/db/services.go deleted file mode 100644 index 49fce576..00000000 --- a/engine/db/services.go +++ /dev/null @@ -1,4 +0,0 @@ -package db - -type ServiceSchema struct { -} diff --git a/engine/engine.go b/engine/engine.go index 6d7d2ef6..5302ac64 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -118,7 +118,7 @@ func (se *ScoringEngine) Start() { Password: os.Getenv("REDIS_PASSWORD"), }) - events := rdb.Subscribe(context.Background(), "events") + events := rdb.Subscribe(context.Background(), RedisChannelEvents) defer events.Close() eventsChannel := events.Channel() @@ -130,7 +130,7 @@ func (se *ScoringEngine) Start() { select { case msg := <-eventsChannel: slog.Info("Received message", "message", msg.Payload) - if msg.Payload == "reset" { + if msg.Payload == RedisEventReset { slog.Info("Engine loop reset event received while waiting, quitting...") return } else { @@ -159,7 +159,7 @@ func (se *ScoringEngine) Start() { slog.Info(fmt.Sprintf("Round %d complete", se.CurrentRound)) se.CurrentRound++ - se.RedisClient.Publish(context.Background(), "events", "round_finish") + se.RedisClient.Publish(context.Background(), RedisChannelEvents, RedisEventRoundFinish) slog.Info(fmt.Sprintf("Round %d will start in %s, sleeping...", se.CurrentRound, time.Until(se.NextRoundStartTime).String())) time.Sleep(time.Until(se.NextRoundStartTime)) } @@ -185,13 +185,13 @@ func waitForReset() { Password: os.Getenv("REDIS_PASSWORD"), }) - events := rdb.Subscribe(context.Background(), "events") + events := rdb.Subscribe(context.Background(), RedisChannelEvents) defer events.Close() eventsChannel := events.Channel() for msg := range eventsChannel { slog.Info("Received message", "message", msg.Payload) - if msg.Payload == "reset" { + if msg.Payload == RedisEventReset { slog.Info("Reset event received, quitting...") return } else { @@ -311,7 +311,7 @@ func (se *ScoringEngine) ResetScores() error { // Flush Redis queues ctx := context.Background() - keysToDelete := []string{"tasks", "results"} + keysToDelete := []string{RedisQueueTasks, RedisQueueResults} for _, key := range keysToDelete { if err := se.RedisClient.Del(ctx, key).Err(); err != nil { slog.Error("Failed to clear Redis queue", "queue", key, "error", err) @@ -320,7 +320,7 @@ func (se *ScoringEngine) ResetScores() error { } // Reset engine state - se.RedisClient.Publish(context.Background(), "events", "reset") + se.RedisClient.Publish(context.Background(), RedisChannelEvents, RedisEventReset) se.CurrentRound = 1 se.uptimeMu.Lock() @@ -361,7 +361,7 @@ func (se *ScoringEngine) rvb() error { Password: os.Getenv("REDIS_PASSWORD"), }) - events := rdb.Subscribe(context.Background(), "events") + events := rdb.Subscribe(context.Background(), RedisChannelEvents) defer events.Close() eventsChannel := events.Channel() // @@ -389,10 +389,10 @@ func (se *ScoringEngine) rvb() error { } // Clear any stale tasks from previous rounds before enqueuing new ones - staleTasks := se.RedisClient.LLen(ctx, "tasks").Val() + staleTasks := se.RedisClient.LLen(ctx, RedisQueueTasks).Val() if staleTasks > 0 { slog.Warn("Clearing stale tasks from queue", "count", staleTasks, "round", se.CurrentRound) - se.RedisClient.Del(ctx, "tasks") + se.RedisClient.Del(ctx, RedisQueueTasks) } // 1) Enqueue @@ -457,7 +457,7 @@ func (se *ScoringEngine) rvb() error { slog.Error("failed to marshal service task", "error", err) continue } - se.RedisClient.RPush(ctx, "tasks", payload) + se.RedisClient.RPush(ctx, RedisQueueTasks, payload) runners++ } } @@ -474,14 +474,14 @@ COLLECTION: select { case msg := <-eventsChannel: slog.Info("Received message", "message", msg.Payload) - if msg.Payload == "reset" { + if msg.Payload == RedisEventReset { slog.Info("Reset event received, quitting...") return fmt.Errorf("reset event received") } else { continue } default: - val, err := se.RedisClient.BLPop(timeoutCtx, time.Until(se.NextRoundStartTime), "results").Result() + val, err := se.RedisClient.BLPop(timeoutCtx, time.Until(se.NextRoundStartTime), RedisQueueResults).Result() if err == redis.Nil { slog.Warn("Timeout waiting for results", "remaining", runners-i, "collected", i, "expected", runners) results = []checks.Result{} diff --git a/runner/runner.go b/runner/runner.go index 76373ca4..cb20147c 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -54,13 +54,13 @@ func runApp(err error) int { slog.Info("runner started", "runner_id", runnerID, "redis_addr", redisAddr) go func() { - events := rdb.Subscribe(context.Background(), "events") + events := rdb.Subscribe(context.Background(), engine.RedisChannelEvents) defer events.Close() eventsChannel := events.Channel() for msg := range eventsChannel { slog.Info("received message", "payload", msg.Payload) - if msg.Payload == "reset" { + if msg.Payload == engine.RedisEventReset { slog.Info("reset event received, quitting") os.Exit(0) } else { @@ -88,7 +88,7 @@ func runApp(err error) int { func getNextTask(ctx context.Context, rdb *redis.Client) (*engine.Task, error) { // Block until we get a task from the "tasks" list - val, err := rdb.BLPop(ctx, 0, "tasks").Result() + val, err := rdb.BLPop(ctx, 0, engine.RedisQueueTasks).Result() if err != nil { return nil, fmt.Errorf("failed to pop task: %w", err) } @@ -240,7 +240,7 @@ func handleTask(ctx context.Context, rdb *redis.Client, runner checks.Runner, ta return } - if err := rdb.RPush(ctx, "results", resultJSON).Err(); err != nil { + if err := rdb.RPush(ctx, engine.RedisQueueResults, resultJSON).Err(); err != nil { slog.Error("failed to push result to Redis", "error", err) return } diff --git a/www/api/admin.go b/www/api/admin.go index 61ccceab..faac5845 100644 --- a/www/api/admin.go +++ b/www/api/admin.go @@ -131,10 +131,6 @@ func ExportScores(w http.ResponseWriter, r *http.Request) { WriteJSON(w, http.StatusOK, data) } -func ExportConfig(w http.ResponseWriter, r *http.Request) { - -} - func GetActiveTasks(w http.ResponseWriter, r *http.Request) { tasks, err := eng.GetActiveTasks() if err != nil { diff --git a/www/api/announcements.go b/www/api/announcements.go index 5a638735..0fdc656f 100644 --- a/www/api/announcements.go +++ b/www/api/announcements.go @@ -9,6 +9,7 @@ import ( "path" "path/filepath" "quotient/engine/db" + "quotient/www/auth" "slices" "time" @@ -24,8 +25,8 @@ func GetAnnouncements(w http.ResponseWriter, r *http.Request) { // if not admin filter out announcements that are not open yet req_roles := r.Context().Value("roles").([]string) - if !slices.Contains(req_roles, "admin") { - if slices.Contains(req_roles, "red") && !conf.UISettings.ShowAnnouncementsForRedTeam { + if !slices.Contains(req_roles, auth.RoleAdmin) { + if slices.Contains(req_roles, auth.RoleRed) && !conf.UISettings.ShowAnnouncementsForRedTeam { WriteJSON(w, http.StatusForbidden, map[string]any{"error": "Forbidden"}) return } @@ -75,8 +76,8 @@ func DownloadAnnouncementFile(w http.ResponseWriter, r *http.Request) { // if not admin check if the announcement is open req_roles := r.Context().Value("roles").([]string) - if !slices.Contains(req_roles, "admin") { - if slices.Contains(req_roles, "red") && !conf.UISettings.ShowAnnouncementsForRedTeam { + if !slices.Contains(req_roles, auth.RoleAdmin) { + if slices.Contains(req_roles, auth.RoleRed) && !conf.UISettings.ShowAnnouncementsForRedTeam { WriteJSON(w, http.StatusForbidden, map[string]any{"error": "Forbidden"}) return } diff --git a/www/api/authentication.go b/www/api/authentication.go index fd93024d..6b62311b 100644 --- a/www/api/authentication.go +++ b/www/api/authentication.go @@ -11,6 +11,7 @@ import ( "net/http" "os" "quotient/engine/config" + "quotient/www/auth" "strings" "time" @@ -98,14 +99,14 @@ func Login(w http.ResponseWriter, r *http.Request) { } // check credentials - auth, err := auth(form.Username, form.Password) + authResult, err := authenticateUser(form.Username, form.Password) if err != nil { WriteJSON(w, http.StatusUnauthorized, map[string]any{"error": "Incorrect username/password"}) slog.Info("Failed logon", "username", form.Username) return } - cookie, err := CookieEncoder.Encode(COOKIENAME, auth) + cookie, err := CookieEncoder.Encode(COOKIENAME, authResult) if err != nil { WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "Authentication error"}) slog.Error(err.Error()) @@ -194,7 +195,7 @@ func Authenticate(w http.ResponseWriter, r *http.Request) (string, []string) { return username, roles } -func auth(username string, password string) (map[string]any, error) { +func authenticateUser(username string, password string) (map[string]any, error) { for _, admin := range conf.Admin { if username == admin.Name && password == admin.Pw { return map[string]any{"username": username, "authSource": "local"}, nil @@ -315,22 +316,22 @@ func findRolesByUsername(username string, authSource string) ([]string, error) { if authSource == "local" { for _, admin := range conf.Admin { if username == admin.Name { - roles = append(roles, "admin") + roles = append(roles, auth.RoleAdmin) } } for _, red := range conf.Red { if username == red.Name { - roles = append(roles, "red") + roles = append(roles, auth.RoleRed) } } for _, team := range conf.Team { if username == team.Name { - roles = append(roles, "team") + roles = append(roles, auth.RoleTeam) } } for _, inject := range conf.Inject { if username == inject.Name { - roles = append(roles, "inject") + roles = append(roles, auth.RoleInject) } } @@ -371,19 +372,19 @@ func findRolesByUsername(username string, authSource string) ([]string, error) { for _, entry := range sr.Entries { for _, memberOf := range entry.GetAttributeValues("memberOf") { if memberOf == conf.LdapSettings.LdapAdminGroupDn { - roles = append(roles, "admin") + roles = append(roles, auth.RoleAdmin) } if memberOf == conf.LdapSettings.LdapRedGroupDn { - roles = append(roles, "red") + roles = append(roles, auth.RoleRed) } if memberOf == conf.LdapSettings.LdapTeamGroupDn { - roles = append(roles, "team") + roles = append(roles, auth.RoleTeam) } if memberOf == conf.LdapSettings.LdapInjectGroupDn { - roles = append(roles, "inject") + roles = append(roles, auth.RoleInject) } } } diff --git a/www/api/graphs.go b/www/api/graphs.go index 7b31fa3c..043b2e75 100644 --- a/www/api/graphs.go +++ b/www/api/graphs.go @@ -3,6 +3,7 @@ package api import ( "net/http" "quotient/engine/db" + "quotient/www/auth" "slices" ) @@ -221,7 +222,7 @@ func GetUptimeStatus(w http.ResponseWriter, r *http.Request) { func shouldScrub(r *http.Request) bool { if r.Context().Value("roles") != nil { req_roles := r.Context().Value("roles").([]string) - if slices.Contains(req_roles, "admin") { + if slices.Contains(req_roles, auth.RoleAdmin) { return false } } diff --git a/www/api/injects.go b/www/api/injects.go index 9294541a..600279f4 100644 --- a/www/api/injects.go +++ b/www/api/injects.go @@ -12,6 +12,8 @@ import ( "slices" "time" + "quotient/www/auth" + "gorm.io/gorm" ) @@ -24,7 +26,7 @@ func GetInjects(w http.ResponseWriter, r *http.Request) { // if not admin filter out injects that are not open yet req_roles := r.Context().Value("roles").([]string) - if !slices.Contains(req_roles, "admin") { + if !slices.Contains(req_roles, auth.RoleAdmin) { openInjects := make([]db.InjectSchema, 0) for _, a := range data { if time.Now().After(a.OpenTime) { @@ -40,7 +42,7 @@ func GetInjects(w http.ResponseWriter, r *http.Request) { WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": err.Error()}) return } - if !slices.Contains(req_roles, "admin") { + if !slices.Contains(req_roles, auth.RoleAdmin) { var mySubmissions []db.SubmissionSchema for _, submission := range data[i].Submissions { if submission.Team.Name == r.Context().Value("username") { @@ -88,7 +90,7 @@ func DownloadInjectFile(w http.ResponseWriter, r *http.Request) { // if not admin, check if the inject is open req_roles := r.Context().Value("roles").([]string) - if !slices.Contains(req_roles, "admin") && time.Now().Before(inject.OpenTime) { + if !slices.Contains(req_roles, auth.RoleAdmin) && time.Now().Before(inject.OpenTime) { WriteJSON(w, http.StatusNotFound, map[string]any{"error": "Inject not found"}) return } diff --git a/www/api/oidc.go b/www/api/oidc.go index 41e25c8a..d1d937e3 100644 --- a/www/api/oidc.go +++ b/www/api/oidc.go @@ -10,6 +10,7 @@ import ( "fmt" "log/slog" "net/http" + "quotient/www/auth" "slices" "strings" "sync" @@ -338,7 +339,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { // Redirect to appropriate dashboard redirectURL := "/announcements" - if slices.Contains(roles, "red") { + if slices.Contains(roles, auth.RoleRed) { redirectURL = "/graphs" } @@ -386,7 +387,7 @@ func mapGroupsToRoles(groups []string) []string { // Check admin groups for _, adminGroup := range conf.OIDCSettings.OIDCAdminGroups { if matchesGroup(groups, adminGroup) { - roles = append(roles, "admin") + roles = append(roles, auth.RoleAdmin) break } } @@ -394,7 +395,7 @@ func mapGroupsToRoles(groups []string) []string { // Check inject groups for _, injectGroup := range conf.OIDCSettings.OIDCInjectGroups { if matchesGroup(groups, injectGroup) { - roles = append(roles, "inject") + roles = append(roles, auth.RoleInject) break } } @@ -402,7 +403,7 @@ func mapGroupsToRoles(groups []string) []string { // Check red groups for _, redGroup := range conf.OIDCSettings.OIDCRedGroups { if matchesGroup(groups, redGroup) { - roles = append(roles, "red") + roles = append(roles, auth.RoleRed) break } } @@ -410,7 +411,7 @@ func mapGroupsToRoles(groups []string) []string { // Check team groups for _, teamGroup := range conf.OIDCSettings.OIDCTeamGroups { if matchesGroup(groups, teamGroup) { - roles = append(roles, "team") + roles = append(roles, auth.RoleTeam) break } } @@ -437,21 +438,21 @@ func matchesGroup(userGroups []string, configGroup string) bool { func getRefreshTokenExpiry(roles []string) int { defaultExpiry := 86400 // 24 hours in seconds - if slices.Contains(roles, "admin") { + if slices.Contains(roles, auth.RoleAdmin) { expiry := conf.OIDCSettings.OIDCRefreshTokenExpiryAdmin if expiry == 0 { return defaultExpiry } return expiry } - if slices.Contains(roles, "red") { + if slices.Contains(roles, auth.RoleRed) { expiry := conf.OIDCSettings.OIDCRefreshTokenExpiryRed if expiry == 0 { return defaultExpiry } return expiry } - if slices.Contains(roles, "inject") { + if slices.Contains(roles, auth.RoleInject) { expiry := conf.OIDCSettings.OIDCRefreshTokenExpiryInject if expiry == 0 { return defaultExpiry diff --git a/www/api/pcrs.go b/www/api/pcrs.go index da8074ff..557af129 100644 --- a/www/api/pcrs.go +++ b/www/api/pcrs.go @@ -6,13 +6,14 @@ import ( "log/slog" "net/http" "quotient/engine/db" + "quotient/www/auth" "slices" "strconv" ) func GetCredlists(w http.ResponseWriter, r *http.Request) { req_roles := r.Context().Value("roles").([]string) - if !slices.Contains(req_roles, "admin") && !conf.MiscSettings.EasyPCR { + if !slices.Contains(req_roles, auth.RoleAdmin) && !conf.MiscSettings.EasyPCR { WriteJSON(w, http.StatusForbidden, map[string]any{"error": "PCR self service not allowed"}) return } @@ -102,7 +103,7 @@ func CreatePcr(w http.ResponseWriter, r *http.Request) { } req_roles := r.Context().Value("roles").([]string) - if !slices.Contains(req_roles, "admin") { + if !slices.Contains(req_roles, auth.RoleAdmin) { if conf.MiscSettings.EasyPCR { me, err := db.GetTeamByUsername(r.Context().Value("username").(string)) if err != nil { @@ -158,7 +159,7 @@ func ResetPcr(w http.ResponseWriter, r *http.Request) { return } req_roles := r.Context().Value("roles").([]string) - if !slices.Contains(req_roles, "admin") { + if !slices.Contains(req_roles, auth.RoleAdmin) { if conf.MiscSettings.EasyPCR { me, err := db.GetTeamByUsername(r.Context().Value("username").(string)) if err != nil { diff --git a/www/api/red.go b/www/api/red.go index 33af6e10..3616d20c 100644 --- a/www/api/red.go +++ b/www/api/red.go @@ -166,10 +166,6 @@ func CreateVector(w http.ResponseWriter, r *http.Request) { WriteJSON(w, http.StatusCreated, map[string]any{"message": "Vector created successfully"}) } -func EditVector(w http.ResponseWriter, r *http.Request) { - -} - func CreateAttack(w http.ResponseWriter, r *http.Request) { if err := r.ParseMultipartForm(10 << 20); err != nil { WriteJSON(w, http.StatusBadRequest, map[string]any{"error": "Failed to parse multipart form"}) @@ -246,6 +242,3 @@ func CreateAttack(w http.ResponseWriter, r *http.Request) { WriteJSON(w, http.StatusCreated, map[string]any{"message": "Attack created successfully"}) } - -func EditAttack(w http.ResponseWriter, r *http.Request) { -} diff --git a/www/api/services.go b/www/api/services.go index 32707ff6..2a702e27 100644 --- a/www/api/services.go +++ b/www/api/services.go @@ -5,6 +5,7 @@ import ( "log/slog" "net/http" "quotient/engine/db" + "quotient/www/auth" "slices" "strconv" "strings" @@ -17,7 +18,7 @@ func GetTeams(w http.ResponseWriter, r *http.Request) { return } req_roles := r.Context().Value("roles").([]string) - if !slices.Contains(req_roles, "admin") { + if !slices.Contains(req_roles, auth.RoleAdmin) { username := r.Context().Value("username").(string) // Check if this is an OIDC user @@ -127,7 +128,7 @@ func GetTeamSummary(w http.ResponseWriter, r *http.Request) { teamID := uint(temp) req_roles := r.Context().Value("roles").([]string) - if !slices.Contains(req_roles, "admin") { + if !slices.Contains(req_roles, auth.RoleAdmin) { myTeamID, err := getUserTeamID(r.Context().Value("username").(string)) if err != nil { slog.Error("Failed to get user's team", "username", r.Context().Value("username").(string), "err", err) @@ -186,7 +187,7 @@ func GetServiceAll(w http.ResponseWriter, r *http.Request) { serviceID := r.PathValue("service_name") req_roles := r.Context().Value("roles").([]string) - if !slices.Contains(req_roles, "admin") { + if !slices.Contains(req_roles, auth.RoleAdmin) { myTeamID, err := getUserTeamID(r.Context().Value("username").(string)) if err != nil { slog.Error("Failed to get user's team", "username", r.Context().Value("username").(string), "err", err) @@ -207,7 +208,7 @@ func GetServiceAll(w http.ResponseWriter, r *http.Request) { // Remove debug and error fields for non-admins // Red team never sees credentials, blue team only if ShowDebugToBlueTeam is enabled - if !slices.Contains(req_roles, "admin") && (slices.Contains(req_roles, "red") || !conf.MiscSettings.ShowDebugToBlueTeam) { + if !slices.Contains(req_roles, auth.RoleAdmin) && (slices.Contains(req_roles, auth.RoleRed) || !conf.MiscSettings.ShowDebugToBlueTeam) { for i := range service { service[i].Debug = "" service[i].Error = "" @@ -216,15 +217,3 @@ func GetServiceAll(w http.ResponseWriter, r *http.Request) { WriteJSON(w, http.StatusOK, service) } - -func CreateService(w http.ResponseWriter, r *http.Request) { - -} - -func UpdateService(w http.ResponseWriter, r *http.Request) { - -} - -func DeleteService(w http.ResponseWriter, r *http.Request) { - -} diff --git a/www/api/submissions.go b/www/api/submissions.go index d9784644..bca293e1 100644 --- a/www/api/submissions.go +++ b/www/api/submissions.go @@ -9,6 +9,7 @@ import ( "net/http" "path/filepath" "quotient/engine/db" + "quotient/www/auth" "slices" "strconv" "time" @@ -137,7 +138,7 @@ func DownloadSubmissionFile(w http.ResponseWriter, r *http.Request) { } req_roles := r.Context().Value("roles").([]string) - if !slices.Contains(req_roles, "admin") && !slices.Contains(req_roles, "inject") && team.ID != teamID { + if !slices.Contains(req_roles, auth.RoleAdmin) && !slices.Contains(req_roles, auth.RoleInject) && team.ID != teamID { WriteJSON(w, http.StatusForbidden, map[string]any{"error": "Forbidden"}) return } diff --git a/www/auth/roles.go b/www/auth/roles.go new file mode 100644 index 00000000..6528e8f6 --- /dev/null +++ b/www/auth/roles.go @@ -0,0 +1,9 @@ +package auth + +const ( + RoleAdmin = "admin" + RoleTeam = "team" + RoleRed = "red" + RoleInject = "inject" + RoleAnonymous = "anonymous" +) diff --git a/www/middleware/authentication.go b/www/middleware/authentication.go index 1fedebe6..694ac219 100644 --- a/www/middleware/authentication.go +++ b/www/middleware/authentication.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "quotient/www/api" + "quotient/www/auth" "slices" "strings" ) @@ -21,7 +22,7 @@ func Authentication(roles ...string) Middleware { username, user_roles := api.Authenticate(w, r) if username == "" { - if slices.Contains(roles, "anonymous") { + if slices.Contains(roles, auth.RoleAnonymous) { next(w, r) return } diff --git a/www/router.go b/www/router.go index ea0660b3..1ee136e2 100644 --- a/www/router.go +++ b/www/router.go @@ -10,6 +10,7 @@ import ( "quotient/engine" "quotient/engine/config" "quotient/www/api" + "quotient/www/auth" "quotient/www/middleware" ) @@ -45,7 +46,7 @@ func (router *Router) Start() { mux.Handle("/static/assets/", http.StripPrefix("/static/assets/", http.FileServer(http.Dir("./static/assets")))) - UNAUTH := middleware.MiddlewareChain(middleware.Logging, middleware.Cors, middleware.Authentication("anonymous", "team", "admin", "red")) + UNAUTH := middleware.MiddlewareChain(middleware.Logging, middleware.Cors, middleware.Authentication(auth.RoleAnonymous, auth.RoleTeam, auth.RoleAdmin, auth.RoleRed)) // public API routes mux.HandleFunc("POST /api/login", api.Login) @@ -69,7 +70,7 @@ func (router *Router) Start() { | | ******************************************/ - ALLAUTH := middleware.MiddlewareChain(middleware.Logging, middleware.Authentication("team", "admin", "red", "inject")) + ALLAUTH := middleware.MiddlewareChain(middleware.Logging, middleware.Authentication(auth.RoleTeam, auth.RoleAdmin, auth.RoleRed, auth.RoleInject)) // general auth API routes mux.HandleFunc("GET /api/logout", ALLAUTH(api.Logout)) @@ -79,15 +80,13 @@ func (router *Router) Start() { // general auth WWW routes mux.HandleFunc("GET /logout", ALLAUTH(router.LogoutPage)) mux.HandleFunc("GET /announcements", ALLAUTH(router.AnnouncementsPage)) - // mux.HandleFunc("GET /graphs", ALLAUTH(router.GraphPage)) - /****************************************** | | | TEAM ROUTES | | | ******************************************/ - TEAMAUTH := middleware.MiddlewareChain(middleware.Logging, middleware.Authentication("team", "admin", "inject")) + TEAMAUTH := middleware.MiddlewareChain(middleware.Logging, middleware.Authentication(auth.RoleTeam, auth.RoleAdmin, auth.RoleInject)) // team auth API routes mux.HandleFunc("GET /api/teams", TEAMAUTH(api.GetTeams)) mux.HandleFunc("GET /api/metadata", TEAMAUTH(api.GetMetadata)) @@ -113,22 +112,15 @@ func (router *Router) Start() { | RED ROUTES | | | ******************************************/ - REDAUTH := middleware.MiddlewareChain(middleware.Authentication("red", "admin")) + REDAUTH := middleware.MiddlewareChain(middleware.Authentication(auth.RoleRed, auth.RoleAdmin)) // red auth API routes mux.HandleFunc("GET /api/red", REDAUTH(api.GetRed)) - // mux.HandleFunc("POST /api/red/vuln", REDAUTH(api.CreatePcr)) mux.HandleFunc("POST /api/red/box", REDAUTH(api.CreateBox)) mux.HandleFunc("POST /api/red/vector", REDAUTH(api.CreateVector)) mux.HandleFunc("POST /api/red/attack", REDAUTH(api.CreateAttack)) mux.HandleFunc("POST /api/red/box/{id}", REDAUTH(api.EditBox)) - mux.HandleFunc("POST /api/red/vector/{id}", REDAUTH(api.EditVector)) - mux.HandleFunc("POST /api/red/attack/{id}", REDAUTH(api.EditAttack)) - - // mux.HandleFunc("DELETE /api/red/box/{id}", REDAUTH(api.DeleteBox)) - // mux.HandleFunc("DELETE /api/red/vector/{id}", REDAUTH(api.DeleteVector)) - // mux.HandleFunc("DELETE /api/red/attack/{id}", REDAUTH(api.DeleteAttack)) // red auth WWW routes mux.HandleFunc("GET /red", REDAUTH(router.RedPage)) @@ -139,7 +131,7 @@ func (router *Router) Start() { | | ******************************************/ - INJECTAUTH := middleware.MiddlewareChain(middleware.Logging, middleware.Authentication("admin", "inject")) + INJECTAUTH := middleware.MiddlewareChain(middleware.Logging, middleware.Authentication(auth.RoleAdmin, auth.RoleInject)) // admin auth API routes mux.HandleFunc("POST /api/announcements/create", INJECTAUTH(api.CreateAnnouncement)) mux.HandleFunc("POST /api/announcements/{id}", INJECTAUTH(api.UpdateAnnouncement)) @@ -150,11 +142,7 @@ func (router *Router) Start() { mux.HandleFunc("DELETE /api/injects/{id}", INJECTAUTH(api.DeleteInject)) mux.HandleFunc("GET /api/injects/{id}/submissions/download", INJECTAUTH(api.DownloadAllSubmissions)) - // router.HandleFunc("POST /api/engine/service/create", ADMINAUTH(api.CreateService)) - // router.HandleFunc("POST /api/engine/service/update", ADMINAUTH(api.UpdateService)) - // router.HandleFunc("DELETE /api/engine/service/delete", ADMINAUTH(api.DeleteService)) - - ADMINAUTH := middleware.MiddlewareChain(middleware.Logging, middleware.Authentication("admin")) + ADMINAUTH := middleware.MiddlewareChain(middleware.Logging, middleware.Authentication(auth.RoleAdmin)) mux.HandleFunc("POST /api/engine/pause", ADMINAUTH(api.PauseEngine)) mux.HandleFunc("GET /api/engine/reset", ADMINAUTH(api.ResetScores)) mux.HandleFunc("GET /api/engine", ADMINAUTH(api.GetEngine)) @@ -165,7 +153,6 @@ func (router *Router) Start() { mux.HandleFunc("POST /api/admin/teamchecks", ADMINAUTH(api.UpdateTeamChecks)) mux.HandleFunc("GET /api/engine/export/scores", ADMINAUTH(api.ExportScores)) - mux.HandleFunc("GET /api/engine/export/config", ADMINAUTH(api.ExportConfig)) // admin-only PCR routes mux.HandleFunc("GET /api/pcrs", ADMINAUTH(api.GetPcrs))