diff --git a/examples/critic.py b/examples/critic.py new file mode 100644 index 0000000..145d444 --- /dev/null +++ b/examples/critic.py @@ -0,0 +1,39 @@ +import asyncio + +from narada import Narada, CriticConfig +from pydantic import BaseModel, Field + + +class SearchCriticOutput(BaseModel): + search_query_used: str = Field(description="The exact search query the agent used") + result_count: int = Field(description="The number of results the agent found") + + +async def main() -> None: + # Initialize the Narada client. + async with Narada() as narada: + window = await narada.open_and_initialize_browser_window() + + # Define a critic that verifies the agent completed the task and extracts + # additional structured information from the agent's actions. + critic = CriticConfig( + prompt=( + "Verify that the agent successfully searched Google and found results. " + "Extract the exact search query the agent used and the number of results found." + ), + output_schema=SearchCriticOutput, + ) + + # Run a task with the critic. After the main agent finishes, the critic + # evaluates whether the task was completed successfully. + response = await window.agent( + prompt='Search Google for "Narada AI" and tell me how many results were found.', + critic=critic, + ) + + print("Agent response:", response.text) + print("Critic result:", response.critic_result.validation_passed) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/packages/narada-core/src/narada_core/actions/critic.py b/packages/narada-core/src/narada_core/actions/critic.py new file mode 100644 index 0000000..a8a3f23 --- /dev/null +++ b/packages/narada-core/src/narada_core/actions/critic.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any, Awaitable, Callable + +from narada_core.models import Agent, CriticConfig +from pydantic import BaseModel, create_model + +from narada_core.actions.models import AgentUsage, CriticResult +from narada_core.tracing.model import parse_action_trace + +_VALIDATION_VAR = "narada_validation_passed" +_DEFAULT_CRITIC_PROMPT = ( + "Using your context about the actions and outcome of the previous agent, " + "determine whether its task was completed successfully." +) + + +async def run_critic( + *, + dispatch_request: Callable[..., Awaitable[Any]], + original_prompt: str, + response_content: dict[str, Any], + action_trace_raw: list[Any] | None, + critic: CriticConfig, + time_zone: str, + timeout: int, +) -> CriticResult: + output_schema = critic.get("output_schema") + if output_schema is not None: + combined_fields: dict[str, Any] = { + name: (info.annotation, info) + for name, info in output_schema.model_fields.items() + } + else: + combined_fields = {} + combined_fields[_VALIDATION_VAR] = (bool, ...) + CriticOutputModel = create_model("CriticOutput", **combined_fields) + + critic_dispatch_response = await dispatch_request( + prompt=critic.get("prompt", _DEFAULT_CRITIC_PROMPT), + agent=Agent.PRODUCTIVITY, + output_schema=CriticOutputModel, + critic_context={ + "agentPrompt": original_prompt, + "agentOutput": response_content["text"], + "actionTrace": action_trace_raw or [], + "validationVariableName": _VALIDATION_VAR, + }, + mcp_servers=critic.get("mcp_servers"), + time_zone=time_zone, + timeout=timeout, + ) + + critic_content = critic_dispatch_response["response"] + if critic_content is None: + raise ValueError("Critic dispatch returned no response") + + combined_output = critic_content.get("structuredOutput") + validation_passed = ( + bool(getattr(combined_output, _VALIDATION_VAR, False)) + if combined_output is not None + else False + ) + + structured_output: BaseModel | None = None + if output_schema is not None and combined_output is not None: + output_dict = combined_output.model_dump() + output_dict.pop(_VALIDATION_VAR, None) + structured_output = output_schema.model_validate(output_dict) + + critic_action_trace_raw = critic_content.get("actionTrace") + critic_action_trace = ( + parse_action_trace(critic_action_trace_raw) + if critic_action_trace_raw is not None + else None + ) + + return CriticResult( + validation_passed=validation_passed, + structured_output=structured_output, + usage=AgentUsage.model_validate(critic_dispatch_response["usage"]), + action_trace=critic_action_trace, + ) diff --git a/packages/narada-core/src/narada_core/actions/models.py b/packages/narada-core/src/narada_core/actions/models.py index 358efa5..3b9dfbf 100644 --- a/packages/narada-core/src/narada_core/actions/models.py +++ b/packages/narada-core/src/narada_core/actions/models.py @@ -40,6 +40,13 @@ class StructuredOutput(BaseModel, Generic[_StructuredOutputT]): content: _StructuredOutputT +class CriticResult(BaseModel): + validation_passed: bool + structured_output: Any + usage: AgentUsage + action_trace: tracing_model.ActionTrace | None = None + + class AgentResponse(BaseModel, Generic[_StructuredOutputT]): request_id: str status: Literal["success", "error", "input-required"] @@ -51,6 +58,7 @@ class AgentResponse(BaseModel, Generic[_StructuredOutputT]): ] usage: AgentUsage action_trace: tracing_model.ActionTrace | None = None + critic_result: CriticResult | None = None class AgenticSelectorClickAction(TypedDict): diff --git a/packages/narada-core/src/narada_core/models.py b/packages/narada-core/src/narada_core/models.py index d7d075f..307c7a9 100644 --- a/packages/narada-core/src/narada_core/models.py +++ b/packages/narada-core/src/narada_core/models.py @@ -77,6 +77,12 @@ class McpServer(BaseModel): selectedTools: list[str] | None = None +class CriticConfig(TypedDict, total=False): + prompt: str + output_schema: type[BaseModel] + mcp_servers: list[McpServer] + + class RemoteDispatchChatHistoryItem(TypedDict): role: Literal["user", "assistant"] content: str diff --git a/packages/narada-pyodide/src/narada/__init__.py b/packages/narada-pyodide/src/narada/__init__.py index 386ed83..ce52c7c 100644 --- a/packages/narada-pyodide/src/narada/__init__.py +++ b/packages/narada-pyodide/src/narada/__init__.py @@ -2,7 +2,15 @@ NaradaError, NaradaTimeoutError, ) -from narada_core.models import Agent, File, ReasoningEffort, Response, ResponseContent +from narada_core.actions.models import CriticResult +from narada_core.models import ( + Agent, + CriticConfig, + File, + ReasoningEffort, + Response, + ResponseContent, +) from narada.client import Narada from narada.utils import download_file, render_html @@ -17,6 +25,8 @@ "__version__", "Agent", "CloudBrowserWindow", + "CriticConfig", + "CriticResult", "download_file", "File", "LocalBrowserWindow", diff --git a/packages/narada-pyodide/src/narada/window.py b/packages/narada-pyodide/src/narada/window.py index 2a74d42..39b3cb2 100644 --- a/packages/narada-pyodide/src/narada/window.py +++ b/packages/narada-pyodide/src/narada/window.py @@ -20,6 +20,7 @@ from urllib.parse import urlencode from js import AbortController, setTimeout # type: ignore +from narada_core.actions.critic import run_critic from narada_core.actions.models import ( AgenticMouseAction, AgenticMouseActionRequest, @@ -30,6 +31,7 @@ AgentResponse, AgentUsage, CloseWindowRequest, + CriticResult, ExtensionActionRequest, ExtensionActionResponse, GetFullHtmlRequest, @@ -63,6 +65,7 @@ ) from narada_core.models import ( Agent, + CriticConfig, File, McpServer, ReasoningEffort, @@ -268,6 +271,7 @@ async def dispatch_request( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: dict[str, Any] | None = None, + critic_context: dict[str, Any] | None = None, callback_url: str | None = None, callback_secret: str | None = None, callback_headers: dict[str, Any] | None = None, @@ -291,6 +295,7 @@ async def dispatch_request( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: dict[str, Any] | None = None, + critic_context: dict[str, Any] | None = None, callback_url: str | None = None, callback_secret: str | None = None, callback_headers: dict[str, Any] | None = None, @@ -314,6 +319,7 @@ async def dispatch_request( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: dict[str, Any] | None = None, + critic_context: dict[str, Any] | None = None, callback_url: str | None = None, callback_secret: str | None = None, callback_headers: dict[str, Any] | None = None, @@ -377,6 +383,8 @@ async def dispatch_request( body["secretVariables"] = secret_variables if input_variables is not None: body["inputVariables"] = input_variables + if critic_context is not None: + body["criticContext"] = critic_context if callback_url is not None: body["callbackUrl"] = callback_url if callback_secret is not None: @@ -551,6 +559,7 @@ async def agent( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: dict[str, Any] | None = None, + critic: CriticConfig | None = None, timeout: int = 1000, ) -> AgentResponse[dict[str, Any]]: ... @@ -567,6 +576,7 @@ async def agent( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: dict[str, Any] | None = None, + critic: CriticConfig | None = None, timeout: int = 1000, ) -> AgentResponse[_StructuredOutput]: ... @@ -583,6 +593,7 @@ async def agent( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: dict[str, Any] | None = None, + critic: CriticConfig | None = None, timeout: int = 1000, ) -> AgentResponse: """Invokes an agent in the Narada extension side panel chat.""" @@ -639,6 +650,18 @@ async def agent( else None ) + critic_result: CriticResult | None = None + if critic is not None: + critic_result = await run_critic( + dispatch_request=self.dispatch_request, + original_prompt=prompt, + response_content=response_content, + action_trace_raw=action_trace_raw, + critic=critic, + time_zone=time_zone, + timeout=timeout, + ) + return AgentResponse( request_id=remote_dispatch_response["requestId"], status=remote_dispatch_response["status"], @@ -647,6 +670,7 @@ async def agent( structured_output=response_content.get("structuredOutput"), usage=AgentUsage.model_validate(remote_dispatch_response["usage"]), action_trace=action_trace, + critic_result=critic_result, ) async def agentic_selector( diff --git a/packages/narada/src/narada/__init__.py b/packages/narada/src/narada/__init__.py index 1434183..756eccf 100644 --- a/packages/narada/src/narada/__init__.py +++ b/packages/narada/src/narada/__init__.py @@ -7,7 +7,15 @@ NaradaUnsupportedBrowserError, UserAbortedError, ) -from narada_core.models import Agent, File, ReasoningEffort, Response, ResponseContent +from narada_core.actions.models import CriticResult +from narada_core.models import ( + Agent, + CriticConfig, + File, + ReasoningEffort, + Response, + ResponseContent, +) from narada.client import Narada from narada.config import BrowserConfig, ProxyConfig @@ -20,6 +28,8 @@ "Agent", "BrowserConfig", "CloudBrowserWindow", + "CriticConfig", + "CriticResult", "download_file", "File", "LocalBrowserWindow", diff --git a/packages/narada/src/narada/window.py b/packages/narada/src/narada/window.py index dff0332..d3c07fa 100644 --- a/packages/narada/src/narada/window.py +++ b/packages/narada/src/narada/window.py @@ -21,6 +21,7 @@ ) import aiohttp +from narada_core.actions.critic import run_critic from narada_core.actions.models import ( AgenticMouseAction, AgenticMouseActionRequest, @@ -31,6 +32,7 @@ AgentResponse, AgentUsage, CloseWindowRequest, + CriticResult, ExtensionActionRequest, ExtensionActionResponse, GetFullHtmlRequest, @@ -64,6 +66,7 @@ ) from narada_core.models import ( Agent, + CriticConfig, File, McpServer, ReasoningEffort, @@ -308,6 +311,7 @@ async def dispatch_request( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: Mapping[str, Any] | None = None, + critic_context: dict[str, Any] | None = None, callback_url: str | None = None, callback_secret: str | None = None, callback_headers: Mapping[str, Any] | None = None, @@ -332,6 +336,7 @@ async def dispatch_request( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: Mapping[str, Any] | None = None, + critic_context: dict[str, Any] | None = None, callback_url: str | None = None, callback_secret: str | None = None, callback_headers: Mapping[str, Any] | None = None, @@ -356,6 +361,7 @@ async def dispatch_request( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: Mapping[str, Any] | None = None, + critic_context: dict[str, Any] | None = None, callback_url: str | None = None, callback_secret: str | None = None, callback_headers: Mapping[str, Any] | None = None, @@ -412,6 +418,8 @@ async def dispatch_request( body["inputVariables"] = await self._normalize_input_variables( input_variables=input_variables ) + if critic_context is not None: + body["criticContext"] = critic_context if callback_url is not None: body["callbackUrl"] = callback_url if callback_secret is not None: @@ -522,6 +530,7 @@ async def agent( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: Mapping[str, Any] | None = None, + critic: CriticConfig | None = None, timeout: int = 1000, ) -> AgentResponse[dict[str, Any]]: ... @@ -539,6 +548,7 @@ async def agent( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: Mapping[str, Any] | None = None, + critic: CriticConfig | None = None, timeout: int = 1000, ) -> AgentResponse[_StructuredOutput]: ... @@ -556,6 +566,7 @@ async def agent( mcp_servers: list[McpServer] | None = None, secret_variables: dict[str, str] | None = None, input_variables: Mapping[str, Any] | None = None, + critic: CriticConfig | None = None, timeout: int = 1000, ) -> AgentResponse: """Invokes an agent in the Narada extension side panel chat.""" @@ -614,6 +625,18 @@ async def agent( else None ) + critic_result: CriticResult | None = None + if critic is not None: + critic_result = await run_critic( + dispatch_request=self.dispatch_request, + original_prompt=prompt, + response_content=response_content, + action_trace_raw=action_trace_raw, + critic=critic, + time_zone=time_zone, + timeout=timeout, + ) + return AgentResponse( request_id=remote_dispatch_response["requestId"], status=remote_dispatch_response["status"], @@ -622,6 +645,7 @@ async def agent( structured_output=response_content.get("structuredOutput"), usage=AgentUsage.model_validate(remote_dispatch_response["usage"]), action_trace=action_trace, + critic_result=critic_result, ) async def agentic_selector(