Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions packages/bub-tapestore-sqlalchemy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ The plugin reads environment variables with prefix `BUB_TAPESTORE_SQLALCHEMY_`:
- `BUB_TAPESTORE_SQLALCHEMY_URL` (optional): SQLAlchemy database URL
- Default: `sqlite+pysqlite:///<BUB_HOME>/tapes.db`
- `BUB_TAPESTORE_SQLALCHEMY_ECHO` (optional, default: `false`)
- `BUB_TAPESTORE_SQLALCHEMY_CONNECT_ARGS` (optional): JSON object forwarded to
`sqlalchemy.create_engine(connect_args=...)`. User-supplied keys override the
built-in defaults. Useful for backends that need keyword-only arguments which
cannot be expressed in the URL — for example Turso/libSQL auth tokens:

```bash
export BUB_TAPESTORE_SQLALCHEMY_URL='sqlite+libsql://your-db.turso.io/?secure=true'
export BUB_TAPESTORE_SQLALCHEMY_CONNECT_ARGS='{"auth_token": "<TURSO_AUTH_TOKEN>"}'
```

## Runtime Behavior

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

import json
from collections.abc import Callable
from functools import lru_cache
from pathlib import Path
from typing import Any

import bub
from bub import hookimpl
from bub import inquirer as bub_inquirer
from pydantic import Field
from pydantic import Field, field_validator
from pydantic_settings import SettingsConfigDict
from sqlalchemy import URL

import bub
from bub import hookimpl
from bub import inquirer as bub_inquirer
from bub_tapestore_sqlalchemy.store import SQLAlchemyTapeStore

CONFIG_NAME = "tapestore-sqlalchemy"
Expand All @@ -38,6 +39,38 @@ class SQLAlchemyTapeStoreSettings(bub.Settings):
default=False,
validation_alias="BUB_TAPESTORE_SQLALCHEMY_ECHO",
)
connect_args: Any = Field(
default_factory=dict,
validation_alias="BUB_TAPESTORE_SQLALCHEMY_CONNECT_ARGS",
description=(
"Extra keyword arguments forwarded to ``sqlalchemy.create_engine(connect_args=...)``. "
"Provide a JSON object via the env var, for example "
'\'{"auth_token": "..."}\' for Turso/libSQL.'
),
)

@field_validator("connect_args", mode="before")
@classmethod
def _decode_connect_args(cls, value: object) -> dict[str, Any]:
if value is None or value == "":
return {}
if isinstance(value, dict):
return value
if isinstance(value, str):
try:
parsed = json.loads(value)
except json.JSONDecodeError as exc:
raise ValueError(
"BUB_TAPESTORE_SQLALCHEMY_CONNECT_ARGS must be a JSON object"
) from exc
if not isinstance(parsed, dict):
raise TypeError(
"BUB_TAPESTORE_SQLALCHEMY_CONNECT_ARGS must decode to a JSON object"
)
return parsed
raise TypeError(
"BUB_TAPESTORE_SQLALCHEMY_CONNECT_ARGS must be a JSON object string or a mapping"
)

@classmethod
def from_env(cls) -> SQLAlchemyTapeStoreSettings:
Expand All @@ -56,7 +89,11 @@ def _build_store(
),
) -> SQLAlchemyTapeStore:
settings = settings_factory()
return SQLAlchemyTapeStore(url=settings.resolved_url, echo=settings.echo)
return SQLAlchemyTapeStore(
url=settings.resolved_url,
echo=settings.echo,
connect_args=settings.connect_args or None,
)


@lru_cache(maxsize=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@


class SQLAlchemyTapeStore(InMemoryQueryMixin):
def __init__(self, url: str, *, echo: bool = False) -> None:
def __init__(
self,
url: str,
*,
echo: bool = False,
connect_args: dict[str, object] | None = None,
) -> None:
self._url = self._normalize_url(url)
self._echo = echo
self._write_lock = threading.RLock()
Expand All @@ -25,7 +31,7 @@ def __init__(self, url: str, *, echo: bool = False) -> None:
echo=echo,
future=True,
pool_pre_ping=True,
connect_args=self._connect_args(self._url),
connect_args=self._build_connect_args(self._url, connect_args or {}),
)
self._configure_engine(self._engine)
self._session_factory = sessionmaker(
Expand All @@ -43,30 +49,28 @@ def list_tapes(self) -> list[str]:
)

def reset(self, tape: str) -> None:
with self._write_lock:
with self._session_factory.begin() as session:
tape_record = self._find_tape_record(session, tape)
if tape_record is not None:
session.delete(tape_record)
with self._write_lock, self._session_factory.begin() as session:
tape_record = self._find_tape_record(session, tape)
if tape_record is not None:
session.delete(tape_record)

def append(self, tape: str, entry: TapeEntry) -> None:
with self._write_lock:
with self._session_factory.begin() as session:
tape_record = self._load_or_create_tape(session, tape)
next_entry_id = self._next_entry_id(session, tape_record)
anchor_name = self._anchor_name_of(entry)
session.add(
TapeEntryRecord(
tape_id=tape_record.id,
entry_id=next_entry_id,
kind=entry.kind,
anchor_name=anchor_name,
anchor_name_key=self._key_for(anchor_name) if anchor_name else None,
payload=dict(entry.payload),
meta=dict(entry.meta),
entry_date=entry.date,
)
with self._write_lock, self._session_factory.begin() as session:
tape_record = self._load_or_create_tape(session, tape)
next_entry_id = self._next_entry_id(session, tape_record)
anchor_name = self._anchor_name_of(entry)
session.add(
TapeEntryRecord(
tape_id=tape_record.id,
entry_id=next_entry_id,
kind=entry.kind,
anchor_name=anchor_name,
anchor_name_key=self._key_for(anchor_name) if anchor_name else None,
payload=dict(entry.payload),
meta=dict(entry.meta),
entry_date=entry.date,
)
)

def fetch_all(self, query: TapeQuery) -> Iterable[TapeEntry]:
normalized_query = replace(query, _kinds=self._normalized_kinds(query._kinds))
Expand Down Expand Up @@ -134,17 +138,22 @@ def _normalize_url(url: str) -> URL:
raise ValueError(f"Invalid SQLAlchemy URL: {url}") from exc

@staticmethod
def _connect_args(url: URL) -> dict[str, object]:
def _build_connect_args(url: URL, overrides: dict[str, object]) -> dict[str, object]:
defaults: dict[str, object] = {}
if url.get_backend_name() == "sqlite":
return {
"check_same_thread": False,
"timeout": 30,
}
return {}
defaults["check_same_thread"] = False
# ``timeout`` is a pysqlite-only kwarg; libSQL drivers reject it.
if url.get_driver_name() == "pysqlite":
defaults["timeout"] = 30
return {**defaults, **overrides}

@staticmethod
def _configure_engine(engine: Engine) -> None:
if engine.url.get_backend_name() != "sqlite":
# PRAGMAs are only safe to issue against the local pysqlite driver.
# Remote libSQL endpoints (Hrana over HTTP/WebSocket) reject PRAGMA
# statements, and embedded replicas pick the same defaults from
# ``libsql_experimental`` automatically.
if engine.url.get_driver_name() != "pysqlite":
return

@event.listens_for(engine, "connect")
Expand Down
34 changes: 34 additions & 0 deletions packages/bub-tapestore-sqlalchemy/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path

import bub_tapestore_sqlalchemy.plugin as plugin
import pytest
from bub_tapestore_sqlalchemy.plugin import SQLAlchemyTapeStoreSettings
from bub_tapestore_sqlalchemy.store import SQLAlchemyTapeStore

Expand Down Expand Up @@ -80,3 +81,36 @@ def test_onboard_config_skips_sqlalchemy_when_declined(monkeypatch) -> None:
monkeypatch.setattr(plugin.bub_inquirer, "ask_confirm", lambda *args, **kwargs: False)

assert plugin.onboard_config({}) is None


def test_connect_args_default_empty(monkeypatch) -> None:
monkeypatch.delenv("BUB_TAPESTORE_SQLALCHEMY_CONNECT_ARGS", raising=False)

settings = SQLAlchemyTapeStoreSettings.from_env()

assert settings.connect_args == {}


def test_connect_args_decoded_from_env_json(monkeypatch) -> None:
monkeypatch.setenv(
"BUB_TAPESTORE_SQLALCHEMY_CONNECT_ARGS",
'{"auth_token": "abc", "sync_interval": 5}',
)

settings = SQLAlchemyTapeStoreSettings.from_env()

assert settings.connect_args == {"auth_token": "abc", "sync_interval": 5}


def test_connect_args_invalid_json_is_rejected(monkeypatch) -> None:
monkeypatch.setenv("BUB_TAPESTORE_SQLALCHEMY_CONNECT_ARGS", "not-json")

with pytest.raises(Exception, match="JSON object"):
SQLAlchemyTapeStoreSettings.from_env()


def test_connect_args_must_be_object(monkeypatch) -> None:
monkeypatch.setenv("BUB_TAPESTORE_SQLALCHEMY_CONNECT_ARGS", "[1, 2, 3]")

with pytest.raises(Exception, match="JSON object"):
SQLAlchemyTapeStoreSettings.from_env()
65 changes: 63 additions & 2 deletions packages/bub-tapestore-sqlalchemy/tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from pathlib import Path

import pytest
from republic import RepublicError, TapeEntry, TapeQuery

from bub_tapestore_sqlalchemy.store import SQLAlchemyTapeStore
from republic import RepublicError, TapeEntry, TapeQuery


def _store(tmp_path: Path) -> SQLAlchemyTapeStore:
Expand Down Expand Up @@ -236,3 +235,65 @@ def test_read_missing_tape_matches_builtin_shape(tmp_path: Path) -> None:
store = _store(tmp_path)

assert store.read("missing__tape") == []


def test_build_connect_args_defaults_for_local_sqlite() -> None:
from sqlalchemy.engine import make_url

args = SQLAlchemyTapeStore._build_connect_args(
make_url("sqlite+pysqlite:///tapes.db"),
{},
)

assert args == {"check_same_thread": False, "timeout": 30}


def test_build_connect_args_skips_pysqlite_only_kwargs_for_libsql() -> None:
from sqlalchemy.engine import make_url

args = SQLAlchemyTapeStore._build_connect_args(
make_url("sqlite+libsql://example.turso.io/?secure=true"),
{},
)

assert "timeout" not in args
assert args["check_same_thread"] is False


def test_build_connect_args_lets_overrides_win() -> None:
from sqlalchemy.engine import make_url

args = SQLAlchemyTapeStore._build_connect_args(
make_url("sqlite+pysqlite:///tapes.db"),
{"timeout": 99, "extra": "value"},
)

assert args["timeout"] == 99
assert args["extra"] == "value"
assert args["check_same_thread"] is False


def test_store_init_forwards_connect_args(tmp_path: Path) -> None:
"""connect_args reach create_engine and override the built-in defaults."""
captured: dict[str, object] = {}

real_create_engine = SQLAlchemyTapeStore.__init__.__globals__["create_engine"]

def spy(url, **kwargs):
captured.update(kwargs.get("connect_args", {}))
return real_create_engine(url, **kwargs)

SQLAlchemyTapeStore.__init__.__globals__["create_engine"] = spy
try:
SQLAlchemyTapeStore(
f"sqlite+pysqlite:///{tmp_path / 'tapes.db'}",
# ``cached_statements`` is a real pysqlite kwarg, so the engine still
# connects; ``timeout`` override proves user values beat the defaults.
connect_args={"timeout": 12, "cached_statements": 50},
)
finally:
SQLAlchemyTapeStore.__init__.__globals__["create_engine"] = real_create_engine

assert captured["timeout"] == 12
assert captured["cached_statements"] == 50
assert captured["check_same_thread"] is False
Loading