Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 104 additions & 32 deletions wayflowcore/src/wayflowcore/_utils/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import logging
import uuid
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple

from json_repair import json_repair

Expand Down Expand Up @@ -296,53 +296,125 @@ def generate_tool_id() -> str:
return str(uuid.uuid4())


_UNPARSE_ERRORS = (AttributeError, ValueError, TypeError)


# AST visitor class to parse tool calls
class CallVisitor(ast.NodeVisitor):
def __init__(self) -> None:
self.tool_calls: List[Tuple[str, Dict[str, Any]]] = []
"""
Collects function call expressions.

def visit_Call(self, node: ast.Call) -> None:
arg_dict = {}
Design goals:
- Never crash on weird AST shapes
- Preserve python values where safe (ast.literal_eval), otherwise fall back to source strings
- Keep explicit *args/**kwargs separate (they aren't normal positional/keyword args)
"""

def __init__(
self,
*,
allowed_names: Optional[Sequence[str]] = None,
) -> None:
self.tool_calls: List[ToolRequest] = []
self.allowed_names = set(allowed_names) if allowed_names else None

def _safe_value(self, expr: ast.AST) -> Any:
"""
Return a real python value if it's a literal (numbers, strings, dict/list literals, etc.),
otherwise return source code string.
"""
try:
return ast.literal_eval(expr)
except (ValueError, SyntaxError, TypeError):
try:
return ast.unparse(expr)
except _UNPARSE_ERRORS:
return ast.dump(expr, include_attributes=False)

def _call_name(self, func: ast.AST) -> str:
"""
Best-effort fully qualified-ish name:
- Name: foo
- Attribute chain: pkg.mod.foo or obj.method
- Other callables: <expr>
"""
# Name: foo
if isinstance(func, ast.Name):
return func.id

# Attribute: x.y (possibly chained)
if isinstance(func, ast.Attribute):
parts: List[str] = []
cur: ast.AST = func
while isinstance(cur, ast.Attribute):
parts.append(cur.attr)
cur = cur.value

if isinstance(cur, ast.Name):
parts.append(cur.id)
return ".".join(reversed(parts))

# Something like (get_obj()).method -> can't name base cleanly
try:
base = ast.unparse(cur)
except _UNPARSE_ERRORS:
return "<attribute>"

# parts currently holds attrs from outer->inner, reversed gives inner->outer.
# For (get_obj()).method, parts would be ["method"] so this becomes "base.method".
return f"{base}." + ".".join(reversed(parts))

# Subscript call: fns[i](...)
if isinstance(func, ast.Subscript):
try:
return ast.unparse(func)
except _UNPARSE_ERRORS:
return "<subscript>"

# Lambda / Call / etc.
try:
return ast.unparse(func)
except _UNPARSE_ERRORS:
return "<expr>"

# first parse children to enqueue recursive tool calls first
def visit_Call(self, node: ast.Call) -> None:
self.generic_visit(node)

# positional arguments
for i, arg in enumerate(node.args):
if isinstance(arg, ast.AST):
key = f"arg{i}"
arg_dict[key] = ast.unparse(arg)
name = self._call_name(node.func)
if self.allowed_names is not None and name not in self.allowed_names:
return

# keyword arguments
for kw in node.keywords:
if isinstance(kw, ast.keyword):
key = kw.arg if kw.arg else "**"
arg_dict[key] = (
kw.value.value if isinstance(kw.value, ast.Constant) else ast.unparse(kw.value)
)
kwargs: Dict[str, Any] = {}
starred_kwargs: List[str] = []

if isinstance(node.func, ast.Attribute):
name = f"{getattr(node.func.value, 'id')}.{node.func.attr}"
else:
name = getattr(node.func, "id")
for kw in node.keywords:
if kw.arg is None:
try:
starred_kwargs.append(ast.unparse(kw.value))
except _UNPARSE_ERRORS:
try:
starred_kwargs.append(ast.dump(kw.value, include_attributes=False))
except TypeError:
starred_kwargs.append("<kwargs>")
else:
kwargs[kw.arg] = self._safe_value(kw.value)

self.tool_calls.append((name, arg_dict))
self.tool_calls.append(
ToolRequest(
name=name,
args=kwargs,
tool_request_id=generate_tool_id(),
)
)


def parse_tool_call_using_ast(raw_txt: str) -> List[ToolRequest]:
try:
ast_tree = ast.parse(raw_txt)
visitor = CallVisitor()
visitor.visit(ast_tree)
return [
ToolRequest(
name=name,
args=args,
tool_request_id=generate_tool_id(),
)
for name, args in visitor.tool_calls
]
except Exception as e:
return visitor.tool_calls
except (SyntaxError, ValueError, TypeError, RecursionError) as e:
logger.debug("Could not find any tool call in %s (%s)", raw_txt, str(e))
return []

Expand Down
6 changes: 5 additions & 1 deletion wayflowcore/src/wayflowcore/models/_modelhelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,8 @@ def _is_gemma_model(model_id: str) -> bool:


def _is_llama_legacy_model(model_id: str) -> bool:
return "llama" in model_id.lower() and "3." in model_id
return "llama" in model_id.lower() and ("3." in model_id or "-3" in model_id)


def _is_recent_llama_model(model_id: str) -> bool:
return "llama" in model_id.lower() and ("4." in model_id or "-4" in model_id)
10 changes: 8 additions & 2 deletions wayflowcore/src/wayflowcore/models/ocigenaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from wayflowcore.tools import Tool, ToolRequest
from wayflowcore.transforms import CanonicalizationMessageTransform

from ._modelhelpers import _is_llama_legacy_model
from ._modelhelpers import _is_llama_legacy_model, _is_recent_llama_model
from ._openaihelpers import _property_to_openai_schema
from ._openaihelpers._utils import _safe_json_loads
from ._requesthelpers import StreamChunkType, TaggedMessageChunkTypeWithTokenUsage
Expand Down Expand Up @@ -393,11 +393,17 @@ def default_agent_template(self) -> "PromptTemplate":

if self.provider == ModelProvider.COHERE:
return NATIVE_AGENT_TEMPLATE
if self.provider == ModelProvider.META and _is_recent_llama_model(self.model_id):
logger.debug(
"Llama-4.x models have limited performance with native tool calling. Wayflow will instead use the `PYTHON_CALL_AGENT_TEMPLATE`, which yields better performance than native tool calling"
)
from wayflowcore.templates.llama4template import LLAMA4_PYTHONIC_AGENT_TEMPLATE

return LLAMA4_PYTHONIC_AGENT_TEMPLATE
if self.provider == ModelProvider.META and _is_llama_legacy_model(self.model_id):
logger.debug(
"Llama-3.x models have limited performance with native tool calling. Wayflow will instead use the `LLAMA_AGENT_TEMPLATE`, which yields better performance than native tool calling"
)
# llama3.x works better with custom template
return LLAMA_AGENT_TEMPLATE
if self.provider == ModelProvider.GOOGLE:
# google models do not support standalone system messages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@
from wayflowcore.agentspec.components import (
PluginVllmEmbeddingConfig as AgentSpecPluginVllmEmbeddingConfig,
)
from wayflowcore.agentspec.components import all_deserialization_plugin
from wayflowcore.agentspec.components import (
all_deserialization_plugin,
)
from wayflowcore.agentspec.components.agent import ExtendedAgent as AgentSpecExtendedAgent
from wayflowcore.agentspec.components.contextprovider import (
PluginConstantContextProvider as AgentSpecPluginConstantContextProvider,
Expand Down Expand Up @@ -644,24 +646,39 @@ def convert_to_wayflow(
)
return agent
elif isinstance(agentspec_component, (AgentSpecMCPTool, AgentSpecPluginMCPTool)):
mcp_tool_misses_description_and_inputs = (
not agentspec_component.description and not agentspec_component.inputs
)
return RuntimeMCPTool(
name=agentspec_component.name,
client_transport=conversion_context.convert(
agentspec_component.client_transport, tool_registry, converted_components
),
description=agentspec_component.description,
input_descriptors=[
self._convert_property_to_runtime(input_property)
for input_property in agentspec_component.inputs or []
],
output_descriptors=[
self._convert_property_to_runtime(output_property)
for output_property in agentspec_component.outputs or []
],
description=(
agentspec_component.description
if not mcp_tool_misses_description_and_inputs
else None
),
input_descriptors=(
[
self._convert_property_to_runtime(input_property)
for input_property in agentspec_component.inputs or []
]
if not mcp_tool_misses_description_and_inputs
else None
),
output_descriptors=(
[
self._convert_property_to_runtime(output_property)
for output_property in agentspec_component.outputs or []
]
if not mcp_tool_misses_description_and_inputs
else None
),
requires_confirmation=agentspec_component.requires_confirmation,
id=agentspec_component.id,
_validate_server_exists=False,
_validate_tool_exist_on_server=False,
_validate_server_exists=mcp_tool_misses_description_and_inputs,
_validate_tool_exist_on_server=mcp_tool_misses_description_and_inputs,
)
elif isinstance(agentspec_component, AgentSpecPluginConstantValuesNode):
# Map PluginConstantValuesNode -> RuntimeConstantValuesStep
Expand Down
Loading