diff --git a/CHANGELOG.md b/CHANGELOG.md index c0b76b0..f06e571 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.3.9] + +### Added + +- Configurable base API URL for local D1 proxy support ([#22](https://github.com/CollierKing/sqlalchemy-cloudflare-d1/issues/22)) + - `Connection` and `AsyncConnection` now accept a `base_url` kwarg to override the Cloudflare endpoint + - Falls back to the `CF_D1_BASE_URL` environment variable, then the default Cloudflare URL + - Works with `create_engine(..., connect_args={"base_url": "http://localhost:8787"})` for the cleanest integration + - Enables local development against a `wrangler dev` D1 proxy without modifying source code + + ## [0.3.8] ### Added diff --git a/pyproject.toml b/pyproject.toml index 527167d..f359105 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sqlalchemy-cloudflare-d1" -version = "0.3.8" +version = "0.3.9" description = "A SQLAlchemy dialect for Cloudflare's D1 Serverless SQLite Database" readme = "README.md" authors = [ diff --git a/src/sqlalchemy_cloudflare_d1/connection.py b/src/sqlalchemy_cloudflare_d1/connection.py index af7a9c8..bd67b48 100644 --- a/src/sqlalchemy_cloudflare_d1/connection.py +++ b/src/sqlalchemy_cloudflare_d1/connection.py @@ -6,6 +6,7 @@ 2. Worker Binding - for use inside Cloudflare Python Workers (d1_binding) """ +import os from typing import Any, Dict, List, Optional, Sequence, Union try: @@ -475,8 +476,14 @@ def __init__(self, account_id: str, database_id: str, api_token: str, **kwargs): self.database_id = database_id self.api_token = api_token - # Build the D1 REST API URL - self.base_url = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/d1/database/{database_id}" + # Build the D1 REST API URL, allowing override via kwarg or CF_D1_BASE_URL env var + _default_base = ( + f"https://api.cloudflare.com/client/v4/accounts/{account_id}" + f"/d1/database/{database_id}" + ) + self.base_url = kwargs.get( + "base_url", os.environ.get("CF_D1_BASE_URL", _default_base) + ) # HTTP client self.client = httpx.Client( @@ -826,11 +833,14 @@ def __init__(self, account_id: str, database_id: str, api_token: str, **kwargs): self.database_id = database_id self.api_token = api_token - # Build the D1 REST API URL - self.base_url = ( + # Build the D1 REST API URL, allowing override via kwarg or CF_D1_BASE_URL env var + _default_base = ( f"https://api.cloudflare.com/client/v4/accounts/{account_id}" f"/d1/database/{database_id}" ) + self.base_url = kwargs.get( + "base_url", os.environ.get("CF_D1_BASE_URL", _default_base) + ) # Async HTTP client self.client = httpx.AsyncClient( diff --git a/tests/integration/test_base_url_integration.py b/tests/integration/test_base_url_integration.py new file mode 100644 index 0000000..c106a48 --- /dev/null +++ b/tests/integration/test_base_url_integration.py @@ -0,0 +1,175 @@ +"""Integration tests for the configurable base URL feature (Issue #22). + +These tests verify that CF_D1_BASE_URL and the base_url kwarg correctly +redirect HTTP requests away from the hard-coded Cloudflare endpoint. + +A minimal local HTTP server mimics the D1 /raw response format so that +no real Cloudflare credentials are required to run these tests. +""" + +import json +import os +import threading +import urllib.parse +from http.server import BaseHTTPRequestHandler, HTTPServer +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine, text + +from sqlalchemy_cloudflare_d1.connection import Connection + + +# MARK: - Local D1 Proxy Fixture + + +@pytest.fixture() +def local_d1_server(): + """Start a local HTTP server that mimics the D1 REST API /raw endpoint. + + Modelled after the local D1 proxy described in Issue #22. Responds with + the Cloudflare /raw response envelope so the Connection can parse it. + """ + + received_requests: list[str] = [] + + class _D1Handler(BaseHTTPRequestHandler): + def do_POST(self): + received_requests.append(self.path) + body_bytes = self.rfile.read(int(self.headers["Content-Length"])) + json.loads(body_bytes) # parse but ignore — any SQL is fine + + response = { + "success": True, + "result": [ + { + "results": { + "columns": ["value"], + "rows": [[42]], + }, + "meta": {}, + "success": True, + } + ], + "errors": [], + "messages": [], + } + payload = json.dumps(response).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(payload))) + self.end_headers() + self.wfile.write(payload) + + def log_message(self, format, *args): # noqa: A002 + pass # suppress server log noise in test output + + server = HTTPServer(("127.0.0.1", 0), _D1Handler) + port = server.server_address[1] + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + yield f"http://127.0.0.1:{port}", received_requests + + server.shutdown() + + +# MARK: - Base URL Override Tests + + +class TestBaseUrlOverride: + """Verify that base_url kwarg and CF_D1_BASE_URL env var route requests + to a custom endpoint instead of the default Cloudflare API URL. + """ + + def test_base_url_kwarg_redirects_requests(self, local_d1_server): + """Connection uses base_url kwarg instead of the Cloudflare endpoint.""" + base_url, received = local_d1_server + conn = Connection( + account_id="fake_account", + database_id="fake_db", + api_token="fake_token", + base_url=base_url, + ) + try: + cur = conn.cursor() + cur.execute("SELECT 1 AS value") + row = cur.fetchone() + assert row == (42,) + assert any("/raw" in path for path in received) + finally: + conn.close() + + def test_cf_d1_base_url_env_var_redirects_requests(self, local_d1_server): + """Connection reads CF_D1_BASE_URL from the environment and routes there.""" + base_url, received = local_d1_server + with patch.dict(os.environ, {"CF_D1_BASE_URL": base_url}): + conn = Connection( + account_id="fake_account", + database_id="fake_db", + api_token="fake_token", + ) + try: + cur = conn.cursor() + cur.execute("SELECT 1 AS value") + row = cur.fetchone() + assert row == (42,) + assert any("/raw" in path for path in received) + finally: + conn.close() + + def test_sqlalchemy_engine_with_connect_args(self, local_d1_server): + """Engine built with connect_args={'base_url': ...} routes to the local server.""" + base_url, received = local_d1_server + engine = create_engine( + "cloudflare_d1://fake_account:fake_token@fake_db", + connect_args={"base_url": base_url}, + ) + try: + with engine.connect() as conn: + result = conn.execute(text("SELECT 1 AS value")) + row = result.fetchone() + assert row == (42,) + assert any("/raw" in path for path in received) + finally: + engine.dispose() + + def test_sqlalchemy_engine_with_base_url_query_param(self, local_d1_server): + """Engine built from a URL with ?base_url= routes requests to the local server.""" + base_url, received = local_d1_server + encoded = urllib.parse.quote(base_url, safe="") + engine = create_engine( + f"cloudflare_d1://fake_account:fake_token@fake_db?base_url={encoded}" + ) + try: + with engine.connect() as conn: + result = conn.execute(text("SELECT 1 AS value")) + row = result.fetchone() + assert row == (42,) + assert any("/raw" in path for path in received) + finally: + engine.dispose() + + def test_base_url_kwarg_takes_precedence_over_env_var(self, local_d1_server): + """When both base_url kwarg and CF_D1_BASE_URL are set, kwarg wins.""" + base_url, received = local_d1_server + with patch.dict( + os.environ, {"CF_D1_BASE_URL": "http://should-not-be-used:9999"} + ): + conn = Connection( + account_id="fake_account", + database_id="fake_db", + api_token="fake_token", + base_url=base_url, + ) + try: + cur = conn.cursor() + cur.execute("SELECT 1 AS value") + row = cur.fetchone() + assert row == (42,) + finally: + conn.close() + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/unit/test_dialect.py b/tests/unit/test_dialect.py index 1c587c1..66104a7 100644 --- a/tests/unit/test_dialect.py +++ b/tests/unit/test_dialect.py @@ -234,5 +234,65 @@ def test_async_dbapi_module(): assert hasattr(AsyncAdapt_d1_dbapi, "connect") +def test_connection_base_url_kwarg(): + """Test that Connection accepts a custom base_url kwarg.""" + from unittest.mock import patch + + from sqlalchemy_cloudflare_d1.connection import Connection + + with patch("sqlalchemy_cloudflare_d1.connection.httpx") as mock_httpx: + mock_httpx.Client.return_value = None + conn = Connection( + account_id="acct", + database_id="db", + api_token="token", + base_url="http://localhost:8787", + ) + assert conn.base_url == "http://localhost:8787" + + +def test_connection_base_url_env_var(): + """Test that Connection reads CF_D1_BASE_URL from the environment.""" + import os + from unittest.mock import patch + + from sqlalchemy_cloudflare_d1.connection import Connection + + with patch.dict(os.environ, {"CF_D1_BASE_URL": "http://localhost:9999"}): + with patch("sqlalchemy_cloudflare_d1.connection.httpx") as mock_httpx: + mock_httpx.Client.return_value = None + conn = Connection(account_id="acct", database_id="db", api_token="token") + assert conn.base_url == "http://localhost:9999" + + +def test_connection_base_url_default(): + """Test that Connection falls back to the Cloudflare URL when no override is set.""" + import os + from unittest.mock import patch + + from sqlalchemy_cloudflare_d1.connection import Connection + + env = {k: v for k, v in os.environ.items() if k != "CF_D1_BASE_URL"} + with patch.dict(os.environ, env, clear=True): + with patch("sqlalchemy_cloudflare_d1.connection.httpx") as mock_httpx: + mock_httpx.Client.return_value = None + conn = Connection(account_id="acct", database_id="db", api_token="token") + assert conn.base_url == ( + "https://api.cloudflare.com/client/v4/accounts/acct/d1/database/db" + ) + + +def test_base_url_passthrough_via_url_query(): + """Test that base_url can be passed via the connection string query params.""" + from sqlalchemy.engine.url import make_url + + dialect = CloudflareD1Dialect() + url = make_url( + "cloudflare_d1://acct:token@db?base_url=http%3A%2F%2Flocalhost%3A8787" + ) + _, kwargs = dialect.create_connect_args(url) + assert kwargs["base_url"] == "http://localhost:8787" + + if __name__ == "__main__": pytest.main([__file__])