diff --git a/openless-all/app/src-tauri/src/commands.rs b/openless-all/app/src-tauri/src/commands.rs index ddbff79a..0d37f1e7 100644 --- a/openless-all/app/src-tauri/src/commands.rs +++ b/openless-all/app/src-tauri/src/commands.rs @@ -165,9 +165,17 @@ struct ProviderConfig { } fn read_openai_provider_config(kind: &str) -> Result { - let (api_key_account, endpoint_account) = match kind { - "llm" => (CredentialAccount::ArkApiKey, CredentialAccount::ArkEndpoint), - "asr" => (CredentialAccount::AsrApiKey, CredentialAccount::AsrEndpoint), + let (api_key_account, endpoint_account, api_key_required) = match kind { + "llm" => ( + CredentialAccount::ArkApiKey, + CredentialAccount::ArkEndpoint, + false, + ), + "asr" => ( + CredentialAccount::AsrApiKey, + CredentialAccount::AsrEndpoint, + true, + ), _ => return Err(format!("unknown provider kind: {kind}")), }; let api_key = CredentialsVault::get(api_key_account) @@ -176,7 +184,7 @@ fn read_openai_provider_config(kind: &str) -> Result { let base_url = CredentialsVault::get(endpoint_account) .map_err(|e| e.to_string())? .unwrap_or_default(); - if api_key.trim().is_empty() { + if api_key_required && api_key.trim().is_empty() { return Err("API Key 为空".to_string()); } if base_url.trim().is_empty() { @@ -217,18 +225,17 @@ async fn fetch_provider_models(config: &ProviderConfig) -> Result, S .timeout(Duration::from_secs(15)) .build() .map_err(|e| format!("HTTP client 初始化失败: {e}"))?; - let response = client - .get(&url) - .header("Authorization", format!("Bearer {}", config.api_key)) - .send() - .await - .map_err(|e| { - if e.is_timeout() { - "请求超时".to_string() - } else { - format!("网络错误: {e}") - } - })?; + let mut request = client.get(&url); + if !config.api_key.trim().is_empty() { + request = request.header("Authorization", format!("Bearer {}", config.api_key)); + } + let response = request.send().await.map_err(|e| { + if e.is_timeout() { + "请求超时".to_string() + } else { + format!("网络错误: {e}") + } + })?; let status = response.status(); let body = response .text() @@ -550,11 +557,17 @@ fn _ensure_snapshot_used(_: CredentialsSnapshot) {} #[cfg(test)] mod tests { - use super::{models_url, parse_model_ids, persist_settings, SettingsWriter}; + use super::{ + fetch_provider_models, models_url, parse_model_ids, persist_settings, ProviderConfig, + SettingsWriter, + }; use crate::types::{ HotkeyBinding, HotkeyMode, HotkeyTrigger, QaHotkeyBinding, UserPreferences, }; + use std::io::{Read, Write}; + use std::net::TcpListener; use std::sync::Mutex; + use std::thread; #[derive(Default)] struct FakeSettingsWriter { @@ -630,4 +643,46 @@ mod tests { assert_eq!(*writer.dictation_refreshes.lock().unwrap(), 1); assert_eq!(*writer.qa_refreshes.lock().unwrap(), 1); } + + #[tokio::test] + async fn fetch_provider_models_omits_authorization_when_api_key_is_empty() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let server = thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + let mut buf = [0u8; 8192]; + let mut request = Vec::new(); + loop { + let n = stream.read(&mut buf).unwrap(); + if n == 0 { + break; + } + request.extend_from_slice(&buf[..n]); + if request.windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let request_text = String::from_utf8_lossy(&request); + assert!(!request_text.contains("Authorization: Bearer")); + + let body = r#"{"data":[{"id":"m1"},{"id":"m2"}]}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes()).unwrap(); + }); + + let models = fetch_provider_models(&ProviderConfig { + base_url: format!("http://{}", addr), + api_key: String::new(), + }) + .await + .unwrap(); + + assert_eq!(models, vec!["m1".to_string(), "m2".to_string()]); + server.join().unwrap(); + } } diff --git a/openless-all/app/src-tauri/src/coordinator.rs b/openless-all/app/src-tauri/src/coordinator.rs index 42eea2e8..c588ec4d 100644 --- a/openless-all/app/src-tauri/src/coordinator.rs +++ b/openless-all/app/src-tauri/src/coordinator.rs @@ -1589,15 +1589,10 @@ async fn polish_text( front_app: Option<&str>, ) -> anyhow::Result { let api_key = CredentialsVault::get(CredentialAccount::ArkApiKey)?.unwrap_or_default(); - if api_key.is_empty() { - anyhow::bail!("ark api key missing"); - } let model = CredentialsVault::get(CredentialAccount::ArkModelId)? .filter(|s| !s.is_empty()) .unwrap_or_else(|| "deepseek-v3-2".to_string()); - let endpoint = CredentialsVault::get(CredentialAccount::ArkEndpoint)? - .filter(|s| !s.is_empty()) - .unwrap_or_else(|| "https://ark.cn-beijing.volces.com/api/v3/chat/completions".to_string()); + let endpoint = resolve_ark_endpoint(&api_key)?; let base_url = endpoint .trim_end_matches("/chat/completions") .trim_end_matches('/') @@ -1634,15 +1629,10 @@ async fn translate_text( front_app: Option<&str>, ) -> anyhow::Result { let api_key = CredentialsVault::get(CredentialAccount::ArkApiKey)?.unwrap_or_default(); - if api_key.is_empty() { - anyhow::bail!("ark api key missing"); - } let model = CredentialsVault::get(CredentialAccount::ArkModelId)? .filter(|s| !s.is_empty()) .unwrap_or_else(|| "deepseek-v3-2".to_string()); - let endpoint = CredentialsVault::get(CredentialAccount::ArkEndpoint)? - .filter(|s| !s.is_empty()) - .unwrap_or_else(|| "https://ark.cn-beijing.volces.com/api/v3/chat/completions".to_string()); + let endpoint = resolve_ark_endpoint(&api_key)?; let base_url = endpoint .trim_end_matches("/chat/completions") .trim_end_matches('/') @@ -2154,15 +2144,10 @@ where C: Fn() -> bool + Send + Sync, { let api_key = CredentialsVault::get(CredentialAccount::ArkApiKey)?.unwrap_or_default(); - if api_key.is_empty() { - anyhow::bail!("ark api key missing"); - } let model = CredentialsVault::get(CredentialAccount::ArkModelId)? .filter(|s| !s.is_empty()) .unwrap_or_else(|| "deepseek-v3-2".to_string()); - let endpoint = CredentialsVault::get(CredentialAccount::ArkEndpoint)? - .filter(|s| !s.is_empty()) - .unwrap_or_else(|| "https://ark.cn-beijing.volces.com/api/v3/chat/completions".to_string()); + let endpoint = resolve_ark_endpoint(&api_key)?; let base_url = endpoint .trim_end_matches("/chat/completions") .trim_end_matches('/') @@ -2170,10 +2155,32 @@ where let config = OpenAICompatibleConfig::new("ark", "Doubao Ark", base_url, api_key, model); let provider = OpenAICompatibleLLMProvider::new(config); Ok(provider - .answer_chat_streaming(messages, working_languages, front_app, on_delta, should_cancel) + .answer_chat_streaming( + messages, + working_languages, + front_app, + on_delta, + should_cancel, + ) .await?) } +fn resolve_ark_endpoint(api_key: &str) -> anyhow::Result { + let endpoint = CredentialsVault::get(CredentialAccount::ArkEndpoint)?.filter(|s| !s.is_empty()); + resolve_ark_endpoint_with_policy(api_key, endpoint) +} + +fn resolve_ark_endpoint_with_policy( + api_key: &str, + endpoint: Option, +) -> anyhow::Result { + if api_key.trim().is_empty() && endpoint.is_none() { + anyhow::bail!("API Key 为空"); + } + Ok(endpoint + .unwrap_or_else(|| "https://ark.cn-beijing.volces.com/api/v3/chat/completions".to_string())) +} + #[cfg(test)] mod tests { use super::*; @@ -2226,6 +2233,26 @@ mod tests { assert!(!window_key_matches_trigger(HotkeyTrigger::Fn, "Fn", "Fn")); } + #[test] + fn resolve_ark_endpoint_rejects_blank_key_without_custom_endpoint() { + assert_eq!( + resolve_ark_endpoint_with_policy("", None) + .unwrap_err() + .to_string(), + "API Key 为空" + ); + } + + #[test] + fn resolve_ark_endpoint_allows_blank_key_with_custom_endpoint() { + let endpoint = resolve_ark_endpoint_with_policy( + "", + Some("https://example.com/v1/chat/completions".to_string()), + ) + .unwrap(); + assert_eq!(endpoint, "https://example.com/v1/chat/completions"); + } + #[test] fn deferred_asr_bridge_flushes_startup_audio_before_live_chunks() { #[derive(Default)] diff --git a/openless-all/app/src-tauri/src/polish.rs b/openless-all/app/src-tauri/src/polish.rs index 4388bf0d..f394eb4e 100644 --- a/openless-all/app/src-tauri/src/polish.rs +++ b/openless-all/app/src-tauri/src/polish.rs @@ -146,10 +146,6 @@ impl OpenAICompatibleLLMProvider { system_prompt: &str, user_prompt: &str, ) -> Result { - if self.config.api_key.trim().is_empty() { - return Err(LLMError::MissingCredentials); - } - let url = chat_completions_url(&self.config.base_url); let body = json!({ "model": self.config.model, @@ -171,8 +167,10 @@ impl OpenAICompatibleLLMProvider { let mut request = self .client .post(&url) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", self.config.api_key)); + .header("Content-Type", "application/json"); + if !self.config.api_key.trim().is_empty() { + request = request.header("Authorization", format!("Bearer {}", self.config.api_key)); + } for (k, v) in &self.config.extra_headers { request = request.header(k.as_str(), v.as_str()); } @@ -222,10 +220,6 @@ impl OpenAICompatibleLLMProvider { F: Fn(&str) + Send + Sync, C: Fn() -> bool + Send + Sync, { - if self.config.api_key.trim().is_empty() { - return Err(LLMError::MissingCredentials); - } - let mut msgs: Vec = Vec::with_capacity(history.len() + 1); msgs.push(json!({ "role": "system", "content": system_prompt })); for m in history { @@ -252,8 +246,10 @@ impl OpenAICompatibleLLMProvider { .client .post(&url) .header("Content-Type", "application/json") - .header("Accept", "text/event-stream") - .header("Authorization", format!("Bearer {}", self.config.api_key)); + .header("Accept", "text/event-stream"); + if !self.config.api_key.trim().is_empty() { + request = request.header("Authorization", format!("Bearer {}", self.config.api_key)); + } for (k, v) in &self.config.extra_headers { request = request.header(k.as_str(), v.as_str()); } @@ -310,7 +306,10 @@ impl OpenAICompatibleLLMProvider { let event = buffer[..idx].to_string(); buffer.drain(..idx + 2); for line in event.lines() { - let Some(payload) = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:")) else { + let Some(payload) = line + .strip_prefix("data: ") + .or_else(|| line.strip_prefix("data:")) + else { continue; }; let payload = payload.trim(); @@ -320,7 +319,10 @@ impl OpenAICompatibleLLMProvider { let v: Value = match serde_json::from_str(payload) { Ok(v) => v, Err(e) => { - log::warn!("[llm] SSE parse skip: {e}; payload preview: {}", safe_str_slice(payload, 80)); + log::warn!( + "[llm] SSE parse skip: {e}; payload preview: {}", + safe_str_slice(payload, 80) + ); continue; } }; @@ -382,9 +384,7 @@ fn context_premise(working_languages: &[String], front_app: Option<&str>) -> Opt .map(|s| s.trim()) .filter(|s| !s.is_empty()) .collect(); - let app = front_app - .map(str::trim) - .filter(|s| !s.is_empty()); + let app = front_app.map(str::trim).filter(|s| !s.is_empty()); if langs.is_empty() && app.is_none() { return None; @@ -832,6 +832,9 @@ pub mod prompts { #[cfg(test)] mod tests { use super::*; + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::thread; #[test] fn clean_polish_output_strips_think_tag_block() { @@ -900,11 +903,60 @@ mod tests { #[test] fn compose_system_prompt_prefers_correct_spelling_for_hotwords() { - let prompt = compose_system_prompt(PolishMode::Light, &["GitHub".into(), "OpenLess".into()]); + let prompt = + compose_system_prompt(PolishMode::Light, &["GitHub".into(), "OpenLess".into()]); assert!(prompt.contains("用户希望以下写法在输出中保持准确")); assert!(prompt.contains("同音 / 近形误识别时,优先按上述写法输出")); assert!(prompt.contains("- GitHub")); assert!(prompt.contains("- OpenLess")); } + + #[tokio::test] + async fn chat_completion_omits_authorization_when_api_key_is_empty() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let server = thread::spawn(move || { + let (mut stream, _) = listener.accept().unwrap(); + let mut buf = [0u8; 8192]; + let mut request = Vec::new(); + loop { + let n = stream.read(&mut buf).unwrap(); + if n == 0 { + break; + } + request.extend_from_slice(&buf[..n]); + if request.windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let request_text = String::from_utf8_lossy(&request); + assert!(!request_text.contains("Authorization: Bearer")); + + let body = r#"{"choices":[{"message":{"content":"最终文本。"}}]}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes()).unwrap(); + }); + + let provider = OpenAICompatibleLLMProvider::new(OpenAICompatibleConfig::new( + "ark", + "Doubao Ark", + format!("http://{}", addr), + "", + "deepseek-v3-2", + )); + + let output = provider + .polish("原文", PolishMode::Raw, &[], &[], None) + .await + .unwrap(); + assert_eq!(output, "最终文本。"); + + server.join().unwrap(); + } }