From 195d21878b726fbe91ed8470034f8c5bf09ea860 Mon Sep 17 00:00:00 2001 From: HuaizhengZhang Date: Tue, 15 Jul 2025 06:04:32 +0000 Subject: [PATCH 1/2] =?UTF-8?q?=F0=9F=90=9B=20use=20qwen=201.5b=20as=20the?= =?UTF-8?q?=20vllm=20testing=20bed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mle/cli.py | 4 ++-- mle/model/vllm.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mle/cli.py b/mle/cli.py index f08b225..f664e06 100644 --- a/mle/cli.py +++ b/mle/cli.py @@ -307,8 +307,8 @@ def new(name): ).ask() or "http://localhost:8000/v1" model_name = questionary.text( - "What is the model name loaded in your vLLM server? (default: mistralai/Mistral-7B-Instruct-v0.3B)" - ).ask() or "mistralai/Mistral-7B-Instruct-v0.3" + "What is the model name loaded in your vLLM server? (default: Qwen/Qwen2.5-1.5B-Instruct)" + ).ask() or "Qwen/Qwen2.5-1.5B-Instruct" search_api_key = questionary.password("What is your Tavily API key? (if no, the web search will be disabled)").ask() if search_api_key: diff --git a/mle/model/vllm.py b/mle/model/vllm.py index 3038e04..95d0648 100644 --- a/mle/model/vllm.py +++ b/mle/model/vllm.py @@ -32,7 +32,7 @@ def __init__(self, base_url: Optional[str] = None, "pip install openai" ) - self.model = model if model else 'mistralai/Mistral-7B-Instruct-v0.3' + self.model = model if model else 'Qwen/Qwen2.5-1.5B-Instruct' self.model_type = 'vLLM' self.temperature = temperature self.client = self.openai( From 51d7aa59c2efbeba3b88cf26dc1ed108b2b2a9d1 Mon Sep 17 00:00:00 2001 From: HuaizhengZhang Date: Tue, 15 Jul 2025 06:33:46 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=F0=9F=9A=91=20engine=20and=20api=20platfor?= =?UTF-8?q?m=20more=20clear?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mle/cli.py | 124 ++++++++++++++++++++++++++++-------------- mle/model/__init__.py | 4 +- 2 files changed, 86 insertions(+), 42 deletions(-) diff --git a/mle/cli.py b/mle/cli.py index f664e06..8f56cbe 100644 --- a/mle/cli.py +++ b/mle/cli.py @@ -263,52 +263,94 @@ def new(name): console.log("Please provide a valid project name. Aborted.") return - platform = questionary.select( - "Which language model platform do you want to use?", - choices=['OpenAI', 'Ollama', 'Claude', 'Gemini', 'MistralAI', 'DeepSeek', 'vLLM'] + # First, ask for the type of platform/engine + platform_type = questionary.select( + "What type of LLM platform or engine do you want to use?", + choices=['API Platform', 'Local Engine'] ).ask() api_key = None base_url = None model_name = None - if platform == 'OpenAI': - api_key = questionary.password("What is your OpenAI API key?").ask() - if not api_key: - console.log("API key is required. Aborted.") - return - - elif platform == 'Claude': - api_key = questionary.password("What is your Anthropic API key?").ask() - if not api_key: - console.log("API key is required. Aborted.") - return - - elif platform == 'MistralAI': - api_key = questionary.password("What is your MistralAI API key?").ask() - if not api_key: - console.log("API key is required. Aborted.") - return - - elif platform == 'DeepSeek': - api_key = questionary.password("What is your DeepSeek API key?").ask() - if not api_key: - console.log("API key is required. Aborted.") - return - - elif platform == 'Gemini': - api_key = questionary.password("What is your Gemini API key?").ask() - if not api_key: - console.log("API key is required. Aborted.") - return - - elif platform == 'vLLM': - base_url = questionary.text( - "What is your vLLM server URL? (default: http://localhost:8000/v1)" - ).ask() or "http://localhost:8000/v1" - - model_name = questionary.text( - "What is the model name loaded in your vLLM server? (default: Qwen/Qwen2.5-1.5B-Instruct)" - ).ask() or "Qwen/Qwen2.5-1.5B-Instruct" + platform = None + + if platform_type == 'API Platform': + # API-based platforms + platform = questionary.select( + "Which API platform do you want to use?", + choices=['OpenAI', 'Claude', 'Gemini', 'MistralAI', 'DeepSeek'] + ).ask() + + if platform == 'OpenAI': + api_key = questionary.password("What is your OpenAI API key?").ask() + if not api_key: + console.log("API key is required. Aborted.") + return + model_name = questionary.text( + "What model do you want to use? (default: gpt-4o-2024-08-06)" + ).ask() or "gpt-4o-2024-08-06" + + elif platform == 'Claude': + api_key = questionary.password("What is your Anthropic API key?").ask() + if not api_key: + console.log("API key is required. Aborted.") + return + model_name = questionary.text( + "What model do you want to use? (default: claude-3-5-sonnet-20240620)" + ).ask() or "claude-3-5-sonnet-20240620" + + elif platform == 'MistralAI': + api_key = questionary.password("What is your MistralAI API key?").ask() + if not api_key: + console.log("API key is required. Aborted.") + return + model_name = questionary.text( + "What model do you want to use? (default: mistral-large-latest)" + ).ask() or "mistral-large-latest" + + elif platform == 'DeepSeek': + api_key = questionary.password("What is your DeepSeek API key?").ask() + if not api_key: + console.log("API key is required. Aborted.") + return + model_name = questionary.text( + "What model do you want to use? (default: deepseek-coder)" + ).ask() or "deepseek-coder" + + elif platform == 'Gemini': + api_key = questionary.password("What is your Gemini API key?").ask() + if not api_key: + console.log("API key is required. Aborted.") + return + model_name = questionary.text( + "What model do you want to use? (default: gemini-2.5-flash)" + ).ask() or "gemini-2.5-flash" + + elif platform_type == 'Local Engine': + # Local engines + platform = questionary.select( + "Which local engine do you want to use?", + choices=['Ollama', 'vLLM'] + ).ask() + + if platform == 'Ollama': + model_name = questionary.text( + "What model do you want to use? (default: llama3)" + ).ask() or "llama3" + + host_url = questionary.text( + "What is your Ollama host URL? (default: http://localhost:11434)" + ).ask() or "http://localhost:11434" + base_url = host_url + + elif platform == 'vLLM': + base_url = questionary.text( + "What is your vLLM server URL? (default: http://localhost:8000/v1)" + ).ask() or "http://localhost:8000/v1" + + model_name = questionary.text( + "What is the model name loaded in your vLLM server? (default: Qwen/Qwen2.5-1.5B-Instruct)" + ).ask() or "Qwen/Qwen2.5-1.5B-Instruct" search_api_key = questionary.password("What is your Tavily API key? (if no, the web search will be disabled)").ask() if search_api_key: diff --git a/mle/model/__init__.py b/mle/model/__init__.py index 611413c..fab71d8 100644 --- a/mle/model/__init__.py +++ b/mle/model/__init__.py @@ -59,7 +59,9 @@ def load_model(project_dir: str, model_name: str=None, observable=True): model = None if config['platform'] == MODEL_OLLAMA: - model = OllamaModel(model=model_name) + # For Ollama, use base_url as host_url if available + host_url = config.get('base_url', None) + model = OllamaModel(model=model_name, host_url=host_url) if config['platform'] == MODEL_OPENAI: model = OpenAIModel(api_key=config['api_key'], model=model_name) if config['platform'] == MODEL_CLAUDE: