Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 67 additions & 15 deletions proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,26 @@ def _extract_call_objects(obj):
def parse_tool_calls(content, tools=None):
"""
Parse <tool_call>…</tool_call> blocks from the model's text.
Also handles bare JSON tool calls without tags (fallback).

Returns (text_without_tags, list_of_openai_tool_call_dicts).
"""
pattern = re.compile(r"<tool_call>\s*(.*?)\s*</tool_call>", re.DOTALL)
matches = pattern.findall(content)

# Fallback: detect bare JSON like {"name": "...", "arguments": {...}}
# when the model forgets to wrap in <tool_call> tags
bare_match = None
if not matches:
bare_pattern = re.compile(
r'(\{[^{}]*"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{.*?\}[^{}]*\})',
re.DOTALL,
)
bare_matches = bare_pattern.findall(content)
if bare_matches:
matches = bare_matches
bare_match = True

if not matches:
return content, []

Expand Down Expand Up @@ -263,7 +277,15 @@ def parse_tool_calls(content, tools=None):
except (json.JSONDecodeError, KeyError, AttributeError):
continue

text = pattern.sub("", content).strip()
if bare_match:
# Remove the matched bare JSON blobs from text
bare_pattern = re.compile(
r'(\{[^{}]*"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{.*?\}[^{}]*\})',
re.DOTALL,
)
text = bare_pattern.sub("", content).strip()
else:
text = pattern.sub("", content).strip()
return text, tool_calls


Expand Down Expand Up @@ -394,7 +416,7 @@ def do_POST(self):
for t in openai_req.get("tools", [])
if t.get("function", {}).get("name", "").lower() not in FILTERED_TOOLS
]
tool_choice = openai_req.get("tool_choice", "auto")
tool_choice = openai_req.get("tool_choice", "required")

last_content = extract_text_content(
messages[-1].get("content", "") if messages else ""
Expand Down Expand Up @@ -468,27 +490,51 @@ def do_POST(self):
else:
chat_messages.append({"role": role, "content": content})

# Append tool definitions to the system prompt
full_system_prompt = system_prompt.strip()
if tools:
full_system_prompt += format_tools_for_prompt(tools, tool_choice)
# ----- Build final system prompt (tools first, then base) -----
# chatjimmy silently returns empty responses above ~8000 chars.
# Strategy: tools are ALWAYS included intact (they must not be truncated),
# and the base system prompt is trimmed to fit within the budget.
MAX_TOTAL_SYSTEM = 80000

tools_section = format_tools_for_prompt(tools, tool_choice) if tools else ""
tools_len = len(tools_section)

base_budget = max(0, MAX_TOTAL_SYSTEM - tools_len)
base_system_prompt = system_prompt.strip()

# ChatJimmy returns empty responses when system prompt exceeds ~30K chars
MAX_SYSTEM_PROMPT = 28000
if len(full_system_prompt) > MAX_SYSTEM_PROMPT:
if len(base_system_prompt) > base_budget:
logfile(
f"WARNING: system prompt is {len(full_system_prompt)} chars, truncating to {MAX_SYSTEM_PROMPT}"
f"WARNING: base system prompt truncated from {len(base_system_prompt)} to {base_budget} chars "
f"(tools use {tools_len} chars)"
)
full_system_prompt = full_system_prompt[:MAX_SYSTEM_PROMPT]
base_system_prompt = base_system_prompt[:base_budget]

# Tools go LAST so they are never cut off by truncation
full_system_prompt = base_system_prompt + tools_section

log(f"system_prompt={len(full_system_prompt)} chars (base={len(base_system_prompt)}, tools={tools_len})")

# Clean messages: drop empty content, keep all valid roles
clean_messages = []
for m in chat_messages:
msg_content = m.get("content", "")
if not msg_content or not str(msg_content).strip():
continue
clean_messages.append({
"role": m.get("role", "user"),
"content": str(msg_content),
})

if not clean_messages:
clean_messages = [{"role": "user", "content": "Hello"}]

jimmy_payload = {
"messages": chat_messages,
"messages": clean_messages,
"chatOptions": {
"selectedModel": MODELS.get(model, model),
"selectedModel": MODELS.get(model, "llama3.1-8B"),
"systemPrompt": full_system_prompt,
"topK": 8,
},
"attachment": None,
}

# File: translated payload
Expand Down Expand Up @@ -526,6 +572,12 @@ def do_POST(self):
logfile("--- RAW UPSTREAM RESPONSE ---")
logfile(raw_response)

# Warn on empty or suspiciously short responses
if not raw_response.strip():
log(f"WARNING: upstream returned empty response (system_prompt={len(full_system_prompt)} chars, tools={len(tools)})")
elif len(raw_response.strip()) < 10:
log(f"WARNING: upstream returned very short response: {repr(raw_response)}")

# Strip stats, parse usage
content = re.sub(
r"<\|stats\|>.*?<\|/stats\|>", "", raw_response, flags=re.DOTALL
Expand Down Expand Up @@ -726,4 +778,4 @@ def main():


if __name__ == "__main__":
main()
main()