diff --git a/crates/flyllm/.cargo-ok b/crates/flyllm/.cargo-ok new file mode 100644 index 0000000..5f8b795 --- /dev/null +++ b/crates/flyllm/.cargo-ok @@ -0,0 +1 @@ +{"v":1} \ No newline at end of file diff --git a/crates/flyllm/.cargo_vcs_info.json b/crates/flyllm/.cargo_vcs_info.json new file mode 100644 index 0000000..5966f42 --- /dev/null +++ b/crates/flyllm/.cargo_vcs_info.json @@ -0,0 +1,6 @@ +{ + "git": { + "sha1": "d0c37eb8177f64585ededfdc6b57cb3550106464" + }, + "path_in_vcs": "" +} \ No newline at end of file diff --git a/crates/flyllm/.gitignore b/crates/flyllm/.gitignore new file mode 100644 index 0000000..1a6970a --- /dev/null +++ b/crates/flyllm/.gitignore @@ -0,0 +1,4 @@ +/target +Cargo.lock +CLAUDE.md +debug_folder \ No newline at end of file diff --git a/crates/flyllm/CHANGELOG.md b/crates/flyllm/CHANGELOG.md new file mode 100644 index 0000000..e549465 --- /dev/null +++ b/crates/flyllm/CHANGELOG.md @@ -0,0 +1,42 @@ +# Changelog + +All notable changes to FlyLLM will be documented in this file. + +## [0.3.1] - 2025-08-25 +### Added +- Upon request, the conversion from `&str` to ProviderType has been implemented + +## [0.3.0] - 2025-08-06 +### Added +- Refactored the internals of FlyLLM, making it way simpler to modify and understand +- Added optional debugging to LlmManager, allowing the user to store all requests and their metadata to JSON files automatically + +## [0.2.3] - 2025-06-06 +### Added +- Rate limiting with wait for whenever all providers are overloaded + +## [0.2.2] - 2025-05-19 +### Added +- Made the library entirely asynchronous, making the library more suitable for use in async contexts + +## [0.2.1] - 2025-05-12 +### Added +- Capability of listing all available models from all providers + +## [0.2.0] - 2025-04-30 +### Added +- Ollama provider support +- Builder pattern for easier configuration +- Aggregation of more basic routing strategies +- Added optional custom endpoint configuration for any provider + +## [0.1.0] - 2025-04-27 +### Added +- Initial release +- Multiple Provider Support (OpenAI, Anthropic, Google, Mistral) +- Task-Based Routing +- Load Balancing +- Failure Handling +- Parallel Processing +- Custom Parameters +- Usage Tracking \ No newline at end of file diff --git a/crates/flyllm/Cargo.toml b/crates/flyllm/Cargo.toml new file mode 100644 index 0000000..bbca4dd --- /dev/null +++ b/crates/flyllm/Cargo.toml @@ -0,0 +1,91 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +name = "flyllm" +version = "0.3.1" +authors = ["Pablo Rodríguez "] +build = false +autolib = false +autobins = false +autoexamples = false +autotests = false +autobenches = false +description = "A Rust library for unifying LLM backends as an abstraction layer with load balancing." +readme = "README.md" +keywords = [ + "llm", + "ai", + "openai", + "anthropic", + "load-balancing", +] +license = "MIT" +repository = "https://github.com/rodmarkun/flyllm" + +[lib] +name = "flyllm" +path = "src/lib.rs" + +[[example]] +name = "task_routing" +path = "examples/task_routing.rs" + +[dependencies.async-trait] +version = "0.1.88" + +[dependencies.env_logger] +version = "0.10" + +[dependencies.futures] +version = "0.3.31" + +[dependencies.json] +version = "0.12.4" + +[dependencies.log] +version = "0.4" + +[dependencies.rand] +version = "0.9.1" + +[dependencies.reqwest] +version = "0.12.15" +features = ["json"] + +[dependencies.serde] +version = "1.0" +features = ["derive"] + +[dependencies.serde_json] +version = "1.0.140" + +[dependencies.tokio] +version = "1" +features = [ + "macros", + "rt-multi-thread", + "sync", +] + +[dependencies.url] +version = "2.5.4" + +[dev-dependencies.tokio] +version = "1" +features = [ + "macros", + "rt-multi-thread", + "sync", +] +[dev-dependencies.httpmock] +version = "0.7" diff --git a/crates/flyllm/Cargo.toml.orig b/crates/flyllm/Cargo.toml.orig new file mode 100644 index 0000000..41bae9d --- /dev/null +++ b/crates/flyllm/Cargo.toml.orig @@ -0,0 +1,29 @@ +[package] +name = "flyllm" +version = "0.3.1" +edition = "2021" +description = "A Rust library for unifying LLM backends as an abstraction layer with load balancing." +authors = ["Pablo Rodríguez "] +license = "MIT" +repository = "https://github.com/rodmarkun/flyllm" +readme = "README.md" +keywords = ["llm", "ai", "openai", "anthropic", "load-balancing"] + +[dependencies] +async-trait = "0.1.88" +futures = "0.3.31" +json = "0.12.4" +reqwest = { version = "0.12.15", features = ["json"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0.140" +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } +log = "0.4" +env_logger = "0.10" +rand = "0.9.1" +url = "2.5.4" + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } + +[dev-dependencies] +httpmock = "0.7" diff --git a/crates/flyllm/LICENSE b/crates/flyllm/LICENSE new file mode 100644 index 0000000..86e4229 --- /dev/null +++ b/crates/flyllm/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Pablo Rodríguez Martín + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/crates/flyllm/README.md b/crates/flyllm/README.md new file mode 100644 index 0000000..8909bd1 --- /dev/null +++ b/crates/flyllm/README.md @@ -0,0 +1,335 @@ +# FlyLLM + +FlyLLM is a Rust library that provides a load-balanced, multi-provider client for Large Language Models. It enables developers to seamlessly work with multiple LLM providers (OpenAI, Anthropic, Google, Mistral...) through a unified API with request routing, load balancing, and failure handling. + +
+ FlyLLM Logo +
+ +## Features + +- **Multiple Provider Support** 🌐: Unified interface for OpenAI, Anthropic, Google, Ollama and Mistral APIs +- **Task-Based Routing** 🧭: Route requests to the most appropriate provider based on predefined tasks +- **Load Balancing** ⚖️: Automatically distribute load across multiple provider instances +- **Failure Handling** 🛡️: Retry mechanisms and automatic failover between providers +- **Parallel Processing** ⚡: Process multiple requests concurrently for improved throughput +- **Custom Parameters** 🔧: Set provider-specific parameters per task or request +- **Usage Tracking** 📊: Monitor token consumption for cost management +- **Debug Logging** 🔍: Optional request/response logging to JSON files for debugging and analysis +- **Builder Pattern Configuration** ✨: Fluent and readable setup for tasks and providers. + +## Installation + +Add FlyLLM to your `Cargo.toml`: + +```toml +[dependencies] +flyllm = "0.3.0" +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } # For async runtime +``` + +## Architecture + +![Open Escordia_2025-04-25_13-41-55](https://github.com/user-attachments/assets/a56e375b-0bca-4de6-a4d3-c000812105d5) + +The LLM Manager (`LLMManager`) serves as the core component for orchestrating language model interactions in your application. It manages multiple LLM instances (`LLMInstance`), each defined by a model, API key, and supported tasks (`TaskDefinition`). + +When your application sends a generation request (`GenerationRequest`), the manager automatically selects an appropriate instance based on configurable strategies (Last Recently Used, Quickest Response Time, etc.) and returns the generated response by the LLM (`LLMResponse`). This design prevents rate limiting by distributing requests across multiple instances (even of the same model) with different API keys. + +The manager handles failures gracefully by re-routing requests to alternative instances. It also supports parallel execution for significant performance improvements when processing multiple requests simultaneously! + +You can define default parameters (temperature, max_tokens) for each task while retaining the ability to override these settings in specific requests. The system also tracks token usage across all instances: + +``` +--- Token Usage Statistics --- +ID Provider Model Prompt Tokens Completion Tokens Total Tokens +----------------------------------------------------------------------------------------------- +0 mistral mistral-small-latest 109 897 1006 +1 anthropic claude-3-sonnet-20240229 133 1914 2047 +2 anthropic claude-3-opus-20240229 51 529 580 +3 google gemini-2.0-flash 0 0 0 +4 openai gpt-3.5-turbo 312 1003 1315 +``` + +## Usage Examples + +The following sections describe the usage of flyllm. You can also check out the example given in `examples/task_routing.rs`! To activate FlyLLM's debug messages by setting the environment variable `RUST_LOG` to the value `"debug"`. + +### Quick Start + +```rust +use flyllm::{ + ProviderType, LlmManager, GenerationRequest, LlmManagerResponse, TaskDefinition, LlmResult, + use_logging, // Helper to setup basic logging +}; +use std::env; // To read API keys from environment variables + +#[tokio::main] +async fn main() -> LlmResult<()> { // Use LlmResult for error handling + // Initialize logging (optional, requires log and env_logger crates) + use_logging(); + + // Retrieve API key from environment + let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + + // Configure the LLM manager using the builder pattern + let manager = LlmManager::builder() + // Define a task with specific default parameters + .define_task( + TaskDefinition::new("summary") + .with_max_tokens(500) // Set max tokens for this task + .with_param("temperature", 0.3) // Set temperature for this task + ) + // Add a provider instance and specify the tasks it supports + .add_instance( + ProviderType::OpenAI, + "gpt-3.5-turbo", + &openai_api_key, // Pass the API key + ) + .supports("summary") // Link the provider to the "summary" task + // Finalize the manager configuration + .build().await?; // Use await and '?' for error propagation + + // Create a generation request using the builder pattern + let request = GenerationRequest::builder( + "Summarize the following text: Climate change refers to long-term shifts in temperatures..." + ) + .task("summary") // Specify the task for routing + .build(); + + // Generate response sequentially (for a single request) + // The Manager will automatically choose the configured OpenAI provider for the "summary" task. + let responses = manager.generate_sequentially(vec![request]).await; + + // Handle the response + if let Some(response) = responses.first() { + if response.success { + println!("Response: {}", response.content); + } else { + println!("Error: {}", response.error.as_ref().unwrap_or(&"Unknown error".to_string())); + } + } + + // Print token usage statistics + manager.print_token_usage().await; + + Ok(()) +} +``` + +### Adding Multiple Providers + +Configure the LlmManager with various providers, each supporting different tasks. + +```rust +use flyllm::{ProviderType, LlmManager, TaskDefinition, LlmResult}; +use std::env; +use std::path::PathBuf; + +async fn configure_manager() -> LlmResult { + // --- API Keys --- + let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let anthropic_api_key = env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"); + let mistral_api_key = env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set"); + let google_api_key = env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY not set"); + // Ollama typically doesn't require an API key for local instances + + let manager = LlmManager::builder() + // Define all tasks first + .define_task(TaskDefinition::new("summary").with_max_tokens(500).with_param("temperature", 0.3)) + .define_task(TaskDefinition::new("qa").with_max_tokens(1000)) + .define_task(TaskDefinition::new("creative_writing").with_max_tokens(1500).with_temperature(0.9)) + .define_task(TaskDefinition::new("code_generation").with_param("temperature", 0.1)) + .define_task(TaskDefinition::new("translation")) // Task with default provider parameters + + // Add OpenAI provider supporting summary and QA + .add_instance(ProviderType::OpenAI, "gpt-4-turbo", &openai_api_key) + .supports_many(&["summary", "qa"]) // Assign multiple tasks + + // Add Anthropic provider supporting creative writing and code generation + .add_instance(ProviderType::Anthropic, "claude-3-sonnet-20240229", &anthropic_api_key) + .supports("creative_writing") + .supports("code_generation") + + // Add Mistral provider supporting summary and translation + .add_instance(ProviderType::Mistral, "mistral-large-latest", &mistral_api_key) + .supports("summary") + .supports("translation") + + // Add Google (Gemini) provider supporting QA and creative writing + .add_instance(ProviderType::Google, "gemini-1.5-pro", &google_api_key) + .supports("qa") + .supports("creative_writing") + + // Add a local Ollama provider supporting summary and code generation + .add_instance(ProviderType::Ollama, "llama3:8b", "") // API key often empty for local Ollama + .supports("summary") + .supports("code_generation") + .custom_endpoint("http://localhost:11434/api/chat") // Optional: Specify if not default + + // Optional: Enable debug logging to JSON files + .debug_folder(PathBuf::from("debug_logs")) // All request/response data will be logged here + + // Finalize configuration + .build().await?; + + println!("LlmManager configured with multiple providers."); + Ok(manager) +} +``` + +### Task-Based Routing + +Define tasks with specific default parameters and create requests targeting those tasks. FlyLLM routes the request to a provider configured to support that task. + +```rust +use flyllm::{LlmManager, GenerationRequest, TaskDefinition, LlmResult}; +use std::env; + +// Assume manager is configured as shown in "Adding Multiple Providers" +async fn route_by_task(manager: LlmManager) -> LlmResult<()> { + + // Define tasks centrally in the builder (shown conceptually here) + // LlmManager::builder() + // .define_task( + // TaskDefinition::new("summary") + // .with_max_tokens(500) + // .with_temperature(0.3) + // ) + // .define_task( + // TaskDefinition::new("creative_writing") + // .with_max_tokens(1500) + // .with_temperature(0.9) + // ) + // // ... add providers supporting these tasks ... + // .build()?; + + // Create requests with different tasks using the request builder + let summary_request = GenerationRequest::builder( + "Summarize the following article about renewable energy: ..." + ) + .task("summary") // This request will be routed to providers supporting "summary" + .build(); + + let story_request = GenerationRequest::builder( + "Write a short story about a futuristic city powered by algae." + ) + .task("creative_writing") // This request uses the "creative_writing" task defaults + .build(); + + // Example: Override task defaults for a specific request + let short_story_request = GenerationRequest::builder( + "Write a VERY short story about a time traveler meeting a dinosaur." + ) + .task("creative_writing") // Based on "creative_writing" task... + .max_tokens(200) // ...but override max_tokens for this specific request + .param("temperature", 0.95) // Can override any parameter + .build(); + + // Process requests (e.g., sequentially) + let requests = vec![summary_request, story_request, short_story_request]; + let results = manager.generate_sequentially(requests).await; + + // Handle results... + for (i, result) in results.iter().enumerate() { + println!("Request {}: Success = {}, Content/Error = {}", + i + 1, + result.success, + if result.success { &result.content[..std::cmp::min(result.content.len(), 50)] } else { result.error.as_deref().unwrap_or("Unknown") } + ); + } + + Ok(()) +} +``` + +### Parallel Processing + +```rust +// Process in parallel +let parallel_results = manager.batch_generate(requests).await; + +// Process each result +for result in parallel_results { + if result.success { + println!("Success: {}", result.content); + } else { + println!("Error: {}", result.error.as_ref().unwrap_or(&"Unknown error".to_string())); + } +} +``` + +### Debug Logging + +FlyLLM supports optional debug logging to help you analyze requests and responses. When enabled, it creates JSON files with detailed information about each generation call. + +```rust +use flyllm::{ProviderType, LlmManager, GenerationRequest, TaskDefinition, LlmResult}; +use std::path::PathBuf; + +async fn setup_with_debug_logging() -> LlmResult { + let manager = LlmManager::builder() + .define_task(TaskDefinition::new("summary").with_max_tokens(500)) + .add_instance(ProviderType::OpenAI, "gpt-3.5-turbo", &api_key) + .supports("summary") + + // Enable debug logging - creates folder structure: debug_logs/timestamp/instance_id_provider_model/debug.json + .debug_folder(PathBuf::from("debug_logs")) + + .build().await?; + + Ok(manager) +} +``` + +The debug files contain structured JSON with: +- **Metadata**: timestamp, instance details, request duration +- **Input**: prompt, task, parameters used +- **Output**: success status, generated content or error, token usage + +Example debug file structure: +```json +[ + { + "metadata": { + "timestamp": 1703123456, + "instance_id": 0, + "instance_name": "openai", + "instance_model": "gpt-3.5-turbo", + "duration_ms": 1250 + }, + "input": { + "prompt": "Summarize this text...", + "task": "summary", + "parameters": { + "max_tokens": 500, + "temperature": 0.3 + } + }, + "output": { + "success": true, + "content": "This text discusses...", + "usage": { + "prompt_tokens": 45, + "completion_tokens": 123, + "total_tokens": 168 + } + } + } +] +``` + +## License + +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. + +## Contributing + +Contributions are always welcome! If you're interested in contributing to FlyLLM, please fork the repository and create a new branch for your changes. When you're done with your changes, submit a pull request to merge your changes into the main branch. + +## Supporting FlyLLM + +If you want to support FlyLLM, you can: +- **Star** :star: the project in Github! +- **Donate** :coin: to my [Ko-fi](https://ko-fi.com/rodmarkun) page! +- **Share** :heart: the project with your friends! diff --git a/crates/flyllm/examples/task_routing.rs b/crates/flyllm/examples/task_routing.rs new file mode 100644 index 0000000..ef916ad --- /dev/null +++ b/crates/flyllm/examples/task_routing.rs @@ -0,0 +1,278 @@ +use flyllm::{ + use_logging, GenerationRequest, LlmManager, LlmManagerResponse, LlmResult, ModelDiscovery, + ModelInfo, ProviderType, TaskDefinition, +}; +use futures::future::join_all; +use log::info; +use std::collections::HashMap; +use std::env; +use std::path::PathBuf; +use std::time::Instant; + +#[tokio::main] +async fn main() -> LlmResult<()> { + env::set_var("RUST_LOG", "debug"); // Uncomment this to see debugging messages + use_logging(); // Setup logging + + info!("Starting Task Routing Example"); + + // --- API Keys --- + let anthropic_api_key = env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"); + let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let mistral_api_key = env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set"); + let google_api_key = env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY not set"); + + // --- Fetch and print available models --- + print_available_models( + &anthropic_api_key, + &openai_api_key, + &mistral_api_key, + &google_api_key, + ) + .await; + + // --- Configure Manager using Builder --- + let manager = LlmManager::builder() + // Define tasks centrally + .define_task( + TaskDefinition::new("summary") + .with_max_tokens(500) // Use helper or with_param + .with_param("temperature", 0.3), // Use generic method + ) + .define_task( + TaskDefinition::new("creative_writing") + .with_max_tokens(1500) + .with_temperature(0.9), + ) + .define_task(TaskDefinition::new("code_generation")) + .define_task( + TaskDefinition::new("short_poem") + .with_max_tokens(100) + .with_temperature(0.8), + ) + // Add providers and link tasks by name + // .add_instance(ProviderType::Ollama, "llama2:7b", "") + // .supports("summary") // Chain configuration for this provider + // .supports("code_generation") + // .custom_endpoint("http://localhost:11434/api/chat") // This is the default Ollama endpoint, but we can specify custom ones. + // // .enabled(true) // Optional, defaults to true + .add_instance( + ProviderType::Mistral, + "mistral-large-latest", + &mistral_api_key, + ) + .supports("summary") + .supports("code_generation") + .add_instance( + ProviderType::Anthropic, + "claude-3-sonnet-20240229", + &anthropic_api_key, + ) + .supports("summary") + .supports("creative_writing") + .supports("code_generation") + .add_instance( + ProviderType::Anthropic, + "claude-3-opus-20240229", + &anthropic_api_key, + ) + .supports_many(&["creative_writing", "short_poem"]) // Example using supports_many + .add_instance(ProviderType::Google, "gemini-2.0-flash", &google_api_key) + .supports("short_poem") + .add_instance(ProviderType::OpenAI, "gpt-3.5-turbo", &openai_api_key) + .supports("summary") + // Example: Add a disabled provider + // .add_instance(ProviderType::OpenAI, "gpt-4", &openai_api_key) + // .supports("creative_writing") + // .supports("code_generation") + // .enabled(false) // Explicitly disable + // Adds a debug folder for debugging all requests made + .debug_folder(PathBuf::from("debug_folder")) + // Finalize the manager configuration + .build() + .await?; // Added .await here + + // Get provider count asynchronously + let provider_count = manager.get_provider_count().await; + info!("LlmManager configured with {} providers.", provider_count); + + // --- Define Requests using Builder --- + let requests = vec![ + GenerationRequest::builder( + "Summarize the following text: Climate change refers to long-term shifts...", + ) + .task("summary") + .build(), + GenerationRequest::builder("Write a short story about a robot discovering emotions.") + .task("creative_writing") + .build(), + GenerationRequest::builder( + "Write a Python function that calculates the Fibonacci sequence up to n terms.", + ) + .task("code_generation") + .build(), + // Example overriding parameters for a specific request + GenerationRequest::builder("Write a VERY short poem about the rain.") + .task("creative_writing") // Target creative writing task defaults... + .max_tokens(50) // ...but override max_tokens just for this request + // .param("temperature", 0.95) // Could override temperature too + .build(), + GenerationRequest::builder("Write a rust program to sum two input numbers via console.") + .task("code_generation") + .build(), + GenerationRequest::builder("Craft a haiku about a silent dawn.") + .task("short_poem") + .build(), + ]; + info!("Defined {} requests using builder pattern.", requests.len()); + + // --- Run Requests (Sequential and Parallel) --- + println!("\n=== Running requests sequentially... ==="); + let sequential_start = Instant::now(); + let sequential_results = manager.generate_sequentially(requests.clone()).await; + let sequential_duration = sequential_start.elapsed(); + println!( + "Sequential processing completed in {:?}", + sequential_duration + ); + print_results(&sequential_results); + + println!("\n=== Running requests in parallel... ==="); + let parallel_start = Instant::now(); + let parallel_results = manager.batch_generate(requests).await; // Use original requests vec + let parallel_duration = parallel_start.elapsed(); + println!("Parallel processing completed in {:?}", parallel_duration); + print_results(¶llel_results); + + info!("Task Routing Example Finished."); + + // --- Comparison --- + println!("\n--- Comparison ---"); + println!("Sequential Duration: {:?}", sequential_duration); + println!("Parallel Duration: {:?}", parallel_duration); + + if parallel_duration < sequential_duration && parallel_duration.as_nanos() > 0 { + let speedup = sequential_duration.as_secs_f64() / parallel_duration.as_secs_f64(); + println!("Parallel execution was roughly {:.2}x faster.", speedup); + } else if parallel_duration >= sequential_duration { + println!("Parallel execution was not faster (or was equal) in this run."); + } else { + println!("Parallel execution finished too quickly to measure speedup reliably."); + } + + // Print token usage asynchronously + manager.print_token_usage().await; + + Ok(()) +} + +/// Fetches models from all providers and prints them in a table format +async fn print_available_models( + anthropic_api_key: &str, + openai_api_key: &str, + mistral_api_key: &str, + google_api_key: &str, +) { + println!("\n=== AVAILABLE MODELS ==="); + + // Clone the API keys for use in the spawned tasks + let anthropic_key = anthropic_api_key.to_string(); + let openai_key = openai_api_key.to_string(); + let mistral_key = mistral_api_key.to_string(); + let google_key = google_api_key.to_string(); + + // Fetch models from different providers in parallel + let futures = vec![ + tokio::spawn(async move { ModelDiscovery::list_anthropic_models(&anthropic_key).await }), + tokio::spawn(async move { ModelDiscovery::list_openai_models(&openai_key).await }), + tokio::spawn(async move { ModelDiscovery::list_mistral_models(&mistral_key).await }), + tokio::spawn(async move { ModelDiscovery::list_google_models(&google_key).await }), + tokio::spawn(async { ModelDiscovery::list_ollama_models(None).await }), + ]; + + let results = join_all(futures).await; + + // Create a map to store models by provider + let mut models_by_provider: HashMap> = HashMap::new(); + + // Define the provider order for each index + let providers = [ + ProviderType::Anthropic, + ProviderType::OpenAI, + ProviderType::Mistral, + ProviderType::Google, + ProviderType::Ollama, + ]; + + // Process results + for (i, result) in results.into_iter().enumerate() { + if i >= providers.len() { + continue; + } + let provider = providers[i]; + + match result { + Ok(Ok(models)) => { + models_by_provider.insert(provider, models); + } + Ok(Err(e)) => { + println!("Error fetching {} models: {}", provider, e); + } + Err(e) => { + println!("Task error fetching {} models: {}", provider, e); + } + } + } + + // Print models in a table format + println!("\n{:<15} {:<40}", "PROVIDER", "MODEL NAME"); + println!("{}", "=".repeat(55)); + + // Print models in the specified provider order + for provider in providers.iter() { + if let Some(models) = models_by_provider.get(provider) { + for model in models { + println!("{:<15} {:<40}", provider.to_string(), model.name); + } + // Add a separator between providers + println!("{}", "-".repeat(55)); + } + } +} + +fn print_results(results: &[LlmManagerResponse]) { + println!("\n--- Request Results ---"); + + let task_labels = [ + "Summary Request", + "Creative Writing Request", + "Code Generation Request", + "Short Poem Request (Override)", + "Rust Code Request", + "Haiku Request", + ]; + + for (i, result) in results.iter().enumerate() { + let task_label = task_labels + .get(i) + .map_or_else(|| "Unknown Task", |&name| name); + println!("{}:", task_label); + if result.success { + let content_preview = result.content.chars().take(150).collect::(); + let ellipsis = if result.content.chars().count() > 150 { + "..." + } else { + "" + }; + println!(" Success: {}{}\n", content_preview, ellipsis); + } else { + println!( + " Error: {}\n", + result + .error + .as_ref() + .unwrap_or(&"Unknown error".to_string()) + ); + } + } +} diff --git a/crates/flyllm/src/constants.rs b/crates/flyllm/src/constants.rs new file mode 100644 index 0000000..fe78367 --- /dev/null +++ b/crates/flyllm/src/constants.rs @@ -0,0 +1,28 @@ +/// Common constants used throughout the crate + +// General +pub const DEFAULT_MAX_TOKENS: u32 = 1024; +pub const DEFAULT_MAX_TRIES: usize = 5; + +// OpenAI +pub const OPENAI_API_ENDPOINT: &str = "https://api.openai.com/v1/chat/completions"; + +// Anthropic +pub const ANTHROPIC_API_ENDPOINT: &str = "https://api.anthropic.com/v1/messages"; +pub const ANTHROPIC_API_VERSION: &str = "2023-06-01"; + +// Mistral +pub const MISTRAL_API_ENDPOINT: &str = "https://api.mistral.ai/v1/chat/completions"; + +// Google +pub const GOOGLE_API_ENDPOINT_PREFIX: &str = "https://generativelanguage.googleapis.com"; + +// Ollama +pub const OLLAMA_API_ENDPOINT: &str = "http://localhost:11434/api/chat"; + +// LM Studio +pub const LM_STUDIO_API_ENDPOINT: &str = "http://127.0.0.1:1234/v1/chat/completions"; + +// Rate limiting +pub const DEFAULT_RATE_LIMIT_WAIT_SECS: u64 = 2; +pub const MAX_RATE_LIMIT_WAIT_SECS: u64 = 60; diff --git a/crates/flyllm/src/errors.rs b/crates/flyllm/src/errors.rs new file mode 100644 index 0000000..04bfd00 --- /dev/null +++ b/crates/flyllm/src/errors.rs @@ -0,0 +1,81 @@ +use serde_json; +use std::error::Error; +use std::fmt; + +/// Custom error types for LLM operations +#[derive(Debug)] +pub enum LlmError { + /// Error from the HTTP client + RequestError(reqwest::Error), + /// Error from the API provider + ApiError(String), + /// Rate limiting error + RateLimit(String), + /// Parsing error + ParseError(String), + /// Provider is disabled + ProviderDisabled(String), + /// Configuration error + ConfigError(String), +} + +impl fmt::Display for LlmError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LlmError::RequestError(err) => write!(f, "Request error: {}", err), + LlmError::ApiError(msg) => write!(f, "API error: {}", msg), + LlmError::RateLimit(msg) => write!(f, "Rate limit error: {}", msg), + LlmError::ParseError(msg) => write!(f, "Parse error: {}", msg), + LlmError::ProviderDisabled(provider) => write!(f, "Provider disabled: {}", provider), + LlmError::ConfigError(msg) => write!(f, "Configuration error: {}", msg), + } + } +} + +impl Error for LlmError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + LlmError::RequestError(err) => Some(err), + _ => None, + } + } +} + +/// Convert reqwest errors to LlmError +impl From for LlmError { + fn from(err: reqwest::Error) -> Self { + LlmError::RequestError(err) + } +} + +/// Convert serde_json errors to LlmError +impl From for LlmError { + fn from(err: serde_json::Error) -> Self { + LlmError::ParseError(err.to_string()) + } +} + +/// Result type alias for LLM operations +pub type LlmResult = Result; + +impl LlmError { + /// Returns RateLimit error for 429 status or rate limit keywords + pub fn from_api_response(status: reqwest::StatusCode, error_message: String) -> Self { + if status == reqwest::StatusCode::TOO_MANY_REQUESTS { + return LlmError::RateLimit(error_message); + } + + // Check error message for rate limit indicators + let msg_lower = error_message.to_lowercase(); + if msg_lower.contains("rate limit") + || msg_lower.contains("too many requests") + || msg_lower.contains("quota exceeded") + || msg_lower.contains("overloaded") + || msg_lower.contains("throttle") + { + return LlmError::RateLimit(error_message); + } + + LlmError::ApiError(error_message) + } +} diff --git a/crates/flyllm/src/lib.rs b/crates/flyllm/src/lib.rs new file mode 100644 index 0000000..88a964f --- /dev/null +++ b/crates/flyllm/src/lib.rs @@ -0,0 +1,64 @@ +//! FlyLLM is a Rust library that provides a load-balanced, multi-provider client for Large Language Models. +//! +//! It enables developers to seamlessly work with multiple LLM providers (OpenAI, Anthropic, Google, Mistral...) +//! through a unified API with request routing, load balancing, and failure handling. +//! +//! # Features +//! +//! - **Multi-provider support**: Integrate with OpenAI, Anthropic, Google, and Mistral +//! - **Load balancing**: Distribute requests across multiple providers +//! - **Automatic retries**: Handle provider failures with configurable retry policies +//! - **Task routing**: Route specific tasks to the most suitable providers +//! - **Metrics tracking**: Monitor response times, error rates, and token usage +//! +//! # Example +//! +//! ```ignore +//! use flyllm::{LlmManager, ProviderType, GenerationRequest, TaskDefinition}; +//! +//! async fn example() { +//! // Create a manager +//! let mut manager = LlmManager::new(); +//! +//! // Add providers +//! manager.add_provider( +//! ProviderType::OpenAI, +//! "api-key".to_string(), +//! "gpt-4-turbo".to_string(), +//! vec![], +//! true +//! ); +//! +//! // Generate a response +//! let request = GenerationRequest { +//! prompt: "Explain Rust in one paragraph".to_string(), +//! task: None, +//! params: None, +//! }; +//! +//! let responses = manager.generate_sequentially(vec![request]).await; +//! println!("{}", responses[0].content); +//! } +//! ``` + +pub mod constants; +pub mod errors; +pub mod load_balancer; +pub mod providers; + +pub use providers::{ + create_instance, AnthropicInstance, LlmInstance, LlmRequest, LlmResponse, ModelDiscovery, + ModelInfo, OpenAIInstance, ProviderType, +}; + +pub use errors::{LlmError, LlmResult}; + +pub use load_balancer::{GenerationRequest, LlmManager, LlmManagerResponse, TaskDefinition}; + +/// Initialize the logging system +/// +/// This should be called at the start of your application in case +/// you want to activate the library's debug and info logging. +pub fn use_logging() { + env_logger::init(); +} diff --git a/crates/flyllm/src/load_balancer/builder.rs b/crates/flyllm/src/load_balancer/builder.rs new file mode 100644 index 0000000..a11adb5 --- /dev/null +++ b/crates/flyllm/src/load_balancer/builder.rs @@ -0,0 +1,203 @@ +use super::LlmManager; +use crate::errors::{LlmError, LlmResult}; +use crate::load_balancer::strategies::{LeastRecentlyUsedStrategy, LoadBalancingStrategy}; +use crate::load_balancer::tasks::TaskDefinition; +use crate::{constants, ProviderType}; +use log::debug; +use std::collections::HashMap; +use std::path::PathBuf; + +/// Internal helper struct for Builder +#[derive(Clone)] +struct ProviderConfig { + provider_type: ProviderType, + api_key: String, + model: String, + supported_task_names: Vec, + enabled: bool, + custom_endpoint: Option, +} + +/// LlmManager Builder +pub struct LlmManagerBuilder { + defined_tasks: HashMap, + providers_to_build: Vec, + strategy: Box, + max_retries: usize, + debug_folder: Option, +} + +impl LlmManagerBuilder { + /// Creates a new builder with default settings. + pub fn new() -> Self { + LlmManagerBuilder { + defined_tasks: HashMap::new(), + providers_to_build: Vec::new(), + strategy: Box::new(LeastRecentlyUsedStrategy::new()), // Default strategy + max_retries: constants::DEFAULT_MAX_TRIES, // Default retries + debug_folder: None, + } + } + + /// Defines a task that providers can later reference by name. + pub fn define_task(mut self, task_def: TaskDefinition) -> Self { + self.defined_tasks.insert(task_def.name.clone(), task_def); + self + } + + /// Sets the load balancing strategy for the manager. + pub fn strategy(mut self, strategy: Box) -> Self { + self.strategy = strategy; + self + } + + /// Sets the maximum number of retries for failed requests. + pub fn max_retries(mut self, retries: usize) -> Self { + self.max_retries = retries; + self + } + + /// Begins configuring a new provider instance. + /// Subsequent calls like `.supports()`, `.enabled()`, `.custom_endpoint()` will apply to this provider. + pub fn add_instance( + mut self, + provider_type: ProviderType, + model: impl Into, + api_key: impl Into, + ) -> Self { + let config = ProviderConfig { + provider_type, + api_key: api_key.into(), + model: model.into(), + supported_task_names: Vec::new(), + enabled: true, // Default to enabled + custom_endpoint: None, + }; + self.providers_to_build.push(config); + self // Return self to allow chaining provider configurations + } + + /// Specifies that the *last added* provider supports the task with the given name. + /// Panics if `add_instance` was not called before this. + pub fn supports(mut self, task_name: &str) -> Self { + match self.providers_to_build.last_mut() { + Some(last_provider) => { + if !self.defined_tasks.contains_key(task_name) { + // Optional: Warn or error early if task isn't defined yet + log::warn!("Provider configured to support task '{}' which has not been defined yet with define_task().", task_name); + } + last_provider + .supported_task_names + .push(task_name.to_string()); + } + None => { + panic!("'.supports()' called before '.add_instance()'"); + } + } + self + } + + /// Specifies that the *last added* provider supports multiple tasks by name. + /// Panics if `add_provider` was not called before this. + pub fn supports_many(mut self, task_names: &[&str]) -> Self { + match self.providers_to_build.last_mut() { + Some(last_provider) => { + for task_name in task_names { + if !self.defined_tasks.contains_key(*task_name) { + log::warn!("Provider configured to support task '{}' which has not been defined yet with define_task().", task_name); + } + last_provider + .supported_task_names + .push(task_name.to_string()); + } + } + None => { + panic!("'.supports_many()' called before '.add_provider()'"); + } + } + self + } + + /// Sets the enabled status for the *last added* provider. + /// Panics if `add_provider` was not called before this. + pub fn enabled(mut self, enabled: bool) -> Self { + match self.providers_to_build.last_mut() { + Some(last_provider) => { + last_provider.enabled = enabled; + } + None => { + panic!("'.enabled()' called before '.add_provider()'"); + } + } + self + } + + pub fn debug_folder(mut self, path: impl Into) -> Self { + self.debug_folder = Some(path.into()); + self + } + + /// Sets a custom endpoint for the *last added* provider. + /// Panics if `add_provider` was not called before this. + pub fn custom_endpoint(mut self, endpoint: impl Into) -> Self { + match self.providers_to_build.last_mut() { + Some(last_provider) => { + last_provider.custom_endpoint = Some(endpoint.into()); + } + None => { + panic!("'.custom_endpoint()' called before '.add_provider()'"); + } + } + self + } + + /// Consumes the builder and constructs the `LlmManager`. + /// Returns an error if a referenced task was not defined. + pub async fn build(self) -> LlmResult { + let mut manager = + LlmManager::new_with_strategy_and_retries(self.strategy, self.max_retries); + + // Set debug folder if specified + manager.debug_folder = self.debug_folder; + + for provider_config in self.providers_to_build { + // Resolve TaskDefinition structs from names + let mut provider_tasks: Vec = Vec::new(); + for task_name in &provider_config.supported_task_names { + match self.defined_tasks.get(task_name) { + Some(task_def) => provider_tasks.push(task_def.clone()), + None => return Err(LlmError::ConfigError(format!( + "Build failed: Task '{}' referenced by provider '{}' ({}) was not defined using define_task()", + task_name, provider_config.provider_type, provider_config.model + ))), + } + } + + manager + .add_instance( + provider_config.provider_type, + provider_config.api_key, + provider_config.model.clone(), + provider_tasks, + provider_config.enabled, + provider_config.custom_endpoint, + ) + .await; + debug!( + "Built and added provider: {} ({})", + provider_config.provider_type, provider_config.model + ); + } + + // Check if the manager has instances + let trackers = manager.trackers.lock().await; + let is_empty = trackers.is_empty(); + drop(trackers); + + if is_empty { + log::warn!("LlmManager built with no provider instances."); + } + + Ok(manager) + } +} diff --git a/crates/flyllm/src/load_balancer/manager.rs b/crates/flyllm/src/load_balancer/manager.rs new file mode 100644 index 0000000..2b59545 --- /dev/null +++ b/crates/flyllm/src/load_balancer/manager.rs @@ -0,0 +1,971 @@ +use crate::errors::{LlmError, LlmResult}; +use crate::load_balancer::builder::LlmManagerBuilder; +use crate::load_balancer::strategies; +use crate::load_balancer::tasks::TaskDefinition; +use crate::load_balancer::tracker::InstanceTracker; +use crate::load_balancer::utils::{get_debug_path, write_to_debug_file}; +use crate::providers::{LlmInstance, LlmRequest, Message, TokenUsage}; +use crate::{constants, create_instance, ProviderType}; // TODO - ? +use futures::future::join_all; +use log::{debug, info, warn}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::fs; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::{Instant, SystemTime, UNIX_EPOCH}; +use tokio::sync::Mutex; + +/// User-facing request for LLM generation +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct GenerationRequest { + pub prompt: String, // Prompt for the LLM + pub task: Option, // Task to route for + pub params: Option>, // Extra parameters +} + +impl Default for GenerationRequest { + fn default() -> Self { + Self { + prompt: String::new(), + task: None, + params: None, + } + } +} + +impl GenerationRequest { + // Standard Constructor + pub fn new(prompt: String) -> Self { + GenerationRequest { + prompt, + ..Default::default() + } + } + + /// Creates a builder for a GenerationRequest. + pub fn builder(prompt: impl Into) -> GenerationRequest { + GenerationRequest::new(prompt.into()) + } + + /// Sets the target task for this request. + pub fn task(mut self, name: impl Into) -> Self { + self.task = Some(name.into()); + self + } + + /// Adds or overrides a parameter specifically for this request. + pub fn param(mut self, key: impl Into, value: impl Into) -> Self { + self.params + .get_or_insert_with(HashMap::new) + .insert(key.into(), value.into()); + self + } + + /// Sets max tokens for this generation in specific + pub fn max_tokens(self, tokens: u32) -> Self { + self.param("max_tokens", json!(tokens)) + } + + /// Finalizes the GenerationRequest + pub fn build(self) -> Self { + self + } +} + +/// Internal request structure with additional retry information +#[derive(Clone)] +struct LlmManagerRequest { + pub prompt: String, + pub task: Option, + pub params: Option>, + pub attempts: usize, + pub failed_instances: Vec, +} + +impl LlmManagerRequest { + /// Convert a user-facing GenerationRequest to internal format + fn from_generation_request(request: GenerationRequest) -> Self { + Self { + prompt: request.prompt, + task: request.task, + params: request.params, + attempts: 0, + failed_instances: Vec::new(), + } + } +} + +/// Response structure returned to users +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct LlmManagerResponse { + pub content: String, + pub success: bool, + pub error: Option, +} + +/// Main manager for LLM providers that handles load balancing and retries +/// +/// The LlmManager: +/// - Manages multiple LLM instances ( from different providers) +/// - Maps tasks to compatible instances +/// - Routes requests to appropriate instances +/// - Implements retries and fallbacks +/// - Tracks performance metrics and token usage +pub struct LlmManager { + pub trackers: Arc>>, // Current instance trackers in the manager (contains the instances themselves) + pub strategy: Arc>>, // Current strategy for load balancing being used + pub tasks_to_instances: Arc>>>, // Map of which instances handle which tasks + pub instance_counter: Mutex, // Used for giving unique IDs to each instance in this manager + pub max_retries: usize, // Controls how many times a failed request will be tried before giving up + pub total_usage: Mutex>, // Token usage of each instance + pub debug_folder: Option, // Path where JSONs with debug inputs/outputs of each model will be stored + pub creation_time: SystemTime, +} + +impl LlmManager { + /// Create a new LlmManager with default settings + pub fn new() -> Self { + Self { + trackers: Arc::new(Mutex::new(HashMap::new())), + strategy: Arc::new(Mutex::new(Box::new( + strategies::LeastRecentlyUsedStrategy::new(), + ))), + tasks_to_instances: Arc::new(Mutex::new(HashMap::new())), + instance_counter: Mutex::new(0), + max_retries: constants::DEFAULT_MAX_TRIES, + total_usage: Mutex::new(HashMap::new()), + debug_folder: None, + creation_time: SystemTime::now(), + } + } + + /// Creates a new builder to configure the LlmManager. + pub fn builder() -> LlmManagerBuilder { + LlmManagerBuilder::new() + } + + /// Create a new LlmManager with a custom load balancing strategy + /// + /// # Parameters + /// * `strategy` - The load balancing strategy to use + pub fn new_with_strategy( + strategy: Box, + ) -> Self { + Self { + trackers: Arc::new(Mutex::new(HashMap::new())), + strategy: Arc::new(Mutex::new(strategy)), + tasks_to_instances: Arc::new(Mutex::new(HashMap::new())), + instance_counter: Mutex::new(0), + max_retries: constants::DEFAULT_MAX_TRIES, + total_usage: Mutex::new(HashMap::new()), + debug_folder: None, + creation_time: SystemTime::now(), + } + } + + /// Constructor used by the builder. + pub fn new_with_strategy_and_retries( + strategy: Box, + max_retries: usize, + ) -> Self { + Self { + trackers: Arc::new(Mutex::new(HashMap::new())), + strategy: Arc::new(Mutex::new(strategy)), + tasks_to_instances: Arc::new(Mutex::new(HashMap::new())), + instance_counter: Mutex::new(0), + max_retries, // Use passed value + total_usage: Mutex::new(HashMap::new()), + debug_folder: None, + creation_time: SystemTime::now(), + } + } + + /// Adds a new LLM instance by creating it from basic parameters + /// + /// # Parameters + /// * `provider_type` - Which LLM provider to use (Anthropic, OpenAI, etc) + /// * `api_key` - API key for the provider + /// * `model` - Model identifier to use + /// * `tasks` - List of tasks this provider supports + /// * `enabled` - Whether this provider should be enabled + /// * `custom_endpont` - Optional specification on where the requests for this instance should go + pub async fn add_instance( + &mut self, + provider_type: ProviderType, + api_key: String, + model: String, + tasks: Vec, + enabled: bool, + custom_endpoint: Option, + ) { + debug!("Creating provider with model {}", model); + let instance = create_instance( + provider_type, + api_key, + model.clone(), + tasks.clone(), + enabled, + custom_endpoint, + ); + self.add_instance_to_manager(instance).await; + info!( + "Added Provider Instance ({}) - Model: {} - Supports Tasks: {:?}", + provider_type, + model, + tasks.iter().map(|t| t.name.as_str()).collect::>() + ); + } + + /// Add a pre-created provider instance + /// + /// # Parameters + /// * `provider` - The provider instance to add + pub async fn add_instance_to_manager(&mut self, instance: Arc) { + let id = { + let mut counter = self.instance_counter.lock().await; + let current_id = *counter; + *counter += 1; + current_id + }; + + let tracker = InstanceTracker::new(instance.clone()); + debug!("Adding instance {} ({})", id, instance.get_name()); + + let supported_tasks_names: Vec = + instance.get_supported_tasks().keys().cloned().collect(); + + { + let mut task_map = self.tasks_to_instances.lock().await; + for task_name in &supported_tasks_names { + task_map + .entry(task_name.clone()) + .or_insert_with(Vec::new) + .push(id); + debug!("Added instance {} to task mapping for '{}'", id, task_name); + } + } + + { + let mut trackers = self.trackers.lock().await; + trackers.insert(id, tracker); + } + + { + let mut usage_map = self.total_usage.lock().await; + usage_map.insert(id, TokenUsage::default()); // TODO - Implement default + } + } + + /// Set a new load balancing strategy + /// + /// # Parameters + /// * `strategy` - The new load balancing strategy to use + pub async fn set_strategy( + &mut self, + strategy: Box, + ) { + let mut current_strategy = self.strategy.lock().await; + *current_strategy = strategy; + } + + /// Process multiple requests sequentially + /// + /// # Parameters + /// * `requests` - List of generation requests to process + /// + /// # Returns + /// * List of responses in the same order as the requests + pub async fn generate_sequentially( + &self, + requests: Vec, + ) -> Vec { + let mut responses = Vec::with_capacity(requests.len()); + info!( + "Entering generate_sequentially with {} requests", + requests.len() + ); + + for (index, request) in requests.into_iter().enumerate() { + info!("Starting sequential request index: {}", index); + let internal_request = LlmManagerRequest::from_generation_request(request); + + let response_result = self.generate_response(internal_request, None).await; + info!( + "Sequential request index {} completed generate_response call.", + index + ); + + let response = match response_result { + Ok(content) => { + info!("Sequential request index {} succeeded.", index); + LlmManagerResponse { + content, + success: true, + error: None, + } + } + Err(e) => { + warn!("Sequential request index {} failed: {}", index, e); + LlmManagerResponse { + content: String::new(), + success: false, + error: Some(e.to_string()), + } + } + }; + + debug!("Pushing response for sequential request index {}", index); + responses.push(response); + info!("Finished processing sequential request index {}", index); + } + + info!("Exiting generate_sequentially"); + responses + } + + /// Process multiple requests in parallel + /// + /// # Parameters + /// * `requests` - List of generation requests to process + /// + /// # Returns + /// * List of responses in the same order as the requests + pub async fn batch_generate( + &self, + requests: Vec, + ) -> Vec { + info!("Entering batch_generate with {} requests", requests.len()); + let internal_requests = requests + .into_iter() + .map(|request| LlmManagerRequest::from_generation_request(request)) + .collect::>(); + + let futures = internal_requests + .into_iter() + .enumerate() + .map(|(index, request)| async move { + info!("Starting parallel request index: {}", index); + match self.generate_response(request, None).await { + Ok(content) => { + info!("Parallel request index {} succeeded.", index); + LlmManagerResponse { + content, + success: true, + error: None, + } + } + Err(e) => { + warn!("Parallel request index {} failed: {}", index, e); + LlmManagerResponse { + content: String::new(), + success: false, + error: Some(e.to_string()), + } + } + } + }) + .collect::>(); + + let results = join_all(futures).await; + info!("Exiting batch_generate"); + results + } + + /// Core function to generate a response with retries + /// + /// # Parameters + /// * `request` - The internal request with retry state + /// * `max_attempts` - Optional override for maximum retry attempts + /// + /// # Returns + /// * Result with either the generated content or an error + async fn generate_response( + &self, + request: LlmManagerRequest, + max_attempts: Option, + ) -> LlmResult { + let start_time = Instant::now(); + let mut attempts = request.attempts; + let mut failed_instances = request.failed_instances.clone(); + let prompt_preview = request.prompt.chars().take(50).collect::(); + let task = request.task.as_deref(); + let request_params = request.params.clone(); + let max_retries = max_attempts.unwrap_or(self.max_retries); + + info!( + "generate_response called for task: {:?}, prompt: '{}...'", + task, prompt_preview + ); + + while attempts <= max_retries { + debug!( + "Attempt {} of {} for request (task: {:?})", + attempts + 1, + max_retries + 1, + task + ); + + let attempt_result = self + .instance_selection( + &request.prompt, + task, + request_params.clone(), + &failed_instances, + ) + .await; + + match attempt_result { + Ok((content, instance_id)) => { + let duration = start_time.elapsed(); + info!( + "Request successful on attempt {} with instance {} after {:?}", + attempts + 1, + instance_id, + duration + ); + return Ok(content); + } + Err((error, instance_id)) => { + warn!( + "Attempt {} failed with instance {}: {}", + attempts + 1, + instance_id, + error + ); + + // Check if this is a rate limit error + if matches!(error, LlmError::RateLimit(_)) { + warn!( + "Rate limit detected for instance {}. Waiting before retry...", + instance_id + ); + + // Wait before retrying (exponential backoff) + let wait_time = + std::time::Duration::from_secs(2_u64.pow(attempts as u32).min(60)); + tokio::time::sleep(wait_time).await; + + // Don't mark this instance as failed for rate limits + // Just increment attempts and try again + attempts += 1; + } else { + // For non-rate-limit errors, mark instance as failed + failed_instances.push(instance_id); + attempts += 1; + } + + if attempts > max_retries { + warn!( + "Max retries ({}) reached for task: {:?}. Returning last error.", + max_retries + 1, + task + ); + return Err(error); + } + + debug!( + "Retrying with next eligible instance for task: {:?}...", + task + ); + } + } + } + + warn!("Exited retry loop unexpectedly for task: {:?}", task); + Err(LlmError::ConfigError( + "No available providers after all retry attempts".to_string(), + )) + } + + /// Select an appropriate instance and execute the request + /// + /// This function: + /// 1. Identifies instances that support the requested task + /// 2. Filters out failed and disabled instances + /// 3. Uses the load balancing strategy to select an instance + /// 4. Merges task and request parameters + /// 5. Executes the request against the selected provider + /// 6. Updates metrics based on the result + /// + /// # Parameters + /// * `prompt` - The prompt text to send + /// * `task` - Optional task identifier + /// * `request_params` - Optional request parameters + /// * `failed_instances` - List of instance IDs that have failed + /// + /// # Returns + /// * Success: (generated content, instance ID) + /// * Error: (error, instance ID that failed) + async fn instance_selection( + &self, + prompt: &str, + task: Option<&str>, + request_params: Option>, + failed_instances: &[usize], + ) -> Result<(String, usize), (LlmError, usize)> { + debug!( + "instance_selection: Starting selection for task: {:?}", + task + ); + + // 1. Get candidate instance IDs based on task (if any) + let candidate_ids: Option> = match task { + Some(task_name) => { + let task_map = self.tasks_to_instances.lock().await; + task_map.get(task_name).cloned() + } + None => None, // No specific task, consider all instances initially + }; + + if task.is_some() && candidate_ids.is_none() { + warn!("No instances found supporting task: '{}'", task.unwrap()); + debug!("instance_selection returning Err (no task support)"); + return Err(( + LlmError::ConfigError(format!( + "No providers available for task: {}", + task.unwrap() + )), + 0, + )); + } + + // 2. Filter candidates by availability and collect all needed data in one go + let eligible_instances_data: Vec<( + usize, + String, + Arc, + Option, + )>; + + // Get eligible instance IDs for strategy selection + let eligible_instance_ids: Vec; + + // Scope the lock to ensure it's dropped before strategy selection + { + let trackers_guard = self.trackers.lock().await; + debug!("instance_selection: Acquired trackers lock (1st time)"); + + if trackers_guard.is_empty() { + warn!("No LLM providers configured."); + return Err(( + LlmError::ConfigError("No LLM providers available".to_string()), + 0, + )); + } + + // Extract all the data we need while holding the lock + match candidate_ids { + Some(ids) => { + debug!( + "Filtering instances for task '{}' using IDs: {:?}", + task.unwrap(), + ids + ); + eligible_instances_data = trackers_guard + .iter() + .filter(|(id, tracker)| { + ids.contains(id) + && tracker.is_enabled() + && !failed_instances.contains(id) + }) + .map(|(id, tracker)| { + let task_def = task.and_then(|t| { + tracker.instance.get_supported_tasks().get(t).cloned() + }); + ( + *id, + tracker.instance.get_name().to_string(), + tracker.instance.clone(), + task_def, + ) + }) + .collect(); + debug!( + "Found {} eligible instances for task '{}'", + eligible_instances_data.len(), + task.unwrap() + ); + } + None => { + debug!("No specific task. Filtering all enabled instances."); + eligible_instances_data = trackers_guard + .iter() + .filter(|(id, tracker)| { + tracker.is_enabled() && !failed_instances.contains(id) + }) + .map(|(id, tracker)| { + let task_def = task.and_then(|t| { + tracker.instance.get_supported_tasks().get(t).cloned() + }); + ( + *id, + tracker.instance.get_name().to_string(), + tracker.instance.clone(), + task_def, + ) + }) + .collect(); + debug!( + "Found {} eligible instances (no task)", + eligible_instances_data.len() + ); + } + } + + // Extract just the IDs for strategy selection + eligible_instance_ids = eligible_instances_data + .iter() + .map(|(id, _, _, _)| *id) + .collect(); + + // No eligible instances check + if eligible_instances_data.is_empty() { + let error_msg = format!( + "No enabled providers available{}{}", + task.map_or_else(|| "".to_string(), |t| format!(" for task: '{}'", t)), + if !failed_instances.is_empty() { + format!(" (excluded {} failed instances)", failed_instances.len()) + } else { + "".to_string() + } + ); + warn!("{}", error_msg); + return Err((LlmError::ConfigError(error_msg), 0)); + } + } + + // 5. Select instance using strategy (need to re-acquire lock for metrics) + let selected_instance_id = { + let trackers_guard = self.trackers.lock().await; + let mut strategy = self.strategy.lock().await; + debug!("instance_selection: Acquired strategy and trackers locks"); + + // Build the trackers slice for the strategy + let eligible_trackers: Vec<(usize, &InstanceTracker)> = eligible_instance_ids + .iter() + .filter_map(|id| trackers_guard.get(id).map(|tracker| (*id, tracker))) + .collect(); + + let selected_metric_index = strategy.select_instance(&eligible_trackers); + let selected_id = eligible_trackers[selected_metric_index].0; + + debug!("instance_selection: Released strategy lock"); + selected_id + }; + + // Find the corresponding instance in our extracted data + let selected_instance = eligible_instances_data + .iter() + .find(|(id, _, _, _)| *id == selected_instance_id) + .expect("Selected instance ID from metrics not found in eligible list - LOGIC ERROR!"); + + // Unpack the tuple + let (selected_id, selected_name, selected_provider_arc, task_def) = ( + selected_instance.0, + &selected_instance.1, + &selected_instance.2, + &selected_instance.3, + ); + + debug!( + "Selected instance {} ({}) for the request.", + selected_id, selected_name + ); + + // 6. Merge parameters + let mut final_params = HashMap::new(); + if let Some(task_def) = task_def { + final_params.extend(task_def.parameters.clone()); + debug!("Applied parameters from task for instance {}", selected_id); + } + + if let Some(req_params) = request_params { + final_params.extend(req_params); + debug!( + "Applied request-specific parameters for instance {}", + selected_id + ); + } + + // Create and execute the request + let max_tokens = final_params + .get("max_tokens") + .and_then(|v| v.as_u64()) + .map(|v| v as u32); + + let temperature = final_params + .get("temperature") + .and_then(|v| v.as_f64()) + .map(|v| v as f32); + + let request = LlmRequest { + messages: vec![Message { + role: "user".to_string(), + content: prompt.to_string(), + }], + model: None, // Let provider use its configured model + max_tokens, + temperature, + }; + + debug!( + "Instance {} ({}) sending request to provider...", + selected_id, selected_name + ); + let start_time = Instant::now(); + let result = selected_provider_arc.generate(&request).await; + let duration = start_time.elapsed(); + info!( + "Instance {} ({}) received result in {:?}", + selected_id, selected_name, duration + ); + + // Update metrics regardless of success or failure + { + debug!("instance_selection: Attempting to acquire trackers lock (2nd time) for metrics update"); + let mut trackers_guard = self.trackers.lock().await; + debug!("instance_selection: Acquired trackers lock (2nd time)"); + if let Some((_id, instance_tracker)) = trackers_guard + .iter_mut() + .find(|(id, _tracker)| **id == selected_id) + { + debug!("Recording result for instance {}", selected_id); + instance_tracker.record_result(duration, &result); + debug!("Finished recording result for instance {}", selected_id); + } else { + warn!( + "Instance {} not found for metric update after request completion.", + selected_id + ); + } + debug!("instance_selection: Releasing trackers lock (2nd time) after metrics update"); + // Lock released when trackers_guard goes out of scope here + } + + // Write debug information if debug folder is configured + self.write_debug_info( + selected_id, + selected_name, + &selected_provider_arc.get_model(), + prompt, + task, + &final_params, + &result, + duration, + ) + .await; + + // Return either content or error with the instance ID + match result { + Ok(response) => { + if let Some(usage) = &response.usage { + self.update_instance_usage(selected_id, usage).await; + debug!( + "Updated token usage for instance {}: {:?}", + selected_id, usage + ); + } + debug!( + "instance_selection returning Ok for instance {}", + selected_id + ); + Ok((response.content, selected_id)) + } + Err(e) => { + debug!( + "instance_selection returning Err for instance {}: {}", + selected_id, e + ); + Err((e, selected_id)) + } + } + } + + /// Write debug information for a request/response to the debug folder + async fn write_debug_info( + &self, + instance_id: usize, + instance_name: &str, + instance_model: &str, + prompt: &str, + task: Option<&str>, + final_params: &HashMap, + result: &Result, + duration: std::time::Duration, + ) { + if let Some(debug_folder) = &self.debug_folder { + let timestamp = self + .creation_time + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + let debug_path = get_debug_path( + debug_folder, + timestamp, + instance_id, + instance_name, + instance_model, + ); + + // Create the new generation entry + let generation_entry = json!({ + "metadata": { + "timestamp": SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs(), + "instance_id": instance_id, + "instance_name": instance_name, + "instance_model": instance_model, + "duration_ms": duration.as_millis() + }, + "input": { + "prompt": prompt, + "task": task, + "parameters": final_params + }, + "output": match result { + Ok(response) => json!({ + "success": true, + "content": response.content, + "usage": response.usage + }), + Err(error) => json!({ + "success": false, + "error": error.to_string() + }) + } + }); + + // Read existing file or create new array + let mut generations: Vec = if debug_path.exists() { + match fs::read_to_string(&debug_path) { + Ok(content) => match serde_json::from_str::>(&content) { + Ok(array) => array, + Err(e) => { + warn!("Failed to parse existing debug file as JSON array, creating new: {}", e); + Vec::new() + } + }, + Err(e) => { + warn!("Failed to read existing debug file, creating new: {}", e); + Vec::new() + } + } + } else { + Vec::new() + }; + + // Append new generation + generations.push(generation_entry); + + // Write updated array back to file + let json_string = match serde_json::to_string_pretty(&generations) { + Ok(s) => s, + Err(e) => { + warn!("Failed to serialize debug data: {}", e); + return; + } + }; + + if let Err(e) = write_to_debug_file(&debug_path, &json_string) { + warn!("Failed to write debug file: {}", e); + } + } + } + + /// Update token usage for a specific instance + /// + /// # Parameters + /// * `instance_id` - ID of the instance to update + /// * `usage` - The token usage to add + async fn update_instance_usage(&self, instance_id: usize, usage: &TokenUsage) { + let mut usage_map = self.total_usage.lock().await; + + let instance_usage = usage_map.entry(instance_id).or_insert(TokenUsage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }); + + instance_usage.prompt_tokens += usage.prompt_tokens; + instance_usage.completion_tokens += usage.completion_tokens; + instance_usage.total_tokens += usage.total_tokens; + + debug!( + "Updated usage for instance {}: current total is {} tokens", + instance_id, instance_usage.total_tokens + ); + } + + /// Get token usage for a specific instance + /// + /// # Parameters + /// * `instance_id` - ID of the instance to query + /// + /// # Returns + /// * Token usage for the specified instance, if found + pub async fn get_instance_usage(&self, instance_id: usize) -> Option { + let usage_map = self.total_usage.lock().await; + usage_map.get(&instance_id).cloned() + } + + /// Get total token usage across all instances + /// + /// # Returns + /// * Combined token usage statistics + pub async fn get_total_usage(&self) -> TokenUsage { + let usage_map = self.total_usage.lock().await; + + usage_map.values().fold( + TokenUsage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + |mut acc, usage| { + acc.prompt_tokens += usage.prompt_tokens; + acc.completion_tokens += usage.completion_tokens; + acc.total_tokens += usage.total_tokens; + acc + }, + ) + } + + /// Get the number of configured provider instances + /// + /// # Returns + /// * Number of provider instances in the manager + pub async fn get_provider_count(&self) -> usize { + let trackers = self.trackers.lock().await; + trackers.len() + } + + /// Print token usage statistics to console + pub async fn print_token_usage(&self) { + println!("\n--- Token Usage Statistics ---"); + println!( + "{:<5} {:<15} {:<30} {:<15} {:<15} {:<15}", + "ID", "Provider", "Model", "Prompt Tokens", "Completion Tokens", "Total Tokens" + ); + println!("{}", "-".repeat(95)); + + let trackers = self.trackers.lock().await; + let usage_map = self.total_usage.lock().await; + + // Print usage for each instance + for (instance_id, tracker) in trackers.iter() { + if let Some(usage) = usage_map.get(instance_id) { + println!( + "{:<5} {:<15} {:<30} {:<15} {:<15} {:<15}", + instance_id, + tracker.instance.get_name(), + tracker.instance.get_model(), + usage.prompt_tokens, + usage.completion_tokens, + usage.total_tokens + ); + } + } + } +} diff --git a/crates/flyllm/src/load_balancer/mod.rs b/crates/flyllm/src/load_balancer/mod.rs new file mode 100644 index 0000000..14cacb2 --- /dev/null +++ b/crates/flyllm/src/load_balancer/mod.rs @@ -0,0 +1,17 @@ +pub mod builder; +pub mod manager; +pub mod strategies; +pub mod tasks; +/// Load balancer module for distributing requests across multiple LLM providers +/// +/// This module contains components for: +/// - Managing provider instances with metrics tracking +/// - Selecting appropriate providers based on tasks and load +/// - Implementing different load balancing strategies +/// - Handling retries and fallbacks when providers fail +/// - Tracking token usage across providers +pub mod tracker; +pub mod utils; + +pub use manager::{GenerationRequest, LlmManager, LlmManagerResponse}; +pub use tasks::TaskDefinition; diff --git a/crates/flyllm/src/load_balancer/strategies.rs b/crates/flyllm/src/load_balancer/strategies.rs new file mode 100644 index 0000000..262d7dc --- /dev/null +++ b/crates/flyllm/src/load_balancer/strategies.rs @@ -0,0 +1,154 @@ +use log::debug; +use rand::Rng; + +use crate::load_balancer::tracker::InstanceTracker; + +/// Trait defining the interface for load balancing strategies +/// +/// Implementations of this trait determine how to select which LLM instance +/// will handle a particular request based on instance metrics. +pub trait LoadBalancingStrategy { + /// Select an instance from available candidates + /// + /// # Parameters + /// * `trackers` - Array of (id, tracker) tuples for available instances + /// + /// # Returns + /// * Index into the trackers array of the selected instance + fn select_instance(&mut self, trackers: &[(usize, &InstanceTracker)]) -> usize; +} + +/// Strategy that selects the instance that was used least recently +/// +/// This strategy helps distribute load by prioritizing instances +/// that haven't been used in the longest time. +pub struct LeastRecentlyUsedStrategy; + +impl LeastRecentlyUsedStrategy { + /// Creates a new LeastRecentlyUsedStrategy + pub fn new() -> Self { + Self {} + } +} + +impl LoadBalancingStrategy for LeastRecentlyUsedStrategy { + /// Select the instance with the oldest last_used timestamp + /// + /// # Parameters + /// * `trackers` - Array of (id, tracker) tuples for available instances + /// + /// # Returns + /// * Index into the trackers array of the least recently used instance + /// + /// # Panics + /// Panics if `trackers` is empty + fn select_instance(&mut self, trackers: &[(usize, &InstanceTracker)]) -> usize { + if trackers.is_empty() { + panic!("LoadBalancingStrategy::select_instance called with empty trackers slice"); + } + + let mut oldest_index = 0; + let mut oldest_time = trackers[0].1.last_used; + + for (i, (_id, tracker)) in trackers.iter().enumerate().skip(1) { + if tracker.last_used < oldest_time { + oldest_index = i; + oldest_time = tracker.last_used; + } + } + + debug!( + "LeastRecentlyUsedStrategy: Selected index {} (ID: {}) from {} eligible trackers with last_used: {:?}", + oldest_index, trackers[oldest_index].0, trackers.len(), oldest_time + ); + + oldest_index + } +} + +/// Strategy that selects the instance with the lowest average response time. +#[derive(Debug, Default)] +pub struct LowestLatencyStrategy; + +impl LowestLatencyStrategy { + /// Creates a new LowestLatencyStrategy + pub fn new() -> Self { + Self {} + } +} + +impl LoadBalancingStrategy for LowestLatencyStrategy { + /// Select the instance with the minimum `avg_response_time`. + /// + /// # Parameters + /// * `trackers` - Array of (id, tracker) tuples for available instances. + /// + /// # Returns + /// * Index into the trackers array of the fastest instance. + /// + /// # Panics + /// * Panics if `trackers` is empty. + fn select_instance(&mut self, trackers: &[(usize, &InstanceTracker)]) -> usize { + if trackers.is_empty() { + panic!("LowestLatencyStrategy::select_instance called with empty trackers slice"); + } + + let mut best_index = 0; + let mut lowest_time = trackers[0].1.avg_response_time(); + + for (i, (_id, tracker)) in trackers.iter().enumerate().skip(1) { + let avg_time = tracker.avg_response_time(); + if avg_time < lowest_time { + best_index = i; + lowest_time = avg_time; + } + } + + debug!( + "LowestLatencyStrategy: Selected index {} (ID: {}) from {} eligible trackers with avg_response_time: {:?}", + best_index, trackers[best_index].0, trackers.len(), lowest_time + ); + + best_index + } +} + +/// Strategy that selects a random instance from the available pool. +#[derive(Debug, Default)] +pub struct RandomStrategy; + +impl RandomStrategy { + /// Creates a new RandomStrategy + pub fn new() -> Self { + Self {} + } +} + +impl LoadBalancingStrategy for RandomStrategy { + /// Select a random instance. + /// + /// # Parameters + /// * `trackers` - Array of (id, tracker) tuples for available instances. + /// + /// # Returns + /// * Index into the trackers array of a randomly chosen instance. + /// + /// # Panics + /// * Panics if `trackers` is empty. + fn select_instance(&mut self, trackers: &[(usize, &InstanceTracker)]) -> usize { + if trackers.is_empty() { + panic!("RandomStrategy::select_instance called with empty trackers slice"); + } + + let index = rand::rng().random_range(0..trackers.len()); + + debug!( + "RandomStrategy: Selected random index {} (ID: {}) from {} eligible trackers", + index, + trackers[index].0, + trackers.len() + ); + + index + } +} diff --git a/crates/flyllm/src/load_balancer/tasks.rs b/crates/flyllm/src/load_balancer/tasks.rs new file mode 100644 index 0000000..7d0a210 --- /dev/null +++ b/crates/flyllm/src/load_balancer/tasks.rs @@ -0,0 +1,41 @@ +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::collections::HashMap; + +/// Definition of a task that can be routed to specific providers +/// +/// Tasks represent specialized capabilities or configurations that +/// certain providers might be better suited for. Each task can have +/// associated parameters that affect how the request is processed. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskDefinition { + pub name: String, + pub parameters: HashMap, +} + +impl TaskDefinition { + /// Creates a new TaskDefinition with the given name. + pub fn new(name: impl Into) -> Self { + TaskDefinition { + name: name.into(), + parameters: HashMap::new(), + } + } + + /// Adds or updates a parameter for this task definition. + /// Accepts any value that can be converted into a serde_json::Value. + pub fn with_param(mut self, key: impl Into, value: impl Into) -> Self { + self.parameters.insert(key.into(), value.into()); + self + } + + /// Sets the `max_tokens` parameter to a given value. + pub fn with_max_tokens(self, tokens: u32) -> Self { + self.with_param("max_tokens", json!(tokens)) + } + + /// Sets the `temperature` parameter to a given value. + pub fn with_temperature(self, temp: f32) -> Self { + self.with_param("temperature", json!(temp)) + } +} diff --git a/crates/flyllm/src/load_balancer/tracker.rs b/crates/flyllm/src/load_balancer/tracker.rs new file mode 100644 index 0000000..625b4ec --- /dev/null +++ b/crates/flyllm/src/load_balancer/tracker.rs @@ -0,0 +1,92 @@ +use crate::providers::LlmInstance; +use crate::{LlmResponse, LlmResult}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +/// An LLM provider instance with associated metrics +pub struct InstanceTracker { + pub instance: Arc, + pub last_used: Instant, + pub response_times: Vec, + pub request_count: usize, + pub error_count: usize, +} + +impl InstanceTracker { + /// Create a new LLM instance + /// + /// # Parameters + /// * `id` - Unique identifier for this instance + /// * `provider` - Reference to the provider implementation + pub fn new(instance: Arc) -> Self { + Self { + instance: instance, + last_used: Instant::now(), + response_times: Vec::new(), + request_count: 0, + error_count: 0, + } + } + + /// Record the result of a request for metrics tracking + /// + /// # Parameters + /// * `duration` - How long the request took + /// * `result` - The result of the request (success or error) + pub fn record_result(&mut self, duration: Duration, result: &LlmResult) { + self.last_used = Instant::now(); + self.request_count += 1; + + match result { + Ok(_) => { + self.response_times.push(duration); + if self.response_times.len() > 10 { + self.response_times.remove(0); + } + } + Err(e) => { + self.error_count += 1; + } + } + } + + /// Calculate the average response time from recent requests + /// + /// # Returns + /// * Average duration, or zero if no requests recorded + pub fn avg_response_time(&self) -> Duration { + if self.response_times.is_empty() { + return Duration::from_millis(0); + } + let total: Duration = self.response_times.iter().sum(); + total / self.response_times.len().max(1) as u32 // Avoid division by zero + } + + /// Calculate the error rate as a percentage + /// + /// # Returns + /// * Error rate from 0.0 to 100.0, or 0.0 if no requests + pub fn get_error_rate(&self) -> f64 { + if self.request_count > 0 { + (self.error_count as f64 / self.request_count as f64) * 100.0 + } else { + 0.0 + } + } + + /// Check if this instance is currently enabled + /// + /// # Returns + /// * Whether this instance is enabled or not + pub fn is_enabled(&self) -> bool { + self.instance.is_enabled() + } + + /// Check if this instance supports a specific task + /// + /// # Returns + /// * Whether this instance supports this task or not + pub fn supports_task(&self, task_name: &str) -> bool { + self.instance.get_supported_tasks().contains_key(task_name) + } +} diff --git a/crates/flyllm/src/load_balancer/utils.rs b/crates/flyllm/src/load_balancer/utils.rs new file mode 100644 index 0000000..adb7471 --- /dev/null +++ b/crates/flyllm/src/load_balancer/utils.rs @@ -0,0 +1,36 @@ +use crate::errors::LlmError; +use std::fs::{create_dir_all, File}; +use std::io::Write; +use std::path::PathBuf; + +pub fn get_debug_path( + debug_folder: &PathBuf, + timestamp: u64, + instance_id: usize, + instance_provider: &str, + instance_model: &str, +) -> PathBuf { + let timestamp_folder = debug_folder.join(timestamp.to_string()); + let instance_folder = timestamp_folder.join(format!( + "{}_{}_{}", + instance_id, instance_provider, instance_model + )); + instance_folder.join("debug.json") +} + +pub fn write_to_debug_file(file_path: &PathBuf, contents: &str) -> Result<(), LlmError> { + // Create parent directories if they don't exist + if let Some(parent) = file_path.parent() { + create_dir_all(parent).map_err(|e| { + LlmError::ConfigError(format!("Failed to create debug directories: {}", e)) + })?; + } + + let mut file = File::create(file_path) + .map_err(|e| LlmError::ConfigError(format!("Failed to create debug file: {}", e)))?; + + file.write_all(contents.as_bytes()) + .map_err(|e| LlmError::ConfigError(format!("Failed to write to debug file: {}", e)))?; + + Ok(()) +} diff --git a/crates/flyllm/src/providers/anthropic.rs b/crates/flyllm/src/providers/anthropic.rs new file mode 100644 index 0000000..94eb3d2 --- /dev/null +++ b/crates/flyllm/src/providers/anthropic.rs @@ -0,0 +1,222 @@ +use crate::constants; +use crate::errors::{LlmError, LlmResult}; +use crate::load_balancer::tasks::TaskDefinition; +use crate::providers::instances::{BaseInstance, LlmInstance}; +use crate::providers::types::{LlmRequest, LlmResponse, TokenUsage}; + +use async_trait::async_trait; +use reqwest::header; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Provider implementation for Anthropic's Claude API +pub struct AnthropicInstance { + base: BaseInstance, +} + +/// Request structure for the Anthropic Claude API +#[derive(Serialize)] +struct AnthropicRequest { + model: String, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option, + messages: Vec, + max_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, +} + +/// Individual message structure for Anthropic's API +#[derive(Serialize)] +struct AnthropicMessage { + role: String, + content: String, +} + +/// Response structure from Anthropic's Claude API +#[derive(Deserialize)] +struct AnthropicResponse { + content: Vec, + model: String, + usage: Option, +} + +/// Content block from Anthropic's response +#[derive(Deserialize)] +struct AnthropicContent { + text: String, + #[serde(rename = "type")] + content_type: String, +} + +/// Token usage information from Anthropic +#[derive(Deserialize)] +struct AnthropicUsage { + input_tokens: u32, + output_tokens: u32, +} + +impl AnthropicInstance { + /// Creates a new Anthropic provider instance + /// + /// # Parameters + /// * `api_key` - Anthropic API key + /// * `model` - Default model to use (e.g. "claude-3-opus-20240229") + /// * `supported_tasks` - Map of tasks this provider supports + /// * `enabled` - Whether this provider is enabled + pub fn new( + api_key: String, + model: String, + supported_tasks: HashMap, + enabled: bool, + ) -> Self { + let base = BaseInstance::new( + "anthropic".to_string(), + api_key, + model, + supported_tasks, + enabled, + ); + Self { base } + } +} + +#[async_trait] +impl LlmInstance for AnthropicInstance { + /// Generates a completion using Anthropic's Claude API + /// + /// # Parameters + /// * `request` - The LLM request containing messages and parameters + /// + /// # Returns + /// * `LlmResult` - The response from the model or an error + async fn generate(&self, request: &LlmRequest) -> LlmResult { + if !self.base.is_enabled() { + return Err(LlmError::ProviderDisabled("Anthropic".to_string())); + } + + let mut headers = header::HeaderMap::new(); + headers.insert( + "x-api-key", + header::HeaderValue::from_str(self.base.api_key()) + .map_err(|e| LlmError::ConfigError(format!("Invalid API key format: {}", e)))?, + ); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + headers.insert( + "anthropic-version", + header::HeaderValue::from_static(constants::ANTHROPIC_API_VERSION), + ); + + let model = request + .model + .clone() + .unwrap_or_else(|| self.base.model().to_string()); + + // Extract system message and regular messages + let mut system_content = None; + let mut regular_messages = Vec::new(); + + for msg in &request.messages { + if msg.role == "system" { + system_content = Some(msg.content.clone()); + } else { + regular_messages.push(AnthropicMessage { + role: msg.role.clone(), + content: msg.content.clone(), + }); + } + } + + // Ensure we have at least one message + if regular_messages.is_empty() && system_content.is_some() { + regular_messages.push(AnthropicMessage { + role: "user".to_string(), + content: format!("Using this context: {}", system_content.unwrap()), + }); + system_content = None; + } + + if regular_messages.is_empty() { + return Err(LlmError::ApiError( + "Anthropic requires at least one message".to_string(), + )); + } + + let anthropic_request = AnthropicRequest { + model, + system: system_content, + messages: regular_messages, + max_tokens: request.max_tokens.unwrap_or(constants::DEFAULT_MAX_TOKENS), + temperature: request.temperature, + }; + + let response = self + .base + .client() + .post(constants::ANTHROPIC_API_ENDPOINT) + .headers(headers) + .json(&anthropic_request) + .send() + .await?; + + if !response.status().is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(LlmError::ApiError(format!( + "Anthropic API error: {}", + error_text + ))); + } + + let anthropic_response: AnthropicResponse = response.json().await?; + + if anthropic_response.content.is_empty() { + return Err(LlmError::ApiError("No response from Anthropic".to_string())); + } + + let usage = anthropic_response.usage.map(|u| TokenUsage { + prompt_tokens: u.input_tokens, + completion_tokens: u.output_tokens, + total_tokens: u.input_tokens + u.output_tokens, + }); + + let text = anthropic_response + .content + .iter() + .filter(|c| c.content_type == "text") + .map(|c| c.text.clone()) + .collect::>() + .join(""); + + Ok(LlmResponse { + content: text, + model: anthropic_response.model, + usage, + }) + } + + /// Returns provider name + fn get_name(&self) -> &str { + self.base.name() + } + + /// Returns current model name + fn get_model(&self) -> &str { + self.base.model() + } + + /// Returns supported tasks for this provider + fn get_supported_tasks(&self) -> &HashMap { + self.base.supported_tasks() + } + + /// Returns whether this provider is enabled + fn is_enabled(&self) -> bool { + self.base.is_enabled() + } +} diff --git a/crates/flyllm/src/providers/google.rs b/crates/flyllm/src/providers/google.rs new file mode 100644 index 0000000..79d81d8 --- /dev/null +++ b/crates/flyllm/src/providers/google.rs @@ -0,0 +1,305 @@ +use crate::constants; +use crate::errors::{LlmError, LlmResult}; +use crate::load_balancer::tasks::TaskDefinition; +use crate::providers::instances::{BaseInstance, LlmInstance}; +use crate::providers::types::{LlmRequest, LlmResponse, Message, TokenUsage}; + +use async_trait::async_trait; +use log::debug; +use reqwest::header; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Provider implementation for Google's Gemini AI models +pub struct GoogleInstance { + base: BaseInstance, +} + +/// Request structure for Google's Gemini API +#[derive(Serialize)] +struct GoogleGenerateContentRequest { + contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "generationConfig")] + generation_config: Option, +} + +/// Content structure for Google's Gemini API messages +#[derive(Serialize, Deserialize)] +struct GoogleContent { + role: String, + parts: Vec, +} + +/// Individual content part for Google's Gemini API +#[derive(Serialize, Deserialize)] +struct GooglePart { + text: String, +} + +/// Generation configuration for Google's Gemini API +#[derive(Serialize, Default)] +struct GoogleGenerationConfig { + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + // #[serde(skip_serializing_if = "Option::is_none")] + // top_k: Option, + // #[serde(skip_serializing_if = "Option::is_none")] + // top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "maxOutputTokens")] + max_output_tokens: Option, + // #[serde(skip_serializing_if = "Option::is_none")] + // stop_sequences: Option>, +} + +/// Response structure from Google's Gemini API +#[derive(Deserialize)] +struct GoogleGenerateContentResponse { + candidates: Vec, +} + +/// Individual candidate from Google's Gemini API response +#[derive(Deserialize)] +struct GoogleCandidate { + content: GoogleContent, + #[serde(rename = "tokenCount")] + #[serde(default)] + token_count: u32, // Note: Google provides total token count here + // safety_ratings: Vec, // We don't use this currently +} + +impl GoogleInstance { + /// Creates a new Google provider instance + /// + /// # Parameters + /// * `api_key` - Google API key + /// * `model` - Default model to use (e.g. "gemini-pro") + /// * `supported_tasks` - Map of tasks this provider supports + /// * `enabled` - Whether this provider is enabled + pub fn new( + api_key: String, + model: String, + supported_tasks: HashMap, + enabled: bool, + ) -> Self { + let base = BaseInstance::new( + "google".to_string(), + api_key, + model, + supported_tasks, + enabled, + ); + Self { base } + } + + /// Maps standard message format to Google's expected format + /// + /// This function handles several Google-specific requirements: + /// - Converts "assistant" role to "model" role + /// - Prepends system messages to the first user message + /// - Validates that the first message is from the user + /// + /// # Parameters + /// * `messages` - Array of messages in our standard format + /// + /// # Returns + /// * `LlmResult>` - Mapped contents or an error + fn map_messages_to_contents(messages: &[Message]) -> LlmResult> { + let mut contents = Vec::new(); + let mut system_prompt: Option = None; + let mut first_user_message_index: Option = None; + for (_, msg) in messages.iter().enumerate() { + match msg.role.as_str() { + "system" => { + if system_prompt.is_some() { + return Err(LlmError::ApiError("Multiple system messages are not supported by Google provider mapping.".to_string())); + } + system_prompt = Some(msg.content.clone()); + } + "user" | "model" | "assistant" => { + let role = if msg.role == "assistant" { + "model" + } else { + &msg.role + }; + if role == "user" && first_user_message_index.is_none() { + first_user_message_index = Some(contents.len()); + } + contents.push(GoogleContent { + role: role.to_string(), + parts: vec![GooglePart { + text: msg.content.clone(), + }], + }); + } + _ => { + log::warn!("Ignoring message with unknown role: {}", msg.role); + } + } + } + + if let Some(sys_prompt) = &system_prompt { + if let Some(user_idx) = first_user_message_index { + if let Some(user_content) = contents.get_mut(user_idx) { + if let Some(part) = user_content.parts.get_mut(0) { + part.text = format!("{}\n\n{}", sys_prompt, part.text); + } + } else { + return Err(LlmError::ApiError( + "System message provided but no user message found.".to_string(), + )); + } + } else { + return Err(LlmError::ApiError( + "System message provided but no user message found.".to_string(), + )); + } + } + + if contents.is_empty() { + return Err(LlmError::ApiError( + "No valid messages found for Google provider.".to_string(), + )); + } + if contents[0].role != "user" { + return Err(LlmError::ApiError(format!( + "Google chat must start with a 'user' role message, found '{}'.", + contents[0].role + ))); + } + Ok(contents) + } +} + +#[async_trait] +impl LlmInstance for GoogleInstance { + /// Generates a completion using Google's Gemini API + /// + /// # Parameters + /// * `request` - The LLM request containing messages and parameters + /// + /// # Returns + /// * `LlmResult` - The response from the model or an error + async fn generate(&self, request: &LlmRequest) -> LlmResult { + if !self.base.is_enabled() { + return Err(LlmError::ProviderDisabled("Google".to_string())); + } + + let model_name = self.base.model(); + let api_key = self.base.api_key(); + + let url = format!( + "{}/v1beta/models/{}:generateContent?key={}", + constants::GOOGLE_API_ENDPOINT_PREFIX, + model_name, + api_key + ); + + let mut headers = header::HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + + let contents = Self::map_messages_to_contents(&request.messages)?; + + let mut generation_config = GoogleGenerationConfig::default(); + generation_config.temperature = request.temperature; + generation_config.max_output_tokens = request.max_tokens; + + let google_request = GoogleGenerateContentRequest { + contents, + generation_config: Some(generation_config) + .filter(|gc| gc.temperature.is_some() || gc.max_output_tokens.is_some()), + }; + + let response = self + .base + .client() + .post(&url) + .headers(headers) + .json(&google_request) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let error_json: Result = response.json().await; + let error_details = match error_json { + Ok(json) => json + .get("error") + .and_then(|e| e.get("message")) + .and_then(|m| m.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("Unknown error structure: {}", json)), + Err(_) => "Failed to parse error response body".to_string(), + }; + + return Err(LlmError::ApiError(format!( + "Google API error ({}): {}", + status, error_details + ))); + } + + let google_response: GoogleGenerateContentResponse = + response.json().await.map_err(|e| { + LlmError::ApiError(format!("Failed to parse Google JSON response: {}", e)) + })?; + + if google_response.candidates.is_empty() { + return Err(LlmError::ApiError( + "No candidates returned from Google. Content may have been blocked.".to_string(), + )); + } + + let candidate = &google_response.candidates[0]; + + let combined_content = candidate + .content + .parts + .iter() + .map(|part| part.text.clone()) + .collect::>() + .join(""); + + let usage = if candidate.token_count > 0 { + // Simply use the token count as the total + Some(TokenUsage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: candidate.token_count, + }) + } else { + None + }; + + debug!("Google usage: {:?}", usage); + + Ok(LlmResponse { + content: combined_content, + model: model_name.to_string(), + usage, + }) + } + + /// Returns provider name + fn get_name(&self) -> &str { + self.base.name() + } + + /// Returns current model name + fn get_model(&self) -> &str { + self.base.model() + } + + /// Returns supported tasks for this provider + fn get_supported_tasks(&self) -> &HashMap { + &self.base.supported_tasks() + } + + /// Returns whether this provider is enabled + fn is_enabled(&self) -> bool { + self.base.is_enabled() + } +} diff --git a/crates/flyllm/src/providers/instances.rs b/crates/flyllm/src/providers/instances.rs new file mode 100644 index 0000000..f2af731 --- /dev/null +++ b/crates/flyllm/src/providers/instances.rs @@ -0,0 +1,177 @@ +use crate::errors::LlmResult; +use crate::load_balancer::tasks::TaskDefinition; +use crate::providers::anthropic::AnthropicInstance; +use crate::providers::google::GoogleInstance; +use crate::providers::lmstudio::LmStudioInstance; +use crate::providers::mistral::MistralInstance; +use crate::providers::ollama::OllamaInstance; +use crate::providers::openai::OpenAIInstance; +use crate::providers::types::{LlmRequest, LlmResponse, ProviderType}; +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use reqwest::Client; +use std::time::Duration; + +/// Common interface for all LLM instances +/// +/// This trait defines the interface that all LLM instances must implement +/// to be compatible with the load balancer system. +#[async_trait] +pub trait LlmInstance { + /// Generate a completion from the LLM instance + async fn generate(&self, request: &LlmRequest) -> LlmResult; + /// Get the name of this instance + fn get_name(&self) -> &str; + /// Get the currently configured model name + fn get_model(&self) -> &str; + /// Get the tasks this instance supports + fn get_supported_tasks(&self) -> &HashMap; + /// Check if this instance is enabled + fn is_enabled(&self) -> bool; +} + +/// Base instance implementation with common functionality +/// +/// Handles common properties and functionality shared across all instances: +/// - HTTP client with timeout +/// - API key storage +/// - Model selection +/// - Task support +/// - Enable/disable status +pub struct BaseInstance { + name: String, + client: Client, + api_key: String, + model: String, + supported_tasks: HashMap, + enabled: bool, +} + +impl BaseInstance { + /// Create a new Baseinstance with specified parameters + /// + /// # Parameters + /// * `name` - instance name identifier + /// * `api_key` - API key for authentication + /// * `model` - Default model identifier to use + /// * `supported_tasks` - Map of tasks this instance supports + /// * `enabled` - Whether this instance is enabled + pub fn new( + name: String, + api_key: String, + model: String, + supported_tasks: HashMap, + enabled: bool, + ) -> Self { + let client = Client::builder() + .timeout(Duration::from_secs(120)) + .build() + .expect("Failed to create HTTP client"); + + Self { + name, + client, + api_key, + model, + supported_tasks, + enabled, + } + } + + /// Get the HTTP client instance + pub fn client(&self) -> &Client { + &self.client + } + + /// Get the API key + pub fn api_key(&self) -> &str { + &self.api_key + } + + /// Get the current model name + pub fn model(&self) -> &str { + &self.model + } + + /// Check if this instance is enabled + pub fn is_enabled(&self) -> bool { + self.enabled + } + + /// Get the instance name + pub fn name(&self) -> &str { + &self.name + } + + /// Get the map of supported tasks + pub fn supported_tasks(&self) -> &HashMap { + &self.supported_tasks + } +} + +/// Factory function to create a instance instance based on type +/// +/// # Parameters +/// * `instance_type` - Which instance type to create +/// * `api_key` - API key for authentication +/// * `model` - Default model identifier +/// * `supported_tasks` - List of tasks this instance supports +/// * `enabled` - Whether this instance should be enabled +/// +/// # Returns +/// * Arc-wrapped trait object implementing Llminstance +pub fn create_instance( + instance_type: ProviderType, + api_key: String, + model: String, + supported_tasks: Vec, + enabled: bool, + endpoint_url: Option, +) -> Arc { + let supported_tasks: HashMap = supported_tasks + .into_iter() + .map(|task| (task.name.clone(), task)) + .collect(); + match instance_type { + ProviderType::Anthropic => Arc::new(AnthropicInstance::new( + api_key, + model, + supported_tasks, + enabled, + )), + ProviderType::OpenAI => Arc::new(OpenAIInstance::new( + api_key, + model, + supported_tasks, + enabled, + )), + ProviderType::Mistral => Arc::new(MistralInstance::new( + api_key, + model, + supported_tasks, + enabled, + )), + ProviderType::Google => Arc::new(GoogleInstance::new( + api_key, + model, + supported_tasks, + enabled, + )), + ProviderType::Ollama => Arc::new(OllamaInstance::new( + api_key, + model, + supported_tasks, + enabled, + endpoint_url, + )), + ProviderType::LmStudio => Arc::new(LmStudioInstance::new( + api_key, + model, + supported_tasks, + enabled, + endpoint_url, + )), + } +} diff --git a/crates/flyllm/src/providers/lmstudio.rs b/crates/flyllm/src/providers/lmstudio.rs new file mode 100644 index 0000000..d87f9eb --- /dev/null +++ b/crates/flyllm/src/providers/lmstudio.rs @@ -0,0 +1,172 @@ +use std::collections::HashMap; + +use crate::constants; +use crate::errors::{LlmError, LlmResult}; +use crate::load_balancer::tasks::TaskDefinition; +use crate::providers::instances::{BaseInstance, LlmInstance}; +use crate::providers::types::{LlmRequest, LlmResponse, Message, TokenUsage}; + +use async_trait::async_trait; +use reqwest::header; +use serde::{Deserialize, Serialize}; + +/// Provider implementation for LM Studio using its OpenAI-compatible API +pub struct LmStudioInstance { + base: BaseInstance, + chat_endpoint: String, +} + +impl LmStudioInstance { + pub fn new( + api_key: String, + model: String, + supported_tasks: HashMap, + enabled: bool, + endpoint_url: Option, + ) -> Self { + let chat_endpoint = Self::build_chat_endpoint(endpoint_url.clone()); + let base = BaseInstance::new( + "lmstudio".to_string(), + api_key, + model, + supported_tasks, + enabled, + ); + + Self { + base, + chat_endpoint, + } + } + + fn build_chat_endpoint(endpoint_url: Option) -> String { + let base = endpoint_url.unwrap_or_else(|| constants::LM_STUDIO_API_ENDPOINT.to_string()); + let trimmed = base.trim_end_matches('/'); + if trimmed.ends_with("/v1/chat/completions") { + trimmed.to_string() + } else if trimmed.ends_with("/v1") { + format!("{}/chat/completions", trimmed) + } else { + format!("{}/v1/chat/completions", trimmed) + } + } +} + +#[derive(Serialize)] +struct LmStudioRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, +} + +#[derive(Deserialize)] +struct LmStudioResponse { + choices: Vec, + model: String, + usage: Option, +} + +#[derive(Deserialize)] +struct LmStudioChoice { + message: Message, +} + +#[derive(Deserialize)] +struct LmStudioUsage { + prompt_tokens: u32, + completion_tokens: u32, + total_tokens: u32, +} + +#[async_trait] +impl LlmInstance for LmStudioInstance { + async fn generate(&self, request: &LlmRequest) -> LlmResult { + if !self.base.is_enabled() { + return Err(LlmError::ProviderDisabled("LmStudio".to_string())); + } + + let mut headers = header::HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + + if !self.base.api_key().is_empty() { + if let Ok(value) = + header::HeaderValue::from_str(&format!("Bearer {}", self.base.api_key())) + { + headers.insert(header::AUTHORIZATION, value); + } + } + + let model = request + .model + .clone() + .unwrap_or_else(|| self.base.model().to_string()); + + let payload = LmStudioRequest { + model, + messages: request.messages.clone(), + max_tokens: request.max_tokens, + temperature: request.temperature, + }; + + let response = self + .base + .client() + .post(&self.chat_endpoint) + .headers(headers) + .json(&payload) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(LlmError::ApiError(format!( + "LmStudio API error ({}): {}", + status, error_text + ))); + } + + let parsed: LmStudioResponse = response.json().await?; + let choice = parsed + .choices + .get(0) + .ok_or_else(|| LlmError::ApiError("LmStudio returned no choices".to_string()))?; + + let usage = parsed.usage.map(|u| TokenUsage { + prompt_tokens: u.prompt_tokens, + completion_tokens: u.completion_tokens, + total_tokens: u.total_tokens, + }); + + Ok(LlmResponse { + content: choice.message.content.clone(), + model: parsed.model, + usage, + }) + } + + fn get_name(&self) -> &str { + self.base.name() + } + + fn get_model(&self) -> &str { + self.base.model() + } + + fn get_supported_tasks(&self) -> &HashMap { + &self.base.supported_tasks() + } + + fn is_enabled(&self) -> bool { + self.base.is_enabled() + } +} diff --git a/crates/flyllm/src/providers/mistral.rs b/crates/flyllm/src/providers/mistral.rs new file mode 100644 index 0000000..d9eab09 --- /dev/null +++ b/crates/flyllm/src/providers/mistral.rs @@ -0,0 +1,218 @@ +use crate::constants; +use crate::errors::{LlmError, LlmResult}; +use crate::load_balancer::tasks::TaskDefinition; +use crate::providers::instances::{BaseInstance, LlmInstance}; +use crate::providers::types::{LlmRequest, LlmResponse, Message, TokenUsage}; + +use async_trait::async_trait; +use reqwest::header; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Provider implementation for Mistral AI's API +pub struct MistralInstance { + base: BaseInstance, +} + +/// Request structure for Mistral AI's chat completion API +#[derive(Serialize)] +struct MistralRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, +} + +/// Response structure from Mistral AI's chat completion API +#[derive(Deserialize, Debug)] +struct MistralResponse { + id: String, + model: String, + object: String, + created: u64, + choices: Vec, + usage: Option, +} + +/// Individual choice from Mistral's response +#[derive(Deserialize, Debug)] +struct MistralChoice { + index: u32, // Removed underscore prefix + message: Message, + finish_reason: Option, +} + +/// Token usage information from Mistral +#[derive(Deserialize, Debug)] +struct MistralUsage { + prompt_tokens: u32, + completion_tokens: u32, + total_tokens: u32, +} + +impl MistralInstance { + /// Creates a new Mistral provider instance + /// + /// # Parameters + /// * `api_key` - Mistral API key + /// * `model` - Default model to use (e.g. "mistral-large") + /// * `supported_tasks` - Map of tasks this provider supports + /// * `enabled` - Whether this provider is enabled + pub fn new( + api_key: String, + model: String, + supported_tasks: HashMap, + enabled: bool, + ) -> Self { + let base = BaseInstance::new( + "mistral".to_string(), + api_key, + model, + supported_tasks, + enabled, + ); + Self { base } + } +} + +#[async_trait] +impl LlmInstance for MistralInstance { + /// Generates a completion using Mistral AI's API + /// + /// # Parameters + /// * `request` - The LLM request containing messages and parameters + /// + /// # Returns + /// * `LlmResult` - The response from the model or an error + async fn generate(&self, request: &LlmRequest) -> LlmResult { + if !self.base.is_enabled() { + return Err(LlmError::ProviderDisabled("Mistral".to_string())); + } + + let mut headers = header::HeaderMap::new(); + headers.insert( + header::AUTHORIZATION, + header::HeaderValue::from_str(&format!("Bearer {}", self.base.api_key())) + .map_err(|e| LlmError::ConfigError(format!("Invalid API key format: {}", e)))?, + ); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + headers.insert( + header::ACCEPT, + header::HeaderValue::from_static("application/json"), + ); + + let model = request + .model + .clone() + .unwrap_or_else(|| self.base.model().to_string()); + + if request.messages.is_empty() { + return Err(LlmError::ApiError( + "Mistral requires at least one message".to_string(), + )); + } + + let mistral_request = MistralRequest { + model, + messages: request + .messages + .iter() + .map(|m| Message { + role: match m.role.as_str() { + "system" | "user" | "assistant" => m.role.clone(), + _ => "user".to_string(), + }, + content: m.content.clone(), + }) + .collect(), + temperature: request.temperature, + max_tokens: request.max_tokens, + }; + + let response = self + .base + .client() + .post(constants::MISTRAL_API_ENDPOINT) + .headers(headers) + .json(&mistral_request) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error reading response body".to_string()); + return Err(LlmError::ApiError(format!( + "Mistral API error ({}): {}", + status, error_text + ))); + } + + // Debug: Log raw response body for inspection if needed + let response_body = response.text().await.map_err(|e| { + LlmError::ApiError(format!("Failed to read Mistral response body: {}", e)) + })?; + + // Try to parse the response as JSON + let mistral_response: MistralResponse = + serde_json::from_str(&response_body).map_err(|e| { + // Provide more context in the error message + LlmError::ApiError(format!( + "Failed to parse Mistral JSON response: {}. Response body: {}", + e, + if response_body.len() > 200 { + format!("{}... (truncated)", &response_body[..200]) + } else { + response_body.clone() + } + )) + })?; + + if mistral_response.choices.is_empty() { + return Err(LlmError::ApiError( + "No choices returned from Mistral".to_string(), + )); + } + + let choice = &mistral_response.choices[0]; + + let usage = mistral_response.usage.map(|u| TokenUsage { + prompt_tokens: u.prompt_tokens, + completion_tokens: u.completion_tokens, + total_tokens: u.total_tokens, + }); + + Ok(LlmResponse { + content: choice.message.content.clone(), + model: mistral_response.model, + usage, + }) + } + + /// Returns provider name + fn get_name(&self) -> &str { + self.base.name() + } + + /// Returns current model name + fn get_model(&self) -> &str { + self.base.model() + } + + /// Returns supported tasks for this provider + fn get_supported_tasks(&self) -> &HashMap { + &self.base.supported_tasks() + } + + /// Returns whether this provider is enabled + fn is_enabled(&self) -> bool { + self.base.is_enabled() + } +} diff --git a/crates/flyllm/src/providers/mod.rs b/crates/flyllm/src/providers/mod.rs new file mode 100644 index 0000000..a3fb0ea --- /dev/null +++ b/crates/flyllm/src/providers/mod.rs @@ -0,0 +1,27 @@ +/// Module for various LLM provider implementations +/// +/// This module contains implementations for different LLM providers: +/// - Anthropic (Claude models) +/// - OpenAI (GPT models) +/// - Mistral AI +/// - Google (Gemini models) +/// - Ollama +/// +/// Each provider implements a common interface for generating text +/// completions through their respective APIs. +pub mod anthropic; +pub mod google; +pub mod instances; +pub mod lmstudio; +pub mod mistral; +pub mod model_discovery; +pub mod ollama; +pub mod openai; +pub mod types; + +pub use anthropic::AnthropicInstance; +pub use instances::{create_instance, LlmInstance}; +pub use lmstudio::LmStudioInstance; +pub use model_discovery::ModelDiscovery; +pub use openai::OpenAIInstance; +pub use types::{LlmRequest, LlmResponse, Message, ModelInfo, ProviderType, TokenUsage}; diff --git a/crates/flyllm/src/providers/model_discovery.rs b/crates/flyllm/src/providers/model_discovery.rs new file mode 100644 index 0000000..3f2ea61 --- /dev/null +++ b/crates/flyllm/src/providers/model_discovery.rs @@ -0,0 +1,505 @@ +use crate::constants; +use crate::errors::{LlmError, LlmResult}; +use crate::providers::types::{ModelInfo, ProviderType}; +use reqwest::{header, Client}; +use serde::Deserialize; +use std::time::Duration; + +/// Helper module for listing available models from providers +/// without requiring a fully initialized provider instance +pub struct ModelDiscovery; + +impl ModelDiscovery { + /// Create a standardized HTTP client for model discovery + fn create_client() -> Client { + Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .expect("Failed to create HTTP client") + } + + /// List available models from Anthropic + /// + /// # Parameters + /// * `api_key` - Anthropic API key + /// + /// # Returns + /// * Vector of ModelInfo structs containing model names + pub async fn list_anthropic_models(api_key: &str) -> LlmResult> { + let client = Self::create_client(); + + let mut headers = header::HeaderMap::new(); + headers.insert( + "x-api-key", + header::HeaderValue::from_str(api_key).map_err(|e| { + LlmError::ConfigError(format!("Invalid API key format for Anthropic: {}", e)) + })?, + ); + headers.insert( + "anthropic-version", + header::HeaderValue::from_static(constants::ANTHROPIC_API_VERSION), + ); + + let models_endpoint = "https://api.anthropic.com/v1/models"; + + let response = client.get(models_endpoint).headers(headers).send().await?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_else(|_| { + format!( + "Unknown error reading error response body, status: {}", + status + ) + }); + return Err(LlmError::ApiError(format!( + "Anthropic API error ({}): {}", + status, error_text + ))); + } + + let response_bytes = response.bytes().await?; + + #[derive(Deserialize, Debug)] + struct AnthropicModelsResponse { + data: Vec, + } + #[derive(Deserialize, Debug)] + struct AnthropicModelInfo { + id: String, + display_name: String, + } + + let anthropic_response: AnthropicModelsResponse = serde_json::from_slice(&response_bytes) + .map_err(|e| { + let snippet_len = std::cmp::min(response_bytes.len(), 256); + let response_snippet = + String::from_utf8_lossy(response_bytes.get(0..snippet_len).unwrap_or_default()); + LlmError::ParseError(format!( + "Error decoding Anthropic models JSON: {}. Response snippet: '{}'", + e, response_snippet + )) + })?; + + let models = anthropic_response + .data + .into_iter() + .map(|m| ModelInfo { + name: m.id, + provider: ProviderType::Anthropic, + }) + .collect(); + + Ok(models) + } + + /// List available models from OpenAI + /// + /// # Parameters + /// * `api_key` - OpenAI API key + /// + /// # Returns + /// * Vector of ModelInfo structs containing model names + pub async fn list_openai_models(api_key: &str) -> LlmResult> { + let client = Self::create_client(); + + let mut headers = header::HeaderMap::new(); + headers.insert( + header::AUTHORIZATION, + header::HeaderValue::from_str(&format!("Bearer {}", api_key)) + .map_err(|e| LlmError::ConfigError(format!("Invalid API key format: {}", e)))?, + ); + + let models_endpoint = "https://api.openai.com/v1/models"; + + let response = client.get(models_endpoint).headers(headers).send().await?; + + if !response.status().is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(LlmError::ApiError(format!( + "OpenAI API error: {}", + error_text + ))); + } + + #[derive(Deserialize)] + struct OpenAIModelsResponse { + data: Vec, + } + + #[derive(Deserialize)] + struct OpenAIModelInfo { + id: String, + } + + let openai_response: OpenAIModelsResponse = response.json().await?; + + let models = openai_response + .data + .into_iter() + .filter(|m| m.id.starts_with("gpt-")) + .map(|m| ModelInfo { + name: m.id, + provider: ProviderType::OpenAI, + }) + .collect(); + + Ok(models) + } + + /// List available models from Mistral + /// + /// # Parameters + /// * `api_key` - Mistral API key + /// + /// # Returns + /// * Vector of ModelInfo structs containing model names + pub async fn list_mistral_models(api_key: &str) -> LlmResult> { + let client = Self::create_client(); + + let mut headers = header::HeaderMap::new(); + headers.insert( + header::AUTHORIZATION, + header::HeaderValue::from_str(&format!("Bearer {}", api_key)) + .map_err(|e| LlmError::ConfigError(format!("Invalid API key format: {}", e)))?, + ); + + let models_endpoint = "https://api.mistral.ai/v1/models"; + + let response = client.get(models_endpoint).headers(headers).send().await?; + + if !response.status().is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(LlmError::ApiError(format!( + "Mistral API error: {}", + error_text + ))); + } + + #[derive(Deserialize)] + struct MistralModelsResponse { + data: Vec, + } + + #[derive(Deserialize)] + struct MistralModelInfo { + id: String, + } + + let mistral_response: MistralModelsResponse = response.json().await?; + + let models = mistral_response + .data + .into_iter() + .map(|m| ModelInfo { + name: m.id, + provider: ProviderType::Mistral, + }) + .collect(); + + Ok(models) + } + + /// List available models from Google + /// + /// # Parameters + /// * `api_key` - Google API key + /// + /// # Returns + /// * Vector of ModelInfo structs containing model names + pub async fn list_google_models(api_key: &str) -> LlmResult> { + let client = Self::create_client(); + + let models_endpoint = format!( + "{}/v1beta/models?key={}", + constants::GOOGLE_API_ENDPOINT_PREFIX, + api_key + ); + + let response = client.get(&models_endpoint).send().await?; + + if !response.status().is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(LlmError::ApiError(format!( + "Google API error: {}", + error_text + ))); + } + + #[derive(Deserialize)] + struct GoogleModelsResponse { + models: Vec, + } + + #[derive(Deserialize)] + struct GoogleModelInfo { + name: String, + #[serde(default)] + display_name: Option, + } + + let google_response: GoogleModelsResponse = response.json().await?; + + let models = google_response + .models + .into_iter() + .map(|m| { + let name = m + .display_name + .unwrap_or_else(|| m.name.split('/').last().unwrap_or(&m.name).to_string()); + + ModelInfo { + name, + provider: ProviderType::Google, + } + }) + .collect(); + + Ok(models) + } + + /// List available models from Ollama + /// + /// # Parameters + /// * `base_url` - Optional base URL for Ollama API, defaults to localhost + /// + /// # Returns + /// * Vector of ModelInfo structs containing model names + pub async fn list_ollama_models(base_url: Option<&str>) -> LlmResult> { + let client = Self::create_client(); + + // Use provided base URL or default to localhost + let base_url = base_url.unwrap_or("http://localhost:11434"); + let models_endpoint = format!("{}/api/tags", base_url.trim_end_matches('/')); + + let response = client.get(&models_endpoint).send().await?; + + if !response.status().is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(LlmError::ApiError(format!( + "Ollama API error: {}", + error_text + ))); + } + + #[derive(Deserialize)] + struct OllamaModelsResponse { + models: Vec, + } + + #[derive(Deserialize)] + struct OllamaModelInfo { + name: String, + } + + let ollama_response: OllamaModelsResponse = response.json().await?; + + let models = ollama_response + .models + .into_iter() + .map(|m| ModelInfo { + name: m.name, + provider: ProviderType::Ollama, + }) + .collect(); + + Ok(models) + } + + /// List available models from LM Studio + /// + /// # Parameters + /// * `base_url` - Optional base URL for LM Studio API, defaults to localhost + /// + /// # Returns + /// * Vector of ModelInfo structs containing model names + pub async fn list_lmstudio_models(base_url: Option<&str>) -> LlmResult> { + let client = Self::create_client(); + + let base_url = base_url.unwrap_or("http://127.0.0.1:1234"); + let models_endpoint = format!("{}/v1/models", base_url.trim_end_matches('/')); + + let response = client.get(&models_endpoint).send().await?; + let status = response.status(); + let body = response.bytes().await?; + + if !status.is_success() { + let message = String::from_utf8_lossy(&body).to_string(); + return Err(LlmError::ApiError(format!( + "LmStudio API error: {}", + message + ))); + } + + #[derive(Deserialize)] + struct LmStudioModelInfo { + id: String, + #[serde(default)] + object: Option, + #[serde(default)] + owned_by: Option, + } + + #[derive(Deserialize)] + struct LmStudioModelsResponse { + data: Vec, + } + + let model_names = match serde_json::from_slice::(&body) { + Ok(parsed) => parsed + .data + .into_iter() + .map(|m| m.id) + .collect::>(), + Err(_) => { + let value: serde_json::Value = serde_json::from_slice(&body)?; + if let Some(models) = value.get("models").and_then(|v| v.as_array()) { + models + .iter() + .filter_map(|model| model.get("id").and_then(|id| id.as_str())) + .map(|name| name.to_string()) + .collect() + } else { + return Err(LlmError::ParseError(format!( + "Unexpected LM Studio models response: {}", + String::from_utf8_lossy(&body) + ))); + } + } + }; + + Ok(model_names + .into_iter() + .map(|name| ModelInfo { + name, + provider: ProviderType::LmStudio, + }) + .collect()) + } + + /// List all models from a specific provider + /// + /// # Parameters + /// * `provider_type` - Type of provider to query + /// * `api_key` - API key for authentication + /// * `base_url` - Optional base URL (only used for Ollama) + /// + /// # Returns + /// * Vector of ModelInfo structs containing model names + pub async fn list_models( + provider_type: ProviderType, + api_key: &str, + base_url: Option<&str>, + ) -> LlmResult> { + match provider_type { + ProviderType::Anthropic => Self::list_anthropic_models(api_key).await, + ProviderType::OpenAI => Self::list_openai_models(api_key).await, + ProviderType::Mistral => Self::list_mistral_models(api_key).await, + ProviderType::Google => Self::list_google_models(api_key).await, + ProviderType::Ollama => Self::list_ollama_models(base_url).await, + ProviderType::LmStudio => Self::list_lmstudio_models(base_url).await, + } + } +} + +#[cfg(test)] +mod tests { + use super::ModelDiscovery; + use crate::providers::types::ProviderType; + use httpmock::prelude::*; + + fn sample_models_response() -> serde_json::Value { + serde_json::json!({ + "data": [ + { "id": "local/model-1", "object": "model", "owned_by": "local" }, + { "id": "local/model-2", "object": "model", "owned_by": "local" } + ], + "object": "list" + }) + } + + #[tokio::test] + async fn list_lmstudio_models_returns_models_from_server() { + let server = MockServer::start_async().await; + + let _mock = server + .mock_async(|when, then| { + when.method(GET).path("/v1/models"); + then.status(200).json_body(sample_models_response()); + }) + .await; + + let base_url = format!("http://{}", server.address()); + + let result = ModelDiscovery::list_lmstudio_models(Some(&base_url)) + .await + .expect("models should be parsed"); + + let names: Vec = result.into_iter().map(|m| m.name).collect(); + assert_eq!( + names, + vec!["local/model-1".to_string(), "local/model-2".to_string()] + ); + } + + #[tokio::test] + async fn list_lmstudio_models_returns_error_on_bad_status() { + let server = MockServer::start_async().await; + + let _mock = server + .mock_async(|when, then| { + when.method(GET).path("/v1/models"); + then.status(404).body("Not Found"); + }) + .await; + + let base_url = format!("http://{}", server.address()); + + let err = ModelDiscovery::list_lmstudio_models(Some(&base_url)) + .await + .expect_err("should return error for non-success status"); + + match err { + crate::errors::LlmError::ApiError(message) => { + assert!( + message.contains("LmStudio"), + "unexpected message: {}", + message + ); + } + other => panic!("expected ApiError, got {:?}", other), + } + } + + #[tokio::test] + async fn list_models_dispatches_to_lmstudio() { + let server = MockServer::start_async().await; + + let _mock = server + .mock_async(|when, then| { + when.method(GET).path("/v1/models"); + then.status(200).json_body(sample_models_response()); + }) + .await; + + let base_url = format!("http://{}", server.address()); + + let models = ModelDiscovery::list_models(ProviderType::LmStudio, "", Some(&base_url)) + .await + .expect("models should succeed"); + + assert_eq!(models.len(), 2); + } +} diff --git a/crates/flyllm/src/providers/ollama.rs b/crates/flyllm/src/providers/ollama.rs new file mode 100644 index 0000000..ff05683 --- /dev/null +++ b/crates/flyllm/src/providers/ollama.rs @@ -0,0 +1,240 @@ +use crate::constants; +use crate::errors::{LlmError, LlmResult}; +use crate::load_balancer::tasks::TaskDefinition; +use crate::providers::instances::{BaseInstance, LlmInstance}; +use crate::providers::types::{LlmRequest, LlmResponse, Message, TokenUsage}; +use async_trait::async_trait; +use reqwest::header; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use url::Url; + +/// Provider implementation for Ollama (local LLMs) +pub struct OllamaInstance { + base: BaseInstance, + // Specific URL for this provider instance + endpoint_url: String, +} + +/// Request structure for Ollama's chat API +#[derive(Serialize)] +struct OllamaRequest { + model: String, + messages: Vec, + stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + options: Option, +} + +#[derive(Serialize, Default)] +struct OllamaOptions { + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + num_predict: Option, // Corresponds to max_tokens +} + +/// Response structure from Ollama's chat API (non-streaming) +#[derive(Deserialize, Debug)] +struct OllamaResponse { + model: String, + created_at: String, + message: Message, + done: bool, // Should be true for non-streaming response + #[serde(default)] // Use default (0) if not present + prompt_eval_count: u32, + #[serde(default)] // Use default (0) if not present + eval_count: u32, // Corresponds roughly to completion tokens +} + +impl OllamaInstance { + /// Creates a new Ollama provider instance + /// + /// # Parameters + /// * `api_key` - Unused for Ollama by default, but kept for consistency. Could be repurposed (e.g., for future auth or endpoint override). + /// * `model` - Default model to use (e.g., "llama3") + /// * `supported_tasks` - Map of tasks this provider supports + /// * `enabled` - Whether this provider is enabled + /// * `endpoint_url` - Optional base endpoint URL override. If None, uses the default from constants. + pub fn new( + api_key: String, + model: String, + supported_tasks: HashMap, + enabled: bool, + endpoint_url: Option, + ) -> Self { + // Determine the endpoint: use provided one or default + let base_endpoint = + endpoint_url.unwrap_or_else(|| constants::OLLAMA_API_ENDPOINT.to_string()); + + // Validate and ensure the path ends correctly + let final_endpoint = match Url::parse(&base_endpoint) { + Ok(mut url) => { + if !url.path().ends_with("/api/chat") { + if url.path() == "/" { + url.set_path("api/chat"); + } else { + let current_path = url.path().trim_end_matches('/'); + url.set_path(&format!("{}/api/chat", current_path)); + } + } + url.to_string() + } + Err(_) => { + eprintln!( + "Warning: Invalid Ollama endpoint URL '{}' provided. Falling back to default: {}", + base_endpoint, constants::OLLAMA_API_ENDPOINT + ); + constants::OLLAMA_API_ENDPOINT.to_string() + } + }; + + // Create BaseProvider with the actual API key (even if empty/unused) + let base = BaseInstance::new( + "ollama".to_string(), + api_key, + model, + supported_tasks, + enabled, + ); + + Self { + base, + endpoint_url: final_endpoint, + } + } +} + +#[async_trait] +impl LlmInstance for OllamaInstance { + /// Generates a completion using Ollama's API + /// + /// # Parameters + /// * `request` - The LLM request containing messages and parameters + /// + /// # Returns + /// * `LlmResult` - The response from the model or an error + async fn generate(&self, request: &LlmRequest) -> LlmResult { + if !self.base.is_enabled() { + return Err(LlmError::ProviderDisabled("Ollama".to_string())); + } + + let mut headers = header::HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + + // Add Authorization header if an API key is actually provided and non-empty + if !self.base.api_key().is_empty() { + match header::HeaderValue::from_str(&format!("Bearer {}", self.base.api_key())) { + Ok(val) => { + headers.insert(header::AUTHORIZATION, val); + } + Err(e) => { + return Err(LlmError::ConfigError(format!( + "Invalid API key format for Ollama: {}", + e + ))) + } + } + } + + let model = request + .model + .clone() + .unwrap_or_else(|| self.base.model().to_string()); + + // Map common parameters to Ollama options + let mut options = OllamaOptions::default(); + if request.temperature.is_some() { + options.temperature = request.temperature; + } + if request.max_tokens.is_some() { + options.num_predict = request.max_tokens; + } + + let ollama_request = OllamaRequest { + model, + messages: request.messages.clone(), + stream: false, + options: if options.temperature.is_some() || options.num_predict.is_some() { + Some(options) + } else { + None + }, + }; + + let response = self + .base + .client() + .post(&self.endpoint_url) + .headers(headers) + .json(&ollama_request) + .send() + .await?; + + let response_status = response.status(); + if !response_status.is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| format!("Unknown error. Status: {}", response_status)); + return Err(LlmError::ApiError(format!( + "Ollama API error: {}", + error_text + ))); + } + + let response_text = response.text().await?; + if response_text.is_empty() { + return Err(LlmError::ApiError( + "Received empty response body from Ollama".to_string(), + )); + } + + // Attempt to parse the JSON response + let ollama_response: OllamaResponse = + serde_json::from_str(&response_text).map_err(|e| { + LlmError::ApiError(format!( + "Failed to parse Ollama JSON response: {}. Body: {}", + e, response_text + )) + })?; + + // Map Ollama token counts to unified format. + // Note: Ollama's `eval_count` is often used for completion tokens. `prompt_eval_count` for prompt. + // The exact definition might vary slightly depending on the model and Ollama version. + let usage = Some(TokenUsage { + prompt_tokens: ollama_response.prompt_eval_count, + completion_tokens: ollama_response.eval_count, + total_tokens: ollama_response.prompt_eval_count + ollama_response.eval_count, + }); + + Ok(LlmResponse { + content: ollama_response.message.content.clone(), + model: ollama_response.model, + usage, + }) + } + + /// Returns provider name + fn get_name(&self) -> &str { + self.base.name() + } + + /// Returns current model name + fn get_model(&self) -> &str { + self.base.model() + } + + /// Returns supported tasks for this provider + fn get_supported_tasks(&self) -> &HashMap { + &self.base.supported_tasks() + } + + /// Returns whether this provider is enabled + fn is_enabled(&self) -> bool { + self.base.is_enabled() + } +} diff --git a/crates/flyllm/src/providers/openai.rs b/crates/flyllm/src/providers/openai.rs new file mode 100644 index 0000000..9c31206 --- /dev/null +++ b/crates/flyllm/src/providers/openai.rs @@ -0,0 +1,172 @@ +use std::collections::HashMap; + +use crate::constants; +use crate::errors::{LlmError, LlmResult}; +use crate::load_balancer::tasks::TaskDefinition; +use crate::providers::instances::{BaseInstance, LlmInstance}; +use crate::providers::types::{LlmRequest, LlmResponse, Message, TokenUsage}; + +use async_trait::async_trait; +use reqwest::header; +use serde::{Deserialize, Serialize}; + +/// Provider implementation for OpenAI's API (GPT models) +pub struct OpenAIInstance { + base: BaseInstance, +} + +/// Request structure for OpenAI's chat completion API +/// Maps to the format expected by OpenAI's API +#[derive(Serialize)] +struct OpenAIRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, +} + +/// Response structure from OpenAI's chat completion API +#[derive(Deserialize)] +struct OpenAIResponse { + choices: Vec, + model: String, + usage: Option, +} + +/// Individual choice from OpenAI's response +#[derive(Deserialize)] +struct OpenAIChoice { + message: Message, +} + +/// Token usage information from OpenAI +#[derive(Deserialize)] +struct OpenAIUsage { + prompt_tokens: u32, + completion_tokens: u32, + total_tokens: u32, +} + +impl OpenAIInstance { + /// Creates a new OpenAI provider instance + /// + /// # Parameters + /// * `api_key` - OpenAI API key + /// * `model` - Default model to use (e.g. "gpt-4-turbo") + /// * `supported_tasks` - Map of tasks this provider supports + /// * `enabled` - Whether this provider is enabled + pub fn new( + api_key: String, + model: String, + supported_tasks: HashMap, + enabled: bool, + ) -> Self { + let base = BaseInstance::new( + "openai".to_string(), + api_key, + model, + supported_tasks, + enabled, + ); + Self { base } + } +} + +#[async_trait] +impl LlmInstance for OpenAIInstance { + /// Generates a completion using OpenAI's API + /// + /// # Parameters + /// * `request` - The LLM request containing messages and parameters + /// + /// # Returns + /// * `LlmResult` - The response from the model or an error + async fn generate(&self, request: &LlmRequest) -> LlmResult { + if !self.base.is_enabled() { + return Err(LlmError::ProviderDisabled("OpenAI".to_string())); + } + + let mut headers = header::HeaderMap::new(); + headers.insert( + header::AUTHORIZATION, + header::HeaderValue::from_str(&format!("Bearer {}", self.base.api_key())) + .map_err(|e| LlmError::ConfigError(format!("Invalid API key format: {}", e)))?, + ); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + + let model = request + .model + .clone() + .unwrap_or_else(|| self.base.model().to_string()); + + let openai_request = OpenAIRequest { + model, + messages: request.messages.clone(), + max_tokens: request.max_tokens, + temperature: request.temperature, + }; + + let response = self + .base + .client() + .post(constants::OPENAI_API_ENDPOINT) + .headers(headers) + .json(&openai_request) + .send() + .await?; + + if !response.status().is_success() { + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(LlmError::ApiError(format!( + "OpenAI API error: {}", + error_text + ))); + } + + let openai_response: OpenAIResponse = response.json().await?; + + if openai_response.choices.is_empty() { + return Err(LlmError::ApiError("No response from OpenAI".to_string())); + } + + let usage = openai_response.usage.map(|u| TokenUsage { + prompt_tokens: u.prompt_tokens, + completion_tokens: u.completion_tokens, + total_tokens: u.total_tokens, + }); + + Ok(LlmResponse { + content: openai_response.choices[0].message.content.clone(), + model: openai_response.model, + usage, + }) + } + + /// Returns provider name + fn get_name(&self) -> &str { + self.base.name() + } + + /// Returns current model name + fn get_model(&self) -> &str { + self.base.model() + } + + /// Returns supported tasks for this provider + fn get_supported_tasks(&self) -> &HashMap { + &self.base.supported_tasks() + } + + /// Returns whether this provider is enabled + fn is_enabled(&self) -> bool { + self.base.is_enabled() + } +} diff --git a/crates/flyllm/src/providers/types.rs b/crates/flyllm/src/providers/types.rs new file mode 100644 index 0000000..b1fdcb0 --- /dev/null +++ b/crates/flyllm/src/providers/types.rs @@ -0,0 +1,89 @@ +use serde::{Deserialize, Serialize}; + +/// Enum representing the different LLM providers supported +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub enum ProviderType { + Anthropic, + OpenAI, + Mistral, + Google, + Ollama, + LmStudio, +} + +/// Unified request structure used across all providers +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct LlmRequest { + pub messages: Vec, + pub model: Option, + pub max_tokens: Option, + pub temperature: Option, +} + +/// Standard message format used across providers +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Message { + pub role: String, + pub content: String, +} + +/// Unified response structure returned by all providers +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct LlmResponse { + pub content: String, + pub model: String, + pub usage: Option, +} + +/// Token usage information returned by providers +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct TokenUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +impl Default for TokenUsage { + fn default() -> Self { + Self { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + } + } +} + +/// Information about an LLM model +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ModelInfo { + pub name: String, + pub provider: ProviderType, +} + +/// Display implementation for ProviderType +impl std::fmt::Display for ProviderType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ProviderType::Anthropic => write!(f, "Anthropic"), + ProviderType::OpenAI => write!(f, "OpenAI"), + ProviderType::Mistral => write!(f, "Mistral"), + ProviderType::Google => write!(f, "Google"), + ProviderType::Ollama => write!(f, "Ollama"), + ProviderType::LmStudio => write!(f, "LmStudio"), + } + } +} + +impl From<&str> for ProviderType { + fn from(value: &str) -> Self { + match value { + "Anthropic" => ProviderType::Anthropic, + "OpenAI" => ProviderType::OpenAI, + "Mistral" => ProviderType::Mistral, + "Google" => ProviderType::Google, + "Ollama" => ProviderType::Ollama, + "LmStudio" => ProviderType::LmStudio, + _ => panic!("Unknown provider: {}", value), + } + } +} diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index ff674db..1b9893e 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1916,8 +1916,6 @@ dependencies = [ [[package]] name = "flyllm" version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85d53872ad1ff95a3edb4f2c0504eb27f506309562403106b39aa483bb99e8a5" dependencies = [ "async-trait", "env_logger", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index fe0d87c..d6c0e68 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -57,3 +57,5 @@ regex = "1.11.2" tauri-plugin-shell = "2" chromiumoxide = { version = "0.7", features = ["tokio-runtime", "_fetcher-rusttls-tokio"], default-features = false } base64 = "0.22.1" +[patch.crates-io] +flyllm = { path = "../crates/flyllm" } diff --git a/src-tauri/src/commands/api.rs b/src-tauri/src/commands/api.rs index 3d66506..9b9c21c 100644 --- a/src-tauri/src/commands/api.rs +++ b/src-tauri/src/commands/api.rs @@ -2,7 +2,8 @@ use crate::commands::structure::responses; use crate::config::api_config::LlmUserConfig; use crate::config::app_settings::AgentSettings; use crate::config::keystore::ApiKeystore; -use crate::config::utils::{internal_str_to_provider_type, str_to_agent_type}; +use crate::config::utils::str_to_agent_type; +use crate::constants; use crate::services::llm_service::agents::AgentType; use crate::state::AppState; use std::collections::HashMap; @@ -155,6 +156,7 @@ pub fn get_provider_config(provider: String) -> Result Result, String> { } #[tauri::command] -pub async fn set_ollama_endpoint( +pub async fn set_provider_custom_endpoint( state: State<'_, AppState>, - endpoint_url: String, + provider: String, + endpoint_url: Option, ) -> Result<(), String> { + let normalized_endpoint = endpoint_url.and_then(|value| { + let trimmed = value.trim().to_string(); + if trimmed.is_empty() { + None + } else { + Some(trimmed) + } + }); + let mut config = LlmUserConfig::load().map_err(|e| e.to_string())?; - config.set_ollama_endpoint(Some(endpoint_url)).map_err(|e| e.to_string())?; + config + .set_provider_endpoint(&provider, normalized_endpoint) + .map_err(|e| e.to_string())?; + state .rebuild_llm_service() .await @@ -286,10 +301,25 @@ pub async fn set_ollama_endpoint( } #[tauri::command] -pub async fn get_ollama_endpoint() -> Result { +pub fn get_provider_custom_endpoint(provider: String) -> Result, String> { let config = LlmUserConfig::load().map_err(|e| e.to_string())?; config - .get_ollama_endpoint() - .map_err(|e| e.to_string())? + .get_provider_endpoint(&provider) + .map_err(|e| e.to_string()) +} + +#[tauri::command] +pub async fn set_ollama_endpoint( + state: State<'_, AppState>, + endpoint_url: String, +) -> Result<(), String> { + set_provider_custom_endpoint(state, "Ollama".to_string(), Some(endpoint_url)).await +} + +#[tauri::command] +pub async fn get_ollama_endpoint() -> Result { + let endpoint = get_provider_custom_endpoint("Ollama".to_string()).map_err(|e| e.to_string())?; + endpoint + .or_else(|| Some(constants::OLLAMA_CUSTOM_ENDPOINT.to_string())) .ok_or_else(|| "No custom Ollama endpoint configured".to_string()) } diff --git a/src-tauri/src/commands/structure/responses.rs b/src-tauri/src/commands/structure/responses.rs index 2b1c4a8..24d3f97 100644 --- a/src-tauri/src/commands/structure/responses.rs +++ b/src-tauri/src/commands/structure/responses.rs @@ -45,6 +45,7 @@ pub struct ProviderConfigResponse { pub enabled: bool, pub available_models: Vec, pub is_configured: bool, + pub endpoint_url: Option, } #[derive(Serialize)] diff --git a/src-tauri/src/config/api_config.rs b/src-tauri/src/config/api_config.rs index 257e66e..62d96a5 100644 --- a/src-tauri/src/config/api_config.rs +++ b/src-tauri/src/config/api_config.rs @@ -1,10 +1,10 @@ use crate::config::keystore::ApiKeystore; use crate::config::utils::internal_str_to_provider_type; use crate::config::APP_PATHS; +use crate::constants; use crate::constants::{ALL_PROVIDER_TYPES, API_SETTINGS_FILE_NAME}; use crate::errors::{AppError, AppResult}; use flyllm::{ModelDiscovery, ProviderType}; -use crate::constants; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; @@ -239,7 +239,9 @@ impl LlmUserConfig { internal_str_to_provider_type(provider_str).map_err(|e| AppError::ConfigError(e))?; if let Some(config) = self.provider_configs.get(&provider_type) { - let has_key = if provider_type == ProviderType::Ollama { + let has_key = if provider_type == ProviderType::Ollama + || provider_type == ProviderType::LmStudio + { true } else { let keystore = ApiKeystore::new(); @@ -282,43 +284,65 @@ impl LlmUserConfig { .get(&provider_type) .ok_or_else(|| AppError::ApiError(format!("Provider '{}' not found", provider_str)))?; - let models = if provider_type == ProviderType::Ollama { - // Determine the endpoint to use - let endpoint_to_use = config.endpoint_url - .clone() - .or_else(|| self.get_ollama_endpoint().ok().flatten()) - .unwrap_or_else(|| constants::OLLAMA_CUSTOM_ENDPOINT.to_string()); - - ModelDiscovery::list_ollama_models(Some(&endpoint_to_use)) - .await - .map_err(|e| AppError::ApiError(format!("Failed to fetch Ollama models: {}", e)))? - .into_iter() - .map(|m| m.name) - .collect() - } else { - let keystore = ApiKeystore::new(); - let api_key = keystore.get_api_key(provider_str)?.ok_or_else(|| { - AppError::ApiError(format!("API key required for provider '{}'", provider_str)) - })?; - - let discovered_models = match provider_type { - ProviderType::OpenAI => ModelDiscovery::list_openai_models(&api_key).await, - ProviderType::Anthropic => ModelDiscovery::list_anthropic_models(&api_key).await, - ProviderType::Mistral => ModelDiscovery::list_mistral_models(&api_key).await, - ProviderType::Google => ModelDiscovery::list_google_models(&api_key).await, - _ => { - return Err(AppError::ApiError(format!( - "Model discovery not supported for provider '{}'", - provider_str - ))); - } - }; + let models = match provider_type { + ProviderType::Ollama => { + let endpoint_to_use = config + .endpoint_url + .clone() + .or_else(|| self.get_ollama_endpoint().ok().flatten()) + .unwrap_or_else(|| constants::OLLAMA_CUSTOM_ENDPOINT.to_string()); + + ModelDiscovery::list_ollama_models(Some(&endpoint_to_use)) + .await + .map_err(|e| { + AppError::ApiError(format!("Failed to fetch Ollama models: {}", e)) + })? + .into_iter() + .map(|m| m.name) + .collect() + } + ProviderType::LmStudio => { + let endpoint_to_use = config + .endpoint_url + .clone() + .unwrap_or_else(|| constants::LM_STUDIO_DEFAULT_ENDPOINT.to_string()); + + ModelDiscovery::list_lmstudio_models(Some(&endpoint_to_use)) + .await + .map_err(|e| { + AppError::ApiError(format!("Failed to fetch LM Studio models: {}", e)) + })? + .into_iter() + .map(|m| m.name) + .collect() + } + _ => { + let keystore = ApiKeystore::new(); + let api_key = keystore.get_api_key(provider_str)?.ok_or_else(|| { + AppError::ApiError(format!("API key required for provider '{}'", provider_str)) + })?; + + let discovered_models = match provider_type { + ProviderType::OpenAI => ModelDiscovery::list_openai_models(&api_key).await, + ProviderType::Anthropic => { + ModelDiscovery::list_anthropic_models(&api_key).await + } + ProviderType::Mistral => ModelDiscovery::list_mistral_models(&api_key).await, + ProviderType::Google => ModelDiscovery::list_google_models(&api_key).await, + _ => { + return Err(AppError::ApiError(format!( + "Model discovery not supported for provider '{}'", + provider_str + ))); + } + }; - discovered_models - .map_err(|e| AppError::ApiError(format!("Failed to fetch models: {}", e)))? - .into_iter() - .map(|m| m.name) - .collect() + discovered_models + .map_err(|e| AppError::ApiError(format!("Failed to fetch models: {}", e)))? + .into_iter() + .map(|m| m.name) + .collect() + } }; self.set_available_models(provider_str, models) @@ -369,7 +393,7 @@ impl LlmUserConfig { pub fn validate_provider_setup(&self, provider: String) -> Vec { let mut issues = Vec::new(); if let Ok(provider_config) = self.get_provider_config(&provider) { - if provider != "Ollama" { + if provider != "Ollama" && provider != "LmStudio" { let keystore = ApiKeystore::new(); match keystore.get_api_key(&provider) { Ok(Some(_)) => {} @@ -392,18 +416,37 @@ impl LlmUserConfig { issues } - pub fn set_ollama_endpoint(&mut self, endpoint_url: Option) -> AppResult<()> { - let provider_type = ProviderType::Ollama; - - if let Some(config) = self.provider_configs.get_mut(&provider_type) { - config.endpoint_url = endpoint_url.clone(); - } - + pub fn set_provider_endpoint( + &mut self, + provider_str: &str, + endpoint_url: Option, + ) -> AppResult<()> { + let provider_type = + internal_str_to_provider_type(provider_str).map_err(|e| AppError::ConfigError(e))?; + + let config = self + .provider_configs + .get_mut(&provider_type) + .ok_or_else(|| AppError::ApiError(format!("Provider '{}' not found", provider_str)))?; + + config.endpoint_url = endpoint_url; self.save() } + pub fn get_provider_endpoint(&self, provider_str: &str) -> AppResult> { + self.get_endpoint_url(provider_str) + } + + pub fn set_ollama_endpoint(&mut self, endpoint_url: Option) -> AppResult<()> { + self.set_provider_endpoint("Ollama", endpoint_url) + } + pub fn get_ollama_endpoint(&self) -> AppResult> { - self.get_endpoint_url("Ollama") + self.get_provider_endpoint("Ollama") + } + + pub fn get_lmstudio_endpoint(&self) -> AppResult> { + self.get_provider_endpoint("LmStudio") } fn get_endpoint_url(&self, provider_str: &str) -> AppResult> { @@ -416,3 +459,27 @@ impl LlmUserConfig { .and_then(|config| config.endpoint_url.clone())) } } + +#[cfg(test)] +mod tests { + use super::*; + use flyllm::ProviderType; + + #[test] + fn default_config_includes_lmstudio() { + let config = LlmUserConfig::default(); + assert!(config + .provider_configs + .contains_key(&ProviderType::LmStudio)); + } + + #[test] + fn set_lmstudio_endpoint_persists_value() { + let mut config = LlmUserConfig::default(); + let endpoint = Some("http://localhost:1234".to_string()); + config + .set_provider_endpoint("LmStudio", endpoint.clone()) + .expect("should set endpoint"); + assert_eq!(config.get_provider_endpoint("LmStudio").unwrap(), endpoint); + } +} diff --git a/src-tauri/src/config/utils.rs b/src-tauri/src/config/utils.rs index e0a106a..aaf988a 100644 --- a/src-tauri/src/config/utils.rs +++ b/src-tauri/src/config/utils.rs @@ -7,6 +7,7 @@ pub fn internal_str_to_provider_type(provider_str: &str) -> Result Ok(flyllm::ProviderType::Mistral), "google" => Ok(flyllm::ProviderType::Google), "ollama" => Ok(flyllm::ProviderType::Ollama), + "lmstudio" => Ok(flyllm::ProviderType::LmStudio), _ => Err(format!("Unknown provider string: {}", provider_str)), } } @@ -21,3 +22,15 @@ pub fn str_to_agent_type(agent_str: &str) -> Result { _ => Err(format!("Unknown agent type: {}", agent_str)), } } + +#[cfg(test)] +mod tests { + use super::internal_str_to_provider_type; + use flyllm::ProviderType; + + #[test] + fn maps_lmstudio_provider_string() { + let provider = internal_str_to_provider_type("lmstudio").expect("should map"); + assert_eq!(provider, ProviderType::LmStudio); + } +} diff --git a/src-tauri/src/constants.rs b/src-tauri/src/constants.rs index 1cf9ec8..0c979af 100644 --- a/src-tauri/src/constants.rs +++ b/src-tauri/src/constants.rs @@ -58,14 +58,16 @@ pub const CONTENT_SELECTORS: [&'static str; 7] = [ "body", ]; // LLMs -pub const ALL_PROVIDER_TYPES: [ProviderType; 5] = [ +pub const ALL_PROVIDER_TYPES: [ProviderType; 6] = [ ProviderType::OpenAI, ProviderType::Anthropic, ProviderType::Mistral, ProviderType::Google, ProviderType::Ollama, + ProviderType::LmStudio, ]; pub const OLLAMA_CUSTOM_ENDPOINT: &str = "http://localhost:11434"; +pub const LM_STUDIO_DEFAULT_ENDPOINT: &str = "http://127.0.0.1:1234"; // Prompts pub const CONCEPT_EXTRACTOR_PROMPT: &str = @@ -87,6 +89,7 @@ pub fn get_default_agent_model(agent: &AgentType, provider: &ProviderType) -> &' (AgentType::ConceptExtractor, ProviderType::Google) => "gemini-2.5-flash", (AgentType::ConceptExtractor, ProviderType::Mistral) => "mistral-large-latest", (AgentType::ConceptExtractor, ProviderType::Ollama) => "llama3", + (AgentType::ConceptExtractor, ProviderType::LmStudio) => "openai/gpt-oss-20b", // FlashcardContentCreator (AgentType::FlashcardContentCreator, ProviderType::OpenAI) => "gpt-4.1-mini-2025-04-14", @@ -94,6 +97,7 @@ pub fn get_default_agent_model(agent: &AgentType, provider: &ProviderType) -> &' (AgentType::FlashcardContentCreator, ProviderType::Google) => "gemini-2.5-flash", (AgentType::FlashcardContentCreator, ProviderType::Mistral) => "mistral-large-latest", (AgentType::FlashcardContentCreator, ProviderType::Ollama) => "llama3", + (AgentType::FlashcardContentCreator, ProviderType::LmStudio) => "openai/gpt-oss-20b", // TestContentCreator (AgentType::TestContentCreator, ProviderType::OpenAI) => "gpt-4.1-mini-2025-04-14", @@ -101,6 +105,7 @@ pub fn get_default_agent_model(agent: &AgentType, provider: &ProviderType) -> &' (AgentType::TestContentCreator, ProviderType::Google) => "gemini-2.5-flash", (AgentType::TestContentCreator, ProviderType::Mistral) => "mistral-large-latest", (AgentType::TestContentCreator, ProviderType::Ollama) => "llama3", + (AgentType::TestContentCreator, ProviderType::LmStudio) => "openai/gpt-oss-20b", // ExplanationAgent (AgentType::ExplanationAgent, ProviderType::OpenAI) => "gpt-4.1-mini-2025-04-14", @@ -108,6 +113,7 @@ pub fn get_default_agent_model(agent: &AgentType, provider: &ProviderType) -> &' (AgentType::ExplanationAgent, ProviderType::Google) => "gemini-2.5-flash", (AgentType::ExplanationAgent, ProviderType::Mistral) => "mistral-large-latest", (AgentType::ExplanationAgent, ProviderType::Ollama) => "llama3", + (AgentType::ExplanationAgent, ProviderType::LmStudio) => "openai/gpt-oss-20b", // SearchAgent (AgentType::SearchAgent, ProviderType::OpenAI) => "gpt-4.1-mini-2025-04-14", @@ -115,6 +121,6 @@ pub fn get_default_agent_model(agent: &AgentType, provider: &ProviderType) -> &' (AgentType::SearchAgent, ProviderType::Google) => "gemini-2.5-flash", (AgentType::SearchAgent, ProviderType::Mistral) => "mistral-large-latest", (AgentType::SearchAgent, ProviderType::Ollama) => "llama3", - + (AgentType::SearchAgent, ProviderType::LmStudio) => "openai/gpt-oss-20b", } } diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index e28ed2f..37e6d05 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -111,6 +111,8 @@ async fn main() { api::is_provider_configured, api::validate_provider_setup, api::get_all_agent_model_configs, + api::set_provider_custom_endpoint, + api::get_provider_custom_endpoint, api::set_ollama_endpoint, api::get_ollama_endpoint, // Settings commands diff --git a/src-tauri/src/services/llm_service/manager.rs b/src-tauri/src/services/llm_service/manager.rs index 4f311e4..9780a01 100644 --- a/src-tauri/src/services/llm_service/manager.rs +++ b/src-tauri/src/services/llm_service/manager.rs @@ -1268,6 +1268,12 @@ async fn build_llm_manager_from_config( .flatten() .unwrap_or_else(|| constants::OLLAMA_CUSTOM_ENDPOINT.to_string()); + let lmstudio_endpoint = config + .get_lmstudio_endpoint() + .ok() + .flatten() + .unwrap_or_else(|| constants::LM_STUDIO_DEFAULT_ENDPOINT.to_string()); + // For each enabled provider, create instances for each agent type that has a model configured for (provider_type, provider_config) in &config.provider_configs { if !provider_config.enabled { @@ -1275,28 +1281,29 @@ async fn build_llm_manager_from_config( } // Get API key from keystore - let api_key = if *provider_type == ProviderType::Ollama { - String::new() - } else { - let provider_str = format!("{:?}", provider_type); - match keystore.get_api_key(&provider_str) { - Ok(Some(key)) => key, - Ok(None) => { - eprintln!( - "Skipping provider {:?}: API key required but not found in keystore", - provider_type - ); - continue; - } - Err(e) => { - eprintln!( - "Skipping provider {:?}: Failed to retrieve API key from keystore: {}", - provider_type, e - ); - continue; + let api_key = + if *provider_type == ProviderType::Ollama || *provider_type == ProviderType::LmStudio { + String::new() + } else { + let provider_str = format!("{:?}", provider_type); + match keystore.get_api_key(&provider_str) { + Ok(Some(key)) => key, + Ok(None) => { + eprintln!( + "Skipping provider {:?}: API key required but not found in keystore", + provider_type + ); + continue; + } + Err(e) => { + eprintln!( + "Skipping provider {:?}: Failed to retrieve API key from keystore: {}", + provider_type, e + ); + continue; + } } - } - }; + }; // Create a separate instance for each agent type that has a model configured for this provider for agent_type in AgentType::all() { @@ -1314,6 +1321,10 @@ async fn build_llm_manager_from_config( .add_instance(*provider_type, model_name, "") .supports(&task_name) .custom_endpoint(&ollama_endpoint), + ProviderType::LmStudio => builder + .add_instance(*provider_type, model_name, "") + .supports(&task_name) + .custom_endpoint(&lmstudio_endpoint), _ => builder .add_instance(*provider_type, model_name, &api_key) .supports(&task_name), @@ -1345,17 +1356,17 @@ async fn build_llm_manager_from_config( fn clean_json_content(content: &str) -> String { let trimmed = content.trim(); - let without_prefix = trimmed.strip_prefix("```json") - .or_else(|| trimmed.strip_prefix("```")) - .unwrap_or(trimmed); - - let without_suffix = without_prefix.strip_suffix("```") - .unwrap_or(without_prefix); - + let without_prefix = trimmed + .strip_prefix("```json") + .or_else(|| trimmed.strip_prefix("```")) + .unwrap_or(trimmed); + + let without_suffix = without_prefix.strip_suffix("```").unwrap_or(without_prefix); + let cleaned = without_suffix - .replace("**", "") // Remove bold markers - .replace("*****", "") // Remove emphasis markers - .replace("*", ""); // Remove remaining asterisks + .replace("**", "") // Remove bold markers + .replace("*****", "") // Remove emphasis markers + .replace("*", ""); // Remove remaining asterisks sanitize_json_strings(cleaned.trim()) } diff --git a/src/lib/components/Settings/LLMSettings.svelte b/src/lib/components/Settings/LLMSettings.svelte index b74400b..c76aa4d 100644 --- a/src/lib/components/Settings/LLMSettings.svelte +++ b/src/lib/components/Settings/LLMSettings.svelte @@ -8,7 +8,7 @@ import { llmStore } from '../../stores/llmStore'; // Types - import {type ProviderConfig, type ValidationResult} from '../../logic/Settings/types' + import { type ProviderConfig, type ValidationResult } from '../../logic/Settings/types'; // Props export let handleError = (error) => {}; @@ -18,30 +18,37 @@ let providers = []; let providerConfigs = {}; let validatingKeys = {}; - let ollamaEndpoint = ''; - let loadingOllamaEndpoint = false; + let providerEndpoints = {}; + let endpointLoading = {}; - const DEFAULT_OLLAMA_ENDPOINT = 'http://localhost:11434'; + const DEFAULT_ENDPOINTS: Record = { + ollama: 'http://localhost:11434', + lmstudio: 'http://127.0.0.1:1234' + }; - // Functions async function loadProviders() { try { const rawProviders: [] = await invoke('get_providers'); providers = rawProviders.sort(); - + providerConfigs = {}; validatingKeys = {}; - + providerEndpoints = {}; + endpointLoading = {}; + for (const provider of providers) { try { const config: ProviderConfig = await invoke('get_provider_config', { provider }); + const endpointProvider = hasCustomEndpoint(provider); + providerConfigs[provider] = { provider_id: provider, api_key: config.api_key || '', enabled: config.enabled || false, is_configured: config.is_configured || false, - keyValidated: config.api_key ? true : false + keyValidated: endpointProvider ? !!config.is_configured : !!config.api_key, }; + providerEndpoints[provider] = config.endpoint_url || ''; } catch (error) { console.error(`Failed to load config for ${provider}:`, error); providerConfigs[provider] = { @@ -49,13 +56,17 @@ api_key: '', enabled: false, is_configured: false, - keyValidated: false + keyValidated: false, }; + providerEndpoints[provider] = ''; } validatingKeys[provider] = false; + endpointLoading[provider] = false; } - + providerConfigs = { ...providerConfigs }; + providerEndpoints = { ...providerEndpoints }; + endpointLoading = { ...endpointLoading }; clearError(); } catch (error) { console.error('Failed to load providers:', error); @@ -63,69 +74,96 @@ } } - async function loadOllamaEndpoint() { - try { - const endpoint = await invoke('get_ollama_endpoint'); - ollamaEndpoint = endpoint; - } catch (error) { - ollamaEndpoint = ''; - } + function normalizeEndpointInput(value: string | undefined) { + if (!value) return null; + const trimmed = value.trim(); + return trimmed.length === 0 ? null : trimmed; } - async function saveOllamaEndpoint() { + async function saveProviderEndpoint(provider: string) { try { - loadingOllamaEndpoint = true; - let endpointToSave = ollamaEndpoint.trim() || DEFAULT_OLLAMA_ENDPOINT; - + endpointLoading[provider] = true; + endpointLoading = { ...endpointLoading }; + const endpointToSave = normalizeEndpointInput(providerEndpoints[provider]); + if (endpointToSave && !endpointToSave.startsWith('http://') && !endpointToSave.startsWith('https://')) { alert('Endpoint must start with http:// or https://'); return; } - - await invoke('set_ollama_endpoint', { endpointUrl: endpointToSave }); - + + await invoke('set_provider_custom_endpoint', { + provider, + endpointUrl: endpointToSave ?? null, + }); + + if (providerConfigs[provider]) { + providerConfigs[provider].keyValidated = false; + providerConfigs = { ...providerConfigs }; + } + clearError(); - alert('Ollama endpoint saved successfully'); + alert(`${provider} endpoint saved successfully`); } catch (error) { - console.error('Failed to save Ollama endpoint:', error); - handleError(`Failed to save Ollama endpoint: ${error.message || error}`); + console.error(`Failed to save ${provider} endpoint:`, error); + handleError(`Failed to save ${provider} endpoint: ${error.message || error}`); } finally { - loadingOllamaEndpoint = false; + endpointLoading[provider] = false; + endpointLoading = { ...endpointLoading }; } } + function hasCustomEndpoint(provider: string) { + const name = provider.toLowerCase(); + return name === 'ollama' || name === 'lmstudio'; + } + + function getDefaultEndpoint(provider: string) { + const name = provider.toLowerCase(); + return DEFAULT_ENDPOINTS[name] ?? ''; + } + + function getEndpointHelp(provider: string) { + const name = provider.toLowerCase(); + if (name === 'ollama') { + return `Leave empty to use default (${DEFAULT_ENDPOINTS.ollama})`; + } + if (name === 'lmstudio') { + return `Leave empty to use default (${DEFAULT_ENDPOINTS.lmstudio})`; + } + return ''; + } + async function validateApiKey(provider) { const config = providerConfigs[provider]; - - if (!config.api_key.trim() && provider.toLowerCase() !== 'ollama') { + + if (!config.api_key.trim() && !hasCustomEndpoint(provider)) { alert('Please enter an API key first'); return; } - + try { validatingKeys[provider] = true; validatingKeys = { ...validatingKeys }; - + const result: ValidationResult = await invoke('validate_api_key_and_fetch_models', { provider, - apiKey: provider.toLowerCase() === 'ollama' ? '' : config.api_key + apiKey: hasCustomEndpoint(provider) ? '' : config.api_key, }); - + if (result.valid) { config.keyValidated = true; providerConfigs = { ...providerConfigs }; - llmStore.refresh(); } else { config.keyValidated = false; providerConfigs = { ...providerConfigs }; - alert(`${provider === 'ollama' ? 'Connection failed' : 'API validation failed'}: ${result.error_message || 'Unknown error'}`); + alert(`${hasCustomEndpoint(provider) ? 'Connection failed' : 'API validation failed'}: ${result.error_message || 'Unknown error'}`); } } catch (error) { console.error(`Failed to validate API key for ${provider}:`, error); config.keyValidated = false; providerConfigs = { ...providerConfigs }; - alert(`Failed to ${provider === 'ollama' ? 'connect' : 'validate API key'}: ${error.message || error}`); + alert(`Failed to ${hasCustomEndpoint(provider) ? 'connect' : 'validate API key'}: ${error.message || error}`); } finally { validatingKeys[provider] = false; validatingKeys = { ...validatingKeys }; @@ -135,7 +173,7 @@ async function updateApiKey(provider, apiKey) { const config = providerConfigs[provider]; config.api_key = apiKey; - + if (!apiKey.trim()) { config.keyValidated = false; config.enabled = false; @@ -150,7 +188,7 @@ return; } } - + providerConfigs = { ...providerConfigs }; } @@ -178,7 +216,8 @@ function isProviderConfigured(provider) { const config = providerConfigs[provider]; - if (provider.toLowerCase() === 'ollama') { + if (!config) return false; + if (hasCustomEndpoint(provider)) { return config.enabled; } return config.keyValidated && config.enabled; @@ -186,7 +225,7 @@ function canEnableProvider(provider) { const config = providerConfigs[provider]; - return config.keyValidated; + return config && config.keyValidated; } function getProviderApiUrl(provider) { @@ -221,7 +260,6 @@ onMount(async () => { await loadProviders(); - await loadOllamaEndpoint(); }); @@ -276,23 +314,32 @@ onMount(async () => { - {#if provider.toLowerCase() === 'ollama'} + {#if hasCustomEndpoint(provider)}

Custom Endpoint (Optional)

{ + const target = e.target as HTMLInputElement; + providerEndpoints[provider] = target.value; + providerEndpoints = { ...providerEndpoints }; + if (providerConfigs[provider]) { + providerConfigs[provider].keyValidated = false; + providerConfigs = { ...providerConfigs }; + } + }} + placeholder={getDefaultEndpoint(provider)} class="config-input" />
-

Leave empty to use default (http://localhost:11434)

+

{getEndpointHelp(provider)}

{:else}
@@ -362,7 +409,7 @@ onMount(async () => { {/if}
- {#if provider.toLowerCase() !== 'ollama'} + {#if !hasCustomEndpoint(provider)} {@const apiUrl = getProviderApiUrl(provider)} {@const displayName = getProviderDisplayName(provider)} {#if apiUrl} @@ -370,38 +417,37 @@ onMount(async () => { {:else}

Get your API key from the {displayName}

{/if} - {:else} -
-
- +
+ + {#if config.keyValidated} +
+ + + - Testing Connection... - {:else} - - Test Connection - {/if} - + Connection successful +
+ {/if}
- - {#if config.keyValidated} -
- - - - - Connection successful -
- {/if} -
{/if}
diff --git a/src/lib/logic/Settings/types.ts b/src/lib/logic/Settings/types.ts index 99a1e4f..e58d688 100644 --- a/src/lib/logic/Settings/types.ts +++ b/src/lib/logic/Settings/types.ts @@ -4,6 +4,7 @@ enabled?: boolean; available_models?: string[]; is_configured?: boolean; + endpoint_url?: string | null; } export interface ValidationResult { @@ -14,4 +15,4 @@ export interface AgentModelConfig { agent_type?: string; provider_models?: Record; - } \ No newline at end of file + }