Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions python/nemo_relay/integrations/langchain/_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,50 +304,55 @@ def model_response_from_json(payload: Any, codec: Any) -> ModelResponse[Any]:
raise TypeError(f"NeMo Relay model execution returned {type(decoded)!r}, expected ModelResponse")


def _prepare_outputs(outputs: Any) -> Any:
"""Prepare a NeMo Relay scope output dict for returning to LangChain."""
if isinstance(outputs, dict):
prepared_outputs = {}
for key, value in outputs.items():
prepared_outputs[key] = _prepare_outputs(value)
elif isinstance(outputs, list | tuple):
prepared_outputs = []
for value in outputs:
prepared_outputs.append(_prepare_outputs(value))
elif isinstance(outputs, Command):
prepared_outputs = {
def _prepare_lc_payloads(payload: Any) -> Any:
"""
Convert a LangChain payload to a JSON-serializable structure

Typically the entry point to this method is a LangChain dictionary containing LC message objects, and the returned
dictionary should contain the same structure, but the values are JSON serializable representations
"""
if isinstance(payload, dict):
prepared = {}
for key, value in payload.items():
prepared[key] = _prepare_lc_payloads(value)
elif isinstance(payload, list | tuple):
prepared = []
for value in payload:
prepared.append(_prepare_lc_payloads(value))
elif isinstance(payload, Command):
prepared = {
"type": "command",
"command": {
"graph": _prepare_outputs(outputs.graph),
"update": _prepare_outputs(outputs.update),
"resume": _prepare_outputs(outputs.resume),
"goto": _prepare_outputs(outputs.goto),
"graph": _prepare_lc_payloads(payload.graph),
"update": _prepare_lc_payloads(payload.update),
"resume": _prepare_lc_payloads(payload.resume),
"goto": _prepare_lc_payloads(payload.goto),
},
}
elif isinstance(outputs, Send):
prepared_outputs = {
elif isinstance(payload, Send):
prepared = {
"type": "send",
"send": {
"node": outputs.node,
"arg": _prepare_outputs(outputs.arg),
"node": payload.node,
"arg": _prepare_lc_payloads(payload.arg),
},
}
elif isinstance(outputs, ToolMessage):
prepared_outputs = {
elif isinstance(payload, ToolMessage):
prepared = {
"type": "tool_message",
"tool_call": {
"name": outputs.name,
"id": outputs.id,
"tool_call_id": outputs.tool_call_id,
"content": outputs.content,
"name": payload.name,
"id": payload.id,
"tool_call_id": payload.tool_call_id,
"content": payload.content,
},
}
elif isinstance(outputs, BaseMessage):
prepared_outputs = {
elif isinstance(payload, BaseMessage):
prepared = {
"type": "message",
"message": messages_to_dict([outputs]),
"message": messages_to_dict([payload]),
}
else:
prepared_outputs = outputs
prepared = payload

return prepared_outputs
return prepared
19 changes: 10 additions & 9 deletions python/nemo_relay/integrations/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from langchain_core.callbacks.base import BaseCallbackHandler

import nemo_relay
from nemo_relay.integrations.langchain._serialization import _prepare_outputs
from nemo_relay.integrations.langchain._serialization import _prepare_lc_payloads

if typing.TYPE_CHECKING:
from uuid import UUID
Expand Down Expand Up @@ -54,21 +54,23 @@ def on_chain_start(
if name is None:
name = "Unknown"

parent = self._scope_handles.get(parent_run_id) if parent_run_id else None
parent = None
if parent_run_id is not None:
parent = self._scope_handles.get(parent_run_id)

scope_metadata = metadata.copy() if metadata else {}
scope_metadata["langchain_run_id"] = str(run_id)
prepared_inputs = _prepare_lc_payloads(inputs)
handle = nemo_relay.scope.push(
name,
nemo_relay.ScopeType.Agent,
handle=parent,
input=inputs,
input=prepared_inputs,
metadata=scope_metadata,
)
self._scope_handles[run_id] = handle
except Exception:
_logger.debug("NeMo Relay: on_chain_start failed", exc_info=True)
return None
_logger.error("NeMo Relay: on_chain_start failed", exc_info=True)

def on_chain_end(
self,
Expand All @@ -80,7 +82,6 @@ def on_chain_end(
) -> typing.Any:
"""Pop the NeMo Relay scope associated with a LangChain chain run."""
self._pop_scope(run_id, output=outputs)
return None

def on_chain_error(
self,
Expand All @@ -92,14 +93,14 @@ def on_chain_error(
) -> typing.Any:
"""Pop the NeMo Relay scope associated with a failed LangChain chain run."""
self._pop_scope(run_id, output={"error": repr(error)})
return None

def _pop_scope(self, run_id: UUID, *, output: dict[str, typing.Any] | None = None) -> None:
handle = self._scope_handles.pop(run_id, None)
if handle is None:
return

try:
prepared_outputs = _prepare_outputs(output) if output is not None else None
prepared_outputs = _prepare_lc_payloads(output) if output is not None else None
nemo_relay.scope.pop(handle, output=prepared_outputs)
except Exception:
_logger.warning("NeMo Relay: scope.pop failed", exc_info=True)
_logger.error("NeMo Relay: scope.pop failed", exc_info=True)
4 changes: 2 additions & 2 deletions python/nemo_relay/integrations/langgraph/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from langgraph.callbacks import GraphCallbackHandler, GraphInterruptEvent, GraphResumeEvent

import nemo_relay
from nemo_relay.integrations.langchain._serialization import _prepare_outputs
from nemo_relay.integrations.langchain._serialization import _prepare_lc_payloads
from nemo_relay.integrations.langchain.callbacks import NemoRelayCallbackHandler as LangChainNemoRelayCallbackHandler

_logger = logging.getLogger(__name__)
Expand All @@ -20,7 +20,7 @@
def _json_safe(value: Any) -> nemo_relay.Json:
"""Return a conservative JSON-compatible representation for mark payloads."""
try:
value = _prepare_outputs(value)
value = _prepare_lc_payloads(value)
except Exception:
pass

Expand Down
12 changes: 8 additions & 4 deletions python/nemo_relay/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ToolRequestIntercept,
ToolSanitizeGuardrail,
UnsupportedBehavior,
subscribers,
)
from nemo_relay._native import (
active_plugin_report as _active_plugin_report,
Expand Down Expand Up @@ -332,23 +333,26 @@ def clear() -> None:


@asynccontextmanager
async def plugin(config: PluginConfig | JsonObject) -> AsyncIterator[ConfigReport]:
async def plugin(config: PluginConfig | JsonObject, *, clear_on_exit: bool = True) -> AsyncIterator[ConfigReport]:
"""Context manager for plugin initialization and cleanup.

Args:
config: `PluginConfig` or an equivalent JSON object.
clear_on_exit: Whether to clear the plugin configuration on exit.

Yields:
The `ConfigReport` for the initialized configuration.

Behavior:
This context manager initializes the plugin configuration on entry and clears it on exit.
"""
report = await initialize(config)
report_ = await initialize(config)
try:
yield report
yield report_
finally:
clear()
subscribers.flush()
if clear_on_exit:
clear()


def report() -> ConfigReport | None:
Expand Down
Loading