Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 83 additions & 41 deletions mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,52 +263,94 @@ def new(name):
console.log("Please provide a valid project name. Aborted.")
return

platform = questionary.select(
"Which language model platform do you want to use?",
choices=['OpenAI', 'Ollama', 'Claude', 'Gemini', 'MistralAI', 'DeepSeek', 'vLLM']
# First, ask for the type of platform/engine
platform_type = questionary.select(
"What type of LLM platform or engine do you want to use?",
choices=['API Platform', 'Local Engine']
).ask()

api_key = None
base_url = None
model_name = None
if platform == 'OpenAI':
api_key = questionary.password("What is your OpenAI API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return

elif platform == 'Claude':
api_key = questionary.password("What is your Anthropic API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return

elif platform == 'MistralAI':
api_key = questionary.password("What is your MistralAI API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return

elif platform == 'DeepSeek':
api_key = questionary.password("What is your DeepSeek API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return

elif platform == 'Gemini':
api_key = questionary.password("What is your Gemini API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return

elif platform == 'vLLM':
base_url = questionary.text(
"What is your vLLM server URL? (default: http://localhost:8000/v1)"
).ask() or "http://localhost:8000/v1"

model_name = questionary.text(
"What is the model name loaded in your vLLM server? (default: mistralai/Mistral-7B-Instruct-v0.3B)"
).ask() or "mistralai/Mistral-7B-Instruct-v0.3"
platform = None

if platform_type == 'API Platform':
# API-based platforms
platform = questionary.select(
"Which API platform do you want to use?",
choices=['OpenAI', 'Claude', 'Gemini', 'MistralAI', 'DeepSeek']
).ask()

if platform == 'OpenAI':
api_key = questionary.password("What is your OpenAI API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return
model_name = questionary.text(
"What model do you want to use? (default: gpt-4o-2024-08-06)"
).ask() or "gpt-4o-2024-08-06"

elif platform == 'Claude':
api_key = questionary.password("What is your Anthropic API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return
model_name = questionary.text(
"What model do you want to use? (default: claude-3-5-sonnet-20240620)"
).ask() or "claude-3-5-sonnet-20240620"

elif platform == 'MistralAI':
api_key = questionary.password("What is your MistralAI API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return
model_name = questionary.text(
"What model do you want to use? (default: mistral-large-latest)"
).ask() or "mistral-large-latest"

elif platform == 'DeepSeek':
api_key = questionary.password("What is your DeepSeek API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return
model_name = questionary.text(
"What model do you want to use? (default: deepseek-coder)"
).ask() or "deepseek-coder"

elif platform == 'Gemini':
api_key = questionary.password("What is your Gemini API key?").ask()
if not api_key:
console.log("API key is required. Aborted.")
return
model_name = questionary.text(
"What model do you want to use? (default: gemini-2.5-flash)"
).ask() or "gemini-2.5-flash"

elif platform_type == 'Local Engine':
# Local engines
platform = questionary.select(
"Which local engine do you want to use?",
choices=['Ollama', 'vLLM']
).ask()

if platform == 'Ollama':
model_name = questionary.text(
"What model do you want to use? (default: llama3)"
).ask() or "llama3"

host_url = questionary.text(
"What is your Ollama host URL? (default: http://localhost:11434)"
).ask() or "http://localhost:11434"
base_url = host_url

elif platform == 'vLLM':
base_url = questionary.text(
"What is your vLLM server URL? (default: http://localhost:8000/v1)"
).ask() or "http://localhost:8000/v1"

model_name = questionary.text(
"What is the model name loaded in your vLLM server? (default: Qwen/Qwen2.5-1.5B-Instruct)"
).ask() or "Qwen/Qwen2.5-1.5B-Instruct"

search_api_key = questionary.password("What is your Tavily API key? (if no, the web search will be disabled)").ask()
if search_api_key:
Expand Down
4 changes: 3 additions & 1 deletion mle/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def load_model(project_dir: str, model_name: str=None, observable=True):
model = None

if config['platform'] == MODEL_OLLAMA:
model = OllamaModel(model=model_name)
# For Ollama, use base_url as host_url if available
host_url = config.get('base_url', None)
model = OllamaModel(model=model_name, host_url=host_url)
if config['platform'] == MODEL_OPENAI:
model = OpenAIModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_CLAUDE:
Expand Down
2 changes: 1 addition & 1 deletion mle/model/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, base_url: Optional[str] = None,
"pip install openai"
)

self.model = model if model else 'mistralai/Mistral-7B-Instruct-v0.3'
self.model = model if model else 'Qwen/Qwen2.5-1.5B-Instruct'
self.model_type = 'vLLM'
self.temperature = temperature
self.client = self.openai(
Expand Down