diff --git a/fastllm/_modidx.py b/fastllm/_modidx.py index fa7229f..41c16a9 100644 --- a/fastllm/_modidx.py +++ b/fastllm/_modidx.py @@ -151,6 +151,7 @@ 'fastllm.chat.split_tools': ('chat.html#split_tools', 'fastllm/chat.py'), 'fastllm.chat.stop_reason': ('chat.html#stop_reason', 'fastllm/chat.py'), 'fastllm.chat.stop_sequences': ('chat.html#stop_sequences', 'fastllm/chat.py'), + 'fastllm.chat.strip_tc_args': ('chat.html#strip_tc_args', 'fastllm/chat.py'), 'fastllm.chat.structured': ('chat.html#structured', 'fastllm/chat.py')}, 'fastllm.codex': {}, 'fastllm.gemini': { 'fastllm.gemini._gem_filter_sch': ('gemini.html#_gem_filter_sch', 'fastllm/gemini.py'), diff --git a/fastllm/chat.py b/fastllm/chat.py index 7fb9cf2..5f38b60 100644 --- a/fastllm/chat.py +++ b/fastllm/chat.py @@ -6,7 +6,7 @@ __all__ = ['tool_dtls_tag', 're_tools', 'token_dtls_tag', 're_token', 'think_start', 'think_end', 're_think', 'effort', 'MediaUrl', 'remove_cache_ckpts', 'contents', 'stop_reason', 'mk_msg', 'FenceToolStop', 'extract_fence_call', 'split_tools', 'fmt2hist', 'mk_msgs', 'cite_footnote', 'postproc', 'lite_mk_func', 'ToolResponse', - 'structured', 'StopResponse', 'FullResponse', 'search_count', 'UsageStats', 'AsyncChat', + 'strip_tc_args', 'structured', 'StopResponse', 'FullResponse', 'search_count', 'UsageStats', 'AsyncChat', 'astream_with_complete', 'ChatCallback', 'DeepseekMsgsCallback', 'DeepseekPrefillCallback', 'add_warning', 'StopReasonCallback', 'run_fence_tool', 'FenceToolCallback', 'ToolReminderCallback', 'stop_sequences', 'StopSequencesCallback', 'mk_tr_details', 'StreamFormatter', 'AsyncStreamFormatter', 'adisplay_stream'] @@ -15,7 +15,7 @@ import asyncio, base64, json, mimetypes, random, string, ast, warnings from typing import Optional,Callable from html import escape -from toolslm.funccall import mk_ns, call_func, call_func_async, get_schema +from toolslm.funccall import mk_ns, call_func, call_func_async, get_schema, strip_tool_arg_defaults from fastcore.utils import * from fastcore.meta import delegates from fastcore import imghdr @@ -286,6 +286,12 @@ def _lite_call_func(tc, tool_schemas, ns): res = _call_func(tc, tool_schemas, ns, call_func) return _mk_tool_result(res) +# %% ../nbs/07_chat.ipynb #51c968d3 +def strip_tc_args(tcs, tool_schemas): + 'Update list of ToolCall arguments by stripping the defaults' + tcs_args = strip_tool_arg_defaults([dict(name=tc.name,arguments=tc.arguments) for tc in tcs], tool_schemas) + for tc, args in zip(tcs, tcs_args): tc.arguments = args + # %% ../nbs/07_chat.ipynb #6fb0e375 @delegates(acomplete) async def structured( @@ -522,6 +528,7 @@ async def _call(self:AsyncChat, msg=None, prefill=None, temp=None, think=None, s res = astream_with_complete(res, postproc=postproc) async for chunk in res: yield chunk res = res.value + strip_tc_args(res.tool_calls, self.tool_schemas) self.turn_res, self.turn_msg = res, contents(res) if self.prefill: self.turn_msg.content[0].text = self.prefill + self.turn_msg.content[0].text self.hist.append(self.turn_msg) diff --git a/nbs/07_chat.ipynb b/nbs/07_chat.ipynb index 766bd75..82f1db6 100644 --- a/nbs/07_chat.ipynb +++ b/nbs/07_chat.ipynb @@ -52,7 +52,7 @@ "import asyncio, base64, json, mimetypes, random, string, ast, warnings\n", "from typing import Optional,Callable\n", "from html import escape\n", - "from toolslm.funccall import mk_ns, call_func, call_func_async, get_schema\n", + "from toolslm.funccall import mk_ns, call_func, call_func_async, get_schema, strip_tool_arg_defaults\n", "from fastcore.utils import *\n", "from fastcore.meta import delegates\n", "from fastcore import imghdr\n", @@ -4071,7 +4071,87 @@ "id": "5f83d9ed", "metadata": {}, "source": [ - "Test tool calls that were not in tool_choice are caught:" + "Strip tool call default args:" + ] + }, + { + "cell_type": "markdown", + "id": "0a76bbed", + "metadata": {}, + "source": [ + "`strip_tool_arg_defaults` from toolslm expects a list of function name -> tool call arguments mapping and tool schemas. We'll use that to strip the default args that might be passed in `ToolCall` objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a5b7bb3", + "metadata": {}, + "outputs": [], + "source": [ + "def defaults_tool(a:int, b:int=0, flag:bool=False, name:str='x', ratio:float=1.0) -> str:\n", + " \"Test defaults of different primitive types\"\n", + " return f'{a=} {b=} {flag=} {name=} {ratio=}'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a51d7a45", + "metadata": {}, + "outputs": [], + "source": [ + "tc1 = ToolCall(id='x', name='simple_add', arguments=dict(a=1234, b=0))\n", + "tc2 = ToolCall(id='y', name='defaults_tool', arguments=dict(a=1234, b=0, flag=False, name='x', ratio=1.0))\n", + "tc3 = ToolCall(id='y', name='web_search', arguments=dict(q='Weather'), server=True)\n", + "tcs = [tc1,tc2,tc3]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4bc2d3c5", + "metadata": {}, + "outputs": [], + "source": [ + "tschemas = [lite_mk_func(simple_add), lite_mk_func(defaults_tool)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51c968d3", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "def strip_tc_args(tcs, tool_schemas):\n", + " 'Update list of ToolCall arguments by stripping the defaults'\n", + " tcs_args = strip_tool_arg_defaults([dict(name=tc.name,arguments=tc.arguments) for tc in tcs], tool_schemas)\n", + " for tc, args in zip(tcs, tcs_args): tc.arguments = args" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ab0a5ea", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[ToolCall(id='x', name='simple_add', arguments={'a': 1234}, server=False, extra={}),\n", + " ToolCall(id='y', name='defaults_tool', arguments={'a': 1234}, server=False, extra={}),\n", + " ToolCall(id='y', name='web_search', arguments={'q': 'Weather'}, server=True, extra={})]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strip_tc_args(tcs, tschemas); tcs" ] }, { @@ -5095,6 +5175,7 @@ " res = astream_with_complete(res, postproc=postproc)\n", " async for chunk in res: yield chunk\n", " res = res.value\n", + " strip_tc_args(res.tool_calls, self.tool_schemas)\n", " self.turn_res, self.turn_msg = res, contents(res)\n", " if self.prefill: self.turn_msg.content[0].text = self.prefill + self.turn_msg.content[0].text\n", " self.hist.append(self.turn_msg)\n", @@ -6137,7 +6218,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "models/gemini-3-flash-preview\n", + "models/gemini-3-flash-preview\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "claude-sonnet-4-6\n" ] }, @@ -8009,7 +8096,7 @@ "\n", "\n", "
\n", - "asimple_div(a="3", b="0")→"Traceback (most recent call last):\\n Fil…"\n", + "asimple_div(a="3")→"Traceback (most recent call last):\\n Fil…"\n", "\n", "```json\n", "{\n", @@ -8018,8 +8105,7 @@ " \"call\": {\n", " \"function\": \"asimple_div\",\n", " \"arguments\": {\n", - " \"a\": \"3\",\n", - " \"b\": \"0\"\n", + " \"a\": \"3\"\n", " }\n", " },\n", " \"result\": \"Traceback (most recent call last):\\n File \\\"/Users/keremturgutlu/aai-ws/toolslm/toolslm/funccall.py\\\", line 276, in call_func_async\\n res = await maybe_await(res)\\n ^^^^^^^^^^^^^^^^^^^^^^\\n File \\\"/Users/keremturgutlu/aai-ws/fastcore/fastcore/xtras.py\\\", line 1063, in maybe_await\\n return await o if isawaitable(o) else o\\n ^^^^^^^\\n File \\\"/var/folders/zl/js35kg3914qc7d8lsdtqsyf00000gn/T/ipykernel_43288/466431256.py\\\", line 6, in asimple_div\\n return a/b\\n ~^~\\nZeroDivisionError: division by zero\"\n", @@ -8077,7 +8163,7 @@ "\n", "\n", "
\n", - "asimple_div(a="3", b="0")→"Traceback (most recent call last):\\n Fil…"\n", + "asimple_div(a="3")→"Traceback (most recent call last):\\n Fil…"\n", "\n", "```json\n", "{\n", @@ -8086,8 +8172,7 @@ " \"call\": {\n", " \"function\": \"asimple_div\",\n", " \"arguments\": {\n", - " \"a\": \"3\",\n", - " \"b\": \"0\"\n", + " \"a\": \"3\"\n", " }\n", " },\n", " \"result\": \"Traceback (most recent call last):\\n File \\\"/Users/keremturgutlu/aai-ws/toolslm/toolslm/funccall.py\\\", line 276, in call_func_async\\n res = await maybe_await(res)\\n ^^^^^^^^^^^^^^^^^^^^^^\\n File \\\"/Users/keremturgutlu/aai-ws/fastcore/fastcore/xtras.py\\\", line 1063, in maybe_await\\n return await o if isawaitable(o) else o\\n ^^^^^^^\\n File \\\"/var/folders/zl/js35kg3914qc7d8lsdtqsyf00000gn/T/ipykernel_43288/466431256.py\\\", line 6, in asimple_div\\n return a/b\\n ~^~\\nZeroDivisionError: division by zero\"\n",