From dca790ae96063df866c68fc6b81d657aefb16388 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Fri, 8 May 2026 13:40:23 +0200 Subject: [PATCH 01/22] refactor: share WASM fetch and replacement patch utilities --- marimo/_runtime/_wasm/_fetch.py | 40 +++++++++++++++++++++++++++++++ marimo/_runtime/_wasm/_patches.py | 29 ++++++++++++++++++++++ marimo/_runtime/_wasm/_polars.py | 14 ++--------- tests/_runtime/test_patches.py | 29 ++++++++++++++++++++++ 4 files changed, 100 insertions(+), 12 deletions(-) create mode 100644 marimo/_runtime/_wasm/_fetch.py diff --git a/marimo/_runtime/_wasm/_fetch.py b/marimo/_runtime/_wasm/_fetch.py new file mode 100644 index 00000000000..c84d1bf3d19 --- /dev/null +++ b/marimo/_runtime/_wasm/_fetch.py @@ -0,0 +1,40 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Shared URL fetch for WASM fallbacks.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, TypedDict, cast + +if TYPE_CHECKING: + import ssl + + +class RequestKwargs(TypedDict, total=False): + data: bytes | None + headers: dict[str, str] + origin_req_host: str + unverifiable: bool + method: str + + +class UrlOpenKwargs(TypedDict, total=False): + data: bytes | None + timeout: float + context: ssl.SSLContext | None + + +def fetch_url_bytes( + url: str, + *, + request_kwargs: RequestKwargs | None = None, + urlopen_kwargs: UrlOpenKwargs | None = None, +) -> bytes: + """Sync fetch via urllib — works in pyodide thanks to pyodide_http's + patch_all (installed at marimo startup), which routes urllib through JS + fetch. Single sync path for text and binary. + """ + import urllib.request + + request = urllib.request.Request(url, **(request_kwargs or {})) + with urllib.request.urlopen(request, **(urlopen_kwargs or {})) as response: + return cast(bytes, response.read()) diff --git a/marimo/_runtime/_wasm/_patches.py b/marimo/_runtime/_wasm/_patches.py index a726d45d7f6..ae628a63c3b 100644 --- a/marimo/_runtime/_wasm/_patches.py +++ b/marimo/_runtime/_wasm/_patches.py @@ -19,6 +19,7 @@ Unpatch = Callable[[], None] Fallback = Callable[..., Any] +WrapperFactory = Callable[[Callable[..., Any]], Callable[..., Any]] class WasmPatchSet: @@ -80,6 +81,34 @@ def _unpatch() -> None: self._unpatches.append(_unpatch) + def replace( + self, + owner: Any, + attr: str, + wrapper_factory: WrapperFactory, + ) -> None: + """Replace ``owner.attr`` with a WASM-only wrapper. + + Unlike ``patch``, this does not call the original first. Use this for + APIs where an original call can have side effects before failing. + """ + if not self._active: + return + + original = getattr(owner, attr, None) + if original is None: + return + + wrapper = wrapper_factory(original) + setattr(owner, attr, wrapper) + + def _unpatch() -> None: + # Only restore if we're still the active wrapper. + if getattr(owner, attr, None) is wrapper: + setattr(owner, attr, original) + + self._unpatches.append(_unpatch) + def unpatch_all(self) -> Unpatch: """Return a callable that restores all originals (idempotent).""" unpatches = self._unpatches diff --git a/marimo/_runtime/_wasm/_polars.py b/marimo/_runtime/_wasm/_polars.py index 36571c6f060..c212653588b 100644 --- a/marimo/_runtime/_wasm/_polars.py +++ b/marimo/_runtime/_wasm/_polars.py @@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any from marimo import _loggers +from marimo._runtime._wasm import _fetch from marimo._runtime._wasm._patches import Unpatch, WasmPatchSet from marimo._utils.platform import is_pyodide @@ -47,17 +48,6 @@ def _is_url(source: Any) -> bool: ) -def _fetch_url_bytes(url: str) -> bytes: - """Sync fetch via urllib — works in pyodide thanks to pyodide_http's - patch_all (installed at marimo startup), which routes urllib through JS - fetch. Single sync path for text and binary. - """ - import urllib.request - - with urllib.request.urlopen(url) as response: - return response.read() # type: ignore[no-any-return] - - def _resolve_to_bytes(source: Any) -> bytes: """Coerce a polars I/O source into raw bytes.""" if isinstance(source, bytes): @@ -73,7 +63,7 @@ def _resolve_to_bytes(source: Any) -> bytes: return source.read_bytes() if isinstance(source, str): if _is_url(source): - return _fetch_url_bytes(source) + return _fetch.fetch_url_bytes(source) return pathlib.Path(source).read_bytes() raise TypeError( f"Unsupported source type for WASM polars fallback: {type(source)!r}" diff --git a/tests/_runtime/test_patches.py b/tests/_runtime/test_patches.py index abd42633c96..9b820f8f315 100644 --- a/tests/_runtime/test_patches.py +++ b/tests/_runtime/test_patches.py @@ -2,6 +2,7 @@ from __future__ import annotations import io +from contextlib import nullcontext from typing import TYPE_CHECKING from unittest.mock import patch @@ -332,6 +333,34 @@ def test_skips_missing_attr() -> None: patches.unpatch_all()() +def test_fetch_url_bytes_forwards_request_and_urlopen_kwargs() -> None: + import urllib.request + + from marimo._runtime._wasm._fetch import fetch_url_bytes + + with patch( + "urllib.request.urlopen", + return_value=nullcontext(io.BytesIO(b"ok")), + ) as urlopen: + assert ( + fetch_url_bytes( + "https://example.com/cars.csv", + request_kwargs={ + "headers": {"User-Agent": "marimo"}, + "method": "GET", + }, + urlopen_kwargs={"timeout": 5}, + ) + == b"ok" + ) + + request = urlopen.call_args.args[0] + assert isinstance(request, urllib.request.Request) + assert request.get_header("User-agent") == "marimo" + assert request.get_method() == "GET" + assert urlopen.call_args.kwargs == {"timeout": 5} + + @pytest.mark.requires("polars", "pyarrow") class TestPolarsIoWasmPatch: @staticmethod From b6351cf365e610805551a83b2490d945e585d422 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Fri, 8 May 2026 13:40:52 +0200 Subject: [PATCH 02/22] feat: add DuckDB WASM remote-source fallbacks --- marimo/_runtime/_wasm/_duckdb/__init__.py | 677 +++++++++++++++++++++ marimo/_runtime/_wasm/_duckdb/dataframe.py | 165 +++++ marimo/_runtime/_wasm/_duckdb/io.py | 493 +++++++++++++++ marimo/_runtime/_wasm/_duckdb/sources.py | 160 +++++ 4 files changed, 1495 insertions(+) create mode 100644 marimo/_runtime/_wasm/_duckdb/__init__.py create mode 100644 marimo/_runtime/_wasm/_duckdb/dataframe.py create mode 100644 marimo/_runtime/_wasm/_duckdb/io.py create mode 100644 marimo/_runtime/_wasm/_duckdb/sources.py diff --git a/marimo/_runtime/_wasm/_duckdb/__init__.py b/marimo/_runtime/_wasm/_duckdb/__init__.py new file mode 100644 index 00000000000..a810243780c --- /dev/null +++ b/marimo/_runtime/_wasm/_duckdb/__init__.py @@ -0,0 +1,677 @@ +# Copyright 2026 Marimo. All rights reserved. +"""WASM-only DuckDB fallbacks for remote file scans. + +DuckDB-WASM cannot use ``httpfs``. marimo fetches supported URLs itself, +materializes supported files as pandas DataFrames, then hands those frames +back to DuckDB. + +We have two concrete use cases to patch: + +* **Direct read methods** — ``duckdb.read_csv/read_parquet/read_json`` are + wrapped with :class:`WasmPatchSet`; supported remote URLs are fetched by + marimo and returned as DuckDB relations. +* **SQL remote scans** — raw DuckDB APIs and marimo's ``mo.sql`` path call the + same rewrite helper, which replaces supported URLs with generated + replacement-scan tables backed by fetched pandas DataFrames. +""" + +from __future__ import annotations + +import functools +import inspect +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from marimo import _loggers +from marimo._dependencies.dependencies import DependencyManager +from marimo._runtime._wasm._duckdb.io import ( + RemoteFileSource, + remote_file_source_from_reader_args, +) +from marimo._runtime._wasm._duckdb.sources import ( + remote_file_source_from_table, +) +from marimo._runtime._wasm._patches import ( + Unpatch, + WasmPatchSet, + WrapperFactory, +) +from marimo._utils.platform import is_pyodide + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping, Sequence + + import pandas as pd + from sqlglot import exp + +LOGGER = _loggers.marimo_logger() + +# DuckDB SQL APIs also accept non-string query objects; this marks "not found". +_MISSING = object() +_DIRECT_READER_NAMES: tuple[str, ...] = ( + "read_csv", + "read_parquet", + "read_json", +) +_SQL_CALL_EXPRESSION = "{original}(*{args}, **{kwargs})" +_SQL_HELPER_NAME_BASES = ( + "__marimo_wasm_duckdb_original", + "__marimo_wasm_duckdb_args", + "__marimo_wasm_duckdb_kwargs", +) +_MODULE_SQL_FUNCTIONS: dict[str, tuple[int, tuple[str, ...]]] = { + "sql": (0, ("query",)), + "query": (0, ("query",)), + "execute": (0, ("query",)), + "query_df": (2, ("sql_query", "query")), +} +_CONNECTION_SQL_METHODS: dict[str, tuple[int, tuple[str, ...]]] = { + "sql": (1, ("query",)), + "query": (1, ("query",)), + "execute": (1, ("query",)), +} +_SOURCE_KWARGS: dict[str, tuple[str, ...]] = { + "read_csv": ("path_or_buffer", "source", "file", "path"), + "read_parquet": ("file_glob", "file_globs", "source", "file", "path"), + "read_json": ("path_or_buffer", "source", "file", "path"), +} + + +@dataclass(frozen=True) +class WasmDuckDBQueryPatch: + """A rewritten DuckDB query and the DataFrames it references.""" + + query: str + tables: Mapping[str, pd.DataFrame] + + +@dataclass(frozen=True) +class WasmDuckDBSqlResult: + """Result of a DuckDB SQL API call that was rewritten for WASM.""" + + value: Any + + +@dataclass(frozen=True) +class _EvalBindingNames: + original: str + args: str + kwargs: str + + +class _RemoteTableNames: + """Track URL sources and generated table names for one SQL rewrite.""" + + def __init__(self, reserved_names: Sequence[str]) -> None: + self._reserved_names = { + _duckdb_identifier_key(name) for name in reserved_names + } + self._names_by_source: dict[RemoteFileSource, str] = {} + + def __bool__(self) -> bool: + return bool(self._names_by_source) + + def name_for(self, source: RemoteFileSource) -> str: + name = self._names_by_source.get(source) + if name is not None: + return name + + idx = len(self._names_by_source) + while True: + candidate = f"__marimo_wasm_duckdb_remote_{idx}" + candidate_key = _duckdb_identifier_key(candidate) + if candidate_key not in self._reserved_names: + self._names_by_source[source] = candidate + self._reserved_names.add(candidate_key) + return candidate + idx += 1 + + def read_dataframes(self) -> dict[str, pd.DataFrame]: + return { + table_name: source.read_dataframe() + for source, table_name in self._names_by_source.items() + } + + +def _duckdb_identifier_key(name: str) -> str: + # DuckDB preserves identifier case but resolves names case-insensitively. + return name.casefold() + + +def patch_duckdb_query_for_wasm( + query: str, + *, + reserved_names: Sequence[str] = (), +) -> WasmDuckDBQueryPatch | None: + """Rewrite remote file sources to generated DataFrame names. + + Example: ``SELECT * FROM read_csv('https://example.com/cars.csv')`` + becomes ``SELECT * FROM __marimo_wasm_duckdb_remote_0``, and + ``tables["__marimo_wasm_duckdb_remote_0"]`` holds the fetched DataFrame. + + In Pyodide this raises if sqlglot is unavailable. Returns ``None`` when + not running in Pyodide, when the query has no supported remote file + source, or when the query cannot be parsed. + """ + if not is_pyodide(): + return None + + _require_sqlglot() + statements = _parse_duckdb_query(query) + if statements is None: + return None + + table_names = _RemoteTableNames( + (*reserved_names, *_reserved_sql_names(statements)) + ) + patched_statements = _replace_remote_sources(statements, table_names) + if not table_names: + return None + + _require_pandas() + + return WasmDuckDBQueryPatch( + query=_format_duckdb_query(patched_statements, original_query=query), + tables=table_names.read_dataframes(), + ) + + +def patch_duckdb_for_wasm() -> Unpatch: + """Install WASM fallbacks for DuckDB remote file and SQL APIs.""" + if not is_pyodide(): + return lambda: None + + try: + import duckdb + except ImportError: + return lambda: None + + # DuckDB's WASM patch set is atomic: direct reader and SQL API wrappers + # install together, and SQL rewriting hard-requires sqlglot. + _require_sqlglot() + + patches = WasmPatchSet() + for function_name in _DIRECT_READER_NAMES: + patches.replace( + duckdb, + function_name, + _make_direct_reader_wrapper( + function_name, + source_arg_index=0, + connection_arg_index=None, + ), + ) + patches.replace( + duckdb.DuckDBPyConnection, + function_name, + _make_direct_reader_wrapper( + function_name, + source_arg_index=1, + connection_arg_index=0, + ), + ) + for function_name, ( + query_index, + query_kwargs, + ) in _MODULE_SQL_FUNCTIONS.items(): + patches.replace( + duckdb, + function_name, + _make_sql_api_wrapper( + query_arg_index=query_index, + query_kwarg_names=query_kwargs, + ), + ) + for method_name, ( + query_index, + query_kwargs, + ) in _CONNECTION_SQL_METHODS.items(): + patches.replace( + duckdb.DuckDBPyConnection, + method_name, + _make_sql_api_wrapper( + query_arg_index=query_index, + query_kwarg_names=query_kwargs, + ), + ) + return patches.unpatch_all() + + +def run_duckdb_sql_with_wasm_patch( + original: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + query_arg_index: int, + query_kwarg_names: tuple[str, ...], + eval_globals: dict[str, Any], + eval_locals: Mapping[str, Any], + reserved_names: Sequence[str] = (), +) -> Any: + """Run a DuckDB SQL API call after rewriting supported remote scans.""" + wasm_result = try_run_duckdb_sql_with_wasm_patch( + original, + args, + kwargs, + query_arg_index=query_arg_index, + query_kwarg_names=query_kwarg_names, + eval_globals=eval_globals, + eval_locals=eval_locals, + reserved_names=reserved_names, + ) + if wasm_result is not None: + return wasm_result.value + + kwargs_dict = dict(kwargs) + binding_names = _eval_binding_names( + _reserved_namespace_names( + eval_globals, + eval_locals, + ( + *reserved_names, + *_identifier_string_args(args), + *_identifier_string_args(tuple(kwargs_dict.values())), + ), + ) + ) + return _eval_duckdb_original_call( + original, + args, + kwargs_dict, + eval_globals=eval_globals, + eval_locals=eval_locals, + binding_names=binding_names, + ) + + +def try_run_duckdb_sql_with_wasm_patch( + original: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + query_arg_index: int, + query_kwarg_names: tuple[str, ...], + eval_globals: dict[str, Any], + eval_locals: Mapping[str, Any], + reserved_names: Sequence[str] = (), +) -> WasmDuckDBSqlResult | None: + """Run only if a DuckDB SQL API call needs a WASM rewrite.""" + if not is_pyodide(): + return None + + kwargs_dict = dict(kwargs) + query = _query_argument( + args, + kwargs_dict, + query_arg_index=query_arg_index, + query_kwarg_names=query_kwarg_names, + ) + if not isinstance(query, str): + return None + + namespace_names = _reserved_namespace_names( + eval_globals, + eval_locals, + ( + *reserved_names, + *_duckdb_catalog_names(original, args), + *_identifier_string_args(args), + *_identifier_string_args(tuple(kwargs_dict.values())), + ), + ) + binding_names = _eval_binding_names(namespace_names) + + wasm_patch = patch_duckdb_query_for_wasm( + query, + reserved_names=( + *namespace_names, + binding_names.original, + binding_names.args, + binding_names.kwargs, + ), + ) + if wasm_patch is None: + return None + + patched_args, patched_kwargs = _replace_query_argument( + args, + kwargs_dict, + patched_query=wasm_patch.query, + query_arg_index=query_arg_index, + query_kwarg_names=query_kwarg_names, + ) + return WasmDuckDBSqlResult( + _eval_duckdb_original_call( + original, + patched_args, + patched_kwargs, + eval_globals=eval_globals, + eval_locals=eval_locals, + binding_names=binding_names, + extra_locals=wasm_patch.tables, + ) + ) + + +def _make_sql_api_wrapper( + *, + query_arg_index: int, + query_kwarg_names: tuple[str, ...], +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def _make_wrapper(original: Callable[..., Any]) -> Callable[..., Any]: + @functools.wraps(original) + def _wrapper(*args: Any, **kwargs: Any) -> Any: + frame = inspect.currentframe() + if frame is None or frame.f_back is None: + return original(*args, **kwargs) + caller_frame = frame.f_back + try: + return run_duckdb_sql_with_wasm_patch( + original, + args, + kwargs, + query_arg_index=query_arg_index, + query_kwarg_names=query_kwarg_names, + eval_globals=caller_frame.f_globals, + eval_locals=caller_frame.f_locals, + ) + finally: + del frame + + return _wrapper + + return _make_wrapper + + +def _query_argument( + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + query_arg_index: int, + query_kwarg_names: tuple[str, ...], +) -> Any: + if len(args) > query_arg_index: + return args[query_arg_index] + for name in query_kwarg_names: + if name in kwargs: + return kwargs[name] + return _MISSING + + +def _replace_query_argument( + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + patched_query: str, + query_arg_index: int, + query_kwarg_names: tuple[str, ...], +) -> tuple[tuple[Any, ...], dict[str, Any]]: + kwargs_dict = dict(kwargs) + if len(args) > query_arg_index: + patched_args = list(args) + patched_args[query_arg_index] = patched_query + return tuple(patched_args), kwargs_dict + + for name in query_kwarg_names: + if name in kwargs_dict: + kwargs_dict[name] = patched_query + return args, kwargs_dict + return args, kwargs_dict + + +def _eval_duckdb_original_call( + original: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + eval_globals: dict[str, Any], + eval_locals: Mapping[str, Any], + binding_names: _EvalBindingNames, + extra_locals: Mapping[str, Any] | None = None, +) -> Any: + locals_for_eval = dict(eval_locals) + if extra_locals is not None: + locals_for_eval.update(extra_locals) + locals_for_eval[binding_names.original] = inspect.unwrap(original) + locals_for_eval[binding_names.args] = args + locals_for_eval[binding_names.kwargs] = dict(kwargs) + + return eval( + _SQL_CALL_EXPRESSION.format( + original=binding_names.original, + args=binding_names.args, + kwargs=binding_names.kwargs, + ), + eval_globals, + locals_for_eval, + ) + + +def _reserved_namespace_names( + eval_globals: Mapping[str, Any], + eval_locals: Mapping[str, Any], + reserved_names: Sequence[str], +) -> tuple[str, ...]: + names = set(reserved_names) + names.update(eval_globals) + names.update(eval_locals) + return tuple(names) + + +def _identifier_string_args(args: tuple[Any, ...]) -> tuple[str, ...]: + return tuple( + arg for arg in args if isinstance(arg, str) and arg.isidentifier() + ) + + +def _eval_binding_names(reserved_names: Sequence[str]) -> _EvalBindingNames: + used = set(reserved_names) + original = _unused_name(_SQL_HELPER_NAME_BASES[0], used) + used.add(original) + args = _unused_name(_SQL_HELPER_NAME_BASES[1], used) + used.add(args) + kwargs = _unused_name(_SQL_HELPER_NAME_BASES[2], used) + return _EvalBindingNames(original=original, args=args, kwargs=kwargs) + + +def _unused_name(base: str, used: set[str]) -> str: + if base not in used: + return base + + idx = 0 + while True: + candidate = f"{base}_{idx}" + if candidate not in used: + return candidate + idx += 1 + + +def _duckdb_catalog_names( + original: Callable[..., Any], + args: tuple[Any, ...], +) -> tuple[str, ...]: + try: + relation = _show_duckdb_tables(original, args) + rows = relation.fetchall() + except Exception: + return () + return tuple(str(row[0]) for row in rows if row and row[0] is not None) + + +def _show_duckdb_tables( + original: Callable[..., Any], + args: tuple[Any, ...], +) -> Any: + import duckdb + + original_call = inspect.unwrap(original) + if args and isinstance(args[0], duckdb.DuckDBPyConnection): + return inspect.unwrap(type(args[0]).sql)(args[0], "SHOW TABLES") + if original_call is inspect.unwrap(duckdb.query_df): + return inspect.unwrap(duckdb.sql)("SHOW TABLES") + return original_call("SHOW TABLES") + + +def _make_direct_reader_wrapper( + function_name: str, + *, + source_arg_index: int, + connection_arg_index: int | None, +) -> WrapperFactory: + def _wrap(original: Callable[..., Any]) -> Callable[..., Any]: + @functools.wraps(original) + def _wrapper(*args: Any, **kwargs: Any) -> Any: + source_info = _direct_reader_source( + function_name, + args, + kwargs, + source_arg_index=source_arg_index, + connection_arg_index=connection_arg_index, + ) + if source_info is None: + return original(*args, **kwargs) + + DependencyManager.pandas.require( + f"to read DuckDB {function_name} sources in WASM" + ) + + source, connection = source_info + import duckdb + + df = source.read_dataframe() + if connection is None: + return duckdb.from_df(df) + return duckdb.from_df(df, connection=connection) + + return _wrapper + + return _wrap + + +def _direct_reader_source( + function_name: str, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + source_arg_index: int, + connection_arg_index: int | None, +) -> tuple[RemoteFileSource, Any] | None: + options = dict(kwargs) + try: + source, rest_args = _pop_source_argument( + function_name, + args, + options, + source_arg_index=source_arg_index, + ) + except TypeError: + return None + if rest_args: + return None + + if connection_arg_index is None: + connection = options.pop("connection", None) + else: + connection = args[connection_arg_index] + source_info = remote_file_source_from_reader_args( + function_name, source, options + ) + if source_info is None: + return None + return source_info, connection + + +def _pop_source_argument( + function_name: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + source_arg_index: int, +) -> tuple[Any, tuple[Any, ...]]: + if len(args) > source_arg_index: + return args[source_arg_index], args[source_arg_index + 1 :] + + for key in _SOURCE_KWARGS[function_name]: + if key in kwargs: + return kwargs.pop(key), args[source_arg_index + 1 :] + + raise TypeError(f"Missing source argument for duckdb.{function_name}") + + +def _require_pandas() -> None: + DependencyManager.pandas.require( + "to read remote DuckDB file sources in WASM" + ) + + +def _require_sqlglot() -> None: + DependencyManager.sqlglot.require( + "to rewrite remote DuckDB SQL sources in WASM" + ) + + +def _parse_duckdb_query(query: str) -> list[exp.Expression] | None: + import sqlglot + + try: + parsed = sqlglot.parse(query, read="duckdb") + except Exception as e: + LOGGER.debug("Failed to parse DuckDB query for WASM patch: %s", e) + return None + + return [statement for statement in parsed if statement is not None] + + +def _replace_remote_sources( + statements: Sequence[exp.Expression], + table_names: _RemoteTableNames, +) -> list[exp.Expression]: + from sqlglot import exp + + def replace_table(node: exp.Expression) -> exp.Expression: + if not isinstance(node, exp.Table): + return node + + source = remote_file_source_from_table(node) + if source is None: + return node + + replacement = exp.Table( + this=exp.to_identifier(table_names.name_for(source)) + ) + alias = node.args.get("alias") + if alias is not None: + replacement.set("alias", alias.copy()) + return replacement + + return [ + statement.transform(replace_table, copy=True) + for statement in statements + ] + + +def _reserved_sql_names( + statements: Sequence[exp.Expression], +) -> tuple[str, ...]: + from sqlglot import exp + + names: set[str] = set() + for statement in statements: + for identifier in statement.find_all(exp.Identifier): + if isinstance(identifier.this, str): + names.add(identifier.this) + for table in statement.find_all(exp.Table): + if table.name: + names.add(table.name) + return tuple(names) + + +def _format_duckdb_query( + statements: Sequence[exp.Expression], *, original_query: str +) -> str: + patched_query = "; ".join( + statement.sql(dialect="duckdb") for statement in statements + ) + if original_query.rstrip().endswith(";"): + patched_query += ";" + return patched_query diff --git a/marimo/_runtime/_wasm/_duckdb/dataframe.py b/marimo/_runtime/_wasm/_duckdb/dataframe.py new file mode 100644 index 00000000000..528bb42fd12 --- /dev/null +++ b/marimo/_runtime/_wasm/_duckdb/dataframe.py @@ -0,0 +1,165 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Decode fetched DuckDB file bytes into pandas DataFrames.""" + +from __future__ import annotations + +import os +import tempfile +from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + import pandas as pd + +_CSV_SUFFIXES = (".csv.gz", ".tsv.gz", ".csv", ".tsv") +_JSON_SUFFIXES = ( + ".geojson.gz", + ".ndjson.gz", + ".jsonl.gz", + ".json.gz", + ".geojson", + ".ndjson", + ".jsonl", + ".json", +) + + +def read_csv_dataframe( + data: bytes, options: Mapping[str, Any], *, url: str +) -> pd.DataFrame: + import duckdb + + return _read_temp_dataframe( + data, + suffix=_temp_suffix( + url, + suffixes=_CSV_SUFFIXES, + default=".csv", + ), + reader=lambda path: duckdb.from_csv_auto(path, **options).df(), + ) + + +def read_parquet_dataframe(data: bytes) -> pd.DataFrame: + import duckdb + + return _read_temp_dataframe( + data, + suffix=".parquet", + reader=lambda path: duckdb.read_parquet(path).df(), + ) + + +def read_json_dataframe( + data: bytes, options: Mapping[str, Any], *, url: str +) -> pd.DataFrame: + import duckdb + + return _read_temp_dataframe( + data, + suffix=_temp_suffix( + url, + suffixes=_JSON_SUFFIXES, + default=".json", + ), + reader=lambda path: duckdb.read_json(path, **options).df(), + ) + + +def read_json_objects_dataframe( + data: bytes, options: Mapping[str, Any], *, url: str +) -> pd.DataFrame: + return _read_temp_dataframe( + data, + suffix=_temp_suffix( + url, + suffixes=_JSON_SUFFIXES, + default=".json", + ), + reader=lambda path: _read_json_objects_path(path, options), + ) + + +def _read_json_objects_path( + path: str, options: Mapping[str, Any] +) -> pd.DataFrame: + import duckdb + + # DuckDB exposes JSON-object readers as SQL table functions, not Python + # module methods, so invoke the table function with bound parameters. + option_items = tuple(options.items()) + query_args = ["?"] + query_args.extend(f"{key} := ?" for key, _ in option_items) + return duckdb.sql( + f"SELECT * FROM read_json_objects_auto({', '.join(query_args)})", + params=[path, *(value for _, value in option_items)], + ).df() + + +def read_text_dataframe(data: bytes, url: str) -> pd.DataFrame: + import pandas as pd + + return pd.DataFrame( + { + "filename": [url], + "content": [data.decode("utf-8")], + "size": [len(data)], + "last_modified": [pd.NaT], + } + ) + + +def read_blob_dataframe(data: bytes, url: str) -> pd.DataFrame: + import pandas as pd + + return pd.DataFrame( + { + "filename": [url], + "content": [data], + "size": [len(data)], + "last_modified": [pd.NaT], + } + ) + + +def append_filename_column( + df: pd.DataFrame, url: str, column_name: str +) -> pd.DataFrame: + if column_name in df.columns: + raise ValueError( + f'Option filename adds column "{column_name}", but a column with this ' + "name is also in the file" + ) + + df = df.copy() + df[column_name] = url + return df + + +def _read_temp_dataframe( + data: bytes, + *, + suffix: str, + reader: Callable[[str], pd.DataFrame], +) -> pd.DataFrame: + """Materialize fetched bytes so DuckDB can parse them from a local path.""" + fd, path = tempfile.mkstemp(suffix=suffix) + try: + with os.fdopen(fd, "wb") as file: + file.write(data) + return reader(path) + finally: + try: + os.unlink(path) + except FileNotFoundError: + pass + + +def _temp_suffix(url: str, *, suffixes: tuple[str, ...], default: str) -> str: + path = urlparse(url).path.lower() + for suffix in suffixes: + if path.endswith(suffix): + return suffix + return default diff --git a/marimo/_runtime/_wasm/_duckdb/io.py b/marimo/_runtime/_wasm/_duckdb/io.py new file mode 100644 index 00000000000..47c049cced9 --- /dev/null +++ b/marimo/_runtime/_wasm/_duckdb/io.py @@ -0,0 +1,493 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Map DuckDB remote files to readers.""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Protocol +from urllib.parse import urlparse + +from marimo._runtime._wasm import _fetch +from marimo._runtime._wasm._duckdb.dataframe import ( + append_filename_column, + read_blob_dataframe, + read_csv_dataframe, + read_json_dataframe, + read_json_objects_dataframe, + read_parquet_dataframe, + read_text_dataframe, +) + +if TYPE_CHECKING: + from collections.abc import Mapping + + import pandas as pd + + +_ReaderName = str + + +@dataclass(frozen=True) +class _FetchedBytes: + url: str + data: bytes + + +@dataclass(frozen=True) +class RemoteFile: + url: str + fetcher_name: str + + def fetch(self) -> _FetchedBytes: + return _fetcher_by_name(self.fetcher_name).fetch(self.url) + + +@dataclass(frozen=True) +class _ReadRequest: + file: _FetchedBytes + options: Mapping[str, Any] + + +class _ByteFetcher(Protocol): + name: str + + def can_fetch(self, url: str) -> bool: ... + + def fetch(self, url: str) -> _FetchedBytes: ... + + +class _DataFrameReader(Protocol): + name: str + direct_extensions: tuple[str, ...] + function_names: tuple[str, ...] + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: ... + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: ... + + +class _HttpFetcher: + name: str = "http" + + def can_fetch(self, url: str) -> bool: + return urlparse(url).scheme in {"http", "https"} + + def fetch(self, url: str) -> _FetchedBytes: + return _FetchedBytes(url=url, data=_fetch.fetch_url_bytes(url)) + + +class _CsvReader: + name = "csv" + direct_extensions: tuple[str, ...] = ( + ".csv", + ".csv.gz", + ".tsv", + ".tsv.gz", + ) + function_names: tuple[str, ...] = ("read_csv", "read_csv_auto") + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + del function_name + options: dict[str, Any] = {} + for key, value in raw_options.items(): + if key in {"delim", "delimiter", "sep", "separator"}: + options["delimiter"] = _normalize_delimiter(value) + elif key == "header": + options["header"] = value + elif key in {"auto_detect", "normalize_names"}: + options[key] = value + elif _apply_shared_source_option(options, key, value): + pass + else: + return None + return options + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_csv_dataframe( + request.file.data, + _csv_reader_options(request.options), + url=request.file.url, + ) + + +class _ParquetReader: + name = "parquet" + direct_extensions: tuple[str, ...] = (".parquet", ".parq") + function_names: tuple[str, ...] = ("read_parquet", "parquet_scan") + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + del function_name + options: dict[str, Any] = {} + for key, value in raw_options.items(): + if _apply_common_table_option(options, key, value): + pass + else: + return None + return options + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_parquet_dataframe(request.file.data) + + +class _JsonReader: + name = "json" + direct_extensions: tuple[str, ...] = ( + ".json", + ".json.gz", + ".geojson", + ".geojson.gz", + ".jsonl", + ".jsonl.gz", + ".ndjson", + ".ndjson.gz", + ) + function_names: tuple[str, ...] = ( + "read_json", + "read_json_auto", + "read_ndjson", + ) + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + options: dict[str, Any] = {} + if function_name == "read_ndjson": + options["format"] = "newline_delimited" + + for key, value in raw_options.items(): + if not _apply_json_option(options, key, value): + return None + return options + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_json_dataframe( + request.file.data, + _json_reader_options(request.options), + url=request.file.url, + ) + + +class _JsonObjectsReader: + name = "json_objects" + direct_extensions: tuple[str, ...] = () + function_names: tuple[str, ...] = ( + "read_json_objects", + "read_json_objects_auto", + "read_ndjson_objects", + ) + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + options: dict[str, Any] = {} + if function_name == "read_ndjson_objects": + options["format"] = "newline_delimited" + + for key, value in raw_options.items(): + if not _apply_json_option(options, key, value): + return None + return options + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_json_objects_dataframe( + request.file.data, + _json_reader_options(request.options), + url=request.file.url, + ) + + +class _TextReader: + name = "text" + direct_extensions: tuple[str, ...] = () + function_names: tuple[str, ...] = ("read_text",) + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + del function_name + return {} if not raw_options else None + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_text_dataframe(request.file.data, request.file.url) + + +class _BlobReader: + name = "blob" + direct_extensions: tuple[str, ...] = () + function_names: tuple[str, ...] = ("read_blob",) + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + del function_name + return {} if not raw_options else None + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_blob_dataframe(request.file.data, request.file.url) + + +_FETCHERS: tuple[_ByteFetcher, ...] = (_HttpFetcher(),) +_READERS: tuple[_DataFrameReader, ...] = ( + _CsvReader(), + _ParquetReader(), + _JsonReader(), + _JsonObjectsReader(), + _TextReader(), + _BlobReader(), +) + + +@dataclass(frozen=True) +class RemoteFileSource: + files: tuple[RemoteFile, ...] + reader_name: _ReaderName + options: tuple[tuple[str, Any], ...] = () + + def read_options(self) -> dict[str, Any]: + return dict(self.options) + + def read_dataframe(self) -> pd.DataFrame: + frames = [self._read_file_dataframe(file) for file in self.files] + if len(frames) == 1: + return frames[0] + + import pandas as pd + + if self.read_options().get("union_by_name") is True: + return pd.concat(frames, ignore_index=True, sort=False) + + columns = list(frames[0].columns) + for frame in frames[1:]: + if list(frame.columns) != columns: + raise ValueError( + "DuckDB WASM remote sources must have matching columns " + "unless union_by_name=True" + ) + return pd.concat(frames, ignore_index=True) + + def _read_file_dataframe(self, file: RemoteFile) -> pd.DataFrame: + fetched = file.fetch() + options = self.read_options() + reader = _reader_by_name(self.reader_name) + df = reader.read_dataframe(_ReadRequest(file=fetched, options=options)) + filename = options.get("filename") + if filename is True: + return append_filename_column(df, fetched.url, "filename") + if isinstance(filename, str): + return append_filename_column(df, fetched.url, filename) + return df + + +def remote_file_from_url(url: str) -> RemoteFile | None: + fetcher = _fetcher_for_url(url) + if fetcher is None: + return None + return RemoteFile(url=url, fetcher_name=fetcher.name) + + +def remote_file_source_from_reader_args( + function_name: str, + source: Any, + raw_options: Mapping[str, Any], +) -> RemoteFileSource | None: + reader = reader_for_function(function_name) + if reader is None: + return None + + files = _remote_files_from_source_arg(source) + if files is None: + return None + + options = reader.read_options(function_name, raw_options) + if options is None: + return None + return RemoteFileSource(files, reader.name, tuple(sorted(options.items()))) + + +def _remote_files_from_source_arg( + source: Any, +) -> tuple[RemoteFile, ...] | None: + if isinstance(source, str): + file = remote_file_from_url(source) + return (file,) if file is not None else None + + if isinstance(source, Sequence) and not isinstance( + source, bytes | bytearray + ): + files: list[RemoteFile] = [] + for item in source: + if not isinstance(item, str): + return None + file = remote_file_from_url(item) + if file is None: + return None + files.append(file) + return tuple(files) if files else None + + return None + + +def _fetcher_for_url(url: str) -> _ByteFetcher | None: + return next( + (fetcher for fetcher in _FETCHERS if fetcher.can_fetch(url)), + None, + ) + + +def _fetcher_by_name(name: str) -> _ByteFetcher: + for fetcher in _FETCHERS: + if fetcher.name == name: + return fetcher + raise KeyError(f"Unknown DuckDB WASM fetcher: {name}") + + +def _reader_by_name( + name: _ReaderName, +) -> _DataFrameReader: + for reader in _READERS: + if reader.name == name: + return reader + raise KeyError(f"Unknown DuckDB WASM reader: {name}") + + +def reader_for_url(url: str) -> _DataFrameReader | None: + path = urlparse(url).path.lower() + return next( + ( + reader + for reader in _READERS + if path.endswith(reader.direct_extensions) + ), + None, + ) + + +def reader_for_function( + function_name: str, +) -> _DataFrameReader | None: + return next( + ( + reader + for reader in _READERS + if function_name in reader.function_names + ), + None, + ) + + +def _csv_reader_options(options: Mapping[str, Any]) -> dict[str, Any]: + return { + key: value + for key, value in options.items() + if key not in {"filename", "union_by_name"} + } + + +def _json_reader_options(options: Mapping[str, Any]) -> dict[str, Any]: + return { + key: _normalize_json_reader_option(key, value) + for key, value in options.items() + if key not in {"filename", "union_by_name"} + } + + +def _normalize_delimiter(value: Any) -> Any: + if value == r"\t": + return "\t" + return value + + +def _normalize_json_reader_option(key: str, value: Any) -> Any: + if key == "compression": + return _normalize_json_compression(value) + if key == "format": + return _normalize_json_format(value) + return value + + +def _normalize_json_compression(value: Any) -> str: + compression = str(value).lower() + if compression == "auto": + return "auto_detect" + if compression == "none": + return "uncompressed" + return compression + + +def _normalize_json_format(value: Any) -> str: + fmt = str(value).lower() + if fmt == "ndjson": + return "newline_delimited" + if fmt == "array_of_objects": + return "array" + return fmt + + +def _apply_json_option(options: dict[str, Any], key: str, value: Any) -> bool: + if key == "format": + options["format"] = _normalize_json_format(value) + return True + if _apply_shared_source_option(options, key, value): + return True + if _is_safe_reader_option_name(key): + options[key] = value + return True + return False + + +def _is_safe_reader_option_name(key: str) -> bool: + return key.isidentifier() + + +def _apply_common_table_option( + options: dict[str, Any], key: str, value: Any +) -> bool: + if key == "filename" and isinstance(value, bool | str): + options["filename"] = value + return True + if key == "union_by_name" and isinstance(value, bool): + options["union_by_name"] = value + return True + return False + + +def _apply_compression_option( + options: dict[str, Any], key: str, value: Any +) -> bool: + if key == "compression" and _is_supported_compression(value): + options["compression"] = str(value).lower() + return True + return False + + +def _apply_shared_source_option( + options: dict[str, Any], key: str, value: Any +) -> bool: + return _apply_compression_option( + options, key, value + ) or _apply_common_table_option(options, key, value) + + +def _is_supported_compression(value: Any) -> bool: + return str(value).lower() in {"auto", "none", "gzip"} diff --git a/marimo/_runtime/_wasm/_duckdb/sources.py b/marimo/_runtime/_wasm/_duckdb/sources.py new file mode 100644 index 00000000000..d8337f664f0 --- /dev/null +++ b/marimo/_runtime/_wasm/_duckdb/sources.py @@ -0,0 +1,160 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Resolve sqlglot table nodes to remote file sources. + +Handles direct URL tables and reader calls such as +``read_csv('https://example.com/cars.csv')``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from marimo._runtime._wasm._duckdb.io import ( + RemoteFileSource, + reader_for_url, + remote_file_from_url, + remote_file_source_from_reader_args, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + from sqlglot import exp + +# SQL options can parse to falsy values such as false, 0, or ""; this marks +# unsupported expressions without conflating them with valid literal values. +_MISSING = object() + + +def remote_file_source_from_table( + table: exp.Table, +) -> RemoteFileSource | None: + table_name = table.name + if table_name: + reader = reader_for_url(table_name) + remote_file = remote_file_from_url(table_name) + if reader is not None and remote_file is not None: + return RemoteFileSource((remote_file,), reader.name) + + table_expr = table.this + if table_expr is None: + return None + + function_name = _table_function_name(table_expr) + if function_name is None: + return None + + args = _table_function_args(table_expr) + source = _read_function_source(args) + if source is None: + return None + + raw_options = _read_function_options(args[1:]) + if raw_options is None: + return None + return remote_file_source_from_reader_args( + function_name, source, raw_options + ) + + +def _table_function_name(table_expr: exp.Expression) -> str | None: + import sqlglot.expressions as exp + + # sqlglot versions model first-party DuckDB readers either as explicit + # Read* nodes or as generic anonymous table functions. + read_csv = getattr(exp, "ReadCSV", None) + if read_csv is not None and isinstance(table_expr, read_csv): + return "read_csv" + + read_parquet = getattr(exp, "ReadParquet", None) + if read_parquet is not None and isinstance(table_expr, read_parquet): + return "read_parquet" + + if isinstance(table_expr, exp.Anonymous): + return str(table_expr.this).lower() + return None + + +def _table_function_args(table_expr: exp.Expression) -> list[exp.Expression]: + import sqlglot.expressions as exp + + read_csv = getattr(exp, "ReadCSV", None) + if read_csv is not None and isinstance(table_expr, read_csv): + first = [table_expr.this] if table_expr.this is not None else [] + return [*first, *table_expr.expressions] + return list(table_expr.expressions) + + +def _read_function_source( + args: Sequence[exp.Expression], +) -> str | tuple[str, ...] | None: + import sqlglot.expressions as exp + + if not args: + return None + source_expr = args[0] + if isinstance(source_expr, exp.Literal) and source_expr.is_string: + return str(source_expr.this) + + if isinstance(source_expr, exp.Array) and source_expr.expressions: + urls: list[str] = [] + for item_expr in source_expr.expressions: + if not ( + isinstance(item_expr, exp.Literal) and item_expr.is_string + ): + return None + urls.append(str(item_expr.this)) + return tuple(urls) + + return None + + +def _read_function_options( + option_exprs: Sequence[exp.Expression], +) -> dict[str, Any] | None: + options: dict[str, Any] = {} + for option_expr in option_exprs: + option = _read_function_option(option_expr) + if option is None: + return None + key, value = option + options[key] = value + return options + + +def _read_function_option( + option_expr: exp.Expression, +) -> tuple[str, Any] | None: + import sqlglot.expressions as exp + + property_eq = getattr(exp, "PropertyEQ", None) + option_classes = ( + (exp.EQ,) if property_eq is None else (exp.EQ, property_eq) + ) + if not isinstance(option_expr, option_classes): + return None + + value_expr = option_expr.args.get("expression") + if value_expr is None: + return None + + value = _literal_value(value_expr) + if value is _MISSING: + return None + + key = getattr(option_expr.this, "name", None) + if key is None: + return None + return key.lower(), value + + +def _literal_value(value_expr: exp.Expression) -> Any: + import sqlglot.expressions as exp + + if isinstance(value_expr, exp.Boolean): + return value_expr.this + if isinstance(value_expr, exp.Literal): + if value_expr.is_string: + return value_expr.this + return value_expr.to_py() + return _MISSING From 9f030933d349423cc308dc4d696806ebd5f34631 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Fri, 8 May 2026 13:41:09 +0200 Subject: [PATCH 03/22] feat: wire DuckDB WASM SQL through marimo --- frontend/src/core/islands/worker/worker.tsx | 2 +- frontend/src/core/wasm/worker/bootstrap.ts | 2 +- frontend/src/core/wasm/worker/worker.ts | 2 +- marimo/_output/formatters/df_formatters.py | 12 ++ marimo/_output/formatters/formatters.py | 2 + marimo/_sql/utils.py | 126 +++++++++++++++++--- 6 files changed, 129 insertions(+), 17 deletions(-) diff --git a/frontend/src/core/islands/worker/worker.tsx b/frontend/src/core/islands/worker/worker.tsx index 6d4f4f04f29..4423f6c0c5b 100644 --- a/frontend/src/core/islands/worker/worker.tsx +++ b/frontend/src/core/islands/worker/worker.tsx @@ -85,7 +85,7 @@ const requestHandler = createRPCRequestHandler({ loadPackages: async (code: string) => { await pyodideReadyPromise; // Make sure loading is done - if (code.includes("mo.sql")) { + if (code.includes("mo.sql") || code.includes("duckdb")) { // Add pandas and duckdb to the code code = `import pandas\n${code}`; code = `import duckdb\n${code}`; diff --git a/frontend/src/core/wasm/worker/bootstrap.ts b/frontend/src/core/wasm/worker/bootstrap.ts index 25dd7dd3d0b..d32b34563c7 100644 --- a/frontend/src/core/wasm/worker/bootstrap.ts +++ b/frontend/src/core/wasm/worker/bootstrap.ts @@ -163,7 +163,7 @@ export class DefaultWasmController implements WasmController { private async loadNotebookDeps(code: string, foundPackages: Set) { const pyodide = this.requirePyodide; - if (code.includes("mo.sql")) { + if (code.includes("mo.sql") || code.includes("duckdb")) { // We need pandas and duckdb for mo.sql code = `import pandas\n${code}`; code = `import duckdb\n${code}`; diff --git a/frontend/src/core/wasm/worker/worker.ts b/frontend/src/core/wasm/worker/worker.ts index e7890c9d0bd..6e2c9a98fbd 100644 --- a/frontend/src/core/wasm/worker/worker.ts +++ b/frontend/src/core/wasm/worker/worker.ts @@ -141,7 +141,7 @@ const requestHandler = createRPCRequestHandler({ const span = t.startSpan("loadPackages"); await pyodideReadyPromise; // Make sure loading is done - if (code.includes("mo.sql")) { + if (code.includes("mo.sql") || code.includes("duckdb")) { // Add pandas and duckdb to the code code = `import pandas\n${code}`; code = `import duckdb\n${code}`; diff --git a/marimo/_output/formatters/df_formatters.py b/marimo/_output/formatters/df_formatters.py index 71a17705cce..f7ebe708fe2 100644 --- a/marimo/_output/formatters/df_formatters.py +++ b/marimo/_output/formatters/df_formatters.py @@ -19,6 +19,7 @@ from marimo._plugins.stateless.plain_text import plain_text from marimo._plugins.ui._impl import tabs from marimo._plugins.ui._impl.table import get_default_table_page_size, table +from marimo._runtime._wasm._duckdb import patch_duckdb_for_wasm from marimo._runtime._wasm._polars import patch_polars_for_wasm LOGGER = _loggers.marimo_logger() @@ -160,6 +161,17 @@ def _show_marimo_dataframe( return table(df, selection=None, pagination=None)._mime_() +class DuckDBFormatter(FormatterFactory): + """Use DuckDB's lazy import hook to install WASM runtime patches.""" + + @staticmethod + def package_name() -> str: + return "duckdb" + + def register(self) -> Unregister: + return patch_duckdb_for_wasm() + + class DataFusionFormatter(FormatterFactory): @staticmethod def package_name() -> str: diff --git a/marimo/_output/formatters/formatters.py b/marimo/_output/formatters/formatters.py index 5a1ae706e56..ee375220989 100644 --- a/marimo/_output/formatters/formatters.py +++ b/marimo/_output/formatters/formatters.py @@ -18,6 +18,7 @@ from marimo._output.formatters.cell import CellFormatter from marimo._output.formatters.df_formatters import ( DataFusionFormatter, + DuckDBFormatter, IbisFormatter, PolarsFormatter, PyArrowFormatter, @@ -54,6 +55,7 @@ AltairFormatter.package_name(): AltairFormatter(), MatplotlibFormatter.package_name(): MatplotlibFormatter(), DataFusionFormatter.package_name(): DataFusionFormatter(), + DuckDBFormatter.package_name(): DuckDBFormatter(), IbisFormatter.package_name(): IbisFormatter(), PandasFormatter.package_name(): PandasFormatter(), PolarsFormatter.package_name(): PolarsFormatter(), diff --git a/marimo/_sql/utils.py b/marimo/_sql/utils.py index 4d17982906b..4e5f79a587e 100644 --- a/marimo/_sql/utils.py +++ b/marimo/_sql/utils.py @@ -8,6 +8,10 @@ from marimo._config.config import SqlOutputType from marimo._data.models import DataType from marimo._dependencies.dependencies import DependencyManager +from marimo._runtime._wasm._duckdb import ( + WasmDuckDBSqlResult, + try_run_duckdb_sql_with_wasm_patch, +) from marimo._runtime.context.types import ( ContextNotInitializedError, get_context, @@ -26,10 +30,41 @@ CHEAP_DISCOVERY_DATABASES = ["duckdb", "sqlite", "mysql", "postgresql"] +def _try_wasm_duckdb_sql( + query: str, + connection: Any, + glbls: dict[str, Any], + *, + reserved_names: tuple[str, ...] = (), +) -> WasmDuckDBSqlResult | None: + import duckdb + + if connection is duckdb: + original = duckdb.sql + args: tuple[Any, ...] = (query,) + query_arg_index = 0 + else: + original = type(connection).sql + args = (connection, query) + query_arg_index = 1 + + result = try_run_duckdb_sql_with_wasm_patch( + original, + args, + {}, + query_arg_index=query_arg_index, + query_kwarg_names=("query",), + eval_globals=glbls, + eval_locals=glbls, + reserved_names=reserved_names, + ) + return result + + def wrapped_sql( query: str, connection: duckdb.DuckDBPyConnection | None, -) -> duckdb.DuckDBPyRelation: +) -> duckdb.DuckDBPyRelation | None: DependencyManager.duckdb.require("to execute sql") # In Python globals() are scoped to modules; since this function @@ -46,7 +81,13 @@ def wrapped_sql( try: ctx = get_context() except ContextNotInitializedError: - relation = connection.sql(query=query) + relation: duckdb.DuckDBPyRelation | None + result = _try_wasm_duckdb_sql(query, connection, globals()) + if result is None: + # No WASM rewrite was needed; use DuckDB's normal SQL path. + relation = connection.sql(query=query) + else: + relation = cast(duckdb.DuckDBPyRelation | None, result.value) else: install_connection = ( ctx.execution_context.with_connection @@ -54,15 +95,57 @@ def wrapped_sql( else nullcontext ) with install_connection(connection): - relation = eval( - "connection.sql(query=query)", + result = _try_wasm_duckdb_sql( + query, + connection, ctx.globals, - {"query": query, "connection": connection}, + reserved_names=tuple(ctx.globals), ) + if result is None: + # Run in kernel globals so DuckDB replacement scans see user data. + relation = eval( + "connection.sql(query=query)", + ctx.globals, + {"query": query, "connection": connection}, + ) + else: + relation = cast(duckdb.DuckDBPyRelation | None, result.value) return relation +def _try_wasm_duckdb_execute( + query: str, + params: list[Any], + connection: Any, + glbls: dict[str, Any], + *, + reserved_names: tuple[str, ...] = (), +) -> WasmDuckDBSqlResult | None: + import duckdb + + if connection is duckdb: + original = duckdb.execute + args: tuple[Any, ...] = (query, params) + query_arg_index = 0 + else: + original = type(connection).execute + args = (connection, query, params) + query_arg_index = 1 + + result = try_run_duckdb_sql_with_wasm_patch( + original, + args, + {}, + query_arg_index=query_arg_index, + query_kwarg_names=("query",), + eval_globals=glbls, + eval_locals=glbls, + reserved_names=reserved_names, + ) + return result + + def execute_duckdb_sql( query: str, params: list[Any], @@ -83,7 +166,11 @@ def execute_duckdb_sql( try: ctx = get_context() except ContextNotInitializedError: - return connection.execute(query, params) + result = _try_wasm_duckdb_execute(query, params, connection, globals()) + if result is None: + # No WASM rewrite was needed; preserve DuckDB's parameterized path. + return connection.execute(query, params) + return cast(duckdb.DuckDBPyConnection, result.value) else: install_connection = ( ctx.execution_context.with_connection @@ -91,16 +178,27 @@ def execute_duckdb_sql( else nullcontext ) with install_connection(connection): - result: duckdb.DuckDBPyConnection = eval( - "connection.execute(query, params)", + result = _try_wasm_duckdb_execute( + query, + params, + connection, ctx.globals, - { - "query": query, - "params": params, - "connection": connection, - }, + reserved_names=tuple(ctx.globals), ) - return result + if result is None: + # Run in kernel globals so parameterized SQL can scan user data. + value = eval( + "connection.execute(query, params)", + ctx.globals, + { + "query": query, + "params": params, + "connection": connection, + }, + ) + else: + value = result.value + return cast(duckdb.DuckDBPyConnection, value) def try_convert_to_polars( From 35499881d2a4a5abca099f7994ca8a358058925e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Fri, 8 May 2026 13:41:22 +0200 Subject: [PATCH 04/22] test: cover DuckDB WASM parity and runtime integration --- tests/_runtime/test_duckdb_wasm.py | 1196 ++++++++++++++++++++++++++++ 1 file changed, 1196 insertions(+) create mode 100644 tests/_runtime/test_duckdb_wasm.py diff --git a/tests/_runtime/test_duckdb_wasm.py b/tests/_runtime/test_duckdb_wasm.py new file mode 100644 index 00000000000..e8ff9b158d1 --- /dev/null +++ b/tests/_runtime/test_duckdb_wasm.py @@ -0,0 +1,1196 @@ +# Copyright 2026 Marimo. All rights reserved. +from __future__ import annotations + +import gzip +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pytest + +from marimo._runtime._wasm._duckdb import ( + patch_duckdb_for_wasm, + patch_duckdb_query_for_wasm, +) +from marimo._sql.engines.duckdb import DuckDBEngine +from tests.conftest import ExecReqProvider, mock_pyodide + +pytestmark = pytest.mark.requires("duckdb", "pandas", "sqlglot") + +pytest.importorskip("duckdb") +pytest.importorskip("pandas") +pytest.importorskip("sqlglot") + +if TYPE_CHECKING: + from collections.abc import Sequence + + from marimo._runtime.runtime import Kernel + + +def _normalize_value(value: object) -> object: + import pandas as pd + + if hasattr(value, "tolist") and not isinstance( + value, dict | list | str | bytes | bytearray + ): + return value.tolist() + if isinstance(value, list): + return [_normalize_value(item) for item in value] + if isinstance(value, dict): + return {key: _normalize_value(item) for key, item in value.items()} + if isinstance(value, bytes | bytearray): + return bytes(value) + try: + return None if bool(pd.isna(value)) else value + except (TypeError, ValueError): + return value + + +def _records(df: object) -> list[dict[str, object]]: + return [ + {key: _normalize_value(value) for key, value in row.items()} + for row in df.to_dict("records") # type: ignore[attr-defined] + ] + + +def _rows(rows: Sequence[Sequence[object]]) -> list[tuple[object, ...]]: + return [tuple(_normalize_value(value) for value in row) for row in rows] + + +@dataclass(frozen=True) +class RemoteFixture: + url: str + suffix: str + data: bytes + + +@dataclass(frozen=True) +class QueryParityCase: + name: str + query: str + fixtures: tuple[RemoteFixture, ...] + + +@dataclass(frozen=True) +class DirectReadParityCase: + name: str + function_name: str + fixture: RemoteFixture + source_kwarg: str | None = None + options: tuple[tuple[str, object], ...] = () + + +def _parquet_bytes(sql: str) -> bytes: + import duckdb + + with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as file: + path = Path(file.name) + try: + duckdb.sql(sql).write_parquet(str(path)) + return path.read_bytes() + finally: + path.unlink(missing_ok=True) + + +def _local_fixture_path( + fixture: RemoteFixture, tmp_path: Path, filename: str +) -> str: + path = tmp_path / f"{filename}{fixture.suffix}" + path.write_bytes(fixture.data) + return path.as_posix() + + +def _local_query( + remote_query: str, + fixtures: Sequence[RemoteFixture], + tmp_path: Path, +) -> str: + query = remote_query + for idx, fixture in enumerate(fixtures): + query = query.replace( + fixture.url, + _local_fixture_path(fixture, tmp_path, f"remote_{idx}"), + ) + return query + + +def _native_rows(query: str) -> list[tuple[object, ...]]: + import duckdb + + connection = duckdb.connect(":memory:") + try: + return _rows(connection.sql(query).fetchall()) + finally: + connection.close() + + +def _patched_rows( + query: str, + fixtures: Sequence[RemoteFixture], +) -> tuple[list[tuple[object, ...]], list[str]]: + import duckdb + + fixtures_by_url = {fixture.url: fixture for fixture in fixtures} + + def fetch(url: str) -> bytes: + return fixtures_by_url[url].data + + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + side_effect=fetch, + ) as fetch_url_bytes, + ): + patch_result = patch_duckdb_query_for_wasm(query) + + assert patch_result is not None + assert all(fixture.url not in patch_result.query for fixture in fixtures) + + connection = duckdb.connect(":memory:") + try: + for table_name, df in patch_result.tables.items(): + connection.register(table_name, df) + rows = _rows(connection.sql(patch_result.query).fetchall()) + finally: + connection.close() + + fetched_urls = [call.args[0] for call in fetch_url_bytes.call_args_list] + return rows, fetched_urls + + +def _direct_reader_args( + case: DirectReadParityCase, source: str +) -> tuple[tuple[object, ...], dict[str, object]]: + kwargs = dict(case.options) + if case.source_kwarg is None: + return (source,), kwargs + kwargs[case.source_kwarg] = source + return (), kwargs + + +def _run_direct_reader( + case: DirectReadParityCase, + source: str, + *, + api_kind: str, +) -> list[tuple[object, ...]]: + import duckdb + + args, kwargs = _direct_reader_args(case, source) + if api_kind == "module": + relation = getattr(duckdb, case.function_name)(*args, **kwargs) + return _rows(relation.fetchall()) + + connection = duckdb.connect(":memory:") + try: + if api_kind == "connection": + relation = getattr(connection, case.function_name)(*args, **kwargs) + elif api_kind == "module-connection-kw": + kwargs["connection"] = connection + relation = getattr(duckdb, case.function_name)(*args, **kwargs) + else: + raise ValueError(f"Unknown DuckDB direct reader API: {api_kind}") + return _rows(relation.fetchall()) + finally: + connection.close() + + +def _patched_direct_rows( + case: DirectReadParityCase, + *, + api_kind: str, +) -> tuple[list[tuple[object, ...]], list[str]]: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=case.fixture.data, + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + rows = _run_direct_reader( + case, case.fixture.url, api_kind=api_kind + ) + finally: + unpatch() + + fetched_urls = [call.args[0] for call in fetch_url_bytes.call_args_list] + return rows, fetched_urls + + +def _direct_reader_parity_cases() -> list[DirectReadParityCase]: + return [ + DirectReadParityCase( + "csv-positional", + "read_csv", + RemoteFixture( + "https://datasets.marimo.app/cars.csv", + ".csv", + b"1;ford\n2;toyota\n", + ), + options=(("delimiter", ";"), ("header", False)), + ), + DirectReadParityCase( + "parquet-file-glob", + "read_parquet", + RemoteFixture( + "https://datasets.marimo.app/cars.parquet", + ".parquet", + _parquet_bytes("SELECT 'ford' AS make"), + ), + source_kwarg="file_glob", + ), + DirectReadParityCase( + "json-path-or-buffer", + "read_json", + RemoteFixture( + "https://datasets.marimo.app/cars.json", + ".json", + b'[{"make":"ford"},{"make":"toyota"}]', + ), + source_kwarg="path_or_buffer", + options=(("format", "array"),), + ), + ] + + +def _query_parity_cases() -> list[QueryParityCase]: + csv = RemoteFixture( + "https://example.com/cars.csv", + ".csv", + b"make,mpg\nford,25\ntoyota,18\n", + ) + csv_semicolon = RemoteFixture( + "https://example.com/cars-semicolon.csv", + ".csv", + b"1;ford\n2;toyota\n", + ) + csv_gzip = RemoteFixture( + "https://example.com/cars.csv.gz", + ".csv.gz", + gzip.compress(b"make,mpg\nford,25\n"), + ) + csv_download = RemoteFixture( + "https://example.com/download", + "", + b"make,mpg\nford,25\n", + ) + csv_normalize = RemoteFixture( + "https://example.com/names.csv", + ".csv", + b"make name,mpg\nford,25\n", + ) + tsv = RemoteFixture( + "https://example.com/walmarts.tsv", + ".tsv", + b"longitude\tlatitude\n1\t2\n", + ) + csv_a = RemoteFixture("https://example.com/a.csv", ".csv", b"a,b\n1,2\n") + csv_b = RemoteFixture("https://example.com/b.csv", ".csv", b"a,c\n3,4\n") + parquet = RemoteFixture( + "https://example.com/a.parquet", + ".parquet", + _parquet_bytes( + """ + SELECT 1 AS a, 'x' AS b + UNION ALL SELECT 2, 'y' + """ + ), + ) + parquet_b = RemoteFixture( + "https://example.com/b.parquet", + ".parquet", + _parquet_bytes("SELECT 3 AS a, 'z' AS b"), + ) + json_array = RemoteFixture( + "https://example.com/a.json", + ".json", + b'[{"a":1,"b":"x"},{"a":2,"b":"y"}]', + ) + json_array_b = RemoteFixture( + "https://example.com/b.json", + ".json", + b'[{"a":3,"b":"z"}]', + ) + complex_json = RemoteFixture( + "https://example.com/countries.json", + ".json", + b'{"type":"Topology","arcs":[[[0]],1]}', + ) + unstructured_json = RemoteFixture( + "https://example.com/unstructured.json", + ".json", + b'{"a":1} {"a":2}', + ) + ndjson = RemoteFixture( + "https://example.com/a.ndjson", + ".ndjson", + b'{"a":1}\n{"a":2}\n', + ) + ndjson_gzip = RemoteFixture( + "https://example.com/events.ndjson.gz", + ".ndjson.gz", + gzip.compress(b'{"event_id":5,"value":10}\n'), + ) + ndjson_objects_gzip = RemoteFixture( + "https://example.com/objects.ndjson.gz", + ".ndjson.gz", + gzip.compress(b'{"a":1}\n{"a":2}\n'), + ) + geojson = RemoteFixture( + "https://example.com/a.geojson", + ".geojson", + b'{"type":"FeatureCollection","features":[]}', + ) + text = RemoteFixture("https://example.com/a.txt", ".txt", b"hello") + blob = RemoteFixture("https://example.com/a.bin", ".bin", b"\x00\x01") + return [ + QueryParityCase( + "direct-csv-literal", + f"SELECT make, mpg FROM '{csv.url}' ORDER BY make", + (csv,), + ), + QueryParityCase( + "csv-reader-options", + f""" + SELECT column1 FROM read_csv( + '{csv_semicolon.url}', delim := ';', header := false + ) + ORDER BY column0 + """, + (csv_semicolon,), + ), + QueryParityCase( + "tsv-escaped-delimiter", + f""" + SELECT * FROM read_csv_auto('{tsv.url}', delim='\\t') + """, + (tsv,), + ), + QueryParityCase( + "gzipped-csv", + f"SELECT make, mpg FROM read_csv('{csv_gzip.url}')", + (csv_gzip,), + ), + QueryParityCase( + "reader-without-extension", + f"SELECT make, mpg FROM read_csv('{csv_download.url}')", + (csv_download,), + ), + QueryParityCase( + "csv-normalize-names", + f""" + SELECT make_name, mpg FROM read_csv( + '{csv_normalize.url}', normalize_names=true + ) + """, + (csv_normalize,), + ), + QueryParityCase( + "csv-union-by-name", + f""" + SELECT * FROM read_csv( + ['{csv_a.url}', '{csv_b.url}'], union_by_name=true + ) + ORDER BY a + """, + (csv_a, csv_b), + ), + QueryParityCase( + "parquet-reader", + f"SELECT * FROM read_parquet('{parquet.url}') ORDER BY a", + (parquet,), + ), + QueryParityCase( + "parquet-list", + f""" + SELECT * FROM read_parquet(['{parquet.url}', '{parquet_b.url}']) + ORDER BY a + """, + (parquet, parquet_b), + ), + QueryParityCase( + "parquet-scan-alias", + f"SELECT * FROM parquet_scan('{parquet.url}') ORDER BY a", + (parquet,), + ), + QueryParityCase( + "parquet-direct-literal", + f"SELECT * FROM '{parquet.url}' ORDER BY a", + (parquet,), + ), + QueryParityCase( + "json-reader", + f"SELECT * FROM read_json_auto('{json_array.url}') ORDER BY a", + (json_array,), + ), + QueryParityCase( + "json-list", + f""" + SELECT * FROM read_json_auto( + ['{json_array.url}', '{json_array_b.url}'] + ) + ORDER BY a + """, + (json_array, json_array_b), + ), + QueryParityCase( + "complex-json", + f"SELECT type, arcs FROM read_json_auto('{complex_json.url}')", + (complex_json,), + ), + QueryParityCase( + "unstructured-json", + f""" + SELECT * FROM read_json( + '{unstructured_json.url}', format='unstructured' + ) + ORDER BY a + """, + (unstructured_json,), + ), + QueryParityCase( + "ndjson-reader", + f"SELECT * FROM read_ndjson('{ndjson.url}') ORDER BY a", + (ndjson,), + ), + QueryParityCase( + "direct-ndjson-literal", + f"SELECT * FROM '{ndjson.url}' ORDER BY a", + (ndjson,), + ), + QueryParityCase( + "gzipped-ndjson", + f""" + SELECT * FROM read_json( + '{ndjson_gzip.url}', format='newline_delimited', + compression='gzip' + ) + """, + (ndjson_gzip,), + ), + QueryParityCase( + "ndjson-objects", + f"SELECT json FROM read_ndjson_objects('{json_array.url}')", + (json_array,), + ), + QueryParityCase( + "json-objects-auto-gzip", + f""" + SELECT json FROM read_json_objects_auto('{ndjson_objects_gzip.url}') + ORDER BY json + """, + (ndjson_objects_gzip,), + ), + QueryParityCase( + "geojson-reader", + f"SELECT type FROM read_json_auto('{geojson.url}')", + (geojson,), + ), + QueryParityCase( + "text-and-blob", + f""" + SELECT content, size FROM read_text('{text.url}') + UNION ALL + SELECT content::VARCHAR, size FROM read_blob('{blob.url}') + """, + (text, blob), + ), + QueryParityCase( + "mixed-parquet-json", + f""" + SELECT * FROM read_parquet('{parquet.url}') + UNION ALL + SELECT * FROM read_json_auto('{json_array.url}') + ORDER BY a + """, + (parquet, json_array), + ), + ] + + +class TestDuckDBWasmDirectReadPatch: + @staticmethod + def test_noop_outside_pyodide() -> None: + import duckdb + + original = duckdb.read_csv + unpatch = patch_duckdb_for_wasm() + assert duckdb.read_csv is original + unpatch() + assert duckdb.read_csv is original + + @staticmethod + @pytest.mark.parametrize("api_kind", ["module", "connection"]) + @pytest.mark.parametrize( + "case", + _direct_reader_parity_cases(), + ids=lambda case: case.name, + ) + def test_direct_readers_match_native_duckdb( + case: DirectReadParityCase, + api_kind: str, + tmp_path: Path, + ) -> None: + native_rows = _run_direct_reader( + case, + _local_fixture_path(case.fixture, tmp_path, case.name), + api_kind=api_kind, + ) + patched_rows, fetched_urls = _patched_direct_rows( + case, api_kind=api_kind + ) + + assert patched_rows == native_rows + assert fetched_urls == [case.fixture.url] + + @staticmethod + @pytest.mark.parametrize( + "case", + [ + case + for case in _direct_reader_parity_cases() + if case.function_name != "read_csv" + ], + ids=lambda case: case.name, + ) + def test_module_readers_preserve_connection_kw( + case: DirectReadParityCase, tmp_path: Path + ) -> None: + native_rows = _run_direct_reader( + case, + _local_fixture_path(case.fixture, tmp_path, case.name), + api_kind="module-connection-kw", + ) + patched_rows, fetched_urls = _patched_direct_rows( + case, api_kind="module-connection-kw" + ) + + assert patched_rows == native_rows + assert fetched_urls == [case.fixture.url] + + +class TestDuckDBWasmQueryPatch: + @staticmethod + def test_noop_outside_pyodide() -> None: + assert ( + patch_duckdb_query_for_wasm( + "SELECT * FROM 'https://datasets.marimo.app/cars.csv'" + ) + is None + ) + + @staticmethod + @pytest.mark.parametrize( + "case", + _query_parity_cases(), + ids=lambda case: case.name, + ) + def test_rewrites_remote_sources_with_native_duckdb_parity( + case: QueryParityCase, tmp_path: Path + ) -> None: + native_rows = _native_rows( + _local_query(case.query, case.fixtures, tmp_path) + ) + patched_rows, fetched_urls = _patched_rows(case.query, case.fixtures) + + assert patched_rows == native_rows + assert fetched_urls == [fixture.url for fixture in case.fixtures] + + @staticmethod + def test_list_argument_requires_matching_schemas() -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + side_effect=[b"a,b\n1,2\n", b"a,c\n3,4\n"], + ), + pytest.raises(ValueError), + ): + patch_duckdb_query_for_wasm( + """ + SELECT * FROM read_csv([ + 'https://example.com/a.csv', + 'https://example.com/b.csv' + ]) + """, + ) + + @staticmethod + def test_rewrites_direct_geojson_literal() -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b'{"type":"FeatureCollection","features":[]}', + ), + ): + patch_result = patch_duckdb_query_for_wasm( + "FROM 'https://example.com/a.geojson'", + ) + + assert patch_result is not None + assert ( + patch_result.query == "SELECT * FROM __marimo_wasm_duckdb_remote_0" + ) + assert _records(next(iter(patch_result.tables.values()))) == [ + {"type": "FeatureCollection", "features": []} + ] + + @staticmethod + def test_avoids_reserved_table_names() -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\n", + ), + ): + patch_result = patch_duckdb_query_for_wasm( + "SELECT * FROM 'https://datasets.marimo.app/cars.csv'", + reserved_names=("__marimo_wasm_duckdb_remote_0",), + ) + + assert patch_result is not None + assert ( + patch_result.query == "SELECT * FROM __marimo_wasm_duckdb_remote_1" + ) + + @staticmethod + def test_avoids_sql_cte_table_names_case_insensitively() -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"mpg\n25\n", + ), + ): + patch_result = patch_duckdb_query_for_wasm( + """ + WITH __MARIMO_WASM_DUCKDB_REMOTE_0 AS (SELECT 99 AS mpg) + SELECT mpg FROM 'https://datasets.marimo.app/cars.csv' + """, + ) + + assert patch_result is not None + assert "FROM __marimo_wasm_duckdb_remote_1" in patch_result.query + + +class TestDuckDBWasmMoSqlIntegration: + @staticmethod + async def test_mo_sql_rewrites_remote_literal_in_kernel( + executing_kernel: Kernel, exec_req: ExecReqProvider + ) -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + patch.object( + DuckDBEngine, "sql_output_format", return_value="native" + ), + ): + await executing_kernel.run( + [ + exec_req.get("import marimo as mo"), + exec_req.get( + """ + result = mo.sql( + ''' + SELECT make + FROM 'https://datasets.marimo.app/cars.csv' + WHERE mpg > 20 + ''', + output=False, + ) + """ + ), + ] + ) + + result = executing_kernel.globals["result"] + assert result.fetchall() == [("ford",)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + async def test_mo_sql_create_table_remote_literal_runs_once( + executing_kernel: Kernel, exec_req: ExecReqProvider + ) -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + patch.object( + DuckDBEngine, "sql_output_format", return_value="native" + ), + ): + await executing_kernel.run( + [ + exec_req.get("import duckdb"), + exec_req.get("import marimo as mo"), + exec_req.get( + """ + result = mo.sql( + ''' + CREATE OR REPLACE TABLE __marimo_wasm_create_once AS ( + SELECT * FROM 'https://datasets.marimo.app/cars.csv' + ) + ''', + output=False, + ) + """ + ), + exec_req.get( + """ + rows = duckdb.sql( + ''' + SELECT make, mpg + FROM __marimo_wasm_create_once + ORDER BY make + ''' + ).fetchall() + duckdb.close() + """ + ), + ] + ) + + assert executing_kernel.globals["result"] is None + assert executing_kernel.globals["rows"] == [ + ("ford", 25), + ("toyota", 18), + ] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + +class TestDuckDBWasmSqlApiPatch: + @staticmethod + def test_noop_outside_pyodide() -> None: + import duckdb + + original_sql = duckdb.sql + original_connection_sql = duckdb.DuckDBPyConnection.sql + unpatch = patch_duckdb_for_wasm() + assert duckdb.sql is original_sql + assert duckdb.DuckDBPyConnection.sql is original_connection_sql + unpatch() + assert duckdb.sql is original_sql + assert duckdb.DuckDBPyConnection.sql is original_connection_sql + + @staticmethod + def test_module_sql_rewrites_remote_literal_and_preserves_params() -> None: + import duckdb + + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + relation = duckdb.sql( + """ + SELECT make FROM 'https://datasets.marimo.app/cars.csv' + WHERE mpg > ? + """, + params=[20], + ) + finally: + unpatch() + + assert relation.fetchall() == [("ford",)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_module_query_rewrites_reader_call() -> None: + import duckdb + + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"1;ford\n2;toyota\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + relation = duckdb.query( + """ + SELECT column1 FROM read_csv( + 'https://datasets.marimo.app/cars.csv', + delim=';', header=false + ) + """ + ) + finally: + unpatch() + + assert relation.fetchall() == [("ford",), ("toyota",)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_module_query_df_rewrites_reader_call() -> None: + import duckdb + import pandas as pd + + local_df = pd.DataFrame({"make": ["ford"], "score": [7]}) + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + relation = duckdb.query_df( + df=local_df, + virtual_table_name="query_df_local", + sql_query=""" + SELECT query_df_local.score, cars.mpg + FROM query_df_local + JOIN read_csv('https://datasets.marimo.app/cars.csv') AS cars + USING (make) + """, + ) + rows = relation.fetchall() + finally: + unpatch() + duckdb.close() + + assert rows == [(7, 25)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_module_query_df_avoids_existing_catalog_table_names() -> None: + import duckdb + import pandas as pd + + local_df = pd.DataFrame({"make": ["ford"], "score": [7]}) + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ), + ): + unpatch = patch_duckdb_for_wasm() + try: + duckdb.sql( + """ + CREATE OR REPLACE TABLE "__MARIMO_WASM_DUCKDB_REMOTE_0" + AS SELECT 'ford' AS make, 99 AS mpg + """ + ) + relation = duckdb.query_df( + df=local_df, + virtual_table_name="query_df_local_collision", + sql_query=""" + SELECT query_df_local_collision.score, cars.mpg + FROM query_df_local_collision + JOIN read_csv('https://datasets.marimo.app/cars.csv') AS cars + USING (make) + """, + ) + rows = relation.fetchall() + finally: + unpatch() + duckdb.close() + + assert rows == [(7, 25)] + + @staticmethod + def test_module_sql_preserves_caller_replacement_scan() -> None: + import duckdb + import pandas as pd + + local_df = pd.DataFrame({"x": [1, 2]}) + with mock_pyodide(): + unpatch = patch_duckdb_for_wasm() + try: + rows = duckdb.sql("SELECT sum(x) FROM local_df").fetchall() + finally: + unpatch() + + assert rows == [(3,)] + + @staticmethod + def test_module_execute_rewrites_before_side_effects() -> None: + import duckdb + + table_name = "__marimo_duckdb_wasm_execute_test" + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ), + ): + unpatch = patch_duckdb_for_wasm() + try: + duckdb.execute( + f""" + CREATE OR REPLACE TABLE {table_name} AS + SELECT make FROM 'https://datasets.marimo.app/cars.csv' + WHERE mpg > ? + """, + [20], + ) + rows = duckdb.sql( + f"SELECT make FROM {table_name} ORDER BY make" + ).fetchall() + finally: + unpatch() + duckdb.close() + + assert rows == [("ford",)] + + @staticmethod + def test_module_execute_creates_table_from_remote_literal() -> None: + import duckdb + + table_name = "__marimo_duckdb_wasm_execute_create_test" + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + duckdb.execute( + f""" + CREATE OR REPLACE TABLE {table_name} AS ( + SELECT * FROM 'https://datasets.marimo.app/cars.csv' + ) + """ + ) + rows = duckdb.sql( + f"SELECT make, mpg FROM {table_name} ORDER BY make" + ).fetchall() + finally: + unpatch() + duckdb.close() + + assert rows == [("ford", 25), ("toyota", 18)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_connection_methods_preserve_caller_replacement_scans() -> None: + import duckdb + import pandas as pd + + local_df = pd.DataFrame({"x": [1, 2]}) + connection = duckdb.connect(":memory:") + try: + with mock_pyodide(): + unpatch = patch_duckdb_for_wasm() + try: + sql_rows = connection.sql( + "SELECT sum(x) FROM local_df" + ).fetchall() + query_rows = connection.query( + "SELECT count(*) FROM local_df" + ).fetchall() + execute_rows = connection.execute( + "SELECT max(x) FROM local_df" + ).fetchall() + finally: + unpatch() + finally: + connection.close() + + assert sql_rows == [(3,)] + assert query_rows == [(2,)] + assert execute_rows == [(2,)] + + @staticmethod + def test_connection_execute_creates_table_from_remote_literal() -> None: + import duckdb + + table_name = "__marimo_duckdb_wasm_conn_execute_create_test" + connection = duckdb.connect(":memory:") + try: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + connection.execute( + f""" + CREATE OR REPLACE TABLE {table_name} AS ( + SELECT * FROM 'https://datasets.marimo.app/cars.csv' + ) + """ + ) + rows = connection.execute( + f"SELECT make, mpg FROM {table_name} ORDER BY make" + ).fetchall() + finally: + unpatch() + finally: + connection.close() + + assert rows == [("ford", 25), ("toyota", 18)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_connection_sql_joins_caller_local_and_remote_tables() -> None: + import duckdb + import pandas as pd + + local_df = pd.DataFrame({"make": ["ford"], "score": [7]}) + connection = duckdb.connect(":memory:") + try: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + relation = connection.sql( + """ + SELECT local_df.score, cars.mpg + FROM local_df + JOIN read_csv('https://datasets.marimo.app/cars.csv') AS cars + USING (make) + """ + ) + rows = relation.fetchall() + finally: + unpatch() + finally: + connection.close() + + assert rows == [(7, 25)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_connection_sql_avoids_existing_catalog_table_names() -> None: + import duckdb + + connection = duckdb.connect(":memory:") + try: + connection.sql( + """ + CREATE OR REPLACE TABLE "__MARIMO_WASM_DUCKDB_REMOTE_0" + AS SELECT 99 AS mpg + """ + ) + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"mpg\n25\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + rows = connection.sql( + """ + SELECT mpg + FROM 'https://datasets.marimo.app/cars.csv' + """ + ).fetchall() + finally: + unpatch() + finally: + connection.close() + + assert rows == [(25,)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + @mock_pyodide() + def test_unpatch_restores_module_functions_and_connection_methods() -> ( + None + ): + import duckdb + + original_sql = duckdb.sql + original_query_df = duckdb.query_df + original_execute = duckdb.execute + original_connection_sql = duckdb.DuckDBPyConnection.sql + original_connection_execute = duckdb.DuckDBPyConnection.execute + original_connection_read_csv = duckdb.DuckDBPyConnection.read_csv + original_connection_read_parquet = ( + duckdb.DuckDBPyConnection.read_parquet + ) + original_connection_read_json = duckdb.DuckDBPyConnection.read_json + + unpatch = patch_duckdb_for_wasm() + assert duckdb.sql is not original_sql + assert duckdb.query_df is not original_query_df + assert duckdb.execute is not original_execute + assert duckdb.DuckDBPyConnection.sql is not original_connection_sql + assert ( + duckdb.DuckDBPyConnection.execute + is not original_connection_execute + ) + assert ( + duckdb.DuckDBPyConnection.read_csv + is not original_connection_read_csv + ) + assert ( + duckdb.DuckDBPyConnection.read_parquet + is not original_connection_read_parquet + ) + assert ( + duckdb.DuckDBPyConnection.read_json + is not original_connection_read_json + ) + + unpatch() + assert duckdb.sql is original_sql + assert duckdb.query_df is original_query_df + assert duckdb.execute is original_execute + assert duckdb.DuckDBPyConnection.sql is original_connection_sql + assert duckdb.DuckDBPyConnection.execute is original_connection_execute + assert ( + duckdb.DuckDBPyConnection.read_csv is original_connection_read_csv + ) + assert ( + duckdb.DuckDBPyConnection.read_parquet + is original_connection_read_parquet + ) + assert ( + duckdb.DuckDBPyConnection.read_json + is original_connection_read_json + ) + + unpatch() From c0218b468dabfcb459a97122f440a3b1107a1370 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Fri, 8 May 2026 13:50:48 +0200 Subject: [PATCH 05/22] refactor: clarify DuckDB WASM patch specs --- marimo/_runtime/_wasm/_duckdb/__init__.py | 192 ++++++++++++++-------- 1 file changed, 127 insertions(+), 65 deletions(-) diff --git a/marimo/_runtime/_wasm/_duckdb/__init__.py b/marimo/_runtime/_wasm/_duckdb/__init__.py index a810243780c..f384c77b966 100644 --- a/marimo/_runtime/_wasm/_duckdb/__init__.py +++ b/marimo/_runtime/_wasm/_duckdb/__init__.py @@ -20,7 +20,7 @@ import functools import inspect from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NamedTuple from marimo import _loggers from marimo._dependencies.dependencies import DependencyManager @@ -46,36 +46,112 @@ LOGGER = _loggers.marimo_logger() -# DuckDB SQL APIs also accept non-string query objects; this marks "not found". + +class _EvalBindingNames(NamedTuple): + original: str + args: str + kwargs: str + + +class _SqlApiSpec(NamedTuple): + query_positional_index: int + query_keyword_names: tuple[str, ...] + + +class _DirectReaderSpec(NamedTuple): + source_keyword_names: tuple[str, ...] + + +class _DirectReaderCallSpec(NamedTuple): + source_positional_index: int + connection_positional_index: int | None + + +# DuckDB SQL APIs can receive non-string query objects. This sentinel lets us +# tell an omitted query argument apart from a present value such as None. _MISSING = object() -_DIRECT_READER_NAMES: tuple[str, ...] = ( - "read_csv", - "read_parquet", - "read_json", -) _SQL_CALL_EXPRESSION = "{original}(*{args}, **{kwargs})" -_SQL_HELPER_NAME_BASES = ( - "__marimo_wasm_duckdb_original", - "__marimo_wasm_duckdb_args", - "__marimo_wasm_duckdb_kwargs", + +# The SQL wrappers invoke the original DuckDB callable through eval so DuckDB +# can still see caller-local replacement scans. These are the local binding +# names used for that eval call after collision checks. +_EVAL_BINDING_NAME_BASES = _EvalBindingNames( + original="__marimo_wasm_duckdb_original", + args="__marimo_wasm_duckdb_args", + kwargs="__marimo_wasm_duckdb_kwargs", ) -_MODULE_SQL_FUNCTIONS: dict[str, tuple[int, tuple[str, ...]]] = { - "sql": (0, ("query",)), - "query": (0, ("query",)), - "execute": (0, ("query",)), - "query_df": (2, ("sql_query", "query")), + +# Module-level DuckDB SQL functions put the SQL string in different argument +# slots. The specs identify where wrappers should look for the query text. +_MODULE_SQL_FUNCTIONS: dict[str, _SqlApiSpec] = { + "sql": _SqlApiSpec( + query_positional_index=0, + query_keyword_names=("query",), + ), + "query": _SqlApiSpec( + query_positional_index=0, + query_keyword_names=("query",), + ), + "execute": _SqlApiSpec( + query_positional_index=0, + query_keyword_names=("query",), + ), + "query_df": _SqlApiSpec( + query_positional_index=2, + query_keyword_names=("sql_query", "query"), + ), } -_CONNECTION_SQL_METHODS: dict[str, tuple[int, tuple[str, ...]]] = { - "sql": (1, ("query",)), - "query": (1, ("query",)), - "execute": (1, ("query",)), + +# Bound connection methods include the connection as args[0], so their query +# argument starts one slot later than the module-level functions. +_CONNECTION_SQL_METHODS: dict[str, _SqlApiSpec] = { + "sql": _SqlApiSpec( + query_positional_index=1, + query_keyword_names=("query",), + ), + "query": _SqlApiSpec( + query_positional_index=1, + query_keyword_names=("query",), + ), + "execute": _SqlApiSpec( + query_positional_index=1, + query_keyword_names=("query",), + ), } -_SOURCE_KWARGS: dict[str, tuple[str, ...]] = { - "read_csv": ("path_or_buffer", "source", "file", "path"), - "read_parquet": ("file_glob", "file_globs", "source", "file", "path"), - "read_json": ("path_or_buffer", "source", "file", "path"), + +# Direct reader APIs use reader-specific keyword names for the file source. +_DIRECT_READER_SPECS: dict[str, _DirectReaderSpec] = { + "read_csv": _DirectReaderSpec( + source_keyword_names=("path_or_buffer", "source", "file", "path"), + ), + "read_parquet": _DirectReaderSpec( + source_keyword_names=( + "file_glob", + "file_globs", + "source", + "file", + "path", + ), + ), + "read_json": _DirectReaderSpec( + source_keyword_names=("path_or_buffer", "source", "file", "path"), + ), } +# Module-level direct readers receive the file source as their first positional +# argument. They may also receive an explicit connection= keyword. +_MODULE_DIRECT_READER_CALL = _DirectReaderCallSpec( + source_positional_index=0, + connection_positional_index=None, +) + +# Connection direct readers receive the bound connection as args[0], followed +# by the file source as args[1]. +_CONNECTION_DIRECT_READER_CALL = _DirectReaderCallSpec( + source_positional_index=1, + connection_positional_index=0, +) + @dataclass(frozen=True) class WasmDuckDBQueryPatch: @@ -92,13 +168,6 @@ class WasmDuckDBSqlResult: value: Any -@dataclass(frozen=True) -class _EvalBindingNames: - original: str - args: str - kwargs: str - - class _RemoteTableNames: """Track URL sources and generated table names for one SQL rewrite.""" @@ -191,14 +260,13 @@ def patch_duckdb_for_wasm() -> Unpatch: _require_sqlglot() patches = WasmPatchSet() - for function_name in _DIRECT_READER_NAMES: + for function_name in _DIRECT_READER_SPECS: patches.replace( duckdb, function_name, _make_direct_reader_wrapper( function_name, - source_arg_index=0, - connection_arg_index=None, + call_spec=_MODULE_DIRECT_READER_CALL, ), ) patches.replace( @@ -206,32 +274,25 @@ def patch_duckdb_for_wasm() -> Unpatch: function_name, _make_direct_reader_wrapper( function_name, - source_arg_index=1, - connection_arg_index=0, + call_spec=_CONNECTION_DIRECT_READER_CALL, ), ) - for function_name, ( - query_index, - query_kwargs, - ) in _MODULE_SQL_FUNCTIONS.items(): + for function_name, spec in _MODULE_SQL_FUNCTIONS.items(): patches.replace( duckdb, function_name, _make_sql_api_wrapper( - query_arg_index=query_index, - query_kwarg_names=query_kwargs, + query_arg_index=spec.query_positional_index, + query_kwarg_names=spec.query_keyword_names, ), ) - for method_name, ( - query_index, - query_kwargs, - ) in _CONNECTION_SQL_METHODS.items(): + for method_name, spec in _CONNECTION_SQL_METHODS.items(): patches.replace( duckdb.DuckDBPyConnection, method_name, _make_sql_api_wrapper( - query_arg_index=query_index, - query_kwarg_names=query_kwargs, + query_arg_index=spec.query_positional_index, + query_kwarg_names=spec.query_keyword_names, ), ) return patches.unpatch_all() @@ -466,11 +527,11 @@ def _identifier_string_args(args: tuple[Any, ...]) -> tuple[str, ...]: def _eval_binding_names(reserved_names: Sequence[str]) -> _EvalBindingNames: used = set(reserved_names) - original = _unused_name(_SQL_HELPER_NAME_BASES[0], used) + original = _unused_name(_EVAL_BINDING_NAME_BASES.original, used) used.add(original) - args = _unused_name(_SQL_HELPER_NAME_BASES[1], used) + args = _unused_name(_EVAL_BINDING_NAME_BASES.args, used) used.add(args) - kwargs = _unused_name(_SQL_HELPER_NAME_BASES[2], used) + kwargs = _unused_name(_EVAL_BINDING_NAME_BASES.kwargs, used) return _EvalBindingNames(original=original, args=args, kwargs=kwargs) @@ -515,8 +576,7 @@ def _show_duckdb_tables( def _make_direct_reader_wrapper( function_name: str, *, - source_arg_index: int, - connection_arg_index: int | None, + call_spec: _DirectReaderCallSpec, ) -> WrapperFactory: def _wrap(original: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(original) @@ -525,8 +585,7 @@ def _wrapper(*args: Any, **kwargs: Any) -> Any: function_name, args, kwargs, - source_arg_index=source_arg_index, - connection_arg_index=connection_arg_index, + call_spec=call_spec, ) if source_info is None: return original(*args, **kwargs) @@ -553,8 +612,7 @@ def _direct_reader_source( args: tuple[Any, ...], kwargs: Mapping[str, Any], *, - source_arg_index: int, - connection_arg_index: int | None, + call_spec: _DirectReaderCallSpec, ) -> tuple[RemoteFileSource, Any] | None: options = dict(kwargs) try: @@ -562,17 +620,17 @@ def _direct_reader_source( function_name, args, options, - source_arg_index=source_arg_index, + call_spec=call_spec, ) except TypeError: return None if rest_args: return None - if connection_arg_index is None: + if call_spec.connection_positional_index is None: connection = options.pop("connection", None) else: - connection = args[connection_arg_index] + connection = args[call_spec.connection_positional_index] source_info = remote_file_source_from_reader_args( function_name, source, options ) @@ -586,14 +644,18 @@ def _pop_source_argument( args: tuple[Any, ...], kwargs: dict[str, Any], *, - source_arg_index: int, + call_spec: _DirectReaderCallSpec, ) -> tuple[Any, tuple[Any, ...]]: - if len(args) > source_arg_index: - return args[source_arg_index], args[source_arg_index + 1 :] + source_positional_index = call_spec.source_positional_index + if len(args) > source_positional_index: + return ( + args[source_positional_index], + args[source_positional_index + 1 :], + ) - for key in _SOURCE_KWARGS[function_name]: + for key in _DIRECT_READER_SPECS[function_name].source_keyword_names: if key in kwargs: - return kwargs.pop(key), args[source_arg_index + 1 :] + return kwargs.pop(key), args[source_positional_index + 1 :] raise TypeError(f"Missing source argument for duckdb.{function_name}") From 2fc69d07745cb3b12e0b7498c9abce26da6f7d5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Fri, 8 May 2026 15:34:57 +0200 Subject: [PATCH 06/22] docs: clarify def details --- marimo/_runtime/_wasm/_duckdb/__init__.py | 41 +++++++++++++++------- marimo/_runtime/_wasm/_duckdb/dataframe.py | 14 +++++++- marimo/_runtime/_wasm/_duckdb/io.py | 31 +++++++++++++++- marimo/_runtime/_wasm/_duckdb/sources.py | 16 +++++++-- 4 files changed, 85 insertions(+), 17 deletions(-) diff --git a/marimo/_runtime/_wasm/_duckdb/__init__.py b/marimo/_runtime/_wasm/_duckdb/__init__.py index f384c77b966..3513e2dc7ed 100644 --- a/marimo/_runtime/_wasm/_duckdb/__init__.py +++ b/marimo/_runtime/_wasm/_duckdb/__init__.py @@ -1,18 +1,19 @@ # Copyright 2026 Marimo. All rights reserved. -"""WASM-only DuckDB fallbacks for remote file scans. +"""Install WASM-only DuckDB fallbacks for remote file scans. DuckDB-WASM cannot use ``httpfs``. marimo fetches supported URLs itself, materializes supported files as pandas DataFrames, then hands those frames -back to DuckDB. +back to DuckDB through replacement scans. We have two concrete use cases to patch: * **Direct read methods** — ``duckdb.read_csv/read_parquet/read_json`` are - wrapped with :class:`WasmPatchSet`; supported remote URLs are fetched by + wrapped with :class:`WasmPatchSet`. Supported remote URLs are fetched by marimo and returned as DuckDB relations. * **SQL remote scans** — raw DuckDB APIs and marimo's ``mo.sql`` path call the - same rewrite helper, which replaces supported URLs with generated - replacement-scan tables backed by fetched pandas DataFrames. + same sqlglot rewrite helper. It replaces supported URL scans with generated + table names and evaluates the original DuckDB call with fetched DataFrames + in scope so DuckDB replacement scans can resolve them. """ from __future__ import annotations @@ -212,15 +213,19 @@ def patch_duckdb_query_for_wasm( *, reserved_names: Sequence[str] = (), ) -> WasmDuckDBQueryPatch | None: - """Rewrite remote file sources to generated DataFrame names. + """Replace supported remote file reads with generated table names. - Example: ``SELECT * FROM read_csv('https://example.com/cars.csv')`` - becomes ``SELECT * FROM __marimo_wasm_duckdb_remote_0``, and - ``tables["__marimo_wasm_duckdb_remote_0"]`` holds the fetched DataFrame. + For example, ``SELECT * FROM read_csv('https://example.com/cars.csv')`` + becomes ``SELECT * FROM __marimo_wasm_duckdb_remote_0`` when suffix ``0`` + is free. The returned ``tables`` mapping binds that name to the fetched + DataFrame. If the query or ``reserved_names`` already use that identifier, + the rewriter uses the next free suffix. - In Pyodide this raises if sqlglot is unavailable. Returns ``None`` when - not running in Pyodide, when the query has no supported remote file - source, or when the query cannot be parsed. + In Pyodide this raises if sqlglot is unavailable. Returns ``None`` when: + + - marimo is not running in Pyodide; + - the query has no supported remote file source; + - the query cannot be parsed. """ if not is_pyodide(): return None @@ -419,6 +424,8 @@ def _make_sql_api_wrapper( query_arg_index: int, query_kwarg_names: tuple[str, ...], ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Wrap DuckDB SQL APIs while preserving the caller's local namespace.""" + def _make_wrapper(original: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(original) def _wrapper(*args: Any, **kwargs: Any) -> Any: @@ -451,6 +458,7 @@ def _query_argument( query_arg_index: int, query_kwarg_names: tuple[str, ...], ) -> Any: + """Find the query text without assuming the caller used positional args.""" if len(args) > query_arg_index: return args[query_arg_index] for name in query_kwarg_names: @@ -467,6 +475,7 @@ def _replace_query_argument( query_arg_index: int, query_kwarg_names: tuple[str, ...], ) -> tuple[tuple[Any, ...], dict[str, Any]]: + """Replace the query at the same call site shape used by the caller.""" kwargs_dict = dict(kwargs) if len(args) > query_arg_index: patched_args = list(args) @@ -490,6 +499,7 @@ def _eval_duckdb_original_call( binding_names: _EvalBindingNames, extra_locals: Mapping[str, Any] | None = None, ) -> Any: + """Evaluate the original call where DuckDB can see replacement DataFrames.""" locals_for_eval = dict(eval_locals) if extra_locals is not None: locals_for_eval.update(extra_locals) @@ -551,6 +561,7 @@ def _duckdb_catalog_names( original: Callable[..., Any], args: tuple[Any, ...], ) -> tuple[str, ...]: + """Reserve existing DuckDB table names before generating replacements.""" try: relation = _show_duckdb_tables(original, args) rows = relation.fetchall() @@ -563,6 +574,7 @@ def _show_duckdb_tables( original: Callable[..., Any], args: tuple[Any, ...], ) -> Any: + """Run ``SHOW TABLES`` through the same DuckDB entry point being patched.""" import duckdb original_call = inspect.unwrap(original) @@ -614,6 +626,7 @@ def _direct_reader_source( *, call_spec: _DirectReaderCallSpec, ) -> tuple[RemoteFileSource, Any] | None: + """Return a remote source only for direct reader calls we can emulate.""" options = dict(kwargs) try: source, rest_args = _pop_source_argument( @@ -646,6 +659,7 @@ def _pop_source_argument( *, call_spec: _DirectReaderCallSpec, ) -> tuple[Any, tuple[Any, ...]]: + """Remove the source argument so remaining kwargs are pure reader options.""" source_positional_index = call_spec.source_positional_index if len(args) > source_positional_index: return ( @@ -688,6 +702,7 @@ def _replace_remote_sources( statements: Sequence[exp.Expression], table_names: _RemoteTableNames, ) -> list[exp.Expression]: + """Replace supported remote table nodes while preserving aliases.""" from sqlglot import exp def replace_table(node: exp.Expression) -> exp.Expression: @@ -715,6 +730,7 @@ def replace_table(node: exp.Expression) -> exp.Expression: def _reserved_sql_names( statements: Sequence[exp.Expression], ) -> tuple[str, ...]: + """Collect SQL identifiers that generated table names must not shadow.""" from sqlglot import exp names: set[str] = set() @@ -731,6 +747,7 @@ def _reserved_sql_names( def _format_duckdb_query( statements: Sequence[exp.Expression], *, original_query: str ) -> str: + """Serialize sqlglot statements without dropping a trailing semicolon.""" patched_query = "; ".join( statement.sql(dialect="duckdb") for statement in statements ) diff --git a/marimo/_runtime/_wasm/_duckdb/dataframe.py b/marimo/_runtime/_wasm/_duckdb/dataframe.py index 528bb42fd12..e1aae2bc0d5 100644 --- a/marimo/_runtime/_wasm/_duckdb/dataframe.py +++ b/marimo/_runtime/_wasm/_duckdb/dataframe.py @@ -1,5 +1,11 @@ # Copyright 2026 Marimo. All rights reserved. -"""Decode fetched DuckDB file bytes into pandas DataFrames.""" +"""Decode fetched DuckDB file bytes into pandas DataFrames. + +The WASM patch fetches remote bytes in Python, but DuckDB's Python readers +still expect local paths for CSV, JSON, and parquet parsing. This module +materializes fetched bytes through short-lived temp files where DuckDB parsing +is needed and synthesizes DataFrames for ``read_text`` and ``read_blob``. +""" from __future__ import annotations @@ -29,6 +35,7 @@ def read_csv_dataframe( data: bytes, options: Mapping[str, Any], *, url: str ) -> pd.DataFrame: + """Read CSV/TSV bytes with a suffix that preserves compression hints.""" import duckdb return _read_temp_dataframe( @@ -71,6 +78,7 @@ def read_json_dataframe( def read_json_objects_dataframe( data: bytes, options: Mapping[str, Any], *, url: str ) -> pd.DataFrame: + """Read JSON-object bytes through DuckDB's SQL-only table function.""" return _read_temp_dataframe( data, suffix=_temp_suffix( @@ -99,6 +107,7 @@ def _read_json_objects_path( def read_text_dataframe(data: bytes, url: str) -> pd.DataFrame: + """Match DuckDB's ``read_text`` shape for an already-fetched object.""" import pandas as pd return pd.DataFrame( @@ -112,6 +121,7 @@ def read_text_dataframe(data: bytes, url: str) -> pd.DataFrame: def read_blob_dataframe(data: bytes, url: str) -> pd.DataFrame: + """Match DuckDB's ``read_blob`` shape for an already-fetched object.""" import pandas as pd return pd.DataFrame( @@ -127,6 +137,7 @@ def read_blob_dataframe(data: bytes, url: str) -> pd.DataFrame: def append_filename_column( df: pd.DataFrame, url: str, column_name: str ) -> pd.DataFrame: + """Apply DuckDB's filename option after bytes have been decoded.""" if column_name in df.columns: raise ValueError( f'Option filename adds column "{column_name}", but a column with this ' @@ -158,6 +169,7 @@ def _read_temp_dataframe( def _temp_suffix(url: str, *, suffixes: tuple[str, ...], default: str) -> str: + """Preserve file extensions so DuckDB can infer format details.""" path = urlparse(url).path.lower() for suffix in suffixes: if path.endswith(suffix): diff --git a/marimo/_runtime/_wasm/_duckdb/io.py b/marimo/_runtime/_wasm/_duckdb/io.py index 47c049cced9..f82cbcdbaed 100644 --- a/marimo/_runtime/_wasm/_duckdb/io.py +++ b/marimo/_runtime/_wasm/_duckdb/io.py @@ -1,5 +1,13 @@ # Copyright 2026 Marimo. All rights reserved. -"""Map DuckDB remote files to readers.""" +"""Resolve DuckDB remote file reads into fetch and DataFrame reader steps. + +DuckDB's native network scanner is unavailable in Pyodide. This module is the +compatibility layer that recognizes supported URL shapes, validates the reader +options we can reproduce, fetches bytes through WASM fetch shim, and +dispatches to a local DataFrame reader. Unsupported readers or option +combinations return ``None`` so callers can use the unpatched DuckDB path +and surface the underlying error. +""" from __future__ import annotations @@ -40,6 +48,7 @@ class RemoteFile: fetcher_name: str def fetch(self) -> _FetchedBytes: + """Fetch through a named strategy so sources stay hashable.""" return _fetcher_by_name(self.fetcher_name).fetch(self.url) @@ -265,9 +274,11 @@ class RemoteFileSource: options: tuple[tuple[str, Any], ...] = () def read_options(self) -> dict[str, Any]: + """Expose sorted, hashable options as regular reader kwargs.""" return dict(self.options) def read_dataframe(self) -> pd.DataFrame: + """Read one or more remote files using DuckDB-compatible concat rules.""" frames = [self._read_file_dataframe(file) for file in self.files] if len(frames) == 1: return frames[0] @@ -287,6 +298,7 @@ def read_dataframe(self) -> pd.DataFrame: return pd.concat(frames, ignore_index=True) def _read_file_dataframe(self, file: RemoteFile) -> pd.DataFrame: + """Fetch bytes, decode them, then apply DuckDB's filename option.""" fetched = file.fetch() options = self.read_options() reader = _reader_by_name(self.reader_name) @@ -300,6 +312,7 @@ def _read_file_dataframe(self, file: RemoteFile) -> pd.DataFrame: def remote_file_from_url(url: str) -> RemoteFile | None: + """Return a fetchable remote file only for URL schemes marimo supports.""" fetcher = _fetcher_for_url(url) if fetcher is None: return None @@ -311,6 +324,7 @@ def remote_file_source_from_reader_args( source: Any, raw_options: Mapping[str, Any], ) -> RemoteFileSource | None: + """Map a DuckDB reader call to a reproducible remote DataFrame source.""" reader = reader_for_function(function_name) if reader is None: return None @@ -328,6 +342,7 @@ def remote_file_source_from_reader_args( def _remote_files_from_source_arg( source: Any, ) -> tuple[RemoteFile, ...] | None: + """Accept DuckDB source shapes that are static URL strings or URL lists.""" if isinstance(source, str): file = remote_file_from_url(source) return (file,) if file is not None else None @@ -372,6 +387,7 @@ def _reader_by_name( def reader_for_url(url: str) -> _DataFrameReader | None: + """Infer a reader from direct URL table syntax such as ``FROM 'x.csv'``.""" path = urlparse(url).path.lower() return next( ( @@ -386,6 +402,7 @@ def reader_for_url(url: str) -> _DataFrameReader | None: def reader_for_function( function_name: str, ) -> _DataFrameReader | None: + """Resolve DuckDB table-function names to marimo's fallback readers.""" return next( ( reader @@ -397,6 +414,7 @@ def reader_for_function( def _csv_reader_options(options: Mapping[str, Any]) -> dict[str, Any]: + """Drop options implemented outside DuckDB's CSV reader call.""" return { key: value for key, value in options.items() @@ -405,6 +423,7 @@ def _csv_reader_options(options: Mapping[str, Any]) -> dict[str, Any]: def _json_reader_options(options: Mapping[str, Any]) -> dict[str, Any]: + """Translate DuckDB JSON option spelling to DuckDB Python API spelling.""" return { key: _normalize_json_reader_option(key, value) for key, value in options.items() @@ -413,12 +432,14 @@ def _json_reader_options(options: Mapping[str, Any]) -> dict[str, Any]: def _normalize_delimiter(value: Any) -> Any: + """Convert SQL's escaped tab literal to the byte delimiter DuckDB expects.""" if value == r"\t": return "\t" return value def _normalize_json_reader_option(key: str, value: Any) -> Any: + """Normalize JSON values whose SQL names differ from Python reader values.""" if key == "compression": return _normalize_json_compression(value) if key == "format": @@ -427,6 +448,7 @@ def _normalize_json_reader_option(key: str, value: Any) -> Any: def _normalize_json_compression(value: Any) -> str: + """Map SQL compression aliases to DuckDB Python JSON reader values.""" compression = str(value).lower() if compression == "auto": return "auto_detect" @@ -436,6 +458,7 @@ def _normalize_json_compression(value: Any) -> str: def _normalize_json_format(value: Any) -> str: + """Map DuckDB SQL JSON format aliases to Python reader values.""" fmt = str(value).lower() if fmt == "ndjson": return "newline_delimited" @@ -445,6 +468,7 @@ def _normalize_json_format(value: Any) -> str: def _apply_json_option(options: dict[str, Any], key: str, value: Any) -> bool: + """Keep JSON options only when the fallback can safely pass them through.""" if key == "format": options["format"] = _normalize_json_format(value) return True @@ -457,12 +481,14 @@ def _apply_json_option(options: dict[str, Any], key: str, value: Any) -> bool: def _is_safe_reader_option_name(key: str) -> bool: + """Reject option names that cannot be passed as Python reader kwargs.""" return key.isidentifier() def _apply_common_table_option( options: dict[str, Any], key: str, value: Any ) -> bool: + """Handle options marimo applies after per-file reads are decoded.""" if key == "filename" and isinstance(value, bool | str): options["filename"] = value return True @@ -475,6 +501,7 @@ def _apply_common_table_option( def _apply_compression_option( options: dict[str, Any], key: str, value: Any ) -> bool: + """Accept only compression modes supported by the byte-fetch fallback.""" if key == "compression" and _is_supported_compression(value): options["compression"] = str(value).lower() return True @@ -484,10 +511,12 @@ def _apply_compression_option( def _apply_shared_source_option( options: dict[str, Any], key: str, value: Any ) -> bool: + """Apply source options shared by CSV, parquet, and JSON fallbacks.""" return _apply_compression_option( options, key, value ) or _apply_common_table_option(options, key, value) def _is_supported_compression(value: Any) -> bool: + """Limit compression to modes the fallback knows how to preserve.""" return str(value).lower() in {"auto", "none", "gzip"} diff --git a/marimo/_runtime/_wasm/_duckdb/sources.py b/marimo/_runtime/_wasm/_duckdb/sources.py index d8337f664f0..1b76cfa10e5 100644 --- a/marimo/_runtime/_wasm/_duckdb/sources.py +++ b/marimo/_runtime/_wasm/_duckdb/sources.py @@ -1,8 +1,11 @@ # Copyright 2026 Marimo. All rights reserved. -"""Resolve sqlglot table nodes to remote file sources. +"""Resolve sqlglot DuckDB table nodes to remote file sources. -Handles direct URL tables and reader calls such as -``read_csv('https://example.com/cars.csv')``. +The SQL patch should only rewrite queries it can execute with fetched +DataFrames. This module recognizes direct URL table syntax and supported +``read_*`` table functions, extracts literal URL/options from sqlglot's AST, +and returns ``None`` for dynamic expressions so they continue through DuckDB +unchanged. """ from __future__ import annotations @@ -29,6 +32,7 @@ def remote_file_source_from_table( table: exp.Table, ) -> RemoteFileSource | None: + """Return a remote source for supported direct URLs or reader calls.""" table_name = table.name if table_name: reader = reader_for_url(table_name) @@ -58,6 +62,7 @@ def remote_file_source_from_table( def _table_function_name(table_expr: exp.Expression) -> str | None: + """Normalize sqlglot's version-dependent DuckDB reader node names.""" import sqlglot.expressions as exp # sqlglot versions model first-party DuckDB readers either as explicit @@ -76,6 +81,7 @@ def _table_function_name(table_expr: exp.Expression) -> str | None: def _table_function_args(table_expr: exp.Expression) -> list[exp.Expression]: + """Return table-function args despite sqlglot reader-node differences.""" import sqlglot.expressions as exp read_csv = getattr(exp, "ReadCSV", None) @@ -88,6 +94,7 @@ def _table_function_args(table_expr: exp.Expression) -> list[exp.Expression]: def _read_function_source( args: Sequence[exp.Expression], ) -> str | tuple[str, ...] | None: + """Accept only literal URL sources that can be fetched before execution.""" import sqlglot.expressions as exp if not args: @@ -112,6 +119,7 @@ def _read_function_source( def _read_function_options( option_exprs: Sequence[exp.Expression], ) -> dict[str, Any] | None: + """Decode literal keyword options from a DuckDB table-function call.""" options: dict[str, Any] = {} for option_expr in option_exprs: option = _read_function_option(option_expr) @@ -125,6 +133,7 @@ def _read_function_options( def _read_function_option( option_expr: exp.Expression, ) -> tuple[str, Any] | None: + """Return one static option or ``None`` for unsupported expressions.""" import sqlglot.expressions as exp property_eq = getattr(exp, "PropertyEQ", None) @@ -149,6 +158,7 @@ def _read_function_option( def _literal_value(value_expr: exp.Expression) -> Any: + """Convert sqlglot literals while preserving falsy values via _MISSING.""" import sqlglot.expressions as exp if isinstance(value_expr, exp.Boolean): From 32836708de767e68862486135b4fe4e64939ccc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Fri, 8 May 2026 17:12:29 +0200 Subject: [PATCH 07/22] fix: preserve DuckDB WASM remote reader semantics --- marimo/_runtime/_wasm/_duckdb/dataframe.py | 22 ++++++-- marimo/_runtime/_wasm/_duckdb/io.py | 18 ++++++- marimo/_runtime/_wasm/_duckdb/sources.py | 9 ++-- tests/_runtime/test_duckdb_wasm.py | 62 ++++++++++++++++++++++ 4 files changed, 101 insertions(+), 10 deletions(-) diff --git a/marimo/_runtime/_wasm/_duckdb/dataframe.py b/marimo/_runtime/_wasm/_duckdb/dataframe.py index e1aae2bc0d5..27ebe070a40 100644 --- a/marimo/_runtime/_wasm/_duckdb/dataframe.py +++ b/marimo/_runtime/_wasm/_duckdb/dataframe.py @@ -30,6 +30,13 @@ ".jsonl", ".json", ) +_JSON_OBJECT_FUNCTIONS = frozenset( + { + "read_json_objects", + "read_json_objects_auto", + "read_ndjson_objects", + } +) def read_csv_dataframe( @@ -76,9 +83,14 @@ def read_json_dataframe( def read_json_objects_dataframe( - data: bytes, options: Mapping[str, Any], *, url: str + data: bytes, options: Mapping[str, Any], *, url: str, function_name: str ) -> pd.DataFrame: """Read JSON-object bytes through DuckDB's SQL-only table function.""" + if function_name not in _JSON_OBJECT_FUNCTIONS: + raise ValueError( + f"Unsupported DuckDB JSON object reader: {function_name}" + ) + return _read_temp_dataframe( data, suffix=_temp_suffix( @@ -86,12 +98,14 @@ def read_json_objects_dataframe( suffixes=_JSON_SUFFIXES, default=".json", ), - reader=lambda path: _read_json_objects_path(path, options), + reader=lambda path: _read_json_objects_path( + path, options, function_name + ), ) def _read_json_objects_path( - path: str, options: Mapping[str, Any] + path: str, options: Mapping[str, Any], function_name: str ) -> pd.DataFrame: import duckdb @@ -101,7 +115,7 @@ def _read_json_objects_path( query_args = ["?"] query_args.extend(f"{key} := ?" for key, _ in option_items) return duckdb.sql( - f"SELECT * FROM read_json_objects_auto({', '.join(query_args)})", + f"SELECT * FROM {function_name}({', '.join(query_args)})", params=[path, *(value for _, value in option_items)], ).df() diff --git a/marimo/_runtime/_wasm/_duckdb/io.py b/marimo/_runtime/_wasm/_duckdb/io.py index f82cbcdbaed..b0cab2190ad 100644 --- a/marimo/_runtime/_wasm/_duckdb/io.py +++ b/marimo/_runtime/_wasm/_duckdb/io.py @@ -55,6 +55,7 @@ def fetch(self) -> _FetchedBytes: @dataclass(frozen=True) class _ReadRequest: file: _FetchedBytes + function_name: str options: Mapping[str, Any] @@ -219,6 +220,7 @@ def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: request.file.data, _json_reader_options(request.options), url=request.file.url, + function_name=request.function_name, ) @@ -272,6 +274,7 @@ class RemoteFileSource: files: tuple[RemoteFile, ...] reader_name: _ReaderName options: tuple[tuple[str, Any], ...] = () + function_name: str | None = None def read_options(self) -> dict[str, Any]: """Expose sorted, hashable options as regular reader kwargs.""" @@ -302,7 +305,13 @@ def _read_file_dataframe(self, file: RemoteFile) -> pd.DataFrame: fetched = file.fetch() options = self.read_options() reader = _reader_by_name(self.reader_name) - df = reader.read_dataframe(_ReadRequest(file=fetched, options=options)) + df = reader.read_dataframe( + _ReadRequest( + file=fetched, + function_name=self.function_name or reader.function_names[0], + options=options, + ) + ) filename = options.get("filename") if filename is True: return append_filename_column(df, fetched.url, "filename") @@ -336,7 +345,12 @@ def remote_file_source_from_reader_args( options = reader.read_options(function_name, raw_options) if options is None: return None - return RemoteFileSource(files, reader.name, tuple(sorted(options.items()))) + return RemoteFileSource( + files, + reader.name, + tuple(sorted(options.items())), + function_name=function_name, + ) def _remote_files_from_source_arg( diff --git a/marimo/_runtime/_wasm/_duckdb/sources.py b/marimo/_runtime/_wasm/_duckdb/sources.py index 1b76cfa10e5..6ebdd47c7a4 100644 --- a/marimo/_runtime/_wasm/_duckdb/sources.py +++ b/marimo/_runtime/_wasm/_duckdb/sources.py @@ -84,10 +84,11 @@ def _table_function_args(table_expr: exp.Expression) -> list[exp.Expression]: """Return table-function args despite sqlglot reader-node differences.""" import sqlglot.expressions as exp - read_csv = getattr(exp, "ReadCSV", None) - if read_csv is not None and isinstance(table_expr, read_csv): - first = [table_expr.this] if table_expr.this is not None else [] - return [*first, *table_expr.expressions] + for node_name in ("ReadCSV", "ReadParquet"): + read_node = getattr(exp, node_name, None) + if read_node is not None and isinstance(table_expr, read_node): + first = [table_expr.this] if table_expr.this is not None else [] + return [*first, *table_expr.expressions] return list(table_expr.expressions) diff --git a/tests/_runtime/test_duckdb_wasm.py b/tests/_runtime/test_duckdb_wasm.py index e8ff9b158d1..9d0cff212fe 100644 --- a/tests/_runtime/test_duckdb_wasm.py +++ b/tests/_runtime/test_duckdb_wasm.py @@ -14,6 +14,9 @@ patch_duckdb_for_wasm, patch_duckdb_query_for_wasm, ) +from marimo._runtime._wasm._duckdb.sources import ( + remote_file_source_from_table, +) from marimo._sql.engines.duckdb import DuckDBEngine from tests.conftest import ExecReqProvider, mock_pyodide @@ -584,6 +587,65 @@ def test_noop_outside_pyodide() -> None: is None ) + @staticmethod + def test_read_parquet_node_preserves_this_argument() -> None: + from sqlglot import exp + + table = exp.Table( + this=exp.ReadParquet( + this=exp.Literal.string("https://example.com/a.parquet") + ) + ) + + source = remote_file_source_from_table(table) + + assert source is not None + assert source.reader_name == "parquet" + assert [file.url for file in source.files] == [ + "https://example.com/a.parquet" + ] + + @staticmethod + @pytest.mark.parametrize( + "function_name", + [ + "read_json_objects", + "read_json_objects_auto", + "read_ndjson_objects", + ], + ) + def test_json_objects_reader_preserves_requested_function( + monkeypatch: pytest.MonkeyPatch, function_name: str + ) -> None: + import duckdb + import pandas as pd + + from marimo._runtime._wasm._duckdb.dataframe import ( + read_json_objects_dataframe, + ) + + queries: list[str] = [] + + class Relation: + def df(self) -> pd.DataFrame: + return pd.DataFrame({"json": []}) + + def fake_sql(query: str, *, params: list[object]) -> Relation: + del params + queries.append(query) + return Relation() + + monkeypatch.setattr(duckdb, "sql", fake_sql) + + read_json_objects_dataframe( + b'{"a":1}\n', + {}, + url="https://example.com/a.json", + function_name=function_name, + ) + + assert queries == [f"SELECT * FROM {function_name}(?)"] + @staticmethod @pytest.mark.parametrize( "case", From a78037e24125d69a17bcff477621e11109267724 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Fri, 8 May 2026 17:12:37 +0200 Subject: [PATCH 08/22] fix: tighten DuckDB WASM SQL patch wrapper --- marimo/_runtime/_wasm/_duckdb/__init__.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/marimo/_runtime/_wasm/_duckdb/__init__.py b/marimo/_runtime/_wasm/_duckdb/__init__.py index 3513e2dc7ed..71161b239cd 100644 --- a/marimo/_runtime/_wasm/_duckdb/__init__.py +++ b/marimo/_runtime/_wasm/_duckdb/__init__.py @@ -433,6 +433,8 @@ def _wrapper(*args: Any, **kwargs: Any) -> Any: if frame is None or frame.f_back is None: return original(*args, **kwargs) caller_frame = frame.f_back + eval_globals = caller_frame.f_globals + eval_locals = caller_frame.f_locals try: return run_duckdb_sql_with_wasm_patch( original, @@ -440,10 +442,11 @@ def _wrapper(*args: Any, **kwargs: Any) -> Any: kwargs, query_arg_index=query_arg_index, query_kwarg_names=query_kwarg_names, - eval_globals=caller_frame.f_globals, - eval_locals=caller_frame.f_locals, + eval_globals=eval_globals, + eval_locals=eval_locals, ) finally: + del caller_frame del frame return _wrapper @@ -688,6 +691,7 @@ def _require_sqlglot() -> None: def _parse_duckdb_query(query: str) -> list[exp.Expression] | None: import sqlglot + from sqlglot import exp as sqlglot_exp try: parsed = sqlglot.parse(query, read="duckdb") @@ -695,7 +699,11 @@ def _parse_duckdb_query(query: str) -> list[exp.Expression] | None: LOGGER.debug("Failed to parse DuckDB query for WASM patch: %s", e) return None - return [statement for statement in parsed if statement is not None] + return [ + statement + for statement in parsed + if isinstance(statement, sqlglot_exp.Expression) + ] def _replace_remote_sources( From a3ae6926cc11c78650f0c899de78033f6ebc8828 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Sat, 9 May 2026 13:43:41 +0200 Subject: [PATCH 09/22] fix: handle nullable DuckDB SQL summary relations --- marimo/_data/sql_summaries.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/marimo/_data/sql_summaries.py b/marimo/_data/sql_summaries.py index d23ec26c2ec..159047c624c 100644 --- a/marimo/_data/sql_summaries.py +++ b/marimo/_data/sql_summaries.py @@ -65,9 +65,13 @@ def get_sql_stats( FROM {table_name} """ - stats_result: tuple[int, ...] | None = wrapped_sql( - stats_query, connection=None - ).fetchone() + relation = wrapped_sql(stats_query, connection=None) + if relation is None: + raise ValueError( + f"Column {column_name} not found in table {table_name}" + ) + + stats_result: tuple[int, ...] | None = relation.fetchone() if stats_result is None: raise ValueError( f"Column {column_name} not found in table {table_name}" From 31e3e192689c160e170a336b91e66893596e1bf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Sat, 9 May 2026 13:43:47 +0200 Subject: [PATCH 10/22] fix: skip DuckDB catalog lookup for local WASM SQL --- marimo/_runtime/_wasm/_duckdb/__init__.py | 21 ++++++++++++++++++++- tests/_runtime/test_duckdb_wasm.py | 19 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/marimo/_runtime/_wasm/_duckdb/__init__.py b/marimo/_runtime/_wasm/_duckdb/__init__.py index 71161b239cd..c486169ba51 100644 --- a/marimo/_runtime/_wasm/_duckdb/__init__.py +++ b/marimo/_runtime/_wasm/_duckdb/__init__.py @@ -211,6 +211,7 @@ def _duckdb_identifier_key(name: str) -> str: def patch_duckdb_query_for_wasm( query: str, *, + statements: Sequence[exp.Expression] | None = None, reserved_names: Sequence[str] = (), ) -> WasmDuckDBQueryPatch | None: """Replace supported remote file reads with generated table names. @@ -231,7 +232,7 @@ def patch_duckdb_query_for_wasm( return None _require_sqlglot() - statements = _parse_duckdb_query(query) + statements = statements or _parse_duckdb_query(query) if statements is None: return None @@ -375,6 +376,10 @@ def try_run_duckdb_sql_with_wasm_patch( if not isinstance(query, str): return None + statements = _parse_duckdb_query(query) + if statements is None or not _contains_supported_remote_source(statements): + return None + namespace_names = _reserved_namespace_names( eval_globals, eval_locals, @@ -389,6 +394,7 @@ def try_run_duckdb_sql_with_wasm_patch( wasm_patch = patch_duckdb_query_for_wasm( query, + statements=statements, reserved_names=( *namespace_names, binding_names.original, @@ -706,6 +712,19 @@ def _parse_duckdb_query(query: str) -> list[exp.Expression] | None: ] +def _contains_supported_remote_source( + statements: Sequence[exp.Expression], +) -> bool: + """Check for rewrite work before paying for DuckDB catalog inspection.""" + from sqlglot import exp + + return any( + remote_file_source_from_table(table) is not None + for statement in statements + for table in statement.find_all(exp.Table) + ) + + def _replace_remote_sources( statements: Sequence[exp.Expression], table_names: _RemoteTableNames, diff --git a/tests/_runtime/test_duckdb_wasm.py b/tests/_runtime/test_duckdb_wasm.py index 9d0cff212fe..c6a1b1eabac 100644 --- a/tests/_runtime/test_duckdb_wasm.py +++ b/tests/_runtime/test_duckdb_wasm.py @@ -995,6 +995,25 @@ def test_module_sql_preserves_caller_replacement_scan() -> None: assert rows == [(3,)] + @staticmethod + def test_module_sql_skips_catalog_lookup_without_remote_source() -> None: + import duckdb + + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._duckdb._duckdb_catalog_names" + ) as catalog_names, + ): + unpatch = patch_duckdb_for_wasm() + try: + rows = duckdb.sql("SELECT 1").fetchall() + finally: + unpatch() + + assert rows == [(1,)] + catalog_names.assert_not_called() + @staticmethod def test_module_execute_rewrites_before_side_effects() -> None: import duckdb From e4f4bfb96ba6cc8e5d2eae1ecb1e94293612d2c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Sat, 9 May 2026 14:44:41 +0200 Subject: [PATCH 11/22] fix: avoid stale DuckDB WASM view rewrites --- marimo/_runtime/_wasm/_duckdb/__init__.py | 18 ++++++++++++++ tests/_runtime/test_duckdb_wasm.py | 30 +++++++++++++++++++---- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/marimo/_runtime/_wasm/_duckdb/__init__.py b/marimo/_runtime/_wasm/_duckdb/__init__.py index c486169ba51..b3266ee3b77 100644 --- a/marimo/_runtime/_wasm/_duckdb/__init__.py +++ b/marimo/_runtime/_wasm/_duckdb/__init__.py @@ -235,6 +235,8 @@ def patch_duckdb_query_for_wasm( statements = statements or _parse_duckdb_query(query) if statements is None: return None + if _contains_remote_view_definition(statements): + return None table_names = _RemoteTableNames( (*reserved_names, *_reserved_sql_names(statements)) @@ -379,6 +381,8 @@ def try_run_duckdb_sql_with_wasm_patch( statements = _parse_duckdb_query(query) if statements is None or not _contains_supported_remote_source(statements): return None + if _contains_remote_view_definition(statements): + return None namespace_names = _reserved_namespace_names( eval_globals, @@ -725,6 +729,20 @@ def _contains_supported_remote_source( ) +def _contains_remote_view_definition( + statements: Sequence[exp.Expression], +) -> bool: + """Views persist SQL text, so replacement-scan locals would go stale.""" + from sqlglot import exp + + return any( + isinstance(statement, exp.Create) + and str(statement.args.get("kind")).upper() == "VIEW" + and _contains_supported_remote_source((statement,)) + for statement in statements + ) + + def _replace_remote_sources( statements: Sequence[exp.Expression], table_names: _RemoteTableNames, diff --git a/tests/_runtime/test_duckdb_wasm.py b/tests/_runtime/test_duckdb_wasm.py index c6a1b1eabac..5788dccd945 100644 --- a/tests/_runtime/test_duckdb_wasm.py +++ b/tests/_runtime/test_duckdb_wasm.py @@ -741,6 +741,25 @@ def test_avoids_sql_cte_table_names_case_insensitively() -> None: assert patch_result is not None assert "FROM __marimo_wasm_duckdb_remote_1" in patch_result.query + @staticmethod + def test_does_not_rewrite_create_view_remote_source() -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\n", + ) as fetch_url_bytes, + ): + patch_result = patch_duckdb_query_for_wasm( + """ + CREATE OR REPLACE VIEW remote_cars AS + SELECT * FROM 'https://datasets.marimo.app/cars.csv' + """, + ) + + assert patch_result is None + fetch_url_bytes.assert_not_called() + class TestDuckDBWasmMoSqlIntegration: @staticmethod @@ -820,7 +839,7 @@ async def test_mo_sql_create_table_remote_literal_runs_once( ORDER BY make ''' ).fetchall() - duckdb.close() + duckdb.sql("DROP TABLE IF EXISTS __marimo_wasm_create_once") """ ), ] @@ -935,7 +954,6 @@ def test_module_query_df_rewrites_reader_call() -> None: rows = relation.fetchall() finally: unpatch() - duckdb.close() assert rows == [(7, 25)] fetch_url_bytes.assert_called_once_with( @@ -976,7 +994,9 @@ def test_module_query_df_avoids_existing_catalog_table_names() -> None: rows = relation.fetchall() finally: unpatch() - duckdb.close() + duckdb.sql( + 'DROP TABLE IF EXISTS "__MARIMO_WASM_DUCKDB_REMOTE_0"' + ) assert rows == [(7, 25)] @@ -1041,7 +1061,7 @@ def test_module_execute_rewrites_before_side_effects() -> None: ).fetchall() finally: unpatch() - duckdb.close() + duckdb.sql(f"DROP TABLE IF EXISTS {table_name}") assert rows == [("ford",)] @@ -1071,7 +1091,7 @@ def test_module_execute_creates_table_from_remote_literal() -> None: ).fetchall() finally: unpatch() - duckdb.close() + duckdb.sql(f"DROP TABLE IF EXISTS {table_name}") assert rows == [("ford", 25), ("toyota", 18)] fetch_url_bytes.assert_called_once_with( From a5be2291422bb152c22bfb54428efb4ca4b94f2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Sat, 9 May 2026 14:46:00 +0200 Subject: [PATCH 12/22] fix: narrow DuckDB WASM package preloading --- frontend/src/core/islands/worker/worker.tsx | 5 +-- .../src/core/wasm/__tests__/utils.test.ts | 34 +++++++++++++++++++ frontend/src/core/wasm/utils.ts | 27 +++++++++++++++ frontend/src/core/wasm/worker/bootstrap.ts | 5 +-- frontend/src/core/wasm/worker/worker.ts | 5 +-- 5 files changed, 70 insertions(+), 6 deletions(-) create mode 100644 frontend/src/core/wasm/__tests__/utils.test.ts diff --git a/frontend/src/core/islands/worker/worker.tsx b/frontend/src/core/islands/worker/worker.tsx index 4423f6c0c5b..5abf93786e9 100644 --- a/frontend/src/core/islands/worker/worker.tsx +++ b/frontend/src/core/islands/worker/worker.tsx @@ -9,6 +9,7 @@ import { } from "rpc-anywhere"; import type { NotificationPayload } from "@/core/kernel/messages"; import type { ParentSchema } from "@/core/wasm/rpc"; +import { shouldLoadDuckDBPackages } from "@/core/wasm/utils"; import { TRANSPORT_ID } from "@/core/wasm/worker/constants"; import { getPyodideVersion } from "@/core/wasm/worker/getPyodideVersion"; import { MessageBuffer } from "@/core/wasm/worker/message-buffer"; @@ -85,8 +86,8 @@ const requestHandler = createRPCRequestHandler({ loadPackages: async (code: string) => { await pyodideReadyPromise; // Make sure loading is done - if (code.includes("mo.sql") || code.includes("duckdb")) { - // Add pandas and duckdb to the code + if (shouldLoadDuckDBPackages(code)) { + // DuckDB SQL and remote readers need these packages loaded up front. code = `import pandas\n${code}`; code = `import duckdb\n${code}`; code = `import sqlglot\n${code}`; diff --git a/frontend/src/core/wasm/__tests__/utils.test.ts b/frontend/src/core/wasm/__tests__/utils.test.ts new file mode 100644 index 00000000000..d0534b00790 --- /dev/null +++ b/frontend/src/core/wasm/__tests__/utils.test.ts @@ -0,0 +1,34 @@ +/* Copyright 2026 Marimo. All rights reserved. */ + +import { describe, expect, it } from "vitest"; +import { shouldLoadDuckDBPackages } from "../utils"; + +describe("shouldLoadDuckDBPackages", () => { + it("loads for mo.sql", () => { + expect(shouldLoadDuckDBPackages('df = mo.sql("SELECT 1")')).toBe(true); + }); + + it("loads for detected duckdb imports and usage", () => { + expect(shouldLoadDuckDBPackages("import duckdb")).toBe(true); + expect(shouldLoadDuckDBPackages("from duckdb import sql")).toBe(true); + expect(shouldLoadDuckDBPackages("rows = duckdb.sql('SELECT 1')")).toBe( + true, + ); + }); + + it("loads when package discovery found duckdb", () => { + expect( + shouldLoadDuckDBPackages("print('hello')", new Set(["duckdb"])), + ).toBe(true); + }); + + it("ignores incidental duckdb text", () => { + expect(shouldLoadDuckDBPackages("# duckdb is mentioned here")).toBe(false); + expect(shouldLoadDuckDBPackages("# duckdb.sql is mentioned here")).toBe( + false, + ); + expect(shouldLoadDuckDBPackages("name = 'duckdb'")).toBe(false); + expect(shouldLoadDuckDBPackages("name = 'duckdb.sql'")).toBe(false); + expect(shouldLoadDuckDBPackages('name = "duckdb.sql"')).toBe(false); + }); +}); diff --git a/frontend/src/core/wasm/utils.ts b/frontend/src/core/wasm/utils.ts index 2801b9d3e3f..8dbd33c0598 100644 --- a/frontend/src/core/wasm/utils.ts +++ b/frontend/src/core/wasm/utils.ts @@ -10,3 +10,30 @@ export function isWasm(): boolean { document.querySelector("marimo-wasm") !== null ); } + +const DUCKDB_IMPORT = /^(import\s+duckdb\b|from\s+duckdb\s+import\b)/; +const DUCKDB_USAGE = /(^|[^"'#])\bduckdb\s*\./; + +function hasDuckDBImportOrUsage(code: string): boolean { + return code.split("\n").some((line) => { + const trimmed = line.trimStart(); + if (trimmed.startsWith("#")) { + return false; + } + if (DUCKDB_IMPORT.test(trimmed)) { + return true; + } + return DUCKDB_USAGE.test(line.split("#")[0]); + }); +} + +export function shouldLoadDuckDBPackages( + code: string, + foundPackages?: ReadonlySet, +): boolean { + return ( + code.includes("mo.sql") || + foundPackages?.has("duckdb") === true || + hasDuckDBImportOrUsage(code) + ); +} diff --git a/frontend/src/core/wasm/worker/bootstrap.ts b/frontend/src/core/wasm/worker/bootstrap.ts index d32b34563c7..1d9c0e28376 100644 --- a/frontend/src/core/wasm/worker/bootstrap.ts +++ b/frontend/src/core/wasm/worker/bootstrap.ts @@ -9,6 +9,7 @@ import { WasmFileSystem } from "./fs"; import { getMarimoWheel } from "./getMarimoWheel"; import { t } from "./tracer"; import type { SerializedBridge, WasmController } from "./types"; +import { shouldLoadDuckDBPackages } from "../utils"; const MAKE_SNAPSHOT = false; @@ -163,8 +164,8 @@ export class DefaultWasmController implements WasmController { private async loadNotebookDeps(code: string, foundPackages: Set) { const pyodide = this.requirePyodide; - if (code.includes("mo.sql") || code.includes("duckdb")) { - // We need pandas and duckdb for mo.sql + if (shouldLoadDuckDBPackages(code, foundPackages)) { + // DuckDB SQL and remote readers need these packages loaded up front. code = `import pandas\n${code}`; code = `import duckdb\n${code}`; code = `import sqlglot\n${code}`; diff --git a/frontend/src/core/wasm/worker/worker.ts b/frontend/src/core/wasm/worker/worker.ts index 6e2c9a98fbd..bbd57d9850c 100644 --- a/frontend/src/core/wasm/worker/worker.ts +++ b/frontend/src/core/wasm/worker/worker.ts @@ -34,6 +34,7 @@ import type { SerializedBridge, WasmController, } from "./types"; +import { shouldLoadDuckDBPackages } from "../utils"; /** * Web worker responsible for running the notebook. @@ -141,8 +142,8 @@ const requestHandler = createRPCRequestHandler({ const span = t.startSpan("loadPackages"); await pyodideReadyPromise; // Make sure loading is done - if (code.includes("mo.sql") || code.includes("duckdb")) { - // Add pandas and duckdb to the code + if (shouldLoadDuckDBPackages(code)) { + // DuckDB SQL and remote readers need these packages loaded up front. code = `import pandas\n${code}`; code = `import duckdb\n${code}`; code = `import sqlglot\n${code}`; From 2e77fbf24fcd00c864ec7af5ab8e1a574601c2e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Sat, 9 May 2026 15:18:06 +0200 Subject: [PATCH 13/22] refactor: keep DuckDB SQL wrapper contract non-optional --- marimo/_data/sql_summaries.py | 10 +++------- marimo/_sql/utils.py | 30 +++++++++++++++--------------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/marimo/_data/sql_summaries.py b/marimo/_data/sql_summaries.py index 159047c624c..d23ec26c2ec 100644 --- a/marimo/_data/sql_summaries.py +++ b/marimo/_data/sql_summaries.py @@ -65,13 +65,9 @@ def get_sql_stats( FROM {table_name} """ - relation = wrapped_sql(stats_query, connection=None) - if relation is None: - raise ValueError( - f"Column {column_name} not found in table {table_name}" - ) - - stats_result: tuple[int, ...] | None = relation.fetchone() + stats_result: tuple[int, ...] | None = wrapped_sql( + stats_query, connection=None + ).fetchone() if stats_result is None: raise ValueError( f"Column {column_name} not found in table {table_name}" diff --git a/marimo/_sql/utils.py b/marimo/_sql/utils.py index 4e5f79a587e..2cf7880acd8 100644 --- a/marimo/_sql/utils.py +++ b/marimo/_sql/utils.py @@ -9,7 +9,6 @@ from marimo._data.models import DataType from marimo._dependencies.dependencies import DependencyManager from marimo._runtime._wasm._duckdb import ( - WasmDuckDBSqlResult, try_run_duckdb_sql_with_wasm_patch, ) from marimo._runtime.context.types import ( @@ -28,6 +27,8 @@ LOGGER = _loggers.marimo_logger() CHEAP_DISCOVERY_DATABASES = ["duckdb", "sqlite", "mysql", "postgresql"] +# DuckDB SQL can return None for DDL, so keep "patch did not apply" distinct. +_NO_WASM_DUCKDB_RESULT = object() def _try_wasm_duckdb_sql( @@ -36,7 +37,7 @@ def _try_wasm_duckdb_sql( glbls: dict[str, Any], *, reserved_names: tuple[str, ...] = (), -) -> WasmDuckDBSqlResult | None: +) -> object: import duckdb if connection is duckdb: @@ -58,13 +59,13 @@ def _try_wasm_duckdb_sql( eval_locals=glbls, reserved_names=reserved_names, ) - return result + return _NO_WASM_DUCKDB_RESULT if result is None else result.value def wrapped_sql( query: str, connection: duckdb.DuckDBPyConnection | None, -) -> duckdb.DuckDBPyRelation | None: +) -> duckdb.DuckDBPyRelation: DependencyManager.duckdb.require("to execute sql") # In Python globals() are scoped to modules; since this function @@ -81,13 +82,12 @@ def wrapped_sql( try: ctx = get_context() except ContextNotInitializedError: - relation: duckdb.DuckDBPyRelation | None result = _try_wasm_duckdb_sql(query, connection, globals()) - if result is None: + if result is _NO_WASM_DUCKDB_RESULT: # No WASM rewrite was needed; use DuckDB's normal SQL path. relation = connection.sql(query=query) else: - relation = cast(duckdb.DuckDBPyRelation | None, result.value) + relation = cast(duckdb.DuckDBPyRelation, result) else: install_connection = ( ctx.execution_context.with_connection @@ -101,7 +101,7 @@ def wrapped_sql( ctx.globals, reserved_names=tuple(ctx.globals), ) - if result is None: + if result is _NO_WASM_DUCKDB_RESULT: # Run in kernel globals so DuckDB replacement scans see user data. relation = eval( "connection.sql(query=query)", @@ -109,7 +109,7 @@ def wrapped_sql( {"query": query, "connection": connection}, ) else: - relation = cast(duckdb.DuckDBPyRelation | None, result.value) + relation = cast(duckdb.DuckDBPyRelation, result) return relation @@ -121,7 +121,7 @@ def _try_wasm_duckdb_execute( glbls: dict[str, Any], *, reserved_names: tuple[str, ...] = (), -) -> WasmDuckDBSqlResult | None: +) -> object: import duckdb if connection is duckdb: @@ -143,7 +143,7 @@ def _try_wasm_duckdb_execute( eval_locals=glbls, reserved_names=reserved_names, ) - return result + return _NO_WASM_DUCKDB_RESULT if result is None else result.value def execute_duckdb_sql( @@ -167,10 +167,10 @@ def execute_duckdb_sql( ctx = get_context() except ContextNotInitializedError: result = _try_wasm_duckdb_execute(query, params, connection, globals()) - if result is None: + if result is _NO_WASM_DUCKDB_RESULT: # No WASM rewrite was needed; preserve DuckDB's parameterized path. return connection.execute(query, params) - return cast(duckdb.DuckDBPyConnection, result.value) + return cast(duckdb.DuckDBPyConnection, result) else: install_connection = ( ctx.execution_context.with_connection @@ -185,7 +185,7 @@ def execute_duckdb_sql( ctx.globals, reserved_names=tuple(ctx.globals), ) - if result is None: + if result is _NO_WASM_DUCKDB_RESULT: # Run in kernel globals so parameterized SQL can scan user data. value = eval( "connection.execute(query, params)", @@ -197,7 +197,7 @@ def execute_duckdb_sql( }, ) else: - value = result.value + value = result return cast(duckdb.DuckDBPyConnection, value) From 32dc803a897541146cbe0a083e3a9aff87ec3bd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Sat, 9 May 2026 15:18:28 +0200 Subject: [PATCH 14/22] test: clean up views --- tests/_runtime/test_duckdb_wasm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/_runtime/test_duckdb_wasm.py b/tests/_runtime/test_duckdb_wasm.py index 5788dccd945..e4e4b209a8c 100644 --- a/tests/_runtime/test_duckdb_wasm.py +++ b/tests/_runtime/test_duckdb_wasm.py @@ -954,6 +954,7 @@ def test_module_query_df_rewrites_reader_call() -> None: rows = relation.fetchall() finally: unpatch() + duckdb.sql("DROP VIEW IF EXISTS query_df_local") assert rows == [(7, 25)] fetch_url_bytes.assert_called_once_with( @@ -994,6 +995,7 @@ def test_module_query_df_avoids_existing_catalog_table_names() -> None: rows = relation.fetchall() finally: unpatch() + duckdb.sql("DROP VIEW IF EXISTS query_df_local_collision") duckdb.sql( 'DROP TABLE IF EXISTS "__MARIMO_WASM_DUCKDB_REMOTE_0"' ) From 6fe3e01b7898e29675a4be8510a62b268a14673a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Sat, 9 May 2026 16:45:55 +0200 Subject: [PATCH 15/22] fix: preserve DuckDB catalog semantics in WASM rewrites --- marimo/_runtime/_wasm/_duckdb/__init__.py | 35 +++++++--- marimo/_runtime/_wasm/_duckdb/sources.py | 23 ++++++- tests/_runtime/test_duckdb_wasm.py | 82 +++++++++++++++++++++++ 3 files changed, 130 insertions(+), 10 deletions(-) diff --git a/marimo/_runtime/_wasm/_duckdb/__init__.py b/marimo/_runtime/_wasm/_duckdb/__init__.py index b3266ee3b77..3e9bcaf2027 100644 --- a/marimo/_runtime/_wasm/_duckdb/__init__.py +++ b/marimo/_runtime/_wasm/_duckdb/__init__.py @@ -235,13 +235,15 @@ def patch_duckdb_query_for_wasm( statements = statements or _parse_duckdb_query(query) if statements is None: return None - if _contains_remote_view_definition(statements): + if _contains_remote_view_definition(statements, query=query): return None table_names = _RemoteTableNames( (*reserved_names, *_reserved_sql_names(statements)) ) - patched_statements = _replace_remote_sources(statements, table_names) + patched_statements = _replace_remote_sources( + statements, table_names, query=query + ) if not table_names: return None @@ -379,9 +381,11 @@ def try_run_duckdb_sql_with_wasm_patch( return None statements = _parse_duckdb_query(query) - if statements is None or not _contains_supported_remote_source(statements): + if statements is None or not _contains_supported_remote_source( + statements, query=query + ): return None - if _contains_remote_view_definition(statements): + if _contains_remote_view_definition(statements, query=query): return None namespace_names = _reserved_namespace_names( @@ -389,7 +393,7 @@ def try_run_duckdb_sql_with_wasm_patch( eval_locals, ( *reserved_names, - *_duckdb_catalog_names(original, args), + *_duckdb_catalog_names(original, args, kwargs_dict), *_identifier_string_args(args), *_identifier_string_args(tuple(kwargs_dict.values())), ), @@ -573,10 +577,11 @@ def _unused_name(base: str, used: set[str]) -> str: def _duckdb_catalog_names( original: Callable[..., Any], args: tuple[Any, ...], + kwargs: Mapping[str, Any], ) -> tuple[str, ...]: """Reserve existing DuckDB table names before generating replacements.""" try: - relation = _show_duckdb_tables(original, args) + relation = _show_duckdb_tables(original, args, kwargs) rows = relation.fetchall() except Exception: return () @@ -586,6 +591,7 @@ def _duckdb_catalog_names( def _show_duckdb_tables( original: Callable[..., Any], args: tuple[Any, ...], + kwargs: Mapping[str, Any], ) -> Any: """Run ``SHOW TABLES`` through the same DuckDB entry point being patched.""" import duckdb @@ -593,6 +599,9 @@ def _show_duckdb_tables( original_call = inspect.unwrap(original) if args and isinstance(args[0], duckdb.DuckDBPyConnection): return inspect.unwrap(type(args[0]).sql)(args[0], "SHOW TABLES") + connection = kwargs.get("connection") + if isinstance(connection, duckdb.DuckDBPyConnection): + return inspect.unwrap(type(connection).sql)(connection, "SHOW TABLES") if original_call is inspect.unwrap(duckdb.query_df): return inspect.unwrap(duckdb.sql)("SHOW TABLES") return original_call("SHOW TABLES") @@ -691,12 +700,14 @@ def _require_pandas() -> None: DependencyManager.pandas.require( "to read remote DuckDB file sources in WASM" ) + import pandas # noqa: F401 def _require_sqlglot() -> None: DependencyManager.sqlglot.require( "to rewrite remote DuckDB SQL sources in WASM" ) + import sqlglot # noqa: F401 def _parse_duckdb_query(query: str) -> list[exp.Expression] | None: @@ -718,12 +729,14 @@ def _parse_duckdb_query(query: str) -> list[exp.Expression] | None: def _contains_supported_remote_source( statements: Sequence[exp.Expression], + *, + query: str, ) -> bool: """Check for rewrite work before paying for DuckDB catalog inspection.""" from sqlglot import exp return any( - remote_file_source_from_table(table) is not None + remote_file_source_from_table(table, query=query) is not None for statement in statements for table in statement.find_all(exp.Table) ) @@ -731,6 +744,8 @@ def _contains_supported_remote_source( def _contains_remote_view_definition( statements: Sequence[exp.Expression], + *, + query: str, ) -> bool: """Views persist SQL text, so replacement-scan locals would go stale.""" from sqlglot import exp @@ -738,7 +753,7 @@ def _contains_remote_view_definition( return any( isinstance(statement, exp.Create) and str(statement.args.get("kind")).upper() == "VIEW" - and _contains_supported_remote_source((statement,)) + and _contains_supported_remote_source((statement,), query=query) for statement in statements ) @@ -746,6 +761,8 @@ def _contains_remote_view_definition( def _replace_remote_sources( statements: Sequence[exp.Expression], table_names: _RemoteTableNames, + *, + query: str, ) -> list[exp.Expression]: """Replace supported remote table nodes while preserving aliases.""" from sqlglot import exp @@ -754,7 +771,7 @@ def replace_table(node: exp.Expression) -> exp.Expression: if not isinstance(node, exp.Table): return node - source = remote_file_source_from_table(node) + source = remote_file_source_from_table(node, query=query) if source is None: return node diff --git a/marimo/_runtime/_wasm/_duckdb/sources.py b/marimo/_runtime/_wasm/_duckdb/sources.py index 6ebdd47c7a4..612e4a84056 100644 --- a/marimo/_runtime/_wasm/_duckdb/sources.py +++ b/marimo/_runtime/_wasm/_duckdb/sources.py @@ -31,10 +31,12 @@ def remote_file_source_from_table( table: exp.Table, + *, + query: str | None = None, ) -> RemoteFileSource | None: """Return a remote source for supported direct URLs or reader calls.""" table_name = table.name - if table_name: + if table_name and _is_single_quoted_table_identifier(table, query): reader = reader_for_url(table_name) remote_file = remote_file_from_url(table_name) if reader is not None and remote_file is not None: @@ -61,6 +63,25 @@ def remote_file_source_from_table( ) +def _is_single_quoted_table_identifier( + table: exp.Table, query: str | None +) -> bool: + """Distinguish DuckDB file scans from ordinary quoted identifiers.""" + if query is None: + return False + meta = getattr(table.this, "meta", {}) + start = meta.get("start") + end = meta.get("end") + if not isinstance(start, int) or not isinstance(end, int): + return False + return ( + 0 <= start + and end < len(query) + and query[start] == "'" + and query[end] == "'" + ) + + def _table_function_name(table_expr: exp.Expression) -> str | None: """Normalize sqlglot's version-dependent DuckDB reader node names.""" import sqlglot.expressions as exp diff --git a/tests/_runtime/test_duckdb_wasm.py b/tests/_runtime/test_duckdb_wasm.py index e4e4b209a8c..e723c9b7efe 100644 --- a/tests/_runtime/test_duckdb_wasm.py +++ b/tests/_runtime/test_duckdb_wasm.py @@ -1238,6 +1238,88 @@ def test_connection_sql_avoids_existing_catalog_table_names() -> None: "https://datasets.marimo.app/cars.csv" ) + @staticmethod + @pytest.mark.parametrize("api_kind", ["sql", "query", "execute"]) + def test_module_methods_with_connection_avoid_existing_catalog_table_names( + api_kind: str, + ) -> None: + import duckdb + + connection = duckdb.connect(":memory:") + result_table = "__marimo_duckdb_wasm_module_conn_result" + try: + connection.sql( + """ + CREATE OR REPLACE TABLE "__MARIMO_WASM_DUCKDB_REMOTE_0" + AS SELECT 99 AS mpg + """ + ) + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"mpg\n25\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + if api_kind == "execute": + duckdb.execute( + f""" + CREATE OR REPLACE TABLE {result_table} AS + SELECT mpg + FROM 'https://datasets.marimo.app/cars.csv' + """, + connection=connection, + ) + rows = connection.sql( + f"SELECT mpg FROM {result_table}" + ).fetchall() + else: + rows = getattr(duckdb, api_kind)( + """ + SELECT mpg + FROM 'https://datasets.marimo.app/cars.csv' + """, + connection=connection, + ).fetchall() + finally: + unpatch() + finally: + connection.close() + + assert rows == [(25,)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_double_quoted_url_table_identifier_is_not_remote_source() -> None: + import duckdb + + table_name = '"https://datasets.marimo.app/cars.csv"' + try: + duckdb.sql( + f"CREATE OR REPLACE TABLE {table_name} AS SELECT 42 AS x" + ) + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"x\n7\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + rows = duckdb.sql(f"SELECT x FROM {table_name}").fetchall() + finally: + unpatch() + finally: + duckdb.sql(f"DROP TABLE IF EXISTS {table_name}") + + assert rows == [(42,)] + fetch_url_bytes.assert_not_called() + @staticmethod @mock_pyodide() def test_unpatch_restores_module_functions_and_connection_methods() -> ( From 1665f9984d3423116b07c2587c94541dac3aa700 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Sat, 9 May 2026 16:45:58 +0200 Subject: [PATCH 16/22] fix: broaden DuckDB WASM package preloading --- .../src/core/wasm/__tests__/utils.test.ts | 14 +++++-------- frontend/src/core/wasm/utils.ts | 20 ++----------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/frontend/src/core/wasm/__tests__/utils.test.ts b/frontend/src/core/wasm/__tests__/utils.test.ts index d0534b00790..b9d4829ef69 100644 --- a/frontend/src/core/wasm/__tests__/utils.test.ts +++ b/frontend/src/core/wasm/__tests__/utils.test.ts @@ -8,12 +8,14 @@ describe("shouldLoadDuckDBPackages", () => { expect(shouldLoadDuckDBPackages('df = mo.sql("SELECT 1")')).toBe(true); }); - it("loads for detected duckdb imports and usage", () => { + it("loads for duckdb text", () => { expect(shouldLoadDuckDBPackages("import duckdb")).toBe(true); expect(shouldLoadDuckDBPackages("from duckdb import sql")).toBe(true); + expect(shouldLoadDuckDBPackages("import pandas, duckdb")).toBe(true); expect(shouldLoadDuckDBPackages("rows = duckdb.sql('SELECT 1')")).toBe( true, ); + expect(shouldLoadDuckDBPackages("name = 'duckdb'")).toBe(true); }); it("loads when package discovery found duckdb", () => { @@ -22,13 +24,7 @@ describe("shouldLoadDuckDBPackages", () => { ).toBe(true); }); - it("ignores incidental duckdb text", () => { - expect(shouldLoadDuckDBPackages("# duckdb is mentioned here")).toBe(false); - expect(shouldLoadDuckDBPackages("# duckdb.sql is mentioned here")).toBe( - false, - ); - expect(shouldLoadDuckDBPackages("name = 'duckdb'")).toBe(false); - expect(shouldLoadDuckDBPackages("name = 'duckdb.sql'")).toBe(false); - expect(shouldLoadDuckDBPackages('name = "duckdb.sql"')).toBe(false); + it("does not load without mo.sql, duckdb text, or discovery", () => { + expect(shouldLoadDuckDBPackages("print('hello')")).toBe(false); }); }); diff --git a/frontend/src/core/wasm/utils.ts b/frontend/src/core/wasm/utils.ts index 8dbd33c0598..a5e177e3f48 100644 --- a/frontend/src/core/wasm/utils.ts +++ b/frontend/src/core/wasm/utils.ts @@ -11,29 +11,13 @@ export function isWasm(): boolean { ); } -const DUCKDB_IMPORT = /^(import\s+duckdb\b|from\s+duckdb\s+import\b)/; -const DUCKDB_USAGE = /(^|[^"'#])\bduckdb\s*\./; - -function hasDuckDBImportOrUsage(code: string): boolean { - return code.split("\n").some((line) => { - const trimmed = line.trimStart(); - if (trimmed.startsWith("#")) { - return false; - } - if (DUCKDB_IMPORT.test(trimmed)) { - return true; - } - return DUCKDB_USAGE.test(line.split("#")[0]); - }); -} - export function shouldLoadDuckDBPackages( code: string, foundPackages?: ReadonlySet, ): boolean { return ( code.includes("mo.sql") || - foundPackages?.has("duckdb") === true || - hasDuckDBImportOrUsage(code) + code.includes("duckdb") || + foundPackages?.has("duckdb") === true ); } From 942e191f4ed3040eaa0217f6f178c04bb9435c5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Tue, 12 May 2026 15:05:07 +0200 Subject: [PATCH 17/22] fix: address DuckDB WASM review comments --- .../src/core/wasm/__tests__/utils.test.ts | 10 +- frontend/src/core/wasm/utils.ts | 5 +- marimo/_runtime/_wasm/_duckdb/__init__.py | 22 ++-- marimo/_sql/utils.py | 6 +- tests/_runtime/test_duckdb_wasm.py | 116 ++++++++++++++++++ 5 files changed, 144 insertions(+), 15 deletions(-) diff --git a/frontend/src/core/wasm/__tests__/utils.test.ts b/frontend/src/core/wasm/__tests__/utils.test.ts index b9d4829ef69..604463b24ba 100644 --- a/frontend/src/core/wasm/__tests__/utils.test.ts +++ b/frontend/src/core/wasm/__tests__/utils.test.ts @@ -8,14 +8,13 @@ describe("shouldLoadDuckDBPackages", () => { expect(shouldLoadDuckDBPackages('df = mo.sql("SELECT 1")')).toBe(true); }); - it("loads for duckdb text", () => { + it("loads for duckdb imports and usage", () => { expect(shouldLoadDuckDBPackages("import duckdb")).toBe(true); expect(shouldLoadDuckDBPackages("from duckdb import sql")).toBe(true); expect(shouldLoadDuckDBPackages("import pandas, duckdb")).toBe(true); expect(shouldLoadDuckDBPackages("rows = duckdb.sql('SELECT 1')")).toBe( true, ); - expect(shouldLoadDuckDBPackages("name = 'duckdb'")).toBe(true); }); it("loads when package discovery found duckdb", () => { @@ -24,7 +23,12 @@ describe("shouldLoadDuckDBPackages", () => { ).toBe(true); }); - it("does not load without mo.sql, duckdb text, or discovery", () => { + it("does not load for incidental duckdb text", () => { + expect(shouldLoadDuckDBPackages("name = 'duckdb'")).toBe(false); + expect(shouldLoadDuckDBPackages("# import duckdb")).toBe(false); + }); + + it("does not load without mo.sql, duckdb usage, or discovery", () => { expect(shouldLoadDuckDBPackages("print('hello')")).toBe(false); }); }); diff --git a/frontend/src/core/wasm/utils.ts b/frontend/src/core/wasm/utils.ts index a5e177e3f48..b8f371477c5 100644 --- a/frontend/src/core/wasm/utils.ts +++ b/frontend/src/core/wasm/utils.ts @@ -11,13 +11,16 @@ export function isWasm(): boolean { ); } +const DUCKDB_USAGE_PATTERN = + /(^|\n)\s*(?:import\s+[^\n#]*\bduckdb\b|from\s+duckdb\b|[^\n#]*\bduckdb\s*\.)/; + export function shouldLoadDuckDBPackages( code: string, foundPackages?: ReadonlySet, ): boolean { return ( code.includes("mo.sql") || - code.includes("duckdb") || + DUCKDB_USAGE_PATTERN.test(code) || foundPackages?.has("duckdb") === true ); } diff --git a/marimo/_runtime/_wasm/_duckdb/__init__.py b/marimo/_runtime/_wasm/_duckdb/__init__.py index 3e9bcaf2027..1d1540692d5 100644 --- a/marimo/_runtime/_wasm/_duckdb/__init__.py +++ b/marimo/_runtime/_wasm/_duckdb/__init__.py @@ -222,17 +222,22 @@ def patch_duckdb_query_for_wasm( DataFrame. If the query or ``reserved_names`` already use that identifier, the rewriter uses the next free suffix. - In Pyodide this raises if sqlglot is unavailable. Returns ``None`` when: + In Pyodide this raises if sqlglot is unavailable and the query may contain + a remote source. Returns ``None`` when: - marimo is not running in Pyodide; + - the query has no remote URL; - the query has no supported remote file source; - the query cannot be parsed. """ if not is_pyodide(): return None - _require_sqlglot() - statements = statements or _parse_duckdb_query(query) + if statements is None: + if not _query_may_contain_remote_file_source(query): + return None + _require_sqlglot() + statements = _parse_duckdb_query(query) if statements is None: return None if _contains_remote_view_definition(statements, query=query): @@ -265,10 +270,6 @@ def patch_duckdb_for_wasm() -> Unpatch: except ImportError: return lambda: None - # DuckDB's WASM patch set is atomic: direct reader and SQL API wrappers - # install together, and SQL rewriting hard-requires sqlglot. - _require_sqlglot() - patches = WasmPatchSet() for function_name in _DIRECT_READER_SPECS: patches.replace( @@ -379,7 +380,10 @@ def try_run_duckdb_sql_with_wasm_patch( ) if not isinstance(query, str): return None + if not _query_may_contain_remote_file_source(query): + return None + _require_sqlglot() statements = _parse_duckdb_query(query) if statements is None or not _contains_supported_remote_source( statements, query=query @@ -710,6 +714,10 @@ def _require_sqlglot() -> None: import sqlglot # noqa: F401 +def _query_may_contain_remote_file_source(query: str) -> bool: + return "https://" in query or "http://" in query + + def _parse_duckdb_query(query: str) -> list[exp.Expression] | None: import sqlglot from sqlglot import exp as sqlglot_exp diff --git a/marimo/_sql/utils.py b/marimo/_sql/utils.py index 2cf7880acd8..f8ebfd02947 100644 --- a/marimo/_sql/utils.py +++ b/marimo/_sql/utils.py @@ -67,6 +67,7 @@ def wrapped_sql( connection: duckdb.DuckDBPyConnection | None, ) -> duckdb.DuckDBPyRelation: DependencyManager.duckdb.require("to execute sql") + import duckdb # In Python globals() are scoped to modules; since this function # is in a different module than user code, globals() doesn't return @@ -75,8 +76,6 @@ def wrapped_sql( # However, duckdb needs access to the kernel's globals. For this reason, # we manually exec duckdb and provide it with the kernel's globals. if connection is None: - import duckdb - connection = cast(duckdb.DuckDBPyConnection, duckdb) try: @@ -157,10 +156,9 @@ def execute_duckdb_sql( parameterized queries ($1, $2, ...) for safe value interpolation. """ DependencyManager.duckdb.require("to execute sql") + import duckdb if connection is None: - import duckdb - connection = cast(duckdb.DuckDBPyConnection, duckdb) try: diff --git a/tests/_runtime/test_duckdb_wasm.py b/tests/_runtime/test_duckdb_wasm.py index e723c9b7efe..42b47b89b15 100644 --- a/tests/_runtime/test_duckdb_wasm.py +++ b/tests/_runtime/test_duckdb_wasm.py @@ -527,6 +527,26 @@ def test_noop_outside_pyodide() -> None: unpatch() assert duckdb.read_csv is original + @staticmethod + def test_patch_installation_does_not_require_sqlglot() -> None: + import duckdb + + original = duckdb.read_csv + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._duckdb._require_sqlglot", + side_effect=AssertionError("sqlglot should be lazy"), + ), + ): + unpatch = patch_duckdb_for_wasm() + try: + assert duckdb.read_csv is not original + finally: + unpatch() + + assert duckdb.read_csv is original + @staticmethod @pytest.mark.parametrize("api_kind", ["module", "connection"]) @pytest.mark.parametrize( @@ -855,6 +875,81 @@ async def test_mo_sql_create_table_remote_literal_runs_once( ) +class TestDuckDBWasmSqlUtils: + @staticmethod + def test_wrapped_sql_rewrites_remote_literal_with_explicit_connection() -> ( + None + ): + import duckdb + + from marimo._sql.utils import wrapped_sql + + connection = duckdb.connect(":memory:") + try: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + relation = wrapped_sql( + """ + SELECT make + FROM 'https://datasets.marimo.app/cars.csv' + WHERE mpg > 20 + """, + connection, + ) + rows = relation.fetchall() + finally: + connection.close() + + assert rows == [("ford",)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_execute_duckdb_sql_rewrites_remote_literal_with_explicit_connection() -> ( + None + ): + import duckdb + + from marimo._sql.utils import execute_duckdb_sql + + table_name = "__marimo_duckdb_wasm_sql_utils_execute_test" + connection = duckdb.connect(":memory:") + try: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + execute_duckdb_sql( + f""" + CREATE OR REPLACE TABLE {table_name} AS + SELECT make + FROM 'https://datasets.marimo.app/cars.csv' + WHERE mpg > ? + """, + [20], + connection, + ) + rows = connection.sql( + f"SELECT make FROM {table_name} ORDER BY make" + ).fetchall() + finally: + connection.close() + + assert rows == [("ford",)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + class TestDuckDBWasmSqlApiPatch: @staticmethod def test_noop_outside_pyodide() -> None: @@ -1036,6 +1131,27 @@ def test_module_sql_skips_catalog_lookup_without_remote_source() -> None: assert rows == [(1,)] catalog_names.assert_not_called() + @staticmethod + def test_module_sql_without_remote_source_does_not_require_sqlglot() -> ( + None + ): + import duckdb + + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._duckdb._require_sqlglot", + side_effect=AssertionError("sqlglot should be lazy"), + ), + ): + unpatch = patch_duckdb_for_wasm() + try: + rows = duckdb.sql("SELECT 1").fetchall() + finally: + unpatch() + + assert rows == [(1,)] + @staticmethod def test_module_execute_rewrites_before_side_effects() -> None: import duckdb From 7224e5a29340835bccc64e431f0d0850076084d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Tue, 12 May 2026 15:26:24 +0200 Subject: [PATCH 18/22] fix: support older sqlglot URL scan metadata --- marimo/_runtime/_wasm/_duckdb/sources.py | 19 +++++++- tests/_runtime/test_duckdb_wasm.py | 56 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/marimo/_runtime/_wasm/_duckdb/sources.py b/marimo/_runtime/_wasm/_duckdb/sources.py index 612e4a84056..ecbc36b59f1 100644 --- a/marimo/_runtime/_wasm/_duckdb/sources.py +++ b/marimo/_runtime/_wasm/_duckdb/sources.py @@ -10,6 +10,7 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING, Any from marimo._runtime._wasm._duckdb.io import ( @@ -73,7 +74,7 @@ def _is_single_quoted_table_identifier( start = meta.get("start") end = meta.get("end") if not isinstance(start, int) or not isinstance(end, int): - return False + return _query_has_single_quoted_table_reference(query, table.name) return ( 0 <= start and end < len(query) @@ -82,6 +83,22 @@ def _is_single_quoted_table_identifier( ) +def _query_has_single_quoted_table_reference( + query: str, table_name: str +) -> bool: + if not table_name: + return False + + quoted_table = re.escape(f"'{table_name}'") + return any( + re.search(pattern, query, flags=re.IGNORECASE) is not None + for pattern in ( + rf"(?:^|[()\s])FROM\s+{quoted_table}", + rf"\bJOIN\s+{quoted_table}", + ) + ) + + def _table_function_name(table_expr: exp.Expression) -> str | None: """Normalize sqlglot's version-dependent DuckDB reader node names.""" import sqlglot.expressions as exp diff --git a/tests/_runtime/test_duckdb_wasm.py b/tests/_runtime/test_duckdb_wasm.py index 42b47b89b15..213c84f9077 100644 --- a/tests/_runtime/test_duckdb_wasm.py +++ b/tests/_runtime/test_duckdb_wasm.py @@ -625,6 +625,62 @@ def test_read_parquet_node_preserves_this_argument() -> None: "https://example.com/a.parquet" ] + @staticmethod + def test_direct_literal_without_token_metadata_is_remote_source() -> None: + from sqlglot import exp + + table = exp.Table( + this=exp.Identifier(this="https://example.com/a.csv", quoted=True) + ) + + source = remote_file_source_from_table( + table, + query="SELECT * FROM 'https://example.com/a.csv'", + ) + + assert source is not None + assert source.reader_name == "csv" + assert [file.url for file in source.files] == [ + "https://example.com/a.csv" + ] + + @staticmethod + def test_double_quoted_identifier_without_token_metadata_is_not_remote_source() -> ( + None + ): + from sqlglot import exp + + table = exp.Table( + this=exp.Identifier(this="https://example.com/a.csv", quoted=True) + ) + + source = remote_file_source_from_table( + table, + query='SELECT * FROM "https://example.com/a.csv"', + ) + + assert source is None + + @staticmethod + def test_string_literal_without_token_metadata_does_not_make_identifier_remote() -> ( + None + ): + from sqlglot import exp + + table = exp.Table( + this=exp.Identifier(this="https://example.com/a.csv", quoted=True) + ) + + source = remote_file_source_from_table( + table, + query=""" + SELECT 1, 'https://example.com/a.csv' AS label + FROM "https://example.com/a.csv" + """, + ) + + assert source is None + @staticmethod @pytest.mark.parametrize( "function_name", From e50c8b502c70dd1e8c66fda8d3b9bfaf7e5e8894 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Tue, 12 May 2026 17:08:31 +0200 Subject: [PATCH 19/22] docs: comment adjustments --- frontend/src/core/islands/worker/worker.tsx | 2 +- frontend/src/core/wasm/worker/bootstrap.ts | 2 +- frontend/src/core/wasm/worker/worker.ts | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/src/core/islands/worker/worker.tsx b/frontend/src/core/islands/worker/worker.tsx index 5abf93786e9..090236f9b86 100644 --- a/frontend/src/core/islands/worker/worker.tsx +++ b/frontend/src/core/islands/worker/worker.tsx @@ -87,7 +87,7 @@ const requestHandler = createRPCRequestHandler({ await pyodideReadyPromise; // Make sure loading is done if (shouldLoadDuckDBPackages(code)) { - // DuckDB SQL and remote readers need these packages loaded up front. + // Add pandas and duckdb to the code for mo.sql and for remote duckdb sources code = `import pandas\n${code}`; code = `import duckdb\n${code}`; code = `import sqlglot\n${code}`; diff --git a/frontend/src/core/wasm/worker/bootstrap.ts b/frontend/src/core/wasm/worker/bootstrap.ts index 1d9c0e28376..a40e4fbb930 100644 --- a/frontend/src/core/wasm/worker/bootstrap.ts +++ b/frontend/src/core/wasm/worker/bootstrap.ts @@ -165,7 +165,7 @@ export class DefaultWasmController implements WasmController { const pyodide = this.requirePyodide; if (shouldLoadDuckDBPackages(code, foundPackages)) { - // DuckDB SQL and remote readers need these packages loaded up front. + // We need pandas and duckdb for mo.sql and for remote duckdb sources code = `import pandas\n${code}`; code = `import duckdb\n${code}`; code = `import sqlglot\n${code}`; diff --git a/frontend/src/core/wasm/worker/worker.ts b/frontend/src/core/wasm/worker/worker.ts index bbd57d9850c..c8a880f6a48 100644 --- a/frontend/src/core/wasm/worker/worker.ts +++ b/frontend/src/core/wasm/worker/worker.ts @@ -143,7 +143,7 @@ const requestHandler = createRPCRequestHandler({ await pyodideReadyPromise; // Make sure loading is done if (shouldLoadDuckDBPackages(code)) { - // DuckDB SQL and remote readers need these packages loaded up front. + // Add pandas and duckdb to the code for mo.sql and for remote duckdb sources code = `import pandas\n${code}`; code = `import duckdb\n${code}`; code = `import sqlglot\n${code}`; From c82765d0cc62f2d5be1087f56bc79804aea1cfa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Tue, 12 May 2026 17:45:55 +0200 Subject: [PATCH 20/22] refactor: reduce WASM patch helper duplication --- marimo/_runtime/_wasm/_patches.py | 52 +++++++++++-------------- marimo/_sql/utils.py | 63 ++++++++----------------------- 2 files changed, 38 insertions(+), 77 deletions(-) diff --git a/marimo/_runtime/_wasm/_patches.py b/marimo/_runtime/_wasm/_patches.py index ae628a63c3b..06662c80eb4 100644 --- a/marimo/_runtime/_wasm/_patches.py +++ b/marimo/_runtime/_wasm/_patches.py @@ -48,38 +48,30 @@ def patch( No-op outside pyodide or if ``attr`` is missing (e.g. renamed across polars versions). """ - if not self._active: - return - - original = getattr(owner, attr, None) - if original is None: - return - @functools.wraps(original) - def wrapper(*args: Any, **kwargs: Any) -> Any: - try: - return original(*args, **kwargs) - except catch as original_exc: - original_tb = original_exc.__traceback__ + def wrapper_factory( + original: Callable[..., Any], + ) -> Callable[..., Any]: + @functools.wraps(original) + def wrapper(*args: Any, **kwargs: Any) -> Any: try: - return fallback(original, *args, **kwargs) - except ModuleNotFoundError: - # Let missing-dependency errors bubble up so marimo can - # prompt the user to install the package. - raise - except Exception as fallback_exc: - raise original_exc.with_traceback( - original_tb - ) from fallback_exc - - setattr(owner, attr, wrapper) - - def _unpatch() -> None: - # Only restore if we're still the active wrapper. - if getattr(owner, attr, None) is wrapper: - setattr(owner, attr, original) - - self._unpatches.append(_unpatch) + return original(*args, **kwargs) + except catch as original_exc: + original_tb = original_exc.__traceback__ + try: + return fallback(original, *args, **kwargs) + except ModuleNotFoundError: + # Let missing-dependency errors bubble up so marimo can + # prompt the user to install the package. + raise + except Exception as fallback_exc: + raise original_exc.with_traceback( + original_tb + ) from fallback_exc + + return wrapper + + self.replace(owner, attr, wrapper_factory) def replace( self, diff --git a/marimo/_sql/utils.py b/marimo/_sql/utils.py index f8ebfd02947..2249fd8d0b9 100644 --- a/marimo/_sql/utils.py +++ b/marimo/_sql/utils.py @@ -31,22 +31,22 @@ _NO_WASM_DUCKDB_RESULT = object() -def _try_wasm_duckdb_sql( +def _try_wasm_duckdb( + method_name: str, query: str, connection: Any, glbls: dict[str, Any], - *, - reserved_names: tuple[str, ...] = (), + *trailing_args: Any, ) -> object: import duckdb if connection is duckdb: - original = duckdb.sql - args: tuple[Any, ...] = (query,) + original = getattr(duckdb, method_name) + args: tuple[Any, ...] = (query, *trailing_args) query_arg_index = 0 else: - original = type(connection).sql - args = (connection, query) + original = getattr(type(connection), method_name) + args = (connection, query, *trailing_args) query_arg_index = 1 result = try_run_duckdb_sql_with_wasm_patch( @@ -57,7 +57,6 @@ def _try_wasm_duckdb_sql( query_kwarg_names=("query",), eval_globals=glbls, eval_locals=glbls, - reserved_names=reserved_names, ) return _NO_WASM_DUCKDB_RESULT if result is None else result.value @@ -81,7 +80,7 @@ def wrapped_sql( try: ctx = get_context() except ContextNotInitializedError: - result = _try_wasm_duckdb_sql(query, connection, globals()) + result = _try_wasm_duckdb("sql", query, connection, globals()) if result is _NO_WASM_DUCKDB_RESULT: # No WASM rewrite was needed; use DuckDB's normal SQL path. relation = connection.sql(query=query) @@ -94,11 +93,11 @@ def wrapped_sql( else nullcontext ) with install_connection(connection): - result = _try_wasm_duckdb_sql( + result = _try_wasm_duckdb( + "sql", query, connection, ctx.globals, - reserved_names=tuple(ctx.globals), ) if result is _NO_WASM_DUCKDB_RESULT: # Run in kernel globals so DuckDB replacement scans see user data. @@ -113,38 +112,6 @@ def wrapped_sql( return relation -def _try_wasm_duckdb_execute( - query: str, - params: list[Any], - connection: Any, - glbls: dict[str, Any], - *, - reserved_names: tuple[str, ...] = (), -) -> object: - import duckdb - - if connection is duckdb: - original = duckdb.execute - args: tuple[Any, ...] = (query, params) - query_arg_index = 0 - else: - original = type(connection).execute - args = (connection, query, params) - query_arg_index = 1 - - result = try_run_duckdb_sql_with_wasm_patch( - original, - args, - {}, - query_arg_index=query_arg_index, - query_kwarg_names=("query",), - eval_globals=glbls, - eval_locals=glbls, - reserved_names=reserved_names, - ) - return _NO_WASM_DUCKDB_RESULT if result is None else result.value - - def execute_duckdb_sql( query: str, params: list[Any], @@ -164,7 +131,9 @@ def execute_duckdb_sql( try: ctx = get_context() except ContextNotInitializedError: - result = _try_wasm_duckdb_execute(query, params, connection, globals()) + result = _try_wasm_duckdb( + "execute", query, connection, globals(), params + ) if result is _NO_WASM_DUCKDB_RESULT: # No WASM rewrite was needed; preserve DuckDB's parameterized path. return connection.execute(query, params) @@ -176,12 +145,12 @@ def execute_duckdb_sql( else nullcontext ) with install_connection(connection): - result = _try_wasm_duckdb_execute( + result = _try_wasm_duckdb( + "execute", query, - params, connection, ctx.globals, - reserved_names=tuple(ctx.globals), + params, ) if result is _NO_WASM_DUCKDB_RESULT: # Run in kernel globals so parameterized SQL can scan user data. From 43a4ddd2a6254f4d28c41a084d27e34d6081217e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Tue, 12 May 2026 17:46:03 +0200 Subject: [PATCH 21/22] refactor: simplify DuckDB WASM source resolution --- marimo/_runtime/_wasm/_duckdb/io.py | 46 ++++-------------------- marimo/_runtime/_wasm/_duckdb/sources.py | 41 +++++++++------------ 2 files changed, 23 insertions(+), 64 deletions(-) diff --git a/marimo/_runtime/_wasm/_duckdb/io.py b/marimo/_runtime/_wasm/_duckdb/io.py index b0cab2190ad..508916ee3ae 100644 --- a/marimo/_runtime/_wasm/_duckdb/io.py +++ b/marimo/_runtime/_wasm/_duckdb/io.py @@ -45,11 +45,13 @@ class _FetchedBytes: @dataclass(frozen=True) class RemoteFile: url: str - fetcher_name: str def fetch(self) -> _FetchedBytes: - """Fetch through a named strategy so sources stay hashable.""" - return _fetcher_by_name(self.fetcher_name).fetch(self.url) + """Fetch through the browser-backed Pyodide HTTP shim.""" + return _FetchedBytes( + url=self.url, + data=_fetch.fetch_url_bytes(self.url), + ) @dataclass(frozen=True) @@ -59,14 +61,6 @@ class _ReadRequest: options: Mapping[str, Any] -class _ByteFetcher(Protocol): - name: str - - def can_fetch(self, url: str) -> bool: ... - - def fetch(self, url: str) -> _FetchedBytes: ... - - class _DataFrameReader(Protocol): name: str direct_extensions: tuple[str, ...] @@ -81,16 +75,6 @@ def read_options( def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: ... -class _HttpFetcher: - name: str = "http" - - def can_fetch(self, url: str) -> bool: - return urlparse(url).scheme in {"http", "https"} - - def fetch(self, url: str) -> _FetchedBytes: - return _FetchedBytes(url=url, data=_fetch.fetch_url_bytes(url)) - - class _CsvReader: name = "csv" direct_extensions: tuple[str, ...] = ( @@ -258,7 +242,6 @@ def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: return read_blob_dataframe(request.file.data, request.file.url) -_FETCHERS: tuple[_ByteFetcher, ...] = (_HttpFetcher(),) _READERS: tuple[_DataFrameReader, ...] = ( _CsvReader(), _ParquetReader(), @@ -322,10 +305,9 @@ def _read_file_dataframe(self, file: RemoteFile) -> pd.DataFrame: def remote_file_from_url(url: str) -> RemoteFile | None: """Return a fetchable remote file only for URL schemes marimo supports.""" - fetcher = _fetcher_for_url(url) - if fetcher is None: + if urlparse(url).scheme not in {"http", "https"}: return None - return RemoteFile(url=url, fetcher_name=fetcher.name) + return RemoteFile(url=url) def remote_file_source_from_reader_args( @@ -377,20 +359,6 @@ def _remote_files_from_source_arg( return None -def _fetcher_for_url(url: str) -> _ByteFetcher | None: - return next( - (fetcher for fetcher in _FETCHERS if fetcher.can_fetch(url)), - None, - ) - - -def _fetcher_by_name(name: str) -> _ByteFetcher: - for fetcher in _FETCHERS: - if fetcher.name == name: - return fetcher - raise KeyError(f"Unknown DuckDB WASM fetcher: {name}") - - def _reader_by_name( name: _ReaderName, ) -> _DataFrameReader: diff --git a/marimo/_runtime/_wasm/_duckdb/sources.py b/marimo/_runtime/_wasm/_duckdb/sources.py index ecbc36b59f1..8086cfa75cb 100644 --- a/marimo/_runtime/_wasm/_duckdb/sources.py +++ b/marimo/_runtime/_wasm/_duckdb/sources.py @@ -47,11 +47,11 @@ def remote_file_source_from_table( if table_expr is None: return None - function_name = _table_function_name(table_expr) - if function_name is None: + table_function = _table_function_call(table_expr) + if table_function is None: return None - args = _table_function_args(table_expr) + function_name, args = table_function source = _read_function_source(args) if source is None: return None @@ -99,35 +99,26 @@ def _query_has_single_quoted_table_reference( ) -def _table_function_name(table_expr: exp.Expression) -> str | None: - """Normalize sqlglot's version-dependent DuckDB reader node names.""" +def _table_function_call( + table_expr: exp.Expression, +) -> tuple[str, list[exp.Expression]] | None: + """Return a normalized DuckDB reader name and its arguments.""" import sqlglot.expressions as exp # sqlglot versions model first-party DuckDB readers either as explicit # Read* nodes or as generic anonymous table functions. - read_csv = getattr(exp, "ReadCSV", None) - if read_csv is not None and isinstance(table_expr, read_csv): - return "read_csv" - - read_parquet = getattr(exp, "ReadParquet", None) - if read_parquet is not None and isinstance(table_expr, read_parquet): - return "read_parquet" - - if isinstance(table_expr, exp.Anonymous): - return str(table_expr.this).lower() - return None - - -def _table_function_args(table_expr: exp.Expression) -> list[exp.Expression]: - """Return table-function args despite sqlglot reader-node differences.""" - import sqlglot.expressions as exp - - for node_name in ("ReadCSV", "ReadParquet"): + for node_name, function_name in ( + ("ReadCSV", "read_csv"), + ("ReadParquet", "read_parquet"), + ): read_node = getattr(exp, node_name, None) if read_node is not None and isinstance(table_expr, read_node): first = [table_expr.this] if table_expr.this is not None else [] - return [*first, *table_expr.expressions] - return list(table_expr.expressions) + return function_name, [*first, *table_expr.expressions] + + if isinstance(table_expr, exp.Anonymous): + return str(table_expr.this).lower(), list(table_expr.expressions) + return None def _read_function_source( From 7110d6acec00154a8edea42d42b2ca93e91a5319 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Ferenc=20Gyarmati?= Date: Tue, 12 May 2026 17:46:09 +0200 Subject: [PATCH 22/22] test: compact DuckDB WASM regression coverage --- tests/_runtime/test_duckdb_wasm.py | 112 +++++++++++++---------------- 1 file changed, 51 insertions(+), 61 deletions(-) diff --git a/tests/_runtime/test_duckdb_wasm.py b/tests/_runtime/test_duckdb_wasm.py index 213c84f9077..cd87c5d4262 100644 --- a/tests/_runtime/test_duckdb_wasm.py +++ b/tests/_runtime/test_duckdb_wasm.py @@ -516,17 +516,25 @@ def _query_parity_cases() -> list[QueryParityCase]: ] -class TestDuckDBWasmDirectReadPatch: - @staticmethod - def test_noop_outside_pyodide() -> None: - import duckdb +def test_patch_duckdb_for_wasm_noop_outside_pyodide() -> None: + import duckdb + + original_read_csv = duckdb.read_csv + original_sql = duckdb.sql + original_connection_sql = duckdb.DuckDBPyConnection.sql + + unpatch = patch_duckdb_for_wasm() + + assert duckdb.read_csv is original_read_csv + assert duckdb.sql is original_sql + assert duckdb.DuckDBPyConnection.sql is original_connection_sql + unpatch() + assert duckdb.read_csv is original_read_csv + assert duckdb.sql is original_sql + assert duckdb.DuckDBPyConnection.sql is original_connection_sql - original = duckdb.read_csv - unpatch = patch_duckdb_for_wasm() - assert duckdb.read_csv is original - unpatch() - assert duckdb.read_csv is original +class TestDuckDBWasmDirectReadPatch: @staticmethod def test_patch_installation_does_not_require_sqlglot() -> None: import duckdb @@ -626,7 +634,35 @@ def test_read_parquet_node_preserves_this_argument() -> None: ] @staticmethod - def test_direct_literal_without_token_metadata_is_remote_source() -> None: + @pytest.mark.parametrize( + ("query", "expected_source"), + [ + ( + "SELECT * FROM 'https://example.com/a.csv'", + True, + ), + ( + 'SELECT * FROM "https://example.com/a.csv"', + False, + ), + ( + """ + SELECT 1, 'https://example.com/a.csv' AS label + FROM "https://example.com/a.csv" + """, + False, + ), + ], + ids=[ + "direct-literal", + "double-quoted-identifier", + "string-literal-and-double-quoted-identifier", + ], + ) + def test_token_metadata_fallback_detects_single_quoted_remote_sources( + query: str, + expected_source: bool, + ) -> None: from sqlglot import exp table = exp.Table( @@ -635,52 +671,19 @@ def test_direct_literal_without_token_metadata_is_remote_source() -> None: source = remote_file_source_from_table( table, - query="SELECT * FROM 'https://example.com/a.csv'", + query=query, ) + if not expected_source: + assert source is None + return + assert source is not None assert source.reader_name == "csv" assert [file.url for file in source.files] == [ "https://example.com/a.csv" ] - @staticmethod - def test_double_quoted_identifier_without_token_metadata_is_not_remote_source() -> ( - None - ): - from sqlglot import exp - - table = exp.Table( - this=exp.Identifier(this="https://example.com/a.csv", quoted=True) - ) - - source = remote_file_source_from_table( - table, - query='SELECT * FROM "https://example.com/a.csv"', - ) - - assert source is None - - @staticmethod - def test_string_literal_without_token_metadata_does_not_make_identifier_remote() -> ( - None - ): - from sqlglot import exp - - table = exp.Table( - this=exp.Identifier(this="https://example.com/a.csv", quoted=True) - ) - - source = remote_file_source_from_table( - table, - query=""" - SELECT 1, 'https://example.com/a.csv' AS label - FROM "https://example.com/a.csv" - """, - ) - - assert source is None - @staticmethod @pytest.mark.parametrize( "function_name", @@ -1007,19 +1010,6 @@ def test_execute_duckdb_sql_rewrites_remote_literal_with_explicit_connection() - class TestDuckDBWasmSqlApiPatch: - @staticmethod - def test_noop_outside_pyodide() -> None: - import duckdb - - original_sql = duckdb.sql - original_connection_sql = duckdb.DuckDBPyConnection.sql - unpatch = patch_duckdb_for_wasm() - assert duckdb.sql is original_sql - assert duckdb.DuckDBPyConnection.sql is original_connection_sql - unpatch() - assert duckdb.sql is original_sql - assert duckdb.DuckDBPyConnection.sql is original_connection_sql - @staticmethod def test_module_sql_rewrites_remote_literal_and_preserves_params() -> None: import duckdb