Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@ make docker-build

## Step 1: Generating C Specifications

To generate specs with an LLM, you first need to put your API key in a `.env` file.
To generate specs with an LLM, you first need to put your API key(s) in a `.env` file.

```sh
echo "LLM_API_KEY=<your key here>" > models/.env
echo "ANTHROPIC_API_KEY=<your key here>" >> models/.env
```

The `ANTHROPIC_API_KEY` is required for specification generation and repair via Anthropic's
Claude models.
Comment thread
jyoo980 marked this conversation as resolved.

Then run the Python script

```sh
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

VALID_LOG_LEVELS = ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL")

DEFAULT_MODEL = "gpt-4o"
DEFAULT_MODEL = "claude-sonnet-4-6"
DEFAULT_HEADERS_FOR_VERIFICATION: Sequence[str] = (
"#include <stdlib.h>",
"#include <limits.h>",
Expand Down
105 changes: 90 additions & 15 deletions models/default_llm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import os
import pathlib
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any

import litellm
from litellm import completion
Expand All @@ -33,7 +35,10 @@ def __init__(self, model: str, use_vertex_api: bool):
else:
self.vertex_credentials = None
self.model = model
self.api_key = os.environ["LLM_API_KEY"]
api_key_for_model = (
"ANTHROPIC_API_KEY" if self._is_claude_model(model) else "LLM_API_KEY"
)
self.api_key = os.environ[api_key_for_model]

if "claude" in model:
self.max_tokens = 64000
Expand All @@ -46,34 +51,94 @@ def __init__(self, model: str, use_vertex_api: bool):
def send_messages(
self, messages: tuple[ConversationMessage, ...], temperature: float = 0, top_k: int = 1
) -> list[str]:
"""messages: [{'role': 'system', 'content': 'You are an intelligent code assistant'},
{'role': 'user', 'content': 'Translate this program...'},
{'role': 'assistant', 'content': 'Here is the translation...'},
{'role': 'user', 'content': 'Do something else...'}]

<returned>: ['Sure, here is...',
'Okay, let me see...',
...]
len(<returned>) == top_k
"""Return `top_k` sampled responses from the LLM for the given messages.

Args:
messages (tuple[ConversationMessage, ...]): The conversation to send to the LLM.
temperature (float): The sampling temperature. Must be non-zero when `top_k > 1`.
top_k (int): The number of responses to sample.

Returns:
list[str]: The sampled responses. `len(returned) == top_k`.
"""
if top_k < 1:
raise GenerationError("top_k must be >= 1")
if top_k != 1 and temperature == 0:
raise GenerationError("Top k sampling requires a non-zero temperature")

# Claude models do not support the `n` parameter; issue parallel requests instead.
if "claude" in self.model:
return self._send_parallel(messages, temperature, top_k)

response = self._send_with_retry(messages, temperature, n=top_k)
return [choice["message"]["content"] for choice in response["choices"]]

def _send_parallel(
self, messages: tuple[ConversationMessage, ...], temperature: float, top_k: int
) -> list[str]:
"""Return `top_k` responses by issuing parallel single requests.

Used for models that do not support the `n` parameter (e.g. Claude). Each request runs
its own retry loop independently.

Args:
messages (tuple[ConversationMessage, ...]): The conversation to send.
temperature (float): The sampling temperature.
top_k (int): The number of parallel requests to make.

Returns:
list[str]: The `top_k` sampled responses.
"""
with ThreadPoolExecutor(max_workers=top_k) as executor:
futures = [
executor.submit(self._send_one_message, messages, temperature) for _ in range(top_k)
]
return [f.result() for f in futures]

def _send_one_message(
self, messages: tuple[ConversationMessage, ...], temperature: float
) -> str:
"""Return a single response from the LLM with retry and compaction logic.

Args:
messages (tuple[ConversationMessage, ...]): The conversation to send.
temperature (float): The sampling temperature.

Returns:
str: The model's response text.
"""
response = self._send_with_retry(messages, temperature)
return response["choices"][0]["message"]["content"]

def _send_with_retry(
self,
messages: tuple[ConversationMessage, ...],
temperature: float,
**kwargs: Any,
) -> Any:
"""Return the raw LLM response, retrying on transient errors and compacting on overflow.

Args:
messages (tuple[ConversationMessage, ...]): The conversation to send.
temperature (float): The sampling temperature.
**kwargs: Extra keyword arguments forwarded to `completion` (e.g. `n=top_k`).

Returns:
Any: The raw litellm response object.
"""
count = 0
while True:
try:
response = completion(
model=self.model,
messages=[message.to_dict() for message in messages],
temperature=temperature,
n=top_k,
api_key=self.api_key,
vertex_credentials=self.vertex_credentials,
max_tokens=self.max_tokens,
**kwargs,
)
break
return response
except litellm.ContextWindowExceededError as e:
compacted = self._compact_conversation(messages)
if compacted is None:
Expand All @@ -95,14 +160,13 @@ def send_messages(
) as e:
count += 1
if count >= 5:
raise ModelError("Vertex AI API: Too many retries")
msg = f"LLM API retries exceeded with model {self.model}"
raise ModelError(msg)
logger.warning(f"LLM Error {e}. Waiting 10 seconds and retrying")
time.sleep(10)
except Exception as e:
raise GenerationError(f"LLM Error: {e}")

return [choice["message"]["content"] for choice in response["choices"]]

@staticmethod
def get_instance(model_name: str, use_vertex_api: bool) -> LlmBackend:
"""Return an instance of LlmBackend for the given model.
Expand Down Expand Up @@ -174,3 +238,14 @@ def _compact_conversation(
):
return None
return (system_message, initial_user_message, latest_llm_response, last_user_message)

def _is_claude_model(self, model_name: str) -> bool:
"""Return True iff the model name corresponds to an Anthropic Claude model.

Args:
model_name (str): The model name to check.

Returns:
bool: True iff the model name corresponds to an Anthropic Claude model.
"""
return model_name.strip().startswith("claude")