From 5c1c38d7eee12d8e8ff81e8c9390e556e219e5ea Mon Sep 17 00:00:00 2001 From: David-patrick-chuks Date: Sun, 29 Mar 2026 14:04:49 +0100 Subject: [PATCH] feat(middleware): add redis-backed api rate limiting --- backend/src/app.module.ts | 13 +- .../middleware/rate-limit.middleware.ts | 300 ++++++++++++++++ middleware/README.md | 22 ++ middleware/jest.unit.config.js | 20 ++ middleware/package.json | 4 +- middleware/src/security/index.ts | 7 +- middleware/src/security/rate-limit.config.ts | 133 +++++++ .../src/security/rate-limit.middleware.ts | 326 ++++++++++++++++++ .../tests/unit/rate-limit.middleware.spec.ts | 126 +++++++ middleware/tests/utils/mock-express.ts | 6 +- package-lock.json | 47 ++- 11 files changed, 994 insertions(+), 10 deletions(-) create mode 100644 backend/src/common/middleware/rate-limit.middleware.ts create mode 100644 middleware/jest.unit.config.js create mode 100644 middleware/src/security/rate-limit.config.ts create mode 100644 middleware/src/security/rate-limit.middleware.ts create mode 100644 middleware/tests/unit/rate-limit.middleware.spec.ts diff --git a/backend/src/app.module.ts b/backend/src/app.module.ts index 5da1b312..bf24349b 100644 --- a/backend/src/app.module.ts +++ b/backend/src/app.module.ts @@ -22,6 +22,7 @@ import jwtConfig from './auth/authConfig/jwt.config'; import { UsersService } from './users/providers/users.service'; import { GeolocationMiddleware } from './common/middleware/geolocation.middleware'; import { HealthModule } from './health/health.module'; +import { RateLimitMiddleware } from './common/middleware/rate-limit.middleware'; // const ENV = process.env.NODE_ENV; // console.log('NODE_ENV:', process.env.NODE_ENV); @@ -104,7 +105,7 @@ import { HealthModule } from './health/health.module'; HealthModule, ], controllers: [AppController], - providers: [AppService], + providers: [AppService, RateLimitMiddleware], }) export class AppModule implements NestModule { /** @@ -124,5 +125,15 @@ export class AppModule implements NestModule { { path: 'health', method: RequestMethod.GET }, ) .forRoutes('*'); + + consumer + .apply(RateLimitMiddleware) + .exclude( + { path: 'health/(.*)', method: RequestMethod.ALL }, + { path: 'health', method: RequestMethod.ALL }, + { path: 'api', method: RequestMethod.ALL }, + { path: 'docs', method: RequestMethod.ALL }, + ) + .forRoutes('*'); } } diff --git a/backend/src/common/middleware/rate-limit.middleware.ts b/backend/src/common/middleware/rate-limit.middleware.ts new file mode 100644 index 00000000..95664e46 --- /dev/null +++ b/backend/src/common/middleware/rate-limit.middleware.ts @@ -0,0 +1,300 @@ +import { Inject, Injectable, Logger, NestMiddleware } from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; +import { NextFunction, Request, Response } from 'express'; +import Redis from 'ioredis'; +import { REDIS_CLIENT } from '../../redis/redis.constants'; + +interface RateLimitTier { + name: string; + limit: number; + windowMs: number; + burstAllowance: number; + methods?: string[]; + match: (req: Request) => boolean; +} + +interface RateLimitDecision { + allowed: boolean; + remaining: number; + retryAfterMs: number; + resetAtMs: number; + nowMs: number; +} + +const TOKEN_BUCKET_LUA = ` +local key = KEYS[1] +local capacity = tonumber(ARGV[1]) +local refill_per_ms = tonumber(ARGV[2]) +local requested = tonumber(ARGV[3]) +local ttl_ms = tonumber(ARGV[4]) + +local time = redis.call('TIME') +local now_ms = tonumber(time[1]) * 1000 + math.floor(tonumber(time[2]) / 1000) +local values = redis.call('HMGET', key, 'tokens', 'last') + +local tokens = tonumber(values[1]) +local last = tonumber(values[2]) + +if not tokens or not last then + tokens = capacity + last = now_ms +end + +if now_ms > last then + local replenished = (now_ms - last) * refill_per_ms + tokens = math.min(capacity, tokens + replenished) + last = now_ms +end + +local allowed = 0 +local retry_after_ms = 0 + +if tokens >= requested then + allowed = 1 + tokens = tokens - requested +else + retry_after_ms = math.ceil((requested - tokens) / refill_per_ms) +end + +redis.call('HMSET', key, 'tokens', tokens, 'last', last) +redis.call('PEXPIRE', key, ttl_ms) + +local remaining = math.floor(tokens) +local reset_at_ms = now_ms + math.ceil((capacity - tokens) / refill_per_ms) + +return { allowed, remaining, retry_after_ms, reset_at_ms, now_ms } +`; + +@Injectable() +export class RateLimitMiddleware implements NestMiddleware { + private readonly logger = new Logger(RateLimitMiddleware.name); + private readonly whitelistIps: string[]; + private readonly tiers: RateLimitTier[]; + + constructor( + @Inject(REDIS_CLIENT) private readonly redisClient: Redis, + private readonly configService: ConfigService, + ) { + this.whitelistIps = this.parseCsv( + this.configService.get('RATE_LIMIT_WHITELIST_IPS'), + ); + this.tiers = this.createTiers(); + } + + async use(req: Request, res: Response, next: NextFunction): Promise { + const tier = this.tiers.find((candidate) => candidate.match(req)); + if (!tier) { + return next(); + } + + const ip = this.getClientIp(req); + if (this.whitelistIps.includes(ip)) { + return next(); + } + + const userId = this.getUserId(req); + const identity = userId ? `user:${userId}` : `ip:${ip}`; + const key = `ratelimit:${tier.name}:${identity}`; + + try { + const decision = await this.consumeRateLimit(key, tier); + + res.setHeader('X-RateLimit-Limit', String(tier.limit)); + res.setHeader( + 'X-RateLimit-Remaining', + String(Math.max(0, Math.min(tier.limit, decision.remaining))), + ); + res.setHeader( + 'X-RateLimit-Reset', + String(Math.ceil(decision.resetAtMs / 1000)), + ); + + if (decision.allowed) { + return next(); + } + + const retryAfterSeconds = Math.max( + 1, + Math.ceil(decision.retryAfterMs / 1000), + ); + + res.setHeader('Retry-After', String(retryAfterSeconds)); + + this.logger.warn( + `Rate limit violation tier=${tier.name} identity=${identity} method=${req.method} path=${req.path} retry_after_ms=${decision.retryAfterMs}`, + ); + + res.status(429).json({ + statusCode: 429, + errorCode: 'RATE_LIMIT_EXCEEDED', + message: `Rate limit exceeded for ${tier.name}. Please retry later.`, + correlationId: (req as Request & { correlationId?: string }).correlationId, + timestamp: new Date(decision.nowMs).toISOString(), + path: req.originalUrl || req.url, + }); + return; + } catch (error) { + this.logger.error( + `Rate limit store failure on ${req.method} ${req.path}: ${error instanceof Error ? error.message : String(error)}`, + ); + next(); + } + } + + private async consumeRateLimit( + key: string, + tier: RateLimitTier, + ): Promise { + const capacity = tier.limit + tier.burstAllowance; + const refillPerMs = tier.limit / tier.windowMs; + const ttlMs = Math.max(tier.windowMs * 2, 60_000); + + const raw = (await this.redisClient.eval( + TOKEN_BUCKET_LUA, + 1, + key, + capacity, + refillPerMs, + 1, + ttlMs, + )) as [number, number, number, number, number]; + + const [allowed, remaining, retryAfterMs, resetAtMs, nowMs] = raw.map( + Number, + ) as [number, number, number, number, number]; + + return { + allowed: allowed === 1, + remaining, + retryAfterMs, + resetAtMs, + nowMs, + }; + } + + private createTiers(): RateLimitTier[] { + return [ + { + name: 'auth', + limit: this.getNumber('RATE_LIMIT_AUTH_LIMIT', 5), + windowMs: this.getNumber('RATE_LIMIT_AUTH_WINDOW_MS', 15 * 60 * 1000), + burstAllowance: this.getNumber('RATE_LIMIT_AUTH_BURST', 0), + methods: ['POST'], + match: (req) => + req.method === 'POST' && + this.matchesAny(req.path, [ + '/auth/signIn', + '/auth/refreshToken', + '/auth/stellar-wallet-login', + '/auth/google-authentication', + '/auth/forgot-password', + '/auth/reset-password/', + ]), + }, + { + name: 'puzzle-submit', + limit: this.getNumber('RATE_LIMIT_PUZZLE_SUBMIT_LIMIT', 30), + windowMs: this.getNumber( + 'RATE_LIMIT_PUZZLE_SUBMIT_WINDOW_MS', + 60 * 60 * 1000, + ), + burstAllowance: this.getNumber('RATE_LIMIT_PUZZLE_SUBMIT_BURST', 0), + methods: ['POST'], + match: (req) => req.method === 'POST' && req.path === '/progress/submit', + }, + { + name: 'daily-quest-generate', + limit: this.getNumber('RATE_LIMIT_DAILY_QUEST_LIMIT', 2), + windowMs: this.getNumber( + 'RATE_LIMIT_DAILY_QUEST_WINDOW_MS', + 24 * 60 * 60 * 1000, + ), + burstAllowance: this.getNumber('RATE_LIMIT_DAILY_QUEST_BURST', 0), + methods: ['GET'], + match: (req) => req.method === 'GET' && req.path === '/daily-quest', + }, + { + name: 'admin', + limit: this.getNumber('RATE_LIMIT_ADMIN_LIMIT', 1000), + windowMs: this.getNumber('RATE_LIMIT_ADMIN_WINDOW_MS', 60 * 60 * 1000), + burstAllowance: this.getNumber('RATE_LIMIT_ADMIN_BURST', 0), + match: (req) => req.path.startsWith('/admin/'), + }, + { + name: 'public-landing', + limit: this.getNumber('RATE_LIMIT_PUBLIC_LIMIT', 1000), + windowMs: this.getNumber( + 'RATE_LIMIT_PUBLIC_WINDOW_MS', + 60 * 60 * 1000, + ), + burstAllowance: this.getNumber('RATE_LIMIT_PUBLIC_BURST', 0), + methods: ['GET'], + match: (req) => req.method === 'GET' && req.path === '/', + }, + { + name: 'read-only', + limit: this.getNumber('RATE_LIMIT_READ_LIMIT', 300), + windowMs: this.getNumber('RATE_LIMIT_READ_WINDOW_MS', 60 * 60 * 1000), + burstAllowance: this.getNumber('RATE_LIMIT_READ_BURST', 0), + match: (req) => + ['GET', 'HEAD'].includes(req.method) && + req.path !== '/' && + req.path !== '/daily-quest' && + !req.path.startsWith('/admin/') && + !req.path.startsWith('/health') && + !req.path.startsWith('/api') && + !req.path.startsWith('/docs'), + }, + ]; + } + + private matchesAny(path: string, prefixes: string[]): boolean { + return prefixes.some( + (prefix) => path === prefix || path.startsWith(prefix), + ); + } + + private getUserId(req: Request): string | undefined { + const user = (req as Request & { + user?: { userId?: string | number; sub?: string | number; id?: string | number }; + }).user; + + const candidate = user?.userId ?? user?.sub ?? user?.id; + return candidate !== undefined ? String(candidate) : undefined; + } + + private getClientIp(req: Request): string { + const forwarded = req.headers['x-forwarded-for']; + if (typeof forwarded === 'string' && forwarded.length > 0) { + return forwarded.split(',')[0].trim(); + } + + if (Array.isArray(forwarded) && forwarded.length > 0) { + return forwarded[0].split(',')[0].trim(); + } + + return ( + req.ip || + req.socket?.remoteAddress || + (req as Request & { connection?: { remoteAddress?: string } }).connection + ?.remoteAddress || + 'unknown' + ); + } + + private getNumber(key: string, fallback: number): number { + const value = Number(this.configService.get(key)); + return Number.isFinite(value) && value > 0 ? value : fallback; + } + + private parseCsv(value?: string): string[] { + if (!value) { + return []; + } + + return value + .split(',') + .map((entry) => entry.trim()) + .filter(Boolean); + } +} diff --git a/middleware/README.md b/middleware/README.md index 39c04a88..0c0b254e 100644 --- a/middleware/README.md +++ b/middleware/README.md @@ -73,6 +73,28 @@ Common environment variables (expected across middleware in the future) may incl - `JWT_SECRET` - `JWT_EXPIRES_IN` - `BCRYPT_SALT_ROUNDS` +- `RATE_LIMIT_AUTH_LIMIT` +- `RATE_LIMIT_AUTH_WINDOW_MS` +- `RATE_LIMIT_PUZZLE_SUBMIT_LIMIT` +- `RATE_LIMIT_DAILY_QUEST_LIMIT` +- `RATE_LIMIT_READ_LIMIT` +- `RATE_LIMIT_ADMIN_LIMIT` +- `RATE_LIMIT_PUBLIC_LIMIT` +- `RATE_LIMIT_WHITELIST_IPS` + +## Rate Limiting + +The `security` package now includes a Redis-backed `RateLimitMiddleware` with: + +- Per-tier limits for authentication, puzzle submission, daily quest generation, admin routes, read-only routes, and the public landing page +- User ID tracking for authenticated requests, with IP fallback for anonymous traffic +- Whitelisted IP exemptions +- `429 Too Many Requests` responses with `Retry-After` +- `X-RateLimit-Limit`, `X-RateLimit-Remaining`, and `X-RateLimit-Reset` headers +- Configurable burst allowance per tier +- Fail-open behavior when Redis is temporarily unavailable + +Use `createDefaultRateLimitConfig()` from `src/security/rate-limit.config.ts` to build tier settings from environment variables. ## Testing diff --git a/middleware/jest.unit.config.js b/middleware/jest.unit.config.js new file mode 100644 index 00000000..1a09ad08 --- /dev/null +++ b/middleware/jest.unit.config.js @@ -0,0 +1,20 @@ +/** @type {import('jest').Config} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + testMatch: ['**/tests/unit/**/*.test.ts', '**/tests/unit/**/*.spec.ts'], + collectCoverageFrom: ['src/**/*.ts'], + coverageDirectory: 'coverage/unit', + coverageThreshold: { + global: { + branches: 80, + functions: 80, + lines: 80, + statements: 80, + }, + }, + setupFilesAfterEnv: ['/tests/setup.ts'], + moduleNameMapper: { + '^@mindblock/middleware/(.*)$': '/src/$1', + }, +}; diff --git a/middleware/package.json b/middleware/package.json index 3c989583..b56d8373 100644 --- a/middleware/package.json +++ b/middleware/package.json @@ -8,12 +8,12 @@ "scripts": { "build": "tsc -p tsconfig.json", "test": "npm run test:unit && npm run test:integration && npm run test:e2e", - "test:unit": "jest --config jest.unit.config.ts", + "test:unit": "jest --config jest.unit.config.js", "test:integration": "jest --config jest.integration.config.ts", "test:e2e": "jest --config jest.e2e.config.ts", "test:watch": "jest --watch --passWithNoTests", "test:cov": "jest --coverage --passWithNoTests", - "test:unit:cov": "jest --config jest.unit.config.ts --coverage", + "test:unit:cov": "jest --config jest.unit.config.js --coverage", "test:integration:cov": "jest --config jest.integration.config.ts --coverage", "test:e2e:cov": "jest --config jest.e2e.config.ts --coverage", "lint": "eslint -c eslint.config.mjs \"src/**/*.ts\" \"tests/**/*.ts\"", diff --git a/middleware/src/security/index.ts b/middleware/src/security/index.ts index f3e26a5f..f6683fa1 100644 --- a/middleware/src/security/index.ts +++ b/middleware/src/security/index.ts @@ -1,3 +1,4 @@ -// Placeholder: security middleware exports will live here. - -export const __securityPlaceholder = true; +export * from './security-headers.config'; +export * from './security-headers.middleware'; +export * from './rate-limit.config'; +export * from './rate-limit.middleware'; diff --git a/middleware/src/security/rate-limit.config.ts b/middleware/src/security/rate-limit.config.ts new file mode 100644 index 00000000..0825e625 --- /dev/null +++ b/middleware/src/security/rate-limit.config.ts @@ -0,0 +1,133 @@ +import { Request } from 'express'; +import { + createRateLimitTier, + RateLimitTier, +} from './rate-limit.middleware'; + +export interface RateLimitEnvironment { + RATE_LIMIT_AUTH_LIMIT?: string; + RATE_LIMIT_AUTH_WINDOW_MS?: string; + RATE_LIMIT_AUTH_BURST?: string; + RATE_LIMIT_PUZZLE_SUBMIT_LIMIT?: string; + RATE_LIMIT_PUZZLE_SUBMIT_WINDOW_MS?: string; + RATE_LIMIT_PUZZLE_SUBMIT_BURST?: string; + RATE_LIMIT_DAILY_QUEST_LIMIT?: string; + RATE_LIMIT_DAILY_QUEST_WINDOW_MS?: string; + RATE_LIMIT_DAILY_QUEST_BURST?: string; + RATE_LIMIT_READ_LIMIT?: string; + RATE_LIMIT_READ_WINDOW_MS?: string; + RATE_LIMIT_READ_BURST?: string; + RATE_LIMIT_ADMIN_LIMIT?: string; + RATE_LIMIT_ADMIN_WINDOW_MS?: string; + RATE_LIMIT_ADMIN_BURST?: string; + RATE_LIMIT_PUBLIC_LIMIT?: string; + RATE_LIMIT_PUBLIC_WINDOW_MS?: string; + RATE_LIMIT_PUBLIC_BURST?: string; + RATE_LIMIT_WHITELIST_IPS?: string; +} + +export interface RateLimitResolvedConfig { + tiers: RateLimitTier[]; + whitelistIps: string[]; +} + +function parseNumber(value: string | undefined, fallback: number): number { + const parsed = Number(value); + return Number.isFinite(parsed) && parsed > 0 ? parsed : fallback; +} + +function parseCsv(value: string | undefined): string[] { + if (!value) { + return []; + } + + return value + .split(',') + .map((entry) => entry.trim()) + .filter(Boolean); +} + +function isReadOnlyRequest(req: Request): boolean { + return ['GET', 'HEAD'].includes(req.method.toUpperCase()); +} + +export function createDefaultRateLimitConfig( + env: RateLimitEnvironment = process.env, +): RateLimitResolvedConfig { + const tiers: RateLimitTier[] = [ + createRateLimitTier( + 'auth', + parseNumber(env.RATE_LIMIT_AUTH_LIMIT, 5), + parseNumber(env.RATE_LIMIT_AUTH_WINDOW_MS, 15 * 60 * 1000), + { + burstAllowance: parseNumber(env.RATE_LIMIT_AUTH_BURST, 0), + methods: ['POST'], + pathPatterns: [ + '/auth/signIn', + '/auth/stellar-wallet-login', + '/auth/google-authentication', + '/auth/forgot-password', + '/auth/reset-password/*', + '/auth/refreshToken', + ], + }, + ), + createRateLimitTier( + 'puzzle-submit', + parseNumber(env.RATE_LIMIT_PUZZLE_SUBMIT_LIMIT, 30), + parseNumber(env.RATE_LIMIT_PUZZLE_SUBMIT_WINDOW_MS, 60 * 60 * 1000), + { + burstAllowance: parseNumber(env.RATE_LIMIT_PUZZLE_SUBMIT_BURST, 0), + methods: ['POST'], + pathPatterns: ['/progress/submit'], + }, + ), + createRateLimitTier( + 'daily-quest-generate', + parseNumber(env.RATE_LIMIT_DAILY_QUEST_LIMIT, 2), + parseNumber( + env.RATE_LIMIT_DAILY_QUEST_WINDOW_MS, + 24 * 60 * 60 * 1000, + ), + { + burstAllowance: parseNumber(env.RATE_LIMIT_DAILY_QUEST_BURST, 0), + methods: ['GET'], + pathPatterns: ['/daily-quest'], + }, + ), + createRateLimitTier( + 'admin', + parseNumber(env.RATE_LIMIT_ADMIN_LIMIT, 1000), + parseNumber(env.RATE_LIMIT_ADMIN_WINDOW_MS, 60 * 60 * 1000), + { + burstAllowance: parseNumber(env.RATE_LIMIT_ADMIN_BURST, 0), + pathPatterns: ['/admin/**'], + }, + ), + createRateLimitTier( + 'public-landing', + parseNumber(env.RATE_LIMIT_PUBLIC_LIMIT, 1000), + parseNumber(env.RATE_LIMIT_PUBLIC_WINDOW_MS, 60 * 60 * 1000), + { + burstAllowance: parseNumber(env.RATE_LIMIT_PUBLIC_BURST, 0), + methods: ['GET'], + pathPatterns: ['/'], + }, + ), + createRateLimitTier( + 'read-only', + parseNumber(env.RATE_LIMIT_READ_LIMIT, 300), + parseNumber(env.RATE_LIMIT_READ_WINDOW_MS, 60 * 60 * 1000), + { + burstAllowance: parseNumber(env.RATE_LIMIT_READ_BURST, 0), + match: isReadOnlyRequest, + pathPatterns: ['/**'], + }, + ), + ]; + + return { + tiers, + whitelistIps: parseCsv(env.RATE_LIMIT_WHITELIST_IPS), + }; +} diff --git a/middleware/src/security/rate-limit.middleware.ts b/middleware/src/security/rate-limit.middleware.ts new file mode 100644 index 00000000..4f6ab343 --- /dev/null +++ b/middleware/src/security/rate-limit.middleware.ts @@ -0,0 +1,326 @@ +import { Injectable, Logger, NestMiddleware } from '@nestjs/common'; +import { NextFunction, Request, Response } from 'express'; +import micromatch from 'micromatch'; + +export interface RateLimitStore { + eval( + script: string, + numKeys: number, + ...args: Array + ): Promise; +} + +export interface RateLimitRequestIdentity { + key: string; + userId?: string; + ip: string; +} + +export interface RateLimitTier { + name: string; + limit: number; + windowMs: number; + burstAllowance?: number; + methods?: string[]; + pathPatterns?: Array; + match?: (req: Request) => boolean; + keyPrefix?: string; +} + +export interface RateLimitMiddlewareOptions { + store: RateLimitStore; + tiers: RateLimitTier[]; + whitelistIps?: string[]; + trustProxy?: boolean; + keyGenerator?: (req: Request) => RateLimitRequestIdentity; + onViolation?: (context: RateLimitViolationContext) => void; + onStoreError?: (error: unknown, req: Request, tier: RateLimitTier) => void; + logger?: Pick; +} + +export interface RateLimitViolationContext { + req: Request; + tier: RateLimitTier; + identity: RateLimitRequestIdentity; + retryAfterMs: number; + remaining: number; + resetAtMs: number; +} + +export interface RateLimitDecision { + allowed: boolean; + remaining: number; + retryAfterMs: number; + resetAtMs: number; + nowMs: number; +} + +const TOKEN_BUCKET_LUA = ` +local key = KEYS[1] +local capacity = tonumber(ARGV[1]) +local refill_per_ms = tonumber(ARGV[2]) +local requested = tonumber(ARGV[3]) +local ttl_ms = tonumber(ARGV[4]) + +local time = redis.call('TIME') +local now_ms = tonumber(time[1]) * 1000 + math.floor(tonumber(time[2]) / 1000) +local values = redis.call('HMGET', key, 'tokens', 'last') + +local tokens = tonumber(values[1]) +local last = tonumber(values[2]) + +if not tokens or not last then + tokens = capacity + last = now_ms +end + +if now_ms > last then + local replenished = (now_ms - last) * refill_per_ms + tokens = math.min(capacity, tokens + replenished) + last = now_ms +end + +local allowed = 0 +local retry_after_ms = 0 + +if tokens >= requested then + allowed = 1 + tokens = tokens - requested +else + retry_after_ms = math.ceil((requested - tokens) / refill_per_ms) +end + +redis.call('HMSET', key, 'tokens', tokens, 'last', last) +redis.call('PEXPIRE', key, ttl_ms) + +local remaining = math.floor(tokens) +local reset_at_ms = now_ms + math.ceil((capacity - tokens) / refill_per_ms) + +return { allowed, remaining, retry_after_ms, reset_at_ms, now_ms } +`; + +const DEFAULT_LOGGER = new Logger('RateLimitMiddleware'); + +type RawEvalResult = [number, number, number, number, number]; + +function normalizePath(path: string): string { + return path.startsWith('/') ? path : `/${path}`; +} + +function matchesPattern(path: string, pattern: string | RegExp): boolean { + if (pattern instanceof RegExp) { + return pattern.test(path); + } + + return micromatch.isMatch(path, pattern); +} + +function getClientIp(req: Request): string { + const forwarded = req.headers['x-forwarded-for']; + if (typeof forwarded === 'string' && forwarded.length > 0) { + return forwarded.split(',')[0].trim(); + } + + if (Array.isArray(forwarded) && forwarded.length > 0) { + return forwarded[0].split(',')[0].trim(); + } + + return ( + req.ip || + (req.socket?.remoteAddress ?? '') || + (req.connection?.remoteAddress ?? '') || + 'unknown' + ); +} + +function getRequestUserId(req: Request): string | undefined { + const user = (req as Request & { user?: Record }).user; + if (!user) { + return undefined; + } + + const candidates = [user.userId, user.sub, user.id]; + const firstValue = candidates.find( + (candidate): candidate is string | number => + typeof candidate === 'string' || typeof candidate === 'number', + ); + + return firstValue !== undefined ? String(firstValue) : undefined; +} + +function defaultKeyGenerator(req: Request): RateLimitRequestIdentity { + const userId = getRequestUserId(req); + const ip = getClientIp(req); + + if (userId) { + return { + key: `user:${userId}`, + userId, + ip, + }; + } + + return { + key: `ip:${ip}`, + ip, + }; +} + +function matchTier(req: Request, tiers: RateLimitTier[]): RateLimitTier | undefined { + const method = req.method.toUpperCase(); + const path = normalizePath(req.path || req.url || '/'); + + return tiers.find((tier) => { + const methodsMatch = + !tier.methods || + tier.methods.length === 0 || + tier.methods.some((tierMethod) => tierMethod.toUpperCase() === method); + + if (!methodsMatch) { + return false; + } + + if (tier.match && tier.match(req)) { + return true; + } + + if (!tier.pathPatterns || tier.pathPatterns.length === 0) { + return false; + } + + return tier.pathPatterns.some((pattern) => matchesPattern(path, pattern)); + }); +} + +function setRateLimitHeaders( + res: Response, + tier: RateLimitTier, + remaining: number, + resetAtMs: number, +): void { + res.setHeader('X-RateLimit-Limit', String(tier.limit)); + res.setHeader( + 'X-RateLimit-Remaining', + String(Math.max(0, Math.min(tier.limit, remaining))), + ); + res.setHeader('X-RateLimit-Reset', String(Math.ceil(resetAtMs / 1000))); +} + +function getExceededMessage(tier: RateLimitTier): string { + return `Rate limit exceeded for ${tier.name}. Please retry later.`; +} + +export async function consumeRateLimit( + store: RateLimitStore, + key: string, + tier: RateLimitTier, +): Promise { + const capacity = tier.limit + (tier.burstAllowance ?? 0); + const refillPerMs = tier.limit / tier.windowMs; + const ttlMs = Math.max(tier.windowMs * 2, 60_000); + + const raw = (await store.eval( + TOKEN_BUCKET_LUA, + 1, + key, + capacity, + refillPerMs, + 1, + ttlMs, + )) as RawEvalResult; + + const [allowed, remaining, retryAfterMs, resetAtMs, nowMs] = raw.map(Number) as RawEvalResult; + + return { + allowed: allowed === 1, + remaining, + retryAfterMs, + resetAtMs, + nowMs, + }; +} + +@Injectable() +export class RateLimitMiddleware implements NestMiddleware { + private readonly logger: Pick; + + constructor(private readonly options: RateLimitMiddlewareOptions) { + this.logger = options.logger ?? DEFAULT_LOGGER; + } + + async use(req: Request, res: Response, next: NextFunction): Promise { + const tier = matchTier(req, this.options.tiers); + if (!tier) { + return next(); + } + + const ip = getClientIp(req); + if (this.options.whitelistIps?.includes(ip)) { + return next(); + } + + const identity = (this.options.keyGenerator ?? defaultKeyGenerator)(req); + const storageKey = `${tier.keyPrefix ?? 'ratelimit'}:${tier.name}:${identity.key}`; + + try { + const decision = await consumeRateLimit(this.options.store, storageKey, tier); + + setRateLimitHeaders(res, tier, decision.remaining, decision.resetAtMs); + + if (decision.allowed) { + return next(); + } + + const retryAfterSeconds = Math.max( + 1, + Math.ceil(decision.retryAfterMs / 1000), + ); + + res.setHeader('Retry-After', String(retryAfterSeconds)); + + const violationContext: RateLimitViolationContext = { + req, + tier, + identity, + retryAfterMs: decision.retryAfterMs, + remaining: decision.remaining, + resetAtMs: decision.resetAtMs, + }; + + this.options.onViolation?.(violationContext); + this.logger.warn( + `Rate limit violation tier=${tier.name} key=${identity.key} ip=${identity.ip} path=${req.method} ${req.path} retry_after_ms=${decision.retryAfterMs}`, + ); + + res.status(429).json({ + statusCode: 429, + errorCode: 'RATE_LIMIT_EXCEEDED', + message: getExceededMessage(tier), + timestamp: new Date(decision.nowMs).toISOString(), + path: req.originalUrl || req.url, + }); + return; + } catch (error) { + this.options.onStoreError?.(error, req, tier); + this.logger.error( + `Rate limit store failure on ${req.method} ${req.path}: ${error instanceof Error ? error.message : String(error)}`, + ); + next(); + } + } +} + +export function createRateLimitTier( + name: string, + limit: number, + windowMs: number, + options: Omit = {}, +): RateLimitTier { + return { + name, + limit, + windowMs, + ...options, + }; +} + diff --git a/middleware/tests/unit/rate-limit.middleware.spec.ts b/middleware/tests/unit/rate-limit.middleware.spec.ts new file mode 100644 index 00000000..aa495ad3 --- /dev/null +++ b/middleware/tests/unit/rate-limit.middleware.spec.ts @@ -0,0 +1,126 @@ +import { Request } from 'express'; +import { + createRateLimitTier, + RateLimitMiddleware, + RateLimitStore, +} from '../../src/security/rate-limit.middleware'; +import { createMiddlewareTestContext } from '../utils/mock-express'; + +describe('RateLimitMiddleware', () => { + const tiers = [ + createRateLimitTier('auth', 5, 15 * 60 * 1000, { + methods: ['POST'], + pathPatterns: ['/auth/**'], + }), + createRateLimitTier('puzzle-submit', 30, 60 * 60 * 1000, { + methods: ['POST'], + pathPatterns: ['/progress/submit'], + }), + createRateLimitTier('daily-quest-generate', 2, 24 * 60 * 60 * 1000, { + methods: ['GET'], + pathPatterns: ['/daily-quest'], + }), + createRateLimitTier('read-only', 300, 60 * 60 * 1000, { + methods: ['GET'], + pathPatterns: ['/**'], + }), + ]; + + it('applies user-based keys when an authenticated user exists', async () => { + const evalMock = jest.fn().mockResolvedValue([1, 29, 0, 1_710_000_000_000, 1_700_000_000_000]); + const store: RateLimitStore = { eval: evalMock }; + const middleware = new RateLimitMiddleware({ store, tiers }); + const { req, res, next } = createMiddlewareTestContext({ + req: { + method: 'POST', + path: '/progress/submit', + }, + }); + + ((req as unknown) as Request & { user?: Record }).user = { + userId: 'user-123', + }; + + await middleware.use(req as unknown as Request, res as any, next); + + expect(evalMock).toHaveBeenCalled(); + expect(String(evalMock.mock.calls[0][2])).toContain( + 'ratelimit:puzzle-submit:user:user-123', + ); + expect(res.setHeader).toHaveBeenCalledWith('X-RateLimit-Limit', '30'); + expect(next).toHaveBeenCalledTimes(1); + }); + + it('returns 429 with headers when the request is throttled', async () => { + const store: RateLimitStore = { + eval: jest.fn().mockResolvedValue([0, 0, 12_000, 1_710_000_012_000, 1_700_000_000_000]), + }; + const middleware = new RateLimitMiddleware({ store, tiers }); + const { req, res, next } = createMiddlewareTestContext({ + req: { + method: 'POST', + path: '/auth/signIn', + }, + }); + + await middleware.use(req as unknown as Request, res as any, next); + + expect(res.status).toHaveBeenCalledWith(429); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + statusCode: 429, + errorCode: 'RATE_LIMIT_EXCEEDED', + }), + ); + expect(res.setHeader).toHaveBeenCalledWith('Retry-After', '12'); + expect(next).not.toHaveBeenCalled(); + }); + + it('bypasses whitelisted IPs', async () => { + const store: RateLimitStore = { + eval: jest.fn(), + }; + const middleware = new RateLimitMiddleware({ + store, + tiers, + whitelistIps: ['127.0.0.1'], + }); + const { req, res, next } = createMiddlewareTestContext({ + req: { + method: 'POST', + path: '/auth/signIn', + headers: { + 'x-forwarded-for': '127.0.0.1', + }, + }, + }); + + await middleware.use(req as unknown as Request, res as any, next); + + expect(store.eval).not.toHaveBeenCalled(); + expect(next).toHaveBeenCalledTimes(1); + }); + + it('fails open when Redis is unavailable', async () => { + const store: RateLimitStore = { + eval: jest.fn().mockRejectedValue(new Error('redis offline')), + }; + const onStoreError = jest.fn(); + const middleware = new RateLimitMiddleware({ + store, + tiers, + onStoreError, + }); + const { req, next } = createMiddlewareTestContext({ + req: { + method: 'GET', + path: '/puzzles', + }, + }); + + await middleware.use(req as unknown as Request, {} as any, next); + + expect(onStoreError).toHaveBeenCalled(); + expect(next).toHaveBeenCalledTimes(1); + }); +}); diff --git a/middleware/tests/utils/mock-express.ts b/middleware/tests/utils/mock-express.ts index 81ed9bad..9ca80785 100644 --- a/middleware/tests/utils/mock-express.ts +++ b/middleware/tests/utils/mock-express.ts @@ -1,4 +1,4 @@ -import { Response, NextFunction } from 'express'; +import { NextFunction } from 'express'; /** * Mock Express request object with proper typing @@ -38,13 +38,14 @@ export function mockRequest(overrides?: Partial): MockRequest { /** * Mock Express response object with proper typing */ -export interface MockResponse extends Partial { +export interface MockResponse { statusCode: number; statusMessage: string; headersSent: boolean; json: jest.Mock; send: jest.Mock; set: jest.Mock, string?]>; + setHeader: jest.Mock; status: jest.Mock; end: jest.Mock; on: jest.Mock; @@ -61,6 +62,7 @@ export function mockResponse(overrides?: Partial): MockResponse { json: jest.fn().mockReturnThis(), send: jest.fn().mockReturnThis(), set: jest.fn().mockReturnThis(), + setHeader: jest.fn().mockReturnThis(), status: jest.fn().mockReturnThis(), end: jest.fn().mockReturnThis(), on: jest.fn(), diff --git a/package-lock.json b/package-lock.json index 36196d54..ba03a1aa 100644 --- a/package-lock.json +++ b/package-lock.json @@ -774,18 +774,23 @@ "version": "0.1.0", "dependencies": { "@nestjs/common": "^11.0.12", + "@nestjs/config": "^4.0.0", "@types/micromatch": "^4.0.10", "bcrypt": "^6.0.0", "class-transformer": "^0.5.1", "class-validator": "^0.14.1", "express": "^5.1.0", "jsonwebtoken": "^9.0.2", - "micromatch": "^4.0.8" + "micromatch": "^4.0.8", + "prom-client": "^15.1.3", + "stellar-sdk": "^13.1.0" }, "devDependencies": { + "@nestjs/testing": "^11.0.12", "@types/express": "^5.0.0", "@types/jest": "^29.5.14", "@types/node": "^22.10.7", + "@types/supertest": "^6.0.2", "@typescript-eslint/eslint-plugin": "^8.20.0", "@typescript-eslint/parser": "^8.20.0", "eslint": "^9.18.0", @@ -793,6 +798,7 @@ "globals": "^16.0.0", "jest": "^29.7.0", "prettier": "^3.4.2", + "supertest": "^7.0.0", "ts-jest": "^29.2.5", "typescript": "^5.7.3", "typescript-eslint": "^8.20.0" @@ -3685,7 +3691,6 @@ "resolved": "https://registry.npmjs.org/@nestjs/platform-express/-/platform-express-11.1.12.tgz", "integrity": "sha512-GYK/vHI0SGz5m8mxr7v3Urx8b9t78Cf/dj5aJMZlGd9/1D9OI1hAl00BaphjEXINUJ/BQLxIlF2zUjrYsd6enQ==", "license": "MIT", - "peer": true, "dependencies": { "cors": "2.8.5", "express": "5.2.1", @@ -4064,6 +4069,16 @@ "npm": ">=5.10.0" } }, + "node_modules/@opentelemetry/api": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.1.tgz", + "integrity": "sha512-gLyJlPHPZYdAk1JENA9LeHejZe1Ti77/pTeFm/nMXmQH/HFZlcS/O2XJB+L8fkbrNSqhdtlvjBVjxwUYanNH5Q==", + "license": "Apache-2.0", + "peer": true, + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/@paralleldrive/cuid2": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/@paralleldrive/cuid2/-/cuid2-2.3.1.tgz", @@ -7269,6 +7284,12 @@ "file-uri-to-path": "1.0.0" } }, + "node_modules/bintrees": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/bintrees/-/bintrees-1.0.2.tgz", + "integrity": "sha512-VOMgTMwjAaUG580SXn3LacVgjurrbMme7ZZNYGSSV7mmtY6QQRh0Eg3pwIcntQ77DErK1L0NxkbetjcoXzVwKw==", + "license": "MIT" + }, "node_modules/bl": { "version": "4.1.0", "resolved": "https://registry.npmjs.org/bl/-/bl-4.1.0.tgz", @@ -15099,6 +15120,19 @@ "dev": true, "license": "MIT" }, + "node_modules/prom-client": { + "version": "15.1.3", + "resolved": "https://registry.npmjs.org/prom-client/-/prom-client-15.1.3.tgz", + "integrity": "sha512-6ZiOBfCywsD4k1BN9IX0uZhF+tJkV8q8llP64G5Hajs4JOeVLPCwpPVcpXy3BwYiUGgyJzsJJQeOIv7+hDSq8g==", + "license": "Apache-2.0", + "dependencies": { + "@opentelemetry/api": "^1.4.0", + "tdigest": "^0.1.1" + }, + "engines": { + "node": "^16 || ^18 || >=20" + } + }, "node_modules/promise-inflight": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/promise-inflight/-/promise-inflight-1.0.1.tgz", @@ -17089,6 +17123,15 @@ "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", "license": "ISC" }, + "node_modules/tdigest": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/tdigest/-/tdigest-0.1.2.tgz", + "integrity": "sha512-+G0LLgjjo9BZX2MfdvPfH+MKLCrxlXSYec5DaPYP1fe6Iyhf0/fSmJ0bFiZ1F8BT6cGXl2LpltQptzjXKWEkKA==", + "license": "MIT", + "dependencies": { + "bintrees": "1.0.2" + } + }, "node_modules/terser": { "version": "5.46.0", "resolved": "https://registry.npmjs.org/terser/-/terser-5.46.0.tgz",