From 575f5c484aae96191525a9858d7ee7b1307ec7d6 Mon Sep 17 00:00:00 2001 From: jschweiz Date: Wed, 21 Jan 2026 11:11:40 +0100 Subject: [PATCH] [feat]: add llama4 pythonic template --- .../src/wayflowcore/_utils/formatting.py | 136 ++++++--- .../src/wayflowcore/models/_modelhelpers.py | 6 +- .../src/wayflowcore/models/ocigenaimodel.py | 10 +- .../_builtins_deserialization_plugin.py | 41 ++- .../wayflowcore/templates/llama4template.py | 271 ++++++++++++++++++ 5 files changed, 417 insertions(+), 47 deletions(-) create mode 100644 wayflowcore/src/wayflowcore/templates/llama4template.py diff --git a/wayflowcore/src/wayflowcore/_utils/formatting.py b/wayflowcore/src/wayflowcore/_utils/formatting.py index cf0544f95..4d3d17a5b 100644 --- a/wayflowcore/src/wayflowcore/_utils/formatting.py +++ b/wayflowcore/src/wayflowcore/_utils/formatting.py @@ -8,7 +8,7 @@ import json import logging import uuid -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple from json_repair import json_repair @@ -296,37 +296,116 @@ def generate_tool_id() -> str: return str(uuid.uuid4()) +_UNPARSE_ERRORS = (AttributeError, ValueError, TypeError) + + # AST visitor class to parse tool calls class CallVisitor(ast.NodeVisitor): - def __init__(self) -> None: - self.tool_calls: List[Tuple[str, Dict[str, Any]]] = [] + """ + Collects function call expressions. - def visit_Call(self, node: ast.Call) -> None: - arg_dict = {} + Design goals: + - Never crash on weird AST shapes + - Preserve python values where safe (ast.literal_eval), otherwise fall back to source strings + - Keep explicit *args/**kwargs separate (they aren't normal positional/keyword args) + """ + + def __init__( + self, + *, + allowed_names: Optional[Sequence[str]] = None, + ) -> None: + self.tool_calls: List[ToolRequest] = [] + self.allowed_names = set(allowed_names) if allowed_names else None + + def _safe_value(self, expr: ast.AST) -> Any: + """ + Return a real python value if it's a literal (numbers, strings, dict/list literals, etc.), + otherwise return source code string. + """ + try: + return ast.literal_eval(expr) + except (ValueError, SyntaxError, TypeError): + try: + return ast.unparse(expr) + except _UNPARSE_ERRORS: + return ast.dump(expr, include_attributes=False) + + def _call_name(self, func: ast.AST) -> str: + """ + Best-effort fully qualified-ish name: + - Name: foo + - Attribute chain: pkg.mod.foo or obj.method + - Other callables: + """ + # Name: foo + if isinstance(func, ast.Name): + return func.id + + # Attribute: x.y (possibly chained) + if isinstance(func, ast.Attribute): + parts: List[str] = [] + cur: ast.AST = func + while isinstance(cur, ast.Attribute): + parts.append(cur.attr) + cur = cur.value + + if isinstance(cur, ast.Name): + parts.append(cur.id) + return ".".join(reversed(parts)) + + # Something like (get_obj()).method -> can't name base cleanly + try: + base = ast.unparse(cur) + except _UNPARSE_ERRORS: + return "" + + # parts currently holds attrs from outer->inner, reversed gives inner->outer. + # For (get_obj()).method, parts would be ["method"] so this becomes "base.method". + return f"{base}." + ".".join(reversed(parts)) + + # Subscript call: fns[i](...) + if isinstance(func, ast.Subscript): + try: + return ast.unparse(func) + except _UNPARSE_ERRORS: + return "" + + # Lambda / Call / etc. + try: + return ast.unparse(func) + except _UNPARSE_ERRORS: + return "" - # first parse children to enqueue recursive tool calls first + def visit_Call(self, node: ast.Call) -> None: self.generic_visit(node) - # positional arguments - for i, arg in enumerate(node.args): - if isinstance(arg, ast.AST): - key = f"arg{i}" - arg_dict[key] = ast.unparse(arg) + name = self._call_name(node.func) + if self.allowed_names is not None and name not in self.allowed_names: + return - # keyword arguments - for kw in node.keywords: - if isinstance(kw, ast.keyword): - key = kw.arg if kw.arg else "**" - arg_dict[key] = ( - kw.value.value if isinstance(kw.value, ast.Constant) else ast.unparse(kw.value) - ) + kwargs: Dict[str, Any] = {} + starred_kwargs: List[str] = [] - if isinstance(node.func, ast.Attribute): - name = f"{getattr(node.func.value, 'id')}.{node.func.attr}" - else: - name = getattr(node.func, "id") + for kw in node.keywords: + if kw.arg is None: + try: + starred_kwargs.append(ast.unparse(kw.value)) + except _UNPARSE_ERRORS: + try: + starred_kwargs.append(ast.dump(kw.value, include_attributes=False)) + except TypeError: + starred_kwargs.append("") + else: + kwargs[kw.arg] = self._safe_value(kw.value) - self.tool_calls.append((name, arg_dict)) + self.tool_calls.append( + ToolRequest( + name=name, + args=kwargs, + tool_request_id=generate_tool_id(), + ) + ) def parse_tool_call_using_ast(raw_txt: str) -> List[ToolRequest]: @@ -334,15 +413,8 @@ def parse_tool_call_using_ast(raw_txt: str) -> List[ToolRequest]: ast_tree = ast.parse(raw_txt) visitor = CallVisitor() visitor.visit(ast_tree) - return [ - ToolRequest( - name=name, - args=args, - tool_request_id=generate_tool_id(), - ) - for name, args in visitor.tool_calls - ] - except Exception as e: + return visitor.tool_calls + except (SyntaxError, ValueError, TypeError, RecursionError) as e: logger.debug("Could not find any tool call in %s (%s)", raw_txt, str(e)) return [] diff --git a/wayflowcore/src/wayflowcore/models/_modelhelpers.py b/wayflowcore/src/wayflowcore/models/_modelhelpers.py index 3c26e527f..6f7b028ce 100644 --- a/wayflowcore/src/wayflowcore/models/_modelhelpers.py +++ b/wayflowcore/src/wayflowcore/models/_modelhelpers.py @@ -85,4 +85,8 @@ def _is_gemma_model(model_id: str) -> bool: def _is_llama_legacy_model(model_id: str) -> bool: - return "llama" in model_id.lower() and "3." in model_id + return "llama" in model_id.lower() and ("3." in model_id or "-3" in model_id) + + +def _is_recent_llama_model(model_id: str) -> bool: + return "llama" in model_id.lower() and ("4." in model_id or "-4" in model_id) diff --git a/wayflowcore/src/wayflowcore/models/ocigenaimodel.py b/wayflowcore/src/wayflowcore/models/ocigenaimodel.py index 5b8bfd537..04cd7c7c4 100644 --- a/wayflowcore/src/wayflowcore/models/ocigenaimodel.py +++ b/wayflowcore/src/wayflowcore/models/ocigenaimodel.py @@ -22,7 +22,7 @@ from wayflowcore.tools import Tool, ToolRequest from wayflowcore.transforms import CanonicalizationMessageTransform -from ._modelhelpers import _is_llama_legacy_model +from ._modelhelpers import _is_llama_legacy_model, _is_recent_llama_model from ._openaihelpers import _property_to_openai_schema from ._openaihelpers._utils import _safe_json_loads from ._requesthelpers import StreamChunkType, TaggedMessageChunkTypeWithTokenUsage @@ -393,11 +393,17 @@ def default_agent_template(self) -> "PromptTemplate": if self.provider == ModelProvider.COHERE: return NATIVE_AGENT_TEMPLATE + if self.provider == ModelProvider.META and _is_recent_llama_model(self.model_id): + logger.debug( + "Llama-4.x models have limited performance with native tool calling. Wayflow will instead use the `PYTHON_CALL_AGENT_TEMPLATE`, which yields better performance than native tool calling" + ) + from wayflowcore.templates.llama4template import LLAMA4_PYTHONIC_AGENT_TEMPLATE + + return LLAMA4_PYTHONIC_AGENT_TEMPLATE if self.provider == ModelProvider.META and _is_llama_legacy_model(self.model_id): logger.debug( "Llama-3.x models have limited performance with native tool calling. Wayflow will instead use the `LLAMA_AGENT_TEMPLATE`, which yields better performance than native tool calling" ) - # llama3.x works better with custom template return LLAMA_AGENT_TEMPLATE if self.provider == ModelProvider.GOOGLE: # google models do not support standalone system messages diff --git a/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py b/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py index eec9ea0b0..479563eb6 100644 --- a/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py +++ b/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py @@ -115,7 +115,9 @@ from wayflowcore.agentspec.components import ( PluginVllmEmbeddingConfig as AgentSpecPluginVllmEmbeddingConfig, ) -from wayflowcore.agentspec.components import all_deserialization_plugin +from wayflowcore.agentspec.components import ( + all_deserialization_plugin, +) from wayflowcore.agentspec.components.agent import ExtendedAgent as AgentSpecExtendedAgent from wayflowcore.agentspec.components.contextprovider import ( PluginConstantContextProvider as AgentSpecPluginConstantContextProvider, @@ -644,24 +646,39 @@ def convert_to_wayflow( ) return agent elif isinstance(agentspec_component, (AgentSpecMCPTool, AgentSpecPluginMCPTool)): + mcp_tool_misses_description_and_inputs = ( + not agentspec_component.description and not agentspec_component.inputs + ) return RuntimeMCPTool( name=agentspec_component.name, client_transport=conversion_context.convert( agentspec_component.client_transport, tool_registry, converted_components ), - description=agentspec_component.description, - input_descriptors=[ - self._convert_property_to_runtime(input_property) - for input_property in agentspec_component.inputs or [] - ], - output_descriptors=[ - self._convert_property_to_runtime(output_property) - for output_property in agentspec_component.outputs or [] - ], + description=( + agentspec_component.description + if not mcp_tool_misses_description_and_inputs + else None + ), + input_descriptors=( + [ + self._convert_property_to_runtime(input_property) + for input_property in agentspec_component.inputs or [] + ] + if not mcp_tool_misses_description_and_inputs + else None + ), + output_descriptors=( + [ + self._convert_property_to_runtime(output_property) + for output_property in agentspec_component.outputs or [] + ] + if not mcp_tool_misses_description_and_inputs + else None + ), requires_confirmation=agentspec_component.requires_confirmation, id=agentspec_component.id, - _validate_server_exists=False, - _validate_tool_exist_on_server=False, + _validate_server_exists=mcp_tool_misses_description_and_inputs, + _validate_tool_exist_on_server=mcp_tool_misses_description_and_inputs, ) elif isinstance(agentspec_component, AgentSpecPluginConstantValuesNode): # Map PluginConstantValuesNode -> RuntimeConstantValuesStep diff --git a/wayflowcore/src/wayflowcore/templates/llama4template.py b/wayflowcore/src/wayflowcore/templates/llama4template.py new file mode 100644 index 000000000..e5b321f5a --- /dev/null +++ b/wayflowcore/src/wayflowcore/templates/llama4template.py @@ -0,0 +1,271 @@ +# Copyright © 2025 Oracle and/or its affiliates. +# +# This software is under the Apache License 2.0 +# (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License +# (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. +import ast +import json +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence + +from wayflowcore._utils.formatting import generate_tool_id, stringify +from wayflowcore.messagelist import Message +from wayflowcore.outputparser import ToolOutputParser +from wayflowcore.serialization.serializer import SerializableObject +from wayflowcore.templates.template import PromptTemplate +from wayflowcore.tools import ToolRequest, ToolResult +from wayflowcore.transforms import MessageTransform + +#################################################### +########### MESSAGE TRANSFORM ########### +#################################################### + + +class Llama4PythonicTransform(MessageTransform, SerializableObject): + """Simple message processor that joins tool requests and calls into a python-like message""" + + def __call__(self, messages: List["Message"]) -> List["Message"]: + formatted_messages = [] + for msg in messages: + if msg.tool_requests is not None: + new_message = Message( + role="assistant", + content=Llama4PythonicTransform._tool_requests_to_call_str(msg.tool_requests), + ) + elif msg.tool_result is not None: + new_message = Message( + role="user", + content=Llama4PythonicTransform._tool_result_to_str(msg.tool_result), + ) + else: + new_message = msg + formatted_messages.append(new_message) + return formatted_messages + + @staticmethod + def _format_value(v: Any) -> str: + if v is None or isinstance(v, (bool, int, float, str)): + return repr(v) + if isinstance(v, (list, tuple)): + inner = ", ".join(Llama4PythonicTransform._format_value(x) for x in v) + return f"[{inner}]" if isinstance(v, list) else f"({inner}{',' if len(v) == 1 else ''})" + if isinstance(v, dict): + items = ", ".join( + f"{repr(k)}: {Llama4PythonicTransform._format_value(vv)}" + for k, vv in sorted(v.items(), key=lambda kv: str(kv[0])) + ) + return "{" + items + "}" + if isinstance(v, str): + return repr(v) + try: + return json.dumps(v, sort_keys=True, ensure_ascii=False) + except TypeError: + return repr(v) + + @staticmethod + def _tool_request_to_call_str(req: ToolRequest) -> str: + # Deterministic order + items = [] + for k in sorted(req.args.keys()): + items.append(f"{k}={Llama4PythonicTransform._format_value(req.args[k])}") + return f"{req.name}({', '.join(items)})" + + @staticmethod + def _tool_requests_to_call_str(tool_requests: List[ToolRequest]) -> str: + return ( + "[" + + ",".join( + Llama4PythonicTransform._tool_request_to_call_str(tr) for tr in tool_requests + ) + + "]" + ) + + @staticmethod + def _tool_result_to_str( + tool_result: ToolResult, + ) -> str: + return f"{stringify(tool_result)}" + + +#################################################### +########### TOOL OUTPUT PARSER ########### +#################################################### + + +@dataclass(frozen=True) +class ToolCall: + name: str + kwargs: Dict[str, Any] + + +_UNPARSE_ERRORS = (AttributeError, ValueError, TypeError) + + +class CallVisitor(ast.NodeVisitor): + """ + Collects function call expressions. + + Design goals: + - Never crash on weird AST shapes + - Preserve python values where safe (ast.literal_eval), otherwise fall back to source strings + - Keep explicit *args/**kwargs separate (they aren't normal positional/keyword args) + """ + + def __init__( + self, + *, + allowed_names: Optional[Sequence[str]] = None, + ) -> None: + self.tool_calls: List[ToolCall] = [] + self.allowed_names = set(allowed_names) if allowed_names else None + + def _safe_value(self, expr: ast.AST) -> Any: + """ + Return a real python value if it's a literal (numbers, strings, dict/list literals, etc.), + otherwise return source code string. + """ + try: + return ast.literal_eval(expr) + except (ValueError, SyntaxError, TypeError): + try: + return ast.unparse(expr) + except _UNPARSE_ERRORS: + return ast.dump(expr, include_attributes=False) + + def _call_name(self, func: ast.AST) -> str: + """ + Best-effort fully qualified-ish name: + - Name: foo + - Attribute chain: pkg.mod.foo or obj.method + - Other callables: + """ + # Name: foo + if isinstance(func, ast.Name): + return func.id + + # Attribute: x.y (possibly chained) + if isinstance(func, ast.Attribute): + parts: List[str] = [] + cur: ast.AST = func + while isinstance(cur, ast.Attribute): + parts.append(cur.attr) + cur = cur.value + + if isinstance(cur, ast.Name): + parts.append(cur.id) + return ".".join(reversed(parts)) + + # Something like (get_obj()).method -> can't name base cleanly + try: + base = ast.unparse(cur) + except _UNPARSE_ERRORS: + return "" + + # parts currently holds attrs from outer->inner, reversed gives inner->outer. + # For (get_obj()).method, parts would be ["method"] so this becomes "base.method". + return f"{base}." + ".".join(reversed(parts)) + + # Subscript call: fns[i](...) + if isinstance(func, ast.Subscript): + try: + return ast.unparse(func) + except _UNPARSE_ERRORS: + return "" + + # Lambda / Call / etc. + try: + return ast.unparse(func) + except _UNPARSE_ERRORS: + return "" + + def visit_Call(self, node: ast.Call) -> None: + self.generic_visit(node) + + name = self._call_name(node.func) + if self.allowed_names is not None and name not in self.allowed_names: + return + + kwargs: Dict[str, Any] = {} + starred_kwargs: List[str] = [] + + for kw in node.keywords: + if kw.arg is None: + try: + starred_kwargs.append(ast.unparse(kw.value)) + except _UNPARSE_ERRORS: + try: + starred_kwargs.append(ast.dump(kw.value, include_attributes=False)) + except TypeError: + starred_kwargs.append("") + else: + kwargs[kw.arg] = self._safe_value(kw.value) + + self.tool_calls.append( + ToolCall( + name=name, + kwargs=kwargs, + ) + ) + + +class PythonToolOutputParser(ToolOutputParser): + """Parses tool requests from Python function call syntax.""" + + def parse_tool_request_from_str(self, raw_txt: str) -> List[ToolRequest]: + """Parses tool calls of the format 'some_tool(arg1=...)'""" + try: + ast_tree = ast.parse(raw_txt) + visitor = CallVisitor() + visitor.visit(ast_tree) + return [ + ToolRequest( + name=tool_call.name, + args=tool_call.kwargs, + tool_request_id=generate_tool_id(), + ) + for tool_call in visitor.tool_calls + ] + + except (SyntaxError, ValueError, TypeError, RecursionError) as e: + logging.debug("Could not find any tool call in %s (%s)", raw_txt, str(e)) + return [] + + +#################################################### +########### TEMPLATE ########### +#################################################### + +PYTHON_CALL_CHAT_SYSTEM_TEMPLATE = """{%- if custom_instruction -%}{{custom_instruction}}{%- endif -%} + +Here is a list of functions in JSON format that you can invoke. Use exact python format: +[func_name1(param1=value1, param2=value2), func_name2(...)] + +Do not use variables. Do not output pure text, always output tool calls. + +Available tools: + +{%- for tool in __TOOLS__ %} + {{- tool.function | tojson(indent=4) }} + {{- "\n\n" }} +{%- endfor %} + +You ONLY have access to those tools, calling any tool not mentioned above will result in a FAILURE. + +Tool outputs will be put inside delimiters: ... in user messages. +Never output an empty tool list, use `talk_to_user` or `submit_tool` (whichever is available) instead. +""" + + +LLAMA4_PYTHONIC_AGENT_TEMPLATE = PromptTemplate( + messages=[ + Message(role="system", content=PYTHON_CALL_CHAT_SYSTEM_TEMPLATE), + PromptTemplate.CHAT_HISTORY_PLACEHOLDER, + ], + native_tool_calling=False, # <-- important, so that we don't use the buggy tool calling support + post_rendering_transforms=[ + Llama4PythonicTransform() + ], # <-- we format ourselves the tool format + output_parser=PythonToolOutputParser(), # <-- we parse ourselves the tool calls +) +"""Pythonic agent template that leverages llama4 pythonic syntax to write tool calls"""