diff --git a/README.md b/README.md index fc8ef3d..9c821a9 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,16 @@ make docker-build ## Step 1: Generating C Specifications -To generate specs with an LLM, you first need to put your API key in a `.env` file. +To generate specs with an LLM, you first need to put your API key(s) in a `.env` file. ```sh echo "LLM_API_KEY=" > models/.env +echo "ANTHROPIC_API_KEY=" >> models/.env ``` +The `ANTHROPIC_API_KEY` is required for specification generation and repair via Anthropic's + Claude models. + Then run the Python script ```sh diff --git a/main.py b/main.py index d37e464..c8c26b8 100755 --- a/main.py +++ b/main.py @@ -41,7 +41,7 @@ VALID_LOG_LEVELS = ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL") -DEFAULT_MODEL = "gpt-4o" +DEFAULT_MODEL = "claude-sonnet-4-6" DEFAULT_HEADERS_FOR_VERIFICATION: Sequence[str] = ( "#include ", "#include ", diff --git a/models/default_llm_backend.py b/models/default_llm_backend.py index 18771a2..fc0cb11 100644 --- a/models/default_llm_backend.py +++ b/models/default_llm_backend.py @@ -12,6 +12,8 @@ import os import pathlib import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any import litellm from litellm import completion @@ -33,7 +35,10 @@ def __init__(self, model: str, use_vertex_api: bool): else: self.vertex_credentials = None self.model = model - self.api_key = os.environ["LLM_API_KEY"] + api_key_for_model = ( + "ANTHROPIC_API_KEY" if self._is_claude_model(model) else "LLM_API_KEY" + ) + self.api_key = os.environ[api_key_for_model] if "claude" in model: self.max_tokens = 64000 @@ -46,21 +51,81 @@ def __init__(self, model: str, use_vertex_api: bool): def send_messages( self, messages: tuple[ConversationMessage, ...], temperature: float = 0, top_k: int = 1 ) -> list[str]: - """messages: [{'role': 'system', 'content': 'You are an intelligent code assistant'}, - {'role': 'user', 'content': 'Translate this program...'}, - {'role': 'assistant', 'content': 'Here is the translation...'}, - {'role': 'user', 'content': 'Do something else...'}] - - : ['Sure, here is...', - 'Okay, let me see...', - ...] - len() == top_k + """Return `top_k` sampled responses from the LLM for the given messages. + + Args: + messages (tuple[ConversationMessage, ...]): The conversation to send to the LLM. + temperature (float): The sampling temperature. Must be non-zero when `top_k > 1`. + top_k (int): The number of responses to sample. + + Returns: + list[str]: The sampled responses. `len(returned) == top_k`. """ if top_k < 1: raise GenerationError("top_k must be >= 1") if top_k != 1 and temperature == 0: raise GenerationError("Top k sampling requires a non-zero temperature") + # Claude models do not support the `n` parameter; issue parallel requests instead. + if "claude" in self.model: + return self._send_parallel(messages, temperature, top_k) + + response = self._send_with_retry(messages, temperature, n=top_k) + return [choice["message"]["content"] for choice in response["choices"]] + + def _send_parallel( + self, messages: tuple[ConversationMessage, ...], temperature: float, top_k: int + ) -> list[str]: + """Return `top_k` responses by issuing parallel single requests. + + Used for models that do not support the `n` parameter (e.g. Claude). Each request runs + its own retry loop independently. + + Args: + messages (tuple[ConversationMessage, ...]): The conversation to send. + temperature (float): The sampling temperature. + top_k (int): The number of parallel requests to make. + + Returns: + list[str]: The `top_k` sampled responses. + """ + with ThreadPoolExecutor(max_workers=top_k) as executor: + futures = [ + executor.submit(self._send_one_message, messages, temperature) for _ in range(top_k) + ] + return [f.result() for f in futures] + + def _send_one_message( + self, messages: tuple[ConversationMessage, ...], temperature: float + ) -> str: + """Return a single response from the LLM with retry and compaction logic. + + Args: + messages (tuple[ConversationMessage, ...]): The conversation to send. + temperature (float): The sampling temperature. + + Returns: + str: The model's response text. + """ + response = self._send_with_retry(messages, temperature) + return response["choices"][0]["message"]["content"] + + def _send_with_retry( + self, + messages: tuple[ConversationMessage, ...], + temperature: float, + **kwargs: Any, + ) -> Any: + """Return the raw LLM response, retrying on transient errors and compacting on overflow. + + Args: + messages (tuple[ConversationMessage, ...]): The conversation to send. + temperature (float): The sampling temperature. + **kwargs: Extra keyword arguments forwarded to `completion` (e.g. `n=top_k`). + + Returns: + Any: The raw litellm response object. + """ count = 0 while True: try: @@ -68,12 +133,12 @@ def send_messages( model=self.model, messages=[message.to_dict() for message in messages], temperature=temperature, - n=top_k, api_key=self.api_key, vertex_credentials=self.vertex_credentials, max_tokens=self.max_tokens, + **kwargs, ) - break + return response except litellm.ContextWindowExceededError as e: compacted = self._compact_conversation(messages) if compacted is None: @@ -95,14 +160,13 @@ def send_messages( ) as e: count += 1 if count >= 5: - raise ModelError("Vertex AI API: Too many retries") + msg = f"LLM API retries exceeded with model {self.model}" + raise ModelError(msg) logger.warning(f"LLM Error {e}. Waiting 10 seconds and retrying") time.sleep(10) except Exception as e: raise GenerationError(f"LLM Error: {e}") - return [choice["message"]["content"] for choice in response["choices"]] - @staticmethod def get_instance(model_name: str, use_vertex_api: bool) -> LlmBackend: """Return an instance of LlmBackend for the given model. @@ -174,3 +238,14 @@ def _compact_conversation( ): return None return (system_message, initial_user_message, latest_llm_response, last_user_message) + + def _is_claude_model(self, model_name: str) -> bool: + """Return True iff the model name corresponds to an Anthropic Claude model. + + Args: + model_name (str): The model name to check. + + Returns: + bool: True iff the model name corresponds to an Anthropic Claude model. + """ + return model_name.strip().startswith("claude")