diff --git a/frontend/src/core/islands/worker/worker.tsx b/frontend/src/core/islands/worker/worker.tsx index 6d4f4f04f29..090236f9b86 100644 --- a/frontend/src/core/islands/worker/worker.tsx +++ b/frontend/src/core/islands/worker/worker.tsx @@ -9,6 +9,7 @@ import { } from "rpc-anywhere"; import type { NotificationPayload } from "@/core/kernel/messages"; import type { ParentSchema } from "@/core/wasm/rpc"; +import { shouldLoadDuckDBPackages } from "@/core/wasm/utils"; import { TRANSPORT_ID } from "@/core/wasm/worker/constants"; import { getPyodideVersion } from "@/core/wasm/worker/getPyodideVersion"; import { MessageBuffer } from "@/core/wasm/worker/message-buffer"; @@ -85,8 +86,8 @@ const requestHandler = createRPCRequestHandler({ loadPackages: async (code: string) => { await pyodideReadyPromise; // Make sure loading is done - if (code.includes("mo.sql")) { - // Add pandas and duckdb to the code + if (shouldLoadDuckDBPackages(code)) { + // Add pandas and duckdb to the code for mo.sql and for remote duckdb sources code = `import pandas\n${code}`; code = `import duckdb\n${code}`; code = `import sqlglot\n${code}`; diff --git a/frontend/src/core/wasm/__tests__/utils.test.ts b/frontend/src/core/wasm/__tests__/utils.test.ts new file mode 100644 index 00000000000..604463b24ba --- /dev/null +++ b/frontend/src/core/wasm/__tests__/utils.test.ts @@ -0,0 +1,34 @@ +/* Copyright 2026 Marimo. All rights reserved. */ + +import { describe, expect, it } from "vitest"; +import { shouldLoadDuckDBPackages } from "../utils"; + +describe("shouldLoadDuckDBPackages", () => { + it("loads for mo.sql", () => { + expect(shouldLoadDuckDBPackages('df = mo.sql("SELECT 1")')).toBe(true); + }); + + it("loads for duckdb imports and usage", () => { + expect(shouldLoadDuckDBPackages("import duckdb")).toBe(true); + expect(shouldLoadDuckDBPackages("from duckdb import sql")).toBe(true); + expect(shouldLoadDuckDBPackages("import pandas, duckdb")).toBe(true); + expect(shouldLoadDuckDBPackages("rows = duckdb.sql('SELECT 1')")).toBe( + true, + ); + }); + + it("loads when package discovery found duckdb", () => { + expect( + shouldLoadDuckDBPackages("print('hello')", new Set(["duckdb"])), + ).toBe(true); + }); + + it("does not load for incidental duckdb text", () => { + expect(shouldLoadDuckDBPackages("name = 'duckdb'")).toBe(false); + expect(shouldLoadDuckDBPackages("# import duckdb")).toBe(false); + }); + + it("does not load without mo.sql, duckdb usage, or discovery", () => { + expect(shouldLoadDuckDBPackages("print('hello')")).toBe(false); + }); +}); diff --git a/frontend/src/core/wasm/utils.ts b/frontend/src/core/wasm/utils.ts index 2801b9d3e3f..b8f371477c5 100644 --- a/frontend/src/core/wasm/utils.ts +++ b/frontend/src/core/wasm/utils.ts @@ -10,3 +10,17 @@ export function isWasm(): boolean { document.querySelector("marimo-wasm") !== null ); } + +const DUCKDB_USAGE_PATTERN = + /(^|\n)\s*(?:import\s+[^\n#]*\bduckdb\b|from\s+duckdb\b|[^\n#]*\bduckdb\s*\.)/; + +export function shouldLoadDuckDBPackages( + code: string, + foundPackages?: ReadonlySet, +): boolean { + return ( + code.includes("mo.sql") || + DUCKDB_USAGE_PATTERN.test(code) || + foundPackages?.has("duckdb") === true + ); +} diff --git a/frontend/src/core/wasm/worker/bootstrap.ts b/frontend/src/core/wasm/worker/bootstrap.ts index 25dd7dd3d0b..a40e4fbb930 100644 --- a/frontend/src/core/wasm/worker/bootstrap.ts +++ b/frontend/src/core/wasm/worker/bootstrap.ts @@ -9,6 +9,7 @@ import { WasmFileSystem } from "./fs"; import { getMarimoWheel } from "./getMarimoWheel"; import { t } from "./tracer"; import type { SerializedBridge, WasmController } from "./types"; +import { shouldLoadDuckDBPackages } from "../utils"; const MAKE_SNAPSHOT = false; @@ -163,8 +164,8 @@ export class DefaultWasmController implements WasmController { private async loadNotebookDeps(code: string, foundPackages: Set) { const pyodide = this.requirePyodide; - if (code.includes("mo.sql")) { - // We need pandas and duckdb for mo.sql + if (shouldLoadDuckDBPackages(code, foundPackages)) { + // We need pandas and duckdb for mo.sql and for remote duckdb sources code = `import pandas\n${code}`; code = `import duckdb\n${code}`; code = `import sqlglot\n${code}`; diff --git a/frontend/src/core/wasm/worker/worker.ts b/frontend/src/core/wasm/worker/worker.ts index e7890c9d0bd..c8a880f6a48 100644 --- a/frontend/src/core/wasm/worker/worker.ts +++ b/frontend/src/core/wasm/worker/worker.ts @@ -34,6 +34,7 @@ import type { SerializedBridge, WasmController, } from "./types"; +import { shouldLoadDuckDBPackages } from "../utils"; /** * Web worker responsible for running the notebook. @@ -141,8 +142,8 @@ const requestHandler = createRPCRequestHandler({ const span = t.startSpan("loadPackages"); await pyodideReadyPromise; // Make sure loading is done - if (code.includes("mo.sql")) { - // Add pandas and duckdb to the code + if (shouldLoadDuckDBPackages(code)) { + // Add pandas and duckdb to the code for mo.sql and for remote duckdb sources code = `import pandas\n${code}`; code = `import duckdb\n${code}`; code = `import sqlglot\n${code}`; diff --git a/marimo/_output/formatters/df_formatters.py b/marimo/_output/formatters/df_formatters.py index 71a17705cce..f7ebe708fe2 100644 --- a/marimo/_output/formatters/df_formatters.py +++ b/marimo/_output/formatters/df_formatters.py @@ -19,6 +19,7 @@ from marimo._plugins.stateless.plain_text import plain_text from marimo._plugins.ui._impl import tabs from marimo._plugins.ui._impl.table import get_default_table_page_size, table +from marimo._runtime._wasm._duckdb import patch_duckdb_for_wasm from marimo._runtime._wasm._polars import patch_polars_for_wasm LOGGER = _loggers.marimo_logger() @@ -160,6 +161,17 @@ def _show_marimo_dataframe( return table(df, selection=None, pagination=None)._mime_() +class DuckDBFormatter(FormatterFactory): + """Use DuckDB's lazy import hook to install WASM runtime patches.""" + + @staticmethod + def package_name() -> str: + return "duckdb" + + def register(self) -> Unregister: + return patch_duckdb_for_wasm() + + class DataFusionFormatter(FormatterFactory): @staticmethod def package_name() -> str: diff --git a/marimo/_output/formatters/formatters.py b/marimo/_output/formatters/formatters.py index 5a1ae706e56..ee375220989 100644 --- a/marimo/_output/formatters/formatters.py +++ b/marimo/_output/formatters/formatters.py @@ -18,6 +18,7 @@ from marimo._output.formatters.cell import CellFormatter from marimo._output.formatters.df_formatters import ( DataFusionFormatter, + DuckDBFormatter, IbisFormatter, PolarsFormatter, PyArrowFormatter, @@ -54,6 +55,7 @@ AltairFormatter.package_name(): AltairFormatter(), MatplotlibFormatter.package_name(): MatplotlibFormatter(), DataFusionFormatter.package_name(): DataFusionFormatter(), + DuckDBFormatter.package_name(): DuckDBFormatter(), IbisFormatter.package_name(): IbisFormatter(), PandasFormatter.package_name(): PandasFormatter(), PolarsFormatter.package_name(): PolarsFormatter(), diff --git a/marimo/_runtime/_wasm/_duckdb/__init__.py b/marimo/_runtime/_wasm/_duckdb/__init__.py new file mode 100644 index 00000000000..1d1540692d5 --- /dev/null +++ b/marimo/_runtime/_wasm/_duckdb/__init__.py @@ -0,0 +1,826 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Install WASM-only DuckDB fallbacks for remote file scans. + +DuckDB-WASM cannot use ``httpfs``. marimo fetches supported URLs itself, +materializes supported files as pandas DataFrames, then hands those frames +back to DuckDB through replacement scans. + +We have two concrete use cases to patch: + +* **Direct read methods** — ``duckdb.read_csv/read_parquet/read_json`` are + wrapped with :class:`WasmPatchSet`. Supported remote URLs are fetched by + marimo and returned as DuckDB relations. +* **SQL remote scans** — raw DuckDB APIs and marimo's ``mo.sql`` path call the + same sqlglot rewrite helper. It replaces supported URL scans with generated + table names and evaluates the original DuckDB call with fetched DataFrames + in scope so DuckDB replacement scans can resolve them. +""" + +from __future__ import annotations + +import functools +import inspect +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, NamedTuple + +from marimo import _loggers +from marimo._dependencies.dependencies import DependencyManager +from marimo._runtime._wasm._duckdb.io import ( + RemoteFileSource, + remote_file_source_from_reader_args, +) +from marimo._runtime._wasm._duckdb.sources import ( + remote_file_source_from_table, +) +from marimo._runtime._wasm._patches import ( + Unpatch, + WasmPatchSet, + WrapperFactory, +) +from marimo._utils.platform import is_pyodide + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping, Sequence + + import pandas as pd + from sqlglot import exp + +LOGGER = _loggers.marimo_logger() + + +class _EvalBindingNames(NamedTuple): + original: str + args: str + kwargs: str + + +class _SqlApiSpec(NamedTuple): + query_positional_index: int + query_keyword_names: tuple[str, ...] + + +class _DirectReaderSpec(NamedTuple): + source_keyword_names: tuple[str, ...] + + +class _DirectReaderCallSpec(NamedTuple): + source_positional_index: int + connection_positional_index: int | None + + +# DuckDB SQL APIs can receive non-string query objects. This sentinel lets us +# tell an omitted query argument apart from a present value such as None. +_MISSING = object() +_SQL_CALL_EXPRESSION = "{original}(*{args}, **{kwargs})" + +# The SQL wrappers invoke the original DuckDB callable through eval so DuckDB +# can still see caller-local replacement scans. These are the local binding +# names used for that eval call after collision checks. +_EVAL_BINDING_NAME_BASES = _EvalBindingNames( + original="__marimo_wasm_duckdb_original", + args="__marimo_wasm_duckdb_args", + kwargs="__marimo_wasm_duckdb_kwargs", +) + +# Module-level DuckDB SQL functions put the SQL string in different argument +# slots. The specs identify where wrappers should look for the query text. +_MODULE_SQL_FUNCTIONS: dict[str, _SqlApiSpec] = { + "sql": _SqlApiSpec( + query_positional_index=0, + query_keyword_names=("query",), + ), + "query": _SqlApiSpec( + query_positional_index=0, + query_keyword_names=("query",), + ), + "execute": _SqlApiSpec( + query_positional_index=0, + query_keyword_names=("query",), + ), + "query_df": _SqlApiSpec( + query_positional_index=2, + query_keyword_names=("sql_query", "query"), + ), +} + +# Bound connection methods include the connection as args[0], so their query +# argument starts one slot later than the module-level functions. +_CONNECTION_SQL_METHODS: dict[str, _SqlApiSpec] = { + "sql": _SqlApiSpec( + query_positional_index=1, + query_keyword_names=("query",), + ), + "query": _SqlApiSpec( + query_positional_index=1, + query_keyword_names=("query",), + ), + "execute": _SqlApiSpec( + query_positional_index=1, + query_keyword_names=("query",), + ), +} + +# Direct reader APIs use reader-specific keyword names for the file source. +_DIRECT_READER_SPECS: dict[str, _DirectReaderSpec] = { + "read_csv": _DirectReaderSpec( + source_keyword_names=("path_or_buffer", "source", "file", "path"), + ), + "read_parquet": _DirectReaderSpec( + source_keyword_names=( + "file_glob", + "file_globs", + "source", + "file", + "path", + ), + ), + "read_json": _DirectReaderSpec( + source_keyword_names=("path_or_buffer", "source", "file", "path"), + ), +} + +# Module-level direct readers receive the file source as their first positional +# argument. They may also receive an explicit connection= keyword. +_MODULE_DIRECT_READER_CALL = _DirectReaderCallSpec( + source_positional_index=0, + connection_positional_index=None, +) + +# Connection direct readers receive the bound connection as args[0], followed +# by the file source as args[1]. +_CONNECTION_DIRECT_READER_CALL = _DirectReaderCallSpec( + source_positional_index=1, + connection_positional_index=0, +) + + +@dataclass(frozen=True) +class WasmDuckDBQueryPatch: + """A rewritten DuckDB query and the DataFrames it references.""" + + query: str + tables: Mapping[str, pd.DataFrame] + + +@dataclass(frozen=True) +class WasmDuckDBSqlResult: + """Result of a DuckDB SQL API call that was rewritten for WASM.""" + + value: Any + + +class _RemoteTableNames: + """Track URL sources and generated table names for one SQL rewrite.""" + + def __init__(self, reserved_names: Sequence[str]) -> None: + self._reserved_names = { + _duckdb_identifier_key(name) for name in reserved_names + } + self._names_by_source: dict[RemoteFileSource, str] = {} + + def __bool__(self) -> bool: + return bool(self._names_by_source) + + def name_for(self, source: RemoteFileSource) -> str: + name = self._names_by_source.get(source) + if name is not None: + return name + + idx = len(self._names_by_source) + while True: + candidate = f"__marimo_wasm_duckdb_remote_{idx}" + candidate_key = _duckdb_identifier_key(candidate) + if candidate_key not in self._reserved_names: + self._names_by_source[source] = candidate + self._reserved_names.add(candidate_key) + return candidate + idx += 1 + + def read_dataframes(self) -> dict[str, pd.DataFrame]: + return { + table_name: source.read_dataframe() + for source, table_name in self._names_by_source.items() + } + + +def _duckdb_identifier_key(name: str) -> str: + # DuckDB preserves identifier case but resolves names case-insensitively. + return name.casefold() + + +def patch_duckdb_query_for_wasm( + query: str, + *, + statements: Sequence[exp.Expression] | None = None, + reserved_names: Sequence[str] = (), +) -> WasmDuckDBQueryPatch | None: + """Replace supported remote file reads with generated table names. + + For example, ``SELECT * FROM read_csv('https://example.com/cars.csv')`` + becomes ``SELECT * FROM __marimo_wasm_duckdb_remote_0`` when suffix ``0`` + is free. The returned ``tables`` mapping binds that name to the fetched + DataFrame. If the query or ``reserved_names`` already use that identifier, + the rewriter uses the next free suffix. + + In Pyodide this raises if sqlglot is unavailable and the query may contain + a remote source. Returns ``None`` when: + + - marimo is not running in Pyodide; + - the query has no remote URL; + - the query has no supported remote file source; + - the query cannot be parsed. + """ + if not is_pyodide(): + return None + + if statements is None: + if not _query_may_contain_remote_file_source(query): + return None + _require_sqlglot() + statements = _parse_duckdb_query(query) + if statements is None: + return None + if _contains_remote_view_definition(statements, query=query): + return None + + table_names = _RemoteTableNames( + (*reserved_names, *_reserved_sql_names(statements)) + ) + patched_statements = _replace_remote_sources( + statements, table_names, query=query + ) + if not table_names: + return None + + _require_pandas() + + return WasmDuckDBQueryPatch( + query=_format_duckdb_query(patched_statements, original_query=query), + tables=table_names.read_dataframes(), + ) + + +def patch_duckdb_for_wasm() -> Unpatch: + """Install WASM fallbacks for DuckDB remote file and SQL APIs.""" + if not is_pyodide(): + return lambda: None + + try: + import duckdb + except ImportError: + return lambda: None + + patches = WasmPatchSet() + for function_name in _DIRECT_READER_SPECS: + patches.replace( + duckdb, + function_name, + _make_direct_reader_wrapper( + function_name, + call_spec=_MODULE_DIRECT_READER_CALL, + ), + ) + patches.replace( + duckdb.DuckDBPyConnection, + function_name, + _make_direct_reader_wrapper( + function_name, + call_spec=_CONNECTION_DIRECT_READER_CALL, + ), + ) + for function_name, spec in _MODULE_SQL_FUNCTIONS.items(): + patches.replace( + duckdb, + function_name, + _make_sql_api_wrapper( + query_arg_index=spec.query_positional_index, + query_kwarg_names=spec.query_keyword_names, + ), + ) + for method_name, spec in _CONNECTION_SQL_METHODS.items(): + patches.replace( + duckdb.DuckDBPyConnection, + method_name, + _make_sql_api_wrapper( + query_arg_index=spec.query_positional_index, + query_kwarg_names=spec.query_keyword_names, + ), + ) + return patches.unpatch_all() + + +def run_duckdb_sql_with_wasm_patch( + original: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + query_arg_index: int, + query_kwarg_names: tuple[str, ...], + eval_globals: dict[str, Any], + eval_locals: Mapping[str, Any], + reserved_names: Sequence[str] = (), +) -> Any: + """Run a DuckDB SQL API call after rewriting supported remote scans.""" + wasm_result = try_run_duckdb_sql_with_wasm_patch( + original, + args, + kwargs, + query_arg_index=query_arg_index, + query_kwarg_names=query_kwarg_names, + eval_globals=eval_globals, + eval_locals=eval_locals, + reserved_names=reserved_names, + ) + if wasm_result is not None: + return wasm_result.value + + kwargs_dict = dict(kwargs) + binding_names = _eval_binding_names( + _reserved_namespace_names( + eval_globals, + eval_locals, + ( + *reserved_names, + *_identifier_string_args(args), + *_identifier_string_args(tuple(kwargs_dict.values())), + ), + ) + ) + return _eval_duckdb_original_call( + original, + args, + kwargs_dict, + eval_globals=eval_globals, + eval_locals=eval_locals, + binding_names=binding_names, + ) + + +def try_run_duckdb_sql_with_wasm_patch( + original: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + query_arg_index: int, + query_kwarg_names: tuple[str, ...], + eval_globals: dict[str, Any], + eval_locals: Mapping[str, Any], + reserved_names: Sequence[str] = (), +) -> WasmDuckDBSqlResult | None: + """Run only if a DuckDB SQL API call needs a WASM rewrite.""" + if not is_pyodide(): + return None + + kwargs_dict = dict(kwargs) + query = _query_argument( + args, + kwargs_dict, + query_arg_index=query_arg_index, + query_kwarg_names=query_kwarg_names, + ) + if not isinstance(query, str): + return None + if not _query_may_contain_remote_file_source(query): + return None + + _require_sqlglot() + statements = _parse_duckdb_query(query) + if statements is None or not _contains_supported_remote_source( + statements, query=query + ): + return None + if _contains_remote_view_definition(statements, query=query): + return None + + namespace_names = _reserved_namespace_names( + eval_globals, + eval_locals, + ( + *reserved_names, + *_duckdb_catalog_names(original, args, kwargs_dict), + *_identifier_string_args(args), + *_identifier_string_args(tuple(kwargs_dict.values())), + ), + ) + binding_names = _eval_binding_names(namespace_names) + + wasm_patch = patch_duckdb_query_for_wasm( + query, + statements=statements, + reserved_names=( + *namespace_names, + binding_names.original, + binding_names.args, + binding_names.kwargs, + ), + ) + if wasm_patch is None: + return None + + patched_args, patched_kwargs = _replace_query_argument( + args, + kwargs_dict, + patched_query=wasm_patch.query, + query_arg_index=query_arg_index, + query_kwarg_names=query_kwarg_names, + ) + return WasmDuckDBSqlResult( + _eval_duckdb_original_call( + original, + patched_args, + patched_kwargs, + eval_globals=eval_globals, + eval_locals=eval_locals, + binding_names=binding_names, + extra_locals=wasm_patch.tables, + ) + ) + + +def _make_sql_api_wrapper( + *, + query_arg_index: int, + query_kwarg_names: tuple[str, ...], +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Wrap DuckDB SQL APIs while preserving the caller's local namespace.""" + + def _make_wrapper(original: Callable[..., Any]) -> Callable[..., Any]: + @functools.wraps(original) + def _wrapper(*args: Any, **kwargs: Any) -> Any: + frame = inspect.currentframe() + if frame is None or frame.f_back is None: + return original(*args, **kwargs) + caller_frame = frame.f_back + eval_globals = caller_frame.f_globals + eval_locals = caller_frame.f_locals + try: + return run_duckdb_sql_with_wasm_patch( + original, + args, + kwargs, + query_arg_index=query_arg_index, + query_kwarg_names=query_kwarg_names, + eval_globals=eval_globals, + eval_locals=eval_locals, + ) + finally: + del caller_frame + del frame + + return _wrapper + + return _make_wrapper + + +def _query_argument( + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + query_arg_index: int, + query_kwarg_names: tuple[str, ...], +) -> Any: + """Find the query text without assuming the caller used positional args.""" + if len(args) > query_arg_index: + return args[query_arg_index] + for name in query_kwarg_names: + if name in kwargs: + return kwargs[name] + return _MISSING + + +def _replace_query_argument( + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + patched_query: str, + query_arg_index: int, + query_kwarg_names: tuple[str, ...], +) -> tuple[tuple[Any, ...], dict[str, Any]]: + """Replace the query at the same call site shape used by the caller.""" + kwargs_dict = dict(kwargs) + if len(args) > query_arg_index: + patched_args = list(args) + patched_args[query_arg_index] = patched_query + return tuple(patched_args), kwargs_dict + + for name in query_kwarg_names: + if name in kwargs_dict: + kwargs_dict[name] = patched_query + return args, kwargs_dict + return args, kwargs_dict + + +def _eval_duckdb_original_call( + original: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + eval_globals: dict[str, Any], + eval_locals: Mapping[str, Any], + binding_names: _EvalBindingNames, + extra_locals: Mapping[str, Any] | None = None, +) -> Any: + """Evaluate the original call where DuckDB can see replacement DataFrames.""" + locals_for_eval = dict(eval_locals) + if extra_locals is not None: + locals_for_eval.update(extra_locals) + locals_for_eval[binding_names.original] = inspect.unwrap(original) + locals_for_eval[binding_names.args] = args + locals_for_eval[binding_names.kwargs] = dict(kwargs) + + return eval( + _SQL_CALL_EXPRESSION.format( + original=binding_names.original, + args=binding_names.args, + kwargs=binding_names.kwargs, + ), + eval_globals, + locals_for_eval, + ) + + +def _reserved_namespace_names( + eval_globals: Mapping[str, Any], + eval_locals: Mapping[str, Any], + reserved_names: Sequence[str], +) -> tuple[str, ...]: + names = set(reserved_names) + names.update(eval_globals) + names.update(eval_locals) + return tuple(names) + + +def _identifier_string_args(args: tuple[Any, ...]) -> tuple[str, ...]: + return tuple( + arg for arg in args if isinstance(arg, str) and arg.isidentifier() + ) + + +def _eval_binding_names(reserved_names: Sequence[str]) -> _EvalBindingNames: + used = set(reserved_names) + original = _unused_name(_EVAL_BINDING_NAME_BASES.original, used) + used.add(original) + args = _unused_name(_EVAL_BINDING_NAME_BASES.args, used) + used.add(args) + kwargs = _unused_name(_EVAL_BINDING_NAME_BASES.kwargs, used) + return _EvalBindingNames(original=original, args=args, kwargs=kwargs) + + +def _unused_name(base: str, used: set[str]) -> str: + if base not in used: + return base + + idx = 0 + while True: + candidate = f"{base}_{idx}" + if candidate not in used: + return candidate + idx += 1 + + +def _duckdb_catalog_names( + original: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], +) -> tuple[str, ...]: + """Reserve existing DuckDB table names before generating replacements.""" + try: + relation = _show_duckdb_tables(original, args, kwargs) + rows = relation.fetchall() + except Exception: + return () + return tuple(str(row[0]) for row in rows if row and row[0] is not None) + + +def _show_duckdb_tables( + original: Callable[..., Any], + args: tuple[Any, ...], + kwargs: Mapping[str, Any], +) -> Any: + """Run ``SHOW TABLES`` through the same DuckDB entry point being patched.""" + import duckdb + + original_call = inspect.unwrap(original) + if args and isinstance(args[0], duckdb.DuckDBPyConnection): + return inspect.unwrap(type(args[0]).sql)(args[0], "SHOW TABLES") + connection = kwargs.get("connection") + if isinstance(connection, duckdb.DuckDBPyConnection): + return inspect.unwrap(type(connection).sql)(connection, "SHOW TABLES") + if original_call is inspect.unwrap(duckdb.query_df): + return inspect.unwrap(duckdb.sql)("SHOW TABLES") + return original_call("SHOW TABLES") + + +def _make_direct_reader_wrapper( + function_name: str, + *, + call_spec: _DirectReaderCallSpec, +) -> WrapperFactory: + def _wrap(original: Callable[..., Any]) -> Callable[..., Any]: + @functools.wraps(original) + def _wrapper(*args: Any, **kwargs: Any) -> Any: + source_info = _direct_reader_source( + function_name, + args, + kwargs, + call_spec=call_spec, + ) + if source_info is None: + return original(*args, **kwargs) + + DependencyManager.pandas.require( + f"to read DuckDB {function_name} sources in WASM" + ) + + source, connection = source_info + import duckdb + + df = source.read_dataframe() + if connection is None: + return duckdb.from_df(df) + return duckdb.from_df(df, connection=connection) + + return _wrapper + + return _wrap + + +def _direct_reader_source( + function_name: str, + args: tuple[Any, ...], + kwargs: Mapping[str, Any], + *, + call_spec: _DirectReaderCallSpec, +) -> tuple[RemoteFileSource, Any] | None: + """Return a remote source only for direct reader calls we can emulate.""" + options = dict(kwargs) + try: + source, rest_args = _pop_source_argument( + function_name, + args, + options, + call_spec=call_spec, + ) + except TypeError: + return None + if rest_args: + return None + + if call_spec.connection_positional_index is None: + connection = options.pop("connection", None) + else: + connection = args[call_spec.connection_positional_index] + source_info = remote_file_source_from_reader_args( + function_name, source, options + ) + if source_info is None: + return None + return source_info, connection + + +def _pop_source_argument( + function_name: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], + *, + call_spec: _DirectReaderCallSpec, +) -> tuple[Any, tuple[Any, ...]]: + """Remove the source argument so remaining kwargs are pure reader options.""" + source_positional_index = call_spec.source_positional_index + if len(args) > source_positional_index: + return ( + args[source_positional_index], + args[source_positional_index + 1 :], + ) + + for key in _DIRECT_READER_SPECS[function_name].source_keyword_names: + if key in kwargs: + return kwargs.pop(key), args[source_positional_index + 1 :] + + raise TypeError(f"Missing source argument for duckdb.{function_name}") + + +def _require_pandas() -> None: + DependencyManager.pandas.require( + "to read remote DuckDB file sources in WASM" + ) + import pandas # noqa: F401 + + +def _require_sqlglot() -> None: + DependencyManager.sqlglot.require( + "to rewrite remote DuckDB SQL sources in WASM" + ) + import sqlglot # noqa: F401 + + +def _query_may_contain_remote_file_source(query: str) -> bool: + return "https://" in query or "http://" in query + + +def _parse_duckdb_query(query: str) -> list[exp.Expression] | None: + import sqlglot + from sqlglot import exp as sqlglot_exp + + try: + parsed = sqlglot.parse(query, read="duckdb") + except Exception as e: + LOGGER.debug("Failed to parse DuckDB query for WASM patch: %s", e) + return None + + return [ + statement + for statement in parsed + if isinstance(statement, sqlglot_exp.Expression) + ] + + +def _contains_supported_remote_source( + statements: Sequence[exp.Expression], + *, + query: str, +) -> bool: + """Check for rewrite work before paying for DuckDB catalog inspection.""" + from sqlglot import exp + + return any( + remote_file_source_from_table(table, query=query) is not None + for statement in statements + for table in statement.find_all(exp.Table) + ) + + +def _contains_remote_view_definition( + statements: Sequence[exp.Expression], + *, + query: str, +) -> bool: + """Views persist SQL text, so replacement-scan locals would go stale.""" + from sqlglot import exp + + return any( + isinstance(statement, exp.Create) + and str(statement.args.get("kind")).upper() == "VIEW" + and _contains_supported_remote_source((statement,), query=query) + for statement in statements + ) + + +def _replace_remote_sources( + statements: Sequence[exp.Expression], + table_names: _RemoteTableNames, + *, + query: str, +) -> list[exp.Expression]: + """Replace supported remote table nodes while preserving aliases.""" + from sqlglot import exp + + def replace_table(node: exp.Expression) -> exp.Expression: + if not isinstance(node, exp.Table): + return node + + source = remote_file_source_from_table(node, query=query) + if source is None: + return node + + replacement = exp.Table( + this=exp.to_identifier(table_names.name_for(source)) + ) + alias = node.args.get("alias") + if alias is not None: + replacement.set("alias", alias.copy()) + return replacement + + return [ + statement.transform(replace_table, copy=True) + for statement in statements + ] + + +def _reserved_sql_names( + statements: Sequence[exp.Expression], +) -> tuple[str, ...]: + """Collect SQL identifiers that generated table names must not shadow.""" + from sqlglot import exp + + names: set[str] = set() + for statement in statements: + for identifier in statement.find_all(exp.Identifier): + if isinstance(identifier.this, str): + names.add(identifier.this) + for table in statement.find_all(exp.Table): + if table.name: + names.add(table.name) + return tuple(names) + + +def _format_duckdb_query( + statements: Sequence[exp.Expression], *, original_query: str +) -> str: + """Serialize sqlglot statements without dropping a trailing semicolon.""" + patched_query = "; ".join( + statement.sql(dialect="duckdb") for statement in statements + ) + if original_query.rstrip().endswith(";"): + patched_query += ";" + return patched_query diff --git a/marimo/_runtime/_wasm/_duckdb/dataframe.py b/marimo/_runtime/_wasm/_duckdb/dataframe.py new file mode 100644 index 00000000000..27ebe070a40 --- /dev/null +++ b/marimo/_runtime/_wasm/_duckdb/dataframe.py @@ -0,0 +1,191 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Decode fetched DuckDB file bytes into pandas DataFrames. + +The WASM patch fetches remote bytes in Python, but DuckDB's Python readers +still expect local paths for CSV, JSON, and parquet parsing. This module +materializes fetched bytes through short-lived temp files where DuckDB parsing +is needed and synthesizes DataFrames for ``read_text`` and ``read_blob``. +""" + +from __future__ import annotations + +import os +import tempfile +from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + import pandas as pd + +_CSV_SUFFIXES = (".csv.gz", ".tsv.gz", ".csv", ".tsv") +_JSON_SUFFIXES = ( + ".geojson.gz", + ".ndjson.gz", + ".jsonl.gz", + ".json.gz", + ".geojson", + ".ndjson", + ".jsonl", + ".json", +) +_JSON_OBJECT_FUNCTIONS = frozenset( + { + "read_json_objects", + "read_json_objects_auto", + "read_ndjson_objects", + } +) + + +def read_csv_dataframe( + data: bytes, options: Mapping[str, Any], *, url: str +) -> pd.DataFrame: + """Read CSV/TSV bytes with a suffix that preserves compression hints.""" + import duckdb + + return _read_temp_dataframe( + data, + suffix=_temp_suffix( + url, + suffixes=_CSV_SUFFIXES, + default=".csv", + ), + reader=lambda path: duckdb.from_csv_auto(path, **options).df(), + ) + + +def read_parquet_dataframe(data: bytes) -> pd.DataFrame: + import duckdb + + return _read_temp_dataframe( + data, + suffix=".parquet", + reader=lambda path: duckdb.read_parquet(path).df(), + ) + + +def read_json_dataframe( + data: bytes, options: Mapping[str, Any], *, url: str +) -> pd.DataFrame: + import duckdb + + return _read_temp_dataframe( + data, + suffix=_temp_suffix( + url, + suffixes=_JSON_SUFFIXES, + default=".json", + ), + reader=lambda path: duckdb.read_json(path, **options).df(), + ) + + +def read_json_objects_dataframe( + data: bytes, options: Mapping[str, Any], *, url: str, function_name: str +) -> pd.DataFrame: + """Read JSON-object bytes through DuckDB's SQL-only table function.""" + if function_name not in _JSON_OBJECT_FUNCTIONS: + raise ValueError( + f"Unsupported DuckDB JSON object reader: {function_name}" + ) + + return _read_temp_dataframe( + data, + suffix=_temp_suffix( + url, + suffixes=_JSON_SUFFIXES, + default=".json", + ), + reader=lambda path: _read_json_objects_path( + path, options, function_name + ), + ) + + +def _read_json_objects_path( + path: str, options: Mapping[str, Any], function_name: str +) -> pd.DataFrame: + import duckdb + + # DuckDB exposes JSON-object readers as SQL table functions, not Python + # module methods, so invoke the table function with bound parameters. + option_items = tuple(options.items()) + query_args = ["?"] + query_args.extend(f"{key} := ?" for key, _ in option_items) + return duckdb.sql( + f"SELECT * FROM {function_name}({', '.join(query_args)})", + params=[path, *(value for _, value in option_items)], + ).df() + + +def read_text_dataframe(data: bytes, url: str) -> pd.DataFrame: + """Match DuckDB's ``read_text`` shape for an already-fetched object.""" + import pandas as pd + + return pd.DataFrame( + { + "filename": [url], + "content": [data.decode("utf-8")], + "size": [len(data)], + "last_modified": [pd.NaT], + } + ) + + +def read_blob_dataframe(data: bytes, url: str) -> pd.DataFrame: + """Match DuckDB's ``read_blob`` shape for an already-fetched object.""" + import pandas as pd + + return pd.DataFrame( + { + "filename": [url], + "content": [data], + "size": [len(data)], + "last_modified": [pd.NaT], + } + ) + + +def append_filename_column( + df: pd.DataFrame, url: str, column_name: str +) -> pd.DataFrame: + """Apply DuckDB's filename option after bytes have been decoded.""" + if column_name in df.columns: + raise ValueError( + f'Option filename adds column "{column_name}", but a column with this ' + "name is also in the file" + ) + + df = df.copy() + df[column_name] = url + return df + + +def _read_temp_dataframe( + data: bytes, + *, + suffix: str, + reader: Callable[[str], pd.DataFrame], +) -> pd.DataFrame: + """Materialize fetched bytes so DuckDB can parse them from a local path.""" + fd, path = tempfile.mkstemp(suffix=suffix) + try: + with os.fdopen(fd, "wb") as file: + file.write(data) + return reader(path) + finally: + try: + os.unlink(path) + except FileNotFoundError: + pass + + +def _temp_suffix(url: str, *, suffixes: tuple[str, ...], default: str) -> str: + """Preserve file extensions so DuckDB can infer format details.""" + path = urlparse(url).path.lower() + for suffix in suffixes: + if path.endswith(suffix): + return suffix + return default diff --git a/marimo/_runtime/_wasm/_duckdb/io.py b/marimo/_runtime/_wasm/_duckdb/io.py new file mode 100644 index 00000000000..508916ee3ae --- /dev/null +++ b/marimo/_runtime/_wasm/_duckdb/io.py @@ -0,0 +1,504 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Resolve DuckDB remote file reads into fetch and DataFrame reader steps. + +DuckDB's native network scanner is unavailable in Pyodide. This module is the +compatibility layer that recognizes supported URL shapes, validates the reader +options we can reproduce, fetches bytes through WASM fetch shim, and +dispatches to a local DataFrame reader. Unsupported readers or option +combinations return ``None`` so callers can use the unpatched DuckDB path +and surface the underlying error. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Protocol +from urllib.parse import urlparse + +from marimo._runtime._wasm import _fetch +from marimo._runtime._wasm._duckdb.dataframe import ( + append_filename_column, + read_blob_dataframe, + read_csv_dataframe, + read_json_dataframe, + read_json_objects_dataframe, + read_parquet_dataframe, + read_text_dataframe, +) + +if TYPE_CHECKING: + from collections.abc import Mapping + + import pandas as pd + + +_ReaderName = str + + +@dataclass(frozen=True) +class _FetchedBytes: + url: str + data: bytes + + +@dataclass(frozen=True) +class RemoteFile: + url: str + + def fetch(self) -> _FetchedBytes: + """Fetch through the browser-backed Pyodide HTTP shim.""" + return _FetchedBytes( + url=self.url, + data=_fetch.fetch_url_bytes(self.url), + ) + + +@dataclass(frozen=True) +class _ReadRequest: + file: _FetchedBytes + function_name: str + options: Mapping[str, Any] + + +class _DataFrameReader(Protocol): + name: str + direct_extensions: tuple[str, ...] + function_names: tuple[str, ...] + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: ... + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: ... + + +class _CsvReader: + name = "csv" + direct_extensions: tuple[str, ...] = ( + ".csv", + ".csv.gz", + ".tsv", + ".tsv.gz", + ) + function_names: tuple[str, ...] = ("read_csv", "read_csv_auto") + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + del function_name + options: dict[str, Any] = {} + for key, value in raw_options.items(): + if key in {"delim", "delimiter", "sep", "separator"}: + options["delimiter"] = _normalize_delimiter(value) + elif key == "header": + options["header"] = value + elif key in {"auto_detect", "normalize_names"}: + options[key] = value + elif _apply_shared_source_option(options, key, value): + pass + else: + return None + return options + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_csv_dataframe( + request.file.data, + _csv_reader_options(request.options), + url=request.file.url, + ) + + +class _ParquetReader: + name = "parquet" + direct_extensions: tuple[str, ...] = (".parquet", ".parq") + function_names: tuple[str, ...] = ("read_parquet", "parquet_scan") + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + del function_name + options: dict[str, Any] = {} + for key, value in raw_options.items(): + if _apply_common_table_option(options, key, value): + pass + else: + return None + return options + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_parquet_dataframe(request.file.data) + + +class _JsonReader: + name = "json" + direct_extensions: tuple[str, ...] = ( + ".json", + ".json.gz", + ".geojson", + ".geojson.gz", + ".jsonl", + ".jsonl.gz", + ".ndjson", + ".ndjson.gz", + ) + function_names: tuple[str, ...] = ( + "read_json", + "read_json_auto", + "read_ndjson", + ) + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + options: dict[str, Any] = {} + if function_name == "read_ndjson": + options["format"] = "newline_delimited" + + for key, value in raw_options.items(): + if not _apply_json_option(options, key, value): + return None + return options + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_json_dataframe( + request.file.data, + _json_reader_options(request.options), + url=request.file.url, + ) + + +class _JsonObjectsReader: + name = "json_objects" + direct_extensions: tuple[str, ...] = () + function_names: tuple[str, ...] = ( + "read_json_objects", + "read_json_objects_auto", + "read_ndjson_objects", + ) + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + options: dict[str, Any] = {} + if function_name == "read_ndjson_objects": + options["format"] = "newline_delimited" + + for key, value in raw_options.items(): + if not _apply_json_option(options, key, value): + return None + return options + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_json_objects_dataframe( + request.file.data, + _json_reader_options(request.options), + url=request.file.url, + function_name=request.function_name, + ) + + +class _TextReader: + name = "text" + direct_extensions: tuple[str, ...] = () + function_names: tuple[str, ...] = ("read_text",) + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + del function_name + return {} if not raw_options else None + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_text_dataframe(request.file.data, request.file.url) + + +class _BlobReader: + name = "blob" + direct_extensions: tuple[str, ...] = () + function_names: tuple[str, ...] = ("read_blob",) + + def read_options( + self, + function_name: str, + raw_options: Mapping[str, Any], + ) -> dict[str, Any] | None: + del function_name + return {} if not raw_options else None + + def read_dataframe(self, request: _ReadRequest) -> pd.DataFrame: + return read_blob_dataframe(request.file.data, request.file.url) + + +_READERS: tuple[_DataFrameReader, ...] = ( + _CsvReader(), + _ParquetReader(), + _JsonReader(), + _JsonObjectsReader(), + _TextReader(), + _BlobReader(), +) + + +@dataclass(frozen=True) +class RemoteFileSource: + files: tuple[RemoteFile, ...] + reader_name: _ReaderName + options: tuple[tuple[str, Any], ...] = () + function_name: str | None = None + + def read_options(self) -> dict[str, Any]: + """Expose sorted, hashable options as regular reader kwargs.""" + return dict(self.options) + + def read_dataframe(self) -> pd.DataFrame: + """Read one or more remote files using DuckDB-compatible concat rules.""" + frames = [self._read_file_dataframe(file) for file in self.files] + if len(frames) == 1: + return frames[0] + + import pandas as pd + + if self.read_options().get("union_by_name") is True: + return pd.concat(frames, ignore_index=True, sort=False) + + columns = list(frames[0].columns) + for frame in frames[1:]: + if list(frame.columns) != columns: + raise ValueError( + "DuckDB WASM remote sources must have matching columns " + "unless union_by_name=True" + ) + return pd.concat(frames, ignore_index=True) + + def _read_file_dataframe(self, file: RemoteFile) -> pd.DataFrame: + """Fetch bytes, decode them, then apply DuckDB's filename option.""" + fetched = file.fetch() + options = self.read_options() + reader = _reader_by_name(self.reader_name) + df = reader.read_dataframe( + _ReadRequest( + file=fetched, + function_name=self.function_name or reader.function_names[0], + options=options, + ) + ) + filename = options.get("filename") + if filename is True: + return append_filename_column(df, fetched.url, "filename") + if isinstance(filename, str): + return append_filename_column(df, fetched.url, filename) + return df + + +def remote_file_from_url(url: str) -> RemoteFile | None: + """Return a fetchable remote file only for URL schemes marimo supports.""" + if urlparse(url).scheme not in {"http", "https"}: + return None + return RemoteFile(url=url) + + +def remote_file_source_from_reader_args( + function_name: str, + source: Any, + raw_options: Mapping[str, Any], +) -> RemoteFileSource | None: + """Map a DuckDB reader call to a reproducible remote DataFrame source.""" + reader = reader_for_function(function_name) + if reader is None: + return None + + files = _remote_files_from_source_arg(source) + if files is None: + return None + + options = reader.read_options(function_name, raw_options) + if options is None: + return None + return RemoteFileSource( + files, + reader.name, + tuple(sorted(options.items())), + function_name=function_name, + ) + + +def _remote_files_from_source_arg( + source: Any, +) -> tuple[RemoteFile, ...] | None: + """Accept DuckDB source shapes that are static URL strings or URL lists.""" + if isinstance(source, str): + file = remote_file_from_url(source) + return (file,) if file is not None else None + + if isinstance(source, Sequence) and not isinstance( + source, bytes | bytearray + ): + files: list[RemoteFile] = [] + for item in source: + if not isinstance(item, str): + return None + file = remote_file_from_url(item) + if file is None: + return None + files.append(file) + return tuple(files) if files else None + + return None + + +def _reader_by_name( + name: _ReaderName, +) -> _DataFrameReader: + for reader in _READERS: + if reader.name == name: + return reader + raise KeyError(f"Unknown DuckDB WASM reader: {name}") + + +def reader_for_url(url: str) -> _DataFrameReader | None: + """Infer a reader from direct URL table syntax such as ``FROM 'x.csv'``.""" + path = urlparse(url).path.lower() + return next( + ( + reader + for reader in _READERS + if path.endswith(reader.direct_extensions) + ), + None, + ) + + +def reader_for_function( + function_name: str, +) -> _DataFrameReader | None: + """Resolve DuckDB table-function names to marimo's fallback readers.""" + return next( + ( + reader + for reader in _READERS + if function_name in reader.function_names + ), + None, + ) + + +def _csv_reader_options(options: Mapping[str, Any]) -> dict[str, Any]: + """Drop options implemented outside DuckDB's CSV reader call.""" + return { + key: value + for key, value in options.items() + if key not in {"filename", "union_by_name"} + } + + +def _json_reader_options(options: Mapping[str, Any]) -> dict[str, Any]: + """Translate DuckDB JSON option spelling to DuckDB Python API spelling.""" + return { + key: _normalize_json_reader_option(key, value) + for key, value in options.items() + if key not in {"filename", "union_by_name"} + } + + +def _normalize_delimiter(value: Any) -> Any: + """Convert SQL's escaped tab literal to the byte delimiter DuckDB expects.""" + if value == r"\t": + return "\t" + return value + + +def _normalize_json_reader_option(key: str, value: Any) -> Any: + """Normalize JSON values whose SQL names differ from Python reader values.""" + if key == "compression": + return _normalize_json_compression(value) + if key == "format": + return _normalize_json_format(value) + return value + + +def _normalize_json_compression(value: Any) -> str: + """Map SQL compression aliases to DuckDB Python JSON reader values.""" + compression = str(value).lower() + if compression == "auto": + return "auto_detect" + if compression == "none": + return "uncompressed" + return compression + + +def _normalize_json_format(value: Any) -> str: + """Map DuckDB SQL JSON format aliases to Python reader values.""" + fmt = str(value).lower() + if fmt == "ndjson": + return "newline_delimited" + if fmt == "array_of_objects": + return "array" + return fmt + + +def _apply_json_option(options: dict[str, Any], key: str, value: Any) -> bool: + """Keep JSON options only when the fallback can safely pass them through.""" + if key == "format": + options["format"] = _normalize_json_format(value) + return True + if _apply_shared_source_option(options, key, value): + return True + if _is_safe_reader_option_name(key): + options[key] = value + return True + return False + + +def _is_safe_reader_option_name(key: str) -> bool: + """Reject option names that cannot be passed as Python reader kwargs.""" + return key.isidentifier() + + +def _apply_common_table_option( + options: dict[str, Any], key: str, value: Any +) -> bool: + """Handle options marimo applies after per-file reads are decoded.""" + if key == "filename" and isinstance(value, bool | str): + options["filename"] = value + return True + if key == "union_by_name" and isinstance(value, bool): + options["union_by_name"] = value + return True + return False + + +def _apply_compression_option( + options: dict[str, Any], key: str, value: Any +) -> bool: + """Accept only compression modes supported by the byte-fetch fallback.""" + if key == "compression" and _is_supported_compression(value): + options["compression"] = str(value).lower() + return True + return False + + +def _apply_shared_source_option( + options: dict[str, Any], key: str, value: Any +) -> bool: + """Apply source options shared by CSV, parquet, and JSON fallbacks.""" + return _apply_compression_option( + options, key, value + ) or _apply_common_table_option(options, key, value) + + +def _is_supported_compression(value: Any) -> bool: + """Limit compression to modes the fallback knows how to preserve.""" + return str(value).lower() in {"auto", "none", "gzip"} diff --git a/marimo/_runtime/_wasm/_duckdb/sources.py b/marimo/_runtime/_wasm/_duckdb/sources.py new file mode 100644 index 00000000000..8086cfa75cb --- /dev/null +++ b/marimo/_runtime/_wasm/_duckdb/sources.py @@ -0,0 +1,200 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Resolve sqlglot DuckDB table nodes to remote file sources. + +The SQL patch should only rewrite queries it can execute with fetched +DataFrames. This module recognizes direct URL table syntax and supported +``read_*`` table functions, extracts literal URL/options from sqlglot's AST, +and returns ``None`` for dynamic expressions so they continue through DuckDB +unchanged. +""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any + +from marimo._runtime._wasm._duckdb.io import ( + RemoteFileSource, + reader_for_url, + remote_file_from_url, + remote_file_source_from_reader_args, +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + from sqlglot import exp + +# SQL options can parse to falsy values such as false, 0, or ""; this marks +# unsupported expressions without conflating them with valid literal values. +_MISSING = object() + + +def remote_file_source_from_table( + table: exp.Table, + *, + query: str | None = None, +) -> RemoteFileSource | None: + """Return a remote source for supported direct URLs or reader calls.""" + table_name = table.name + if table_name and _is_single_quoted_table_identifier(table, query): + reader = reader_for_url(table_name) + remote_file = remote_file_from_url(table_name) + if reader is not None and remote_file is not None: + return RemoteFileSource((remote_file,), reader.name) + + table_expr = table.this + if table_expr is None: + return None + + table_function = _table_function_call(table_expr) + if table_function is None: + return None + + function_name, args = table_function + source = _read_function_source(args) + if source is None: + return None + + raw_options = _read_function_options(args[1:]) + if raw_options is None: + return None + return remote_file_source_from_reader_args( + function_name, source, raw_options + ) + + +def _is_single_quoted_table_identifier( + table: exp.Table, query: str | None +) -> bool: + """Distinguish DuckDB file scans from ordinary quoted identifiers.""" + if query is None: + return False + meta = getattr(table.this, "meta", {}) + start = meta.get("start") + end = meta.get("end") + if not isinstance(start, int) or not isinstance(end, int): + return _query_has_single_quoted_table_reference(query, table.name) + return ( + 0 <= start + and end < len(query) + and query[start] == "'" + and query[end] == "'" + ) + + +def _query_has_single_quoted_table_reference( + query: str, table_name: str +) -> bool: + if not table_name: + return False + + quoted_table = re.escape(f"'{table_name}'") + return any( + re.search(pattern, query, flags=re.IGNORECASE) is not None + for pattern in ( + rf"(?:^|[()\s])FROM\s+{quoted_table}", + rf"\bJOIN\s+{quoted_table}", + ) + ) + + +def _table_function_call( + table_expr: exp.Expression, +) -> tuple[str, list[exp.Expression]] | None: + """Return a normalized DuckDB reader name and its arguments.""" + import sqlglot.expressions as exp + + # sqlglot versions model first-party DuckDB readers either as explicit + # Read* nodes or as generic anonymous table functions. + for node_name, function_name in ( + ("ReadCSV", "read_csv"), + ("ReadParquet", "read_parquet"), + ): + read_node = getattr(exp, node_name, None) + if read_node is not None and isinstance(table_expr, read_node): + first = [table_expr.this] if table_expr.this is not None else [] + return function_name, [*first, *table_expr.expressions] + + if isinstance(table_expr, exp.Anonymous): + return str(table_expr.this).lower(), list(table_expr.expressions) + return None + + +def _read_function_source( + args: Sequence[exp.Expression], +) -> str | tuple[str, ...] | None: + """Accept only literal URL sources that can be fetched before execution.""" + import sqlglot.expressions as exp + + if not args: + return None + source_expr = args[0] + if isinstance(source_expr, exp.Literal) and source_expr.is_string: + return str(source_expr.this) + + if isinstance(source_expr, exp.Array) and source_expr.expressions: + urls: list[str] = [] + for item_expr in source_expr.expressions: + if not ( + isinstance(item_expr, exp.Literal) and item_expr.is_string + ): + return None + urls.append(str(item_expr.this)) + return tuple(urls) + + return None + + +def _read_function_options( + option_exprs: Sequence[exp.Expression], +) -> dict[str, Any] | None: + """Decode literal keyword options from a DuckDB table-function call.""" + options: dict[str, Any] = {} + for option_expr in option_exprs: + option = _read_function_option(option_expr) + if option is None: + return None + key, value = option + options[key] = value + return options + + +def _read_function_option( + option_expr: exp.Expression, +) -> tuple[str, Any] | None: + """Return one static option or ``None`` for unsupported expressions.""" + import sqlglot.expressions as exp + + property_eq = getattr(exp, "PropertyEQ", None) + option_classes = ( + (exp.EQ,) if property_eq is None else (exp.EQ, property_eq) + ) + if not isinstance(option_expr, option_classes): + return None + + value_expr = option_expr.args.get("expression") + if value_expr is None: + return None + + value = _literal_value(value_expr) + if value is _MISSING: + return None + + key = getattr(option_expr.this, "name", None) + if key is None: + return None + return key.lower(), value + + +def _literal_value(value_expr: exp.Expression) -> Any: + """Convert sqlglot literals while preserving falsy values via _MISSING.""" + import sqlglot.expressions as exp + + if isinstance(value_expr, exp.Boolean): + return value_expr.this + if isinstance(value_expr, exp.Literal): + if value_expr.is_string: + return value_expr.this + return value_expr.to_py() + return _MISSING diff --git a/marimo/_runtime/_wasm/_fetch.py b/marimo/_runtime/_wasm/_fetch.py new file mode 100644 index 00000000000..c84d1bf3d19 --- /dev/null +++ b/marimo/_runtime/_wasm/_fetch.py @@ -0,0 +1,40 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Shared URL fetch for WASM fallbacks.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, TypedDict, cast + +if TYPE_CHECKING: + import ssl + + +class RequestKwargs(TypedDict, total=False): + data: bytes | None + headers: dict[str, str] + origin_req_host: str + unverifiable: bool + method: str + + +class UrlOpenKwargs(TypedDict, total=False): + data: bytes | None + timeout: float + context: ssl.SSLContext | None + + +def fetch_url_bytes( + url: str, + *, + request_kwargs: RequestKwargs | None = None, + urlopen_kwargs: UrlOpenKwargs | None = None, +) -> bytes: + """Sync fetch via urllib — works in pyodide thanks to pyodide_http's + patch_all (installed at marimo startup), which routes urllib through JS + fetch. Single sync path for text and binary. + """ + import urllib.request + + request = urllib.request.Request(url, **(request_kwargs or {})) + with urllib.request.urlopen(request, **(urlopen_kwargs or {})) as response: + return cast(bytes, response.read()) diff --git a/marimo/_runtime/_wasm/_patches.py b/marimo/_runtime/_wasm/_patches.py index a726d45d7f6..06662c80eb4 100644 --- a/marimo/_runtime/_wasm/_patches.py +++ b/marimo/_runtime/_wasm/_patches.py @@ -19,6 +19,7 @@ Unpatch = Callable[[], None] Fallback = Callable[..., Any] +WrapperFactory = Callable[[Callable[..., Any]], Callable[..., Any]] class WasmPatchSet: @@ -47,6 +48,42 @@ def patch( No-op outside pyodide or if ``attr`` is missing (e.g. renamed across polars versions). """ + + def wrapper_factory( + original: Callable[..., Any], + ) -> Callable[..., Any]: + @functools.wraps(original) + def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return original(*args, **kwargs) + except catch as original_exc: + original_tb = original_exc.__traceback__ + try: + return fallback(original, *args, **kwargs) + except ModuleNotFoundError: + # Let missing-dependency errors bubble up so marimo can + # prompt the user to install the package. + raise + except Exception as fallback_exc: + raise original_exc.with_traceback( + original_tb + ) from fallback_exc + + return wrapper + + self.replace(owner, attr, wrapper_factory) + + def replace( + self, + owner: Any, + attr: str, + wrapper_factory: WrapperFactory, + ) -> None: + """Replace ``owner.attr`` with a WASM-only wrapper. + + Unlike ``patch``, this does not call the original first. Use this for + APIs where an original call can have side effects before failing. + """ if not self._active: return @@ -54,23 +91,7 @@ def patch( if original is None: return - @functools.wraps(original) - def wrapper(*args: Any, **kwargs: Any) -> Any: - try: - return original(*args, **kwargs) - except catch as original_exc: - original_tb = original_exc.__traceback__ - try: - return fallback(original, *args, **kwargs) - except ModuleNotFoundError: - # Let missing-dependency errors bubble up so marimo can - # prompt the user to install the package. - raise - except Exception as fallback_exc: - raise original_exc.with_traceback( - original_tb - ) from fallback_exc - + wrapper = wrapper_factory(original) setattr(owner, attr, wrapper) def _unpatch() -> None: diff --git a/marimo/_runtime/_wasm/_polars.py b/marimo/_runtime/_wasm/_polars.py index 36571c6f060..c212653588b 100644 --- a/marimo/_runtime/_wasm/_polars.py +++ b/marimo/_runtime/_wasm/_polars.py @@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any from marimo import _loggers +from marimo._runtime._wasm import _fetch from marimo._runtime._wasm._patches import Unpatch, WasmPatchSet from marimo._utils.platform import is_pyodide @@ -47,17 +48,6 @@ def _is_url(source: Any) -> bool: ) -def _fetch_url_bytes(url: str) -> bytes: - """Sync fetch via urllib — works in pyodide thanks to pyodide_http's - patch_all (installed at marimo startup), which routes urllib through JS - fetch. Single sync path for text and binary. - """ - import urllib.request - - with urllib.request.urlopen(url) as response: - return response.read() # type: ignore[no-any-return] - - def _resolve_to_bytes(source: Any) -> bytes: """Coerce a polars I/O source into raw bytes.""" if isinstance(source, bytes): @@ -73,7 +63,7 @@ def _resolve_to_bytes(source: Any) -> bytes: return source.read_bytes() if isinstance(source, str): if _is_url(source): - return _fetch_url_bytes(source) + return _fetch.fetch_url_bytes(source) return pathlib.Path(source).read_bytes() raise TypeError( f"Unsupported source type for WASM polars fallback: {type(source)!r}" diff --git a/marimo/_sql/utils.py b/marimo/_sql/utils.py index 4d17982906b..2249fd8d0b9 100644 --- a/marimo/_sql/utils.py +++ b/marimo/_sql/utils.py @@ -8,6 +8,9 @@ from marimo._config.config import SqlOutputType from marimo._data.models import DataType from marimo._dependencies.dependencies import DependencyManager +from marimo._runtime._wasm._duckdb import ( + try_run_duckdb_sql_with_wasm_patch, +) from marimo._runtime.context.types import ( ContextNotInitializedError, get_context, @@ -24,6 +27,38 @@ LOGGER = _loggers.marimo_logger() CHEAP_DISCOVERY_DATABASES = ["duckdb", "sqlite", "mysql", "postgresql"] +# DuckDB SQL can return None for DDL, so keep "patch did not apply" distinct. +_NO_WASM_DUCKDB_RESULT = object() + + +def _try_wasm_duckdb( + method_name: str, + query: str, + connection: Any, + glbls: dict[str, Any], + *trailing_args: Any, +) -> object: + import duckdb + + if connection is duckdb: + original = getattr(duckdb, method_name) + args: tuple[Any, ...] = (query, *trailing_args) + query_arg_index = 0 + else: + original = getattr(type(connection), method_name) + args = (connection, query, *trailing_args) + query_arg_index = 1 + + result = try_run_duckdb_sql_with_wasm_patch( + original, + args, + {}, + query_arg_index=query_arg_index, + query_kwarg_names=("query",), + eval_globals=glbls, + eval_locals=glbls, + ) + return _NO_WASM_DUCKDB_RESULT if result is None else result.value def wrapped_sql( @@ -31,6 +66,7 @@ def wrapped_sql( connection: duckdb.DuckDBPyConnection | None, ) -> duckdb.DuckDBPyRelation: DependencyManager.duckdb.require("to execute sql") + import duckdb # In Python globals() are scoped to modules; since this function # is in a different module than user code, globals() doesn't return @@ -39,14 +75,17 @@ def wrapped_sql( # However, duckdb needs access to the kernel's globals. For this reason, # we manually exec duckdb and provide it with the kernel's globals. if connection is None: - import duckdb - connection = cast(duckdb.DuckDBPyConnection, duckdb) try: ctx = get_context() except ContextNotInitializedError: - relation = connection.sql(query=query) + result = _try_wasm_duckdb("sql", query, connection, globals()) + if result is _NO_WASM_DUCKDB_RESULT: + # No WASM rewrite was needed; use DuckDB's normal SQL path. + relation = connection.sql(query=query) + else: + relation = cast(duckdb.DuckDBPyRelation, result) else: install_connection = ( ctx.execution_context.with_connection @@ -54,11 +93,21 @@ def wrapped_sql( else nullcontext ) with install_connection(connection): - relation = eval( - "connection.sql(query=query)", + result = _try_wasm_duckdb( + "sql", + query, + connection, ctx.globals, - {"query": query, "connection": connection}, ) + if result is _NO_WASM_DUCKDB_RESULT: + # Run in kernel globals so DuckDB replacement scans see user data. + relation = eval( + "connection.sql(query=query)", + ctx.globals, + {"query": query, "connection": connection}, + ) + else: + relation = cast(duckdb.DuckDBPyRelation, result) return relation @@ -74,16 +123,21 @@ def execute_duckdb_sql( parameterized queries ($1, $2, ...) for safe value interpolation. """ DependencyManager.duckdb.require("to execute sql") + import duckdb if connection is None: - import duckdb - connection = cast(duckdb.DuckDBPyConnection, duckdb) try: ctx = get_context() except ContextNotInitializedError: - return connection.execute(query, params) + result = _try_wasm_duckdb( + "execute", query, connection, globals(), params + ) + if result is _NO_WASM_DUCKDB_RESULT: + # No WASM rewrite was needed; preserve DuckDB's parameterized path. + return connection.execute(query, params) + return cast(duckdb.DuckDBPyConnection, result) else: install_connection = ( ctx.execution_context.with_connection @@ -91,16 +145,27 @@ def execute_duckdb_sql( else nullcontext ) with install_connection(connection): - result: duckdb.DuckDBPyConnection = eval( - "connection.execute(query, params)", + result = _try_wasm_duckdb( + "execute", + query, + connection, ctx.globals, - { - "query": query, - "params": params, - "connection": connection, - }, + params, ) - return result + if result is _NO_WASM_DUCKDB_RESULT: + # Run in kernel globals so parameterized SQL can scan user data. + value = eval( + "connection.execute(query, params)", + ctx.globals, + { + "query": query, + "params": params, + "connection": connection, + }, + ) + else: + value = result + return cast(duckdb.DuckDBPyConnection, value) def try_convert_to_polars( diff --git a/tests/_runtime/test_duckdb_wasm.py b/tests/_runtime/test_duckdb_wasm.py new file mode 100644 index 00000000000..cd87c5d4262 --- /dev/null +++ b/tests/_runtime/test_duckdb_wasm.py @@ -0,0 +1,1543 @@ +# Copyright 2026 Marimo. All rights reserved. +from __future__ import annotations + +import gzip +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pytest + +from marimo._runtime._wasm._duckdb import ( + patch_duckdb_for_wasm, + patch_duckdb_query_for_wasm, +) +from marimo._runtime._wasm._duckdb.sources import ( + remote_file_source_from_table, +) +from marimo._sql.engines.duckdb import DuckDBEngine +from tests.conftest import ExecReqProvider, mock_pyodide + +pytestmark = pytest.mark.requires("duckdb", "pandas", "sqlglot") + +pytest.importorskip("duckdb") +pytest.importorskip("pandas") +pytest.importorskip("sqlglot") + +if TYPE_CHECKING: + from collections.abc import Sequence + + from marimo._runtime.runtime import Kernel + + +def _normalize_value(value: object) -> object: + import pandas as pd + + if hasattr(value, "tolist") and not isinstance( + value, dict | list | str | bytes | bytearray + ): + return value.tolist() + if isinstance(value, list): + return [_normalize_value(item) for item in value] + if isinstance(value, dict): + return {key: _normalize_value(item) for key, item in value.items()} + if isinstance(value, bytes | bytearray): + return bytes(value) + try: + return None if bool(pd.isna(value)) else value + except (TypeError, ValueError): + return value + + +def _records(df: object) -> list[dict[str, object]]: + return [ + {key: _normalize_value(value) for key, value in row.items()} + for row in df.to_dict("records") # type: ignore[attr-defined] + ] + + +def _rows(rows: Sequence[Sequence[object]]) -> list[tuple[object, ...]]: + return [tuple(_normalize_value(value) for value in row) for row in rows] + + +@dataclass(frozen=True) +class RemoteFixture: + url: str + suffix: str + data: bytes + + +@dataclass(frozen=True) +class QueryParityCase: + name: str + query: str + fixtures: tuple[RemoteFixture, ...] + + +@dataclass(frozen=True) +class DirectReadParityCase: + name: str + function_name: str + fixture: RemoteFixture + source_kwarg: str | None = None + options: tuple[tuple[str, object], ...] = () + + +def _parquet_bytes(sql: str) -> bytes: + import duckdb + + with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as file: + path = Path(file.name) + try: + duckdb.sql(sql).write_parquet(str(path)) + return path.read_bytes() + finally: + path.unlink(missing_ok=True) + + +def _local_fixture_path( + fixture: RemoteFixture, tmp_path: Path, filename: str +) -> str: + path = tmp_path / f"{filename}{fixture.suffix}" + path.write_bytes(fixture.data) + return path.as_posix() + + +def _local_query( + remote_query: str, + fixtures: Sequence[RemoteFixture], + tmp_path: Path, +) -> str: + query = remote_query + for idx, fixture in enumerate(fixtures): + query = query.replace( + fixture.url, + _local_fixture_path(fixture, tmp_path, f"remote_{idx}"), + ) + return query + + +def _native_rows(query: str) -> list[tuple[object, ...]]: + import duckdb + + connection = duckdb.connect(":memory:") + try: + return _rows(connection.sql(query).fetchall()) + finally: + connection.close() + + +def _patched_rows( + query: str, + fixtures: Sequence[RemoteFixture], +) -> tuple[list[tuple[object, ...]], list[str]]: + import duckdb + + fixtures_by_url = {fixture.url: fixture for fixture in fixtures} + + def fetch(url: str) -> bytes: + return fixtures_by_url[url].data + + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + side_effect=fetch, + ) as fetch_url_bytes, + ): + patch_result = patch_duckdb_query_for_wasm(query) + + assert patch_result is not None + assert all(fixture.url not in patch_result.query for fixture in fixtures) + + connection = duckdb.connect(":memory:") + try: + for table_name, df in patch_result.tables.items(): + connection.register(table_name, df) + rows = _rows(connection.sql(patch_result.query).fetchall()) + finally: + connection.close() + + fetched_urls = [call.args[0] for call in fetch_url_bytes.call_args_list] + return rows, fetched_urls + + +def _direct_reader_args( + case: DirectReadParityCase, source: str +) -> tuple[tuple[object, ...], dict[str, object]]: + kwargs = dict(case.options) + if case.source_kwarg is None: + return (source,), kwargs + kwargs[case.source_kwarg] = source + return (), kwargs + + +def _run_direct_reader( + case: DirectReadParityCase, + source: str, + *, + api_kind: str, +) -> list[tuple[object, ...]]: + import duckdb + + args, kwargs = _direct_reader_args(case, source) + if api_kind == "module": + relation = getattr(duckdb, case.function_name)(*args, **kwargs) + return _rows(relation.fetchall()) + + connection = duckdb.connect(":memory:") + try: + if api_kind == "connection": + relation = getattr(connection, case.function_name)(*args, **kwargs) + elif api_kind == "module-connection-kw": + kwargs["connection"] = connection + relation = getattr(duckdb, case.function_name)(*args, **kwargs) + else: + raise ValueError(f"Unknown DuckDB direct reader API: {api_kind}") + return _rows(relation.fetchall()) + finally: + connection.close() + + +def _patched_direct_rows( + case: DirectReadParityCase, + *, + api_kind: str, +) -> tuple[list[tuple[object, ...]], list[str]]: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=case.fixture.data, + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + rows = _run_direct_reader( + case, case.fixture.url, api_kind=api_kind + ) + finally: + unpatch() + + fetched_urls = [call.args[0] for call in fetch_url_bytes.call_args_list] + return rows, fetched_urls + + +def _direct_reader_parity_cases() -> list[DirectReadParityCase]: + return [ + DirectReadParityCase( + "csv-positional", + "read_csv", + RemoteFixture( + "https://datasets.marimo.app/cars.csv", + ".csv", + b"1;ford\n2;toyota\n", + ), + options=(("delimiter", ";"), ("header", False)), + ), + DirectReadParityCase( + "parquet-file-glob", + "read_parquet", + RemoteFixture( + "https://datasets.marimo.app/cars.parquet", + ".parquet", + _parquet_bytes("SELECT 'ford' AS make"), + ), + source_kwarg="file_glob", + ), + DirectReadParityCase( + "json-path-or-buffer", + "read_json", + RemoteFixture( + "https://datasets.marimo.app/cars.json", + ".json", + b'[{"make":"ford"},{"make":"toyota"}]', + ), + source_kwarg="path_or_buffer", + options=(("format", "array"),), + ), + ] + + +def _query_parity_cases() -> list[QueryParityCase]: + csv = RemoteFixture( + "https://example.com/cars.csv", + ".csv", + b"make,mpg\nford,25\ntoyota,18\n", + ) + csv_semicolon = RemoteFixture( + "https://example.com/cars-semicolon.csv", + ".csv", + b"1;ford\n2;toyota\n", + ) + csv_gzip = RemoteFixture( + "https://example.com/cars.csv.gz", + ".csv.gz", + gzip.compress(b"make,mpg\nford,25\n"), + ) + csv_download = RemoteFixture( + "https://example.com/download", + "", + b"make,mpg\nford,25\n", + ) + csv_normalize = RemoteFixture( + "https://example.com/names.csv", + ".csv", + b"make name,mpg\nford,25\n", + ) + tsv = RemoteFixture( + "https://example.com/walmarts.tsv", + ".tsv", + b"longitude\tlatitude\n1\t2\n", + ) + csv_a = RemoteFixture("https://example.com/a.csv", ".csv", b"a,b\n1,2\n") + csv_b = RemoteFixture("https://example.com/b.csv", ".csv", b"a,c\n3,4\n") + parquet = RemoteFixture( + "https://example.com/a.parquet", + ".parquet", + _parquet_bytes( + """ + SELECT 1 AS a, 'x' AS b + UNION ALL SELECT 2, 'y' + """ + ), + ) + parquet_b = RemoteFixture( + "https://example.com/b.parquet", + ".parquet", + _parquet_bytes("SELECT 3 AS a, 'z' AS b"), + ) + json_array = RemoteFixture( + "https://example.com/a.json", + ".json", + b'[{"a":1,"b":"x"},{"a":2,"b":"y"}]', + ) + json_array_b = RemoteFixture( + "https://example.com/b.json", + ".json", + b'[{"a":3,"b":"z"}]', + ) + complex_json = RemoteFixture( + "https://example.com/countries.json", + ".json", + b'{"type":"Topology","arcs":[[[0]],1]}', + ) + unstructured_json = RemoteFixture( + "https://example.com/unstructured.json", + ".json", + b'{"a":1} {"a":2}', + ) + ndjson = RemoteFixture( + "https://example.com/a.ndjson", + ".ndjson", + b'{"a":1}\n{"a":2}\n', + ) + ndjson_gzip = RemoteFixture( + "https://example.com/events.ndjson.gz", + ".ndjson.gz", + gzip.compress(b'{"event_id":5,"value":10}\n'), + ) + ndjson_objects_gzip = RemoteFixture( + "https://example.com/objects.ndjson.gz", + ".ndjson.gz", + gzip.compress(b'{"a":1}\n{"a":2}\n'), + ) + geojson = RemoteFixture( + "https://example.com/a.geojson", + ".geojson", + b'{"type":"FeatureCollection","features":[]}', + ) + text = RemoteFixture("https://example.com/a.txt", ".txt", b"hello") + blob = RemoteFixture("https://example.com/a.bin", ".bin", b"\x00\x01") + return [ + QueryParityCase( + "direct-csv-literal", + f"SELECT make, mpg FROM '{csv.url}' ORDER BY make", + (csv,), + ), + QueryParityCase( + "csv-reader-options", + f""" + SELECT column1 FROM read_csv( + '{csv_semicolon.url}', delim := ';', header := false + ) + ORDER BY column0 + """, + (csv_semicolon,), + ), + QueryParityCase( + "tsv-escaped-delimiter", + f""" + SELECT * FROM read_csv_auto('{tsv.url}', delim='\\t') + """, + (tsv,), + ), + QueryParityCase( + "gzipped-csv", + f"SELECT make, mpg FROM read_csv('{csv_gzip.url}')", + (csv_gzip,), + ), + QueryParityCase( + "reader-without-extension", + f"SELECT make, mpg FROM read_csv('{csv_download.url}')", + (csv_download,), + ), + QueryParityCase( + "csv-normalize-names", + f""" + SELECT make_name, mpg FROM read_csv( + '{csv_normalize.url}', normalize_names=true + ) + """, + (csv_normalize,), + ), + QueryParityCase( + "csv-union-by-name", + f""" + SELECT * FROM read_csv( + ['{csv_a.url}', '{csv_b.url}'], union_by_name=true + ) + ORDER BY a + """, + (csv_a, csv_b), + ), + QueryParityCase( + "parquet-reader", + f"SELECT * FROM read_parquet('{parquet.url}') ORDER BY a", + (parquet,), + ), + QueryParityCase( + "parquet-list", + f""" + SELECT * FROM read_parquet(['{parquet.url}', '{parquet_b.url}']) + ORDER BY a + """, + (parquet, parquet_b), + ), + QueryParityCase( + "parquet-scan-alias", + f"SELECT * FROM parquet_scan('{parquet.url}') ORDER BY a", + (parquet,), + ), + QueryParityCase( + "parquet-direct-literal", + f"SELECT * FROM '{parquet.url}' ORDER BY a", + (parquet,), + ), + QueryParityCase( + "json-reader", + f"SELECT * FROM read_json_auto('{json_array.url}') ORDER BY a", + (json_array,), + ), + QueryParityCase( + "json-list", + f""" + SELECT * FROM read_json_auto( + ['{json_array.url}', '{json_array_b.url}'] + ) + ORDER BY a + """, + (json_array, json_array_b), + ), + QueryParityCase( + "complex-json", + f"SELECT type, arcs FROM read_json_auto('{complex_json.url}')", + (complex_json,), + ), + QueryParityCase( + "unstructured-json", + f""" + SELECT * FROM read_json( + '{unstructured_json.url}', format='unstructured' + ) + ORDER BY a + """, + (unstructured_json,), + ), + QueryParityCase( + "ndjson-reader", + f"SELECT * FROM read_ndjson('{ndjson.url}') ORDER BY a", + (ndjson,), + ), + QueryParityCase( + "direct-ndjson-literal", + f"SELECT * FROM '{ndjson.url}' ORDER BY a", + (ndjson,), + ), + QueryParityCase( + "gzipped-ndjson", + f""" + SELECT * FROM read_json( + '{ndjson_gzip.url}', format='newline_delimited', + compression='gzip' + ) + """, + (ndjson_gzip,), + ), + QueryParityCase( + "ndjson-objects", + f"SELECT json FROM read_ndjson_objects('{json_array.url}')", + (json_array,), + ), + QueryParityCase( + "json-objects-auto-gzip", + f""" + SELECT json FROM read_json_objects_auto('{ndjson_objects_gzip.url}') + ORDER BY json + """, + (ndjson_objects_gzip,), + ), + QueryParityCase( + "geojson-reader", + f"SELECT type FROM read_json_auto('{geojson.url}')", + (geojson,), + ), + QueryParityCase( + "text-and-blob", + f""" + SELECT content, size FROM read_text('{text.url}') + UNION ALL + SELECT content::VARCHAR, size FROM read_blob('{blob.url}') + """, + (text, blob), + ), + QueryParityCase( + "mixed-parquet-json", + f""" + SELECT * FROM read_parquet('{parquet.url}') + UNION ALL + SELECT * FROM read_json_auto('{json_array.url}') + ORDER BY a + """, + (parquet, json_array), + ), + ] + + +def test_patch_duckdb_for_wasm_noop_outside_pyodide() -> None: + import duckdb + + original_read_csv = duckdb.read_csv + original_sql = duckdb.sql + original_connection_sql = duckdb.DuckDBPyConnection.sql + + unpatch = patch_duckdb_for_wasm() + + assert duckdb.read_csv is original_read_csv + assert duckdb.sql is original_sql + assert duckdb.DuckDBPyConnection.sql is original_connection_sql + unpatch() + assert duckdb.read_csv is original_read_csv + assert duckdb.sql is original_sql + assert duckdb.DuckDBPyConnection.sql is original_connection_sql + + +class TestDuckDBWasmDirectReadPatch: + @staticmethod + def test_patch_installation_does_not_require_sqlglot() -> None: + import duckdb + + original = duckdb.read_csv + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._duckdb._require_sqlglot", + side_effect=AssertionError("sqlglot should be lazy"), + ), + ): + unpatch = patch_duckdb_for_wasm() + try: + assert duckdb.read_csv is not original + finally: + unpatch() + + assert duckdb.read_csv is original + + @staticmethod + @pytest.mark.parametrize("api_kind", ["module", "connection"]) + @pytest.mark.parametrize( + "case", + _direct_reader_parity_cases(), + ids=lambda case: case.name, + ) + def test_direct_readers_match_native_duckdb( + case: DirectReadParityCase, + api_kind: str, + tmp_path: Path, + ) -> None: + native_rows = _run_direct_reader( + case, + _local_fixture_path(case.fixture, tmp_path, case.name), + api_kind=api_kind, + ) + patched_rows, fetched_urls = _patched_direct_rows( + case, api_kind=api_kind + ) + + assert patched_rows == native_rows + assert fetched_urls == [case.fixture.url] + + @staticmethod + @pytest.mark.parametrize( + "case", + [ + case + for case in _direct_reader_parity_cases() + if case.function_name != "read_csv" + ], + ids=lambda case: case.name, + ) + def test_module_readers_preserve_connection_kw( + case: DirectReadParityCase, tmp_path: Path + ) -> None: + native_rows = _run_direct_reader( + case, + _local_fixture_path(case.fixture, tmp_path, case.name), + api_kind="module-connection-kw", + ) + patched_rows, fetched_urls = _patched_direct_rows( + case, api_kind="module-connection-kw" + ) + + assert patched_rows == native_rows + assert fetched_urls == [case.fixture.url] + + +class TestDuckDBWasmQueryPatch: + @staticmethod + def test_noop_outside_pyodide() -> None: + assert ( + patch_duckdb_query_for_wasm( + "SELECT * FROM 'https://datasets.marimo.app/cars.csv'" + ) + is None + ) + + @staticmethod + def test_read_parquet_node_preserves_this_argument() -> None: + from sqlglot import exp + + table = exp.Table( + this=exp.ReadParquet( + this=exp.Literal.string("https://example.com/a.parquet") + ) + ) + + source = remote_file_source_from_table(table) + + assert source is not None + assert source.reader_name == "parquet" + assert [file.url for file in source.files] == [ + "https://example.com/a.parquet" + ] + + @staticmethod + @pytest.mark.parametrize( + ("query", "expected_source"), + [ + ( + "SELECT * FROM 'https://example.com/a.csv'", + True, + ), + ( + 'SELECT * FROM "https://example.com/a.csv"', + False, + ), + ( + """ + SELECT 1, 'https://example.com/a.csv' AS label + FROM "https://example.com/a.csv" + """, + False, + ), + ], + ids=[ + "direct-literal", + "double-quoted-identifier", + "string-literal-and-double-quoted-identifier", + ], + ) + def test_token_metadata_fallback_detects_single_quoted_remote_sources( + query: str, + expected_source: bool, + ) -> None: + from sqlglot import exp + + table = exp.Table( + this=exp.Identifier(this="https://example.com/a.csv", quoted=True) + ) + + source = remote_file_source_from_table( + table, + query=query, + ) + + if not expected_source: + assert source is None + return + + assert source is not None + assert source.reader_name == "csv" + assert [file.url for file in source.files] == [ + "https://example.com/a.csv" + ] + + @staticmethod + @pytest.mark.parametrize( + "function_name", + [ + "read_json_objects", + "read_json_objects_auto", + "read_ndjson_objects", + ], + ) + def test_json_objects_reader_preserves_requested_function( + monkeypatch: pytest.MonkeyPatch, function_name: str + ) -> None: + import duckdb + import pandas as pd + + from marimo._runtime._wasm._duckdb.dataframe import ( + read_json_objects_dataframe, + ) + + queries: list[str] = [] + + class Relation: + def df(self) -> pd.DataFrame: + return pd.DataFrame({"json": []}) + + def fake_sql(query: str, *, params: list[object]) -> Relation: + del params + queries.append(query) + return Relation() + + monkeypatch.setattr(duckdb, "sql", fake_sql) + + read_json_objects_dataframe( + b'{"a":1}\n', + {}, + url="https://example.com/a.json", + function_name=function_name, + ) + + assert queries == [f"SELECT * FROM {function_name}(?)"] + + @staticmethod + @pytest.mark.parametrize( + "case", + _query_parity_cases(), + ids=lambda case: case.name, + ) + def test_rewrites_remote_sources_with_native_duckdb_parity( + case: QueryParityCase, tmp_path: Path + ) -> None: + native_rows = _native_rows( + _local_query(case.query, case.fixtures, tmp_path) + ) + patched_rows, fetched_urls = _patched_rows(case.query, case.fixtures) + + assert patched_rows == native_rows + assert fetched_urls == [fixture.url for fixture in case.fixtures] + + @staticmethod + def test_list_argument_requires_matching_schemas() -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + side_effect=[b"a,b\n1,2\n", b"a,c\n3,4\n"], + ), + pytest.raises(ValueError), + ): + patch_duckdb_query_for_wasm( + """ + SELECT * FROM read_csv([ + 'https://example.com/a.csv', + 'https://example.com/b.csv' + ]) + """, + ) + + @staticmethod + def test_rewrites_direct_geojson_literal() -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b'{"type":"FeatureCollection","features":[]}', + ), + ): + patch_result = patch_duckdb_query_for_wasm( + "FROM 'https://example.com/a.geojson'", + ) + + assert patch_result is not None + assert ( + patch_result.query == "SELECT * FROM __marimo_wasm_duckdb_remote_0" + ) + assert _records(next(iter(patch_result.tables.values()))) == [ + {"type": "FeatureCollection", "features": []} + ] + + @staticmethod + def test_avoids_reserved_table_names() -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\n", + ), + ): + patch_result = patch_duckdb_query_for_wasm( + "SELECT * FROM 'https://datasets.marimo.app/cars.csv'", + reserved_names=("__marimo_wasm_duckdb_remote_0",), + ) + + assert patch_result is not None + assert ( + patch_result.query == "SELECT * FROM __marimo_wasm_duckdb_remote_1" + ) + + @staticmethod + def test_avoids_sql_cte_table_names_case_insensitively() -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"mpg\n25\n", + ), + ): + patch_result = patch_duckdb_query_for_wasm( + """ + WITH __MARIMO_WASM_DUCKDB_REMOTE_0 AS (SELECT 99 AS mpg) + SELECT mpg FROM 'https://datasets.marimo.app/cars.csv' + """, + ) + + assert patch_result is not None + assert "FROM __marimo_wasm_duckdb_remote_1" in patch_result.query + + @staticmethod + def test_does_not_rewrite_create_view_remote_source() -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\n", + ) as fetch_url_bytes, + ): + patch_result = patch_duckdb_query_for_wasm( + """ + CREATE OR REPLACE VIEW remote_cars AS + SELECT * FROM 'https://datasets.marimo.app/cars.csv' + """, + ) + + assert patch_result is None + fetch_url_bytes.assert_not_called() + + +class TestDuckDBWasmMoSqlIntegration: + @staticmethod + async def test_mo_sql_rewrites_remote_literal_in_kernel( + executing_kernel: Kernel, exec_req: ExecReqProvider + ) -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + patch.object( + DuckDBEngine, "sql_output_format", return_value="native" + ), + ): + await executing_kernel.run( + [ + exec_req.get("import marimo as mo"), + exec_req.get( + """ + result = mo.sql( + ''' + SELECT make + FROM 'https://datasets.marimo.app/cars.csv' + WHERE mpg > 20 + ''', + output=False, + ) + """ + ), + ] + ) + + result = executing_kernel.globals["result"] + assert result.fetchall() == [("ford",)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + async def test_mo_sql_create_table_remote_literal_runs_once( + executing_kernel: Kernel, exec_req: ExecReqProvider + ) -> None: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + patch.object( + DuckDBEngine, "sql_output_format", return_value="native" + ), + ): + await executing_kernel.run( + [ + exec_req.get("import duckdb"), + exec_req.get("import marimo as mo"), + exec_req.get( + """ + result = mo.sql( + ''' + CREATE OR REPLACE TABLE __marimo_wasm_create_once AS ( + SELECT * FROM 'https://datasets.marimo.app/cars.csv' + ) + ''', + output=False, + ) + """ + ), + exec_req.get( + """ + rows = duckdb.sql( + ''' + SELECT make, mpg + FROM __marimo_wasm_create_once + ORDER BY make + ''' + ).fetchall() + duckdb.sql("DROP TABLE IF EXISTS __marimo_wasm_create_once") + """ + ), + ] + ) + + assert executing_kernel.globals["result"] is None + assert executing_kernel.globals["rows"] == [ + ("ford", 25), + ("toyota", 18), + ] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + +class TestDuckDBWasmSqlUtils: + @staticmethod + def test_wrapped_sql_rewrites_remote_literal_with_explicit_connection() -> ( + None + ): + import duckdb + + from marimo._sql.utils import wrapped_sql + + connection = duckdb.connect(":memory:") + try: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + relation = wrapped_sql( + """ + SELECT make + FROM 'https://datasets.marimo.app/cars.csv' + WHERE mpg > 20 + """, + connection, + ) + rows = relation.fetchall() + finally: + connection.close() + + assert rows == [("ford",)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_execute_duckdb_sql_rewrites_remote_literal_with_explicit_connection() -> ( + None + ): + import duckdb + + from marimo._sql.utils import execute_duckdb_sql + + table_name = "__marimo_duckdb_wasm_sql_utils_execute_test" + connection = duckdb.connect(":memory:") + try: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + execute_duckdb_sql( + f""" + CREATE OR REPLACE TABLE {table_name} AS + SELECT make + FROM 'https://datasets.marimo.app/cars.csv' + WHERE mpg > ? + """, + [20], + connection, + ) + rows = connection.sql( + f"SELECT make FROM {table_name} ORDER BY make" + ).fetchall() + finally: + connection.close() + + assert rows == [("ford",)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + +class TestDuckDBWasmSqlApiPatch: + @staticmethod + def test_module_sql_rewrites_remote_literal_and_preserves_params() -> None: + import duckdb + + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + relation = duckdb.sql( + """ + SELECT make FROM 'https://datasets.marimo.app/cars.csv' + WHERE mpg > ? + """, + params=[20], + ) + finally: + unpatch() + + assert relation.fetchall() == [("ford",)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_module_query_rewrites_reader_call() -> None: + import duckdb + + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"1;ford\n2;toyota\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + relation = duckdb.query( + """ + SELECT column1 FROM read_csv( + 'https://datasets.marimo.app/cars.csv', + delim=';', header=false + ) + """ + ) + finally: + unpatch() + + assert relation.fetchall() == [("ford",), ("toyota",)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_module_query_df_rewrites_reader_call() -> None: + import duckdb + import pandas as pd + + local_df = pd.DataFrame({"make": ["ford"], "score": [7]}) + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + relation = duckdb.query_df( + df=local_df, + virtual_table_name="query_df_local", + sql_query=""" + SELECT query_df_local.score, cars.mpg + FROM query_df_local + JOIN read_csv('https://datasets.marimo.app/cars.csv') AS cars + USING (make) + """, + ) + rows = relation.fetchall() + finally: + unpatch() + duckdb.sql("DROP VIEW IF EXISTS query_df_local") + + assert rows == [(7, 25)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_module_query_df_avoids_existing_catalog_table_names() -> None: + import duckdb + import pandas as pd + + local_df = pd.DataFrame({"make": ["ford"], "score": [7]}) + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ), + ): + unpatch = patch_duckdb_for_wasm() + try: + duckdb.sql( + """ + CREATE OR REPLACE TABLE "__MARIMO_WASM_DUCKDB_REMOTE_0" + AS SELECT 'ford' AS make, 99 AS mpg + """ + ) + relation = duckdb.query_df( + df=local_df, + virtual_table_name="query_df_local_collision", + sql_query=""" + SELECT query_df_local_collision.score, cars.mpg + FROM query_df_local_collision + JOIN read_csv('https://datasets.marimo.app/cars.csv') AS cars + USING (make) + """, + ) + rows = relation.fetchall() + finally: + unpatch() + duckdb.sql("DROP VIEW IF EXISTS query_df_local_collision") + duckdb.sql( + 'DROP TABLE IF EXISTS "__MARIMO_WASM_DUCKDB_REMOTE_0"' + ) + + assert rows == [(7, 25)] + + @staticmethod + def test_module_sql_preserves_caller_replacement_scan() -> None: + import duckdb + import pandas as pd + + local_df = pd.DataFrame({"x": [1, 2]}) + with mock_pyodide(): + unpatch = patch_duckdb_for_wasm() + try: + rows = duckdb.sql("SELECT sum(x) FROM local_df").fetchall() + finally: + unpatch() + + assert rows == [(3,)] + + @staticmethod + def test_module_sql_skips_catalog_lookup_without_remote_source() -> None: + import duckdb + + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._duckdb._duckdb_catalog_names" + ) as catalog_names, + ): + unpatch = patch_duckdb_for_wasm() + try: + rows = duckdb.sql("SELECT 1").fetchall() + finally: + unpatch() + + assert rows == [(1,)] + catalog_names.assert_not_called() + + @staticmethod + def test_module_sql_without_remote_source_does_not_require_sqlglot() -> ( + None + ): + import duckdb + + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._duckdb._require_sqlglot", + side_effect=AssertionError("sqlglot should be lazy"), + ), + ): + unpatch = patch_duckdb_for_wasm() + try: + rows = duckdb.sql("SELECT 1").fetchall() + finally: + unpatch() + + assert rows == [(1,)] + + @staticmethod + def test_module_execute_rewrites_before_side_effects() -> None: + import duckdb + + table_name = "__marimo_duckdb_wasm_execute_test" + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ), + ): + unpatch = patch_duckdb_for_wasm() + try: + duckdb.execute( + f""" + CREATE OR REPLACE TABLE {table_name} AS + SELECT make FROM 'https://datasets.marimo.app/cars.csv' + WHERE mpg > ? + """, + [20], + ) + rows = duckdb.sql( + f"SELECT make FROM {table_name} ORDER BY make" + ).fetchall() + finally: + unpatch() + duckdb.sql(f"DROP TABLE IF EXISTS {table_name}") + + assert rows == [("ford",)] + + @staticmethod + def test_module_execute_creates_table_from_remote_literal() -> None: + import duckdb + + table_name = "__marimo_duckdb_wasm_execute_create_test" + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + duckdb.execute( + f""" + CREATE OR REPLACE TABLE {table_name} AS ( + SELECT * FROM 'https://datasets.marimo.app/cars.csv' + ) + """ + ) + rows = duckdb.sql( + f"SELECT make, mpg FROM {table_name} ORDER BY make" + ).fetchall() + finally: + unpatch() + duckdb.sql(f"DROP TABLE IF EXISTS {table_name}") + + assert rows == [("ford", 25), ("toyota", 18)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_connection_methods_preserve_caller_replacement_scans() -> None: + import duckdb + import pandas as pd + + local_df = pd.DataFrame({"x": [1, 2]}) + connection = duckdb.connect(":memory:") + try: + with mock_pyodide(): + unpatch = patch_duckdb_for_wasm() + try: + sql_rows = connection.sql( + "SELECT sum(x) FROM local_df" + ).fetchall() + query_rows = connection.query( + "SELECT count(*) FROM local_df" + ).fetchall() + execute_rows = connection.execute( + "SELECT max(x) FROM local_df" + ).fetchall() + finally: + unpatch() + finally: + connection.close() + + assert sql_rows == [(3,)] + assert query_rows == [(2,)] + assert execute_rows == [(2,)] + + @staticmethod + def test_connection_execute_creates_table_from_remote_literal() -> None: + import duckdb + + table_name = "__marimo_duckdb_wasm_conn_execute_create_test" + connection = duckdb.connect(":memory:") + try: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + connection.execute( + f""" + CREATE OR REPLACE TABLE {table_name} AS ( + SELECT * FROM 'https://datasets.marimo.app/cars.csv' + ) + """ + ) + rows = connection.execute( + f"SELECT make, mpg FROM {table_name} ORDER BY make" + ).fetchall() + finally: + unpatch() + finally: + connection.close() + + assert rows == [("ford", 25), ("toyota", 18)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_connection_sql_joins_caller_local_and_remote_tables() -> None: + import duckdb + import pandas as pd + + local_df = pd.DataFrame({"make": ["ford"], "score": [7]}) + connection = duckdb.connect(":memory:") + try: + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"make,mpg\nford,25\ntoyota,18\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + relation = connection.sql( + """ + SELECT local_df.score, cars.mpg + FROM local_df + JOIN read_csv('https://datasets.marimo.app/cars.csv') AS cars + USING (make) + """ + ) + rows = relation.fetchall() + finally: + unpatch() + finally: + connection.close() + + assert rows == [(7, 25)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_connection_sql_avoids_existing_catalog_table_names() -> None: + import duckdb + + connection = duckdb.connect(":memory:") + try: + connection.sql( + """ + CREATE OR REPLACE TABLE "__MARIMO_WASM_DUCKDB_REMOTE_0" + AS SELECT 99 AS mpg + """ + ) + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"mpg\n25\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + rows = connection.sql( + """ + SELECT mpg + FROM 'https://datasets.marimo.app/cars.csv' + """ + ).fetchall() + finally: + unpatch() + finally: + connection.close() + + assert rows == [(25,)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + @pytest.mark.parametrize("api_kind", ["sql", "query", "execute"]) + def test_module_methods_with_connection_avoid_existing_catalog_table_names( + api_kind: str, + ) -> None: + import duckdb + + connection = duckdb.connect(":memory:") + result_table = "__marimo_duckdb_wasm_module_conn_result" + try: + connection.sql( + """ + CREATE OR REPLACE TABLE "__MARIMO_WASM_DUCKDB_REMOTE_0" + AS SELECT 99 AS mpg + """ + ) + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"mpg\n25\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + if api_kind == "execute": + duckdb.execute( + f""" + CREATE OR REPLACE TABLE {result_table} AS + SELECT mpg + FROM 'https://datasets.marimo.app/cars.csv' + """, + connection=connection, + ) + rows = connection.sql( + f"SELECT mpg FROM {result_table}" + ).fetchall() + else: + rows = getattr(duckdb, api_kind)( + """ + SELECT mpg + FROM 'https://datasets.marimo.app/cars.csv' + """, + connection=connection, + ).fetchall() + finally: + unpatch() + finally: + connection.close() + + assert rows == [(25,)] + fetch_url_bytes.assert_called_once_with( + "https://datasets.marimo.app/cars.csv" + ) + + @staticmethod + def test_double_quoted_url_table_identifier_is_not_remote_source() -> None: + import duckdb + + table_name = '"https://datasets.marimo.app/cars.csv"' + try: + duckdb.sql( + f"CREATE OR REPLACE TABLE {table_name} AS SELECT 42 AS x" + ) + with ( + mock_pyodide(), + patch( + "marimo._runtime._wasm._fetch.fetch_url_bytes", + return_value=b"x\n7\n", + ) as fetch_url_bytes, + ): + unpatch = patch_duckdb_for_wasm() + try: + rows = duckdb.sql(f"SELECT x FROM {table_name}").fetchall() + finally: + unpatch() + finally: + duckdb.sql(f"DROP TABLE IF EXISTS {table_name}") + + assert rows == [(42,)] + fetch_url_bytes.assert_not_called() + + @staticmethod + @mock_pyodide() + def test_unpatch_restores_module_functions_and_connection_methods() -> ( + None + ): + import duckdb + + original_sql = duckdb.sql + original_query_df = duckdb.query_df + original_execute = duckdb.execute + original_connection_sql = duckdb.DuckDBPyConnection.sql + original_connection_execute = duckdb.DuckDBPyConnection.execute + original_connection_read_csv = duckdb.DuckDBPyConnection.read_csv + original_connection_read_parquet = ( + duckdb.DuckDBPyConnection.read_parquet + ) + original_connection_read_json = duckdb.DuckDBPyConnection.read_json + + unpatch = patch_duckdb_for_wasm() + assert duckdb.sql is not original_sql + assert duckdb.query_df is not original_query_df + assert duckdb.execute is not original_execute + assert duckdb.DuckDBPyConnection.sql is not original_connection_sql + assert ( + duckdb.DuckDBPyConnection.execute + is not original_connection_execute + ) + assert ( + duckdb.DuckDBPyConnection.read_csv + is not original_connection_read_csv + ) + assert ( + duckdb.DuckDBPyConnection.read_parquet + is not original_connection_read_parquet + ) + assert ( + duckdb.DuckDBPyConnection.read_json + is not original_connection_read_json + ) + + unpatch() + assert duckdb.sql is original_sql + assert duckdb.query_df is original_query_df + assert duckdb.execute is original_execute + assert duckdb.DuckDBPyConnection.sql is original_connection_sql + assert duckdb.DuckDBPyConnection.execute is original_connection_execute + assert ( + duckdb.DuckDBPyConnection.read_csv is original_connection_read_csv + ) + assert ( + duckdb.DuckDBPyConnection.read_parquet + is original_connection_read_parquet + ) + assert ( + duckdb.DuckDBPyConnection.read_json + is original_connection_read_json + ) + + unpatch() diff --git a/tests/_runtime/test_patches.py b/tests/_runtime/test_patches.py index abd42633c96..9b820f8f315 100644 --- a/tests/_runtime/test_patches.py +++ b/tests/_runtime/test_patches.py @@ -2,6 +2,7 @@ from __future__ import annotations import io +from contextlib import nullcontext from typing import TYPE_CHECKING from unittest.mock import patch @@ -332,6 +333,34 @@ def test_skips_missing_attr() -> None: patches.unpatch_all()() +def test_fetch_url_bytes_forwards_request_and_urlopen_kwargs() -> None: + import urllib.request + + from marimo._runtime._wasm._fetch import fetch_url_bytes + + with patch( + "urllib.request.urlopen", + return_value=nullcontext(io.BytesIO(b"ok")), + ) as urlopen: + assert ( + fetch_url_bytes( + "https://example.com/cars.csv", + request_kwargs={ + "headers": {"User-Agent": "marimo"}, + "method": "GET", + }, + urlopen_kwargs={"timeout": 5}, + ) + == b"ok" + ) + + request = urlopen.call_args.args[0] + assert isinstance(request, urllib.request.Request) + assert request.get_header("User-agent") == "marimo" + assert request.get_method() == "GET" + assert urlopen.call_args.kwargs == {"timeout": 5} + + @pytest.mark.requires("polars", "pyarrow") class TestPolarsIoWasmPatch: @staticmethod