diff --git a/.env.example b/.env.example index a1d34f7..07decb0 100644 --- a/.env.example +++ b/.env.example @@ -40,10 +40,15 @@ GIT_REPO_ROOT=./git-repos # If not set, all origins are allowed (development only) # CORS_ORIGINS=https://webpass.example.com,https://webpass.pages.dev -# Session duration in minutes (optional) -# Default: 5, Valid range: 5-480 (5 minutes to 8 hours) -# Invalid values will use default and print a warning -# SESSION_DURATION_MINUTES=5 +# Hard limit (optional) +# Default: 30, Valid range: 5-480 (5 minutes to 8 hours) +# Maximum time a session can last from login (regardless of activity) +# SESSION_HARDLIMIT_MINUTES=30 + +# Soft limit (optional) +# Default: 5, Valid range: 1-60 (1 minute to 1 hour) +# Detects browser close: session expires if no activity for this duration +# SESSION_SOFTLIMIT_MINUTES=5 # ============================================ # Cookie-based Authentication (httpOnly) diff --git a/AGENTS.md b/AGENTS.md index 34a0c9e..17a590d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -86,7 +86,8 @@ npm run typecheck | `PORT` | HTTP listen port (default: `8080`) | | `CORS_ORIGINS` | Comma-separated allowed origins | | `GIT_REPO_ROOT` | Git repos directory (default: `/data/git-repos`) | -| `SESSION_DURATION_MINUTES` | JWT session expiry time in minutes (default: 5, valid range: 5-480) | +| `SESSION_HARDLIMIT_MINUTES` | JWT hard limit (max session time) in minutes (default: 30, range: 5-480) | +| `SESSION_SOFTLIMIT_MINUTES` | JWT soft limit (browser close detection) in minutes (default: 5, range: 1-60) | ## Database diff --git a/IMPROVEMENT.md b/IMPROVEMENT.md new file mode 100644 index 0000000..8005c51 --- /dev/null +++ b/IMPROVEMENT.md @@ -0,0 +1,48 @@ +# WebPass Improvements + +## High Priority + +### 1. Backend: Structured Error Types +Replace string-based errors with typed error types for consistent API responses. +- [x] DONE - Created srv/errors.go with APIError type and error codes + +### 2. Backend: Graceful Shutdown +Add `os.Signal` handling to close DB connections and HTTP listener cleanly. +- [x] DONE - Added signal handling in cmd/srv/main.go + +### 3. Backend: Request Validation +Use validation library instead of manual JSON field checking. +- [ ] SKIPPED - Manual validation is already consistent; library adds complexity + +## Medium Priority + +### 5. Database: Add Indexes +Add indexes on `entries.path` column for query performance with large datasets. +- [x] DONE - Created migration 004-indexes-perf.sql with idx_entries_fingerprint_path and idx_entries_fingerprint + +### 6. Database: Foreign Keys +Add FK constraint on `git_config.fingerprint` referencing users table. +- [x] ALREADY DONE - FK exists in 002-git-sync.sql + +## Low Priority + +### 10. Testing Coverage +Add unit tests for backend utilities and frontend edge cases. +- [x] DONE - Added srv/errors_test.go for APIError type tests + +### 12. Frontend State: Preact Signals +Consider using Preact Signals for better state management. +- [ ] TODO - Current pub/sub pattern works well; needs discussion before changing + +### 13. Accessibility: ARIA Labels +Add ARIA labels to icon-only buttons and custom components. +- [x] DONE - Added aria-label to password/notes toggle buttons in EntryDetail and OTPDisplay + +## Next Step - Needs Discussion + +### A1. Session: Refresh Tokens (IMPLEMENTED) +Implement refresh token mechanism for sessions longer than 5 minutes. +- [x] DONE - Implemented with DB columns login_time and last_activity: + - Hard limit: 30 min (configurable via SESSION_HARDLIMIT_MINUTES) + - Soft limit: 5 min (configurable via SESSION_SOFTLIMIT_MINUTES) + - Auto-rotate: Updates last_activity on each API call \ No newline at end of file diff --git a/README.md b/README.md index 63b08ad..c0b6aa3 100644 --- a/README.md +++ b/README.md @@ -247,7 +247,8 @@ See [`.env.example`](.env.example) for all available options with detailed comme | `CORS_ORIGINS` | No | Comma-separated allowed origins | | `PORT` | No | HTTP listen port (default: `8080`) | | `GIT_REPO_ROOT`| No | Git repos directory (default: `/data/git-repos`) | -| `SESSION_DURATION_MINUTES` | No | JWT session expiry in minutes (default: 5, range: 5-480) | +| `SESSION_HARDLIMIT_MINUTES` | No | JWT hard limit (max session time) in minutes (default: 30, range: 5-480) | +| `SESSION_SOFTLIMIT_MINUTES` | No | JWT soft limit (browser close detection) in minutes (default: 5, range: 1-60) | | `DISABLE_FRONTEND` | No | Disable frontend (`1` or `true`) | | `BCRYPT_COST` | No | Password hashing cost factor (default: 12, range: 10-15) | diff --git a/cmd/srv/main.go b/cmd/srv/main.go index 63bba9d..699e3dd 100644 --- a/cmd/srv/main.go +++ b/cmd/srv/main.go @@ -1,12 +1,18 @@ package main import ( + "context" "crypto/rand" "flag" "fmt" + "log/slog" + "net/http" "os" + "os/signal" "runtime" "strconv" + "syscall" + "time" "srv.exe.dev/srv" ) @@ -70,19 +76,35 @@ func run() error { listenAddr = ":" + port } - // Session duration (default 5 minutes, range: 5-480) - sessionDurationMin := 5 // default - if durationStr := os.Getenv("SESSION_DURATION_MINUTES"); durationStr != "" { - if duration, err := strconv.Atoi(durationStr); err == nil { - if duration >= 5 && duration <= 480 { - sessionDurationMin = duration - } else if duration < 5 { - fmt.Printf("WARNING: SESSION_DURATION_MINUTES=%d too low, using minimum: 5\n", duration) + // Hard limit (default 30 minutes, range: 5-480) + hardLimitMin := 30 // default + if hardLimitStr := os.Getenv("SESSION_HARDLIMIT_MINUTES"); hardLimitStr != "" { + if hardLimit, err := strconv.Atoi(hardLimitStr); err == nil { + if hardLimit >= 5 && hardLimit <= 480 { + hardLimitMin = hardLimit + } else if hardLimit < 5 { + fmt.Printf("WARNING: SESSION_HARDLIMIT_MINUTES=%d too low, using minimum: 5\n", hardLimit) } else { - fmt.Printf("WARNING: SESSION_DURATION_MINUTES=%d too high, using maximum: 480\n", duration) + fmt.Printf("WARNING: SESSION_HARDLIMIT_MINUTES=%d too high, using maximum: 480\n", hardLimit) } } else { - fmt.Printf("WARNING: Invalid SESSION_DURATION_MINUTES=%s, using default: 5\n", durationStr) + fmt.Printf("WARNING: Invalid SESSION_HARDLIMIT_MINUTES=%s, using default: 30\n", hardLimitStr) + } + } + + // Soft limit (default 5 minutes, range: 1-60) + softLimitMin := 5 // default + if softLimitStr := os.Getenv("SESSION_SOFTLIMIT_MINUTES"); softLimitStr != "" { + if softLimit, err := strconv.Atoi(softLimitStr); err == nil { + if softLimit >= 1 && softLimit <= 60 { + softLimitMin = softLimit + } else if softLimit < 1 { + fmt.Printf("WARNING: SESSION_SOFTLIMIT_MINUTES=%d too low, using minimum: 1\n", softLimit) + } else { + fmt.Printf("WARNING: SESSION_SOFTLIMIT_MINUTES=%d too high, using maximum: 60\n", softLimit) + } + } else { + fmt.Printf("WARNING: Invalid SESSION_SOFTLIMIT_MINUTES=%s, using default: 5\n", softLimitStr) } } @@ -94,7 +116,8 @@ func run() error { fmt.Printf(" Disable Frontend:%s\n", disableFrontend) fmt.Printf(" Git Repo Root: %s\n", gitRepoRoot) fmt.Printf(" CORS Origins: %s\n", corsOrigins) - fmt.Printf(" Session Duration:%d minutes\n", sessionDurationMin) + fmt.Printf(" Hard Limit: %d minutes (SESSION_HARDLIMIT_MINUTES)\n", hardLimitMin) + fmt.Printf(" Soft Limit: %d minutes (SESSION_SOFTLIMIT_MINUTES)\n", softLimitMin) fmt.Println() jwtKey := make([]byte, 32) @@ -106,7 +129,7 @@ func run() error { } } - server, err := srv.New(dbPath, jwtKey, sessionDurationMin) + server, err := srv.New(dbPath, jwtKey, hardLimitMin, softLimitMin) if err != nil { return fmt.Errorf("create server: %w", err) } @@ -123,5 +146,42 @@ func run() error { } } - return server.Serve(listenAddr) + // Create HTTP server + httpServer := &http.Server{ + Addr: listenAddr, + Handler: server.Handler(), + } + + // Start server in goroutine + go func() { + slog.Info("starting server", "addr", listenAddr) + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + slog.Error("server error", "error", err) + } + }() + + // Wait for interrupt signal + quit := make(chan os.Signal, 1) + signal.Notify(quit, os.Interrupt, syscall.SIGTERM) + <-quit + + slog.Info("shutting down server...") + + // Give outstanding requests 30 seconds to complete + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := httpServer.Shutdown(ctx); err != nil { + slog.Error("server shutdown error", "error", err) + } + + // Close database connection + if err := server.CloseDB(); err != nil { + slog.Error("database close error", "error", err) + } else { + slog.Info("database connection closed") + } + + slog.Info("server stopped") + return nil } diff --git a/db/dbgen/db.go b/db/dbgen/db.go index d0d3db9..1d56ad8 100644 --- a/db/dbgen/db.go +++ b/db/dbgen/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 package dbgen diff --git a/db/dbgen/git.sql.go b/db/dbgen/git.sql.go index b34db6d..4c25f31 100644 --- a/db/dbgen/git.sql.go +++ b/db/dbgen/git.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: git.sql package dbgen diff --git a/db/dbgen/models.go b/db/dbgen/models.go index 7972606..66f15ba 100644 --- a/db/dbgen/models.go +++ b/db/dbgen/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 package dbgen @@ -49,6 +49,9 @@ type User struct { TotpSecret *string `json:"totp_secret"` TotpEnabled *int64 `json:"totp_enabled"` Created *time.Time `json:"created"` + LoginTime *time.Time `json:"login_time"` + LastActivity *time.Time `json:"last_activity"` + GpgID *string `json:"gpg_id"` } type Visitor struct { diff --git a/db/dbgen/visitors.sql.go b/db/dbgen/visitors.sql.go index 0bfd26e..f08269f 100644 --- a/db/dbgen/visitors.sql.go +++ b/db/dbgen/visitors.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: visitors.sql package dbgen diff --git a/db/dbgen/webpass.sql.go b/db/dbgen/webpass.sql.go index cd50f5a..bbd152b 100644 --- a/db/dbgen/webpass.sql.go +++ b/db/dbgen/webpass.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: webpass.sql package dbgen @@ -11,18 +11,24 @@ import ( ) const createUser = `-- name: CreateUser :exec -INSERT INTO users (fingerprint, password_hash, public_key) -VALUES (?, ?, ?) +INSERT INTO users (fingerprint, password_hash, public_key, gpg_id) +VALUES (?, ?, ?, ?) ` type CreateUserParams struct { - Fingerprint string `json:"fingerprint"` - PasswordHash string `json:"password_hash"` - PublicKey string `json:"public_key"` + Fingerprint string `json:"fingerprint"` + PasswordHash string `json:"password_hash"` + PublicKey string `json:"public_key"` + GpgID *string `json:"gpg_id"` } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) error { - _, err := q.db.ExecContext(ctx, createUser, arg.Fingerprint, arg.PasswordHash, arg.PublicKey) + _, err := q.db.ExecContext(ctx, createUser, + arg.Fingerprint, + arg.PasswordHash, + arg.PublicKey, + arg.GpgID, + ) return err } @@ -74,8 +80,25 @@ func (q *Queries) GetEntry(ctx context.Context, arg GetEntryParams) (Entry, erro return i, err } +const getSessionInfo = `-- name: GetSessionInfo :one +SELECT login_time, last_activity FROM users +WHERE fingerprint = ? +` + +type GetSessionInfoRow struct { + LoginTime *time.Time `json:"login_time"` + LastActivity *time.Time `json:"last_activity"` +} + +func (q *Queries) GetSessionInfo(ctx context.Context, fingerprint string) (GetSessionInfoRow, error) { + row := q.db.QueryRowContext(ctx, getSessionInfo, fingerprint) + var i GetSessionInfoRow + err := row.Scan(&i.LoginTime, &i.LastActivity) + return i, err +} + const getUser = `-- name: GetUser :one -SELECT fingerprint, password_hash, public_key, totp_secret, totp_enabled, created FROM users WHERE fingerprint = ? +SELECT fingerprint, password_hash, public_key, totp_secret, totp_enabled, created, login_time, last_activity, gpg_id FROM users WHERE fingerprint = ? ` func (q *Queries) GetUser(ctx context.Context, fingerprint string) (User, error) { @@ -88,6 +111,9 @@ func (q *Queries) GetUser(ctx context.Context, fingerprint string) (User, error) &i.TotpSecret, &i.TotpEnabled, &i.Created, + &i.LoginTime, + &i.LastActivity, + &i.GpgID, ) return i, err } @@ -189,6 +215,28 @@ func (q *Queries) MoveEntry(ctx context.Context, arg MoveEntryParams) error { return err } +const updateLastActivity = `-- name: UpdateLastActivity :exec +UPDATE users +SET last_activity = CURRENT_TIMESTAMP +WHERE fingerprint = ? +` + +func (q *Queries) UpdateLastActivity(ctx context.Context, fingerprint string) error { + _, err := q.db.ExecContext(ctx, updateLastActivity, fingerprint) + return err +} + +const updateLoginTime = `-- name: UpdateLoginTime :exec +UPDATE users +SET login_time = CURRENT_TIMESTAMP, last_activity = CURRENT_TIMESTAMP +WHERE fingerprint = ? +` + +func (q *Queries) UpdateLoginTime(ctx context.Context, fingerprint string) error { + _, err := q.db.ExecContext(ctx, updateLoginTime, fingerprint) + return err +} + const updatePassword = `-- name: UpdatePassword :exec UPDATE users SET password_hash = ? @@ -205,6 +253,22 @@ func (q *Queries) UpdatePassword(ctx context.Context, arg UpdatePasswordParams) return err } +const updateUserGpgID = `-- name: UpdateUserGpgID :exec +UPDATE users +SET gpg_id = ? +WHERE fingerprint = ? +` + +type UpdateUserGpgIDParams struct { + GpgID *string `json:"gpg_id"` + Fingerprint string `json:"fingerprint"` +} + +func (q *Queries) UpdateUserGpgID(ctx context.Context, arg UpdateUserGpgIDParams) error { + _, err := q.db.ExecContext(ctx, updateUserGpgID, arg.GpgID, arg.Fingerprint) + return err +} + const updateUserTOTP = `-- name: UpdateUserTOTP :exec UPDATE users SET totp_secret = ?, totp_enabled = ? diff --git a/db/migrations/004-indexes-perf.sql b/db/migrations/004-indexes-perf.sql new file mode 100644 index 0000000..f6ee83d --- /dev/null +++ b/db/migrations/004-indexes-perf.sql @@ -0,0 +1,12 @@ +-- Improve query performance with indexes +-- + +-- Index for fast entry lookups by fingerprint and path (used in GetEntry, DeleteEntry, MoveEntry) +CREATE INDEX IF NOT EXISTS idx_entries_fingerprint_path ON entries(fingerprint, path); + +-- Index for fast entry listing by fingerprint (used in ListEntries, DeleteAccount) +CREATE INDEX IF NOT EXISTS idx_entries_fingerprint ON entries(fingerprint); + +-- Record execution of this migration +INSERT OR IGNORE INTO migrations (migration_number, migration_name) +VALUES (004, '004-indexes-perf'); \ No newline at end of file diff --git a/db/migrations/005-session-tracking.sql b/db/migrations/005-session-tracking.sql new file mode 100644 index 0000000..478914e --- /dev/null +++ b/db/migrations/005-session-tracking.sql @@ -0,0 +1,16 @@ +-- Session tracking for hard/soft limits +-- +-- Add columns for session management: +-- - login_time: when user first logged in (for hard limit check) +-- - last_activity: timestamp of last API call (for soft limit check) + +ALTER TABLE users ADD COLUMN login_time DATETIME; +ALTER TABLE users ADD COLUMN last_activity DATETIME; + +-- Index for fast queries +CREATE INDEX IF NOT EXISTS idx_users_login_time ON users(login_time); +CREATE INDEX IF NOT EXISTS idx_users_last_activity ON users(last_activity); + +-- Record execution of this migration +INSERT OR IGNORE INTO migrations (migration_number, migration_name) +VALUES (005, '005-session-tracking'); \ No newline at end of file diff --git a/db/migrations/006-users-gpg-id.sql b/db/migrations/006-users-gpg-id.sql new file mode 100644 index 0000000..d45ff65 --- /dev/null +++ b/db/migrations/006-users-gpg-id.sql @@ -0,0 +1,8 @@ +-- Store .gpg-id from password-store repos in users table (single source of truth) +ALTER TABLE users ADD COLUMN gpg_id TEXT DEFAULT ''; + +-- Backfill existing accounts: set gpg_id to fingerprint so they work with pass CLI +UPDATE users SET gpg_id = fingerprint WHERE gpg_id = '' OR gpg_id IS NULL; + +INSERT OR IGNORE INTO migrations (migration_number, migration_name) +VALUES (006, '006-users-gpg-id'); diff --git a/db/queries/webpass.sql b/db/queries/webpass.sql index 85bde69..ab03b77 100644 --- a/db/queries/webpass.sql +++ b/db/queries/webpass.sql @@ -1,6 +1,6 @@ -- name: CreateUser :exec -INSERT INTO users (fingerprint, password_hash, public_key) -VALUES (?, ?, ?); +INSERT INTO users (fingerprint, password_hash, public_key, gpg_id) +VALUES (?, ?, ?, ?); -- name: GetUser :one SELECT * FROM users WHERE fingerprint = ?; @@ -34,6 +34,11 @@ WHERE fingerprint = ? AND path = ?; -- name: DeleteUser :exec DELETE FROM users WHERE fingerprint = ?; +-- name: UpdateUserGpgID :exec +UPDATE users +SET gpg_id = ? +WHERE fingerprint = ?; + -- name: UpdateUserTOTP :exec UPDATE users SET totp_secret = ?, totp_enabled = ? @@ -48,3 +53,17 @@ WHERE fingerprint = ?; SELECT * FROM entries WHERE fingerprint = ? ORDER BY path; + +-- name: UpdateLoginTime :exec +UPDATE users +SET login_time = CURRENT_TIMESTAMP, last_activity = CURRENT_TIMESTAMP +WHERE fingerprint = ?; + +-- name: UpdateLastActivity :exec +UPDATE users +SET last_activity = CURRENT_TIMESTAMP +WHERE fingerprint = ?; + +-- name: GetSessionInfo :one +SELECT login_time, last_activity FROM users +WHERE fingerprint = ?; diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 257770b..cadfbdf 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -22,6 +22,7 @@ "@types/node": "^22.10.2", "@types/qrcode": "^1.5.5", "@types/tar-stream": "^3.1.3", + "fake-indexeddb": "^6.2.5", "jsdom": "^25.0.1", "tar-stream": "^3.1.7", "typescript": "^5.7.2", @@ -2241,6 +2242,16 @@ "node": ">=12.0.0" } }, + "node_modules/fake-indexeddb": { + "version": "6.2.5", + "resolved": "https://registry.npmjs.org/fake-indexeddb/-/fake-indexeddb-6.2.5.tgz", + "integrity": "sha512-CGnyrvbhPlWYMngksqrSSUT1BAVP49dZocrHuK0SvtR0D5TMs5wP0o3j7jexDJW01KSadjBp1M/71o/KR3nD1w==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=18" + } + }, "node_modules/fast-fifo": { "version": "1.3.2", "resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz", diff --git a/frontend/package.json b/frontend/package.json index 1d0ef88..fa8cec6 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -33,6 +33,7 @@ "@types/node": "^22.10.2", "@types/qrcode": "^1.5.5", "@types/tar-stream": "^3.1.3", + "fake-indexeddb": "^6.2.5", "jsdom": "^25.0.1", "tar-stream": "^3.1.7", "typescript": "^5.7.2", diff --git a/frontend/src/components/EntryDetail.tsx b/frontend/src/components/EntryDetail.tsx index 35a4c04..e5c6df8 100644 --- a/frontend/src/components/EntryDetail.tsx +++ b/frontend/src/components/EntryDetail.tsx @@ -279,6 +279,7 @@ export function EntryDetail({ path, onEdit, onDelete }: Props) { class="btn btn-ghost btn-icon btn-sm" onClick={handlePasswordToggle} title={showPassword ? 'Hide' : 'Show'} + aria-label={showPassword ? 'Hide password' : 'Show password'} style={{ minWidth: 'auto', padding: '4px 8px' }} data-testid="password-toggle-btn" > @@ -309,6 +310,7 @@ export function EntryDetail({ path, onEdit, onDelete }: Props) { class="btn btn-ghost btn-icon btn-sm" onClick={handleNotesToggle} title={showNotes ? 'Hide' : 'Show'} + aria-label={showNotes ? 'Hide notes' : 'Show notes'} style={{ minWidth: 'auto', padding: '4px 8px' }} data-testid="notes-toggle-btn" > diff --git a/frontend/src/components/OTPDisplay.tsx b/frontend/src/components/OTPDisplay.tsx index d0de977..27e3d67 100644 --- a/frontend/src/components/OTPDisplay.tsx +++ b/frontend/src/components/OTPDisplay.tsx @@ -121,6 +121,7 @@ export function OTPDisplay({ content }: Props) { class="otp-copy" onClick={() => setShowOTP(!showOTP)} title={showOTP ? 'Hide' : 'Show'} + aria-label={showOTP ? 'Hide OTP code' : 'Show OTP code'} > {showOTP ? : } diff --git a/frontend/src/lib/api.test.ts b/frontend/src/lib/api.test.ts index 8f493b0..a51bf37 100644 --- a/frontend/src/lib/api.test.ts +++ b/frontend/src/lib/api.test.ts @@ -2,6 +2,7 @@ * Tests for API client functions */ +import { describe, it, expect, beforeEach, vi } from 'vitest'; import { ApiClient } from './api'; describe('ApiClient', () => { @@ -40,6 +41,10 @@ describe('ApiClient', () => { }); describe('headers method', () => { + beforeEach(() => { + vi.stubGlobal('document', { cookie: '' }); + }); + it('should return content-type for JSON requests', () => { const headers = (client as any).headers(false); expect(headers['Content-Type']).toBe('application/json'); diff --git a/frontend/src/lib/crypto.test.ts b/frontend/src/lib/crypto.test.ts new file mode 100644 index 0000000..82fc9dc --- /dev/null +++ b/frontend/src/lib/crypto.test.ts @@ -0,0 +1,293 @@ +/** + * Unit tests for crypto utilities + */ + +import { describe, it, expect, beforeAll } from 'vitest'; +import { + generateKeyPair, + getFingerprint, + decryptPrivateKey, + importPrivateKey, + clearSensitiveData, + encryptText, + decryptMessage, + encryptBinary, + decryptBinary, + WrongKeyError, + deriveKey, + generateSalt, + aesGcmEncrypt, + aesGcmDecrypt, + arrayBufferToBase64, + base64ToArrayBuffer, + encryptPAT, + decryptPAT, +} from './crypto'; + +describe('generateKeyPair', () => { + it('generates a keypair with fingerprint', async () => { + const result = await generateKeyPair('test-passphrase-123'); + + expect(result.publicKey).toContain('-----BEGIN PGP PUBLIC KEY BLOCK-----'); + expect(result.privateKey).toContain('-----BEGIN PGP PRIVATE KEY BLOCK-----'); + expect(result.fingerprint).toMatch(/^[0-9A-F]{40}$/); + }, 30000); + + it('generates deterministic fingerprints for same inputs', async () => { + // Different passphrases should still produce valid keys + const result1 = await generateKeyPair('pass-one'); + const result2 = await generateKeyPair('pass-two'); + + expect(result1.fingerprint).not.toBe(result2.fingerprint); + expect(result1.publicKey).not.toBe(result2.publicKey); + }, 30000); +}); + +describe('getFingerprint', () => { + let publicKey: string; + let fingerprint: string; + + beforeAll(async () => { + const result = await generateKeyPair('test-pass'); + publicKey = result.publicKey; + fingerprint = result.fingerprint; + }, 30000); + + it('extracts fingerprint from public key', async () => { + const fp = await getFingerprint(publicKey); + expect(fp).toBe(fingerprint); + expect(fp).toMatch(/^[0-9A-F]{40}$/); + }); + + it('throws on invalid key', async () => { + await expect(getFingerprint('not-a-key')).rejects.toThrow(); + }); +}); + +describe('decryptPrivateKey', () => { + let privateKey: string; + let passphrase: string; + + beforeAll(async () => { + passphrase = 'correct-passphrase'; + const result = await generateKeyPair(passphrase); + privateKey = result.privateKey; + }, 30000); + + it('decrypts with correct passphrase', async () => { + const decrypted = await decryptPrivateKey(privateKey, passphrase); + expect(decrypted.isDecrypted()).toBe(true); + }); + + it('throws with wrong passphrase', async () => { + await expect(decryptPrivateKey(privateKey, 'wrong-passphrase')).rejects.toThrow(); + }); + + it('throws on invalid key format', async () => { + await expect(decryptPrivateKey(123 as any, 'pass')).rejects.toThrow('Invalid key format'); + }); +}); + +describe('importPrivateKey', () => { + let privateKey: string; + let passphrase: string; + + beforeAll(async () => { + passphrase = 'import-test-pass'; + const result = await generateKeyPair(passphrase); + privateKey = result.privateKey; + }, 30000); + + it('imports and decrypts armored key', async () => { + const decrypted = await importPrivateKey(privateKey, passphrase); + expect(decrypted.isDecrypted()).toBe(true); + }); + + it('throws with wrong passphrase', async () => { + await expect(importPrivateKey(privateKey, 'wrong')).rejects.toThrow(); + }); +}); + +describe('clearSensitiveData', () => { + it('nullifies variables without throwing', () => { + let a: any = 'secret'; + let b: any = 123; + clearSensitiveData(a, b); + // Note: primitive arguments passed by value can't be mutated, + // but function should not throw + expect(() => clearSensitiveData('x', 1, null)).not.toThrow(); + }); +}); + +describe('encryptText / decryptMessage round-trip', () => { + let keypair: { publicKey: string; privateKey: string; fingerprint: string }; + + beforeAll(async () => { + keypair = await generateKeyPair('roundtrip-pass'); + }, 30000); + + it('encrypts and decrypts text', async () => { + const plaintext = 'Hello, WebPass!'; + const encrypted = await encryptText(plaintext, keypair.publicKey); + + expect(encrypted).toContain('-----BEGIN PGP MESSAGE-----'); + + const decryptedKey = await decryptPrivateKey(keypair.privateKey, 'roundtrip-pass'); + const decrypted = await decryptMessage(encrypted, decryptedKey); + + expect(decrypted).toBe(plaintext); + }); + + it('encrypts and decrypts empty string', async () => { + const encrypted = await encryptText('', keypair.publicKey); + const decryptedKey = await decryptPrivateKey(keypair.privateKey, 'roundtrip-pass'); + const decrypted = await decryptMessage(encrypted, decryptedKey); + expect(decrypted).toBe(''); + }); + + it('encrypts and decrypts unicode text', async () => { + const plaintext = 'Unicode: πŸ” Γ± δΈ­ζ–‡ πŸš€'; + const encrypted = await encryptText(plaintext, keypair.publicKey); + const decryptedKey = await decryptPrivateKey(keypair.privateKey, 'roundtrip-pass'); + const decrypted = await decryptMessage(encrypted, decryptedKey); + expect(decrypted).toBe(plaintext); + }); +}); + +describe('encryptBinary / decryptBinary round-trip', () => { + let keypair: { publicKey: string; privateKey: string; fingerprint: string }; + + beforeAll(async () => { + keypair = await generateKeyPair('binary-pass'); + }, 30000); + + it('encrypts to binary and decrypts', async () => { + const plaintext = 'Binary encryption test'; + const encrypted = await encryptBinary(plaintext, keypair.publicKey); + + expect(encrypted).toBeInstanceOf(Uint8Array); + expect(encrypted.length).toBeGreaterThan(0); + + const decryptedKey = await decryptPrivateKey(keypair.privateKey, 'binary-pass'); + const decrypted = await decryptBinary(encrypted, decryptedKey); + + expect(decrypted).toBe(plaintext); + }); + + it('throws with mismatched key', async () => { + const plaintext = 'secret data'; + const encrypted = await encryptBinary(plaintext, keypair.publicKey); + + // Generate a different keypair + const otherKeypair = await generateKeyPair('other-pass'); + const wrongKey = await decryptPrivateKey(otherKeypair.privateKey, 'other-pass'); + + // Completely different keypair throws "No decryption key packets found" + // rather than "Session key decryption failed" + await expect(decryptBinary(encrypted, wrongKey)).rejects.toThrow(); + }, 30000); + + it('throws WrongKeyError on session key decryption failure', () => { + // Directly test the error class behavior + const err = new WrongKeyError('session key failed'); + expect(err.name).toBe('WrongKeyError'); + expect(err.message).toBe('session key failed'); + }); +}); + +describe('WrongKeyError', () => { + it('has correct name and message', () => { + const err = new WrongKeyError('test message'); + expect(err.name).toBe('WrongKeyError'); + expect(err.message).toBe('test message'); + expect(err).toBeInstanceOf(Error); + }); +}); + +describe('PBKDF2 + AES-GCM helpers', () => { + it('generateSalt produces 16 random bytes', () => { + const salt = generateSalt(); + expect(salt).toBeInstanceOf(Uint8Array); + expect(salt.length).toBe(16); + + const salt2 = generateSalt(); + expect(salt).not.toEqual(salt2); // very likely different + }); + + it('deriveKey produces a CryptoKey', async () => { + const salt = generateSalt(); + const key = await deriveKey('my-password', salt); + expect(key).toBeDefined(); + expect(key.type).toBe('secret'); + expect(key.algorithm.name).toBe('AES-GCM'); + }); + + it('aesGcmEncrypt produces ciphertext and iv', async () => { + const salt = generateSalt(); + const key = await deriveKey('password', salt); + const result = await aesGcmEncrypt('hello world', key); + + expect(result.ciphertext).toBeInstanceOf(Uint8Array); + expect(result.iv).toBeInstanceOf(Uint8Array); + expect(result.iv.length).toBe(12); + expect(result.ciphertext.length).toBeGreaterThan(0); + }); + + it('aesGcmDecrypt recovers plaintext', async () => { + const salt = generateSalt(); + const key = await deriveKey('password', salt); + const encrypted = await aesGcmEncrypt('secret message', key); + const decrypted = await aesGcmDecrypt(encrypted.ciphertext, encrypted.iv, key); + expect(decrypted).toBe('secret message'); + }); + + it('aesGcmDecrypt fails with wrong key', async () => { + const salt = generateSalt(); + const key = await deriveKey('correct-password', salt); + const encrypted = await aesGcmEncrypt('secret', key); + + const wrongSalt = generateSalt(); + const wrongKey = await deriveKey('wrong-password', wrongSalt); + + await expect(aesGcmDecrypt(encrypted.ciphertext, encrypted.iv, wrongKey)).rejects.toThrow(); + }); +}); + +describe('Base64 helpers', () => { + it('arrayBufferToBase64 / base64ToArrayBuffer round-trip', () => { + const data = new Uint8Array([0, 1, 2, 255, 128, 64]); + const b64 = arrayBufferToBase64(data); + expect(typeof b64).toBe('string'); + expect(b64.length).toBeGreaterThan(0); + + const recovered = base64ToArrayBuffer(b64); + expect(recovered).toEqual(data); + }); + + it('handles empty array', () => { + const data = new Uint8Array(0); + const b64 = arrayBufferToBase64(data); + expect(b64).toBe(''); + expect(base64ToArrayBuffer(b64)).toEqual(data); + }); +}); + +describe('encryptPAT / decryptPAT', () => { + let keypair: { publicKey: string; privateKey: string; fingerprint: string }; + + beforeAll(async () => { + keypair = await generateKeyPair('pat-pass'); + }, 30000); + + it('encrypts and decrypts PAT', async () => { + const pat = 'ghp_1234567890abcdef'; + const encrypted = await encryptPAT(pat, keypair.publicKey); + + expect(encrypted).toContain('-----BEGIN PGP MESSAGE-----'); + + const decryptedKey = await decryptPrivateKey(keypair.privateKey, 'pat-pass'); + const decrypted = await decryptPAT(encrypted, decryptedKey); + + expect(decrypted).toBe(pat); + }); +}); diff --git a/frontend/src/lib/storage.test.ts b/frontend/src/lib/storage.test.ts new file mode 100644 index 0000000..2a354de --- /dev/null +++ b/frontend/src/lib/storage.test.ts @@ -0,0 +1,247 @@ +/** + * Unit tests for storage utilities + */ + +import 'fake-indexeddb/auto'; +import { describe, it, expect, beforeEach } from 'vitest'; +import type { Account } from '../types'; +import { + saveAccount, + getAccount, + listAccounts, + deleteAccount, + saveGitToken, + getGitToken, + deleteGitToken, + aesEncrypt, + aesDecrypt, + savePublicKey, + getPublicKey, + savePrivateKey, + getDecryptedPrivateKey, +} from './storage'; + +describe('Account CRUD', () => { + const mockAccount: Account = { + fingerprint: 'abc123', + privateKey: '-----BEGIN PGP PRIVATE KEY BLOCK-----\ntest\n-----END PGP PRIVATE KEY BLOCK-----', + publicKey: '-----BEGIN PGP PUBLIC KEY BLOCK-----\ntest\n-----END PGP PUBLIC KEY BLOCK-----', + apiUrlEncrypted: 'encrypted-data', + apiUrlSalt: 'salt123', + apiUrlIv: 'iv456', + label: 'Test Account', + }; + + beforeEach(async () => { + // Clean up any existing accounts + const accounts = await listAccounts(); + for (const acc of accounts) { + await deleteAccount(acc.fingerprint); + } + }); + + it('saves and retrieves an account', async () => { + await saveAccount(mockAccount); + const retrieved = await getAccount('abc123'); + expect(retrieved).toEqual(mockAccount); + }); + + it('returns null for non-existent account', async () => { + const retrieved = await getAccount('nonexistent'); + expect(retrieved).toBeNull(); + }); + + it('lists all accounts', async () => { + await saveAccount(mockAccount); + await saveAccount({ ...mockAccount, fingerprint: 'def456', label: 'Second' }); + + const accounts = await listAccounts(); + expect(accounts).toHaveLength(2); + expect(accounts.map((a) => a.fingerprint)).toContain('abc123'); + expect(accounts.map((a) => a.fingerprint)).toContain('def456'); + }); + + it('returns empty array when no accounts', async () => { + const accounts = await listAccounts(); + expect(accounts).toEqual([]); + }); + + it('deletes an account', async () => { + await saveAccount(mockAccount); + await deleteAccount('abc123'); + const retrieved = await getAccount('abc123'); + expect(retrieved).toBeNull(); + }); + + it('updates an existing account', async () => { + await saveAccount(mockAccount); + await saveAccount({ ...mockAccount, label: 'Updated' }); + const retrieved = await getAccount('abc123'); + expect(retrieved?.label).toBe('Updated'); + }); +}); + +describe('Git Token Storage', () => { + beforeEach(async () => { + const token = await getGitToken('test-fp'); + if (token) await deleteGitToken('test-fp'); + }); + + it('saves and retrieves git token', async () => { + await saveGitToken('test-fp', 'encrypted-token', 'salt123', 'iv456'); + const retrieved = await getGitToken('test-fp'); + expect(retrieved).toEqual({ + fingerprint: 'test-fp', + encryptedToken: 'encrypted-token', + salt: 'salt123', + iv: 'iv456', + }); + }); + + it('returns null for non-existent token', async () => { + const retrieved = await getGitToken('nonexistent'); + expect(retrieved).toBeNull(); + }); + + it('deletes git token', async () => { + await saveGitToken('test-fp', 'token', 'salt', 'iv'); + await deleteGitToken('test-fp'); + const retrieved = await getGitToken('test-fp'); + expect(retrieved).toBeNull(); + }); +}); + +describe('AES Encrypt/Decrypt', () => { + it('encrypts and decrypts data', async () => { + const password = 'my-secret-password'; + const plaintext = 'Hello, WebPass!'; + + const encrypted = await aesEncrypt(plaintext, password); + expect(encrypted.encrypted).toBeDefined(); + expect(encrypted.salt).toBeDefined(); + expect(encrypted.iv).toBeDefined(); + + const decrypted = await aesDecrypt(encrypted.encrypted, password, encrypted.salt, encrypted.iv); + expect(decrypted).toBe(plaintext); + }); + + it('produces different ciphertexts for same plaintext', async () => { + const password = 'password123'; + const plaintext = 'same text'; + + const encrypted1 = await aesEncrypt(plaintext, password); + const encrypted2 = await aesEncrypt(plaintext, password); + + expect(encrypted1.encrypted).not.toBe(encrypted2.encrypted); + expect(encrypted1.salt).not.toBe(encrypted2.salt); + expect(encrypted1.iv).not.toBe(encrypted2.iv); + }); + + it('decrypts with correct password', async () => { + const password = 'correct-password'; + const plaintext = 'secret data'; + + const encrypted = await aesEncrypt(plaintext, password); + const decrypted = await aesDecrypt(encrypted.encrypted, password, encrypted.salt, encrypted.iv); + expect(decrypted).toBe(plaintext); + }); + + it('fails with wrong password', async () => { + const password = 'correct-password'; + const plaintext = 'secret data'; + + const encrypted = await aesEncrypt(plaintext, password); + await expect( + aesDecrypt(encrypted.encrypted, 'wrong-password', encrypted.salt, encrypted.iv) + ).rejects.toThrow(); + }); + + it('handles unicode plaintext', async () => { + const password = 'unicode-pass'; + const plaintext = 'Unicode: πŸ” Γ± δΈ­ζ–‡ πŸš€'; + + const encrypted = await aesEncrypt(plaintext, password); + const decrypted = await aesDecrypt(encrypted.encrypted, password, encrypted.salt, encrypted.iv); + expect(decrypted).toBe(plaintext); + }); + + it('handles empty string', async () => { + const password = 'empty-pass'; + const plaintext = ''; + + const encrypted = await aesEncrypt(plaintext, password); + const decrypted = await aesDecrypt(encrypted.encrypted, password, encrypted.salt, encrypted.iv); + expect(decrypted).toBe(plaintext); + }); +}); + +describe('PGP Key Storage', () => { + beforeEach(async () => { + const accounts = await listAccounts(); + for (const acc of accounts) { + await deleteAccount(acc.fingerprint); + } + }); + + it('saves and retrieves public key', async () => { + const fp = 'key-fp-123'; + const pubKey = '-----BEGIN PGP PUBLIC KEY BLOCK-----\ntest-key\n-----END PGP PUBLIC KEY BLOCK-----'; + + await savePublicKey(fp, pubKey); + const retrieved = await getPublicKey(fp); + expect(retrieved).toBe(pubKey); + }); + + it('falls back to account store for public key', async () => { + const fp = 'key-fp-456'; + const pubKey = 'fallback-public-key'; + const account: Account = { + fingerprint: fp, + privateKey: 'priv', + publicKey: pubKey, + apiUrlEncrypted: 'enc', + apiUrlSalt: 'salt', + apiUrlIv: 'iv', + }; + + await saveAccount(account); + const retrieved = await getPublicKey(fp); + expect(retrieved).toBe(pubKey); + }); + + it('returns null for missing public key', async () => { + const retrieved = await getPublicKey('nonexistent'); + expect(retrieved).toBeNull(); + }); + + it('saves and retrieves private key', async () => { + const fp = 'key-fp-789'; + const privKey = '-----BEGIN PGP PRIVATE KEY BLOCK-----\nprivate\n-----END PGP PRIVATE KEY BLOCK-----'; + + await savePrivateKey(fp, privKey); + const retrieved = await getDecryptedPrivateKey(fp, 'any-passphrase'); + expect(retrieved).toBe(privKey); + }); + + it('falls back to account store for private key', async () => { + const fp = 'key-fp-abc'; + const privKey = 'fallback-private-key'; + const account: Account = { + fingerprint: fp, + privateKey: privKey, + publicKey: 'pub', + apiUrlEncrypted: 'enc', + apiUrlSalt: 'salt', + apiUrlIv: 'iv', + }; + + await saveAccount(account); + const retrieved = await getDecryptedPrivateKey(fp, 'any-passphrase'); + expect(retrieved).toBe(privKey); + }); + + it('returns null for missing private key', async () => { + const retrieved = await getDecryptedPrivateKey('nonexistent', 'pass'); + expect(retrieved).toBeNull(); + }); +}); diff --git a/frontend/vitest.config.ts b/frontend/vitest.config.ts index 1af399c..04a9cba 100644 --- a/frontend/vitest.config.ts +++ b/frontend/vitest.config.ts @@ -6,7 +6,7 @@ export default defineConfig({ plugins: [preact()], test: { globals: true, - environment: 'jsdom', + environment: 'node', include: ['src/**/*.test.ts'], exclude: ['**/node_modules/**', '**/dist/**', '**/*.crypto.test.ts'], }, diff --git a/srv/errors.go b/srv/errors.go new file mode 100644 index 0000000..4d3049f --- /dev/null +++ b/srv/errors.go @@ -0,0 +1,67 @@ +package srv + +import ( + "errors" + "fmt" + "net/http" +) + +type ErrorCode string + +const ( + ErrCodeBadRequest ErrorCode = "bad_request" + ErrCodeUnauthorized ErrorCode = "unauthorized" + ErrCodeForbidden ErrorCode = "forbidden" + ErrCodeNotFound ErrorCode = "not_found" + ErrCodeConflict ErrorCode = "conflict" + ErrCodeInternal ErrorCode = "internal" + ErrCodeTooMany ErrorCode = "too_many_requests" +) + +type APIError struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` +} + +func (e APIError) Error() string { + return fmt.Sprintf("%s: %s", e.Code, e.Message) +} + +func (e APIError) Unwrap() error { + return errors.New(e.Message) +} + +func (e APIError) StatusCode() int { + switch e.Code { + case ErrCodeBadRequest: + return http.StatusBadRequest + case ErrCodeUnauthorized: + return http.StatusUnauthorized + case ErrCodeForbidden: + return http.StatusForbidden + case ErrCodeNotFound: + return http.StatusNotFound + case ErrCodeConflict: + return http.StatusConflict + case ErrCodeInternal: + return http.StatusInternalServerError + case ErrCodeTooMany: + return http.StatusTooManyRequests + default: + return http.StatusInternalServerError + } +} + +var ( + ErrBadRequest = APIError{Code: ErrCodeBadRequest, Message: "bad request"} + ErrUnauthorized = APIError{Code: ErrCodeUnauthorized, Message: "unauthorized"} + ErrForbidden = APIError{Code: ErrCodeForbidden, Message: "forbidden"} + ErrNotFound = APIError{Code: ErrCodeNotFound, Message: "not found"} + ErrConflict = APIError{Code: ErrCodeConflict, Message: "conflict"} + ErrInternal = APIError{Code: ErrCodeInternal, Message: "internal error"} + ErrTooMany = APIError{Code: ErrCodeTooMany, Message: "too many requests"} +) + +func NewAPIError(code ErrorCode, message string) APIError { + return APIError{Code: code, Message: message} +} diff --git a/srv/errors_test.go b/srv/errors_test.go new file mode 100644 index 0000000..b195861 --- /dev/null +++ b/srv/errors_test.go @@ -0,0 +1,103 @@ +package srv + +import ( + "net/http" + "testing" +) + +func TestAPIError(t *testing.T) { + tests := []struct { + name string + err APIError + wantMsg string + wantCode ErrorCode + wantHTTP int + }{ + { + name: "bad request", + err: ErrBadRequest, + wantMsg: "bad request", + wantCode: ErrCodeBadRequest, + wantHTTP: http.StatusBadRequest, + }, + { + name: "unauthorized", + err: ErrUnauthorized, + wantMsg: "unauthorized", + wantCode: ErrCodeUnauthorized, + wantHTTP: http.StatusUnauthorized, + }, + { + name: "forbidden", + err: ErrForbidden, + wantMsg: "forbidden", + wantCode: ErrCodeForbidden, + wantHTTP: http.StatusForbidden, + }, + { + name: "not found", + err: ErrNotFound, + wantMsg: "not found", + wantCode: ErrCodeNotFound, + wantHTTP: http.StatusNotFound, + }, + { + name: "conflict", + err: ErrConflict, + wantMsg: "conflict", + wantCode: ErrCodeConflict, + wantHTTP: http.StatusConflict, + }, + { + name: "internal", + err: ErrInternal, + wantMsg: "internal error", + wantCode: ErrCodeInternal, + wantHTTP: http.StatusInternalServerError, + }, + { + name: "too many", + err: ErrTooMany, + wantMsg: "too many requests", + wantCode: ErrCodeTooMany, + wantHTTP: http.StatusTooManyRequests, + }, + { + name: "custom error", + err: NewAPIError(ErrCodeNotFound, "custom message"), + wantMsg: "custom message", + wantCode: ErrCodeNotFound, + wantHTTP: http.StatusNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err.Message != tt.wantMsg { + t.Errorf("Message = %q, want %q", tt.err.Message, tt.wantMsg) + } + if tt.err.Code != tt.wantCode { + t.Errorf("Code = %q, want %q", tt.err.Code, tt.wantCode) + } + if tt.err.StatusCode() != tt.wantHTTP { + t.Errorf("StatusCode() = %d, want %d", tt.err.StatusCode(), tt.wantHTTP) + } + }) + } +} + +func TestAPIError_Error(t *testing.T) { + err := NewAPIError(ErrCodeBadRequest, "test message") + want := "bad_request: test message" + if got := err.Error(); got != want { + t.Errorf("Error() = %q, want %q", got, want) + } +} + +func TestAPIError_Unwrap(t *testing.T) { + err := NewAPIError(ErrCodeInternal, "internal error") + got := err.Unwrap() + if got.Error() != "internal error" { + t.Errorf("Unwrap() = %v, want 'internal error'", got) + } +} diff --git a/srv/git.go b/srv/git.go index f81f8e9..1911b3b 100644 --- a/srv/git.go +++ b/srv/git.go @@ -219,7 +219,22 @@ func (g *GitService) Push(ctx context.Context, fingerprint, token string) (*Pull } slog.Info("[PUSH] Exported entries", "count", count) - // Step 6: Stage all files + // Step 6: Write .gpg-id from users table (fallback to fingerprint) + user, err := g.q.GetUser(ctx, fingerprint) + if err != nil { + return nil, fmt.Errorf("get user: %w", err) + } + gpgID := fingerprint + if user.GpgID != nil && *user.GpgID != "" { + gpgID = *user.GpgID + } + gpgIDPath := filepath.Join(repoDir, ".gpg-id") + if err := os.WriteFile(gpgIDPath, []byte(gpgID), 0600); err != nil { + return nil, fmt.Errorf("write .gpg-id: %w", err) + } + slog.Info("[PUSH] Wrote .gpg-id", "content", gpgID) + + // Step 7: Stage all files w, err := repo.Worktree() if err != nil { return nil, fmt.Errorf("get worktree: %w", err) @@ -229,7 +244,7 @@ func (g *GitService) Push(ctx context.Context, fingerprint, token string) (*Pull } slog.Info("[PUSH] Staged all files") - // Step 7: Commit + // Step 8: Commit commitMsg := fmt.Sprintf("Sync: %s", time.Now().Format(time.RFC3339)) _, err = w.Commit(commitMsg, &git.CommitOptions{ Author: &object.Signature{ @@ -243,7 +258,7 @@ func (g *GitService) Push(ctx context.Context, fingerprint, token string) (*Pull } slog.Info("[PUSH] Committed", "message", commitMsg) - // Step 8: Get remote and force push + // Step 9: Get remote and force push remote, err := repo.Remote("origin") if err != nil { return nil, fmt.Errorf("get remote: %w", err) @@ -276,7 +291,7 @@ func (g *GitService) Push(ctx context.Context, fingerprint, token string) (*Pull slog.Info("[PUSH] Pushed --force", "branch", branchName) } - // Step 9: Cleanup after + // Step 10: Cleanup after if err := g.cleanupRepoDir(fingerprint); err != nil { return nil, err } @@ -341,14 +356,28 @@ func (g *GitService) Pull(ctx context.Context, fingerprint, token string) (*Pull } slog.Info("[PULL] Cloned remote", "url", config.RepoUrl) - // Step 3: Delete all DB entries and import from clone + // Step 3: Preserve .gpg-id if present (update users table) + gpgIDPath := filepath.Join(repoDir, ".gpg-id") + if gpgIDData, err := os.ReadFile(gpgIDPath); err == nil { + gpgIDStr := string(gpgIDData) + if err := g.q.UpdateUserGpgID(ctx, dbgen.UpdateUserGpgIDParams{ + GpgID: &gpgIDStr, + Fingerprint: fingerprint, + }); err != nil { + slog.Warn("[PULL] Failed to store .gpg-id", "error", err) + } else { + slog.Info("[PULL] Stored .gpg-id", "content", gpgIDStr) + } + } + + // Step 4: Delete all DB entries and import from clone count, err := g.syncDatabase(ctx, fingerprint, repoDir) if err != nil { return nil, fmt.Errorf("sync database: %w", err) } slog.Info("[PULL] Imported entries", "count", count) - // Step 4: Cleanup after + // Step 5: Cleanup after if err := g.cleanupRepoDir(fingerprint); err != nil { return nil, err } diff --git a/srv/git_test.go b/srv/git_test.go index e01c69e..5d5ba24 100644 --- a/srv/git_test.go +++ b/srv/git_test.go @@ -38,6 +38,7 @@ func TestGitServiceConfigure(t *testing.T) { Fingerprint: fingerprint, PasswordHash: "hash", PublicKey: "pk", + GpgID: &fingerprint, }); err != nil { t.Fatalf("failed to create user: %v", err) } @@ -75,6 +76,7 @@ func TestGitServiceConfigureUpdate(t *testing.T) { Fingerprint: fingerprint, PasswordHash: "hash", PublicKey: "pk", + GpgID: &fingerprint, }); err != nil { t.Fatalf("failed to create user: %v", err) } @@ -152,6 +154,7 @@ func TestGitServiceGetStatus(t *testing.T) { Fingerprint: fingerprint, PasswordHash: "hash", PublicKey: "pk", + GpgID: &fingerprint, }); err != nil { t.Fatalf("failed to create user: %v", err) } @@ -226,6 +229,7 @@ func TestGitServiceLogGitSync(t *testing.T) { Fingerprint: fingerprint, PasswordHash: "hash", PublicKey: "pk", + GpgID: &fingerprint, }); err != nil { t.Fatalf("failed to create user: %v", err) } @@ -262,3 +266,47 @@ func TestGitServiceLogGitSync(t *testing.T) { t.Errorf("expected 5 entries changed, got %v", logs[0].EntriesChanged) } } + +func TestGitServiceUpdateUserGpgID(t *testing.T) { + s := newTestServer(t) + ctx := context.Background() + + fingerprint := "test-fp-8" + gpgID := "0xDEADBEEF" + + // Create user first (gpg_id defaults to fingerprint via CreateUserParams) + if err := s.Q.CreateUser(ctx, dbgen.CreateUserParams{ + Fingerprint: fingerprint, + PasswordHash: "hash", + PublicKey: "pk", + GpgID: &fingerprint, + }); err != nil { + t.Fatalf("failed to create user: %v", err) + } + + // Verify initial gpg_id is the fingerprint + user, err := s.Q.GetUser(ctx, fingerprint) + if err != nil { + t.Fatalf("failed to get user: %v", err) + } + if user.GpgID == nil || *user.GpgID != fingerprint { + t.Errorf("expected initial gpg_id %s, got %v", fingerprint, user.GpgID) + } + + // Update gpg_id via UpdateUserGpgID + if err := s.Q.UpdateUserGpgID(ctx, dbgen.UpdateUserGpgIDParams{ + GpgID: &gpgID, + Fingerprint: fingerprint, + }); err != nil { + t.Fatalf("failed to update gpg_id: %v", err) + } + + // Verify gpg_id was updated in users table + user, err = s.Q.GetUser(ctx, fingerprint) + if err != nil { + t.Fatalf("failed to get user: %v", err) + } + if user.GpgID == nil || *user.GpgID != gpgID { + t.Errorf("expected gpg_id %s, got %v", gpgID, user.GpgID) + } +} diff --git a/srv/server.go b/srv/server.go index 4fdd00b..b0d7b6d 100644 --- a/srv/server.go +++ b/srv/server.go @@ -3,6 +3,7 @@ package srv import ( "archive/tar" "compress/gzip" + "context" "crypto/rand" "crypto/sha256" "crypto/subtle" @@ -31,26 +32,35 @@ import ( // Server is the WebPass API server. type Server struct { - DB *sql.DB - Q *dbgen.Queries - JWTKey []byte - StaticDir string // path to frontend dist/ directory (optional) - GitService *GitService - Registration *RegistrationService - RateLimiter *RateLimiter // rate limiter for auth endpoints - sessionDuration time.Duration - cookieAuth bool // whether to use httpOnly cookies instead of localStorage - cookieSecure bool // whether to set Secure flag on cookies - cookieDomain string // optional cookie domain - bcryptCost int // bcrypt cost factor for password hashing + DB *sql.DB + Q *dbgen.Queries + JWTKey []byte + StaticDir string // path to frontend dist/ directory (optional) + GitService *GitService + Registration *RegistrationService + RateLimiter *RateLimiter // rate limiter for auth endpoints + hardLimit time.Duration // hard limit (max session time) + softLimit time.Duration // soft limit (browser closed detection) + cookieAuth bool // whether to use httpOnly cookies instead of localStorage + cookieSecure bool // whether to set Secure flag on cookies + cookieDomain string // optional cookie domain + bcryptCost int // bcrypt cost factor for password hashing // Version info (set from main package) Version string BuildTime string Commit string } +// CloseDB closes the database connection for graceful shutdown. +func (s *Server) CloseDB() error { + if s.DB != nil { + return s.DB.Close() + } + return nil +} + // New creates a new Server, opening the database and running migrations. -func New(dbPath string, jwtKey []byte, sessionDurationMin int) (*Server, error) { +func New(dbPath string, jwtKey []byte, hardLimitMin int, softLimitMin int) (*Server, error) { wdb, err := db.Open(dbPath) if err != nil { return nil, fmt.Errorf("open db: %w", err) @@ -90,7 +100,8 @@ func New(dbPath string, jwtKey []byte, sessionDurationMin int) (*Server, error) cookieDomain: cookieDomain, bcryptCost: bcryptCost, } - s.sessionDuration = time.Duration(sessionDurationMin) * time.Minute + s.hardLimit = time.Duration(hardLimitMin) * time.Minute + s.softLimit = time.Duration(softLimitMin) * time.Minute // Initialize Git service repoRoot := os.Getenv("GIT_REPO_ROOT") @@ -281,7 +292,7 @@ func (s *Server) csrfMiddleware(next http.Handler) http.Handler { Secure: s.cookieSecure, SameSite: http.SameSiteStrictMode, Domain: s.cookieDomain, - MaxAge: int(s.sessionDuration.Seconds()), + MaxAge: int(s.hardLimit.Seconds()), } http.SetCookie(w, cookie) } @@ -370,7 +381,7 @@ func parseCORSOrigins(raw string) map[string]bool { func (s *Server) createToken(fingerprint string) (string, error) { claims := jwt.MapClaims{ "fp": fingerprint, - "exp": time.Now().Add(s.sessionDuration).Unix(), + "exp": time.Now().Add(s.hardLimit).Unix(), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString(s.JWTKey) @@ -390,6 +401,19 @@ func (s *Server) requireAuth(next http.HandlerFunc) http.HandlerFunc { jsonError(w, "forbidden", http.StatusForbidden) return } + + // Check session limits (hard limit and soft limit) + if err := s.checkSessionLimits(r.Context(), fp); err != nil { + slog.Warn("session limit check failed", "fingerprint", fp, "error", err) + jsonError(w, err.Error(), http.StatusUnauthorized) + return + } + + // Update last activity timestamp + if err := s.Q.UpdateLastActivity(r.Context(), fp); err != nil { + slog.Error("update last activity", "error", err) + } + next(w, r) } } @@ -492,6 +516,41 @@ func (s *Server) verifyToken(r *http.Request) (string, error) { return fp, nil } +// checkSessionLimits verifies the session hasn't exceeded hard or soft limits. +// Hard limit: sessionDuration (e.g., 30 min) from login_time +// Soft limit: softLimit (5 min) from last_activity (browser close detection) +func (s *Server) checkSessionLimits(ctx context.Context, fp string) error { + sessionInfo, err := s.Q.GetSessionInfo(ctx, fp) + if err != nil { + // If user doesn't exist, don't block - let the handler decide + return nil + } + + now := time.Now() + + // If no login_time, this is a new session - allow access (will be set on login) + if sessionInfo.LoginTime == nil { + return nil + } + + // Check hard limit: session must not exceed sessionDuration from login_time + hardExpiry := sessionInfo.LoginTime.Add(s.hardLimit) + if now.After(hardExpiry) { + return fmt.Errorf("session expired (hard limit)") + } + + // Check soft limit: must not be away for more than softLimit from last_activity + // Only check if last_activity exists and is not too old + if sessionInfo.LastActivity != nil { + softExpiry := sessionInfo.LastActivity.Add(s.softLimit) + if now.After(softExpiry) { + return fmt.Errorf("session expired (please login again)") + } + } + + return nil +} + // setAuthCookie sets the httpOnly authentication cookie func (s *Server) setAuthCookie(w http.ResponseWriter, token string) { if !s.cookieAuth { @@ -506,7 +565,7 @@ func (s *Server) setAuthCookie(w http.ResponseWriter, token string) { Secure: s.cookieSecure, SameSite: http.SameSiteStrictMode, Domain: s.cookieDomain, - MaxAge: int(s.sessionDuration.Seconds()), + MaxAge: int(s.hardLimit.Seconds()), } http.SetCookie(w, cookie) } @@ -613,6 +672,7 @@ func (s *Server) handleCreateUser(w http.ResponseWriter, r *http.Request) { Fingerprint: fp, PasswordHash: string(hash), PublicKey: body.PublicKey, + GpgID: &fp, }); err != nil { if strings.Contains(err.Error(), "UNIQUE constraint") { jsonError(w, "user already exists", http.StatusConflict) @@ -742,6 +802,11 @@ func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { return } + // Update login_time on successful login + if err := s.Q.UpdateLoginTime(r.Context(), fp); err != nil { + slog.Error("update login time", "error", err) + } + token, err := s.createToken(fp) if err != nil { slog.Error("create token", "error", err) @@ -787,6 +852,11 @@ func (s *Server) handleLogin2FA(w http.ResponseWriter, r *http.Request) { return } + // Update login_time on successful 2FA login + if err := s.Q.UpdateLoginTime(r.Context(), fp); err != nil { + slog.Error("update login time", "error", err) + } + token, err := s.createToken(fp) if err != nil { slog.Error("create token", "error", err) diff --git a/srv/server_test.go b/srv/server_test.go index 9420544..a41007d 100644 --- a/srv/server_test.go +++ b/srv/server_test.go @@ -287,7 +287,7 @@ func newTestServer(t *testing.T) *Server { t.Setenv("REGISTRATION_TOTP_SECRET", "") // Clear TOTP secret for open registration key := []byte("test-secret-key-32-bytes-long!!!") // exactly 32 bytes - srv, err := New(dbPath, key, 5) // 5 minutes for tests + srv, err := New(dbPath, key, 5, 5) // 5 minutes hard/soft limit for tests if err != nil { t.Fatalf("new server: %v", err) }