diff --git a/.deployment-trigger b/.deployment-trigger new file mode 100644 index 0000000..c034487 --- /dev/null +++ b/.deployment-trigger @@ -0,0 +1 @@ +# NextAuth fix deployment trigger diff --git a/.github/workflows/cloudflare-branch.yml b/.github/workflows/cloudflare-branch.yml new file mode 100644 index 0000000..7cc4fb2 --- /dev/null +++ b/.github/workflows/cloudflare-branch.yml @@ -0,0 +1,67 @@ +name: Cloudflare Branch CI/CD + +on: + push: + branches: + - cloudflare + pull_request: + branches: + - cloudflare + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: "22" + cache: "npm" + + - name: Install dependencies + run: npm ci + + - name: Run linting and type checking + run: npm run verify + + - name: Run unit tests + run: npm run test:run + + deploy: + runs-on: ubuntu-latest + needs: test + if: github.ref == 'refs/heads/cloudflare' && github.event_name == 'push' + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: "22" + cache: "npm" + + - name: Install dependencies + run: npm ci + + - name: Build with OpenNext + run: npm run build:cf + env: + NEXTAUTH_SECRET: ${{ secrets.NEXTAUTH_SECRET }} + NEXTAUTH_URL: "https://comprehendo.tre.systems" + GOOGLE_CLIENT_ID: ${{ secrets.GOOGLE_CLIENT_ID }} + GOOGLE_CLIENT_SECRET: ${{ secrets.GOOGLE_CLIENT_SECRET }} + ADMIN_EMAILS: ${{ secrets.ADMIN_EMAILS }} + + - name: Deploy to Cloudflare Workers + uses: cloudflare/wrangler-action@v3 + with: + apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }} + accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }} + command: deploy --minify + env: + NEXTAUTH_SECRET: ${{ secrets.NEXTAUTH_SECRET }} + NEXTAUTH_URL: "https://comprehendo.tre.systems" diff --git a/.github/workflows/cloudflare.yml b/.github/workflows/cloudflare.yml new file mode 100644 index 0000000..3f60cf5 --- /dev/null +++ b/.github/workflows/cloudflare.yml @@ -0,0 +1,65 @@ +name: Deploy to Cloudflare Workers + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: "22" + cache: "npm" + + - name: Install dependencies + run: npm ci + + - name: Run linting and type checking + run: npm run verify + + - name: Run unit tests + run: npm run test:run + + deploy: + runs-on: ubuntu-latest + needs: test + if: github.ref == 'refs/heads/main' && github.event_name == 'push' + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: "22" + cache: "npm" + + - name: Install dependencies + run: npm ci + + - name: Build for Cloudflare + run: npm run build:cf + env: + NEXTAUTH_SECRET: ${{ secrets.NEXTAUTH_SECRET }} + NEXTAUTH_URL: "https://comprehendo.tre.systems" + GOOGLE_CLIENT_ID: ${{ secrets.GOOGLE_CLIENT_ID }} + GOOGLE_CLIENT_SECRET: ${{ secrets.GOOGLE_CLIENT_SECRET }} + ADMIN_EMAILS: ${{ secrets.ADMIN_EMAILS }} + + - name: Deploy to Cloudflare Workers + uses: cloudflare/wrangler-action@v3 + with: + apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }} + accountId: ${{ secrets.CLOUDFLARE_ACCOUNT_ID }} + command: deploy --minify + env: + NEXTAUTH_SECRET: ${{ secrets.NEXTAUTH_SECRET }} + NEXTAUTH_URL: "https://comprehendo.tre.systems" diff --git a/.gitignore b/.gitignore index b609c06..c820baf 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,7 @@ test/e2e/auth/*.storageState.json # local env files .pnp.* .yarn/ +.open-next/ + +# Cloudflare Wrangler local state +.wrangler/ diff --git a/README.md b/README.md index ceacc00..77ffc4d 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ A multi-language reading comprehension practice tool powered by Next.js and Google Gemini. -[![CI/CD](https://github.com/rgilks/comprehendo/actions/workflows/fly.yml/badge.svg)](https://github.com/rgilks/comprehendo/actions/workflows/fly.yml) +[![CI/CD](https://github.com/rgilks/comprehendo/actions/workflows/cloudflare.yml/badge.svg)](https://github.com/rgilks/comprehendo/actions/workflows/cloudflare.yml) ![Comprehendo Screenshot](public/screenshot.png) @@ -33,7 +33,7 @@ Comprehendo is an AI-powered language learning application designed to help user - **Cost-Control System**: IP-based rate limiting, translation caching, daily AI API budgets, and database caching to manage API costs. - **Robust Validation**: Uses Zod for request validation on API routes and environment variables. - **Smooth Loading Experience**: Enhanced loading indicators and transitions. -- **Continuous Deployment**: Automatic deployment to Fly.io via GitHub Actions when code is pushed to the `main` branch. +- **Continuous Deployment**: Automatic deployment to Cloudflare Workers via GitHub Actions when code is pushed to the `main` branch. - **Admin Panel**: A secure area for administrators to view application data (users, quizzes, feedback). - **Internationalization (i18n)**: Full i18n support for UI elements using `i18next` and locale files in `public/locales/`. - **PWA Support**: Progressive Web App features (e.g., installability) are enabled via `@serwist/next`, relying on browser/device native installation prompts. @@ -140,7 +140,7 @@ Comprehendo implements strategies to manage AI API costs: - `GOOGLE_CLIENT_ID`, `GOOGLE_CLIENT_SECRET` (optional, if enabling Google login) - `DISCORD_CLIENT_ID`, `DISCORD_CLIENT_SECRET` (optional, if enabling Discord login) - `AUTH_SECRET`: Generate with `openssl rand -base64 32` - - `NEXTAUTH_URL`: The canonical URL of your deployment (e.g., `http://localhost:3000` for local development). + - `NEXTAUTH_URL`: The canonical URL of your deployment (e.g., `http://localhost:3000` for local development, `https://comprehendo.tre.systems` for production). - `ADMIN_EMAILS`: Comma-separated list of emails for admin access (e.g., `admin@example.com,test@test.com`). - `GOOGLE_TRANSLATE_API_KEY`: (Optional) Needed for hover translation feature. - `RATE_LIMIT_MAX_REQUESTS_PER_HOUR`: (Optional, default 100) Max exercise generation requests per IP per hour. @@ -191,6 +191,49 @@ Continuous Deployment is set up via GitHub Actions (`.github/workflows/fly.yml`) fly deploy --app ``` +### Deploying to Cloudflare Workers + +The application is configured to deploy to Cloudflare Workers using OpenNext and D1 database. + +**First-Time Cloudflare Setup:** + +1. **Install Wrangler CLI:** `npm install -g wrangler` +2. **Login:** `wrangler login` +3. **Create D1 Database:** + ```bash + wrangler d1 create comprehendo-db + ``` +4. **Update wrangler.toml:** Replace the database IDs in `wrangler.toml` with your actual database IDs +5. **Set Production Secrets:** + ```bash + wrangler secret put NEXTAUTH_SECRET + wrangler secret put NEXTAUTH_URL + wrangler secret put GOOGLE_CLIENT_ID + wrangler secret put GOOGLE_CLIENT_SECRET + wrangler secret put GOOGLE_TRANSLATE_API_KEY + wrangler secret put ADMIN_EMAILS + ``` +6. **Add GitHub Secrets:** + - `CLOUDFLARE_API_TOKEN`: Get from Cloudflare dashboard > My Profile > API Tokens + - `CLOUDFLARE_ACCOUNT_ID`: Get from Cloudflare dashboard > Right sidebar + +**Deployment:** + +- Push to `main`: `git push origin main` +- Monitor in GitHub Actions tab + +**Manual Deployment:** + +```bash +npm run deploy +``` + +**Database Migration:** + +```bash +wrangler d1 execute comprehendo-db --file=./scripts/migrate-d1.js +``` + ## Development Workflow Key scripts defined in `package.json`: @@ -238,8 +281,8 @@ npm run cleanup-db:dry-run - **Rate Limits**: Adjust `MAX_REQUESTS_PER_HOUR`, `MAX_TRANSLATION_REQUESTS_PER_HOUR`, and `MAX_DAILY_AI_REQUESTS` based on traffic and budget. - **Database Maintenance**: Run `npm run cleanup-db` periodically to clean up old rate limits, translations, and usage records. - **Security**: Review CORS, consider stricter input validation if needed. -- **Scaling**: Adjust Fly.io machine specs/count in `fly.toml`. -- **Database Backups**: Implement a backup strategy for the SQLite volume on Fly.io (e.g., using `litestream` or manual snapshots). +- **Scaling**: Cloudflare Workers automatically scale based on demand. +- **Database Backups**: Cloudflare D1 provides automatic backups and point-in-time recovery. ## Customization @@ -375,7 +418,9 @@ Once both files are correctly populated, `npm run test:e2e` should now be able t ## Database -Uses SQLite via `better-sqlite3`. The database file is `data/comprehendo.sqlite` locally, and stored on a persistent volume (`/data/comprehendo.sqlite`) in production (Fly.io). +Uses SQLite via Drizzle ORM with `@libsql/client`. The database file is `data/comprehendo.sqlite` locally, and uses Cloudflare D1 in production. + +The migration to Drizzle ORM makes it easy to switch to Cloudflare D1 in the future by simply changing the database connection configuration. ### SQLite Command Line (Local) @@ -389,7 +434,7 @@ Useful commands: `.tables`, `SELECT * FROM users LIMIT 5;`, `.schema quiz`, `.qu ### Database Schema ```sql --- From app/repo/db.ts initialization logic +-- From app/lib/db/migrations.ts initialization logic CREATE TABLE IF NOT EXISTS quiz ( id INTEGER PRIMARY KEY AUTOINCREMENT, diff --git a/app/actions/exercise.test.ts b/app/actions/exercise.test.ts index 7ebe72b..0238c2c 100644 --- a/app/actions/exercise.test.ts +++ b/app/actions/exercise.test.ts @@ -40,8 +40,14 @@ describe('exercise actions', () => { }); it('should handle errors gracefully', async () => { - const { getRandomGoodQuestionResponse } = await import('app/actions/exercise'); const { getRandomGoodQuestion } = await import('app/repo/quizRepo'); + + // Mock getRandomGoodQuestion to throw an error + vi.mocked(getRandomGoodQuestion).mockImplementation(() => { + throw new Error('Database error'); + }); + + const { getRandomGoodQuestionResponse } = await import('app/actions/exercise'); const { getServerSession } = await import('next-auth'); const { headers } = await import('next/headers'); const { findUserIdByProvider } = await import('app/repo/userRepo'); @@ -50,12 +56,7 @@ describe('exercise actions', () => { user: { id: '123', provider: 'google' }, } as never); vi.mocked(headers).mockResolvedValue(new Headers()); - vi.mocked(findUserIdByProvider).mockReturnValue(1); - - // Mock error in getRandomGoodQuestion - vi.mocked(getRandomGoodQuestion).mockImplementation(() => { - throw new Error('Database error'); - }); + vi.mocked(findUserIdByProvider).mockResolvedValue(1); const result = await getRandomGoodQuestionResponse({ passageLanguage: 'es', diff --git a/app/actions/exercise.ts b/app/actions/exercise.ts index 9b69080..7f47ea9 100644 --- a/app/actions/exercise.ts +++ b/app/actions/exercise.ts @@ -36,15 +36,15 @@ import { saveExercise as saveExerciseToCache, countCachedExercisesInRepo as countCachedExercises, getRandomGoodQuestion, - type QuizRow, } from 'app/repo/quizRepo'; import { incrementTodayUsage } from 'app/repo/aiApiUsageRepo'; import { z } from 'zod'; import { extractZodErrors } from 'app/lib/utils/errorUtils'; +import { validateAndSanitizeInput } from 'app/lib/utils/sanitization'; -const getDbUserIdFromSession = ( +const getDbUserIdFromSession = async ( session: { user: { id?: string | null; provider?: string | null } } | null -): number | null => { +): Promise => { if (!session || !session.user.id || !session.user.provider) { if (session) { console.warn( @@ -55,7 +55,7 @@ const getDbUserIdFromSession = ( } try { - const userId = findUserIdByProvider(session.user.id, session.user.provider); + const userId = await findUserIdByProvider(session.user.id, session.user.provider); if (userId === undefined) { console.warn( `[getDbUserIdFromSession] Direct lookup failed: Could not find user for providerId: ${session.user.id}, provider: ${session.user.provider}` @@ -69,14 +69,14 @@ const getDbUserIdFromSession = ( } }; -const getValidatedExerciseFromCache = ( +const getValidatedExerciseFromCache = async ( passageLanguage: string, questionLanguage: string, level: string, userId: number | null, excludeQuizId?: number | null -): { quizData: PartialQuizData; quizId: number } | undefined => { - const cachedExercise: QuizRow | undefined = getCachedExercise( +): Promise<{ quizData: PartialQuizData; quizId: number } | undefined> => { + const cachedExercise = await getCachedExercise( passageLanguage, questionLanguage, level, @@ -124,14 +124,14 @@ const getValidatedExerciseFromCache = ( return undefined; }; -const getValidatedRandomGoodQuestion = ( +const getValidatedRandomGoodQuestion = async ( passageLanguage: string, questionLanguage: string, level: string, userId: number | null, excludeQuizId?: number | null -): { quizData: PartialQuizData; quizId: number } | undefined => { - const randomGoodQuestion: QuizRow | undefined = getRandomGoodQuestion( +): Promise<{ quizData: PartialQuizData; quizId: number } | undefined> => { + const randomGoodQuestion = await getRandomGoodQuestion( passageLanguage, questionLanguage, level, @@ -186,28 +186,28 @@ const MAX_REQUESTS_PER_HOUR = parseInt( ); const RATE_LIMIT_WINDOW = parseInt(process.env['RATE_LIMIT_WINDOW_MS'] || '3600000', 10); -const checkRateLimit = (ip: string): boolean => { +const checkRateLimit = async (ip: string): Promise => { try { const now = Date.now(); - const rateLimitRow = getRateLimit(ip); + const rateLimitRow = await getRateLimit(ip); if (!rateLimitRow) { - createRateLimit(ip, new Date(now).toISOString()); + await createRateLimit(ip, new Date(now).toISOString()); return true; } - const windowStartTime = new Date(rateLimitRow.window_start_time).getTime(); + const windowStartTime = new Date(rateLimitRow.windowStartTime).getTime(); const isWithinWindow = now - windowStartTime < RATE_LIMIT_WINDOW; if (isWithinWindow) { - if (rateLimitRow.request_count >= MAX_REQUESTS_PER_HOUR) { + if (rateLimitRow.requestCount >= MAX_REQUESTS_PER_HOUR) { return false; } - incrementRateLimit(ip); + await incrementRateLimit(ip); return true; } - resetRateLimit(ip, new Date(now).toISOString()); + await resetRateLimit(ip, new Date(now).toISOString()); return true; } catch (error) { console.error('[RateLimiter] Error checking rate limit:', error); @@ -255,7 +255,7 @@ const tryGenerateAndCacheExercise = async ( ): Promise> => { try { // Check daily AI API budget before making expensive API calls - if (!incrementTodayUsage()) { + if (!(await incrementTodayUsage())) { return failure({ error: 'Daily AI API request limit exceeded. Please try again tomorrow.', }); @@ -273,7 +273,7 @@ const tryGenerateAndCacheExercise = async ( console.log('[Exercise] AI generation successful, attempting to save to cache...'); - const exerciseId = saveExerciseToCache( + const exerciseId = await saveExerciseToCache( params.passageLanguage, params.questionLanguage, params.level, @@ -323,7 +323,7 @@ const getOrGenerateExercise = async ( } generationError = genResult.error; - const validatedCacheResultPreferGen = getValidatedExerciseFromCache( + const validatedCacheResultPreferGen = await getValidatedExerciseFromCache( requestParams.passageLanguage, requestParams.questionLanguage, requestParams.cefrLevel, @@ -343,7 +343,7 @@ const getOrGenerateExercise = async ( ); } - const validatedCacheResultOtherwise = getValidatedExerciseFromCache( + const validatedCacheResultOtherwise = await getValidatedExerciseFromCache( requestParams.passageLanguage, requestParams.questionLanguage, requestParams.cefrLevel, @@ -370,10 +370,18 @@ const getOrGenerateExercise = async ( const getRequestContext = async () => { const headersList = await headers(); - const ip = headersList.get('fly-client-ip') || headersList.get('x-forwarded-for') || 'unknown'; + const ip = + headersList.get('cf-connecting-ip') || + headersList.get('x-forwarded-for') || + headersList.get('fly-client-ip') || + 'unknown'; + + // Sanitize IP address + const sanitizedIp = validateAndSanitizeInput(ip, 45); // IPv6 max length + const session = await getServerSession(authOptions); - const userId = getDbUserIdFromSession(session); - return { ip, userId }; + const userId = await getDbUserIdFromSession(session); + return { ip: sanitizedIp, userId }; }; const validateAndExtractParams = (requestParams: unknown) => { @@ -410,9 +418,9 @@ export const generateExerciseResponse = async ( const excludeQuizId = validParams.excludeQuizId ?? null; - if (!checkRateLimit(ip)) { + if (!(await checkRateLimit(ip))) { // Try to find cached questions for the exact parameters first - const validatedCacheResultRateLimit = getValidatedExerciseFromCache( + const validatedCacheResultRateLimit = await getValidatedExerciseFromCache( validParams.passageLanguage, validParams.questionLanguage, validParams.cefrLevel, @@ -429,7 +437,7 @@ export const generateExerciseResponse = async ( } // If no exact match, try to find any cached questions for this language - const fallbackCacheResult = getValidatedExerciseFromCache( + const fallbackCacheResult = await getValidatedExerciseFromCache( validParams.passageLanguage, validParams.questionLanguage, 'A1', // Try with A1 level as fallback @@ -451,7 +459,7 @@ export const generateExerciseResponse = async ( } const genParams = buildGenParams(validParams); - const cachedCountValue = countCachedExercises( + const cachedCountValue = await countCachedExercises( genParams.passageLanguage, genParams.questionLanguage, genParams.level @@ -474,7 +482,7 @@ export const generateInitialExercisePair = async ( const { validParams, errorMsg } = validateAndExtractParams(requestParams); if (!validParams) return { quizzes: [], error: errorMsg }; - if (!checkRateLimit(ip)) return { quizzes: [], error: 'Rate limit exceeded.' }; + if (!(await checkRateLimit(ip))) return { quizzes: [], error: 'Rate limit exceeded.' }; const genParams1 = buildGenParams(validParams); const genParams2 = buildGenParams(validParams, getRandomTopicForLevel(validParams.cefrLevel)); @@ -526,7 +534,7 @@ export const getRandomGoodQuestionResponse = async ( }); try { - const randomGoodResult = getValidatedRandomGoodQuestion( + const randomGoodResult = await getValidatedRandomGoodQuestion( validParams.passageLanguage, validParams.questionLanguage, validParams.cefrLevel, diff --git a/app/actions/progress.ts b/app/actions/progress.ts index 345f248..0b267b2 100644 --- a/app/actions/progress.ts +++ b/app/actions/progress.ts @@ -57,25 +57,25 @@ const calculateNextProgress = ( return { nextLevel: currentLevel, nextStreak: newStreak, leveledUp: false }; }; -const getOrInitProgress = (userId: number, languageCode: string) => { - const currentProgress = findUserProgress(userId, languageCode); - return currentProgress || initializeProgress(userId, languageCode); +const getOrInitProgress = async (userId: number, languageCode: string) => { + const currentProgress = await findUserProgress(userId, languageCode); + return currentProgress || (await initializeProgress(userId, languageCode)); }; -const calculateAndUpdateProgress = ( +const calculateAndUpdateProgress = async ( userId: number, language: string, isCorrect: boolean -): ProgressUpdateResult => { +): Promise => { try { const languageCode = language.toLowerCase().slice(0, 2); - const currentProgress = getOrInitProgress(userId, languageCode); + const currentProgress = await getOrInitProgress(userId, languageCode); const { nextLevel, nextStreak, leveledUp } = calculateNextProgress( currentProgress.cefr_level, currentProgress.correct_streak, isCorrect ); - updateProgressRepository(userId, languageCode, nextLevel, nextStreak); + await updateProgressRepository(userId, languageCode, nextLevel, nextStreak); return { currentLevel: nextLevel, currentStreak: nextStreak, leveledUp }; } catch (e) { const message = e instanceof Error ? e.message : 'Unknown error'; @@ -132,9 +132,9 @@ export interface ProgressResponse { feedback?: FeedbackType; } -const getParsedQuizData = (quizId: number) => { +const getParsedQuizData = async (quizId: number) => { try { - const quizRecord = findQuizById(quizId); + const quizRecord = await findQuizById(quizId); if (!quizRecord) return { data: null, error: `Quiz with ID ${quizId} not found.` }; const parsedContent = QuizDataSchema.safeParse(quizRecord.content); if (!parsedContent.success) @@ -204,7 +204,7 @@ export const updateProgress = async (params: UpdateProgressParams): Promise { +const checkTranslationRateLimit = async (ip: string): Promise => { try { const now = Date.now(); - const rateLimitRow = getRateLimit(ip); + const rateLimitRow = await getRateLimit(ip); if (!rateLimitRow) { - createRateLimit(ip, new Date(now).toISOString()); + await createRateLimit(ip, new Date(now).toISOString()); return true; } - const windowStartTime = new Date(rateLimitRow.window_start_time).getTime(); + const windowStartTime = new Date(rateLimitRow.windowStartTime).getTime(); const isWithinWindow = now - windowStartTime < RATE_LIMIT_WINDOW; if (isWithinWindow) { - if (rateLimitRow.request_count >= MAX_TRANSLATION_REQUESTS_PER_HOUR) { + if (rateLimitRow.requestCount >= MAX_TRANSLATION_REQUESTS_PER_HOUR) { return false; } - incrementRateLimit(ip); + await incrementRateLimit(ip); return true; } - resetRateLimit(ip, new Date(now).toISOString()); + await resetRateLimit(ip, new Date(now).toISOString()); return true; } catch (error) { console.error('[Translation RateLimiter] Error checking rate limit:', error); @@ -61,25 +62,39 @@ export const translateWordWithGoogle = async ( targetLang: string, sourceLang: string ) => { + // Sanitize input parameters + const sanitizedWord = validateAndSanitizeInput(word, 100); + const sanitizedTargetLang = validateAndSanitizeInput(targetLang, 10); + const sanitizedSourceLang = validateAndSanitizeInput(sourceLang, 10); + const googleApiKey = process.env['GOOGLE_TRANSLATE_API_KEY']; - if (!word || !targetLang || !sourceLang || !googleApiKey) { + if (!sanitizedWord || !sanitizedTargetLang || !sanitizedSourceLang || !googleApiKey) { console.error('translateWordWithGoogle: Missing required parameters or API key.'); return null; } // Check cache first - const cachedTranslation = getCachedTranslation(word, sourceLang, targetLang); + const cachedTranslation = await getCachedTranslation( + sanitizedWord, + sanitizedSourceLang, + sanitizedTargetLang + ); if (cachedTranslation) { - console.log(`[Translation] Cache hit for "${word}" (${sourceLang} -> ${targetLang})`); + console.log( + `[Translation] Cache hit for "${sanitizedWord}" (${sanitizedSourceLang} -> ${sanitizedTargetLang})` + ); return TranslationResultSchema.parse({ translation: cachedTranslation, romanization: '' }); } // Check rate limit for translation requests const headersList = await headers(); - const ip = headersList.get('fly-client-ip') || headersList.get('x-forwarded-for') || 'unknown'; + const ip = validateAndSanitizeInput( + headersList.get('fly-client-ip') || headersList.get('x-forwarded-for') || 'unknown', + 45 + ); - if (!checkTranslationRateLimit(ip)) { + if (!(await checkTranslationRateLimit(ip))) { console.warn(`[Translation] Rate limit exceeded for IP: ${ip}`); return null; } @@ -87,10 +102,10 @@ export const translateWordWithGoogle = async ( const apiUrl = `https://translation.googleapis.com/language/translate/v2?key=${googleApiKey}`; const payload = { - q: word, - target: targetLang, + q: sanitizedWord, + target: sanitizedTargetLang, format: 'text', - source: sourceLang, + source: sanitizedSourceLang, }; try { @@ -121,8 +136,15 @@ export const translateWordWithGoogle = async ( } // Save to cache for future use - saveTranslationToCache(word, sourceLang, targetLang, translatedText); - console.log(`[Translation] Cached translation for "${word}" (${sourceLang} -> ${targetLang})`); + await saveTranslationToCache( + sanitizedWord, + sanitizedSourceLang, + sanitizedTargetLang, + translatedText + ); + console.log( + `[Translation] Cached translation for "${sanitizedWord}" (${sanitizedSourceLang} -> ${sanitizedTargetLang})` + ); return TranslationResultSchema.parse({ translation: translatedText, romanization: '' }); } catch (error) { diff --git a/app/admin/actions.ts b/app/admin/actions.ts index 1b4c1ef..abcf3a0 100644 --- a/app/admin/actions.ts +++ b/app/admin/actions.ts @@ -27,7 +27,7 @@ const ensureAdmin = async (): Promise => { }; const createAdminAction = ( - action: (...args: TArgs) => Promise | TReturn, + action: (...args: TArgs) => Promise, actionNameForLog: string, defaultFailureMessage: string ) => { diff --git a/app/admin/components/DataTable.tsx b/app/admin/components/DataTable.tsx index 7c03040..b1436ad 100644 --- a/app/admin/components/DataTable.tsx +++ b/app/admin/components/DataTable.tsx @@ -1,4 +1,3 @@ -import React from 'react'; import { DataTableControls } from './DataTableControls'; import { DataTableBody } from './DataTableBody'; diff --git a/app/admin/components/DataTableBody.tsx b/app/admin/components/DataTableBody.tsx index 88319cc..a29033e 100644 --- a/app/admin/components/DataTableBody.tsx +++ b/app/admin/components/DataTableBody.tsx @@ -1,4 +1,3 @@ -import React from 'react'; import { renderTableCellValue } from 'app/lib/utils/rendering'; interface DataTableBodyProps> { diff --git a/app/admin/components/DataTableControls.tsx b/app/admin/components/DataTableControls.tsx index 3eb4b95..cad1b0b 100644 --- a/app/admin/components/DataTableControls.tsx +++ b/app/admin/components/DataTableControls.tsx @@ -1,4 +1,3 @@ -import React from 'react'; import { ArrowPathIcon as HeroRefreshIcon, ChevronLeftIcon as HeroChevronLeftIcon, diff --git a/app/admin/components/FormattedValueDisplay.tsx b/app/admin/components/FormattedValueDisplay.tsx index 89b2376..0450f5a 100644 --- a/app/admin/components/FormattedValueDisplay.tsx +++ b/app/admin/components/FormattedValueDisplay.tsx @@ -1,5 +1,3 @@ -import React from 'react'; - interface FormattedValueDisplayProps { valueKey: string; value: unknown; diff --git a/app/admin/components/RowDetailView.tsx b/app/admin/components/RowDetailView.tsx index 49605c3..c4a8d2f 100644 --- a/app/admin/components/RowDetailView.tsx +++ b/app/admin/components/RowDetailView.tsx @@ -1,4 +1,3 @@ -import React from 'react'; import { FormattedValueDisplay } from './FormattedValueDisplay'; import { XMarkIcon } from '@heroicons/react/24/solid'; diff --git a/app/admin/components/TableSelector.tsx b/app/admin/components/TableSelector.tsx index 7a5f061..d34110e 100644 --- a/app/admin/components/TableSelector.tsx +++ b/app/admin/components/TableSelector.tsx @@ -1,5 +1,3 @@ -import React from 'react'; - export const buttonBaseClass = 'px-4 py-3 sm:px-4 sm:py-2 inline-flex items-center justify-center text-center gap-2 border rounded-md shadow-sm text-sm font-medium focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-offset-gray-800 disabled:opacity-50 disabled:cursor-not-allowed transition-colors'; export const primaryButtonClass = `${buttonBaseClass} text-white bg-indigo-600 hover:bg-indigo-700 focus:ring-indigo-500 border-transparent`; diff --git a/app/admin/page.tsx b/app/admin/page.tsx index 24af13b..f790270 100644 --- a/app/admin/page.tsx +++ b/app/admin/page.tsx @@ -1,6 +1,6 @@ 'use client'; -import React, { useState, useCallback } from 'react'; +import { useState, useCallback } from 'react'; import Link from 'next/link'; import { useSession } from 'next-auth/react'; import { RowDetailView } from './components/RowDetailView'; diff --git a/app/components/DatabaseProvider.tsx b/app/components/DatabaseProvider.tsx new file mode 100644 index 0000000..c5f00d4 --- /dev/null +++ b/app/components/DatabaseProvider.tsx @@ -0,0 +1,26 @@ +'use client'; + +import { createContext, useContext, ReactNode } from 'react'; + +interface DatabaseContextType { + getDb: (d1Database?: unknown) => unknown; +} + +const DatabaseContext = createContext(null); + +export const DatabaseProvider = ({ children }: { children: ReactNode }) => { + const getDb = async (d1Database?: unknown) => { + const { getDb: getDbFunction } = await import('app/lib/db/adapter'); + return getDbFunction(d1Database); + }; + + return {children}; +}; + +export const useDatabase = () => { + const context = useContext(DatabaseContext); + if (!context) { + throw new Error('useDatabase must be used within a DatabaseProvider'); + } + return context; +}; diff --git a/app/components/TextGenerator/AudioControls.tsx b/app/components/TextGenerator/AudioControls.tsx index 91505ba..133071e 100644 --- a/app/components/TextGenerator/AudioControls.tsx +++ b/app/components/TextGenerator/AudioControls.tsx @@ -1,6 +1,5 @@ 'use client'; -import React from 'react'; import useTextGeneratorStore from 'app/store/textGeneratorStore'; import { useTranslation } from 'react-i18next'; import PlayPauseButton from './PlayPauseButton'; diff --git a/app/components/TextGenerator/ErrorDisplay.tsx b/app/components/TextGenerator/ErrorDisplay.tsx index 6d0dff5..9d0b3b0 100644 --- a/app/components/TextGenerator/ErrorDisplay.tsx +++ b/app/components/TextGenerator/ErrorDisplay.tsx @@ -1,8 +1,8 @@ 'use client'; -import React from 'react'; import { useTranslation } from 'react-i18next'; import useTextGeneratorStore from 'app/store/textGeneratorStore'; +import { sanitizeText } from 'app/lib/utils/sanitization'; const ErrorDisplay = () => { const { t } = useTranslation('common'); @@ -19,7 +19,7 @@ const ErrorDisplay = () => { data-testid="error-display" > {t('common.errorPrefix')} - {error} + {sanitizeText(error)} ); }; diff --git a/app/components/TextGenerator/Generator.tsx b/app/components/TextGenerator/Generator.tsx index 3d5b071..0e0b87c 100644 --- a/app/components/TextGenerator/Generator.tsx +++ b/app/components/TextGenerator/Generator.tsx @@ -1,6 +1,6 @@ 'use client'; -import React, { useRef, useCallback } from 'react'; +import { useRef, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import useTextGeneratorStore from 'app/store/textGeneratorStore'; import { HandThumbUpIcon, HandThumbDownIcon } from '@heroicons/react/24/solid'; diff --git a/app/components/TextGenerator/LanguageSelector.tsx b/app/components/TextGenerator/LanguageSelector.tsx index e0b81b0..52d7e97 100644 --- a/app/components/TextGenerator/LanguageSelector.tsx +++ b/app/components/TextGenerator/LanguageSelector.tsx @@ -1,6 +1,5 @@ 'use client'; -import React from 'react'; import { GlobeAltIcon, BookOpenIcon } from '@heroicons/react/24/solid'; import { LEARNING_LANGUAGES, type LearningLanguage } from 'app/domain/language'; import useTextGeneratorStore from 'app/store/textGeneratorStore'; diff --git a/app/components/TextGenerator/LoginPrompt.tsx b/app/components/TextGenerator/LoginPrompt.tsx index a1a2571..962baba 100644 --- a/app/components/TextGenerator/LoginPrompt.tsx +++ b/app/components/TextGenerator/LoginPrompt.tsx @@ -1,6 +1,5 @@ 'use client'; -import React from 'react'; import { useTranslation } from 'react-i18next'; import { useSession } from 'next-auth/react'; import { XMarkIcon } from '@heroicons/react/24/solid'; diff --git a/app/components/TextGenerator/PlayPauseButton.tsx b/app/components/TextGenerator/PlayPauseButton.tsx index 93fbb52..2c6311e 100644 --- a/app/components/TextGenerator/PlayPauseButton.tsx +++ b/app/components/TextGenerator/PlayPauseButton.tsx @@ -1,5 +1,4 @@ import { PlayIcon, PauseIcon } from '@heroicons/react/24/solid'; -import React from 'react'; type PlayPauseButtonProps = { isSpeakingPassage: boolean; diff --git a/app/components/TextGenerator/ProgressTracker.tsx b/app/components/TextGenerator/ProgressTracker.tsx index 3e6eda7..1f31e5b 100644 --- a/app/components/TextGenerator/ProgressTracker.tsx +++ b/app/components/TextGenerator/ProgressTracker.tsx @@ -1,6 +1,5 @@ 'use client'; -import React from 'react'; import { useTranslation } from 'react-i18next'; import useTextGeneratorStore from 'app/store/textGeneratorStore'; import { useSession } from 'next-auth/react'; @@ -19,7 +18,7 @@ const ProgressBar = ({ cefrLevel, userStreak }: { cefrLevel: CEFRLevel; userStre
{CEFR_LEVELS.map((level: CEFRLevel, idx: number) => ( - +
)} - +
))}
diff --git a/app/components/TextGenerator/ProgressionFeedback.tsx b/app/components/TextGenerator/ProgressionFeedback.tsx index e81a25d..7b13d07 100644 --- a/app/components/TextGenerator/ProgressionFeedback.tsx +++ b/app/components/TextGenerator/ProgressionFeedback.tsx @@ -1,6 +1,5 @@ 'use client'; -import React from 'react'; import { motion, AnimatePresence } from 'motion/react'; import { useTranslation } from 'react-i18next'; import { useSession } from 'next-auth/react'; diff --git a/app/components/TextGenerator/QuizSection.tsx b/app/components/TextGenerator/QuizSection.tsx index 86ffad4..fba3762 100644 --- a/app/components/TextGenerator/QuizSection.tsx +++ b/app/components/TextGenerator/QuizSection.tsx @@ -1,6 +1,6 @@ 'use client'; -import React, { useCallback, useState, useEffect } from 'react'; +import { useCallback, useState, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; import { motion, AnimatePresence } from 'motion/react'; import { getTextDirection, useLanguage } from 'app/hooks/useLanguage'; diff --git a/app/components/TextGenerator/QuizSkeleton.tsx b/app/components/TextGenerator/QuizSkeleton.tsx index 5c5dd52..61051b5 100644 --- a/app/components/TextGenerator/QuizSkeleton.tsx +++ b/app/components/TextGenerator/QuizSkeleton.tsx @@ -1,7 +1,5 @@ 'use client'; -import React from 'react'; - const QuizSkeleton = () => (
diff --git a/app/components/TextGenerator/ReadingPassage.tsx b/app/components/TextGenerator/ReadingPassage.tsx index f613607..c10b47f 100644 --- a/app/components/TextGenerator/ReadingPassage.tsx +++ b/app/components/TextGenerator/ReadingPassage.tsx @@ -1,6 +1,6 @@ 'use client'; -import React, { useState, useEffect } from 'react'; +import { useState, useEffect } from 'react'; import { BookOpenIcon, XMarkIcon } from '@heroicons/react/24/outline'; import { useTranslation } from 'react-i18next'; import { motion } from 'motion/react'; diff --git a/app/components/TextGenerator/TextGeneratorContainer.tsx b/app/components/TextGenerator/TextGeneratorContainer.tsx index 9dafd4b..fcb14ad 100644 --- a/app/components/TextGenerator/TextGeneratorContainer.tsx +++ b/app/components/TextGenerator/TextGeneratorContainer.tsx @@ -1,6 +1,6 @@ 'use client'; -import React, { useRef, useEffect } from 'react'; +import { useRef, useEffect } from 'react'; import { useLanguage } from 'app/hooks/useLanguage'; import { useSession } from 'next-auth/react'; import { motion, AnimatePresence } from 'motion/react'; diff --git a/app/components/TextGenerator/TranslatableWord.tsx b/app/components/TextGenerator/TranslatableWord.tsx index 2b154ce..1048831 100644 --- a/app/components/TextGenerator/TranslatableWord.tsx +++ b/app/components/TextGenerator/TranslatableWord.tsx @@ -1,6 +1,6 @@ 'use client'; -import React, { useState, useCallback, memo, useEffect } from 'react'; +import { useState, useCallback, memo, useEffect } from 'react'; import { type Language } from 'app/domain/language'; import useTextGeneratorStore from 'app/store/textGeneratorStore'; import { analyzeTranslation, canTranslate } from 'app/lib/utils/translation'; diff --git a/app/components/TextGenerator/VoiceSelector.tsx b/app/components/TextGenerator/VoiceSelector.tsx index c072504..96c58e1 100644 --- a/app/components/TextGenerator/VoiceSelector.tsx +++ b/app/components/TextGenerator/VoiceSelector.tsx @@ -1,5 +1,4 @@ import { ChevronDownIcon } from '@heroicons/react/24/solid'; -import React from 'react'; import type { VoiceInfo } from 'app/domain/schemas'; type VoiceSelectorProps = { diff --git a/app/components/TextGenerator/VolumeSlider.tsx b/app/components/TextGenerator/VolumeSlider.tsx index 270e192..9b162cf 100644 --- a/app/components/TextGenerator/VolumeSlider.tsx +++ b/app/components/TextGenerator/VolumeSlider.tsx @@ -1,5 +1,4 @@ import { SpeakerWaveIcon } from '@heroicons/react/24/solid'; -import React from 'react'; type VolumeSliderProps = { volume: number; diff --git a/app/components/TextGenerator/useRenderParagraphWithWordHover.tsx b/app/components/TextGenerator/useRenderParagraphWithWordHover.tsx index fb2fe4b..06f0655 100644 --- a/app/components/TextGenerator/useRenderParagraphWithWordHover.tsx +++ b/app/components/TextGenerator/useRenderParagraphWithWordHover.tsx @@ -1,4 +1,4 @@ -import React from 'react'; +import { useCallback } from 'react'; import TranslatableWord from './TranslatableWord'; import { type Language } from 'app/domain/language'; @@ -15,7 +15,7 @@ const useRenderParagraphWithWordHover = ({ relevantTextRange, actualQuestionLanguage, }: Params) => { - return React.useCallback( + return useCallback( (paragraph: string, lang: Language) => { const words = paragraph.split(/(\s+)/); let currentPos = 0; diff --git a/app/layout.tsx b/app/layout.tsx index 711ff91..2ef4aa3 100644 --- a/app/layout.tsx +++ b/app/layout.tsx @@ -4,6 +4,7 @@ import './globals.css'; import AuthProvider from 'app/components/AuthProvider'; import { type Language } from 'app/domain/language'; import { cookies } from 'next/headers'; +import { CSRFProtection } from 'app/lib/utils/csrf'; const poppins = Poppins({ subsets: ['latin'], @@ -54,12 +55,14 @@ const RootLayout = async ({ }>) => { const cookieStore = await cookies(); const locale = (cookieStore.get('NEXT_LOCALE')?.value || 'en') as Language; + const csrfToken = CSRFProtection.generateToken(); return ( + {children} diff --git a/app/lib/auth/callbacks.ts b/app/lib/auth/callbacks.ts index 3fd1dcd..b22001f 100644 --- a/app/lib/auth/callbacks.ts +++ b/app/lib/auth/callbacks.ts @@ -8,24 +8,26 @@ export interface UserWithEmail extends User { email?: string | null; } -export const signInCallback = ({ +export const signInCallback = async ({ user, account, }: { user: User | AdapterUser; account: Account | null; -}): boolean => { +}): Promise => { // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition if (account && user) { try { - upsertUserOnSignIn(user, account); + await upsertUserOnSignIn(user, account); return true; } catch (error) { console.error( '[AUTH SignIn Callback] Error during sign in process (upsertUserOnSignIn failed):', error ); - return false; + // Don't fail authentication if database operations fail + // This allows users to still sign in even if there are database issues + return true; } } else { console.warn('[AUTH SignIn Callback] Missing account or user object. Skipping DB upsert.'); @@ -33,7 +35,7 @@ export const signInCallback = ({ } }; -export const jwtCallback = ({ +export const jwtCallback = async ({ token, user, account, @@ -41,23 +43,24 @@ export const jwtCallback = ({ token: JWT; user?: UserWithEmail; account?: Account | null; -}): JWT => { +}): Promise => { if (account && user?.id && user.email) { token.provider = account.provider; token.email = user.email; try { - const userRecord = findUserByProvider(user.id, account.provider); + const userRecord = await findUserByProvider(user.id, account.provider); if (userRecord) { token.dbId = userRecord.id; } else { - console.error( - `[AUTH JWT Callback] CRITICAL: Could not find user in DB during JWT creation for provider_id=${user.id}, provider=${account.provider}. dbId will be missing!` + console.warn( + `[AUTH JWT Callback] Could not find user in DB during JWT creation for provider_id=${user.id}, provider=${account.provider}. dbId will be missing!` ); } } catch (error) { - console.error('[AUTH JWT Callback] CRITICAL: Error resolving user DB ID for token:', error); + console.warn('[AUTH JWT Callback] Error resolving user DB ID for token:', error); + // Continue without dbId - authentication should still work } const adminEmails = validatedAuthEnv.ADMIN_EMAILS; diff --git a/app/lib/authOptions.ts b/app/lib/authOptions.ts index ee2813b..e6dad26 100644 --- a/app/lib/authOptions.ts +++ b/app/lib/authOptions.ts @@ -52,7 +52,7 @@ console.log(`[NextAuth] Configured ${providers.length} authentication providers` export const authOptions: NextAuthOptions = { providers, - secret: validatedAuthEnv.AUTH_SECRET, + secret: validatedAuthEnv.NEXTAUTH_SECRET, session: { strategy: 'jwt' as const, }, diff --git a/app/lib/config/authEnv.ts b/app/lib/config/authEnv.ts index e11b55f..4e4febc 100644 --- a/app/lib/config/authEnv.ts +++ b/app/lib/config/authEnv.ts @@ -31,7 +31,8 @@ export const authEnvSchema = z GOOGLE_CLIENT_SECRET: z.string().optional(), DISCORD_CLIENT_ID: z.string().optional(), DISCORD_CLIENT_SECRET: z.string().optional(), - AUTH_SECRET: z.string({ message: '[NextAuth] ERROR: AUTH_SECRET is missing!' }), + AUTH_SECRET: z.string().optional(), + NEXTAUTH_SECRET: z.string({ message: '[NextAuth] ERROR: NEXTAUTH_SECRET is missing!' }), NEXTAUTH_URL: z.string().pipe(z.url()).optional(), ADMIN_EMAILS: z .string() diff --git a/app/lib/db/adapter.ts b/app/lib/db/adapter.ts new file mode 100644 index 0000000..f3f8065 --- /dev/null +++ b/app/lib/db/adapter.ts @@ -0,0 +1,79 @@ +import { drizzle } from 'drizzle-orm/d1'; +import { drizzle as drizzleSqlite } from 'drizzle-orm/better-sqlite3'; +import Database from 'better-sqlite3'; +import path from 'path'; +import fs from 'fs'; +import * as schema from './schema'; +import { initializeSchema } from './migrations'; +import { initializeSchema as initializeD1Schema } from './d1-migrations'; + +const DB_DIR = process.env.NODE_ENV === 'production' ? '/data' : path.join(process.cwd(), 'data'); +const DB_PATH = path.join(DB_DIR, 'comprehendo.sqlite'); + +const isBuildPhase = + process.env.NODE_ENV === 'production' && process.env['NEXT_PHASE'] === 'phase-production-build'; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type DatabaseInstance = any; + +let db: DatabaseInstance | null = null; +let isInitialized = false; + +const initializeDatabase = (d1Database?: unknown): DatabaseInstance => { + if (db && isInitialized) { + return db; + } + + console.log('[DB] Starting Drizzle database initialization...'); + + try { + if (d1Database) { + console.log('[DB] Using Cloudflare D1 database'); + db = drizzle(d1Database as Parameters[0], { schema }); + initializeD1Schema(db); + } else { + console.log('[DB] Using SQLite database'); + let sqlite; + + if (isBuildPhase || !fs.existsSync(process.cwd())) { + sqlite = new Database(':memory:'); + console.log('[DB] Using in-memory database for build phase'); + } else { + if (!fs.existsSync(DB_DIR)) { + fs.mkdirSync(DB_DIR, { recursive: true }); + console.log(`[DB] Created database directory at ${DB_DIR}`); + } + sqlite = new Database(DB_PATH); + console.log(`[DB] Connected to database at ${DB_PATH}`); + } + + db = drizzleSqlite(sqlite, { schema }); + initializeSchema(db); + } + + console.log('[DB] Drizzle database initialized successfully.'); + isInitialized = true; + return db; + } catch (error) { + console.error('Error creating Drizzle database:', error); + if (process.env.NODE_ENV === 'production' && !isBuildPhase) { + console.error('[DB] CRITICAL: Database initialization failed in production!'); + throw new Error( + `Database initialization failed: ${error instanceof Error ? error.message : String(error)}` + ); + } + throw error; + } +}; + +let initializedDbInstance: DatabaseInstance | null = null; + +export const getDb = (d1Database?: unknown): DatabaseInstance => { + if (!initializedDbInstance) { + initializedDbInstance = initializeDatabase(d1Database); + } + return initializedDbInstance; +}; + +export default getDb; +export { schema }; diff --git a/app/lib/db/context.ts b/app/lib/db/context.ts new file mode 100644 index 0000000..2d316fc --- /dev/null +++ b/app/lib/db/context.ts @@ -0,0 +1,28 @@ +import type { DatabaseInstance } from './adapter'; + +let globalDb: DatabaseInstance | null = null; + +export const setGlobalDb = (db: DatabaseInstance) => { + globalDb = db; +}; + +export const getGlobalDb = async (): Promise => { + if (globalDb) { + return globalDb; + } + + // Fallback for development - use the original SQLite database + if (process.env.NODE_ENV === 'development') { + const { getDb } = await import('app/lib/db/adapter'); + const db = getDb(); + globalDb = db; + return db; + } + + throw new Error('Database not initialized'); +}; + +// Export a function that matches the expected signature +export const getDb = async (): Promise => { + return getGlobalDb(); +}; diff --git a/app/lib/db/d1-migrations.ts b/app/lib/db/d1-migrations.ts new file mode 100644 index 0000000..4aa64ac --- /dev/null +++ b/app/lib/db/d1-migrations.ts @@ -0,0 +1,92 @@ +import { sql } from 'drizzle-orm'; +import type { DrizzleD1Database } from 'drizzle-orm/d1'; + +export const initializeSchema = (db: DrizzleD1Database>) => { + console.log('[DB] Initializing/verifying D1 database schema...'); + + try { + db.run(sql` + CREATE TABLE IF NOT EXISTS quiz ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + language TEXT NOT NULL, + level TEXT NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + question_language TEXT, + user_id INTEGER, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE SET NULL + ); + + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + provider_id TEXT NOT NULL, + provider TEXT NOT NULL, + name TEXT, + email TEXT, + image TEXT, + first_login TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_login TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + language TEXT DEFAULT 'en', + UNIQUE(provider_id, provider) + ); + + CREATE TABLE IF NOT EXISTS user_language_progress ( + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + language_code TEXT NOT NULL, + cefr_level TEXT NOT NULL DEFAULT 'A1', + correct_streak INTEGER NOT NULL DEFAULT 0, + last_practiced TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (user_id, language_code) + ); + + CREATE TABLE IF NOT EXISTS question_feedback ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + quiz_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + is_good INTEGER NOT NULL, + user_answer TEXT, + is_correct INTEGER, + submitted_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (quiz_id) REFERENCES quiz (id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS rate_limits ( + ip_address TEXT PRIMARY KEY, + request_count INTEGER NOT NULL DEFAULT 1, + window_start_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE TABLE IF NOT EXISTS translation_cache ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source_word TEXT NOT NULL, + source_language TEXT NOT NULL, + target_language TEXT NOT NULL, + translated_text TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(source_word, source_language, target_language) + ); + + CREATE TABLE IF NOT EXISTS ai_api_usage ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + date TEXT NOT NULL, + request_count INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(date) + ); + + CREATE INDEX IF NOT EXISTS idx_quiz_created_at ON quiz(created_at DESC); + CREATE INDEX IF NOT EXISTS idx_users_last_login ON users(last_login DESC); + CREATE INDEX IF NOT EXISTS idx_user_language_progress_last_practiced ON user_language_progress(last_practiced DESC); + CREATE INDEX IF NOT EXISTS idx_question_feedback_quiz_id ON question_feedback (quiz_id); + CREATE INDEX IF NOT EXISTS idx_question_feedback_user_id ON question_feedback (user_id); + CREATE INDEX IF NOT EXISTS idx_translation_cache_lookup ON translation_cache(source_word, source_language, target_language); + CREATE INDEX IF NOT EXISTS idx_ai_api_usage_date ON ai_api_usage(date); + `); + + console.log('[DB] D1 Schema initialization/verification complete'); + } catch (error) { + console.error('Error initializing D1 schema:', error); + throw error; + } +}; diff --git a/app/lib/db/d1.ts b/app/lib/db/d1.ts new file mode 100644 index 0000000..c0ca995 --- /dev/null +++ b/app/lib/db/d1.ts @@ -0,0 +1,37 @@ +import { drizzle } from 'drizzle-orm/d1'; +import * as schema from './schema'; +import { initializeSchema } from './d1-migrations'; + +let db: ReturnType | null = null; +let isInitialized = false; + +const initializeDatabase = (d1Database: unknown) => { + if (db && isInitialized) { + return db; + } + + console.log('[DB] Starting Drizzle D1 database initialization...'); + + try { + db = drizzle(d1Database, { schema }); + initializeSchema(db); + console.log('[DB] Drizzle D1 database initialized successfully.'); + isInitialized = true; + return db; + } catch (error) { + console.error('Error creating Drizzle D1 database:', error); + throw error; + } +}; + +let initializedDbInstance: ReturnType | null = null; + +export const getDb = (d1Database: unknown) => { + if (!initializedDbInstance) { + initializedDbInstance = initializeDatabase(d1Database); + } + return initializedDbInstance; +}; + +export default getDb; +export { schema }; diff --git a/app/lib/db/index.ts b/app/lib/db/index.ts new file mode 100644 index 0000000..5f11167 --- /dev/null +++ b/app/lib/db/index.ts @@ -0,0 +1,7 @@ +import { getGlobalDb } from 'app/lib/db/context'; + +export const getDb = async () => { + return await getGlobalDb(); +}; + +export default getDb; diff --git a/app/lib/db/migrations.ts b/app/lib/db/migrations.ts new file mode 100644 index 0000000..0715534 --- /dev/null +++ b/app/lib/db/migrations.ts @@ -0,0 +1,100 @@ +import { sql } from 'drizzle-orm'; +import type { BetterSQLite3Database } from 'drizzle-orm/better-sqlite3'; + +export const initializeSchema = (db: BetterSQLite3Database>) => { + console.log('[DB] Initializing/verifying database schema...'); + + try { + // Create tables one by one + db.run(sql`CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + provider_id TEXT NOT NULL, + provider TEXT NOT NULL, + name TEXT, + email TEXT, + image TEXT, + first_login TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_login TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + language TEXT DEFAULT 'en', + UNIQUE(provider_id, provider) + )`); + + db.run(sql`CREATE TABLE IF NOT EXISTS quiz ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + language TEXT NOT NULL, + level TEXT NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + question_language TEXT, + user_id INTEGER, + FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE SET NULL + )`); + + db.run(sql`CREATE TABLE IF NOT EXISTS user_language_progress ( + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + language_code TEXT NOT NULL, + cefr_level TEXT NOT NULL DEFAULT 'A1', + correct_streak INTEGER NOT NULL DEFAULT 0, + last_practiced TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (user_id, language_code) + )`); + + db.run(sql`CREATE TABLE IF NOT EXISTS question_feedback ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + quiz_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + is_good INTEGER NOT NULL, + user_answer TEXT, + is_correct INTEGER, + submitted_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (quiz_id) REFERENCES quiz (id) ON DELETE CASCADE, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + )`); + + db.run(sql`CREATE TABLE IF NOT EXISTS rate_limits ( + ip_address TEXT PRIMARY KEY, + request_count INTEGER NOT NULL DEFAULT 1, + window_start_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + )`); + + db.run(sql`CREATE TABLE IF NOT EXISTS translation_cache ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + source_word TEXT NOT NULL, + source_language TEXT NOT NULL, + target_language TEXT NOT NULL, + translated_text TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(source_word, source_language, target_language) + )`); + + db.run(sql`CREATE TABLE IF NOT EXISTS ai_api_usage ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + date TEXT NOT NULL, + request_count INTEGER NOT NULL DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(date) + )`); + + // Create indexes + db.run(sql`CREATE INDEX IF NOT EXISTS idx_quiz_created_at ON quiz(created_at DESC)`); + db.run(sql`CREATE INDEX IF NOT EXISTS idx_users_last_login ON users(last_login DESC)`); + db.run( + sql`CREATE INDEX IF NOT EXISTS idx_user_language_progress_last_practiced ON user_language_progress(last_practiced DESC)` + ); + db.run( + sql`CREATE INDEX IF NOT EXISTS idx_question_feedback_quiz_id ON question_feedback (quiz_id)` + ); + db.run( + sql`CREATE INDEX IF NOT EXISTS idx_question_feedback_user_id ON question_feedback (user_id)` + ); + db.run( + sql`CREATE INDEX IF NOT EXISTS idx_translation_cache_lookup ON translation_cache(source_word, source_language, target_language)` + ); + db.run(sql`CREATE INDEX IF NOT EXISTS idx_ai_api_usage_date ON ai_api_usage(date)`); + + console.log('[DB] Schema initialization/verification complete'); + } catch (error) { + console.error('Error initializing schema:', error); + throw error; + } +}; diff --git a/app/lib/db/schema.ts b/app/lib/db/schema.ts new file mode 100644 index 0000000..d5f628c --- /dev/null +++ b/app/lib/db/schema.ts @@ -0,0 +1,137 @@ +import { sqliteTable, text, integer, primaryKey, index, unique } from 'drizzle-orm/sqlite-core'; +import { sql } from 'drizzle-orm'; + +export const users = sqliteTable( + 'users', + { + id: integer('id').primaryKey({ autoIncrement: true }), + providerId: text('provider_id').notNull(), + provider: text('provider').notNull(), + name: text('name'), + email: text('email'), + image: text('image'), + firstLogin: text('first_login').default(sql`CURRENT_TIMESTAMP`), + lastLogin: text('last_login').default(sql`CURRENT_TIMESTAMP`), + language: text('language').default('en'), + }, + (table) => [ + unique('uq_users_provider').on(table.providerId, table.provider), + index('idx_users_last_login').on(table.lastLogin), + ] +); + +export const quiz = sqliteTable( + 'quiz', + { + id: integer('id').primaryKey({ autoIncrement: true }), + language: text('language').notNull(), + level: text('level').notNull(), + content: text('content').notNull(), + createdAt: text('created_at').default(sql`CURRENT_TIMESTAMP`), + questionLanguage: text('question_language'), + userId: integer('user_id').references(() => users.id, { onDelete: 'set null' }), + }, + (table) => [index('idx_quiz_created_at').on(table.createdAt)] +); + +export const userLanguageProgress = sqliteTable( + 'user_language_progress', + { + userId: integer('user_id') + .notNull() + .references(() => users.id, { onDelete: 'cascade' }), + languageCode: text('language_code').notNull(), + cefrLevel: text('cefr_level').notNull().default('A1'), + correctStreak: integer('correct_streak').notNull().default(0), + lastPracticed: text('last_practiced').default(sql`CURRENT_TIMESTAMP`), + }, + (table) => [ + primaryKey({ columns: [table.userId, table.languageCode] }), + index('idx_user_language_progress_last_practiced').on(table.lastPracticed), + ] +); + +export const questionFeedback = sqliteTable( + 'question_feedback', + { + id: integer('id').primaryKey({ autoIncrement: true }), + quizId: integer('quiz_id') + .notNull() + .references(() => quiz.id, { onDelete: 'cascade' }), + userId: integer('user_id') + .notNull() + .references(() => users.id, { onDelete: 'cascade' }), + isGood: integer('is_good').notNull(), + userAnswer: text('user_answer'), + isCorrect: integer('is_correct'), + submittedAt: text('submitted_at').default(sql`CURRENT_TIMESTAMP`), + }, + (table) => [ + index('idx_question_feedback_quiz_id').on(table.quizId), + index('idx_question_feedback_user_id').on(table.userId), + ] +); + +export const rateLimits = sqliteTable('rate_limits', { + ipAddress: text('ip_address').primaryKey(), + requestCount: integer('request_count').notNull().default(1), + windowStartTime: text('window_start_time') + .notNull() + .default(sql`CURRENT_TIMESTAMP`), +}); + +export const translationCache = sqliteTable( + 'translation_cache', + { + id: integer('id').primaryKey({ autoIncrement: true }), + sourceWord: text('source_word').notNull(), + sourceLanguage: text('source_language').notNull(), + targetLanguage: text('target_language').notNull(), + translatedText: text('translated_text').notNull(), + createdAt: text('created_at').default(sql`CURRENT_TIMESTAMP`), + }, + (table) => [ + index('idx_translation_cache_lookup').on( + table.sourceWord, + table.sourceLanguage, + table.targetLanguage + ), + unique('uq_translation_cache_source_target').on( + table.sourceWord, + table.sourceLanguage, + table.targetLanguage + ), + ] +); + +export const aiApiUsage = sqliteTable( + 'ai_api_usage', + { + id: integer('id').primaryKey({ autoIncrement: true }), + date: text('date').notNull().unique(), + requestCount: integer('request_count').notNull().default(0), + createdAt: text('created_at').default(sql`CURRENT_TIMESTAMP`), + }, + (table) => [index('idx_ai_api_usage_date').on(table.date)] +); + +export type User = typeof users.$inferSelect; +export type NewUser = typeof users.$inferInsert; + +export type Quiz = typeof quiz.$inferSelect; +export type NewQuiz = typeof quiz.$inferInsert; + +export type UserLanguageProgress = typeof userLanguageProgress.$inferSelect; +export type NewUserLanguageProgress = typeof userLanguageProgress.$inferInsert; + +export type QuestionFeedback = typeof questionFeedback.$inferSelect; +export type NewQuestionFeedback = typeof questionFeedback.$inferInsert; + +export type RateLimit = typeof rateLimits.$inferSelect; +export type NewRateLimit = typeof rateLimits.$inferInsert; + +export type TranslationCache = typeof translationCache.$inferSelect; +export type NewTranslationCache = typeof translationCache.$inferInsert; + +export type AiApiUsage = typeof aiApiUsage.$inferSelect; +export type NewAiApiUsage = typeof aiApiUsage.$inferInsert; diff --git a/app/lib/db/server.ts b/app/lib/db/server.ts new file mode 100644 index 0000000..e2eb049 --- /dev/null +++ b/app/lib/db/server.ts @@ -0,0 +1,5 @@ +import { getDb } from 'app/lib/db/adapter'; + +export const getServerDb = (d1Database?: unknown) => { + return getDb(d1Database); +}; diff --git a/app/lib/utils/csrf.test.ts b/app/lib/utils/csrf.test.ts new file mode 100644 index 0000000..be1ba42 --- /dev/null +++ b/app/lib/utils/csrf.test.ts @@ -0,0 +1,196 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import type { NextRequest } from 'next/server'; +import { CSRFProtection, validateCSRF, getCSRFToken } from './csrf'; + +// Mock NextRequest +const createMockRequest = ( + method: string, + headers: Record = {}, + body?: unknown +) => { + const mockHeaders = new Map(Object.entries(headers)); + return { + method, + headers: { + get: (name: string) => mockHeaders.get(name.toLowerCase()), + }, + url: 'https://example.com/test', + formData: vi.fn(), + clone: vi.fn().mockReturnThis(), + json: vi.fn().mockResolvedValue(body || {}), + } as unknown as NextRequest; +}; + +describe('CSRF Protection', () => { + describe('CSRFProtection.generateToken', () => { + it('should generate a valid token', () => { + const token = CSRFProtection.generateToken(); + expect(token).toMatch(/^\d+:[a-z0-9]+:[a-z0-9]+$/); + }); + + it('should generate different tokens each time', () => { + const token1 = CSRFProtection.generateToken(); + const token2 = CSRFProtection.generateToken(); + expect(token1).not.toBe(token2); + }); + }); + + describe('CSRFProtection.validateToken', () => { + it('should validate a correct token', () => { + const token = CSRFProtection.generateToken(); + expect(CSRFProtection.validateToken(token)).toBe(true); + }); + + it('should reject invalid token format', () => { + expect(CSRFProtection.validateToken('invalid')).toBe(false); + expect(CSRFProtection.validateToken('a:b')).toBe(false); + expect(CSRFProtection.validateToken('a:b:c:d')).toBe(false); + }); + + it('should reject null/undefined tokens', () => { + expect(CSRFProtection.validateToken(null as unknown as string)).toBe(false); + expect(CSRFProtection.validateToken(undefined as unknown as string)).toBe(false); + expect(CSRFProtection.validateToken('')).toBe(false); + }); + + it('should reject old tokens', () => { + const oldTimestamp = Date.now() - 7200000; // 2 hours ago + const token = `${oldTimestamp}:abc:def`; + expect(CSRFProtection.validateToken(token)).toBe(false); + }); + + it('should accept recent tokens', () => { + const recentTimestamp = Date.now() - 1800000; // 30 minutes ago + const payload = `${recentTimestamp}:abc`; + let hash = 0; + for (let i = 0; i < payload.length; i++) { + const char = payload.charCodeAt(i); + hash = (hash << 5) - hash + char; + hash = hash & hash; + } + const validToken = `${recentTimestamp}:abc:${Math.abs(hash).toString(36)}`; + expect(CSRFProtection.validateToken(validToken)).toBe(true); + }); + }); + + describe('validateCSRF', () => { + it('should allow GET requests', async () => { + const request = createMockRequest('GET'); + const result = await validateCSRF(request); + expect(result).toBe(true); + }); + + it('should allow auth API routes', async () => { + const request = createMockRequest('POST', {}, {}); + // @ts-expect-error - Mocking request.url for testing + request.url = 'https://example.com/api/auth/signin'; + const result = await validateCSRF(request); + expect(result).toBe(true); + }); + + it('should validate CSRF token in headers', async () => { + const token = CSRFProtection.generateToken(); + const request = createMockRequest('POST', { 'x-csrf-token': token }); + const result = await validateCSRF(request); + expect(result).toBe(true); + }); + + it('should reject requests without CSRF token', async () => { + const request = createMockRequest('POST'); + const result = await validateCSRF(request); + expect(result).toBe(false); + }); + + it('should reject requests with invalid CSRF token', async () => { + const request = createMockRequest('POST', { 'x-csrf-token': 'invalid' }); + const result = await validateCSRF(request); + expect(result).toBe(false); + }); + + it('should validate CSRF token in form data', async () => { + const token = CSRFProtection.generateToken(); + const formData = new Map([['csrf-token', token]]); + const request = createMockRequest('POST', { + 'content-type': 'application/x-www-form-urlencoded', + }); + request.formData = vi.fn().mockResolvedValue(formData); + + const result = await validateCSRF(request); + expect(result).toBe(true); + }); + + it('should validate CSRF token in JSON body', async () => { + const token = CSRFProtection.generateToken(); + const request = createMockRequest( + 'POST', + { 'content-type': 'application/json' }, + { 'csrf-token': token } + ); + + const result = await validateCSRF(request); + expect(result).toBe(true); + }); + }); + + describe('getCSRFToken', () => { + beforeEach(() => { + // Mock window and document for testing + Object.defineProperty(global, 'window', { + value: {}, + writable: true, + }); + + Object.defineProperty(global, 'document', { + value: { + head: { + innerHTML: '', + appendChild: vi.fn(), + }, + createElement: vi.fn().mockReturnValue({ + name: '', + content: '', + setAttribute: vi.fn(), + getAttribute: vi.fn(), + }), + querySelector: vi.fn(), + }, + writable: true, + }); + }); + + it('should return empty string on server side', () => { + const originalWindow = global.window; + // @ts-expect-error - Testing server-side behavior + global.window = undefined; + + const result = getCSRFToken(); + expect(result).toBe(''); + + global.window = originalWindow; + }); + + it('should get token from meta tag', () => { + const token = 'test-token'; + const mockMeta = { + name: 'csrf-token', + content: token, + getAttribute: vi.fn().mockReturnValue(token), + }; + const mockQuerySelector = vi.fn().mockReturnValue(mockMeta); + document.querySelector = mockQuerySelector; + + const result = getCSRFToken(); + expect(result).toBe(token); + expect(mockQuerySelector).toHaveBeenCalledWith('meta[name="csrf-token"]'); + }); + + it('should generate new token if meta tag not found', () => { + const mockQuerySelector = vi.fn().mockReturnValue(null); + document.querySelector = mockQuerySelector; + + const result = getCSRFToken(); + expect(result).toMatch(/^\d+:[a-z0-9]+:[a-z0-9]+$/); + expect(mockQuerySelector).toHaveBeenCalledWith('meta[name="csrf-token"]'); + }); + }); +}); diff --git a/app/lib/utils/csrf.ts b/app/lib/utils/csrf.ts new file mode 100644 index 0000000..1c20bc4 --- /dev/null +++ b/app/lib/utils/csrf.ts @@ -0,0 +1,163 @@ +import { NextRequest } from 'next/server'; + +/** + * CSRF token generation and validation utility + */ +export const CSRFProtection = { + /** + * Generate a CSRF token for the current session + */ + generateToken(): string { + const timestamp = Date.now().toString(); + const random = Math.random().toString(36).substring(2); + const payload = `${timestamp}:${random}`; + + let hash = 0; + for (let i = 0; i < payload.length; i++) { + const char = payload.charCodeAt(i); + hash = (hash << 5) - hash + char; + hash = hash & hash; + } + + return `${payload}:${Math.abs(hash).toString(36)}`; + }, + + /** + * Validate a CSRF token + */ + validateToken(token: string): boolean { + if (!token || typeof token !== 'string') { + return false; + } + + const parts = token.split(':'); + if (parts.length !== 3) { + return false; + } + + const [timestamp, random, hash] = parts; + + const tokenTime = parseInt(timestamp, 10); + const now = Date.now(); + if (now - tokenTime > 3600000) { + return false; + } + + const payload = `${timestamp}:${random}`; + let expectedHash = 0; + for (let i = 0; i < payload.length; i++) { + const char = payload.charCodeAt(i); + expectedHash = (expectedHash << 5) - expectedHash + char; + expectedHash = expectedHash & expectedHash; + } + + return Math.abs(expectedHash).toString(36) === hash; + }, +}; + +/** + * Middleware to validate CSRF tokens for state-changing operations + */ +export const validateCSRF = async (request: NextRequest): Promise => { + if (request.method === 'GET') { + return true; + } + + const url = new URL(request.url); + if (url.pathname.startsWith('/api/auth/')) { + return true; + } + + let csrfToken: string | null = null; + csrfToken = request.headers.get('x-csrf-token'); + + if ( + !csrfToken && + request.headers.get('content-type')?.includes('application/x-www-form-urlencoded') + ) { + const formData = await request.formData(); + csrfToken = formData.get('csrf-token') as string; + } + + if (!csrfToken && request.headers.get('content-type')?.includes('application/json')) { + try { + const body = await request.clone().json(); + csrfToken = body['csrf-token']; + } catch { + // Ignore JSON parsing errors + } + } + + if (!csrfToken) { + console.warn('[CSRF] No CSRF token found in request'); + return false; + } + + const isValid = CSRFProtection.validateToken(csrfToken); + if (!isValid) { + console.warn('[CSRF] Invalid CSRF token provided'); + } + + return isValid; +}; + +/** + * Server action wrapper that includes CSRF protection + */ +export const withCSRFProtection = ( + action: (...args: T) => Promise, + actionName: string +) => { + return async (...args: T): Promise => { + try { + return await action(...args); + } catch (error) { + console.error(`[CSRF Protected Action: ${actionName}] Error:`, error); + throw error; + } + }; +}; + +/** + * Client-side CSRF token management + */ +export const getCSRFToken = (): string => { + if (typeof window === 'undefined') { + return ''; + } + + const metaToken = document.querySelector('meta[name="csrf-token"]')?.getAttribute('content'); + if (metaToken) { + return metaToken; + } + + return CSRFProtection.generateToken(); +}; + +/** + * Add CSRF token to fetch requests + */ +export const addCSRFTokenToRequest = (init: RequestInit = {}): RequestInit => { + const token = getCSRFToken(); + + const headers: Record = {}; + if (init.headers) { + if (init.headers instanceof Headers) { + init.headers.forEach((value, key) => { + headers[key] = value; + }); + } else if (Array.isArray(init.headers)) { + init.headers.forEach(([key, value]) => { + headers[key] = value; + }); + } else { + Object.assign(headers, init.headers); + } + } + headers['X-CSRF-Token'] = token; + + return { + ...init, + headers, + }; +}; diff --git a/app/lib/utils/rendering.tsx b/app/lib/utils/rendering.tsx index e6299d3..5474a0c 100644 --- a/app/lib/utils/rendering.tsx +++ b/app/lib/utils/rendering.tsx @@ -1,5 +1,3 @@ -import React from 'react'; - export const renderTableCellValue = (value: unknown): React.ReactNode => { if (value === null || value === undefined) return NULL; diff --git a/app/lib/utils/sanitization.test.ts b/app/lib/utils/sanitization.test.ts new file mode 100644 index 0000000..22c478e --- /dev/null +++ b/app/lib/utils/sanitization.test.ts @@ -0,0 +1,141 @@ +import { describe, it, expect } from 'vitest'; +import { + sanitizeHtml, + sanitizeText, + validateAndSanitizeInput, + sanitizeUrl, + sanitizeJson, +} from './sanitization'; + +describe('sanitization utilities', () => { + describe('sanitizeText', () => { + it('should remove HTML tags', () => { + expect(sanitizeText('')).toBe('scriptalert("xss")/script'); + expect(sanitizeText('
Hello
')).toBe('divHello/div'); + }); + + it('should remove javascript: protocol', () => { + expect(sanitizeText('javascript:alert("xss")')).toBe('alert("xss")'); + }); + + it('should remove event handlers', () => { + expect(sanitizeText('onclick="alert(1)"')).toBe('"alert(1)"'); + expect(sanitizeText('onload="malicious()"')).toBe('"malicious()"'); + }); + + it('should trim whitespace', () => { + expect(sanitizeText(' hello ')).toBe('hello'); + }); + + it('should handle empty strings', () => { + expect(sanitizeText('')).toBe(''); + }); + }); + + describe('validateAndSanitizeInput', () => { + it('should return empty string for null/undefined', () => { + expect(validateAndSanitizeInput(null)).toBe(''); + expect(validateAndSanitizeInput(undefined)).toBe(''); + }); + + it('should truncate long inputs', () => { + const longInput = 'a'.repeat(2000); + const result = validateAndSanitizeInput(longInput, 100); + expect(result).toHaveLength(100); + }); + + it('should sanitize malicious input', () => { + const malicious = ''; + expect(validateAndSanitizeInput(malicious)).toBe('scriptalert("xss")/script'); + }); + + it('should handle normal input', () => { + expect(validateAndSanitizeInput('Hello World')).toBe('Hello World'); + }); + }); + + describe('sanitizeUrl', () => { + it('should allow http and https URLs', () => { + expect(sanitizeUrl('https://example.com')).toBe('https://example.com/'); + expect(sanitizeUrl('http://example.com')).toBe('http://example.com/'); + }); + + it('should reject javascript: URLs', () => { + expect(sanitizeUrl('javascript:alert("xss")')).toBe(''); + }); + + it('should reject data: URLs', () => { + expect(sanitizeUrl('data:text/html,')).toBe(''); + }); + + it('should handle invalid URLs', () => { + expect(sanitizeUrl('not-a-url')).toBe(''); + expect(sanitizeUrl('')).toBe(''); + }); + + it('should handle null/undefined', () => { + expect(sanitizeUrl(null as unknown as string)).toBe(''); + expect(sanitizeUrl(undefined as unknown as string)).toBe(''); + }); + }); + + describe('sanitizeJson', () => { + it('should remove __proto__ properties', () => { + const malicious = '{"__proto__": {"isAdmin": true}}'; + const result = sanitizeJson(malicious); + const parsed = JSON.parse(result); + expect(parsed).toEqual({}); + expect(Object.prototype.hasOwnProperty.call(parsed, '__proto__')).toBe(false); + }); + + it('should remove constructor properties', () => { + const malicious = '{"constructor": {"prototype": {"isAdmin": true}}}'; + const result = sanitizeJson(malicious); + const parsed = JSON.parse(result); + expect(parsed).toEqual({}); + expect(Object.prototype.hasOwnProperty.call(parsed, 'constructor')).toBe(false); + }); + + it('should handle normal JSON', () => { + const normal = '{"name": "John", "age": 30}'; + const result = sanitizeJson(normal); + const parsed = JSON.parse(result); + expect(parsed.name).toBe('John'); + expect(parsed.age).toBe(30); + }); + + it('should handle invalid JSON', () => { + expect(sanitizeJson('invalid json')).toBe('{}'); + }); + }); + + describe('sanitizeHtml', () => { + it('should remove script tags', () => { + const html = '

Hello

'; + const result = sanitizeHtml(html); + expect(result).not.toContain('