diff --git a/.gitignore b/.gitignore index a5c325a2..abed402a 100644 --- a/.gitignore +++ b/.gitignore @@ -448,3 +448,7 @@ pyrightconfig.json .ionide # End of https://www.toptal.com/developers/gitignore/api/python,direnv,visualstudiocode,pycharm,macos,jetbrains + +# Mellea config files (may contain credentials) +mellea.toml +.mellea.toml diff --git a/AGENTS.md b/AGENTS.md index 60384a7f..93cb1ac9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -33,6 +33,13 @@ uv run mypy . # Type check ## 2. Directory Structure | Path | Contents | |------|----------| +<<<<<<< user-config +| `mellea/stdlib` | Core: Sessions, Genslots, Requirements, Sampling, Context | +| `mellea/backends` | Providers: HF, OpenAI, Ollama, Watsonx, LiteLLM | +| `mellea/helpers` | Utilities, logging, model ID tables | +| `mellea/config.py` | Configuration file support (TOML) | +| `cli/` | CLI commands (`m serve`, `m alora`, `m decompose`, `m eval`, `m config`) | +======= | `mellea/core/` | Core abstractions: Backend, Base, Formatter, Requirement, Sampling | | `mellea/stdlib/` | Standard library: Sessions, Components, Context | | `mellea/backends/` | Providers: HF, OpenAI, Ollama, Watsonx, LiteLLM | @@ -40,11 +47,58 @@ uv run mypy . # Type check | `mellea/templates/` | Jinja2 templates | | `mellea/helpers/` | Utilities, logging, model ID tables | | `cli/` | CLI commands (`m serve`, `m alora`, `m decompose`, `m eval`) | +>>>>>>> main | `test/` | All tests (run from repo root) | | `docs/examples/` | Example code (run as tests via pytest) | | `scratchpad/` | Experiments (git-ignored) | -## 3. Test Markers +## 3. Configuration Files +Mellea supports TOML configuration files for setting default backends, models, and credentials. + +**Config Location:** `./mellea.toml` (searched in current dir and parents) + +**Value Precedence:** Explicit params > Project config > Defaults + +**CLI Commands:** +```bash +m config init # Create project config +m config show # Display effective config +m config path # Show loaded config file +m config where # Show config location +``` + +**Development Usage:** +- If `mellea.toml` exists, it will be used; if not, defaults apply +- Store credentials in environment variables (never commit credentials) +- Config files are git-ignored by default (`mellea.toml`, `.mellea.toml`) + +**Example Project Config** (`./mellea.toml`): +```toml +[backend] +name = "ollama" +model_id = "llama3.2:1b" + +# Generic model options (apply to all backends) +[backend.model_options] +temperature = 0.7 + +# Per-backend model options (override generic for that backend) +[backend.model_options.ollama] +num_ctx = 4096 + +[backend.model_options.openai] +presence_penalty = 0.5 + +[credentials] +# openai_api_key = "sk-..." # Better: use env vars +``` + +**Testing with Config:** +- Tests use temporary config directories (see `test/config/test_config.py`) +- Integration tests verify config precedence (see `test/config/test_config_integration.py`) +- Clear config cache in tests with `clear_config_cache()` from `mellea.config` + +## 4. Test Markers All tests and examples use markers to indicate requirements. The test infrastructure automatically skips tests based on system capabilities. **Backend Markers:** @@ -107,12 +161,21 @@ Pre-commit runs: ruff, mypy, uv-lock, codespell | Ollama refused | Run `ollama serve` | ## 8. Self-Review (before notifying user) +<<<<<<< user-config +1. **Pre-commit checks pass?** Run `uv run pre-commit run --all-files` or at minimum: + - `uv run ruff format . && uv run ruff check .` (formatting & linting) + - `uv run mypy ` (type checking) +2. `uv run pytest -m "not qualitative"` passes? +======= 1. `uv run pytest test/ -m "not qualitative"` passes? 2. `ruff format` and `ruff check` clean? +>>>>>>> main 3. New functions typed with concise docstrings? 4. Unit tests added for new functionality? 5. Avoided over-engineering? +**Note:** All pre-commit hooks (ruff, mypy, codespell, uv-lock) must pass before a task is considered complete. + ## 9. Writing Tests - Place tests in `test/` mirroring source structure - Name files `test_*.py` (required for pydocstyle) diff --git a/cli/config/__init__.py b/cli/config/__init__.py new file mode 100644 index 00000000..29933f85 --- /dev/null +++ b/cli/config/__init__.py @@ -0,0 +1,5 @@ +"""Configuration management commands for Mellea CLI.""" + +from .commands import config_app + +__all__ = ["config_app"] diff --git a/cli/config/commands.py b/cli/config/commands.py new file mode 100644 index 00000000..5b9dbece --- /dev/null +++ b/cli/config/commands.py @@ -0,0 +1,158 @@ +"""CLI commands for Mellea configuration management.""" + +from pathlib import Path + +import typer +from rich.console import Console +from rich.syntax import Syntax +from rich.table import Table + +from mellea.config import find_config_file, init_project_config, load_config + +config_app = typer.Typer(name="config", help="Manage Mellea configuration files") +console = Console() + + +@config_app.command("init") +def init_project( + force: bool = typer.Option( + False, "--force", "-f", help="Overwrite existing config file" + ), +) -> None: + """Create a project configuration file at ./mellea.toml.""" + try: + config_path = init_project_config(force=force) + console.print(f"[green]✓[/green] Created project config at: {config_path}") + console.print("\nEdit this file to set your backend, model, and other options.") + console.print( + "Run [cyan]m config show[/cyan] to view the current configuration." + ) + except FileExistsError as e: + console.print(f"[red]✗[/red] {e}") + console.print("Use [cyan]--force[/cyan] to overwrite the existing file.") + raise typer.Exit(1) + except Exception as e: + console.print(f"[red]✗[/red] Error creating config: {e}") + raise typer.Exit(1) + + +@config_app.command("show") +def show_config() -> None: + """Display the current effective configuration.""" + try: + config, config_path = load_config() + + # Display config source + if config_path: + console.print(f"[bold]Configuration loaded from:[/bold] {config_path}\n") + else: + console.print( + "[yellow]No configuration file found. Using defaults.[/yellow]\n" + ) + + # Create a table for the configuration + table = Table( + title="Effective Configuration", show_header=True, header_style="bold cyan" + ) + table.add_column("Setting", style="dim") + table.add_column("Value") + + # Backend settings + table.add_row( + "Backend Name", config.backend.name or "[dim](default: ollama)[/dim]" + ) + table.add_row( + "Model ID", + config.backend.model_id or "[dim](default: granite-4-micro:3b)[/dim]", + ) + + # Model options + if config.backend.model_options: + for key, value in config.backend.model_options.items(): + table.add_row(f" {key}", str(value)) + + # Backend kwargs + if config.backend.kwargs: + for key, value in config.backend.kwargs.items(): + table.add_row(f" backend.{key}", str(value)) + + # Credentials (masked) + if config.credentials.openai_api_key: + table.add_row("OpenAI API Key", "[dim]***configured***[/dim]") + if config.credentials.watsonx_api_key: + table.add_row("Watsonx API Key", "[dim]***configured***[/dim]") + if config.credentials.watsonx_project_id: + table.add_row("Watsonx Project ID", config.credentials.watsonx_project_id) + if config.credentials.watsonx_url: + table.add_row("Watsonx URL", config.credentials.watsonx_url) + + # General settings + table.add_row( + "Context Type", config.context_type or "[dim](default: simple)[/dim]" + ) + table.add_row("Log Level", config.log_level or "[dim](default: INFO)[/dim]") + + console.print(table) + + console.print( + "\n[dim]Explicit parameters in code override config file values.[/dim]" + ) + + except Exception as e: + console.print(f"[red]✗[/red] Error loading config: {e}") + raise typer.Exit(1) + + +@config_app.command("path") +def show_path() -> None: + """Show the path to the currently loaded configuration file.""" + try: + config_path = find_config_file() + + if config_path: + console.print(f"[green]✓[/green] Using config file: {config_path}") + + # Show the file content + console.print("\n[bold]File contents:[/bold]") + with open(config_path) as f: + content = f.read() + syntax = Syntax(content, "toml", theme="monokai", line_numbers=True) + console.print(syntax) + else: + console.print("[yellow]No configuration file found.[/yellow]") + console.print("\nSearched: ./mellea.toml (current dir and parents)") + console.print( + "\nRun [cyan]m config init[/cyan] to create a project config." + ) + except Exception as e: + console.print(f"[red]✗[/red] Error: {e}") + raise typer.Exit(1) + + +@config_app.command("where") +def show_locations() -> None: + """Show configuration file location.""" + project_config_path = Path.cwd() / "mellea.toml" + + console.print("[bold]Configuration file location:[/bold]\n") + + # Project config + console.print(f"[cyan]Project config:[/cyan] {project_config_path}") + if project_config_path.exists(): + console.print(" [green]✓ exists[/green]") + else: + console.print(" [dim]✗ not found[/dim]") + console.print(" Run [cyan]m config init[/cyan] to create") + + console.print() + + # Currently loaded (might be in parent dir) + current = find_config_file() + if current: + console.print(f"[bold green]Currently loaded:[/bold green] {current}") + if current != project_config_path: + console.print(" [dim](found in parent directory)[/dim]") + else: + console.print( + "[yellow]No config file currently loaded (using defaults)[/yellow]" + ) diff --git a/cli/eval/runner.py b/cli/eval/runner.py index 3aface94..19ea441c 100644 --- a/cli/eval/runner.py +++ b/cli/eval/runner.py @@ -76,7 +76,7 @@ def pass_rate(self) -> float: def create_session( - backend: str, model: str | None, max_tokens: int | None + backend: str | None, model: str | None, max_tokens: int | None ) -> mellea.MelleaSession: """Create a mellea session with the specified backend and model.""" model_id = None @@ -92,6 +92,11 @@ def create_session( model_id = mellea.model_ids.IBM_GRANITE_4_MICRO_3B try: + from mellea.core.backend import Backend + + if backend is None: + raise ValueError("Backend must be specified") + backend_lower = backend.lower() backend_instance: Backend diff --git a/cli/m.py b/cli/m.py index ab39440e..95d692ea 100644 --- a/cli/m.py +++ b/cli/m.py @@ -3,6 +3,7 @@ import typer from cli.alora.commands import alora_app +from cli.config.commands import config_app from cli.decompose import app as decompose_app from cli.eval.commands import eval_app from cli.serve.app import serve @@ -25,6 +26,7 @@ def callback() -> None: # Add new subcommand groups by importing and adding with `cli.add_typer()` # as documented: https://typer.tiangolo.com/tutorial/subcommands/add-typer/#put-them-together. cli.add_typer(alora_app) +cli.add_typer(config_app) cli.add_typer(decompose_app) cli.add_typer(eval_app) diff --git a/docs/configuration.md b/docs/configuration.md new file mode 100644 index 00000000..455f438d --- /dev/null +++ b/docs/configuration.md @@ -0,0 +1,454 @@ +# Mellea Configuration Guide + +This guide explains how to use configuration files to set default backends, models, credentials, and other options for Mellea. + +## Quick Start + +Create a user configuration file: + +```bash +m config init +``` + +This creates `~/.config/mellea/config.toml` with example settings. Edit this file to set your preferences. + +For project-specific settings: + +```bash +m config init-project +``` + +This creates `./mellea.toml` in your current directory. + +## Configuration Hierarchy + +Mellea searches for configuration files in this order: + +1. **Project config**: `./mellea.toml` (current directory and parent directories) +2. **User config**: `~/.config/mellea/config.toml` (Linux/macOS) or `%APPDATA%\mellea\config.toml` (Windows) + +### Precedence Rules + +Values are applied with the following precedence (highest to lowest): + +1. **Explicit parameters** passed to `start_session()` +2. **Project config** (`./mellea.toml`) +3. **User config** (`~/.config/mellea/config.toml`) +4. **Built-in defaults** + +This means you can set global defaults in your user config and override them per-project or per-call. + +## Configuration File Format + +Configuration files use [TOML](https://toml.io/) format. Here's a complete example: + +```toml +# ~/.config/mellea/config.toml + +[backend] +# Backend to use: "ollama", "hf", "openai", "watsonx", "litellm" +name = "ollama" + +# Model identifier +model_id = "granite-4-micro:3b" + +# Model options (temperature, max_tokens, etc.) +[backend.model_options] +temperature = 0.7 +max_tokens = 2048 +top_p = 0.9 + +# Backend-specific options +[backend.kwargs] +# For Ollama: +# base_url = "http://localhost:11434" + +# For OpenAI: +# organization = "org-..." + +[credentials] +# API keys (environment variables take precedence) +# openai_api_key = "sk-..." +# watsonx_api_key = "..." +# watsonx_project_id = "..." +# watsonx_url = "https://us-south.ml.cloud.ibm.com" + +# General settings +context_type = "simple" # or "chat" +log_level = "INFO" # DEBUG, INFO, WARNING, ERROR +``` + +## Configuration Options + +### Backend Settings + +#### `backend.name` +- **Type**: String +- **Options**: `"ollama"`, `"hf"`, `"openai"`, `"watsonx"`, `"litellm"` +- **Default**: `"ollama"` +- **Description**: The backend to use for model inference + +#### `backend.model_id` +- **Type**: String +- **Default**: `"granite-4-micro:3b"` +- **Description**: Model identifier. Format depends on backend: + - Ollama: `"llama3.2:1b"`, `"granite-4-micro:3b"` + - OpenAI: `"gpt-4"`, `"gpt-3.5-turbo"` + - HuggingFace: `"microsoft/DialoGPT-medium"` + - Watsonx: Model ID from IBM Watsonx catalog + +#### `backend.model_options` +- **Type**: Dictionary +- **Default**: `{}` +- **Description**: Model-specific options. Common options: + - `temperature` (float): Sampling temperature (0.0-2.0) + - `max_tokens` (int): Maximum tokens to generate + - `top_p` (float): Nucleus sampling threshold + - `top_k` (int): Top-k sampling parameter + - `frequency_penalty` (float): Frequency penalty (OpenAI) + - `presence_penalty` (float): Presence penalty (OpenAI) + +#### `backend.kwargs` +- **Type**: Dictionary +- **Default**: `{}` +- **Description**: Backend-specific constructor arguments: + - Ollama: `base_url`, `timeout` + - OpenAI: `organization`, `base_url` + - HuggingFace: `device`, `torch_dtype` + +### Credentials + +#### `credentials.openai_api_key` +- **Type**: String +- **Default**: None +- **Description**: OpenAI API key. Environment variable `OPENAI_API_KEY` takes precedence. + +#### `credentials.watsonx_api_key` +- **Type**: String +- **Default**: None +- **Description**: IBM Watsonx API key. Environment variable `WATSONX_API_KEY` takes precedence. + +#### `credentials.watsonx_project_id` +- **Type**: String +- **Default**: None +- **Description**: IBM Watsonx project ID. Environment variable `WATSONX_PROJECT_ID` takes precedence. + +#### `credentials.watsonx_url` +- **Type**: String +- **Default**: None +- **Description**: IBM Watsonx API URL. Environment variable `WATSONX_URL` takes precedence. + +### General Settings + +#### `context_type` +- **Type**: String +- **Options**: `"simple"`, `"chat"` +- **Default**: `"simple"` +- **Description**: Default context type for sessions + - `"simple"`: Each interaction is independent + - `"chat"`: Maintains conversation history + +#### `log_level` +- **Type**: String +- **Options**: `"DEBUG"`, `"INFO"`, `"WARNING"`, `"ERROR"`, `"CRITICAL"` +- **Default**: `"INFO"` +- **Description**: Logging level for Mellea + +## Example Configurations + +### Local Development with Ollama + +```toml +# ~/.config/mellea/config.toml +[backend] +name = "ollama" +model_id = "llama3.2:1b" + +[backend.model_options] +temperature = 0.8 +max_tokens = 4096 + +context_type = "chat" +log_level = "DEBUG" +``` + +### Production with OpenAI + +```toml +# ~/.config/mellea/config.toml +[backend] +name = "openai" +model_id = "gpt-4" + +[backend.model_options] +temperature = 0.7 +max_tokens = 2048 + +[credentials] +openai_api_key = "sk-..." # Better: use environment variable + +context_type = "chat" +log_level = "INFO" +``` + +### Project-Specific Override + +```toml +# ./mellea.toml (in your project directory) +[backend] +# Override user config for this project +model_id = "llama3.2:3b" + +[backend.model_options] +temperature = 0.9 # More creative for this project + +context_type = "simple" +``` + +### HuggingFace Local Models + +```toml +# ~/.config/mellea/config.toml +[backend] +name = "hf" +model_id = "microsoft/DialoGPT-medium" + +[backend.kwargs] +device = "cuda" # or "cpu" +torch_dtype = "float16" + +[backend.model_options] +temperature = 0.8 +max_tokens = 512 +``` + +## CLI Commands + +### View Current Configuration + +```bash +# Show effective configuration +m config show + +# Show which config file is being used +m config path + +# Show all possible config locations +m config where +``` + +### Initialize Configuration + +```bash +# Create user config +m config init + +# Create project config +m config init-project + +# Force overwrite existing config +m config init --force +m config init-project --force +``` + +## Security Best Practices + +### Credentials Management + +1. **Use Environment Variables**: For CI/CD and production, use environment variables instead of config files: + ```bash + export OPENAI_API_KEY="sk-..." + export WATSONX_API_KEY="..." + ``` + +2. **Don't Commit Credentials**: The `.gitignore` file excludes `mellea.toml` and `.mellea.toml` by default. User config (`~/.config/mellea/config.toml`) is outside your repository. + +3. **File Permissions**: Ensure config files with credentials have restricted permissions: + ```bash + chmod 600 ~/.config/mellea/config.toml + ``` + +4. **Use Separate Configs**: Keep credentials in user config, not project config: + - User config: API keys and credentials + - Project config: Model settings and preferences + +### Example: Secure Setup + +**User config** (`~/.config/mellea/config.toml`): +```toml +[credentials] +openai_api_key = "sk-..." +watsonx_api_key = "..." +``` + +**Project config** (`./mellea.toml`, safe to commit): +```toml +[backend] +name = "openai" +model_id = "gpt-4" + +[backend.model_options] +temperature = 0.7 +``` + +## Programmatic Usage + +Configuration is automatically loaded when you call `start_session()`: + +```python +from mellea import start_session + +# Uses config file settings +with start_session() as session: + response = session.instruct("Hello!") + +# Override config with explicit parameters +with start_session(backend_name="openai", model_id="gpt-4") as session: + response = session.instruct("Hello!") + +# Merge model_options with config +with start_session(model_options={"temperature": 0.9}) as session: + # Config temperature is overridden to 0.9 + response = session.instruct("Hello!") +``` + +## Troubleshooting + +### Config Not Loading + +1. Check which config is being used: + ```bash + m config path + ``` + +2. Verify config syntax: + ```bash + python -c "import tomllib; tomllib.load(open('mellea.toml', 'rb'))" + ``` + +3. Check for typos in field names (case-sensitive) + +### Credentials Not Working + +1. Environment variables take precedence over config files +2. Check if credentials are set in environment: + ```bash + echo $OPENAI_API_KEY + ``` + +3. Verify credentials are in the correct section: + ```toml + [credentials] # Not [backend.credentials] + openai_api_key = "..." + ``` + +### Model Not Found + +1. Verify model ID format for your backend +2. For Ollama, ensure model is pulled: + ```bash + ollama pull llama3.2:1b + ``` + +3. Check backend-specific model naming conventions + +## Advanced Topics + +### Multiple Profiles + +While not directly supported, you can use multiple config files: + +```bash +# Development +cp ~/.config/mellea/config-dev.toml ~/.config/mellea/config.toml + +# Production +cp ~/.config/mellea/config-prod.toml ~/.config/mellea/config.toml +``` + +Or use environment-specific project configs: + +```bash +# Use different configs per environment +cp mellea-dev.toml mellea.toml # For development +cp mellea-prod.toml mellea.toml # For production +``` + +### Config in CI/CD + +For CI/CD pipelines, use environment variables instead of config files: + +```yaml +# GitHub Actions example +env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + WATSONX_API_KEY: ${{ secrets.WATSONX_API_KEY }} +``` + +### Dynamic Configuration + +For dynamic configuration, use explicit parameters: + +```python +import os +from mellea import start_session + +# Load config from custom source +backend = os.getenv("MELLEA_BACKEND", "ollama") +model = os.getenv("MELLEA_MODEL", "llama3.2:1b") + +with start_session(backend_name=backend, model_id=model) as session: + response = session.instruct("Hello!") +``` + +## Migration Guide + +### From Hardcoded Settings + +**Before:** +```python +from mellea import start_session + +with start_session("ollama", "llama3.2:1b") as session: + response = session.instruct("Hello!") +``` + +**After (with config):** +```toml +# ~/.config/mellea/config.toml +[backend] +name = "ollama" +model_id = "llama3.2:1b" +``` + +```python +from mellea import start_session + +# Uses config automatically +with start_session() as session: + response = session.instruct("Hello!") +``` + +### From Environment Variables + +**Before:** +```bash +export MELLEA_BACKEND="ollama" +export MELLEA_MODEL="llama3.2:1b" +``` + +**After:** +```toml +# ~/.config/mellea/config.toml +[backend] +name = "ollama" +model_id = "llama3.2:1b" +``` + +Environment variables for credentials still work and take precedence. + +## See Also + +- [Mellea Documentation](../README.md) +- [TOML Specification](https://toml.io/) +- [XDG Base Directory Specification](https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html) diff --git a/mellea/config.py b/mellea/config.py new file mode 100644 index 00000000..19d43102 --- /dev/null +++ b/mellea/config.py @@ -0,0 +1,283 @@ +r"""Configuration file support for Mellea. + +This module provides support for TOML configuration files to set default +backends, models, credentials, and other options without hardcoding them. + +Configuration files are searched for in the current directory and parent +directories (./mellea.toml). If found, the config is used; if not, defaults +apply. + +Values are applied with the following precedence: +1. Explicit parameters passed to start_session() +2. Project config file (if exists) +3. Built-in defaults +""" + +import os +import sys +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, Field + +# Import tomllib for Python 3.11+, tomli for Python 3.10 +if sys.version_info >= (3, 11): + import tomllib +else: + try: + import tomli as tomllib # type: ignore[import-not-found] + except ImportError: + raise ImportError( + "tomli is required for Python 3.10. Install it with: pip install tomli" + ) + + +class BackendConfig(BaseModel): + """Configuration for backend settings. + + Model options can be specified generically or per-backend: + + ```toml + [backend.model_options] + temperature = 0.7 # applies to all backends + + [backend.model_options.ollama] + num_ctx = 4096 # ollama-specific + + [backend.model_options.openai] + presence_penalty = 0.5 # openai-specific + ``` + """ + + name: str | None = None + model_id: str | None = None + model_options: dict[str, Any] = Field(default_factory=dict) + kwargs: dict[str, Any] = Field(default_factory=dict) + + # Known backend names for detecting per-backend options + _BACKEND_NAMES = { + "ollama", + "hf", + "huggingface", + "openai", + "watsonx", + "litellm", + "vllm", + } + + def get_model_options_for_backend(self, backend_name: str) -> dict[str, Any]: + """Get merged model options for a specific backend. + + Merges generic options with backend-specific options. + Backend-specific options override generic ones. + + Args: + backend_name: The backend name (e.g., "ollama", "openai") + + Returns: + Merged model options dictionary + """ + result = {} + + # First, add generic options (non-dict values that aren't backend names) + for key, value in self.model_options.items(): + if key not in self._BACKEND_NAMES and not isinstance(value, dict): + result[key] = value + + # Then, merge backend-specific options (overrides generic) + backend_specific = self.model_options.get(backend_name) + if isinstance(backend_specific, dict): + result.update(backend_specific) + + return result + + +class CredentialsConfig(BaseModel): + """Configuration for API credentials.""" + + openai_api_key: str | None = None + watsonx_api_key: str | None = None + watsonx_project_id: str | None = None + watsonx_url: str | None = None + + +class MelleaConfig(BaseModel): + """Main configuration model for Mellea.""" + + backend: BackendConfig = Field(default_factory=BackendConfig) + credentials: CredentialsConfig = Field(default_factory=CredentialsConfig) + context_type: str | None = None + log_level: str | None = None + + +# Global cache for loaded config +_config_cache: tuple[MelleaConfig, Path | None] | None = None + + +def find_config_file() -> Path | None: + """Find configuration file in current directory or parent directories. + + Searches for ./mellea.toml starting from current directory and walking + up to parent directories. + + Returns: + Path to config file if found, None otherwise + """ + current = Path.cwd() + for parent in [current, *current.parents]: + project_config = parent / "mellea.toml" + if project_config.exists(): + return project_config + + return None + + +def load_config(config_path: Path | None = None) -> tuple[MelleaConfig, Path | None]: + """Load configuration from file. + + Args: + config_path: Optional explicit path to config file. If None, searches + standard locations. + + Returns: + Tuple of (MelleaConfig, config_path). config_path is None if no config + file was found. + """ + global _config_cache + + # Return cached config if available and no explicit path provided + if _config_cache is not None and config_path is None: + return _config_cache + + # Find config file if not explicitly provided + if config_path is None: + config_path = find_config_file() + + # No config file found - return empty config + if config_path is None: + config = MelleaConfig() + _config_cache = (config, None) + return config, None + + # Load and parse config file + try: + with open(config_path, "rb") as f: + data = tomllib.load(f) + + # Parse into Pydantic model + config = MelleaConfig(**data) + _config_cache = (config, config_path) + return config, config_path + + except Exception as e: + raise ValueError(f"Error loading config from {config_path}: {e}") from e + + +def get_config_path() -> Path | None: + """Get the path to the currently loaded config file. + + Returns: + Path to config file if one was loaded, None otherwise + """ + if _config_cache is None: + load_config() + return _config_cache[1] if _config_cache else None + + +def apply_credentials_to_env(config: MelleaConfig) -> None: + """Apply credentials from config to environment variables. + + Only sets environment variables if they are not already set and the + credential is present in the config. + + Args: + config: Configuration containing credentials + """ + creds = config.credentials + + # Map config fields to environment variable names + env_mappings = { + "openai_api_key": "OPENAI_API_KEY", + "watsonx_api_key": "WATSONX_API_KEY", + "watsonx_project_id": "WATSONX_PROJECT_ID", + "watsonx_url": "WATSONX_URL", + } + + for config_field, env_var in env_mappings.items(): + value = getattr(creds, config_field) + if value is not None and env_var not in os.environ: + os.environ[env_var] = value + + +def init_project_config(force: bool = False) -> Path: + """Create example project configuration file. + + Args: + force: If True, overwrite existing config file + + Returns: + Path to created config file + + Raises: + FileExistsError: If config file exists and force=False + """ + config_path = Path.cwd() / "mellea.toml" + + if config_path.exists() and not force: + raise FileExistsError( + f"Config file already exists at {config_path}. Use --force to overwrite." + ) + + # Example project config content + example_config = """# Mellea Project Configuration +# If this file exists, it will be used to configure start_session() defaults. +# Explicit parameters passed to start_session() override these settings. + +[backend] +# Backend to use (ollama, openai, huggingface, vllm, watsonx, litellm) +name = "ollama" + +# Model ID +model_id = "llama3.2:1b" + +# Generic model options (apply to all backends) +[backend.model_options] +temperature = 0.7 + +# Per-backend model options (override generic options for that backend) +# [backend.model_options.ollama] +# num_ctx = 4096 + +# [backend.model_options.openai] +# presence_penalty = 0.5 + +# Backend-specific constructor options +[backend.kwargs] +# base_url = "http://localhost:11434" # For Ollama + +[credentials] +# API keys (environment variables take precedence) +# openai_api_key = "sk-..." +# watsonx_api_key = "..." +# watsonx_project_id = "..." +# watsonx_url = "https://us-south.ml.cloud.ibm.com" + +# General settings +# context_type = "simple" # or "chat" +# log_level = "INFO" +""" + + # Write config file + with open(config_path, "w") as f: + f.write(example_config) + + return config_path + + +def clear_config_cache() -> None: + """Clear the cached configuration. + + Useful for testing or when config files change during runtime. + """ + global _config_cache + _config_cache = None diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 1d706cf5..07956140 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -31,6 +31,17 @@ from .context import SimpleContext from .sampling import RejectionSamplingStrategy + +# Sentinel value to detect when a parameter was not explicitly provided +class _Unset: + """Sentinel class to detect unset parameters.""" + + def __repr__(self) -> str: + return "" + + +_UNSET = _Unset() + # Global context variable for the context session _context_session: contextvars.ContextVar[MelleaSession | None] = contextvars.ContextVar( "context_session", default=None @@ -96,14 +107,15 @@ def backend_name_to_class(name: str) -> Any: def start_session( - backend_name: Literal["ollama", "hf", "openai", "watsonx", "litellm"] = "ollama", - model_id: str | ModelIdentifier = IBM_GRANITE_4_MICRO_3B, + backend_name: Literal["ollama", "hf", "openai", "watsonx", "litellm"] + | _Unset = _UNSET, + model_id: str | ModelIdentifier | _Unset = _UNSET, ctx: Context | None = None, *, model_options: dict | None = None, **backend_kwargs, ) -> MelleaSession: - """Start a new Mellea session. Can be used as a context manager or called directly. + r"""Start a new Mellea session. Can be used as a context manager or called directly. This function creates and configures a new Mellea session with the specified backend and model. When used as a context manager (with `with` statement), it automatically @@ -111,6 +123,10 @@ def start_session( like `instruct()`, `chat()`, `query()`, and `transform()`. When called directly, it returns a session object that can be used directly. + If a configuration file (./mellea.toml) exists in the current directory or any + parent directory, it will be loaded and used to set defaults. Explicit parameters + override config file values. + Args: backend_name: The backend to use. Options are: - "ollama": Use Ollama backend for local models @@ -152,21 +168,84 @@ def start_session( session.cleanup() ``` """ + # Load configuration file + from ..config import apply_credentials_to_env, load_config + + config, config_path = load_config() + + # Apply credentials from config to environment + apply_credentials_to_env(config) + logger = FancyLogger.get_logger() - backend_class = backend_name_to_class(backend_name) + # Apply config values with precedence: explicit params > config > defaults + # Use sentinel to detect if parameters were explicitly provided + # Resolve to properly typed variables + resolved_backend: str + if isinstance(backend_name, _Unset): + # Not explicitly provided - use config or default + resolved_backend = config.backend.name if config.backend.name else "ollama" + else: + resolved_backend = backend_name + + resolved_model: str | ModelIdentifier + if isinstance(model_id, _Unset): + # Not explicitly provided - use config or default + resolved_model = ( + config.backend.model_id + if config.backend.model_id + else IBM_GRANITE_4_MICRO_3B + ) + else: + resolved_model = model_id + + # Merge model_options: config base (with backend-specific) + explicit overrides + merged_model_options = {} + # Get config model options merged for the selected backend + config_model_options = config.backend.get_model_options_for_backend( + resolved_backend + ) + if config_model_options: + merged_model_options.update(config_model_options) + # Explicit options override config + if model_options: + merged_model_options.update(model_options) + model_options = merged_model_options if merged_model_options else None + + # Merge backend_kwargs: config base + explicit overrides + merged_backend_kwargs = {} + if config.backend.kwargs: + merged_backend_kwargs.update(config.backend.kwargs) + merged_backend_kwargs.update(backend_kwargs) + backend_kwargs = merged_backend_kwargs + + # Set log level from config if specified + if config.log_level: + import logging + + logger.setLevel(getattr(logging, config.log_level.upper(), logging.INFO)) + + backend_class = backend_name_to_class(resolved_backend) if backend_class is None: raise Exception( - f"Backend name {backend_name} unknown. Please see the docstring for `mellea.stdlib.session.start_session` for a list of options." + f"Backend name {resolved_backend} unknown. Please see the docstring for `mellea.stdlib.session.start_session` for a list of options." ) assert backend_class is not None - backend = backend_class(model_id, model_options=model_options, **backend_kwargs) + backend = backend_class( + resolved_model, model_options=model_options, **backend_kwargs + ) + # Create context based on config if not provided if ctx is None: - ctx = SimpleContext() + if config.context_type == "chat": + from .context import ChatContext + + ctx = ChatContext() + else: + ctx = SimpleContext() - # Log session configuration - if isinstance(model_id, ModelIdentifier): + # Log session configuration with config source + if isinstance(resolved_model, ModelIdentifier): # Get the backend-specific model name backend_to_attr = { "ollama": "ollama_name", @@ -176,16 +255,21 @@ def start_session( "watsonx": "watsonx_name", "litellm": "hf_model_name", } - attr = backend_to_attr.get(backend_name, "hf_model_name") + attr = backend_to_attr.get(resolved_backend, "hf_model_name") model_id_str = ( - getattr(model_id, attr, None) or model_id.hf_model_name or str(model_id) + getattr(resolved_model, attr, None) + or resolved_model.hf_model_name + or str(resolved_model) ) else: - model_id_str = model_id + model_id_str = resolved_model + + config_source = f" (config: {config_path})" if config_path else "" logger.info( - f"Starting Mellea session: backend={backend_name}, model={model_id_str}, " + f"Starting Mellea session: backend={resolved_backend}, model={model_id_str}, " f"context={ctx.__class__.__name__}" + (f", model_options={model_options}" if model_options else "") + + config_source ) return MelleaSession(backend, ctx) diff --git a/pyproject.toml b/pyproject.toml index 0d6290b5..fc4b7143 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "math_verify", # Needed for Majority Voting Sampling Strategies. "rouge_score", # Needed for Majority Voting Sampling Strategies. "llm-sandbox[docker]>=0.3.23", + "tomli>=2.0.0; python_version < '3.11'", # TOML parser for Python 3.10 ] [project.scripts] diff --git a/test/backends/test_vision_openai.py b/test/backends/test_vision_openai.py index bfcfd681..11ac8693 100644 --- a/test/backends/test_vision_openai.py +++ b/test/backends/test_vision_openai.py @@ -120,6 +120,7 @@ def test_image_block_in_instruction( # image url image_url = content_img.get("image_url") assert image_url is not None + assert isinstance(image_url, dict) assert "url" in image_url assert isinstance(image_url, dict) @@ -178,6 +179,7 @@ def test_image_block_in_chat( # image url image_url = content_img.get("image_url") assert image_url is not None + assert isinstance(image_url, dict) assert "url" in image_url assert isinstance(image_url, dict) diff --git a/test/config/test_config.py b/test/config/test_config.py new file mode 100644 index 00000000..492de570 --- /dev/null +++ b/test/config/test_config.py @@ -0,0 +1,335 @@ +"""Unit tests for Mellea configuration module.""" + +import os +from pathlib import Path + +import pytest + +from mellea.config import ( + BackendConfig, + CredentialsConfig, + MelleaConfig, + apply_credentials_to_env, + clear_config_cache, + find_config_file, + init_project_config, + load_config, +) + + +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear config cache before each test.""" + clear_config_cache() + yield + clear_config_cache() + + +@pytest.fixture +def temp_project_dir(tmp_path, monkeypatch): + """Create a temporary project directory.""" + project_dir = tmp_path / "project" + project_dir.mkdir(parents=True) + + # Change to project directory + original_cwd = Path.cwd() + os.chdir(project_dir) + + yield project_dir + + # Restore original directory + os.chdir(original_cwd) + + +class TestConfigModels: + """Test Pydantic config models.""" + + def test_backend_config_defaults(self): + """Test BackendConfig with default values.""" + config = BackendConfig() + assert config.name is None + assert config.model_id is None + assert config.model_options == {} + assert config.kwargs == {} + + def test_backend_config_with_values(self): + """Test BackendConfig with explicit values.""" + config = BackendConfig( + name="ollama", + model_id="llama3.2:1b", + model_options={"temperature": 0.7}, + kwargs={"base_url": "http://localhost:11434"}, + ) + assert config.name == "ollama" + assert config.model_id == "llama3.2:1b" + assert config.model_options["temperature"] == 0.7 + assert config.kwargs["base_url"] == "http://localhost:11434" + + def test_credentials_config_defaults(self): + """Test CredentialsConfig with default values.""" + config = CredentialsConfig() + assert config.openai_api_key is None + assert config.watsonx_api_key is None + assert config.watsonx_project_id is None + assert config.watsonx_url is None + + def test_mellea_config_defaults(self): + """Test MelleaConfig with default values.""" + config = MelleaConfig() + assert isinstance(config.backend, BackendConfig) + assert isinstance(config.credentials, CredentialsConfig) + assert config.context_type is None + assert config.log_level is None + + +class TestConfigDiscovery: + """Test configuration file discovery.""" + + def test_find_config_file_none(self, temp_project_dir): + """Test finding config when none exists.""" + config_path = find_config_file() + assert config_path is None + + def test_find_config_file_project(self, temp_project_dir): + """Test finding project config file.""" + project_config = temp_project_dir / "mellea.toml" + project_config.write_text("[backend]\nname = 'openai'") + + config_path = find_config_file() + assert config_path == project_config + + def test_find_config_file_parent_dir(self, temp_project_dir): + """Test finding config in parent directory.""" + # Create config in project dir + project_config = temp_project_dir / "mellea.toml" + project_config.write_text("[backend]\nname = 'ollama'") + + # Create and cd to subdirectory + subdir = temp_project_dir / "src" / "module" + subdir.mkdir(parents=True) + os.chdir(subdir) + + config_path = find_config_file() + assert config_path == project_config + + +class TestConfigLoading: + """Test configuration loading and parsing.""" + + def test_load_config_empty(self, temp_project_dir): + """Test loading config when no file exists.""" + config, path = load_config() + assert isinstance(config, MelleaConfig) + assert path is None + assert config.backend.name is None + + def test_load_config_basic(self, temp_project_dir): + """Test loading a basic config file.""" + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +[backend] +name = "ollama" +model_id = "llama3.2:1b" + +[backend.model_options] +temperature = 0.8 +max_tokens = 2048 +""") + + config, path = load_config() + assert path == project_config + assert config.backend.name == "ollama" + assert config.backend.model_id == "llama3.2:1b" + assert config.backend.model_options["temperature"] == 0.8 + assert config.backend.model_options["max_tokens"] == 2048 + + def test_load_config_with_credentials(self, temp_project_dir): + """Test loading config with credentials.""" + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +[credentials] +openai_api_key = "sk-test123" +watsonx_api_key = "wx-test456" +watsonx_project_id = "proj-789" +""") + + config, _path = load_config() + assert config.credentials.openai_api_key == "sk-test123" + assert config.credentials.watsonx_api_key == "wx-test456" + assert config.credentials.watsonx_project_id == "proj-789" + + def test_load_config_with_general_settings(self, temp_project_dir): + """Test loading config with general settings.""" + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +context_type = "chat" +log_level = "DEBUG" +""") + + config, _path = load_config() + assert config.context_type == "chat" + assert config.log_level == "DEBUG" + + def test_load_config_invalid_toml(self, temp_project_dir): + """Test loading invalid TOML raises error.""" + project_config = temp_project_dir / "mellea.toml" + project_config.write_text("invalid toml [[[") + + with pytest.raises(ValueError, match="Error loading config"): + load_config() + + def test_load_config_caching(self, temp_project_dir): + """Test that config is cached after first load.""" + project_config = temp_project_dir / "mellea.toml" + project_config.write_text("[backend]\nname = 'ollama'") + + # First load + config1, path1 = load_config() + + # Second load should return cached version + config2, path2 = load_config() + + assert config1 is config2 + assert path1 == path2 + + +class TestCredentialApplication: + """Test credential application to environment.""" + + def test_apply_credentials_to_env(self, monkeypatch): + """Test applying credentials to environment variables.""" + # Clear any existing env vars + for key in ["OPENAI_API_KEY", "WATSONX_API_KEY", "WATSONX_PROJECT_ID"]: + monkeypatch.delenv(key, raising=False) + + config = MelleaConfig( + credentials=CredentialsConfig( + openai_api_key="sk-test123", + watsonx_api_key="wx-test456", + watsonx_project_id="proj-789", + ) + ) + + apply_credentials_to_env(config) + + assert os.environ["OPENAI_API_KEY"] == "sk-test123" + assert os.environ["WATSONX_API_KEY"] == "wx-test456" + assert os.environ["WATSONX_PROJECT_ID"] == "proj-789" + + def test_apply_credentials_respects_existing_env(self, monkeypatch): + """Test that existing env vars are not overwritten.""" + monkeypatch.setenv("OPENAI_API_KEY", "existing-key") + + config = MelleaConfig(credentials=CredentialsConfig(openai_api_key="new-key")) + + apply_credentials_to_env(config) + + # Should keep existing value + assert os.environ["OPENAI_API_KEY"] == "existing-key" + + def test_apply_credentials_skips_none(self, monkeypatch): + """Test that None credentials are not set.""" + for key in ["OPENAI_API_KEY", "WATSONX_API_KEY"]: + monkeypatch.delenv(key, raising=False) + + config = MelleaConfig( + credentials=CredentialsConfig(openai_api_key=None, watsonx_api_key=None) + ) + + apply_credentials_to_env(config) + + assert "OPENAI_API_KEY" not in os.environ + assert "WATSONX_API_KEY" not in os.environ + + +class TestBackendModelOptionsHierarchy: + """Test per-backend model options.""" + + def test_generic_options_only(self): + """Test that generic options are returned when no backend-specific options exist.""" + config = BackendConfig(model_options={"temperature": 0.7, "max_tokens": 100}) + result = config.get_model_options_for_backend("ollama") + assert result == {"temperature": 0.7, "max_tokens": 100} + + def test_backend_specific_options(self): + """Test that backend-specific options are returned.""" + config = BackendConfig( + model_options={"temperature": 0.7, "ollama": {"num_ctx": 4096}} + ) + result = config.get_model_options_for_backend("ollama") + assert result == {"temperature": 0.7, "num_ctx": 4096} + + def test_backend_specific_overrides_generic(self): + """Test that backend-specific options override generic options.""" + config = BackendConfig( + model_options={ + "temperature": 0.7, + "ollama": {"temperature": 0.9, "num_ctx": 4096}, + } + ) + result = config.get_model_options_for_backend("ollama") + assert result == {"temperature": 0.9, "num_ctx": 4096} + + def test_different_backends_get_different_options(self): + """Test that different backends get their own specific options.""" + config = BackendConfig( + model_options={ + "temperature": 0.7, + "ollama": {"num_ctx": 4096}, + "openai": {"presence_penalty": 0.5}, + } + ) + ollama_result = config.get_model_options_for_backend("ollama") + openai_result = config.get_model_options_for_backend("openai") + + assert ollama_result == {"temperature": 0.7, "num_ctx": 4096} + assert openai_result == {"temperature": 0.7, "presence_penalty": 0.5} + + def test_backend_without_specific_options(self): + """Test that a backend without specific options gets only generic options.""" + config = BackendConfig( + model_options={"temperature": 0.7, "ollama": {"num_ctx": 4096}} + ) + result = config.get_model_options_for_backend("openai") + assert result == {"temperature": 0.7} + + def test_empty_model_options(self): + """Test with empty model options.""" + config = BackendConfig(model_options={}) + result = config.get_model_options_for_backend("ollama") + assert result == {} + + +class TestConfigInitialization: + """Test config file initialization.""" + + def test_init_project_config(self, temp_project_dir): + """Test creating project config file.""" + config_path = init_project_config() + + assert config_path.exists() + assert config_path == temp_project_dir / "mellea.toml" + + # Verify content is valid TOML + content = config_path.read_text() + assert "[backend]" in content + + def test_init_project_config_exists(self, temp_project_dir): + """Test that init fails if project config exists without force.""" + config_path = temp_project_dir / "mellea.toml" + config_path.write_text("existing") + + with pytest.raises(FileExistsError, match="already exists"): + init_project_config(force=False) + + def test_init_project_config_force(self, temp_project_dir): + """Test that force overwrites existing project config.""" + config_path = temp_project_dir / "mellea.toml" + config_path.write_text("existing") + + new_path = init_project_config(force=True) + + assert new_path == config_path + content = config_path.read_text() + assert "existing" not in content + assert "[backend]" in content diff --git a/test/config/test_config_integration.py b/test/config/test_config_integration.py new file mode 100644 index 00000000..0c62efef --- /dev/null +++ b/test/config/test_config_integration.py @@ -0,0 +1,270 @@ +"""Integration tests for Mellea configuration with start_session().""" + +import os +from pathlib import Path + +import pytest + +from mellea.config import clear_config_cache, init_project_config +from mellea.stdlib.session import start_session + + +@pytest.fixture(autouse=True) +def clear_cache_and_env(): + """Clear config cache and environment variables before each test.""" + clear_config_cache() + + # Store original env vars + original_env = {} + env_vars = [ + "OPENAI_API_KEY", + "WATSONX_API_KEY", + "WATSONX_PROJECT_ID", + "WATSONX_URL", + ] + + for var in env_vars: + if var in os.environ: + original_env[var] = os.environ[var] + del os.environ[var] + + yield + + # Restore original env vars + for var in env_vars: + if var in os.environ: + del os.environ[var] + for var, value in original_env.items(): + os.environ[var] = value + + clear_config_cache() + + +@pytest.fixture +def temp_project_dir(tmp_path, monkeypatch): + """Create a temporary project directory.""" + project_dir = tmp_path / "project" + project_dir.mkdir(parents=True) + + # Change to project directory + original_cwd = Path.cwd() + os.chdir(project_dir) + + yield project_dir + + # Restore original directory + os.chdir(original_cwd) + + +class TestSessionWithConfig: + """Test start_session() with configuration files.""" + + @pytest.mark.ollama + def test_session_uses_project_config(self, temp_project_dir): + """Test that start_session() uses project config.""" + # Create project config + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +[backend] +name = "ollama" +model_id = "llama3.2:1b" + +[backend.model_options] +temperature = 0.8 +max_tokens = 100 +""") + + # Start session without explicit parameters + with start_session() as session: + assert session.backend.model_id == "llama3.2:1b" + + @pytest.mark.ollama + def test_session_explicit_overrides_config(self, temp_project_dir): + """Test that explicit parameters override config.""" + # Create project config + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +[backend] +name = "ollama" +model_id = "llama3.2:1b" +""") + + # Start session with explicit model_id + with start_session(model_id="granite4:3b") as session: + assert session.backend.model_id == "granite4:3b" + + @pytest.mark.ollama + def test_session_without_config(self, temp_project_dir): + """Test that start_session() works without config files.""" + # No config files created + + # Start session with defaults + with start_session() as session: + # Should use default backend and model + assert session.backend is not None + assert session.ctx is not None + + @pytest.mark.ollama + def test_session_credentials_from_config(self, temp_project_dir): + """Test that credentials from config are applied to environment.""" + # Create project config with credentials + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +[backend] +name = "ollama" + +[credentials] +openai_api_key = "sk-test-from-config" +watsonx_api_key = "wx-test-from-config" +""") + + # Start session + with start_session() as _session: + # Credentials should be in environment + assert os.environ.get("OPENAI_API_KEY") == "sk-test-from-config" + assert os.environ.get("WATSONX_API_KEY") == "wx-test-from-config" + + @pytest.mark.ollama + def test_session_env_overrides_config_credentials( + self, temp_project_dir, monkeypatch + ): + """Test that environment variables override config credentials.""" + # Set environment variable + monkeypatch.setenv("OPENAI_API_KEY", "sk-from-env") + + # Create project config with different credential + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +[backend] +name = "ollama" + +[credentials] +openai_api_key = "sk-from-config" +""") + + # Start session + with start_session() as _session: + # Environment variable should take precedence + assert os.environ.get("OPENAI_API_KEY") == "sk-from-env" + + +class TestConfigPrecedence: + """Test configuration precedence in real scenarios.""" + + @pytest.mark.ollama + def test_explicit_overrides_project(self, temp_project_dir): + """Test complete precedence: explicit > project > default.""" + # Create project config + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +[backend] +name = "ollama" +model_id = "llama3.2:1b" + +[backend.model_options] +temperature = 0.7 +""") + + # Test 1: No explicit params - uses project config + with start_session() as session: + assert session.backend.model_id == "llama3.2:1b" + + # Test 2: Explicit model_id - overrides project config + with start_session(model_id="granite4:3b") as session: + assert session.backend.model_id == "granite4:3b" + + +class TestConfigWithDifferentBackends: + """Test configuration with different backend types.""" + + @pytest.mark.ollama + def test_ollama_backend_from_config(self, temp_project_dir): + """Test Ollama backend configuration.""" + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +[backend] +name = "ollama" +model_id = "llama3.2:1b" + +[backend.kwargs] +base_url = "http://localhost:11434" +""") + + with start_session() as session: + assert session.backend.model_id == "llama3.2:1b" + + @pytest.mark.openai + @pytest.mark.requires_api_key + def test_openai_backend_from_config(self, temp_project_dir, monkeypatch): + """Test OpenAI backend configuration.""" + # Set API key in environment + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key") + + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +[backend] +name = "openai" +model_id = "gpt-3.5-turbo" + +[backend.model_options] +temperature = 0.7 +max_tokens = 100 +""") + + with start_session() as session: + assert session.backend.model_id == "gpt-3.5-turbo" + + +class TestConfigCaching: + """Test that config caching works correctly with sessions.""" + + @pytest.mark.ollama + def test_config_cached_across_sessions(self, temp_project_dir): + """Test that config is cached and reused across multiple sessions.""" + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +[backend] +name = "ollama" +model_id = "llama3.2:1b" +""") + + # First session + with start_session() as session1: + model1 = session1.backend.model_id + + # Second session - should use cached config + with start_session() as session2: + model2 = session2.backend.model_id + + assert model1 == model2 == "llama3.2:1b" + + @pytest.mark.ollama + def test_config_cache_cleared(self, temp_project_dir): + """Test that clearing cache forces config reload.""" + project_config = temp_project_dir / "mellea.toml" + project_config.write_text(""" +[backend] +name = "ollama" +model_id = "llama3.2:1b" +""") + + # First session + with start_session() as session1: + model1 = session1.backend.model_id + + # Modify config + project_config.write_text(""" +[backend] +name = "ollama" +model_id = "llama3.2:3b" +""") + + # Clear cache + clear_config_cache() + + # Second session - should reload config + with start_session() as session2: + model2 = session2.backend.model_id + + assert model1 == "llama3.2:1b" + assert model2 == "llama3.2:3b" diff --git a/uv.lock b/uv.lock index 1d8fe940..0f3b15a3 100644 --- a/uv.lock +++ b/uv.lock @@ -3392,6 +3392,7 @@ dependencies = [ { name = "pydantic" }, { name = "requests" }, { name = "rouge-score" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "typer" }, { name = "types-requests" }, { name = "types-tqdm" }, @@ -3505,6 +3506,7 @@ requires-dist = [ { name = "pydantic" }, { name = "requests", specifier = ">=2.32.3" }, { name = "rouge-score" }, + { name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0.0" }, { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.53.2,<5" }, { name = "transformers", marker = "extra == 'vllm'", specifier = "<4.54.0" }, { name = "trl", marker = "extra == 'hf'", specifier = "==0.19.1" },