Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions packages/narada-core/src/narada_core/actions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# There is no `AgentRequest` because the `agent` action delegates to the `dispatch_request` method
# under the hood.

_MaybeStructuredOutput = TypeVar("_MaybeStructuredOutput", bound=BaseModel | None)
_StructuredOutputT = TypeVar("_StructuredOutputT")


class AgentUsage(BaseModel):
Expand Down Expand Up @@ -209,6 +209,11 @@ class ObjectSetPropertiesTrace(BaseModel):
description: str


class OutputTrace(BaseModel):
step_type: Literal["output"]
description: str


ApaStepTrace = Annotated[
GoToUrlTrace
| GetUrlTrace
Expand Down Expand Up @@ -237,7 +242,8 @@ class ObjectSetPropertiesTrace(BaseModel):
| WaitTrace
| DataTableInsertRowTrace
| DataTableUpdateCellValueTrace
| ObjectSetPropertiesTrace,
| ObjectSetPropertiesTrace
| OutputTrace,
Field(discriminator="step_type"),
]

Expand All @@ -259,11 +265,25 @@ def parse_action_trace(trace_data: list[dict[str, Any] | Any]) -> ActionTrace:
return _ApaActionTraceAdapter.validate_python(trace_data)


class AgentResponse(BaseModel, Generic[_MaybeStructuredOutput]):
class TextOutput(BaseModel):
type: Literal["text"]
content: str


class StructuredOutput(BaseModel, Generic[_StructuredOutputT]):
type: Literal["structured"]
content: _StructuredOutputT


class AgentResponse(BaseModel, Generic[_StructuredOutputT]):
request_id: str
status: Literal["success", "error", "input-required"]
text: str
structured_output: _MaybeStructuredOutput | None
structured_output: _StructuredOutputT | None
output: Annotated[
TextOutput | StructuredOutput[_StructuredOutputT],
Field(discriminator="type"),
]
usage: AgentUsage
action_trace: ActionTrace | None = None

Expand Down
6 changes: 6 additions & 0 deletions packages/narada-core/src/narada_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ class ReadCsvTrace(TypedDict):
description: str


class OutputTrace(TypedDict):
step_type: Literal["output"]
description: str


class StartTrace(TypedDict):
step_type: Literal["start"]
url: str
Expand Down Expand Up @@ -280,6 +285,7 @@ class ObjectSetPropertiesTrace(TypedDict):
| GetSimplifiedHtmlTrace
| GetScreenshotTrace
| RunCustomAgentTrace
| OutputTrace
| IfTrace
| SetVariableTrace
| WaitTrace
Expand Down
19 changes: 12 additions & 7 deletions packages/narada-pyodide/src/narada/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,17 @@ async def dispatch_request(
if response_content is not None:
# Populate the `structuredOutput` field. This is a client-side field
# that's not directly returned by the API.
if output_schema is None:
response_content["structuredOutput"] = None
else:
structured_output = output_schema.model_validate_json(
response_content["text"]
output_data = response_content.get("output")
if (
output_schema is not None
and output_data is not None
and output_data.get("type") == "structured"
):
response_content["structuredOutput"] = (
output_schema.model_validate(output_data["content"])
)
response_content["structuredOutput"] = structured_output
else:
response_content["structuredOutput"] = None

return response

Expand All @@ -319,7 +323,7 @@ async def agent(
mcp_servers: list[McpServer] | None = None,
variables: dict[str, str] | None = None,
timeout: int = 1000,
) -> AgentResponse[None]: ...
) -> AgentResponse[dict[str, Any]]: ...

@overload
async def agent(
Expand Down Expand Up @@ -375,6 +379,7 @@ async def agent(
request_id=remote_dispatch_response["requestId"],
status=remote_dispatch_response["status"],
text=response_content["text"],
output=response_content.get("output"),
structured_output=response_content.get("structuredOutput"),
usage=AgentUsage.model_validate(remote_dispatch_response["usage"]),
action_trace=action_trace,
Expand Down
19 changes: 12 additions & 7 deletions packages/narada/src/narada/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,17 @@ async def dispatch_request(
if response_content is not None:
# Populate the `structuredOutput` field. This is a client-side field
# that's not directly returned by the API.
Comment thread
pavlo-haidar marked this conversation as resolved.
if output_schema is None:
response_content["structuredOutput"] = None
else:
structured_output = output_schema.model_validate_json(
response_content["text"]
output_data = response_content.get("output")
if (
output_schema is not None
and output_data is not None
and output_data.get("type") == "structured"
):
response_content["structuredOutput"] = (
output_schema.model_validate(output_data["content"])
)
response_content["structuredOutput"] = structured_output
else:
response_content["structuredOutput"] = None

return response

Expand All @@ -294,7 +298,7 @@ async def agent(
mcp_servers: list[McpServer] | None = None,
variables: dict[str, str] | None = None,
timeout: int = 1000,
) -> AgentResponse[None]: ...
) -> AgentResponse[dict[str, Any]]: ...

@overload
async def agent(
Expand Down Expand Up @@ -353,6 +357,7 @@ async def agent(
request_id=remote_dispatch_response["requestId"],
status=remote_dispatch_response["status"],
text=response_content["text"],
output=response_content["output"],
structured_output=response_content.get("structuredOutput"),
usage=AgentUsage.model_validate(remote_dispatch_response["usage"]),
action_trace=action_trace,
Expand Down
Loading