Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastllm/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
'fastllm.chat.split_tools': ('chat.html#split_tools', 'fastllm/chat.py'),
'fastllm.chat.stop_reason': ('chat.html#stop_reason', 'fastllm/chat.py'),
'fastllm.chat.stop_sequences': ('chat.html#stop_sequences', 'fastllm/chat.py'),
'fastllm.chat.strip_tc_args': ('chat.html#strip_tc_args', 'fastllm/chat.py'),
'fastllm.chat.structured': ('chat.html#structured', 'fastllm/chat.py')},
'fastllm.codex': {},
'fastllm.gemini': { 'fastllm.gemini._gem_filter_sch': ('gemini.html#_gem_filter_sch', 'fastllm/gemini.py'),
Expand Down
11 changes: 9 additions & 2 deletions fastllm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
__all__ = ['tool_dtls_tag', 're_tools', 'token_dtls_tag', 're_token', 'think_start', 'think_end', 're_think', 'effort',
'MediaUrl', 'remove_cache_ckpts', 'contents', 'stop_reason', 'mk_msg', 'FenceToolStop', 'extract_fence_call',
'split_tools', 'fmt2hist', 'mk_msgs', 'cite_footnote', 'postproc', 'lite_mk_func', 'ToolResponse',
'structured', 'StopResponse', 'FullResponse', 'search_count', 'UsageStats', 'AsyncChat',
'strip_tc_args', 'structured', 'StopResponse', 'FullResponse', 'search_count', 'UsageStats', 'AsyncChat',
'astream_with_complete', 'ChatCallback', 'DeepseekMsgsCallback', 'DeepseekPrefillCallback', 'add_warning',
'StopReasonCallback', 'run_fence_tool', 'FenceToolCallback', 'ToolReminderCallback', 'stop_sequences',
'StopSequencesCallback', 'mk_tr_details', 'StreamFormatter', 'AsyncStreamFormatter', 'adisplay_stream']
Expand All @@ -15,7 +15,7 @@
import asyncio, base64, json, mimetypes, random, string, ast, warnings
from typing import Optional,Callable
from html import escape
from toolslm.funccall import mk_ns, call_func, call_func_async, get_schema
from toolslm.funccall import mk_ns, call_func, call_func_async, get_schema, strip_tool_arg_defaults
from fastcore.utils import *
from fastcore.meta import delegates
from fastcore import imghdr
Expand Down Expand Up @@ -286,6 +286,12 @@ def _lite_call_func(tc, tool_schemas, ns):
res = _call_func(tc, tool_schemas, ns, call_func)
return _mk_tool_result(res)

# %% ../nbs/07_chat.ipynb #51c968d3
def strip_tc_args(tcs, tool_schemas):
'Update list of ToolCall arguments by stripping the defaults'
tcs_args = strip_tool_arg_defaults([dict(name=tc.name,arguments=tc.arguments) for tc in tcs], tool_schemas)
for tc, args in zip(tcs, tcs_args): tc.arguments = args

# %% ../nbs/07_chat.ipynb #6fb0e375
@delegates(acomplete)
async def structured(
Expand Down Expand Up @@ -522,6 +528,7 @@ async def _call(self:AsyncChat, msg=None, prefill=None, temp=None, think=None, s
res = astream_with_complete(res, postproc=postproc)
async for chunk in res: yield chunk
res = res.value
strip_tc_args(res.tool_calls, self.tool_schemas)
self.turn_res, self.turn_msg = res, contents(res)
if self.prefill: self.turn_msg.content[0].text = self.prefill + self.turn_msg.content[0].text
self.hist.append(self.turn_msg)
Expand Down
103 changes: 94 additions & 9 deletions nbs/07_chat.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"import asyncio, base64, json, mimetypes, random, string, ast, warnings\n",
"from typing import Optional,Callable\n",
"from html import escape\n",
"from toolslm.funccall import mk_ns, call_func, call_func_async, get_schema\n",
"from toolslm.funccall import mk_ns, call_func, call_func_async, get_schema, strip_tool_arg_defaults\n",
"from fastcore.utils import *\n",
"from fastcore.meta import delegates\n",
"from fastcore import imghdr\n",
Expand Down Expand Up @@ -4071,7 +4071,87 @@
"id": "5f83d9ed",
"metadata": {},
"source": [
"Test tool calls that were not in tool_choice are caught:"
"Strip tool call default args:"
]
},
{
"cell_type": "markdown",
"id": "0a76bbed",
"metadata": {},
"source": [
"`strip_tool_arg_defaults` from toolslm expects a list of function name -> tool call arguments mapping and tool schemas. We'll use that to strip the default args that might be passed in `ToolCall` objects"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a5b7bb3",
"metadata": {},
"outputs": [],
"source": [
"def defaults_tool(a:int, b:int=0, flag:bool=False, name:str='x', ratio:float=1.0) -> str:\n",
" \"Test defaults of different primitive types\"\n",
" return f'{a=} {b=} {flag=} {name=} {ratio=}'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a51d7a45",
"metadata": {},
"outputs": [],
"source": [
"tc1 = ToolCall(id='x', name='simple_add', arguments=dict(a=1234, b=0))\n",
"tc2 = ToolCall(id='y', name='defaults_tool', arguments=dict(a=1234, b=0, flag=False, name='x', ratio=1.0))\n",
"tc3 = ToolCall(id='y', name='web_search', arguments=dict(q='Weather'), server=True)\n",
"tcs = [tc1,tc2,tc3]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4bc2d3c5",
"metadata": {},
"outputs": [],
"source": [
"tschemas = [lite_mk_func(simple_add), lite_mk_func(defaults_tool)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51c968d3",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def strip_tc_args(tcs, tool_schemas):\n",
" 'Update list of ToolCall arguments by stripping the defaults'\n",
" tcs_args = strip_tool_arg_defaults([dict(name=tc.name,arguments=tc.arguments) for tc in tcs], tool_schemas)\n",
" for tc, args in zip(tcs, tcs_args): tc.arguments = args"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ab0a5ea",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[ToolCall(id='x', name='simple_add', arguments={'a': 1234}, server=False, extra={}),\n",
" ToolCall(id='y', name='defaults_tool', arguments={'a': 1234}, server=False, extra={}),\n",
" ToolCall(id='y', name='web_search', arguments={'q': 'Weather'}, server=True, extra={})]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"strip_tc_args(tcs, tschemas); tcs"
]
},
{
Expand Down Expand Up @@ -5095,6 +5175,7 @@
" res = astream_with_complete(res, postproc=postproc)\n",
" async for chunk in res: yield chunk\n",
" res = res.value\n",
" strip_tc_args(res.tool_calls, self.tool_schemas)\n",
" self.turn_res, self.turn_msg = res, contents(res)\n",
" if self.prefill: self.turn_msg.content[0].text = self.prefill + self.turn_msg.content[0].text\n",
" self.hist.append(self.turn_msg)\n",
Expand Down Expand Up @@ -6137,7 +6218,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"models/gemini-3-flash-preview\n",
"models/gemini-3-flash-preview\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"claude-sonnet-4-6\n"
]
},
Expand Down Expand Up @@ -8009,7 +8096,7 @@
"\n",
"\n",
"<details class='tool-usage-details' markdown='1'>\n",
"<summary><code>asimple_div(a=&quot;3&quot;, b=&quot;0&quot;)→&quot;Traceback (most recent call last):\\n Fil…&quot;</code></summary>\n",
"<summary><code>asimple_div(a=&quot;3&quot;)→&quot;Traceback (most recent call last):\\n Fil…&quot;</code></summary>\n",
"\n",
"```json\n",
"{\n",
Expand All @@ -8018,8 +8105,7 @@
" \"call\": {\n",
" \"function\": \"asimple_div\",\n",
" \"arguments\": {\n",
" \"a\": \"3\",\n",
" \"b\": \"0\"\n",
" \"a\": \"3\"\n",
" }\n",
" },\n",
" \"result\": \"Traceback (most recent call last):\\n File \\\"/Users/keremturgutlu/aai-ws/toolslm/toolslm/funccall.py\\\", line 276, in call_func_async\\n res = await maybe_await(res)\\n ^^^^^^^^^^^^^^^^^^^^^^\\n File \\\"/Users/keremturgutlu/aai-ws/fastcore/fastcore/xtras.py\\\", line 1063, in maybe_await\\n return await o if isawaitable(o) else o\\n ^^^^^^^\\n File \\\"/var/folders/zl/js35kg3914qc7d8lsdtqsyf00000gn/T/ipykernel_43288/466431256.py\\\", line 6, in asimple_div\\n return a/b\\n ~^~\\nZeroDivisionError: division by zero\"\n",
Expand Down Expand Up @@ -8077,7 +8163,7 @@
"\n",
"\n",
"<details class='tool-usage-details' markdown='1'>\n",
"<summary><code>asimple_div(a=&quot;3&quot;, b=&quot;0&quot;)→&quot;Traceback (most recent call last):\\n Fil…&quot;</code></summary>\n",
"<summary><code>asimple_div(a=&quot;3&quot;)→&quot;Traceback (most recent call last):\\n Fil…&quot;</code></summary>\n",
"\n",
"```json\n",
"{\n",
Expand All @@ -8086,8 +8172,7 @@
" \"call\": {\n",
" \"function\": \"asimple_div\",\n",
" \"arguments\": {\n",
" \"a\": \"3\",\n",
" \"b\": \"0\"\n",
" \"a\": \"3\"\n",
" }\n",
" },\n",
" \"result\": \"Traceback (most recent call last):\\n File \\\"/Users/keremturgutlu/aai-ws/toolslm/toolslm/funccall.py\\\", line 276, in call_func_async\\n res = await maybe_await(res)\\n ^^^^^^^^^^^^^^^^^^^^^^\\n File \\\"/Users/keremturgutlu/aai-ws/fastcore/fastcore/xtras.py\\\", line 1063, in maybe_await\\n return await o if isawaitable(o) else o\\n ^^^^^^^\\n File \\\"/var/folders/zl/js35kg3914qc7d8lsdtqsyf00000gn/T/ipykernel_43288/466431256.py\\\", line 6, in asimple_div\\n return a/b\\n ~^~\\nZeroDivisionError: division by zero\"\n",
Expand Down