diff --git a/bindings/python/.gitignore b/bindings/python/.gitignore index d2628ced..32ac0fc7 100644 --- a/bindings/python/.gitignore +++ b/bindings/python/.gitignore @@ -1,4 +1,5 @@ /target +.databend # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/bindings/python/README.md b/bindings/python/README.md index 95964e15..1619ded8 100644 --- a/bindings/python/README.md +++ b/bindings/python/README.md @@ -8,6 +8,106 @@ Databend Python Client ## Usage +### Local Embedded Connection + +The local embedded mode runs a full Databend engine in-process without any +server. It is useful for local analytics, testing, and offline workflows. + +Install the `local` extra to pull in the embedded engine: + +```bash +pip install "databend-driver[local]" +``` + +The embedded dependency currently requires Python 3.12 or later. + +```python +from databend_driver import connect + +# Persistent state stored under ./local-state +conn = connect("databend+local:///./local-state") +conn.exec("CREATE TABLE books(id INT, title STRING)") +conn.exec("INSERT INTO books VALUES (1, 'Databend')") + +row = conn.query_row("SELECT title FROM books ORDER BY id LIMIT 1") +print(row.values()) # ('Databend',) + +rows = [row.values() for row in conn.query_iter("SELECT * FROM books ORDER BY id")] +``` + +Supported local targets: + +- `connect(":memory:")` — temporary in-memory instance (discarded on close) +- `connect("databend+local:///:memory:")` — explicit in-memory instance +- `connect("databend+local:///./local-state")` — persistent state under `./local-state` +- `connect("databend+local:///./local-state?tenant=default")` — persistent state with an explicit tenant +- `connect("databend+local:///./local-state?database=mydb")` — open a specific database + +You can also use `connect_local()` directly for more control: + +```python +from databend_driver import connect_local + +conn = connect_local(database=":memory:") +conn = connect_local(data_path="./local-state", tenant="default") +``` + +If the optional `databend` package is not installed, `connect()` raises an +`ImportError` with guidance about enabling the `local` extra and the Python +version requirement. + +For remote Databend, the same `connect()` entrypoint accepts standard DSNs: + +```python +from databend_driver import connect + +conn = connect("databend://root:@localhost:8000/?sslmode=disable") +row = conn.query_row("SELECT 1") +``` + +#### Relation API + +The local connection exposes an embedded-specific relation API for working +with query results as DataFrames or Arrow tables: + +```python +relation = conn.sql("SELECT * FROM books") + +df = relation.df() # pandas DataFrame +pl = relation.pl() # polars DataFrame +tbl = relation.arrow() # pyarrow Table + +rows = relation.fetchall() # list[tuple] +row = relation.fetchone() # tuple | None +``` + +#### Registering External Data + +You can register files or in-memory data as virtual tables: + +```python +# Register a Parquet file +conn.register("sales", "./data/sales.parquet") +conn.sql("SELECT * FROM sales LIMIT 10").df() + +# Register a CSV file +conn.register("events", "./data/events.csv") + +# Register a pandas or polars DataFrame +import pandas as pd +df = pd.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]}) +conn.register("users", df) + +# Shorthand: register a DataFrame and return a relation immediately +relation = conn.from_df(df) + +# Read helpers (register and return relation in one call) +relation = conn.read_parquet("./data/sales.parquet") +relation = conn.read_csv("./data/events.csv") +relation = conn.read_json("./data/logs.ndjson") +relation = conn.read_text("./data/raw.txt") +``` + ### PEP 249 Cursor Object ```python @@ -96,17 +196,25 @@ asyncio.run(main()) ### Parameter bindings ```python -# Test with positional parameters +# Positional parameters using ? row = await context.conn.query_row("SELECT ?, ?, ?, ?", (3, False, 4, "55")) + +# Named parameters using :name row = await context.conn.query_row( "SELECT :a, :b, :c, :d", {"a": 3, "b": False, "c": 4, "d": "55"} ) -row = await context.conn.query_row( - "SELECT ?", 3 -) -row = await context.conn.query_row("SELECT ?, ?, ?, ?", params = (3, False, 4, "55")) + +# Single value (no tuple needed) +row = await context.conn.query_row("SELECT ?", 3) + +# Keyword argument form +row = await context.conn.query_row("SELECT ?, ?, ?, ?", params=(3, False, 4, "55")) ``` +Named parameters use token-aware matching, so `:a` will not corrupt `:ab`. +For local embedded connections, passing a mismatched number of `?` placeholders +and values raises a `ValueError` immediately. + ### Query ID tracking and query management ```python @@ -368,6 +476,92 @@ class ConnectionInfo: def warehouse(self) -> str | None: ... ``` +### connect_local + +```python +def connect_local( + database: str = ":memory:", + *, + data_path: str | None = None, + tenant: str | None = None, +) -> LocalConnection: ... +``` + +### LocalConnection + +```python +class LocalConnection: + def sql(self, query: str) -> LocalRelation: ... + def table(self, name: str) -> LocalRelation: ... + def format_sql(self, sql: str, params: Any = None) -> str: ... + def execute(self, query: str, params: Any = None) -> None: ... + def exec(self, sql: str, params: Any = None) -> None: ... + def query_row(self, sql: str, params: Any = None) -> LocalRow | None: ... + def query_all(self, sql: str, params: Any = None) -> list[LocalRow]: ... + def query_iter(self, sql: str, params: Any = None) -> LocalRowIterator: ... + def close(self) -> None: ... + def last_query_id(self) -> None: ... # always None for local mode + def kill_query(self, query_id: str) -> None: ... # raises NotImplementedError + def register( + self, + name: str, + source: Any, # path str/Path, pandas/polars DataFrame, or pyarrow Table + *, + format: str | None = None, + pattern: str | None = None, + connection: str | None = None, + ) -> LocalConnection: ... + def from_df(self, source: Any, *, name: str | None = None) -> LocalRelation: ... + def read_parquet( + self, path: str | Path, *, pattern: str | None = None, + connection: str | None = None, name: str | None = None, + ) -> LocalRelation: ... + def read_csv( + self, path: str | Path, *, pattern: str | None = None, + connection: str | None = None, name: str | None = None, + ) -> LocalRelation: ... + def read_json( + self, path: str | Path, *, pattern: str | None = None, + connection: str | None = None, name: str | None = None, + ) -> LocalRelation: ... + def read_text( + self, path: str | Path, *, pattern: str | None = None, + connection: str | None = None, name: str | None = None, + ) -> LocalRelation: ... +``` + +### LocalRelation + +```python +class LocalRelation: + def df(self) -> Any: ... # pandas DataFrame + def pl(self) -> Any: ... # polars DataFrame + def arrow(self) -> Any: ... # pyarrow Table + def fetchall(self) -> list[tuple]: ... + def fetchone(self) -> tuple | None: ... +``` + +### LocalRow + +```python +class LocalRow: + def values(self) -> tuple[Any, ...]: ... + def __len__(self) -> int: ... + def __iter__(self) -> LocalRow: ... + def __next__(self) -> Any: ... + def __getitem__(self, key: int) -> Any: ... +``` + +### LocalRowIterator + +```python +class LocalRowIterator: + def schema(self) -> Any: ... # not yet implemented for local mode + def close(self) -> None: ... + def __iter__(self) -> LocalRowIterator: ... + def __next__(self) -> LocalRow: ... +``` + ## Development ``` @@ -377,11 +571,14 @@ make up ```shell cd bindings/python -uv sync +uv python install 3.12 +uv venv --python 3.12 +uv sync --extra local source .venv/bin/activate -maturin develop --uv +maturin develop behave tests/asyncio behave tests/blocking behave tests/cursor +behave tests/local ``` diff --git a/bindings/python/package/databend_driver/__init__.py b/bindings/python/package/databend_driver/__init__.py index 589abc26..7f0a39a8 100644 --- a/bindings/python/package/databend_driver/__init__.py +++ b/bindings/python/package/databend_driver/__init__.py @@ -15,6 +15,14 @@ # flake8: noqa from ._databend_driver import * +from .local import ( + LocalConnection, + LocalRelation, + LocalRow, + LocalRowIterator, + connect, + connect_local, +) # Export for convenience at module level __all__ = [ @@ -42,4 +50,11 @@ "Row", "RowIterator", "ServerStats", + # Local embedded mode + "LocalConnection", + "LocalRelation", + "LocalRow", + "LocalRowIterator", + "connect", + "connect_local", ] diff --git a/bindings/python/package/databend_driver/__init__.pyi b/bindings/python/package/databend_driver/__init__.pyi index 4e4d8a6d..ccfc7df6 100644 --- a/bindings/python/package/databend_driver/__init__.pyi +++ b/bindings/python/package/databend_driver/__init__.pyi @@ -14,6 +14,9 @@ # flake8: noqa +from pathlib import Path +from typing import Any + # Exception classes - PEP 249 compliant class Warning(Exception): ... class Error(Exception): ... @@ -69,9 +72,9 @@ class Row: def values(self) -> tuple: ... def __len__(self) -> int: ... def __iter__(self) -> Row: ... - def __next__(self) -> any: ... + def __next__(self) -> Any: ... def __dict__(self) -> dict: ... - def __getitem__(self, key: int | str) -> any: ... + def __getitem__(self, key: int | str) -> Any: ... class RowIterator: def schema(self) -> Schema: ... @@ -87,9 +90,11 @@ class AsyncDatabendConnection: async def close(self) -> None: ... def last_query_id(self) -> str | None: ... async def kill_query(self, query_id: str) -> None: ... - async def exec(self, sql: str) -> int: ... - async def query_row(self, sql: str) -> Row: ... - async def query_iter(self, sql: str) -> RowIterator: ... + def format_sql(self, sql: str, params: Any = None) -> str: ... + async def exec(self, sql: str, params: Any = None) -> int: ... + async def query_row(self, sql: str, params: Any = None) -> Row: ... + async def query_all(self, sql: str, params: Any = None) -> list[Row]: ... + async def query_iter(self, sql: str, params: Any = None) -> RowIterator: ... async def stream_load(self, sql: str, data: list[list[str]]) -> ServerStats: ... async def load_file( self, sql: str, file: str, format_option: dict, copy_options: dict = None @@ -105,9 +110,11 @@ class BlockingDatabendConnection: def close(self) -> None: ... def last_query_id(self) -> str | None: ... def kill_query(self, query_id: str) -> None: ... - def exec(self, sql: str) -> int: ... - def query_row(self, sql: str) -> Row: ... - def query_iter(self, sql: str) -> RowIterator: ... + def format_sql(self, sql: str, params: Any = None) -> str: ... + def exec(self, sql: str, params: Any = None) -> int: ... + def query_row(self, sql: str, params: Any = None) -> Row: ... + def query_all(self, sql: str, params: Any = None) -> list[Row]: ... + def query_iter(self, sql: str, params: Any = None) -> RowIterator: ... def stream_load(self, sql: str, data: list[list[str]]) -> ServerStats: ... def load_file( self, sql: str, file: str, format_option: dict, copy_options: dict = None @@ -124,10 +131,10 @@ class BlockingDatabendCursor: def rowcount(self) -> int: ... def close(self) -> None: ... def execute( - self, operation: str, params: list[any] | tuple[any] = None + self, operation: str, params: list[Any] | tuple[Any] = None ) -> None | int: ... def executemany( - self, operation: str, params: list[list[any] | tuple[any]] + self, operation: str, params: list[list[Any] | tuple[Any]] ) -> None | int: ... def fetchone(self) -> Row | None: ... def fetchmany(self, size: int = 1) -> list[Row]: ... @@ -142,3 +149,88 @@ class BlockingDatabendClient: def __init__(self, dsn: str): ... def get_conn(self) -> BlockingDatabendConnection: ... def cursor(self) -> BlockingDatabendCursor: ... + +class LocalRelation: + def df(self) -> Any: ... + def pl(self) -> Any: ... + def arrow(self) -> Any: ... + def fetchall(self) -> list[tuple]: ... + def fetchone(self) -> tuple | None: ... + +class LocalRow: + def values(self) -> tuple[Any, ...]: ... + def __len__(self) -> int: ... + def __iter__(self) -> LocalRow: ... + def __next__(self) -> Any: ... + def __getitem__(self, key: int) -> Any: ... + +class LocalRowIterator: + def schema(self) -> Any: ... + def close(self) -> None: ... + def __iter__(self) -> LocalRowIterator: ... + def __next__(self) -> LocalRow: ... + +class LocalConnection: + def format_sql(self, sql: str, params: Any = None) -> str: ... + def execute(self, query: str, params: Any = None) -> None: ... + def exec(self, sql: str, params: Any = None) -> None: ... + def query_row(self, sql: str, params: Any = None) -> LocalRow | None: ... + def query_all(self, sql: str, params: Any = None) -> list[LocalRow]: ... + def query_iter(self, sql: str, params: Any = None) -> LocalRowIterator: ... + def close(self) -> None: ... + def last_query_id(self) -> None: ... + def kill_query(self, query_id: str) -> None: ... + def sql(self, query: str) -> LocalRelation: ... + def table(self, name: str) -> LocalRelation: ... + def register( + self, + name: str, + source: Any, + *, + format: str | None = None, + pattern: str | None = None, + connection: str | None = None, + ) -> LocalConnection: ... + def from_df(self, source: Any, *, name: str | None = None) -> LocalRelation: ... + def read_parquet( + self, + path: str | Path, + *, + pattern: str | None = None, + connection: str | None = None, + name: str | None = None, + ) -> LocalRelation: ... + def read_csv( + self, + path: str | Path, + *, + pattern: str | None = None, + connection: str | None = None, + name: str | None = None, + ) -> LocalRelation: ... + def read_json( + self, + path: str | Path, + *, + pattern: str | None = None, + connection: str | None = None, + name: str | None = None, + ) -> LocalRelation: ... + def read_text( + self, + path: str | Path, + *, + pattern: str | None = None, + connection: str | None = None, + name: str | None = None, + ) -> LocalRelation: ... + +def connect_local( + database: str = ":memory:", + *, + data_path: str | None = None, + tenant: str | None = None, +) -> LocalConnection: ... +def connect( + target: str = ":memory:", **kwargs: Any +) -> BlockingDatabendConnection | LocalConnection: ... diff --git a/bindings/python/package/databend_driver/local.py b/bindings/python/package/databend_driver/local.py new file mode 100644 index 00000000..5556ebc3 --- /dev/null +++ b/bindings/python/package/databend_driver/local.py @@ -0,0 +1,474 @@ +# Copyright 2021 Datafuse Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import uuid +from importlib import import_module +from pathlib import Path +from tempfile import mkdtemp +from typing import Any +from urllib.parse import parse_qs, urlparse + + +def _load_embedded_module(): + try: + import databend as embedded + except ImportError as exc: + version_hint = "" + if _python_version_tuple() < (3, 12): + version_hint = ( + f" Current interpreter is Python {_python_version_str()}, but the " + "embedded dependency currently requires Python 3.12+." + ) + raise ImportError( + "Local embedded mode requires the optional `databend` package. " + "Install databend-driver with the `local` extra or provide the " + "internal databend binding in the environment." + version_hint + ) from exc + return embedded + + +def _normalize_path(path: str | Path) -> str: + return str(Path(path).expanduser().resolve()) + + +def _random_name(prefix: str) -> str: + return f"{prefix}_{uuid.uuid4().hex}" + + +class LocalRelation: + def __init__(self, relation: Any): + self._relation = relation + + def __repr__(self) -> str: + return repr(self._relation) + + def __getattr__(self, name: str) -> Any: + return getattr(self._relation, name) + + def df(self): + return self._relation.to_pandas() + + def pl(self): + return self._relation.to_polars() + + def arrow(self): + return self._relation.to_arrow_table() + + def fetchall(self) -> list[tuple[Any, ...]]: + table = self.arrow() + columns = [ + table.column(index).to_pylist() for index in range(table.num_columns) + ] + return [ + tuple(column[row_index] for column in columns) + for row_index in range(table.num_rows) + ] + + def fetchone(self) -> tuple[Any, ...] | None: + rows = self.fetchall() + return rows[0] if rows else None + + +class LocalRow: + def __init__(self, values: tuple[Any, ...]): + self._values = values + self._idx = 0 + + def values(self) -> tuple[Any, ...]: + return self._values + + def __len__(self) -> int: + return len(self._values) + + def __iter__(self) -> LocalRow: + return self + + def __next__(self) -> Any: + if self._idx >= len(self._values): + raise StopIteration("Columns exhausted") + value = self._values[self._idx] + self._idx += 1 + return value + + def __getitem__(self, key: int) -> Any: + if not isinstance(key, int): + raise TypeError("key must be an integer") + return self._values[key] + + def __repr__(self) -> str: + return repr(self._values) + + +class LocalRowIterator: + def __init__(self, rows: list[LocalRow]): + self._rows = rows + self._idx = 0 + + def schema(self): + raise NotImplementedError( + "schema() is not available for local embedded queries yet." + ) + + def close(self) -> None: + self._idx = len(self._rows) + + def __iter__(self) -> LocalRowIterator: + return self + + def __next__(self) -> LocalRow: + if self._idx >= len(self._rows): + raise StopIteration("Rows exhausted") + row = self._rows[self._idx] + self._idx += 1 + return row + + +class LocalConnection: + def __init__(self, impl: Any): + self._impl = impl + + def __repr__(self) -> str: + return repr(self._impl) + + def __getattr__(self, name: str) -> Any: + return getattr(self._impl, name) + + def sql(self, query: str) -> LocalRelation: + return LocalRelation(self._impl.sql(query)) + + def format_sql(self, sql: str, params: Any = None) -> str: + if params is None: + return sql + + if isinstance(params, dict): + import re + + def _replace_named(m: re.Match) -> str: + key = m.group(1) + if key not in params: + return m.group(0) + return _sql_literal(params[key]) + + return re.sub(r":([A-Za-z_][A-Za-z0-9_]*)", _replace_named, sql) + + if not isinstance(params, (list, tuple)): + params = [params] + + placeholder_count = sql.count("?") + if placeholder_count != len(params): + raise ValueError( + f"Parameter count mismatch: SQL has {placeholder_count} placeholder(s) " + f"but {len(params)} value(s) were provided." + ) + + rendered = sql + for value in params: + rendered = rendered.replace("?", _sql_literal(value), 1) + return rendered + + def execute(self, query: str, params: Any = None) -> None: + statement = self.format_sql(query, params) + self._impl.sql(statement).collect() + + def exec(self, sql: str, params: Any = None) -> None: + self.execute(sql, params) + + def query_row(self, sql: str, params: Any = None) -> LocalRow | None: + statement = self.format_sql(sql, params) + row = self.sql(statement).fetchone() + if row is None: + return None + return LocalRow(tuple(row)) + + def query_all(self, sql: str, params: Any = None) -> list[LocalRow]: + statement = self.format_sql(sql, params) + return [LocalRow(tuple(row)) for row in self.sql(statement).fetchall()] + + def query_iter(self, sql: str, params: Any = None) -> LocalRowIterator: + return LocalRowIterator(self.query_all(sql, params)) + + def close(self) -> None: + if hasattr(self._impl, "close"): + self._impl.close() + + def last_query_id(self) -> None: + return None + + def kill_query(self, query_id: str) -> None: + raise NotImplementedError( + "kill_query() is not supported for local embedded mode." + ) + + def table(self, name: str) -> LocalRelation: + return self.sql(f"SELECT * FROM {name}") + + def register( + self, + name: str, + source: Any, + *, + format: str | None = None, + pattern: str | None = None, + connection: str | None = None, + ) -> LocalConnection: + if isinstance(source, (str, Path)): + source_path = str(source) + source_format = (format or Path(source_path).suffix.lstrip(".")).lower() + if source_format in {"parquet", "pq"}: + self._impl.register_parquet( + name, source_path, pattern=pattern, connection=connection + ) + elif source_format in {"csv"}: + self._impl.register_csv( + name, source_path, pattern=pattern, connection=connection + ) + elif source_format in {"json", "ndjson"}: + self._impl.register_ndjson( + name, source_path, pattern=pattern, connection=connection + ) + elif source_format in {"txt", "text", "tsv"}: + self._impl.register_text( + name, source_path, pattern=pattern, connection=connection + ) + else: + raise ValueError( + f"Unsupported format for {source_path!r}. " + "Use format= explicitly or pass pandas/polars/pyarrow data." + ) + return self + + parquet_path = self._materialize_relation_source(name, source) + self._impl.register_parquet( + name, parquet_path, pattern=pattern, connection=connection + ) + return self + + def from_df(self, source: Any, *, name: str | None = None) -> LocalRelation: + target = name or _random_name("df") + self.register(target, source) + return self.table(target) + + def read_parquet( + self, + path: str | Path, + *, + pattern: str | None = None, + connection: str | None = None, + name: str | None = None, + ) -> LocalRelation: + target = name or _random_name("parquet") + self._impl.register_parquet( + target, str(path), pattern=pattern, connection=connection + ) + return self.table(target) + + def read_csv( + self, + path: str | Path, + *, + pattern: str | None = None, + connection: str | None = None, + name: str | None = None, + ) -> LocalRelation: + target = name or _random_name("csv") + self._impl.register_csv( + target, str(path), pattern=pattern, connection=connection + ) + return self.table(target) + + def read_json( + self, + path: str | Path, + *, + pattern: str | None = None, + connection: str | None = None, + name: str | None = None, + ) -> LocalRelation: + target = name or _random_name("json") + self._impl.register_ndjson( + target, str(path), pattern=pattern, connection=connection + ) + return self.table(target) + + def read_text( + self, + path: str | Path, + *, + pattern: str | None = None, + connection: str | None = None, + name: str | None = None, + ) -> LocalRelation: + target = name or _random_name("text") + self._impl.register_text( + target, str(path), pattern=pattern, connection=connection + ) + return self.table(target) + + def _materialize_relation_source(self, name: str, source: Any) -> str: + table = self._to_arrow_table(source) + temp_dir = self._data_path() / "python" / "registered" + temp_dir.mkdir(parents=True, exist_ok=True) + parquet_path = temp_dir / f"{name}_{uuid.uuid4().hex}.parquet" + + import pyarrow.parquet as pq + + pq.write_table(table, parquet_path) + return _normalize_path(parquet_path) + + @staticmethod + def _to_arrow_table(source: Any): + if hasattr(source, "schema") and hasattr(source, "to_pydict"): + return source + + if hasattr(source, "to_arrow"): + return source.to_arrow() + + if hasattr(source, "to_pandas"): + source = source.to_pandas() + + try: + import pyarrow as pa + + return pa.Table.from_pandas(source, preserve_index=False) + except Exception as exc: + raise TypeError( + "Unsupported source type. Expected path, pandas.DataFrame, " + "polars.DataFrame, or pyarrow.Table." + ) from exc + + def _data_path(self) -> Path: + value = getattr(self._impl, "_data_path", None) + if value is None: + return Path(".databend").resolve() + return Path(value).expanduser().resolve() + + +def connect_local( + database: str = ":memory:", + *, + data_path: str | None = None, + tenant: str | None = None, +) -> LocalConnection: + embedded = _load_embedded_module() + memory_target = database == ":memory:" + explicit_data_path = ( + None if memory_target and data_path == ":memory:" else data_path + ) + + if tenant is None and hasattr(embedded, "connect"): + if explicit_data_path is not None: + return LocalConnection( + embedded.connect(database=database, data_path=explicit_data_path) + ) + if memory_target: + conn = LocalConnection( + embedded.connect(data_path=mkdtemp(prefix="databend-embedded-")) + ) + conn._ephemeral = True + return conn + return LocalConnection(embedded.connect(data_path=database)) + + target_path = explicit_data_path or (".databend" if memory_target else database) + return LocalConnection(embedded.SessionContext(tenant, data_path=target_path)) + + +def connect(target: str = ":memory:", **kwargs: Any): + if _is_local_target(target): + database, data_path, tenant = _parse_local_target( + target, + kwargs.get("data_path"), + kwargs.get("tenant"), + ) + return connect_local(database=database, data_path=data_path, tenant=tenant) + + package = import_module("databend_driver") + client = package.BlockingDatabendClient(target) + return client.get_conn() + + +def _is_local_target(target: str) -> bool: + return target == ":memory:" or target.startswith("databend+local://") + + +def _parse_local_target( + target: str, explicit_data_path: str | None, explicit_tenant: str | None +) -> tuple[str, str | None, str | None]: + if target == ":memory:": + return ":memory:", explicit_data_path, explicit_tenant + + parsed = urlparse(target) + database = ":memory:" + query = parse_qs(parsed.query) + tenant = explicit_tenant + + if explicit_data_path is not None: + data_path = explicit_data_path + elif "data_path" in query and query["data_path"]: + data_path = query["data_path"][0] + else: + raw_path = parsed.path or "" + if raw_path == "/:memory:": + raw_path = ":memory:" + elif raw_path.startswith("/./") or raw_path.startswith("/../"): + # Strip the leading slash so relative paths like /./local-state + # are preserved as ./local-state rather than forced absolute. + raw_path = raw_path[1:] + data_path = raw_path if raw_path not in {"", "/"} else None + + if "database" in query and query["database"]: + database = query["database"][0] + elif data_path is not None: + database = data_path + + if tenant is None and "tenant" in query and query["tenant"]: + tenant = query["tenant"][0] + + return database, data_path, tenant + + +def _python_version_tuple() -> tuple[int, int]: + import sys + + return sys.version_info[:2] + + +def _python_version_str() -> str: + major, minor = _python_version_tuple() + return f"{major}.{minor}" + + +def _sql_literal(value: Any) -> str: + if value is None: + return "NULL" + if isinstance(value, bool): + return "TRUE" if value else "FALSE" + if isinstance(value, (int, float)): + return str(value) + if isinstance(value, str): + return "'" + value.replace("\\", "\\\\").replace("'", "''") + "'" + raise TypeError( + f"Invalid parameter type for {value!r}, expected str, bool, int, float or None" + ) + + +__all__ = [ + "LocalConnection", + "LocalRelation", + "LocalRow", + "LocalRowIterator", + "connect", + "connect_local", +] diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index ee1799dd..1fda564a 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -35,6 +35,9 @@ readme = "README.md" requires-python = ">=3.8, < 3.14" dynamic = ["version"] +[project.optional-dependencies] +local = ['databend>=1.2.895; python_version >= "3.12"', 'pyarrow'] + [project.urls] Repository = "https://github.com/databendlabs/bendsql" diff --git a/bindings/python/tests/local/local.feature b/bindings/python/tests/local/local.feature new file mode 100644 index 00000000..1d374f89 --- /dev/null +++ b/bindings/python/tests/local/local.feature @@ -0,0 +1,53 @@ +Feature: Databend Driver Local Mode + + Scenario: Local connect with persistent path + Given Real local embedded dependencies are available + When A new local embedded connection is created + Then Local select 1 should equal 1 + + Scenario: Local connect with memory target + Given Real local embedded dependencies are available + When A new local memory connection is created + Then Local numbers aggregate should match expected values + + Scenario: Local explicit memory dsn parsing + Given Real local embedded dependencies are available + Then Local explicit memory dsn should parse as memory mode + + Scenario: Local execute and query roundtrip + Given Real local embedded dependencies are available + When A new local embedded connection is created + Then Local execute should create and populate a table + + Scenario: Local tenant mode + Given Real local embedded dependencies are available + When A new local tenant connection is created + Then Local tenant connection should use the configured data path + + Scenario: Local register parquet + Given Real local embedded dependencies are available + When A parquet file is registered in local mode + Then Local parquet query should return expected rows + + Scenario: Local dsn connect + Given Real local embedded dependencies are available + When A new local dsn connection is created + Then Local dsn connection should execute queries + + Scenario: Local tenant dsn connect + Given Real local embedded dependencies are available + When A new local tenant dsn connection is created + Then Local tenant dsn connection should execute queries + + Scenario: Local import error message + Then Local import error should mention Python 3.12 requirement + + Scenario: Local blocking query api + Given Real local embedded dependencies are available + When A new local embedded connection is created + Then Local blocking query api should behave like expected + + Scenario: Local parameter formatting + Given Real local embedded dependencies are available + When A new local embedded connection is created + Then Local parameter formatting should behave like expected diff --git a/bindings/python/tests/local/steps/binding.py b/bindings/python/tests/local/steps/binding.py new file mode 100644 index 00000000..89ff2434 --- /dev/null +++ b/bindings/python/tests/local/steps/binding.py @@ -0,0 +1,261 @@ +# Copyright 2021 Datafuse Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import os +import sys +import tempfile +from importlib import metadata +from pathlib import Path +from unittest.mock import patch + +from behave import given, then, when + + +LOCAL_MODULE_PATH = ( + Path(__file__).resolve().parent.parent.parent.parent + / "package" + / "databend_driver" + / "local.py" +) + +MIN_DATABEND_VERSION = (1, 2, 895) + + +def load_local_module(): + spec = importlib.util.spec_from_file_location( + "databend_driver.local", LOCAL_MODULE_PATH + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def parse_version(version: str) -> tuple[int, ...]: + parts = [] + for chunk in version.split("."): + digits = [] + for char in chunk: + if char.isdigit(): + digits.append(char) + else: + break + if not digits: + break + parts.append(int("".join(digits))) + return tuple(parts) + + +def require_real_embedded(): + if sys.version_info < (3, 12): + raise AssertionError( + "local integration tests require Python 3.12+, because " + "databend>=1.2.895 currently only publishes cp312 wheels" + ) + + try: + import databend # noqa: F401 + import pyarrow # noqa: F401 + except ImportError as exc: + raise AssertionError( + "local integration tests require real `databend` and `pyarrow` packages installed" + ) from exc + + try: + version = metadata.version("databend") + except metadata.PackageNotFoundError as exc: + raise AssertionError("databend package metadata is not available") from exc + + if parse_version(version) < MIN_DATABEND_VERSION: + raise AssertionError( + f"local integration tests require databend >= 1.2.895, found {version}" + ) + + +@given("Real local embedded dependencies are available") +def _(context): + require_real_embedded() + context.local = load_local_module() + context.tmpdirs = [] + + +@when("A new local embedded connection is created") +def _(context): + tmpdir = tempfile.TemporaryDirectory(prefix="bendsql-local-") + context.tmpdirs.append(tmpdir) + context.tmpdir = tmpdir.name + context.conn = context.local.connect_local(context.tmpdir) + + +@when("A new local memory connection is created") +def _(context): + context.conn = context.local.connect(":memory:") + + +@when("A new local tenant connection is created") +def _(context): + tmpdir = tempfile.TemporaryDirectory(prefix="bendsql-tenant-") + context.tmpdirs.append(tmpdir) + context.tmpdir = tmpdir.name + context.conn = context.local.connect_local(context.tmpdir, tenant="default") + + +@when("A parquet file is registered in local mode") +def _(context): + import pyarrow as pa + import pyarrow.parquet as pq + + tmpdir = tempfile.TemporaryDirectory(prefix="bendsql-register-") + context.tmpdirs.append(tmpdir) + context.tmpdir = tmpdir.name + parquet_path = Path(context.tmpdir) / "books.parquet" + pq.write_table( + pa.table({"id": [1, 2], "name": ["databend", "bendsql"]}), + parquet_path, + ) + + context.conn = context.local.connect_local(context.tmpdir) + context.conn.register("books", parquet_path, format="parquet") + + +@when("A new local dsn connection is created") +def _(context): + tmpdir = tempfile.TemporaryDirectory(prefix="bendsql-dsn-") + context.tmpdirs.append(tmpdir) + context.tmpdir = tmpdir.name + context.local.connect("databend+local:///tmp/demo") + context.conn = context.local.connect( + f"databend+local:///{Path(context.tmpdir).as_posix().lstrip('/')}" + ) + + +@when("A new local tenant dsn connection is created") +def _(context): + tmpdir = tempfile.TemporaryDirectory(prefix="bendsql-dsn-tenant-") + context.tmpdirs.append(tmpdir) + context.tmpdir = tmpdir.name + dsn = f"databend+local:///{Path(context.tmpdir).as_posix().lstrip('/')}?tenant=test_tenant" + context.conn = context.local.connect(dsn) + + +@then("Local select 1 should equal 1") +def _(context): + assert context.conn.sql("select 1").fetchone() == (1,) + + +@then("Local numbers aggregate should match expected values") +def _(context): + assert context.conn.sql("select sum(number), 'a' from numbers(101)").fetchone() == ( + 5050, + "a", + ) + + +@then("Local explicit memory dsn should parse as memory mode") +def _(context): + database, data_path, tenant = context.local._parse_local_target( + "databend+local:///:memory:", None, None + ) + assert database == ":memory:" + assert data_path == ":memory:" + assert tenant is None + + +@then("Local execute should create and populate a table") +def _(context): + context.conn.execute("create or replace table t(a int)") + context.conn.exec("insert into t values (1), (2), (3)") + assert context.conn.query_row("select sum(a) from t").values() == (6,) + + +@then("Local tenant connection should use the configured data path") +def _(context): + assert str(context.conn._impl._data_path) == str(Path(context.tmpdir).resolve()) + assert context.conn.query_row("select 1").values() == (1,) + + +@then("Local parquet query should return expected rows") +def _(context): + assert context.conn.query_row("select count(*) from books").values() == (2,) + assert context.conn.query_row( + "select max(name), min(name) from books" + ).values() == ( + "databend", + "bendsql", + ) + + +@then("Local dsn connection should execute queries") +def _(context): + assert context.conn.query_row("select 1").values() == (1,) + + +@then("Local tenant dsn connection should execute queries") +def _(context): + assert str(context.conn._impl._data_path) == str(Path(context.tmpdir).resolve()) + assert context.conn.query_row("select 11111").values() == (11111,) + + +@then("Local import error should mention Python 3.12 requirement") +def _(context): + local = load_local_module() + real_import = __import__ + + def fake_import(name, *args, **kwargs): + if name == "databend": + raise ImportError("missing databend") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=fake_import): + with patch.object(local, "_python_version_tuple", return_value=(3, 11)): + try: + local.connect(":memory:") + except ImportError as exc: + message = str(exc) + else: + raise AssertionError("expected ImportError for missing databend") + + assert "databend-driver with the `local` extra" in message + assert "Python 3.12+" in message + + +@then("Local blocking query api should behave like expected") +def _(context): + context.conn.exec("create or replace table t(a int)") + context.conn.exec("insert into t values (1), (2), (3)") + assert context.conn.query_row("SELECT 1, 'x', TRUE").values() == (1, "x", True) + assert [row.values() for row in context.conn.query_iter("SELECT * FROM t")] == [ + (1,), + (2,), + (3,), + ] + assert [row.values() for row in context.conn.query_all("SELECT * FROM t")] == [ + (1,), + (2,), + (3,), + ] + assert context.conn.execute("SELECT 1, 2") is None + assert context.conn.exec("SELECT 1, 2") is None + assert context.conn.last_query_id() is None + + +@then("Local parameter formatting should behave like expected") +def _(context): + row = context.conn.query_row("SELECT ?, ?, ?", params=(1, "abc", False)) + assert row.values() == (1, "abc", False) + formatted = context.conn.format_sql( + "SELECT :a, :b, :c", {"a": 1, "b": "x", "c": True} + ) + assert formatted == "SELECT 1, 'x', TRUE"