Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions engine/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package engine

const (
RedisEventReset = "reset"
RedisEventRoundFinish = "round_finish"
RedisQueueTasks = "tasks"
RedisQueueResults = "results"
RedisChannelEvents = "events"
)
4 changes: 0 additions & 4 deletions engine/db/services.go

This file was deleted.

26 changes: 13 additions & 13 deletions engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (se *ScoringEngine) Start() {
Password: os.Getenv("REDIS_PASSWORD"),
})

events := rdb.Subscribe(context.Background(), "events")
events := rdb.Subscribe(context.Background(), RedisChannelEvents)
defer events.Close()
eventsChannel := events.Channel()

Expand All @@ -130,7 +130,7 @@ func (se *ScoringEngine) Start() {
select {
case msg := <-eventsChannel:
slog.Info("Received message", "message", msg.Payload)
if msg.Payload == "reset" {
if msg.Payload == RedisEventReset {
slog.Info("Engine loop reset event received while waiting, quitting...")
return
} else {
Expand Down Expand Up @@ -159,7 +159,7 @@ func (se *ScoringEngine) Start() {
slog.Info(fmt.Sprintf("Round %d complete", se.CurrentRound))
se.CurrentRound++

se.RedisClient.Publish(context.Background(), "events", "round_finish")
se.RedisClient.Publish(context.Background(), RedisChannelEvents, RedisEventRoundFinish)
slog.Info(fmt.Sprintf("Round %d will start in %s, sleeping...", se.CurrentRound, time.Until(se.NextRoundStartTime).String()))
time.Sleep(time.Until(se.NextRoundStartTime))
}
Expand All @@ -185,13 +185,13 @@ func waitForReset() {
Password: os.Getenv("REDIS_PASSWORD"),
})

events := rdb.Subscribe(context.Background(), "events")
events := rdb.Subscribe(context.Background(), RedisChannelEvents)
defer events.Close()
eventsChannel := events.Channel()

for msg := range eventsChannel {
slog.Info("Received message", "message", msg.Payload)
if msg.Payload == "reset" {
if msg.Payload == RedisEventReset {
slog.Info("Reset event received, quitting...")
return
} else {
Expand Down Expand Up @@ -311,7 +311,7 @@ func (se *ScoringEngine) ResetScores() error {

// Flush Redis queues
ctx := context.Background()
keysToDelete := []string{"tasks", "results"}
keysToDelete := []string{RedisQueueTasks, RedisQueueResults}
for _, key := range keysToDelete {
if err := se.RedisClient.Del(ctx, key).Err(); err != nil {
slog.Error("Failed to clear Redis queue", "queue", key, "error", err)
Expand All @@ -320,7 +320,7 @@ func (se *ScoringEngine) ResetScores() error {
}

// Reset engine state
se.RedisClient.Publish(context.Background(), "events", "reset")
se.RedisClient.Publish(context.Background(), RedisChannelEvents, RedisEventReset)

se.CurrentRound = 1
se.uptimeMu.Lock()
Expand Down Expand Up @@ -361,7 +361,7 @@ func (se *ScoringEngine) rvb() error {
Password: os.Getenv("REDIS_PASSWORD"),
})

events := rdb.Subscribe(context.Background(), "events")
events := rdb.Subscribe(context.Background(), RedisChannelEvents)
defer events.Close()
eventsChannel := events.Channel()
//
Expand Down Expand Up @@ -389,10 +389,10 @@ func (se *ScoringEngine) rvb() error {
}

// Clear any stale tasks from previous rounds before enqueuing new ones
staleTasks := se.RedisClient.LLen(ctx, "tasks").Val()
staleTasks := se.RedisClient.LLen(ctx, RedisQueueTasks).Val()
if staleTasks > 0 {
slog.Warn("Clearing stale tasks from queue", "count", staleTasks, "round", se.CurrentRound)
se.RedisClient.Del(ctx, "tasks")
se.RedisClient.Del(ctx, RedisQueueTasks)
}

// 1) Enqueue
Expand Down Expand Up @@ -457,7 +457,7 @@ func (se *ScoringEngine) rvb() error {
slog.Error("failed to marshal service task", "error", err)
continue
}
se.RedisClient.RPush(ctx, "tasks", payload)
se.RedisClient.RPush(ctx, RedisQueueTasks, payload)
runners++
}
}
Expand All @@ -474,14 +474,14 @@ COLLECTION:
select {
case msg := <-eventsChannel:
slog.Info("Received message", "message", msg.Payload)
if msg.Payload == "reset" {
if msg.Payload == RedisEventReset {
slog.Info("Reset event received, quitting...")
return fmt.Errorf("reset event received")
} else {
continue
}
default:
val, err := se.RedisClient.BLPop(timeoutCtx, time.Until(se.NextRoundStartTime), "results").Result()
val, err := se.RedisClient.BLPop(timeoutCtx, time.Until(se.NextRoundStartTime), RedisQueueResults).Result()
if err == redis.Nil {
slog.Warn("Timeout waiting for results", "remaining", runners-i, "collected", i, "expected", runners)
results = []checks.Result{}
Expand Down
8 changes: 4 additions & 4 deletions runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ func runApp(err error) int {
slog.Info("runner started", "runner_id", runnerID, "redis_addr", redisAddr)

go func() {
events := rdb.Subscribe(context.Background(), "events")
events := rdb.Subscribe(context.Background(), engine.RedisChannelEvents)
defer events.Close()
eventsChannel := events.Channel()

for msg := range eventsChannel {
slog.Info("received message", "payload", msg.Payload)
if msg.Payload == "reset" {
if msg.Payload == engine.RedisEventReset {
slog.Info("reset event received, quitting")
os.Exit(0)
} else {
Expand Down Expand Up @@ -88,7 +88,7 @@ func runApp(err error) int {

func getNextTask(ctx context.Context, rdb *redis.Client) (*engine.Task, error) {
// Block until we get a task from the "tasks" list
val, err := rdb.BLPop(ctx, 0, "tasks").Result()
val, err := rdb.BLPop(ctx, 0, engine.RedisQueueTasks).Result()
if err != nil {
return nil, fmt.Errorf("failed to pop task: %w", err)
}
Expand Down Expand Up @@ -240,7 +240,7 @@ func handleTask(ctx context.Context, rdb *redis.Client, runner checks.Runner, ta
return
}

if err := rdb.RPush(ctx, "results", resultJSON).Err(); err != nil {
if err := rdb.RPush(ctx, engine.RedisQueueResults, resultJSON).Err(); err != nil {
slog.Error("failed to push result to Redis", "error", err)
return
}
Expand Down
4 changes: 0 additions & 4 deletions www/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,6 @@ func ExportScores(w http.ResponseWriter, r *http.Request) {
WriteJSON(w, http.StatusOK, data)
}

func ExportConfig(w http.ResponseWriter, r *http.Request) {

}

func GetActiveTasks(w http.ResponseWriter, r *http.Request) {
tasks, err := eng.GetActiveTasks()
if err != nil {
Expand Down
9 changes: 5 additions & 4 deletions www/api/announcements.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path"
"path/filepath"
"quotient/engine/db"
"quotient/www/auth"
"slices"
"time"

Expand All @@ -24,8 +25,8 @@ func GetAnnouncements(w http.ResponseWriter, r *http.Request) {

// if not admin filter out announcements that are not open yet
req_roles := r.Context().Value("roles").([]string)
if !slices.Contains(req_roles, "admin") {
if slices.Contains(req_roles, "red") && !conf.UISettings.ShowAnnouncementsForRedTeam {
if !slices.Contains(req_roles, auth.RoleAdmin) {
if slices.Contains(req_roles, auth.RoleRed) && !conf.UISettings.ShowAnnouncementsForRedTeam {
WriteJSON(w, http.StatusForbidden, map[string]any{"error": "Forbidden"})
return
}
Expand Down Expand Up @@ -75,8 +76,8 @@ func DownloadAnnouncementFile(w http.ResponseWriter, r *http.Request) {

// if not admin check if the announcement is open
req_roles := r.Context().Value("roles").([]string)
if !slices.Contains(req_roles, "admin") {
if slices.Contains(req_roles, "red") && !conf.UISettings.ShowAnnouncementsForRedTeam {
if !slices.Contains(req_roles, auth.RoleAdmin) {
if slices.Contains(req_roles, auth.RoleRed) && !conf.UISettings.ShowAnnouncementsForRedTeam {
WriteJSON(w, http.StatusForbidden, map[string]any{"error": "Forbidden"})
return
}
Expand Down
23 changes: 12 additions & 11 deletions www/api/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net/http"
"os"
"quotient/engine/config"
"quotient/www/auth"
"strings"
"time"

Expand Down Expand Up @@ -98,14 +99,14 @@ func Login(w http.ResponseWriter, r *http.Request) {
}

// check credentials
auth, err := auth(form.Username, form.Password)
authResult, err := authenticateUser(form.Username, form.Password)
if err != nil {
WriteJSON(w, http.StatusUnauthorized, map[string]any{"error": "Incorrect username/password"})
slog.Info("Failed logon", "username", form.Username)
return
}

cookie, err := CookieEncoder.Encode(COOKIENAME, auth)
cookie, err := CookieEncoder.Encode(COOKIENAME, authResult)
if err != nil {
WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "Authentication error"})
slog.Error(err.Error())
Expand Down Expand Up @@ -194,7 +195,7 @@ func Authenticate(w http.ResponseWriter, r *http.Request) (string, []string) {
return username, roles
}

func auth(username string, password string) (map[string]any, error) {
func authenticateUser(username string, password string) (map[string]any, error) {
for _, admin := range conf.Admin {
if username == admin.Name && password == admin.Pw {
return map[string]any{"username": username, "authSource": "local"}, nil
Expand Down Expand Up @@ -315,22 +316,22 @@ func findRolesByUsername(username string, authSource string) ([]string, error) {
if authSource == "local" {
for _, admin := range conf.Admin {
if username == admin.Name {
roles = append(roles, "admin")
roles = append(roles, auth.RoleAdmin)
}
}
for _, red := range conf.Red {
if username == red.Name {
roles = append(roles, "red")
roles = append(roles, auth.RoleRed)
}
}
for _, team := range conf.Team {
if username == team.Name {
roles = append(roles, "team")
roles = append(roles, auth.RoleTeam)
}
}
for _, inject := range conf.Inject {
if username == inject.Name {
roles = append(roles, "inject")
roles = append(roles, auth.RoleInject)
}
}

Expand Down Expand Up @@ -371,19 +372,19 @@ func findRolesByUsername(username string, authSource string) ([]string, error) {
for _, entry := range sr.Entries {
for _, memberOf := range entry.GetAttributeValues("memberOf") {
if memberOf == conf.LdapSettings.LdapAdminGroupDn {
roles = append(roles, "admin")
roles = append(roles, auth.RoleAdmin)
}

if memberOf == conf.LdapSettings.LdapRedGroupDn {
roles = append(roles, "red")
roles = append(roles, auth.RoleRed)
}

if memberOf == conf.LdapSettings.LdapTeamGroupDn {
roles = append(roles, "team")
roles = append(roles, auth.RoleTeam)
}

if memberOf == conf.LdapSettings.LdapInjectGroupDn {
roles = append(roles, "inject")
roles = append(roles, auth.RoleInject)
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion www/api/graphs.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"net/http"
"quotient/engine/db"
"quotient/www/auth"
"slices"
)

Expand Down Expand Up @@ -221,7 +222,7 @@ func GetUptimeStatus(w http.ResponseWriter, r *http.Request) {
func shouldScrub(r *http.Request) bool {
if r.Context().Value("roles") != nil {
req_roles := r.Context().Value("roles").([]string)
if slices.Contains(req_roles, "admin") {
if slices.Contains(req_roles, auth.RoleAdmin) {
return false
}
}
Expand Down
8 changes: 5 additions & 3 deletions www/api/injects.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"slices"
"time"

"quotient/www/auth"

"gorm.io/gorm"
)

Expand All @@ -24,7 +26,7 @@ func GetInjects(w http.ResponseWriter, r *http.Request) {

// if not admin filter out injects that are not open yet
req_roles := r.Context().Value("roles").([]string)
if !slices.Contains(req_roles, "admin") {
if !slices.Contains(req_roles, auth.RoleAdmin) {
openInjects := make([]db.InjectSchema, 0)
for _, a := range data {
if time.Now().After(a.OpenTime) {
Expand All @@ -40,7 +42,7 @@ func GetInjects(w http.ResponseWriter, r *http.Request) {
WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": err.Error()})
return
}
if !slices.Contains(req_roles, "admin") {
if !slices.Contains(req_roles, auth.RoleAdmin) {
var mySubmissions []db.SubmissionSchema
for _, submission := range data[i].Submissions {
if submission.Team.Name == r.Context().Value("username") {
Expand Down Expand Up @@ -88,7 +90,7 @@ func DownloadInjectFile(w http.ResponseWriter, r *http.Request) {

// if not admin, check if the inject is open
req_roles := r.Context().Value("roles").([]string)
if !slices.Contains(req_roles, "admin") && time.Now().Before(inject.OpenTime) {
if !slices.Contains(req_roles, auth.RoleAdmin) && time.Now().Before(inject.OpenTime) {
WriteJSON(w, http.StatusNotFound, map[string]any{"error": "Inject not found"})
return
}
Expand Down
Loading
Loading