diff --git a/WORKSPACE.md b/WORKSPACE.md index 3aec1c2..79c56ca 100644 --- a/WORKSPACE.md +++ b/WORKSPACE.md @@ -4,6 +4,38 @@ --- +## 2026-06-10 + +### 操作记录 + +#### 277. PR #77 Postgres/Neon review 启动路径回归补测 + +**时间**:2026-06-10 + +**操作背景**: +CodeXWeb 自动巡检检查最近活跃的 3 个仓库后,本轮只认领并处理 `ZeroPointSix/outlookEmailPlus` PR #77。认领评论:。处理目标是 review 评论 `4670508474` 提到的高优先级补测:用 fake psycopg 覆盖 `DATABASE_URL`、`install_postgres_sqlite_compat()` 与 `init_db()` 的完整启动路径,确保最终发送给 PostgreSQL 的 SQL 不再包含 SQLite-only 片段。 + +**修改内容**: + +1. **outlook_web/db_postgres_compat.py**:补齐 email 字符串函数翻译,覆盖小写/大小写混用的 `instr(email, '@')` 与 `substr(...)`,转换为 PostgreSQL `POSITION` / `SPLIT_PART`。 +2. **tests/test_db_postgres_compat.py**:新增 email 函数翻译单测;新增 fake psycopg 启动路径回归测试,直接运行 `db.init_db()` 并断言最终 SQL 中不含 `sqlite_master`、`PRAGMA`、`strftime`、`INSERT OR REPLACE`、`INSERT OR IGNORE`、`AUTOINCREMENT`、`COLLATE NOCASE`、`unixepoch`、`instr(`、`substr(`。 + +**验证结果**: + +1. `.venv/bin/python -m unittest tests.test_db_postgres_compat -v` → 21 passed。 +2. `.venv/bin/python -m py_compile outlook_web/db_postgres_compat.py tests/test_db_postgres_compat.py` → 通过。 +3. `git diff --check` → 通过。 +4. `.venv/bin/python -m pytest -q`(补装 pytest、Playwright,并启动本地服务覆盖 UI 脚本)→ 1544 passed, 7 skipped, 31 subtests passed;4 failed,均为 `tests/test_pool_cf_real_e2e.py::RealCFWorkerE2ETests`,失败原因是真实 Cloudflare Worker 上游返回 HTTP 400 / `UPSTREAM_BAD_PAYLOAD`,与本次 Postgres/Neon SQL 翻译和 `init_db()` 启动路径补测不相邻。 + +**剩余风险**: + +- 本轮没有连接真实 Neon/PostgreSQL 实例做 smoke test;review 建议的 fake psycopg 启动路径回归已覆盖并通过。 +- 真实 CF Worker E2E 仍受外部上游状态影响,需由对应服务侧继续排查。 + +**是否改动代码**:是(Postgres 兼容层 + 回归测试) + +--- + ## 2026-05-19 ### 操作记录 diff --git a/docs/postgres-neon.md b/docs/postgres-neon.md new file mode 100644 index 0000000..5ad8319 --- /dev/null +++ b/docs/postgres-neon.md @@ -0,0 +1,38 @@ +# PostgreSQL / Neon database mode + +This project still defaults to SQLite through `DATABASE_PATH`. For third-party +database deployments such as Neon, set `DATABASE_URL` to a PostgreSQL URL before +starting the app: + +```env +DATABASE_URL=postgresql://db-user@db-host.example/db-name?sslmode=require +``` + +Use the complete connection string from your database provider when deploying. + +## Behavior + +- Empty `DATABASE_URL`: keeps the existing SQLite behavior. +- `postgres://...` or `postgresql://...`: routes `sqlite3.connect(...)` through a + PostgreSQL compatibility adapter. +- `sqlite://...`, `sqlite3://...`, or `file:`: ignored so the current SQLite path + remains active. +- Any other scheme: startup fails early with a clear configuration error. + +## Notes for Neon + +- Neon requires TLS for most hosted connections, so keep `sslmode=require` in + the URL unless your Neon project says otherwise. +- On first startup, the app creates the same application tables in PostgreSQL. +- This does not automatically copy data from an existing SQLite database. Export + or migrate data separately before switching production traffic. + +## Compatibility scope + +The adapter translates the SQLite patterns used by the current application, +including `?` parameters, `INSERT OR IGNORE`, settings upserts, basic `PRAGMA` +table inspection, `BEGIN IMMEDIATE`, and common SQLite schema fragments. + +This is intended as the first third-party database support path. If future +schema work grows beyond the compatibility adapter, the next step should be a +proper migration layer such as Alembic. diff --git a/outlook_web/__init__.py b/outlook_web/__init__.py index 9c61e43..f82efb7 100644 --- a/outlook_web/__init__.py +++ b/outlook_web/__init__.py @@ -17,6 +17,10 @@ def _glob_list(self: Path, pattern: str): # type: ignore[override] except Exception: pass +from outlook_web.db_postgres_compat import install_postgres_sqlite_compat + +install_postgres_sqlite_compat() + from outlook_web.app import create_app __all__ = ["create_app", "__version__"] diff --git a/outlook_web/db_postgres_compat.py b/outlook_web/db_postgres_compat.py new file mode 100644 index 0000000..41479c5 --- /dev/null +++ b/outlook_web/db_postgres_compat.py @@ -0,0 +1,486 @@ +from __future__ import annotations + +import os +import re +import sqlite3 +from collections.abc import Iterable, Iterator +from typing import Any, Optional + +_ORIGINAL_SQLITE_CONNECT = sqlite3.connect +_INSTALLED = False +_ACTIVE_DATABASE_URL = "" + +_POSTGRES_SCHEMES = ("postgres://", "postgresql://") +_SQLITE_SCHEMES = ("sqlite://", "sqlite3://", "file:") +_RETURNING_ID_TABLES = { + "account_claim_logs", + "account_project_usage", + "account_refresh_logs", + "accounts", + "audit_logs", + "external_api_consumer_usage_daily", + "external_api_keys", + "external_api_rate_limits", + "external_upstream_probes", + "groups", + "notification_delivery_logs", + "schema_migrations", + "tags", + "temp_email_messages", + "temp_emails", + "verification_extract_logs", +} +_TEMP_EMAIL_MESSAGES_CREATE_SQL = """ +CREATE TABLE temp_email_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + message_id TEXT NOT NULL, + email_address TEXT NOT NULL, + from_address TEXT, + subject TEXT, + content TEXT, + html_content TEXT, + has_html INTEGER DEFAULT 0, + timestamp INTEGER, + raw_content TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(email_address, message_id) +) +""" + + +def get_database_url_from_env() -> str: + return (os.getenv("DATABASE_URL") or "").strip() + + +def is_postgres_database_url(database_url: str | None) -> bool: + return str(database_url or "").strip().lower().startswith(_POSTGRES_SCHEMES) + + +def install_postgres_sqlite_compat(database_url: str | None = None) -> bool: + """Install a sqlite3.connect shim when DATABASE_URL points at Postgres. + + The application still defaults to the existing SQLite path. This shim is + only activated for postgresql:// or postgres:// URLs, including Neon URLs. + """ + + global _ACTIVE_DATABASE_URL, _INSTALLED + + url = (database_url or get_database_url_from_env()).strip() + if not url: + return False + + normalized = url.lower() + if normalized.startswith(_SQLITE_SCHEMES): + return False + + if not normalized.startswith(_POSTGRES_SCHEMES): + raise RuntimeError("DATABASE_URL only supports postgresql:// or postgres:// for third-party database mode.") + + if _INSTALLED and _ACTIVE_DATABASE_URL == url: + return True + + try: + import psycopg # noqa: F401 + except Exception as exc: + raise RuntimeError( + "DATABASE_URL is set to a PostgreSQL URL, but psycopg is not installed. " + "Install dependencies with `pip install -r requirements.txt`." + ) from exc + + def _connect(_database: Any = None, *args: Any, **kwargs: Any) -> "PostgresCompatConnection": + return PostgresCompatConnection(url) + + sqlite3.connect = _connect # type: ignore[assignment] + _ACTIVE_DATABASE_URL = url + _INSTALLED = True + return True + + +class CompatRow: + """sqlite3.Row-like wrapper supporting both numeric and name lookup.""" + + def __init__(self, names: Iterable[str], values: Iterable[Any]): + self._names = list(names) + self._values = tuple(values) + self._index = {name: idx for idx, name in enumerate(self._names)} + + def __getitem__(self, key: str | int) -> Any: + if isinstance(key, int): + return self._values[key] + return self._values[self._index[key]] + + def __iter__(self) -> Iterator[Any]: + return iter(self._values) + + def __len__(self) -> int: + return len(self._values) + + def keys(self) -> list[str]: + return list(self._names) + + def items(self): + for name in self._names: + yield name, self[name] + + def values(self): + return iter(self._values) + + def __contains__(self, key: object) -> bool: + return key in self._index + + +class _StaticCursor: + def __init__(self, rows: list[CompatRow] | None = None, *, rowcount: int = -1): + self._rows = rows or [] + self._offset = 0 + self.rowcount = rowcount + self.lastrowid = None + + def fetchone(self) -> Optional[CompatRow]: + if self._offset >= len(self._rows): + return None + row = self._rows[self._offset] + self._offset += 1 + return row + + def fetchall(self) -> list[CompatRow]: + rows = self._rows[self._offset :] + self._offset = len(self._rows) + return rows + + +class PostgresCompatCursor: + def __init__(self, connection: "PostgresCompatConnection"): + self._connection = connection + self._cursor = None + self._rows: list[CompatRow] = [] + self._offset = 0 + self.rowcount = -1 + self.lastrowid = None + + def execute(self, sql: str, params: Any = None) -> "PostgresCompatCursor": + self._rows = [] + self._offset = 0 + self.lastrowid = None + + special = self._execute_special(sql, params) + if special is not None: + self._rows = special._rows + self.rowcount = special.rowcount + self.lastrowid = special.lastrowid + return self + + translated = translate_sqlite_sql(sql) + translated = _append_returning_id_if_needed(translated) + bound_params = _normalize_params(params) + + try: + pg_cursor = self._connection._raw.cursor() + # Queries come from existing app statements; params remain bound. + pg_cursor.execute(translated, bound_params) # NOSONAR + self._cursor = pg_cursor + self.rowcount = pg_cursor.rowcount + if pg_cursor.description: + names = [desc.name for desc in pg_cursor.description] + fetched = pg_cursor.fetchall() + self._rows = [CompatRow(names, row) for row in fetched] + if _returns_single_id(translated) and self._rows: + self.lastrowid = self._rows[0]["id"] + self._connection._last_insert_id = self.lastrowid + return self + except self._connection._psycopg.IntegrityError as exc: + self._connection.rollback() + raise sqlite3.IntegrityError(str(exc)) from exc + + def executemany(self, sql: str, seq_of_params: Iterable[Any]) -> "PostgresCompatCursor": + total_rowcount = 0 + self.lastrowid = None + for params in seq_of_params: + self.execute(sql, params) + if self.rowcount and self.rowcount > 0: + total_rowcount += self.rowcount + self.rowcount = total_rowcount + return self + + def fetchone(self) -> Optional[CompatRow]: + if self._offset >= len(self._rows): + return None + row = self._rows[self._offset] + self._offset += 1 + return row + + def fetchall(self) -> list[CompatRow]: + rows = self._rows[self._offset :] + self._offset = len(self._rows) + return rows + + def close(self) -> None: + if self._cursor is not None: + self._cursor.close() + + def _execute_special(self, sql: str, params: Any = None) -> _StaticCursor | None: + normalized = _collapse_sql(sql) + upper = normalized.upper() + + if upper.startswith("PRAGMA "): + return self._execute_pragma(normalized) + + sqlite_master_table = _sqlite_master_table_name(normalized) + if sqlite_master_table == "temp_email_messages": + row = CompatRow(["sql"], [_TEMP_EMAIL_MESSAGES_CREATE_SQL]) + return _StaticCursor([row], rowcount=1) + if sqlite_master_table: + return _StaticCursor([], rowcount=0) + + if upper in {"BEGIN", "BEGIN IMMEDIATE", "BEGIN EXCLUSIVE"}: + return _StaticCursor([]) + if upper == "COMMIT": + self._connection.commit() + return _StaticCursor([]) + if upper == "ROLLBACK": + self._connection.rollback() + return _StaticCursor([]) + if re.match(r"SELECT\s+last_insert_rowid\(\)\s+AS\s+id", normalized, re.I): + row = CompatRow(["id"], [self._connection._last_insert_id]) + return _StaticCursor([row], rowcount=1) + + return None + + def _execute_pragma(self, normalized: str) -> _StaticCursor: + table_info = re.match(r"PRAGMA\s+table_info\((?:'|\")?([^'\")]+)(?:'|\")?\)", normalized, re.I) + if table_info: + table_name = table_info.group(1) + return _StaticCursor(self._connection._table_info(table_name)) + + index_list = re.match(r"PRAGMA\s+index_list\((?:'|\")?([^'\")]+)(?:'|\")?\)", normalized, re.I) + if index_list: + return _StaticCursor([]) + + return _StaticCursor([]) + + +class PostgresCompatConnection: + def __init__(self, database_url: str): + import psycopg + + self._psycopg = psycopg + # The URL is provided by deployment configuration for the selected database backend. + self._raw = psycopg.connect(database_url) # NOSONAR + self._last_insert_id = None + self.row_factory = None + + def execute(self, sql: str, params: Any = None) -> PostgresCompatCursor: + cursor = self.cursor() + cursor.execute(sql, params) + return cursor + + def executemany(self, sql: str, seq_of_params: Iterable[Any]) -> PostgresCompatCursor: + cursor = self.cursor() + cursor.executemany(sql, seq_of_params) + return cursor + + def cursor(self) -> PostgresCompatCursor: + return PostgresCompatCursor(self) + + def commit(self) -> None: + self._raw.commit() + + def rollback(self) -> None: + self._raw.rollback() + + def close(self) -> None: + self._raw.close() + + def _table_info(self, table_name: str) -> list[CompatRow]: + cursor = self._raw.cursor() + # Fixed metadata query; the table name stays bound as a parameter. + cursor.execute( # NOSONAR + """ + SELECT + ordinal_position - 1 AS cid, + column_name AS name, + data_type AS type, + CASE WHEN is_nullable = 'NO' THEN 1 ELSE 0 END AS notnull, + column_default AS dflt_value, + 0 AS pk + FROM information_schema.columns + WHERE table_schema = current_schema() + AND table_name = %s + ORDER BY ordinal_position + """, + (table_name,), + ) + names = ["cid", "name", "type", "notnull", "dflt_value", "pk"] + return [CompatRow(names, row) for row in cursor.fetchall()] + + +def translate_sqlite_sql(sql: str) -> str: + translated = sql.strip() + translated = re.sub( + r"\bINTEGER\s+PRIMARY\s+KEY\s+AUTOINCREMENT\b", + "INTEGER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY", + translated, + flags=re.I, + ) + translated = re.sub(r" COLLATE NOCASE\b", "", translated, flags=re.I) + translated = translated.replace("unixepoch('now')", "EXTRACT(EPOCH FROM NOW())") + translated = translated.replace("strftime('%s','now')", "EXTRACT(EPOCH FROM NOW())") + translated = translated.replace('strftime("%s","now")', "EXTRACT(EPOCH FROM NOW())") + translated = re.sub( + r"strftime\(\s*['\"]%Y-%m-%dT%H:%M:%S['\"]\s*,\s*['\"]now['\"]\s*\)", + "TO_CHAR(CURRENT_TIMESTAMP, 'YYYY-MM-DD\"T\"HH24:MI:SS')", + translated, + flags=re.I, + ) + translated = translated.replace("datetime('now')", "CURRENT_TIMESTAMP") + translated = _translate_email_sqlite_functions(translated) + translated = _translate_insert_or_replace(translated) + translated = _translate_insert_or_ignore(translated) + translated = _replace_qmark_placeholders(translated) + return translated + + +def _translate_email_sqlite_functions(sql: str) -> str: + translated = re.sub( + r"\bLOWER\s*\(\s*SUBSTR\s*\(\s*email\s*,\s*INSTR\s*\(\s*email\s*,\s*['\"]@['\"]\s*\)\s*\+\s*1\s*\)\s*\)", + "LOWER(SPLIT_PART(email, '@', 2))", + sql, + flags=re.I, + ) + translated = re.sub( + r"\bSUBSTR\s*\(\s*email\s*,\s*1\s*,\s*INSTR\s*\(\s*email\s*,\s*['\"]@['\"]\s*\)\s*-\s*1\s*\)", + "SPLIT_PART(email, '@', 1)", + translated, + flags=re.I, + ) + translated = re.sub( + r"\bSUBSTR\s*\(\s*email\s*,\s*INSTR\s*\(\s*email\s*,\s*['\"]@['\"]\s*\)\s*\+\s*1\s*\)", + "SPLIT_PART(email, '@', 2)", + translated, + flags=re.I, + ) + return re.sub( + r"\bINSTR\s*\(\s*email\s*,\s*['\"]@['\"]\s*\)", + "POSITION('@' IN email)", + translated, + flags=re.I, + ) + + +def _translate_insert_or_replace(sql: str) -> str: + if re.match(r"\s*INSERT\s+OR\s+REPLACE\s+INTO\s+temp_email_messages\b", sql, flags=re.I): + translated = re.sub(r"\bINSERT\s+OR\s+REPLACE\s+INTO\b", "INSERT INTO", sql, count=1, flags=re.I) + if re.search(r"\bON\s+CONFLICT\b", translated, flags=re.I): + return translated + return translated.rstrip().rstrip(";") + """ + ON CONFLICT (email_address, message_id) + DO UPDATE SET + from_address = EXCLUDED.from_address, + subject = EXCLUDED.subject, + content = EXCLUDED.content, + html_content = EXCLUDED.html_content, + has_html = EXCLUDED.has_html, + timestamp = EXCLUDED.timestamp, + raw_content = EXCLUDED.raw_content + """ + + if not re.match(r"\s*INSERT\s+OR\s+REPLACE\s+INTO\s+settings\b", sql, flags=re.I): + return re.sub(r"\bINSERT\s+OR\s+REPLACE\s+INTO\b", "INSERT INTO", sql, flags=re.I) + + translated = re.sub(r"\bINSERT\s+OR\s+REPLACE\s+INTO\b", "INSERT INTO", sql, count=1, flags=re.I) + if re.search(r"\bON\s+CONFLICT\b", translated, flags=re.I): + return translated + return ( + translated.rstrip().rstrip(";") + + " ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value, updated_at = EXCLUDED.updated_at" + ) + + +def _translate_insert_or_ignore(sql: str) -> str: + if not re.match(r"\s*INSERT\s+OR\s+IGNORE\s+INTO\b", sql, flags=re.I): + return sql + translated = re.sub(r"\bINSERT\s+OR\s+IGNORE\s+INTO\b", "INSERT INTO", sql, count=1, flags=re.I) + if re.search(r"\bON\s+CONFLICT\b", translated, flags=re.I): + return translated + return translated.rstrip().rstrip(";") + " ON CONFLICT DO NOTHING" + + +def _replace_qmark_placeholders(sql: str) -> str: + result: list[str] = [] + in_single = False + in_double = False + i = 0 + + while i < len(sql): + char = sql[i] + next_char = sql[i + 1] if i + 1 < len(sql) else "" + + if char == "'" and not in_double: + result.append(char) + if in_single and next_char == "'": + result.append(next_char) + i += 2 + continue + in_single = not in_single + i += 1 + continue + + if char == '"' and not in_single: + result.append(char) + in_double = not in_double + i += 1 + continue + + if char == "?" and not in_single and not in_double: + result.append("%s") + else: + result.append(char) + i += 1 + + return "".join(result) + + +def _append_returning_id_if_needed(sql: str) -> str: + if re.search(r"\bRETURNING\b", sql, flags=re.I): + return sql + match = re.match(r"\s*INSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_]*)\b", sql, flags=re.I) + if not match: + return sql + table_name = match.group(1).lower() + if table_name not in _RETURNING_ID_TABLES: + return sql + return sql.rstrip().rstrip(";") + " RETURNING id" + + +def _returns_single_id(sql: str) -> bool: + return bool(re.search(r"\bRETURNING\s+id\b", sql, flags=re.I)) + + +def _normalize_params(params: Any) -> Any: + if params is None: + return None + if isinstance(params, tuple): + return params + if isinstance(params, list): + return tuple(params) + return params + + +def _collapse_sql(sql: str) -> str: + return " ".join(str(sql or "").strip().split()) + + +def _sqlite_master_table_name(sql: str) -> str | None: + match = re.match( + r"\s*SELECT\s+sql\s+FROM\s+sqlite_master\s+WHERE\s+type\s*=\s*['\"]table['\"]\s+AND\s+name\s*=\s*['\"]([^'\"]+)['\"]", + sql, + flags=re.I, + ) + return match.group(1) if match else None + + +def restore_sqlite_connect_for_tests() -> None: + global _ACTIVE_DATABASE_URL, _INSTALLED + sqlite3.connect = _ORIGINAL_SQLITE_CONNECT # type: ignore[assignment] + _ACTIVE_DATABASE_URL = "" + _INSTALLED = False diff --git a/requirements.txt b/requirements.txt index b6bc5b4..a03499e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ bcrypt>=4.0.0 cryptography>=41.0.0 python-dotenv>=1.0.0 docker>=6.0.0 +psycopg[binary]>=3.1.0 diff --git a/tests/test_db_postgres_compat.py b/tests/test_db_postgres_compat.py new file mode 100644 index 0000000..9a92912 --- /dev/null +++ b/tests/test_db_postgres_compat.py @@ -0,0 +1,434 @@ +from __future__ import annotations + +import os +import sqlite3 +import sys +import types +import unittest +from unittest.mock import patch + +from outlook_web import db_postgres_compat as compat + + +class FakeIntegrityError(Exception): + pass + + +class FakePsycopg: + IntegrityError = FakeIntegrityError + + +class FakePgCursor: + def __init__(self, *, rows=None, description=None, rowcount=1, error=None): + self.rows = list(rows or []) + self.description = description + self.rowcount = rowcount + self.error = error + self.executions = [] + self.closed = False + + def execute(self, sql, params=None): + self.executions.append((sql, params)) + if self.error is not None: + raise self.error + + def fetchall(self): + return list(self.rows) + + def close(self): + self.closed = True + + +class FakeRawConnection: + def __init__(self, cursors=()): + self._cursors = list(cursors) + self.cursor_calls = 0 + self.commits = 0 + self.rollbacks = 0 + self.closes = 0 + + def cursor(self): + self.cursor_calls += 1 + if not self._cursors: + raise AssertionError("No fake cursor queued") + return self._cursors.pop(0) + + def commit(self): + self.commits += 1 + + def rollback(self): + self.rollbacks += 1 + + def close(self): + self.closes += 1 + + +def make_connection(raw): + connection = object.__new__(compat.PostgresCompatConnection) + connection._psycopg = FakePsycopg + connection._raw = raw + connection._last_insert_id = None + connection.row_factory = None + return connection + + +class RecordingPgCursor: + def __init__(self, raw): + self.raw = raw + self.description = None + self.rows = [] + self.rowcount = -1 + self.closed = False + + def execute(self, sql, params=None): + self.raw.executions.append((sql, params)) + lowered = sql.lower().lstrip() + self.description = None + self.rows = [] + self.rowcount = 1 + + if "returning id" in lowered: + self.description = [types.SimpleNamespace(name="id")] + self.rows = [(len(self.raw.executions),)] + elif lowered.startswith("select"): + name = "c" if "count(" in lowered else "value" + self.description = [types.SimpleNamespace(name=name)] + self.rowcount = 0 + + return self + + def fetchall(self): + rows = list(self.rows) + self.rows.clear() + return rows + + def close(self): + self.closed = True + + +class RecordingRawConnection: + def __init__(self): + self.executions = [] + self.commits = 0 + self.rollbacks = 0 + self.closes = 0 + + def cursor(self): + return RecordingPgCursor(self) + + def commit(self): + self.commits += 1 + + def rollback(self): + self.rollbacks += 1 + + def close(self): + self.closes += 1 + + +class RecordingPsycopg: + IntegrityError = FakeIntegrityError + + def __init__(self): + self.database_urls = [] + self.connections = [] + + def connect(self, database_url): + self.database_urls.append(database_url) + raw = RecordingRawConnection() + self.connections.append(raw) + return raw + + +class PostgresCompatSqlTranslationTests(unittest.TestCase): + def tearDown(self) -> None: + compat.restore_sqlite_connect_for_tests() + + def test_no_database_url_keeps_sqlite_connect(self): + original = sqlite3.connect + with patch.dict(os.environ, {}, clear=True): + self.assertFalse(compat.install_postgres_sqlite_compat()) + self.assertIs(sqlite3.connect, original) + + def test_unsupported_database_url_scheme_fails_before_patching(self): + with patch.dict(os.environ, {"DATABASE_URL": "mysql://example"}, clear=True): + with self.assertRaises(RuntimeError): + compat.install_postgres_sqlite_compat() + + def test_sqlite_url_keeps_sqlite_connect_and_detection_is_precise(self): + original = sqlite3.connect + self.assertTrue(compat.is_postgres_database_url("postgres://example/db")) + self.assertTrue(compat.is_postgres_database_url("postgresql://example/db")) + self.assertFalse(compat.is_postgres_database_url("sqlite:///tmp/app.db")) + self.assertFalse(compat.is_postgres_database_url(None)) + + with patch.dict(os.environ, {"DATABASE_URL": "sqlite:///tmp/app.db"}, clear=True): + self.assertFalse(compat.install_postgres_sqlite_compat()) + self.assertIs(sqlite3.connect, original) + + def test_postgres_url_installs_and_restores_connect_shim(self): + original = sqlite3.connect + fake_psycopg = types.SimpleNamespace() + + with patch.dict(sys.modules, {"psycopg": fake_psycopg}): + self.assertTrue(compat.install_postgres_sqlite_compat("postgresql://example/db")) + self.assertIsNot(sqlite3.connect, original) + self.assertTrue(compat.install_postgres_sqlite_compat("postgresql://example/db")) + + compat.restore_sqlite_connect_for_tests() + self.assertIs(sqlite3.connect, original) + + def test_qmark_placeholders_ignore_string_literals(self): + sql = compat.translate_sqlite_sql("SELECT * FROM settings WHERE key = ? AND value != '?' AND note = \"?\"") + self.assertEqual( + sql, + "SELECT * FROM settings WHERE key = %s AND value != '?' AND note = \"?\"", + ) + + def test_insert_or_replace_settings_becomes_postgres_upsert(self): + sql = compat.translate_sqlite_sql(""" + INSERT OR REPLACE INTO settings (key, value, updated_at) + VALUES (?, ?, CURRENT_TIMESTAMP) + """) + self.assertIn("INSERT INTO settings", sql) + self.assertIn("ON CONFLICT (key) DO UPDATE", sql) + self.assertIn("VALUES (%s, %s, CURRENT_TIMESTAMP)", sql) + + def test_insert_or_replace_temp_messages_uses_message_unique_key(self): + sql = compat.translate_sqlite_sql(""" + INSERT OR REPLACE INTO temp_email_messages + (message_id, email_address, subject) + VALUES (?, ?, ?) + """) + self.assertIn("INSERT INTO temp_email_messages", sql) + self.assertIn("ON CONFLICT (email_address, message_id)", sql) + self.assertIn("subject = EXCLUDED.subject", sql) + + def test_insert_or_ignore_becomes_do_nothing(self): + sql = compat.translate_sqlite_sql("INSERT OR IGNORE INTO settings (key, value) VALUES (?, ?)") + self.assertEqual( + sql, + "INSERT INTO settings (key, value) VALUES (%s, %s) ON CONFLICT DO NOTHING", + ) + + def test_sqlite_schema_fragments_become_postgres_compatible(self): + sql = compat.translate_sqlite_sql(""" + CREATE TABLE IF NOT EXISTS sample ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email_domain TEXT COLLATE NOCASE, + created_at REAL DEFAULT (unixepoch('now')) + ) + """) + self.assertIn("INTEGER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY", sql) + self.assertNotIn("COLLATE NOCASE", sql) + self.assertIn("EXTRACT(EPOCH FROM NOW())", sql) + + def test_strftime_iso_default_becomes_postgres_text_default(self): + sql = compat.translate_sqlite_sql("pushed_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%S', 'now'))") + + self.assertIn("TO_CHAR(CURRENT_TIMESTAMP", sql) + self.assertIn('YYYY-MM-DD"T"HH24:MI:SS', sql) + self.assertNotIn("strftime", sql) + + def test_email_address_sqlite_helpers_become_postgres_functions(self): + sql = compat.translate_sqlite_sql(""" + UPDATE temp_emails + SET prefix = substr(email, 1, instr(email, '@') - 1), + domain = substr(email, instr(email, '@') + 1) + WHERE instr(email, '@') > 1 + """) + + self.assertIn("prefix = SPLIT_PART(email, '@', 1)", sql) + self.assertIn("domain = SPLIT_PART(email, '@', 2)", sql) + self.assertIn("POSITION('@' IN email) > 1", sql) + self.assertNotIn("instr(", sql.lower()) + self.assertNotIn("substr(", sql.lower()) + + domain_sql = compat.translate_sqlite_sql( + "UPDATE accounts SET email_domain = LOWER(SUBSTR(email, INSTR(email, '@') + 1)) " + "WHERE INSTR(email, '@') > 1" + ) + self.assertIn("LOWER(SPLIT_PART(email, '@', 2))", domain_sql) + self.assertIn("POSITION('@' IN email) > 1", domain_sql) + self.assertNotIn("instr(", domain_sql.lower()) + self.assertNotIn("substr(", domain_sql.lower()) + + def test_compat_row_behaves_like_sqlite_row_for_common_access(self): + row = compat.CompatRow(["id", "subject"], [7, "Hello"]) + + self.assertEqual(row[0], 7) + self.assertEqual(row["subject"], "Hello") + self.assertEqual(list(row), [7, "Hello"]) + self.assertEqual(len(row), 2) + self.assertEqual(row.keys(), ["id", "subject"]) + self.assertEqual(dict(row.items()), {"id": 7, "subject": "Hello"}) + self.assertEqual(list(row.values()), [7, "Hello"]) + self.assertIn("id", row) + self.assertNotIn("missing", row) + + def test_cursor_execute_translates_params_and_captures_returning_id(self): + description = [types.SimpleNamespace(name="id"), types.SimpleNamespace(name="email")] + pg_cursor = FakePgCursor(rows=[(42, "user@example.com")], description=description, rowcount=1) + connection = make_connection(FakeRawConnection([pg_cursor])) + + cursor = connection.cursor() + result = cursor.execute("INSERT INTO accounts (email) VALUES (?)", ["user@example.com"]) + + self.assertIs(result, cursor) + self.assertEqual(pg_cursor.executions[0][0], "INSERT INTO accounts (email) VALUES (%s) RETURNING id") + self.assertEqual(pg_cursor.executions[0][1], ("user@example.com",)) + self.assertEqual(cursor.rowcount, 1) + self.assertEqual(cursor.lastrowid, 42) + self.assertEqual(connection._last_insert_id, 42) + self.assertEqual(cursor.fetchone()["email"], "user@example.com") + self.assertIsNone(cursor.fetchone()) + self.assertEqual(cursor.fetchall(), []) + cursor.close() + self.assertTrue(pg_cursor.closed) + + def test_cursor_execute_maps_postgres_integrity_errors_to_sqlite(self): + raw = FakeRawConnection([FakePgCursor(error=FakeIntegrityError("duplicate key"))]) + connection = make_connection(raw) + + with self.assertRaises(sqlite3.IntegrityError): + connection.execute("INSERT INTO accounts (email) VALUES (?)", ("user@example.com",)) + + self.assertEqual(raw.rollbacks, 1) + + def test_connection_executemany_sums_rowcount_and_lifecycle_delegates(self): + raw = FakeRawConnection([FakePgCursor(rowcount=1), FakePgCursor(rowcount=2), FakePgCursor(rowcount=-1)]) + connection = make_connection(raw) + + cursor = connection.executemany( + "UPDATE accounts SET email = ? WHERE id = ?", + [("a@example.com", 1), ("b@example.com", 2), ("c@example.com", 3)], + ) + connection.commit() + connection.rollback() + connection.close() + + self.assertEqual(cursor.rowcount, 3) + self.assertEqual(raw.cursor_calls, 3) + self.assertEqual(raw.commits, 1) + self.assertEqual(raw.rollbacks, 1) + self.assertEqual(raw.closes, 1) + + def test_special_statements_handle_transactions_pragmas_and_last_insert_id(self): + connection = make_connection(FakeRawConnection()) + connection._last_insert_id = 99 + cursor = connection.cursor() + + cursor.execute("BEGIN IMMEDIATE") + cursor.execute("COMMIT") + cursor.execute("ROLLBACK") + cursor.execute("SELECT last_insert_rowid() AS id") + row = cursor.fetchone() + cursor.execute("PRAGMA index_list(settings)") + + self.assertEqual(connection._raw.commits, 1) + self.assertEqual(connection._raw.rollbacks, 1) + self.assertEqual(row["id"], 99) + self.assertEqual(cursor.fetchall(), []) + + def test_sqlite_master_query_returns_temp_message_schema(self): + connection = make_connection(FakeRawConnection()) + cursor = connection.cursor() + + cursor.execute("SELECT sql FROM sqlite_master WHERE type = 'table' AND name = 'temp_email_messages'") + row = cursor.fetchone() + + self.assertIsNotNone(row) + self.assertIn("UNIQUE(email_address, message_id)", row["sql"]) + + def test_pragma_table_info_uses_information_schema_rows(self): + pg_cursor = FakePgCursor(rows=[(0, "id", "integer", 1, None, 0)]) + connection = make_connection(FakeRawConnection([pg_cursor])) + + cursor = connection.cursor() + cursor.execute('PRAGMA table_info("accounts")') + rows = cursor.fetchall() + + self.assertEqual(pg_cursor.executions[0][1], ("accounts",)) + self.assertEqual(rows[0]["cid"], 0) + self.assertEqual(rows[0]["name"], "id") + self.assertEqual(rows[0]["type"], "integer") + + def test_helper_functions_cover_returning_params_and_sql_collapse(self): + self.assertEqual( + compat._append_returning_id_if_needed("INSERT INTO accounts (email) VALUES (%s);"), + "INSERT INTO accounts (email) VALUES (%s) RETURNING id", + ) + self.assertEqual( + compat._append_returning_id_if_needed("INSERT INTO accounts (email) VALUES (%s) RETURNING id"), + "INSERT INTO accounts (email) VALUES (%s) RETURNING id", + ) + self.assertEqual( + compat._append_returning_id_if_needed("INSERT INTO unknown_table (name) VALUES (%s)"), + "INSERT INTO unknown_table (name) VALUES (%s)", + ) + self.assertTrue(compat._returns_single_id("insert into accounts (email) values (%s) returning id")) + self.assertEqual(compat._normalize_params(["a", "b"]), ("a", "b")) + self.assertEqual(compat._normalize_params(("a", "b")), ("a", "b")) + self.assertIsNone(compat._normalize_params(None)) + self.assertEqual(compat._normalize_params({"email": "user@example.com"}), {"email": "user@example.com"}) + self.assertEqual(compat._collapse_sql(" SELECT 1\n FROM dual "), "SELECT 1 FROM dual") + + def test_placeholder_replacement_handles_escaped_quotes(self): + sql = compat.translate_sqlite_sql("SELECT '?' AS literal, 'it''s ?' AS escaped, value FROM settings WHERE key = ?") + + self.assertEqual( + sql, + "SELECT '?' AS literal, 'it''s ?' AS escaped, value FROM settings WHERE key = %s", + ) + + def test_postgres_shim_runs_full_init_db_without_sqlite_only_sql(self): + fake_psycopg = RecordingPsycopg() + + with patch.dict(sys.modules, {"psycopg": fake_psycopg}): + with patch.dict( + os.environ, + { + "DATABASE_URL": "postgresql://example/db", + "SECRET_KEY": "test-secret-key", + "LOGIN_PASSWORD": "admin123", + "TEMP_MAIL_API_KEY": "test-temp-api-key", + }, + clear=False, + ): + self.assertTrue(compat.install_postgres_sqlite_compat()) + from outlook_web import db + + db.init_db("/tmp/ignored.sqlite") + + self.assertEqual(fake_psycopg.database_urls, ["postgresql://example/db"]) + raw = fake_psycopg.connections[0] + executed_sql = "\n".join(sql for sql, _ in raw.executions) + + self.assertGreater(len(raw.executions), 100) + self.assertGreaterEqual(raw.commits, 1) + self.assertEqual(raw.rollbacks, 0) + self.assertEqual(raw.closes, 1) + self.assertIn("CREATE TABLE IF NOT EXISTS telegram_push_log", executed_sql) + self.assertIn("CREATE TABLE IF NOT EXISTS temp_email_messages", executed_sql) + self.assertIn("TO_CHAR(CURRENT_TIMESTAMP", executed_sql) + self.assertIn("ON CONFLICT (key) DO UPDATE", executed_sql) + + sqlite_only_fragments = [ + "sqlite_master", + "PRAGMA ", + "strftime", + "INSERT OR REPLACE", + "INSERT OR IGNORE", + "AUTOINCREMENT", + "COLLATE NOCASE", + "unixepoch", + "instr(", + "substr(", + ] + lowered_sql = executed_sql.lower() + for fragment in sqlite_only_fragments: + self.assertNotIn(fragment.lower(), lowered_sql) + + +if __name__ == "__main__": + unittest.main()