From 295b0804af9528bffed85a4798ea4b0f2b8fb783 Mon Sep 17 00:00:00 2001 From: Jacob Ellis Date: Mon, 11 Aug 2025 13:24:57 +0930 Subject: [PATCH] Handle max tokens properly --- .../algo/ai_handlers/litellm_ai_handler.py | 49 +++++++++++++++++-- pr_agent/algo/token_handler.py | 10 +++- pr_agent/config_loader.py | 20 ++++++++ pr_agent/git_providers/utils.py | 25 ++++++++++ pr_agent/settings/configuration.toml | 12 ++++- pr_agent/tools/pr_reviewer.py | 29 ++++++++--- requirements.txt | 6 ++- 7 files changed, 135 insertions(+), 16 deletions(-) diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index 8d727b8b..cfcec8db 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -7,7 +7,7 @@ from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler -from pr_agent.algo.utils import ReasoningEffort, get_version +from pr_agent.algo.utils import ReasoningEffort, get_version, get_max_tokens from pr_agent.config_loader import get_settings from pr_agent.log import get_logger import json @@ -188,6 +188,8 @@ def _configure_claude_extended_thinking(self, model: str, kwargs: dict) -> dict: """ extended_thinking_budget_tokens = get_settings().config.get("extended_thinking_budget_tokens", 2048) extended_thinking_max_output_tokens = get_settings().config.get("extended_thinking_max_output_tokens", 4096) + # Ensure the model still has room to return an answer when thinking is enabled + min_output_after_thinking = get_settings().config.get("extended_thinking_min_output_tokens", 1024) # Validate extended thinking parameters if not isinstance(extended_thinking_budget_tokens, int) or extended_thinking_budget_tokens <= 0: @@ -197,13 +199,31 @@ def _configure_claude_extended_thinking(self, model: str, kwargs: dict) -> dict: if extended_thinking_max_output_tokens < extended_thinking_budget_tokens: raise ValueError(f"extended_thinking_max_output_tokens ({extended_thinking_max_output_tokens}) must be greater than or equal to extended_thinking_budget_tokens ({extended_thinking_budget_tokens})") + # Start from any precomputed max_tokens (dynamic fit), cap by configured max + precomputed_max = kwargs.get("max_tokens", extended_thinking_max_output_tokens) + max_tokens_final = min(precomputed_max, extended_thinking_max_output_tokens) + + # If max_tokens is too small relative to budget, reduce budget to fit a minimum output buffer + adjusted_budget = min(extended_thinking_budget_tokens, max(0, max_tokens_final - min_output_after_thinking)) + if adjusted_budget <= 0: + # Not enough room for thinking + any output → disable thinking this call + if get_settings().config.verbosity_level >= 1: + get_logger().warning( + "Insufficient token budget for extended thinking; disabling thinking for this call" + ) + kwargs.pop("thinking", None) + kwargs["max_tokens"] = max_tokens_final + return kwargs + kwargs["thinking"] = { "type": "enabled", - "budget_tokens": extended_thinking_budget_tokens + "budget_tokens": int(adjusted_budget), } if get_settings().config.verbosity_level >= 2: - get_logger().info(f"Adding max output tokens {extended_thinking_max_output_tokens} to model {model}, extended thinking budget tokens: {extended_thinking_budget_tokens}") - kwargs["max_tokens"] = extended_thinking_max_output_tokens + get_logger().info( + f"Extended thinking: budget={adjusted_budget}, max_tokens={max_tokens_final} for model {model}" + ) + kwargs["max_tokens"] = max_tokens_final # temperature may only be set to 1 when thinking is enabled if get_settings().config.verbosity_level >= 2: @@ -325,6 +345,27 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: "api_base": self.api_base, } + # Dynamically set max_tokens to fit within the model's context window + try: + # Conservative output budget; can be overridden by extended thinking configuration + target_output_tokens = int(get_settings().config.get("default_max_output_tokens", 2048)) + # Estimate input tokens from concatenated message content + concatenated_inputs = "\n".join([ + m.get("content", "") if isinstance(m.get("content", ""), str) else json.dumps(m.get("content", "")) + for m in messages + ]) + from pr_agent.algo.token_handler import TokenHandler + token_handler = TokenHandler() + input_tokens_estimate = token_handler.count_tokens(concatenated_inputs) + model_ctx = get_max_tokens(model) + # If input + desired output exceed context, reduce max output + available_for_output = max(256, model_ctx - input_tokens_estimate - 512) + kwargs["max_tokens"] = max(256, min(target_output_tokens, available_for_output)) + if get_settings().config.verbosity_level >= 2: + get_logger().info(f"Setting max_tokens={kwargs['max_tokens']} (ctx={model_ctx}, input≈{input_tokens_estimate}) for model {model}") + except Exception as e: + get_logger().debug(f"Unable to set dynamic max_tokens: {e}") + # Add temperature only if model supports it if model not in self.no_support_temperature_models and not get_settings().config.custom_reasoning_model: # get_logger().info(f"Adding temperature with value {temperature} to model {model}.") diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 60cf2c84..ab85778f 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -72,7 +72,15 @@ def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): environment = Environment(undefined=StrictUndefined) system_prompt = environment.from_string(system).render(vars) user_prompt = environment.from_string(user).render(vars) - system_prompt_tokens = len(encoder.encode(system_prompt)) + + # Ensure token accounting includes repository-specific rules appended to system prompt + try: + from pr_agent.git_providers.utils import add_repository_rules_to_prompt as _add_rules + system_prompt_with_rules = _add_rules(system_prompt) + except Exception: + system_prompt_with_rules = system_prompt + + system_prompt_tokens = len(encoder.encode(system_prompt_with_rules)) user_prompt_tokens = len(encoder.encode(user_prompt)) return system_prompt_tokens + user_prompt_tokens except Exception as e: diff --git a/pr_agent/config_loader.py b/pr_agent/config_loader.py index 0c865865..730977d4 100644 --- a/pr_agent/config_loader.py +++ b/pr_agent/config_loader.py @@ -86,3 +86,23 @@ def _find_pyproject() -> Optional[Path]: # Support GITHUB_TOKEN as an alternative to GITHUB__USER_TOKEN if os.environ.get('GITHUB_TOKEN') and not os.environ.get('GITHUB__USER_TOKEN'): get_settings().set('github.user_token', os.environ['GITHUB_TOKEN']) + +# Auto-select default model if user hasn’t explicitly set or overrides: +try: + # Only set defaults if not explicitly set by repo/global settings + configured_model = get_settings().get('config.model', '').strip() + if not configured_model: + openai_key = get_settings().get('openai.key') or os.environ.get('OPENAI_API_KEY') or os.environ.get('OPENAI__KEY') + anthropic_key = get_settings().get('anthropic.key') or os.environ.get('ANTHROPIC_API_KEY') or os.environ.get('ANTHROPIC__KEY') + + if openai_key: + # Prefer GPT-5 if OpenAI key is present + get_settings().set('config.model', 'gpt-5') + get_settings().set('config.fallback_models', ['anthropic/claude-sonnet-4-20250514'] if anthropic_key else []) + elif anthropic_key: + # Fall back to Claude Sonnet 4 if only Anthropic is present + get_settings().set('config.model', 'anthropic/claude-sonnet-4-20250514') + # else: leave as-is, user must set +except Exception: + # Silent: don’t block initialization if auto-detect fails + pass diff --git a/pr_agent/git_providers/utils.py b/pr_agent/git_providers/utils.py index e3a0a91d..285b2a70 100644 --- a/pr_agent/git_providers/utils.py +++ b/pr_agent/git_providers/utils.py @@ -169,6 +169,31 @@ def get_repository_rules_for_prompt(): rules_handler = get_cursor_rules() if rules_handler and rules_handler.has_rules(): rules_content = rules_handler.get_rules_for_prompt() + + # Optionally clip rules to a max token budget to avoid exceeding context + try: + from pr_agent.algo.token_handler import TokenHandler + from pr_agent.algo.utils import get_max_tokens + model = get_settings().config.model + token_handler = TokenHandler() + rules_tokens = token_handler.count_tokens(rules_content) + + max_rules_tokens = int(get_settings().config.get('max_cursor_rules_tokens', 20000)) + hard_cap_ratio = float(get_settings().config.get('cursor_rules_context_ratio', 0.25)) + model_ctx = get_max_tokens(model) + hard_cap_tokens = max(2000, int(model_ctx * hard_cap_ratio)) + allowed_tokens = min(max_rules_tokens, hard_cap_tokens) + + if rules_tokens > allowed_tokens: + from pr_agent.algo.utils import clip_tokens + clipped = clip_tokens(rules_content, allowed_tokens, add_three_dots=True) + get_logger().warning( + f"Cursor rules too large ({rules_tokens} tokens). Clipped to {allowed_tokens} tokens for prompting." + ) + rules_content = clipped + except Exception as e: + get_logger().debug(f"Failed to apply cursor rules token budget: {e}") + # Count rules size for logging rules_size = len(rules_content) get_logger().info(f"📋 Including repository Cursor rules in AI prompt ({rules_size:,} characters)") diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 76f30ff2..4c4a63b5 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -6,7 +6,8 @@ [config] # models -model="anthropic/claude-sonnet-4-20250514" +# Leave blank to enable auto-selection in config_loader.py (GPT-5 if OpenAI key, else Claude Sonnet 4 if Anthropic key) +model="" fallback_models=[] #model_reasoning="o4-mini" # dedictated reasoning model for self-reflection #model_weak="gpt-4o" # optional, a weaker model to use for some easier tasks @@ -31,7 +32,7 @@ response_language="en-US" # Language locales code for PR responses in ISO 3166 a # token limits max_description_tokens = 500 max_commits_tokens = 500 -max_model_tokens = 100000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities. +max_model_tokens = 175000 # Global cap to align with desired headroom under 200K models custom_model_max_tokens=-1 # for models not in the default list model_token_count_estimate_factor=0.3 # factor to increase the token count estimate, in order to reduce likelihood of model failure due to too many tokens - applicable only when requesting an accurate estimate. # patch extension logic @@ -74,6 +75,13 @@ feedback_on_draft_pr=false # Set to true to enable processing of draft PRs # Cursor rules handling use_cursor_rules=true # Set to true to enable reading official Cursor rules from repositories (.cursor/rules/*.mdc and .cursorrules) +# Cursor rules budgeting +max_cursor_rules_tokens = 20000 +cursor_rules_context_ratio = 0.25 + +# Default output budget when not using extended thinking +default_max_output_tokens = 2048 + # extended thinking for Claude reasoning models enable_claude_extended_thinking = true # Re-enabled with updated LiteLLM 1.71.1 extended_thinking_budget_tokens = 16384 # Good balance of deep reasoning and PR size capacity diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index cef28e84..7104cb80 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -200,13 +200,28 @@ async def _prepare_prediction(self, model: str) -> None: self.patches_diff = result self.diff_was_pruned = False else: - # Normal review mode - self.patches_diff = get_pr_diff(self.git_provider, - self.token_handler, - model, - add_line_numbers_to_hunks=True, - disable_extra_lines=False,) - self.diff_was_pruned = False + # Normal review mode: if diff too large, fallback to chunked processing + patches = get_pr_diff(self.git_provider, + self.token_handler, + model, + add_line_numbers_to_hunks=True, + disable_extra_lines=False, + return_pruning_info=True) + if isinstance(patches, tuple): + diff_text, pruned = patches + if pruned and not diff_text: + # Use chunked diffs for very large PRs + from pr_agent.algo.pr_processing import get_pr_multi_diffs + max_calls = int(get_settings().pr_reviewer.get("max_number_of_calls", 3)) + chunks = get_pr_multi_diffs(self.git_provider, self.token_handler, model, max_calls=max_calls, add_line_numbers=True) + # Concatenate chunks with separators to keep context manageable + self.patches_diff = "\n\n---\n\n".join(chunks) + else: + self.patches_diff = diff_text + self.diff_was_pruned = pruned + else: + self.patches_diff = patches + self.diff_was_pruned = False if self.patches_diff: get_logger().debug(f"PR diff", diff=self.patches_diff) diff --git a/requirements.txt b/requirements.txt index a3733df4..596c3115 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ # Core AI provider support -litellm>=1.71.1 +# Bump LiteLLM for GPT-5 routing and latest providers +litellm>=1.74.0 openai>=1.55.3 anthropic>=0.48 tiktoken==0.8.0 @@ -17,13 +18,14 @@ Jinja2==3.1.2 loguru==0.7.2 tenacity==8.2.3 retry==0.9.2 -pydantic==2.8.2 +pydantic>=2.11.2,<3 html2text==2024.2.26 ujson==5.8.0 # Web framework (for GitHub App) fastapi==0.111.0 uvicorn==0.22.0 +httpx==0.27.2 gunicorn==22.0.0 starlette-context==0.3.6 aiohttp>=3.10.0