diff --git a/.gitignore b/.gitignore index d0aa5e7..4395d0e 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,4 @@ coverage.html # Editor/IDE # .idea/ # .vscode/ +cmd/argus/argus diff --git a/cmd/argus/main.go b/cmd/argus/main.go index 8aa7600..f707c94 100644 --- a/cmd/argus/main.go +++ b/cmd/argus/main.go @@ -90,7 +90,9 @@ func main() { "status": "healthy", } - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + slog.Error("Failed to encode health response", "error", err) + } }) // Version endpoint @@ -105,7 +107,9 @@ func main() { "service": "argus", } - json.NewEncoder(w).Encode(response) + if err := json.NewEncoder(w).Encode(response); err != nil { + slog.Error("Failed to encode version response", "error", err) + } }) // Initialize v1 API with database-agnostic repository diff --git a/docs/API.md b/docs/API.md index fedae1a..adbd8ca 100644 --- a/docs/API.md +++ b/docs/API.md @@ -32,12 +32,14 @@ Complete API reference for integrating Argus into your microservices architectur | `actorId` | string | Required | Actor identifier (email, UUID, service name) | | `targetType` | string | Required | Target type: `SERVICE` or `RESOURCE` | | `traceId` | string (UUID) | Optional | Trace ID for distributed tracing | -| `eventType` | string | Optional | Custom event type (e.g., `MANAGEMENT_EVENT`) | -| `eventAction` | string | Optional | Action: `CREATE`, `READ`, `UPDATE`, `DELETE` | +| `eventType` | string | Optional | Custom event type (e.g. `MANAGEMENT_EVENT`) | +| `action` | string | Required | Action: `CREATE`, `READ`, `UPDATE`, `DELETE` | | `targetId` | string | Optional | Target identifier | -| `requestMetadata` | object | Optional | Request payload (without PII/sensitive data) | -| `responseMetadata` | object | Optional | Response or error details | -| `additionalMetadata` | object | Optional | Additional context-specific data | +| `metadata` | object | Optional | Consolidated context-specific data | +| `message` | string | Optional | Raw message or payload for signing (Base64) | +| `signature` | string | Optional | Base64 encoded digital signature | +| `signatureAlgorithm` | string | Optional | Signature algorithm used (e.g. `RS256`, `EdDSA`) | +| `publicKeyId` | string | Optional | Identifier for the key used to sign the event | **Example Request:** @@ -48,14 +50,16 @@ curl -X POST http://localhost:3001/api/audit-logs \ "traceId": "550e8400-e29b-41d4-a716-446655440000", "timestamp": "2024-01-20T10:00:00Z", "eventType": "MANAGEMENT_EVENT", - "eventAction": "READ", + "action": "READ", "status": "SUCCESS", "actorType": "SERVICE", "actorId": "my-service", "targetType": "SERVICE", "targetId": "target-service", - "requestMetadata": {"schemaId": "schema-123"}, - "responseMetadata": {"decision": "ALLOWED"} + "metadata": {"schemaId": "schema-123", "decision": "ALLOWED"}, + "signature": "base64-encoded-signature", + "signatureAlgorithm": "RS256", + "publicKeyId": "nsw-key-1" }' ``` @@ -72,6 +76,8 @@ curl -X POST http://localhost:3001/api/audit-logs \ "actorId": "my-service", "targetType": "SERVICE", "targetId": "target-service", + "action": "READ", + "metadata": {"schemaId": "schema-123", "decision": "ALLOWED"}, "createdAt": "2024-01-20T10:00:00.123456Z" } ``` @@ -96,7 +102,7 @@ curl -X POST http://localhost:3001/api/audit-logs \ | ------------- | ------------- | -------- | ------- | ----------------------------------------- | | `traceId` | string (UUID) | Optional | - | Filter by trace ID | | `eventType` | string | Optional | - | Filter by event type | -| `eventAction` | string | Optional | - | Filter by event action | +| `action` | string | Optional | - | Filter by event action | | `status` | string | Optional | - | Filter by status (`SUCCESS` or `FAILURE`) | | `limit` | integer | Optional | 100 | Max results per page (1-1000) | | `offset` | integer | Optional | 0 | Number of results to skip | @@ -132,6 +138,8 @@ curl http://localhost:3001/api/audit-logs?eventType=MANAGEMENT_EVENT&status=SUCC "actorId": "my-service", "targetType": "SERVICE", "targetId": "target-service", + "action": "READ", + "metadata": {"schemaId": "schema-123", "decision": "ALLOWED"}, "createdAt": "2024-01-20T10:00:00.123456Z" } ], @@ -245,11 +253,9 @@ Always use RFC3339 format (ISO 8601): ```json { - "requestMetadata": { + "metadata": { "schemaId": "schema-123", - "requestedFields": ["name", "address"] - }, - "responseMetadata": { + "requestedFields": ["name", "address"], "decision": "ALLOWED", "fieldsReturned": 2 } @@ -263,7 +269,7 @@ Always log failed operations: ```json { "status": "FAILURE", - "responseMetadata": { + "metadata": { "error": "operation_failed", "errorMessage": "Resource not found", "errorCode": "404" diff --git a/internal/api/v1/database/database.go b/internal/api/v1/database/database.go index 4e28f22..39ed8cc 100644 --- a/internal/api/v1/database/database.go +++ b/internal/api/v1/database/database.go @@ -21,10 +21,10 @@ type AuditRepository interface { // AuditLogFilters represents query filters for retrieving audit logs type AuditLogFilters struct { - TraceID *string - EventType *string - EventAction *string - Status *string - Limit int - Offset int + TraceID *string + EventType *string + Action *string + Status *string + Limit int + Offset int } diff --git a/internal/api/v1/database/gorm_repository.go b/internal/api/v1/database/gorm_repository.go index fcc24e0..44a3298 100644 --- a/internal/api/v1/database/gorm_repository.go +++ b/internal/api/v1/database/gorm_repository.go @@ -16,6 +16,9 @@ type GormRepository struct { // NewGormRepository creates a new repository (works with SQLite or PostgreSQL) func NewGormRepository(db *gorm.DB) *GormRepository { + // Run pre-migration helper to handle breaking changes safely + runPreMigration(db) + // Auto-migrate the audit_logs table if err := db.AutoMigrate(&models.AuditLog{}); err != nil { // Log migration error but don't fail service creation @@ -25,6 +28,30 @@ func NewGormRepository(db *gorm.DB) *GormRepository { return &GormRepository{db: db} } +// runPreMigration handles breaking schema changes before GORM's AutoMigrate +func runPreMigration(db *gorm.DB) { + if !db.Migrator().HasTable("audit_logs") { + return + } + + // Handle column rename: event_action -> action + if db.Migrator().HasColumn("audit_logs", "event_action") && !db.Migrator().HasColumn("audit_logs", "action") { + slog.Info("Renaming column event_action to action in audit_logs table") + if err := db.Migrator().RenameColumn("audit_logs", "event_action", "action"); err != nil { + slog.Error("Failed to rename column event_action to action", "error", err) + } + } + + // Handle NULLs for non-nullable conversion to avoid migration failures + slog.Info("Ensuring no NULL values in event_type and action columns before applying NOT NULL constraint") + if err := db.Exec("UPDATE audit_logs SET event_type = '' WHERE event_type IS NULL").Error; err != nil { + slog.Warn("Failed to update NULL event_type values", "error", err) + } + if err := db.Exec("UPDATE audit_logs SET action = '' WHERE action IS NULL").Error; err != nil { + slog.Warn("Failed to update NULL action values", "error", err) + } +} + // CreateAuditLog creates a new audit log entry func (r *GormRepository) CreateAuditLog(ctx context.Context, log *models.AuditLog) (*models.AuditLog, error) { result := r.db.WithContext(ctx).Create(log) @@ -64,8 +91,8 @@ func (r *GormRepository) GetAuditLogs(ctx context.Context, filters *AuditLogFilt if filters.EventType != nil && *filters.EventType != "" { query = query.Where("event_type = ?", *filters.EventType) } - if filters.EventAction != nil && *filters.EventAction != "" { - query = query.Where("event_action = ?", *filters.EventAction) + if filters.Action != nil && *filters.Action != "" { + query = query.Where("action = ?", *filters.Action) } if filters.Status != nil && *filters.Status != "" { query = query.Where("status = ?", *filters.Status) diff --git a/internal/api/v1/handlers/audit_handler.go b/internal/api/v1/handlers/audit_handler.go index 99c74b2..b6bb6a6 100644 --- a/internal/api/v1/handlers/audit_handler.go +++ b/internal/api/v1/handlers/audit_handler.go @@ -34,6 +34,14 @@ func (h *AuditHandler) CreateAuditLog(w http.ResponseWriter, r *http.Request) { return } + // Validation for signed events + if req.Signature != "" || req.PublicKeyID != "" || req.SignatureAlgorithm != "" { + if req.Signature == "" || req.PublicKeyID == "" || req.SignatureAlgorithm == "" { + utils.RespondWithError(w, http.StatusBadRequest, "Invalid signed event: signature, publicKeyId, and signatureAlgorithm must all be provided if any are present", nil) + return + } + } + // Validation is handled by the service layer (auditLog.Validate()) auditLog, err := h.service.CreateAuditLog(r.Context(), &req) if err != nil { @@ -94,7 +102,7 @@ func (h *AuditHandler) GetAuditLogs(w http.ResponseWriter, r *http.Request) { logs, total, err := h.service.GetAuditLogs(r.Context(), traceIDPtr, eventTypePtr, limit, offset) if err != nil { - // Check if it's a validation error (e.g., invalid traceId format from service layer) + // Check if it's a validation error (e.g. invalid traceId format from service layer) if services.IsValidationError(err) { utils.RespondWithError(w, http.StatusBadRequest, "Invalid query parameters", err) return diff --git a/internal/api/v1/models/audit_log.go b/internal/api/v1/models/audit_log.go index 2e83629..267c241 100644 --- a/internal/api/v1/models/audit_log.go +++ b/internal/api/v1/models/audit_log.go @@ -119,9 +119,9 @@ type AuditLog struct { TraceID *uuid.UUID `gorm:"index:idx_audit_logs_trace_id" json:"traceId,omitempty"` // Event Classification - Status string `gorm:"type:varchar(20);not null;index:idx_audit_logs_status" json:"status"` - EventType *string `gorm:"type:varchar(50)" json:"eventType,omitempty"` // e.g., MANAGEMENT_EVENT, USER_MANAGEMENT (user-defined custom names) - EventAction *string `gorm:"type:varchar(50)" json:"eventAction,omitempty"` // e.g., CREATE, READ, UPDATE, DELETE + Status string `gorm:"type:varchar(20);not null;index:idx_audit_logs_status" json:"status"` + EventType string `gorm:"type:varchar(50);not null;default:''" json:"eventType,omitempty"` // e.g. MANAGEMENT_EVENT, USER_MANAGEMENT + Action string `gorm:"type:varchar(50);not null;default:''" json:"action,omitempty"` // e.g. CREATE, READ, UPDATE, DELETE // Actor Information (unified approach) ActorType string `gorm:"type:varchar(50);not null" json:"actorType"` @@ -133,9 +133,13 @@ type AuditLog struct { // Metadata (Payload without PII/sensitive data) // Using JSONBRawMessage to properly handle PostgreSQL JSONB and SQLite TEXT - RequestMetadata JSONBRawMessage `gorm:"type:jsonb" json:"requestMetadata,omitempty"` // Request payload without PII/sensitive data - ResponseMetadata JSONBRawMessage `gorm:"type:jsonb" json:"responseMetadata,omitempty"` // Response or Error details - AdditionalMetadata JSONBRawMessage `gorm:"type:jsonb" json:"additionalMetadata,omitempty"` // Additional context-specific data + Message JSONBRawMessage `gorm:"type:blob" json:"message,omitempty"` // Raw message or payload (e.g. for signing) + Metadata JSONBRawMessage `gorm:"type:jsonb" json:"metadata,omitempty"` // Consolidated metadata + + // Security & Non-Repudiation + Signature string `gorm:"type:text" json:"signature,omitempty"` + SignatureAlgorithm string `gorm:"type:varchar(50)" json:"signatureAlgorithm,omitempty"` + PublicKeyID string `gorm:"type:varchar(255)" json:"publicKeyId,omitempty"` // BaseModel provides CreatedAt BaseModel @@ -204,33 +208,25 @@ func (l *AuditLog) Validate() error { } } - // Validate event_type if provided (nullable field, using config's O(1) validation method) - if l.EventType != nil && *l.EventType != "" { + // Validate event_type if provided + if l.EventType != "" { if enumConfig != nil { - if !enumConfig.IsValidEventType(*l.EventType) { - return fmt.Errorf("invalid eventType: %s", *l.EventType) - } - } else { - // Fallback to default event types when config is not loaded - // Use config.DefaultEnums to avoid duplication (access fields directly to avoid copying sync.Once) - if !contains(config.DefaultEnums.EventTypes, *l.EventType) { - return fmt.Errorf("invalid eventType: %s", *l.EventType) + if !enumConfig.IsValidEventType(l.EventType) { + return fmt.Errorf("invalid eventType: %s", l.EventType) } + } else if !contains(config.DefaultEnums.EventTypes, l.EventType) { + return fmt.Errorf("invalid eventType: %s", l.EventType) } } - // Validate event_action if provided (nullable field, using config's O(1) validation method) - if l.EventAction != nil && *l.EventAction != "" { + // Validate action if provided + if l.Action != "" { if enumConfig != nil { - if !enumConfig.IsValidEventAction(*l.EventAction) { - return fmt.Errorf("invalid eventAction: %s", *l.EventAction) - } - } else { - // Fallback to default actions when config is not loaded - // Use config.DefaultEnums to avoid duplication (access fields directly to avoid copying sync.Once) - if !contains(config.DefaultEnums.EventActions, *l.EventAction) { - return fmt.Errorf("invalid eventAction: %s", *l.EventAction) + if !enumConfig.IsValidEventAction(l.Action) { + return fmt.Errorf("invalid action: %s", l.Action) } + } else if !contains(config.DefaultEnums.EventActions, l.Action) { + return fmt.Errorf("invalid action: %s", l.Action) } } diff --git a/internal/api/v1/models/audit_log_test.go b/internal/api/v1/models/audit_log_test.go index 562808b..7f2d21c 100644 --- a/internal/api/v1/models/audit_log_test.go +++ b/internal/api/v1/models/audit_log_test.go @@ -106,7 +106,7 @@ func TestAuditLog_Validate_WithConfig(t *testing.T) { name: "Valid event type from config", log: AuditLog{ Status: StatusSuccess, - EventType: stringPtr("MANAGEMENT_EVENT"), + EventType: "MANAGEMENT_EVENT", ActorType: "SERVICE", ActorID: "service-1", TargetType: "SERVICE", @@ -118,7 +118,7 @@ func TestAuditLog_Validate_WithConfig(t *testing.T) { name: "Invalid event type (not in config)", log: AuditLog{ Status: StatusSuccess, - EventType: stringPtr("INVALID_EVENT"), + EventType: "INVALID_EVENT", ActorType: "SERVICE", ActorID: "service-1", TargetType: "SERVICE", @@ -129,24 +129,24 @@ func TestAuditLog_Validate_WithConfig(t *testing.T) { { name: "Valid event action from config", log: AuditLog{ - Status: StatusSuccess, - EventAction: stringPtr("CREATE"), - ActorType: "SERVICE", - ActorID: "service-1", - TargetType: "SERVICE", - TargetID: stringPtr("service-2"), + Status: StatusSuccess, + Action: "CREATE", + ActorType: "SERVICE", + ActorID: "service-1", + TargetType: "SERVICE", + TargetID: stringPtr("service-2"), }, wantErr: false, }, { name: "Invalid event action (not in config)", log: AuditLog{ - Status: StatusSuccess, - EventAction: stringPtr("INVALID_ACTION"), - ActorType: "SERVICE", - ActorID: "service-1", - TargetType: "SERVICE", - TargetID: stringPtr("service-2"), + Status: StatusSuccess, + Action: "INVALID_ACTION", + ActorType: "SERVICE", + ActorID: "service-1", + TargetType: "SERVICE", + TargetID: stringPtr("service-2"), }, wantErr: true, }, @@ -264,15 +264,15 @@ func TestAuditLog_Validate_EdgeCases(t *testing.T) { wantErr: true, }, { - name: "Valid with nil optional fields", + name: "Valid with empty optional fields", log: AuditLog{ - Status: StatusSuccess, - ActorType: "SERVICE", - ActorID: "service-1", - TargetType: "SERVICE", - EventType: nil, - EventAction: nil, - TargetID: nil, + Status: StatusSuccess, + ActorType: "SERVICE", + ActorID: "service-1", + TargetType: "SERVICE", + EventType: "", + Action: "", + TargetID: nil, }, wantErr: false, }, diff --git a/internal/api/v1/models/request_dtos.go b/internal/api/v1/models/request_dtos.go index 9507b7b..e99aeda 100644 --- a/internal/api/v1/models/request_dtos.go +++ b/internal/api/v1/models/request_dtos.go @@ -10,9 +10,9 @@ type CreateAuditLogRequest struct { Timestamp string `json:"timestamp" validate:"required"` // ISO 8601 format, required // Event Classification - EventType *string `json:"eventType,omitempty"` // MANAGEMENT_EVENT, USER_MANAGEMENT (user-defined custom names) - EventAction *string `json:"eventAction,omitempty"` // CREATE, READ, UPDATE, DELETE - Status string `json:"status" validate:"required"` // SUCCESS, FAILURE + EventType string `json:"eventType,omitempty"` // MANAGEMENT_EVENT, USER_MANAGEMENT + Action string `json:"action" validate:"required"` // CREATE, READ, UPDATE, DELETE + Status string `json:"status" validate:"required"` // SUCCESS, FAILURE // Actor Information (unified approach) ActorType string `json:"actorType" validate:"required"` // SERVICE, ADMIN, MEMBER, SYSTEM @@ -22,10 +22,12 @@ type CreateAuditLogRequest struct { TargetType string `json:"targetType" validate:"required"` // SERVICE, RESOURCE TargetID *string `json:"targetId,omitempty"` // resource_id or service_name - // Metadata (Payload without PII/sensitive data) - // Using JSONBRawMessage instead of json.RawMessage to avoid type conversion - // JSONBRawMessage implements json.Unmarshaler, so it works seamlessly with JSON decoding - RequestMetadata JSONBRawMessage `json:"requestMetadata,omitempty"` // Request payload without PII/sensitive data - ResponseMetadata JSONBRawMessage `json:"responseMetadata,omitempty"` // Response or Error details - AdditionalMetadata JSONBRawMessage `json:"additionalMetadata,omitempty"` // Additional context-specific data + // Metadata + Message []byte `json:"message,omitempty"` // Raw message or payload for signing + Metadata map[string]interface{} `json:"metadata,omitempty"` // Consolidated metadata + + // Security & Non-Repudiation + Signature string `json:"signature,omitempty"` + SignatureAlgorithm string `json:"signatureAlgorithm,omitempty"` + PublicKeyID string `json:"publicKeyId,omitempty"` } diff --git a/internal/api/v1/models/response_dtos.go b/internal/api/v1/models/response_dtos.go index 36fcfde..c028d7f 100644 --- a/internal/api/v1/models/response_dtos.go +++ b/internal/api/v1/models/response_dtos.go @@ -13,9 +13,9 @@ type AuditLogResponse struct { Timestamp time.Time `json:"timestamp"` TraceID *uuid.UUID `json:"traceId,omitempty"` - EventType *string `json:"eventType,omitempty"` - EventAction *string `json:"eventAction,omitempty"` - Status string `json:"status"` + EventType string `json:"eventType,omitempty"` + Action string `json:"action,omitempty"` + Status string `json:"status"` ActorType string `json:"actorType"` ActorID string `json:"actorId"` @@ -23,9 +23,12 @@ type AuditLogResponse struct { TargetType string `json:"targetType"` TargetID *string `json:"targetId,omitempty"` - RequestMetadata json.RawMessage `json:"requestMetadata,omitempty"` - ResponseMetadata json.RawMessage `json:"responseMetadata,omitempty"` - AdditionalMetadata json.RawMessage `json:"additionalMetadata,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + + Message json.RawMessage `json:"message,omitempty"` + Signature string `json:"signature,omitempty"` + SignatureAlgorithm string `json:"signatureAlgorithm,omitempty"` + PublicKeyID string `json:"publicKeyId,omitempty"` CreatedAt time.Time `json:"createdAt"` } @@ -39,23 +42,23 @@ type GetAuditLogsResponse struct { } // ToAuditLogResponse converts an AuditLog model to an AuditLogResponse -// This encapsulates the mapping logic to keep handlers clean and reduce maintenance risk -// Converts JSONBRawMessage (database type) to json.RawMessage (API type) for proper separation of concerns func ToAuditLogResponse(log AuditLog) AuditLogResponse { return AuditLogResponse{ ID: log.ID, Timestamp: log.Timestamp, TraceID: log.TraceID, EventType: log.EventType, - EventAction: log.EventAction, + Action: log.Action, Status: log.Status, ActorType: log.ActorType, ActorID: log.ActorID, TargetType: log.TargetType, TargetID: log.TargetID, - RequestMetadata: json.RawMessage(log.RequestMetadata), - ResponseMetadata: json.RawMessage(log.ResponseMetadata), - AdditionalMetadata: json.RawMessage(log.AdditionalMetadata), + Metadata: json.RawMessage(log.Metadata), + Message: json.RawMessage(log.Message), + Signature: log.Signature, + SignatureAlgorithm: log.SignatureAlgorithm, + PublicKeyID: log.PublicKeyID, CreatedAt: log.CreatedAt, } } diff --git a/internal/api/v1/services/audit_service.go b/internal/api/v1/services/audit_service.go index ecd3b68..107f6dd 100644 --- a/internal/api/v1/services/audit_service.go +++ b/internal/api/v1/services/audit_service.go @@ -2,6 +2,7 @@ package services import ( "context" + "encoding/json" "fmt" "time" @@ -22,18 +23,26 @@ func NewAuditService(repo database.AuditRepository) *AuditService { // CreateAuditLog creates a new audit log entry from a request func (s *AuditService) CreateAuditLog(ctx context.Context, req *v1models.CreateAuditLogRequest) (*v1models.AuditLog, error) { + // Marshal metadata to JSONB + metaBytes, err := json.Marshal(req.Metadata) + if err != nil { + return nil, fmt.Errorf("%w: failed to marshal metadata: %w", ErrValidation, err) + } + // Convert request to model auditLog := &v1models.AuditLog{ EventType: req.EventType, - EventAction: req.EventAction, + Action: req.Action, Status: req.Status, ActorType: req.ActorType, ActorID: req.ActorID, TargetType: req.TargetType, TargetID: req.TargetID, - RequestMetadata: req.RequestMetadata, - ResponseMetadata: req.ResponseMetadata, - AdditionalMetadata: req.AdditionalMetadata, + Message: v1models.JSONBRawMessage(req.Message), + Metadata: v1models.JSONBRawMessage(metaBytes), + Signature: req.Signature, + SignatureAlgorithm: req.SignatureAlgorithm, + PublicKeyID: req.PublicKeyID, } // Parse and validate timestamp (required) diff --git a/internal/api/v1/services/audit_service_test.go b/internal/api/v1/services/audit_service_test.go index ed118eb..a66ca6f 100644 --- a/internal/api/v1/services/audit_service_test.go +++ b/internal/api/v1/services/audit_service_test.go @@ -31,20 +31,20 @@ func TestAuditService_CreateAuditLog(t *testing.T) { ActorID: "service-a", TargetType: "SERVICE", TargetID: v1testutil.StringPtr("service-b"), - EventType: v1testutil.StringPtr("MANAGEMENT_EVENT"), + EventType: "MANAGEMENT_EVENT", }, wantErr: false, }, { name: "Valid request with ADMIN actor", req: &v1models.CreateAuditLogRequest{ - Timestamp: time.Now().UTC().Format(time.RFC3339), - Status: v1models.StatusSuccess, - ActorType: "ADMIN", - ActorID: "admin@example.com", - TargetType: "RESOURCE", - TargetID: v1testutil.StringPtr("user-123"), - EventAction: v1testutil.StringPtr("CREATE"), + Timestamp: time.Now().UTC().Format(time.RFC3339), + Status: v1models.StatusSuccess, + ActorType: "ADMIN", + ActorID: "admin@example.com", + TargetType: "RESOURCE", + TargetID: v1testutil.StringPtr("user-123"), + Action: "CREATE", }, wantErr: false, }, @@ -94,7 +94,7 @@ func TestAuditService_CreateAuditLog(t *testing.T) { ActorID: "service-1", TargetType: "SERVICE", TargetID: v1testutil.StringPtr("service-2"), - EventType: v1testutil.StringPtr("INVALID_EVENT"), + EventType: "INVALID_EVENT", }, wantErr: true, }, @@ -136,13 +136,13 @@ func TestAuditService_CreateAuditLog(t *testing.T) { { name: "Invalid event action", req: &v1models.CreateAuditLogRequest{ - Timestamp: time.Now().UTC().Format(time.RFC3339), - Status: v1models.StatusSuccess, - ActorType: "SERVICE", - ActorID: "service-1", - TargetType: "SERVICE", - TargetID: v1testutil.StringPtr("service-1"), - EventAction: v1testutil.StringPtr("INVALID_ACTION"), + Timestamp: time.Now().UTC().Format(time.RFC3339), + Status: v1models.StatusSuccess, + ActorType: "SERVICE", + ActorID: "service-1", + TargetType: "SERVICE", + TargetID: v1testutil.StringPtr("service-1"), + Action: "INVALID_ACTION", }, wantErr: true, }, diff --git a/internal/api/v1/testutil/mock_repository.go b/internal/api/v1/testutil/mock_repository.go index 8037f66..2cf47c9 100644 --- a/internal/api/v1/testutil/mock_repository.go +++ b/internal/api/v1/testutil/mock_repository.go @@ -80,14 +80,14 @@ func (m *MockRepository) GetAuditLogs(ctx context.Context, filters *database.Aud // Filter by EventType if matches && filters.EventType != nil && *filters.EventType != "" { - if log.EventType == nil || *log.EventType != *filters.EventType { + if log.EventType != *filters.EventType { matches = false } } - // Filter by EventAction - if matches && filters.EventAction != nil && *filters.EventAction != "" { - if log.EventAction == nil || *log.EventAction != *filters.EventAction { + // Filter by Action + if matches && filters.Action != nil && *filters.Action != "" { + if log.Action != *filters.Action { matches = false } } diff --git a/internal/api/v1/testutil/mock_repository_test.go b/internal/api/v1/testutil/mock_repository_test.go index a252808..463e67a 100644 --- a/internal/api/v1/testutil/mock_repository_test.go +++ b/internal/api/v1/testutil/mock_repository_test.go @@ -25,6 +25,7 @@ func TestMockRepository_GetAuditLogsByTraceID(t *testing.T) { eventType1 := "MANAGEMENT_EVENT" eventType2 := "USER_MANAGEMENT" eventType3 := "DATA_FETCH" + eventAction1 := "CREATE" // Logs for traceID1 (should be returned in chronological order) log1 := &v1models.AuditLog{ @@ -32,7 +33,7 @@ func TestMockRepository_GetAuditLogsByTraceID(t *testing.T) { Timestamp: time.Now().Add(-2 * time.Hour), TraceID: &traceID1, Status: v1models.StatusSuccess, - EventType: &eventType1, + EventType: eventType1, ActorType: "SERVICE", ActorID: "service-a", TargetType: "SERVICE", @@ -43,7 +44,7 @@ func TestMockRepository_GetAuditLogsByTraceID(t *testing.T) { Timestamp: time.Now().Add(-1 * time.Hour), TraceID: &traceID1, Status: v1models.StatusSuccess, - EventType: &eventType2, + EventType: eventType2, ActorType: "SERVICE", ActorID: "service-b", TargetType: "SERVICE", @@ -54,7 +55,7 @@ func TestMockRepository_GetAuditLogsByTraceID(t *testing.T) { Timestamp: time.Now(), TraceID: &traceID1, Status: v1models.StatusSuccess, - EventType: &eventType3, + EventType: eventType3, ActorType: "SERVICE", ActorID: "service-c", TargetType: "SERVICE", @@ -66,10 +67,12 @@ func TestMockRepository_GetAuditLogsByTraceID(t *testing.T) { Timestamp: time.Now(), TraceID: &traceID2, Status: v1models.StatusSuccess, - EventType: &eventType1, + EventType: eventType1, + Action: eventAction1, // Added Action field ActorType: "SERVICE", ActorID: "service-a", TargetType: "SERVICE", + TargetID: nil, // Added TargetID field } // Create logs in the repository @@ -126,48 +129,48 @@ func TestMockRepository_GetAuditLogs_Filtering(t *testing.T) { // Create various audit logs logs := []*v1models.AuditLog{ { - ID: uuid.New(), - Timestamp: time.Now().Add(-3 * time.Hour), - TraceID: &traceID1, - Status: v1models.StatusSuccess, - EventType: &eventType1, - EventAction: &eventAction1, - ActorType: "SERVICE", - ActorID: "service-a", - TargetType: "SERVICE", + ID: uuid.New(), + Timestamp: time.Now().Add(-3 * time.Hour), + TraceID: &traceID1, + Status: v1models.StatusSuccess, + EventType: eventType1, + Action: eventAction1, + ActorType: "SERVICE", + ActorID: "service-a", + TargetType: "SERVICE", }, { - ID: uuid.New(), - Timestamp: time.Now().Add(-2 * time.Hour), - TraceID: &traceID1, - Status: v1models.StatusSuccess, - EventType: &eventType2, - EventAction: &eventAction1, - ActorType: "SERVICE", - ActorID: "service-b", - TargetType: "SERVICE", + ID: uuid.New(), + Timestamp: time.Now().Add(-2 * time.Hour), + TraceID: &traceID1, + Status: v1models.StatusSuccess, + EventType: eventType2, + Action: eventAction1, + ActorType: "SERVICE", + ActorID: "service-b", + TargetType: "SERVICE", }, { - ID: uuid.New(), - Timestamp: time.Now().Add(-1 * time.Hour), - TraceID: &traceID2, - Status: v1models.StatusFailure, - EventType: &eventType1, - EventAction: &eventAction2, - ActorType: "SERVICE", - ActorID: "service-a", - TargetType: "SERVICE", + ID: uuid.New(), + Timestamp: time.Now().Add(-1 * time.Hour), + TraceID: &traceID2, + Status: v1models.StatusFailure, + EventType: eventType1, + Action: eventAction2, + ActorType: "SERVICE", + ActorID: "service-a", + TargetType: "SERVICE", }, { - ID: uuid.New(), - Timestamp: time.Now(), - TraceID: &traceID2, - Status: v1models.StatusSuccess, - EventType: &eventType2, - EventAction: &eventAction1, - ActorType: "SERVICE", - ActorID: "service-b", - TargetType: "SERVICE", + ID: uuid.New(), + Timestamp: time.Now(), + TraceID: &traceID2, + Status: v1models.StatusSuccess, + EventType: eventType2, + Action: eventAction1, + ActorType: "SERVICE", + ActorID: "service-b", + TargetType: "SERVICE", }, } @@ -202,8 +205,7 @@ func TestMockRepository_GetAuditLogs_Filtering(t *testing.T) { assert.Equal(t, int64(2), total, "Should find 2 logs with eventType1") assert.Len(t, result, 2, "Should return 2 logs") for _, log := range result { - assert.NotNil(t, log.EventType) - assert.Equal(t, eventType1, *log.EventType) + assert.Equal(t, eventType1, log.EventType) } }) @@ -220,17 +222,16 @@ func TestMockRepository_GetAuditLogs_Filtering(t *testing.T) { }) // Test: Filter by EventAction - t.Run("FilterByEventAction", func(t *testing.T) { + t.Run("FilterByAction", func(t *testing.T) { filters := &database.AuditLogFilters{ - EventAction: &eventAction1, + Action: &eventAction1, } result, total, err := mockRepo.GetAuditLogs(ctx, filters) require.NoError(t, err) - assert.Equal(t, int64(3), total, "Should find 3 logs with eventAction1") + assert.Equal(t, int64(3), total, "Should find 3 logs with Action1") assert.Len(t, result, 3, "Should return 3 logs") for _, log := range result { - assert.NotNil(t, log.EventAction) - assert.Equal(t, eventAction1, *log.EventAction) + assert.Equal(t, eventAction1, log.Action) } }) @@ -245,7 +246,7 @@ func TestMockRepository_GetAuditLogs_Filtering(t *testing.T) { assert.Equal(t, int64(1), total, "Should find 1 log matching both filters") assert.Len(t, result, 1, "Should return 1 log") assert.Equal(t, traceID1, *result[0].TraceID) - assert.Equal(t, eventType1, *result[0].EventType) + assert.Equal(t, eventType1, result[0].EventType) }) // Test: No filters (should return all logs) @@ -271,7 +272,7 @@ func TestMockRepository_GetAuditLogs_Pagination(t *testing.T) { Timestamp: time.Now().Add(time.Duration(i) * time.Minute), TraceID: &traceID, Status: v1models.StatusSuccess, - EventType: &eventType, + EventType: eventType, ActorType: "SERVICE", ActorID: "test-service", TargetType: "SERVICE", @@ -354,7 +355,7 @@ func TestMockRepository_GetAuditLogs_Ordering(t *testing.T) { Timestamp: now.Add(-2 * time.Hour), TraceID: &traceID, Status: v1models.StatusSuccess, - EventType: &eventType, + EventType: eventType, ActorType: "SERVICE", ActorID: "test-service", TargetType: "SERVICE", @@ -364,7 +365,7 @@ func TestMockRepository_GetAuditLogs_Ordering(t *testing.T) { Timestamp: now, TraceID: &traceID, Status: v1models.StatusSuccess, - EventType: &eventType, + EventType: eventType, ActorType: "SERVICE", ActorID: "test-service", TargetType: "SERVICE", @@ -374,7 +375,7 @@ func TestMockRepository_GetAuditLogs_Ordering(t *testing.T) { Timestamp: now.Add(-1 * time.Hour), TraceID: &traceID, Status: v1models.StatusSuccess, - EventType: &eventType, + EventType: eventType, ActorType: "SERVICE", ActorID: "test-service", TargetType: "SERVICE", diff --git a/internal/config/config.go b/internal/config/config.go index 965df5d..2195198 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -84,21 +84,12 @@ func LoadEnums(configPath string) (*AuditEnums, error) { return GetDefaultEnums(), nil } - // Use defaults for any missing enum arrays - // This matches the opendif-core audit-service approach for consistency + // Merge YAML config with defaults to ensure all required core enums are available enums := &config.Enums - if len(enums.EventTypes) == 0 { - enums.EventTypes = DefaultEnums.EventTypes - } - if len(enums.EventActions) == 0 { - enums.EventActions = DefaultEnums.EventActions - } - if len(enums.ActorTypes) == 0 { - enums.ActorTypes = DefaultEnums.ActorTypes - } - if len(enums.TargetTypes) == 0 { - enums.TargetTypes = DefaultEnums.TargetTypes - } + enums.EventTypes = mergeUniqueStrings(enums.EventTypes, DefaultEnums.EventTypes) + enums.EventActions = mergeUniqueStrings(enums.EventActions, DefaultEnums.EventActions) + enums.ActorTypes = mergeUniqueStrings(enums.ActorTypes, DefaultEnums.ActorTypes) + enums.TargetTypes = mergeUniqueStrings(enums.TargetTypes, DefaultEnums.TargetTypes) // Initialize maps for O(1) validation lookups enums.InitializeMaps() @@ -184,3 +175,23 @@ func GetEnvOrDefault(key, defaultValue string) string { } return defaultValue } + +// mergeUniqueStrings merges two string slices and removes duplicates +func mergeUniqueStrings(a, b []string) []string { + seen := make(map[string]struct{}, len(a)+len(b)) + result := make([]string, 0, len(a)+len(b)) + + for _, s := range a { + if _, ok := seen[s]; !ok { + seen[s] = struct{}{} + result = append(result, s) + } + } + for _, s := range b { + if _, ok := seen[s]; !ok { + seen[s] = struct{}{} + result = append(result, s) + } + } + return result +} diff --git a/internal/database/client_test.go b/internal/database/client_test.go index 58911e1..1de4b4e 100644 --- a/internal/database/client_test.go +++ b/internal/database/client_test.go @@ -17,25 +17,25 @@ import ( func TestNewDatabaseConfig(t *testing.T) { // Clean up environment variables after test defer func() { - os.Unsetenv("DB_TYPE") - os.Unsetenv("DB_PATH") - os.Unsetenv("DB_HOST") - os.Unsetenv("DB_PORT") - os.Unsetenv("DB_USERNAME") - os.Unsetenv("DB_PASSWORD") - os.Unsetenv("DB_NAME") - os.Unsetenv("DB_SSLMODE") - os.Unsetenv("DB_MAX_OPEN_CONNS") - os.Unsetenv("DB_MAX_IDLE_CONNS") - os.Unsetenv("DB_CONN_MAX_LIFETIME") - os.Unsetenv("DB_CONN_MAX_IDLE_TIME") + _ = os.Unsetenv("DB_TYPE") + _ = os.Unsetenv("DB_PATH") + _ = os.Unsetenv("DB_HOST") + _ = os.Unsetenv("DB_PORT") + _ = os.Unsetenv("DB_USERNAME") + _ = os.Unsetenv("DB_PASSWORD") + _ = os.Unsetenv("DB_NAME") + _ = os.Unsetenv("DB_SSLMODE") + _ = os.Unsetenv("DB_MAX_OPEN_CONNS") + _ = os.Unsetenv("DB_MAX_IDLE_CONNS") + _ = os.Unsetenv("DB_CONN_MAX_LIFETIME") + _ = os.Unsetenv("DB_CONN_MAX_IDLE_TIME") }() t.Run("Test 1: No Configuration - In-Memory SQLite", func(t *testing.T) { // Ensure no env vars are set - os.Unsetenv("DB_TYPE") - os.Unsetenv("DB_PATH") - os.Unsetenv("DB_HOST") + _ = os.Unsetenv("DB_TYPE") + _ = os.Unsetenv("DB_PATH") + _ = os.Unsetenv("DB_HOST") config := NewDatabaseConfig() @@ -56,32 +56,32 @@ func TestNewDatabaseConfig(t *testing.T) { sqlDB, err := db.DB() require.NoError(t, err, "Should be able to get underlying sql.DB") assert.NoError(t, sqlDB.Ping(), "Should be able to ping database") - sqlDB.Close() + _ = sqlDB.Close() }) t.Run("Test 2: SQLite with Custom Path", func(t *testing.T) { // Clear any existing env vars first - os.Unsetenv("DB_MAX_OPEN_CONNS") - os.Unsetenv("DB_MAX_IDLE_CONNS") - os.Unsetenv("DB_HOST") - os.Unsetenv("DB_PORT") - os.Unsetenv("DB_USERNAME") - os.Unsetenv("DB_PASSWORD") - os.Unsetenv("DB_NAME") - os.Unsetenv("DB_SSLMODE") + _ = os.Unsetenv("DB_MAX_OPEN_CONNS") + _ = os.Unsetenv("DB_MAX_IDLE_CONNS") + _ = os.Unsetenv("DB_HOST") + _ = os.Unsetenv("DB_PORT") + _ = os.Unsetenv("DB_USERNAME") + _ = os.Unsetenv("DB_PASSWORD") + _ = os.Unsetenv("DB_NAME") + _ = os.Unsetenv("DB_SSLMODE") // Create a temporary directory for custom path tempDir, err := os.MkdirTemp("", "audit_test_custom") require.NoError(t, err, "Should create temp directory") - defer os.RemoveAll(tempDir) + defer func() { _ = os.RemoveAll(tempDir) }() customPath := filepath.Join(tempDir, "custom_audit.db") // Set SQLite-specific env vars - os.Setenv("DB_TYPE", "sqlite") - os.Setenv("DB_PATH", customPath) - os.Setenv("DB_MAX_OPEN_CONNS", "10") - os.Setenv("DB_MAX_IDLE_CONNS", "5") + _ = os.Setenv("DB_TYPE", "sqlite") + _ = os.Setenv("DB_PATH", customPath) + _ = os.Setenv("DB_MAX_OPEN_CONNS", "10") + _ = os.Setenv("DB_MAX_IDLE_CONNS", "5") config := NewDatabaseConfig() @@ -100,7 +100,7 @@ func TestNewDatabaseConfig(t *testing.T) { sqlDB, err := db.DB() require.NoError(t, err, "Should be able to get underlying sql.DB") assert.NoError(t, sqlDB.Ping(), "Should be able to ping database") - sqlDB.Close() + _ = sqlDB.Close() // Verify database file was created assert.FileExists(t, customPath, "Database file should be created at custom path") @@ -108,20 +108,20 @@ func TestNewDatabaseConfig(t *testing.T) { t.Run("Test 2a: DB_PATH Alone (No DB_TYPE) - Should Use File-Based SQLite", func(t *testing.T) { // Clear all database env vars - os.Unsetenv("DB_TYPE") - os.Unsetenv("DB_HOST") - os.Unsetenv("DB_MAX_OPEN_CONNS") - os.Unsetenv("DB_MAX_IDLE_CONNS") + _ = os.Unsetenv("DB_TYPE") + _ = os.Unsetenv("DB_HOST") + _ = os.Unsetenv("DB_MAX_OPEN_CONNS") + _ = os.Unsetenv("DB_MAX_IDLE_CONNS") // Create a temporary directory for custom path tempDir, err := os.MkdirTemp("", "audit_test_dbpath_only") require.NoError(t, err, "Should create temp directory") - defer os.RemoveAll(tempDir) + defer func() { _ = os.RemoveAll(tempDir) }() customPath := filepath.Join(tempDir, "dbpath_only.db") // Set ONLY DB_PATH (no DB_TYPE) - os.Setenv("DB_PATH", customPath) + _ = os.Setenv("DB_PATH", customPath) config := NewDatabaseConfig() @@ -138,7 +138,7 @@ func TestNewDatabaseConfig(t *testing.T) { sqlDB, err := db.DB() require.NoError(t, err, "Should be able to get underlying sql.DB") assert.NoError(t, sqlDB.Ping(), "Should be able to ping database") - sqlDB.Close() + _ = sqlDB.Close() // Verify database file was created assert.FileExists(t, customPath, "Database file should be created") @@ -146,13 +146,13 @@ func TestNewDatabaseConfig(t *testing.T) { t.Run("Test 2b: DB_TYPE=sqlite Without DB_PATH - Should Use Default Path", func(t *testing.T) { // Clear all env vars - os.Unsetenv("DB_PATH") - os.Unsetenv("DB_HOST") - os.Unsetenv("DB_MAX_OPEN_CONNS") - os.Unsetenv("DB_MAX_IDLE_CONNS") + _ = os.Unsetenv("DB_PATH") + _ = os.Unsetenv("DB_HOST") + _ = os.Unsetenv("DB_MAX_OPEN_CONNS") + _ = os.Unsetenv("DB_MAX_IDLE_CONNS") // Set ONLY DB_TYPE=sqlite (no DB_PATH) - os.Setenv("DB_TYPE", "sqlite") + _ = os.Setenv("DB_TYPE", "sqlite") config := NewDatabaseConfig() @@ -169,17 +169,17 @@ func TestNewDatabaseConfig(t *testing.T) { sqlDB, err := db.DB() require.NoError(t, err, "Should be able to get underlying sql.DB") assert.NoError(t, sqlDB.Ping(), "Should be able to ping database") - sqlDB.Close() + _ = sqlDB.Close() }) t.Run("Test 3: SQLite In-Memory", func(t *testing.T) { // Clear any existing env vars first - os.Unsetenv("DB_MAX_OPEN_CONNS") - os.Unsetenv("DB_MAX_IDLE_CONNS") + _ = os.Unsetenv("DB_MAX_OPEN_CONNS") + _ = os.Unsetenv("DB_MAX_IDLE_CONNS") // Set SQLite in-memory configuration - os.Setenv("DB_TYPE", "sqlite") - os.Setenv("DB_PATH", ":memory:") + _ = os.Setenv("DB_TYPE", "sqlite") + _ = os.Setenv("DB_PATH", ":memory:") config := NewDatabaseConfig() @@ -196,24 +196,24 @@ func TestNewDatabaseConfig(t *testing.T) { sqlDB, err := db.DB() require.NoError(t, err, "Should be able to get underlying sql.DB") assert.NoError(t, sqlDB.Ping(), "Should be able to ping in-memory database") - sqlDB.Close() + _ = sqlDB.Close() }) t.Run("Test 4: PostgreSQL Configuration", func(t *testing.T) { // Clear any existing env vars first - os.Unsetenv("DB_PATH") - os.Unsetenv("DB_MAX_OPEN_CONNS") - os.Unsetenv("DB_MAX_IDLE_CONNS") + _ = os.Unsetenv("DB_PATH") + _ = os.Unsetenv("DB_MAX_OPEN_CONNS") + _ = os.Unsetenv("DB_MAX_IDLE_CONNS") // Set PostgreSQL-specific env vars - os.Setenv("DB_TYPE", "postgres") - os.Setenv("DB_HOST", "localhost") - os.Setenv("DB_PORT", "5432") - os.Setenv("DB_USERNAME", "testuser") - os.Setenv("DB_PASSWORD", "testpass") - os.Setenv("DB_NAME", "testdb") - os.Setenv("DB_SSLMODE", "disable") - os.Setenv("DB_MAX_OPEN_CONNS", "50") + _ = os.Setenv("DB_TYPE", "postgres") + _ = os.Setenv("DB_HOST", "localhost") + _ = os.Setenv("DB_PORT", "5432") + _ = os.Setenv("DB_USERNAME", "testuser") + _ = os.Setenv("DB_PASSWORD", "testpass") + _ = os.Setenv("DB_NAME", "testdb") + _ = os.Setenv("DB_SSLMODE", "disable") + _ = os.Setenv("DB_MAX_OPEN_CONNS", "50") config := NewDatabaseConfig() @@ -259,23 +259,23 @@ func TestNewDatabaseConfig(t *testing.T) { sqlDB, err := db.DB() require.NoError(t, err, "Should be able to get underlying sql.DB") assert.NoError(t, sqlDB.Ping(), "Should be able to ping database") - sqlDB.Close() + _ = sqlDB.Close() } }) t.Run("Test 5: Unknown DB_TYPE Defaults to File-Based SQLite", func(t *testing.T) { // Clear any existing env vars that might affect SQLite defaults - os.Unsetenv("DB_MAX_OPEN_CONNS") - os.Unsetenv("DB_MAX_IDLE_CONNS") - os.Unsetenv("DB_HOST") - os.Unsetenv("DB_PORT") - os.Unsetenv("DB_USERNAME") - os.Unsetenv("DB_PASSWORD") - os.Unsetenv("DB_NAME") - os.Unsetenv("DB_SSLMODE") - os.Unsetenv("DB_PATH") - - os.Setenv("DB_TYPE", "unknown_db") + _ = os.Unsetenv("DB_MAX_OPEN_CONNS") + _ = os.Unsetenv("DB_MAX_IDLE_CONNS") + _ = os.Unsetenv("DB_HOST") + _ = os.Unsetenv("DB_PORT") + _ = os.Unsetenv("DB_USERNAME") + _ = os.Unsetenv("DB_PASSWORD") + _ = os.Unsetenv("DB_NAME") + _ = os.Unsetenv("DB_SSLMODE") + _ = os.Unsetenv("DB_PATH") + + _ = os.Setenv("DB_TYPE", "unknown_db") config := NewDatabaseConfig() @@ -292,13 +292,13 @@ func TestNewDatabaseConfig(t *testing.T) { t.Run("Test 5a: DB_HOST Set Without DB_TYPE=postgres - Should Use In-Memory SQLite", func(t *testing.T) { // Clear all env vars - os.Unsetenv("DB_TYPE") - os.Unsetenv("DB_PATH") - os.Unsetenv("DB_MAX_OPEN_CONNS") - os.Unsetenv("DB_MAX_IDLE_CONNS") + _ = os.Unsetenv("DB_TYPE") + _ = os.Unsetenv("DB_PATH") + _ = os.Unsetenv("DB_MAX_OPEN_CONNS") + _ = os.Unsetenv("DB_MAX_IDLE_CONNS") // Set DB_HOST but NOT DB_TYPE=postgres - os.Setenv("DB_HOST", "localhost") + _ = os.Setenv("DB_HOST", "localhost") // Capture log output to verify warning is logged var logBuffer bytes.Buffer @@ -321,13 +321,13 @@ func TestNewDatabaseConfig(t *testing.T) { t.Run("Test 5b: DB_HOST Set With DB_TYPE=sqlite - Should Use File-Based SQLite", func(t *testing.T) { // Clear all env vars - os.Unsetenv("DB_PATH") - os.Unsetenv("DB_MAX_OPEN_CONNS") - os.Unsetenv("DB_MAX_IDLE_CONNS") + _ = os.Unsetenv("DB_PATH") + _ = os.Unsetenv("DB_MAX_OPEN_CONNS") + _ = os.Unsetenv("DB_MAX_IDLE_CONNS") // Set DB_TYPE=sqlite and DB_HOST (DB_HOST should be ignored/warned) - os.Setenv("DB_TYPE", "sqlite") - os.Setenv("DB_HOST", "localhost") + _ = os.Setenv("DB_TYPE", "sqlite") + _ = os.Setenv("DB_HOST", "localhost") // Capture log output to verify warning is logged var logBuffer bytes.Buffer @@ -349,19 +349,19 @@ func TestNewDatabaseConfig(t *testing.T) { t.Run("Test 6: PostgreSQL with Special Characters in Password", func(t *testing.T) { // Clear any existing env vars first - os.Unsetenv("DB_PATH") - os.Unsetenv("DB_MAX_OPEN_CONNS") - os.Unsetenv("DB_MAX_IDLE_CONNS") + _ = os.Unsetenv("DB_PATH") + _ = os.Unsetenv("DB_MAX_OPEN_CONNS") + _ = os.Unsetenv("DB_MAX_IDLE_CONNS") // Set PostgreSQL with special characters in password specialPassword := "p@ss w#rd!123" - os.Setenv("DB_TYPE", "postgres") - os.Setenv("DB_HOST", "localhost") - os.Setenv("DB_PORT", "5432") - os.Setenv("DB_USERNAME", "testuser") - os.Setenv("DB_PASSWORD", specialPassword) - os.Setenv("DB_NAME", "testdb") - os.Setenv("DB_SSLMODE", "disable") + _ = os.Setenv("DB_TYPE", "postgres") + _ = os.Setenv("DB_HOST", "localhost") + _ = os.Setenv("DB_PORT", "5432") + _ = os.Setenv("DB_USERNAME", "testuser") + _ = os.Setenv("DB_PASSWORD", specialPassword) + _ = os.Setenv("DB_NAME", "testdb") + _ = os.Setenv("DB_SSLMODE", "disable") config := NewDatabaseConfig() @@ -405,7 +405,7 @@ func TestNewDatabaseConfig(t *testing.T) { sqlDB, err := db.DB() require.NoError(t, err, "Should be able to get underlying sql.DB") assert.NoError(t, sqlDB.Ping(), "Should be able to ping database") - sqlDB.Close() + _ = sqlDB.Close() } }) } @@ -416,7 +416,7 @@ func TestConnectGormDB(t *testing.T) { if err != nil { t.Fatalf("Failed to create temp dir: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { _ = os.RemoveAll(tempDir) }() dbPath := filepath.Join(tempDir, "test.db") @@ -438,7 +438,7 @@ func TestConnectGormDB(t *testing.T) { sqlDB, err := db.DB() require.NoError(t, err, "Should be able to get underlying sql.DB") assert.NoError(t, sqlDB.Ping(), "Should be able to ping database") - sqlDB.Close() + _ = sqlDB.Close() }) t.Run("Connect to In-Memory SQLite", func(t *testing.T) { @@ -459,7 +459,7 @@ func TestConnectGormDB(t *testing.T) { sqlDB, err := db.DB() require.NoError(t, err, "Should be able to get underlying sql.DB") assert.NoError(t, sqlDB.Ping(), "Should be able to ping in-memory database") - sqlDB.Close() + _ = sqlDB.Close() }) t.Run("Connect to PostgreSQL (graceful failure if not available)", func(t *testing.T) { @@ -488,7 +488,7 @@ func TestConnectGormDB(t *testing.T) { sqlDB, err := db.DB() require.NoError(t, err, "Should be able to get underlying sql.DB") assert.NoError(t, sqlDB.Ping(), "Should be able to ping database") - sqlDB.Close() + _ = sqlDB.Close() } }) } diff --git a/internal/middleware/cors_test.go b/internal/middleware/cors_test.go index cf39dd5..5b72d32 100644 --- a/internal/middleware/cors_test.go +++ b/internal/middleware/cors_test.go @@ -11,7 +11,7 @@ import ( func TestDefaultCORSConfig(t *testing.T) { t.Run("DefaultOrigins", func(t *testing.T) { - os.Unsetenv("CORS_ALLOWED_ORIGINS") + _ = os.Unsetenv("CORS_ALLOWED_ORIGINS") config := DefaultCORSConfig() assert.Contains(t, config.AllowedOrigins, "http://localhost:5173") @@ -22,8 +22,8 @@ func TestDefaultCORSConfig(t *testing.T) { }) t.Run("CustomOriginsFromEnv", func(t *testing.T) { - os.Setenv("CORS_ALLOWED_ORIGINS", "http://example.com,https://test.com") - defer os.Unsetenv("CORS_ALLOWED_ORIGINS") + _ = os.Setenv("CORS_ALLOWED_ORIGINS", "http://example.com,https://test.com") + defer func() { _ = os.Unsetenv("CORS_ALLOWED_ORIGINS") }() config := DefaultCORSConfig() @@ -32,8 +32,8 @@ func TestDefaultCORSConfig(t *testing.T) { }) t.Run("WildcardOrigin", func(t *testing.T) { - os.Setenv("CORS_ALLOWED_ORIGINS", "*") - defer os.Unsetenv("CORS_ALLOWED_ORIGINS") + _ = os.Setenv("CORS_ALLOWED_ORIGINS", "*") + defer func() { _ = os.Unsetenv("CORS_ALLOWED_ORIGINS") }() config := DefaultCORSConfig() diff --git a/pkg/audit/client.go b/pkg/audit/client.go index 0a15360..dcfd145 100644 --- a/pkg/audit/client.go +++ b/pkg/audit/client.go @@ -3,13 +3,16 @@ package audit import ( "bytes" "context" + "crypto" "encoding/json" + "fmt" "io" "log/slog" "net/http" "net/url" "os" "strings" + "sync" "time" ) @@ -20,45 +23,105 @@ const ( DefaultHTTPTimeout = 10 * time.Second ) +// Config defines the configuration for the Audit Client +type Config struct { + BaseURL string + Signer SignPayloadFunc + PublicKeyID string + SignatureAlgorithm string // e.g. "RS256", "EdDSA" + WorkerCount int // Number of background workers, defaults to 5 + QueueSize int // Size of the internal channel, defaults to 100 + HTTPTimeout time.Duration // Defaults to 10s +} + // Client is a client for sending audit events to the audit service type Client struct { - baseURL string - httpClient *http.Client - enabled bool + baseURL string + httpClient *http.Client + enabled bool + signer SignPayloadFunc + publicKeyID string + signatureAlgorithm string + queue chan *AuditLogRequest + quit chan struct{} + wg sync.WaitGroup } -// NewClient creates a new audit client +// NewClient creates a new audit client using the provided configuration. // Audit can be disabled by: // - Setting ENABLE_AUDIT=false environment variable -// - Providing an empty baseURL +// - Providing an empty baseURL in config // // When disabled, all LogEvent calls will be no-ops. -func NewClient(baseURL string) *Client { - enabled := isAuditEnabled(baseURL) +func NewClient(cfg Config) *Client { + enabled := isAuditEnabled(cfg.BaseURL) if !enabled { slog.Info("Audit client disabled", "reason", "ENABLE_AUDIT=false or audit service URL not configured", "impact", "Services will continue running but audit events will not be logged") return &Client{ - baseURL: "", - httpClient: nil, - enabled: false, + enabled: false, } } - slog.Info("Audit client initialized", "baseURL", baseURL) - return &Client{ - baseURL: baseURL, + // Algorithm Hardening: Validate SignatureAlgorithm if a signer is provided + if cfg.Signer != nil { + switch cfg.SignatureAlgorithm { + case "RS256", "EdDSA": + // Valid + default: + slog.Error("Unsupported signature algorithm", "algorithm", cfg.SignatureAlgorithm) + return &Client{ + enabled: false, + } + } + } + + workerCount := cfg.WorkerCount + if workerCount <= 0 { + workerCount = 5 + } + + queueSize := cfg.QueueSize + if queueSize <= 0 { + queueSize = 100 + } + + timeout := cfg.HTTPTimeout + if timeout <= 0 { + timeout = DefaultHTTPTimeout + } + + c := &Client{ + baseURL: cfg.BaseURL, httpClient: &http.Client{ - Timeout: DefaultHTTPTimeout, + Timeout: timeout, Transport: &http.Transport{ MaxIdleConns: 100, MaxIdleConnsPerHost: 10, }, }, - enabled: true, + enabled: true, + signer: cfg.Signer, + publicKeyID: cfg.PublicKeyID, + signatureAlgorithm: cfg.SignatureAlgorithm, + queue: make(chan *AuditLogRequest, queueSize), + quit: make(chan struct{}), } + + // Start background workers + for i := 0; i < workerCount; i++ { + c.wg.Add(1) + go c.worker() + } + + slog.Info("Audit client initialized with async workers", + "baseURL", cfg.BaseURL, + "workers", workerCount, + "queueSize", queueSize) + + return c } // IsEnabled returns whether the audit client is enabled @@ -66,17 +129,138 @@ func (c *Client) IsEnabled() bool { return c.enabled } -// LogEvent sends an audit event to the audit service asynchronously (fire-and-forget) -// This function returns immediately and logs the event in a background goroutine. +// LogEvent sends an audit event to the audit service asynchronously via worker queue. func (c *Client) LogEvent(ctx context.Context, event *AuditLogRequest) { // Skip if audit client is not enabled - if !c.enabled || c.httpClient == nil { + if !c.enabled { + return + } + + // Push to queue + select { + case c.queue <- event: return + default: + slog.Warn("Audit queue full, dropping event", "action", event.Action) + } +} + +// LogSignedEvent logs an audit event that has already been signed. +// This is an alias for LogEvent intended for semantically clearer logging of signed events. +func (c *Client) LogSignedEvent(ctx context.Context, event *AuditLogRequest) { + c.LogEvent(ctx, event) +} + +// SignEvent generates a cryptographic signature for the given request +// using the registered SignPayloadFunc. +func (c *Client) SignEvent(event *AuditLogRequest) error { + if c.signer == nil { + return fmt.Errorf("no signer registered with the client") + } + + payload, err := CanonicalizeRequest(event) + if err != nil { + return fmt.Errorf("failed to canonicalize event: %w", err) + } + + // Using context.Background() for the signing callback as the original caller's context + // may have expired if called from the background worker. + sigBase64, err := c.signer(context.Background(), payload) + if err != nil { + return fmt.Errorf("failed to sign event: %w", err) + } + + event.Signature = sigBase64 + event.SignatureAlgorithm = c.signatureAlgorithm + event.PublicKeyID = c.publicKeyID + + return nil +} + +// Close gracefully shuts down the client, flushing the queue. +func (c *Client) Close(ctx context.Context) error { + if !c.enabled { + return nil } + close(c.quit) + close(c.queue) - // Log asynchronously (fire-and-forget) using background context - // Using background context ensures the request completes even if the original context is cancelled - go c.logEvent(context.Background(), event) + // Wait for workers to finish, but honor context timeout if provided + done := make(chan struct{}) + go func() { + c.wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// VerifyIntegrity verifies the signature of a log request. +// It uses the public key provided by the caller to verify the signature. +func (c *Client) VerifyIntegrity(event *AuditLogRequest, publicKey crypto.PublicKey) (bool, error) { + if event.Signature == "" { + return false, fmt.Errorf("event has no signature") + } + + payload, err := CanonicalizeRequest(event) + if err != nil { + return false, fmt.Errorf("failed to canonicalize event: %w", err) + } + + err = VerifyPayload(payload, event.Signature, event.SignatureAlgorithm, publicKey) + if err != nil { + return false, fmt.Errorf("verification failed: %w", err) + } + + return true, nil +} + +func (c *Client) worker() { + defer c.wg.Done() + for { + select { + case event, ok := <-c.queue: + if !ok { + return + } + + // Create a context with timeout for this specific event's processing + // Using the HTTP client's timeout as the base + ctx, cancel := context.WithTimeout(context.Background(), c.httpClient.Timeout) + + // Automatic signing if required + if event.ShouldSign { + var signErr error + for attempt := 1; attempt <= 3; attempt++ { + if signErr = c.SignEvent(event); signErr == nil { + break + } + slog.Warn("Failed to sign event in worker, retrying", + "attempt", attempt, + "maxAttempts", 3, + "error", signErr) + + // Small backoff before retry (optional, but good practice) + time.Sleep(100 * time.Millisecond) + } + + if signErr != nil { + slog.Error("Failed to sign event in worker after retries, dropping event", "error", signErr) + cancel() + continue + } + } + c.logEvent(ctx, event) + cancel() + case <-c.quit: + return + } + } } // logEvent sends the audit event to the audit service API @@ -130,12 +314,12 @@ func (c *Client) logEvent(ctx context.Context, event *AuditLogRequest) { } slog.Info("Audit event logged successfully", - "eventType", event.EventType, + "action", event.Action, "actorType", event.ActorType, "actorId", event.ActorID, "targetType", event.TargetType, "status", event.Status, - "additionalMetadata", string(event.AdditionalMetadata)) + "metadata", event.Metadata) } // isAuditEnabled checks if audit logging is enabled via environment variable diff --git a/pkg/audit/interface.go b/pkg/audit/interface.go index abd2549..e1c7207 100644 --- a/pkg/audit/interface.go +++ b/pkg/audit/interface.go @@ -1,6 +1,9 @@ package audit -import "context" +import ( + "context" + "crypto" +) // Auditor is the primary interface for audit logging operations. // This interface provides a clean abstraction for audit capabilities, @@ -8,24 +11,29 @@ import "context" // // Implementations should handle: // - Asynchronous logging (fire-and-forget) -// - Graceful degradation when audit service is unavailable +// - Graceful degradation when the audit service is unavailable // - Thread-safe operations type Auditor interface { - // LogEvent logs an audit event asynchronously. - // The implementation should handle the event in a background goroutine - // to avoid blocking the calling code. - // - // If the audit service is disabled or unavailable, this method should - // return immediately without error (graceful degradation). + // LogEvent queues a standard audit event for asynchronous processing LogEvent(ctx context.Context, event *AuditLogRequest) - // IsEnabled returns whether audit logging is currently enabled. - // This can be used by callers to skip expensive audit event preparation - // when audit logging is disabled. + // SignEvent generates a digital signature for the audit request + // using the registered SignPayloadFunc + SignEvent(event *AuditLogRequest) error + + // LogSignedEvent queues an event that already contains a cryptographic signature + LogSignedEvent(ctx context.Context, event *AuditLogRequest) + + // VerifyIntegrity validates a log's signature using a provided public key + VerifyIntegrity(event *AuditLogRequest, publicKey crypto.PublicKey) (bool, error) + + // IsEnabled returns true if the auditor is configured to process events IsEnabled() bool + + // Close gracefully shuts down the client and flushes the queue + Close(ctx context.Context) error } // AuditClient is an alias for Auditor to maintain backward compatibility. // Deprecated: Use Auditor instead. This will be removed in a future version. type AuditClient = Auditor - diff --git a/pkg/audit/models.go b/pkg/audit/models.go index 81ae21d..ee8d610 100644 --- a/pkg/audit/models.go +++ b/pkg/audit/models.go @@ -1,11 +1,6 @@ package audit -import ( - "encoding/json" -) - // AuditLogRequest represents the request payload for creating an audit log -// Any service can use this without importing the full audit service implementation type AuditLogRequest struct { // Trace & Correlation TraceID *string `json:"traceId,omitempty"` // UUID string, nullable for standalone events @@ -14,22 +9,27 @@ type AuditLogRequest struct { Timestamp string `json:"timestamp"` // ISO 8601 format, required // Event Classification - EventType *string `json:"eventType,omitempty"` // MANAGEMENT_EVENT, USER_MANAGEMENT (user-defined custom names) - EventAction *string `json:"eventAction,omitempty"` // CREATE, READ, UPDATE, DELETE - Status string `json:"status"` // SUCCESS, FAILURE + EventType string `json:"eventType,omitempty"` // MANAGEMENT_EVENT, USER_MANAGEMENT + Action string `json:"action"` // CREATE, READ, UPDATE, DELETE (Required by server) + Status string `json:"status"` // SUCCESS, FAILURE - // Actor Information (unified approach) + // Actor Information ActorType string `json:"actorType"` // SERVICE, ADMIN, MEMBER, SYSTEM - ActorID string `json:"actorId"` // email, uuid, or service-name (required) + ActorID string `json:"actorId"` // email, uuid, or service-name - // Target Information (unified approach) + // Target Information TargetType string `json:"targetType"` // SERVICE, RESOURCE TargetID *string `json:"targetId,omitempty"` // resource_id or service_name - // Metadata (Payload without PII/sensitive data) - RequestMetadata json.RawMessage `json:"requestMetadata,omitempty"` // Request payload without PII/sensitive data - ResponseMetadata json.RawMessage `json:"responseMetadata,omitempty"` // Response or Error details - AdditionalMetadata json.RawMessage `json:"additionalMetadata,omitempty"` // Additional context-specific data + // Payload & Metadata + Message []byte `json:"message"` // Specific blob for NSW/NPQS + Metadata map[string]interface{} `json:"metadata,omitempty"` // Consolidated metadata + + // Security & Non-Repudiation + ShouldSign bool `json:"-"` // Internal flag to trigger signing + Signature string `json:"signature,omitempty"` + SignatureAlgorithm string `json:"signatureAlgorithm,omitempty"` + PublicKeyID string `json:"publicKeyId,omitempty"` } // Audit log status constants diff --git a/pkg/audit/security.go b/pkg/audit/security.go new file mode 100644 index 0000000..270edb0 --- /dev/null +++ b/pkg/audit/security.go @@ -0,0 +1,89 @@ +package audit + +import ( + "context" + "crypto" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" +) + +// SignPayloadFunc is a strategy for signing payloads. +// It allows the client to provide its own signing logic (e.g. KMS, File-based, etc.) +// without exposing private keys to the audit library. +type SignPayloadFunc func(ctx context.Context, payload []byte) (signature string, err error) + +// CanonicalizeRequest serializes the AuditLogRequest deterministically. +// It ensures that signature fields are not included in the payload that gets signed. +func CanonicalizeRequest(event *AuditLogRequest) ([]byte, error) { + // Create a shallow copy and clear signature fields + eventCopy := *event + eventCopy.Signature = "" + eventCopy.SignatureAlgorithm = "" + eventCopy.PublicKeyID = "" + + // json.Marshal guarantees struct fields are serialized in declaration order. + // For maps (the new Metadata field), keys are sorted alphabetically. + return json.Marshal(&eventCopy) +} + +// SignPayload hashes and signs the canonical payload using the provided signer. +func SignPayload(payload []byte, signer crypto.Signer) (string, string, error) { + hash := sha256.Sum256(payload) + + switch signer.(type) { + case *rsa.PrivateKey: + sig, err := signer.Sign(rand.Reader, hash[:], crypto.SHA256) + if err != nil { + return "", "", fmt.Errorf("rsa signing failed: %w", err) + } + return base64.StdEncoding.EncodeToString(sig), "RS256", nil + case ed25519.PrivateKey: + // Ed25519 ignores the hash if passed, or uses the full message. + // Standard crypto.Signer.Sign for Ed25519 expects the original message, not a hash. + sig, err := signer.Sign(rand.Reader, payload, crypto.Hash(0)) + if err != nil { + return "", "", fmt.Errorf("ed25519 signing failed: %w", err) + } + return base64.StdEncoding.EncodeToString(sig), "EdDSA", nil + default: + return "", "", errors.New("unsupported signer type (only RSA and Ed25519 are supported)") + } +} + +// VerifyPayload verifies the signature of the payload using the provided public key +func VerifyPayload(payload []byte, signatureBase64 string, alg string, publicKey crypto.PublicKey) error { + sig, err := base64.StdEncoding.DecodeString(signatureBase64) + if err != nil { + return fmt.Errorf("invalid base64 signature: %w", err) + } + + hash := sha256.Sum256(payload) + + switch pk := publicKey.(type) { + case *rsa.PublicKey: + if alg != "RS256" { + return fmt.Errorf("algorithm mismatch: expected RS256 for RSA public key, got %s", alg) + } + err := rsa.VerifyPKCS1v15(pk, crypto.SHA256, hash[:], sig) + if err != nil { + return fmt.Errorf("rsa verification failed: %w", err) + } + return nil + case ed25519.PublicKey: + if alg != "EdDSA" { + return fmt.Errorf("algorithm mismatch: expected EdDSA for Ed25519 public key, got %s", alg) + } + if !ed25519.Verify(pk, payload, sig) { + return errors.New("ed25519 verification failed") + } + return nil + default: + return errors.New("unsupported public key type (only RSA and Ed25519 are supported)") + } +} diff --git a/pkg/audit/security_test.go b/pkg/audit/security_test.go new file mode 100644 index 0000000..92b27a4 --- /dev/null +++ b/pkg/audit/security_test.go @@ -0,0 +1,377 @@ +package audit + +import ( + "bytes" + "context" + "crypto" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "fmt" + "testing" + "time" +) + +func TestCanonicalizeRequest(t *testing.T) { + req := &AuditLogRequest{ + TraceID: func(s string) *string { return &s }("trace-123"), + Timestamp: "2023-01-01T00:00:00Z", + Status: StatusSuccess, + ActorType: "SERVICE", + ActorID: "actor-1", + TargetType: "SERVICE", + TargetID: func(s string) *string { return &s }("target-1"), + Metadata: map[string]interface{}{"key": "value"}, + Signature: "should-be-stripped", + SignatureAlgorithm: "RS256", + PublicKeyID: "key-1", + } + + b, err := CanonicalizeRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var parsed map[string]interface{} + if err := json.Unmarshal(b, &parsed); err != nil { + t.Fatalf("failed to decode canonical json: %v", err) + } + + if _, ok := parsed["signature"]; ok { + t.Errorf("expected signature to be stripped") + } + if _, ok := parsed["signature_algorithm"]; ok { + t.Errorf("expected signature_algorithm to be stripped") + } + if _, ok := parsed["public_key_id"]; ok { + t.Errorf("expected public_key_id to be stripped") + } +} + +func TestSignAndVerify_RSA(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate RSA key: %v", err) + } + + payload := []byte(`{"event":"test"}`) + + sigBase64, alg, err := SignPayload(payload, privateKey) + if err != nil { + t.Fatalf("signing failed: %v", err) + } + + if alg != "RS256" { + t.Errorf("expected algorithm RS256, got %s", alg) + } + + err = VerifyPayload(payload, sigBase64, alg, &privateKey.PublicKey) + if err != nil { + t.Errorf("verification failed: %v", err) + } +} + +func TestSignAndVerify_Ed25519(t *testing.T) { + // Generate an Ed25519 key pair + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate Ed25519 key: %v", err) + } + + payload := []byte(`{"event":"test"}`) + + // Using the private key which implements crypto.Signer + sigBase64, alg, err := SignPayload(payload, priv) + if err != nil { + t.Fatalf("signing failed: %v", err) + } + + if alg != "EdDSA" { + t.Errorf("expected algorithm EdDSA, got %s", alg) + } + + err = VerifyPayload(payload, sigBase64, alg, pub) + if err != nil { + t.Errorf("verification failed: %v", err) + } +} + +func TestVerifyPayload_Mismatch(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate RSA key: %v", err) + } + t.Log("Successfully generated first RSA key") + + payload := []byte(`{"event":"test"}`) + sigBase64, alg, err := SignPayload(payload, privateKey) + if err != nil { + t.Fatalf("signing failed: %v", err) + } + t.Logf("Successfully signed payload with %s", alg) + + privateKey2, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate second RSA key: %v", err) + } + t.Log("Successfully generated second RSA key") + + err = VerifyPayload(payload, sigBase64, alg, &privateKey2.PublicKey) + if err == nil { + t.Errorf("expected verification to fail with wrong key") + } else { + t.Logf("Verification failed as expected: %v", err) + } +} + +func TestVerifyPayload_Tampering(t *testing.T) { + privateKey, _ := rsa.GenerateKey(rand.Reader, 2048) + payload := []byte(`{"event":"test"}`) + sigBase64, alg, _ := SignPayload(payload, privateKey) + + // Tamper with payload + tamperedPayload := []byte(`{"event":"test!"}`) + err := VerifyPayload(tamperedPayload, sigBase64, alg, &privateKey.PublicKey) + if err == nil { + t.Errorf("expected verification to fail for tampered payload") + } +} + +func TestVerifyPayload_InvalidSignature(t *testing.T) { + publicKey, _, _ := ed25519.GenerateKey(rand.Reader) + payload := []byte(`{"event":"test"}`) + + // Completely invalid base64 + err := VerifyPayload(payload, "!!!", "EdDSA", publicKey) + if err == nil { + t.Errorf("expected error for invalid base64 signature") + } + + // Valid base64 but invalid signature data + err = VerifyPayload(payload, "bm90IGEgc2lnbmF0dXJl", "EdDSA", publicKey) + if err == nil { + t.Errorf("expected verification to fail for invalid signature data") + } +} + +func TestVerifyPayload_AlgorithmMismatch(t *testing.T) { + privateKey, _ := rsa.GenerateKey(rand.Reader, 2048) + payload := []byte(`{"event":"test"}`) + sigBase64, _, _ := SignPayload(payload, privateKey) + + // Try to verify RSA signature using EdDSA algorithm identifier + err := VerifyPayload(payload, sigBase64, "EdDSA", &privateKey.PublicKey) + if err == nil { + t.Errorf("expected error for algorithm mismatch (RSA key with EdDSA alg)") + } + + pub, _, _ := ed25519.GenerateKey(rand.Reader) + // Try to verify with Ed25519 key but RS256 algorithm + err = VerifyPayload(payload, sigBase64, "RS256", pub) + if err == nil { + t.Errorf("expected error for algorithm mismatch (Ed25519 key with RS256 alg)") + } +} + +func TestSignPayload_UnsupportedKey(t *testing.T) { + payload := []byte(`{"event":"test"}`) + // A mock signer that is not RSA or Ed25519 + type unsupportedSigner struct{ crypto.Signer } + _, _, err := SignPayload(payload, unsupportedSigner{}) + if err == nil { + t.Errorf("expected error for unsupported signer type") + } +} + +func TestCanonicalizeRequest_Consistency(t *testing.T) { + req1 := &AuditLogRequest{ActorID: "actor", Status: StatusSuccess} + req2 := &AuditLogRequest{ActorID: "actor", Status: StatusSuccess, Signature: "sig"} + + b1, _ := CanonicalizeRequest(req1) + b2, _ := CanonicalizeRequest(req2) + + if string(b1) != string(b2) { + t.Errorf("canonicalization should be invariant to signature fields") + } +} + +func TestClient_InterfaceImplementation(t *testing.T) { + // This test ensures Client implements Auditor interface + var _ Auditor = (*Client)(nil) +} + +func TestClient_SignAndVerify(t *testing.T) { + ctx := context.Background() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + signer := func(ctx context.Context, payload []byte) (string, error) { + sig, _, err := SignPayload(payload, priv) + return sig, err + } + + client := NewClient(Config{ + BaseURL: "http://localhost:8080", + Signer: signer, + PublicKeyID: "test-key-1", + SignatureAlgorithm: "RS256", + }) + defer client.Close(ctx) + + req := &AuditLogRequest{ + ActorID: "test-actor", + Status: StatusSuccess, + } + + // Sign + err = client.SignEvent(req) + if err != nil { + t.Fatalf("SignEvent failed: %v", err) + } + + if req.Signature == "" { + t.Error("expected signature to be populated") + } + if req.PublicKeyID != "test-key-1" { + t.Errorf("expected PublicKeyID test-key-1, got %s", req.PublicKeyID) + } + if req.SignatureAlgorithm != "RS256" { + t.Errorf("expected SignatureAlgorithm RS256, got %s", req.SignatureAlgorithm) + } + + // Verify + ok, err := client.VerifyIntegrity(req, &priv.PublicKey) + if err != nil { + t.Errorf("VerifyIntegrity failed: %v", err) + } + if !ok { + t.Error("expected integrity to be verified") + } + + // Tamper + req.ActorID = "hacker" + ok, err = client.VerifyIntegrity(req, &priv.PublicKey) + if err == nil && ok { + t.Error("expected verification to fail after tampering") + } +} +func TestCanonicalizeRequest_DeepSorting(t *testing.T) { + // Nested JSON in Metadata should be sorted canonically + req1 := &AuditLogRequest{ + Metadata: map[string]interface{}{"z": "last", "a": "first", "m": "middle"}, + } + req2 := &AuditLogRequest{ + Metadata: map[string]interface{}{"a": "first", "m": "middle", "z": "last"}, + } + + b1, err := CanonicalizeRequest(req1) + if err != nil { + t.Fatalf("unexpected error req1: %v", err) + } + b2, err := CanonicalizeRequest(req2) + if err != nil { + t.Fatalf("unexpected error req2: %v", err) + } + + if string(b1) != string(b2) { + t.Errorf("expected canonicalization to sort keys in Metadata") + } + + // Verify it contains the sorted keys + if !bytes.Contains(b1, []byte(`"a":"first","m":"middle","z":"last"`)) { + t.Errorf("expected sorted metadata in output, got %s", string(b1)) + } +} + +func TestClient_AlgorithmValidation(t *testing.T) { + signer := func(ctx context.Context, payload []byte) (string, error) { + return "sig", nil + } + + t.Run("Valid RS256", func(t *testing.T) { + client := NewClient(Config{ + BaseURL: "http://localhost:8080", + Signer: signer, + SignatureAlgorithm: "RS256", + }) + if !client.IsEnabled() { + t.Error("expected client to be enabled with RS256") + } + }) + + t.Run("Valid EdDSA", func(t *testing.T) { + client := NewClient(Config{ + BaseURL: "http://localhost:8080", + Signer: signer, + SignatureAlgorithm: "EdDSA", + }) + if !client.IsEnabled() { + t.Error("expected client to be enabled with EdDSA") + } + }) + + t.Run("Invalid Algorithm", func(t *testing.T) { + client := NewClient(Config{ + BaseURL: "http://localhost:8080", + Signer: signer, + SignatureAlgorithm: "MD5", // Unsupported/Insecure + }) + if client.IsEnabled() { + t.Error("expected client to be disabled with unsupported algorithm") + } + }) +} + +func TestClient_SignRetry(t *testing.T) { + ctx := context.Background() + attempts := 0 + signer := func(ctx context.Context, payload []byte) (string, error) { + attempts++ + if attempts < 3 { + return "", fmt.Errorf("transient error") + } + return "final-sig", nil + } + + client := NewClient(Config{ + BaseURL: "http://localhost:8080", + Signer: signer, + SignatureAlgorithm: "RS256", + WorkerCount: 1, + }) + + req := &AuditLogRequest{ + ActorID: "retry-actor", + ShouldSign: true, + } + + // Push to queue + client.LogEvent(ctx, req) + + // Wait for attempts to reach 3 (3 retries within the worker) + // The worker loop does: try 1, fail, try 2, fail, try 3, success. + // So attempts should be 3. + success := false + for i := 0; i < 20; i++ { + if attempts >= 3 { + success = true + break + } + time.Sleep(100 * time.Millisecond) + } + + if !success { + t.Errorf("expected 3 attempts, got %d", attempts) + } + + // Double check that it didn't do a 4th attempt + time.Sleep(200 * time.Millisecond) + if attempts > 3 { + t.Errorf("expected exactly 3 attempts, got %d", attempts) + } + + _ = client.Close(ctx) +}