Skip to content

Commit 2c62eae

Browse files
committed
Handle AIMessage.content properly
1 parent b1219b2 commit 2c62eae

15 files changed

Lines changed: 632 additions & 89 deletions

File tree

examples/ai_modinput_app/bin/agentic_weather.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from _collections_abc import dict_items
2121
from typing import final, override
2222

23+
from splunklib.ai.messages import AIMessage, ContentBlock, TextBlock
24+
2325
# ! NOTE: This insert is only needed for splunk-sdk-python CI/CD to work.
2426
# ! Remove this if you're modifying this example locally.
2527
sys.path.insert(0, "/splunklib-deps")
@@ -95,9 +97,9 @@ def stream_events(self, inputs: InputDefinition, ew: EventWriter) -> None:
9597
weather_events += list(reader)
9698

9799
for weather_event in weather_events:
98-
weather_event["human_readable"] = asyncio.run(
99-
self.invoke_agent(weather_event)
100-
)
100+
result = asyncio.run(self.invoke_agent(weather_event))
101+
weather_event["human_readable"] = self.parse_content(result)
102+
101103
logger.debug(f"{weather_event=}")
102104

103105
event = Event(
@@ -112,7 +114,7 @@ def stream_events(self, inputs: InputDefinition, ew: EventWriter) -> None:
112114

113115
logger.debug(f"Finishing enrichment for {input_name} at {csv_file_path}")
114116

115-
async def invoke_agent(self, weather_event: dict[str, str | int]) -> str:
117+
async def invoke_agent(self, weather_event: dict[str, str | int]) -> AIMessage:
116118
if not self.service:
117119
raise AssertionError("No Splunk connection available")
118120

@@ -127,7 +129,27 @@ async def invoke_agent(self, weather_event: dict[str, str | int]) -> str:
127129
data=weather_event,
128130
)
129131
logger.debug(f"{response=}")
130-
return response.final_message.content
132+
return response.final_message
133+
134+
def _parse_content_block(self, block: str | ContentBlock) -> str | None:
135+
match block:
136+
case TextBlock():
137+
return block.text
138+
case str():
139+
return block
140+
case _:
141+
return None
142+
143+
def parse_content(self, message: AIMessage) -> str:
144+
"""Parses the content from AIMessage and builds a single string our of it"""
145+
if isinstance(message.content, str):
146+
return message.content
147+
148+
return " ".join(
149+
parsed_block
150+
for block in message.content
151+
if (parsed_block := self._parse_content_block(block))
152+
)
131153

132154

133155
if __name__ == "__main__":

splunklib/ai/engines/langchain.py

Lines changed: 126 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@
7777
AgentResponse,
7878
AIMessage,
7979
BaseMessage,
80+
ContentBlock,
8081
HumanMessage,
82+
OpaqueBlock,
8183
OutputT,
8284
StructuredOutputCall,
8385
StructuredOutputMessage,
@@ -87,6 +89,7 @@
8789
SubagentStructuredResult,
8890
SubagentTextResult,
8991
SystemMessage,
92+
TextBlock,
9093
ToolCall,
9194
ToolFailureResult,
9295
ToolMessage,
@@ -951,7 +954,7 @@ async def awrap_tool_call(
951954
return LC_ToolMessage(
952955
name=_normalize_agent_name(call.name),
953956
tool_call_id=call.id,
954-
content=content,
957+
content=_map_content_to_langchain(content),
955958
status=status,
956959
artifact=sdk_result,
957960
)
@@ -1085,7 +1088,10 @@ def _convert_model_response_to_model_result(
10851088
# This invariant is asserted via ModelResponse.__post_init__
10861089
assert len(resp.message.structured_output_calls) <= 1
10871090

1088-
lc_message = LC_AIMessage(content=resp.message.content)
1091+
lc_message = LC_AIMessage(
1092+
content=_map_content_to_langchain(resp.message.content),
1093+
additional_kwargs=resp.message.extras or {},
1094+
)
10891095
# This field can't be set via __init__()
10901096
lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in resp.message.calls]
10911097

@@ -1160,7 +1166,7 @@ def _convert_tool_message_to_lc(
11601166
name=name,
11611167
tool_call_id=message.call_id,
11621168
status=status,
1163-
content=content,
1169+
content=_map_content_to_langchain(content),
11641170
artifact=artifact,
11651171
)
11661172

@@ -1243,9 +1249,10 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
12431249
ai_message = model_response
12441250
structured_response = None
12451251

1252+
additional_kwargs = cast(dict[str, Any], ai_message.additional_kwargs)
12461253
return ModelResponse(
12471254
message=AIMessage(
1248-
content=ai_message.content.__str__(),
1255+
content=_map_content_from_langchain(ai_message.content), # pyright: ignore[reportUnknownArgumentType]
12491256
calls=[
12501257
_map_tool_call_from_langchain(tc)
12511258
for tc in ai_message.tool_calls
@@ -1260,6 +1267,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
12601267
for tc in ai_message.tool_calls
12611268
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
12621269
],
1270+
extras=additional_kwargs,
12631271
),
12641272
structured_output=structured_response,
12651273
)
@@ -1422,6 +1430,28 @@ def _is_agent_name_valid(name: str) -> bool:
14221430
return set(name).issubset(AGENT_NAME_ALLOWED_CHARS)
14231431

14241432

1433+
def _parse_content_block(block: str | ContentBlock) -> str | None:
1434+
match block:
1435+
case TextBlock():
1436+
return block.text
1437+
case str():
1438+
return block
1439+
case _:
1440+
return None
1441+
1442+
1443+
def _parse_content(content: str | list[str | ContentBlock]) -> str:
1444+
"""Parses the content from AIMessage and builds a single string our of it"""
1445+
if isinstance(content, str):
1446+
return content
1447+
1448+
return " ".join(
1449+
parsed_block
1450+
for block in content
1451+
if (parsed_block := _parse_content_block(block))
1452+
)
1453+
1454+
14251455
def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:
14261456
if not agent.name:
14271457
raise AssertionError("Agent must have a name to be used by other Agents")
@@ -1433,7 +1463,10 @@ def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:
14331463

14341464
async def invoke_agent(
14351465
message: HumanMessage, thread_id: str | None
1436-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1466+
) -> tuple[
1467+
OutputT | str,
1468+
SubagentStructuredResult | SubagentTextResult,
1469+
]:
14371470
result = await agent.invoke([message], thread_id=thread_id)
14381471

14391472
if agent.output_schema:
@@ -1442,23 +1475,28 @@ async def invoke_agent(
14421475
structured_output=result.structured_output.model_dump(),
14431476
)
14441477

1445-
return result.final_message.content, SubagentTextResult(
1446-
content=result.final_message.content
1447-
)
1478+
text_content = _parse_content(result.final_message.content)
1479+
return text_content, SubagentTextResult(content=text_content)
14481480

14491481
InputSchema = agent.input_schema
14501482
if InputSchema is None:
14511483
if agent.conversation_store:
14521484

14531485
async def _run( # pyright: ignore[reportRedeclaration]
14541486
content: str, thread_id: str
1455-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1487+
) -> tuple[
1488+
OutputT | str,
1489+
SubagentStructuredResult | SubagentTextResult,
1490+
]:
14561491
return await invoke_agent(HumanMessage(content=content), thread_id)
14571492
else:
14581493

14591494
async def _run( # pyright: ignore[reportRedeclaration]
14601495
content: str,
1461-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1496+
) -> tuple[
1497+
OutputT | str,
1498+
SubagentStructuredResult | SubagentTextResult,
1499+
]:
14621500
return await invoke_agent(HumanMessage(content=content), None)
14631501

14641502
return StructuredTool.from_function(
@@ -1471,7 +1509,10 @@ async def _run( # pyright: ignore[reportRedeclaration]
14711509

14721510
async def invoke_agent_structured(
14731511
content: BaseModel, thread_id: str | None
1474-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1512+
) -> tuple[
1513+
OutputT | str,
1514+
SubagentStructuredResult | SubagentTextResult,
1515+
]:
14751516
result = await agent.invoke_with_data(
14761517
instructions="Follow the system prompt.",
14771518
data=content.model_dump(),
@@ -1484,15 +1525,17 @@ async def invoke_agent_structured(
14841525
structured_output=result.structured_output.model_dump(),
14851526
)
14861527

1487-
return result.final_message.content, SubagentTextResult(
1488-
content=result.final_message.content
1489-
)
1528+
text_content = _parse_content(result.final_message.content)
1529+
return text_content, SubagentTextResult(content=text_content)
14901530

14911531
if agent.conversation_store:
14921532

14931533
async def _run(
14941534
**kwargs: Any, # noqa: ANN401
1495-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1535+
) -> tuple[
1536+
OutputT | str,
1537+
SubagentStructuredResult | SubagentTextResult,
1538+
]:
14961539
content: BaseModel = kwargs["content"]
14971540
thread_id: str = kwargs["thread_id"]
14981541
return await invoke_agent_structured(content, thread_id)
@@ -1512,7 +1555,10 @@ async def _run(
15121555

15131556
async def _run(
15141557
**kwargs: Any, # noqa: ANN401
1515-
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
1558+
) -> tuple[
1559+
OutputT | str,
1560+
SubagentStructuredResult | SubagentTextResult,
1561+
]:
15161562
content = InputSchema(**kwargs)
15171563
return await invoke_agent_structured(content, None)
15181564

@@ -1564,11 +1610,69 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall:
15641610
return LC_ToolCall(id=call.id, name=name, args=args)
15651611

15661612

1613+
def _map_content_from_langchain(
1614+
content: str | list[str | dict[str, Any]],
1615+
) -> str | list[str | ContentBlock]:
1616+
if isinstance(content, str):
1617+
return content
1618+
1619+
result_content = [_map_content_block_from_langchain(b) for b in content]
1620+
1621+
return result_content
1622+
1623+
1624+
def _map_content_block_from_langchain(
1625+
block: str | dict[str, Any],
1626+
) -> str | ContentBlock:
1627+
if isinstance(block, str):
1628+
return block
1629+
1630+
match block.get("type"):
1631+
case "text":
1632+
return TextBlock(
1633+
text=block["text"], extras=block.get("extras"), id=block.get("id")
1634+
)
1635+
case _:
1636+
# NOTE: we return data we're not handling
1637+
# as opaque content blocks so they
1638+
# are preserved and sent back to the LLM
1639+
return OpaqueBlock(_data=block)
1640+
1641+
1642+
def _map_content_to_langchain(
1643+
content: str | list[str | ContentBlock],
1644+
) -> str | list[str | dict[str, Any]]:
1645+
if isinstance(content, str):
1646+
return content
1647+
1648+
result_content = [_map_content_block_to_langchain(b) for b in content]
1649+
1650+
return result_content
1651+
1652+
1653+
def _map_content_block_to_langchain(block: str | ContentBlock) -> str | dict[str, Any]:
1654+
if isinstance(block, str):
1655+
return block
1656+
1657+
match block:
1658+
case TextBlock():
1659+
result: dict[str, Any] = {
1660+
"type": "text",
1661+
"text": block.text,
1662+
"id": block.id,
1663+
}
1664+
if block.extras:
1665+
result["extras"] = block.extras
1666+
return result
1667+
case OpaqueBlock():
1668+
return block._data # pyright: ignore[reportPrivateUsage]
1669+
1670+
15671671
def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
15681672
match message:
15691673
case LC_AIMessage():
15701674
return AIMessage(
1571-
content=message.content.__str__(),
1675+
content=_map_content_from_langchain(message.content), # pyright: ignore[reportUnknownArgumentType]
15721676
calls=[
15731677
_map_tool_call_from_langchain(tc)
15741678
for tc in message.tool_calls
@@ -1583,6 +1687,7 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
15831687
for tc in message.tool_calls
15841688
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
15851689
],
1690+
extras=cast(dict[str, Any], message.additional_kwargs),
15861691
)
15871692
case LC_HumanMessage():
15881693
return HumanMessage(content=message.content.__str__())
@@ -1597,7 +1702,10 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
15971702
def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
15981703
match message:
15991704
case AIMessage():
1600-
lc_message = LC_AIMessage(content=message.content)
1705+
lc_message = LC_AIMessage(
1706+
content=_map_content_to_langchain(message.content),
1707+
additional_kwargs=message.extras or {},
1708+
)
16011709
# This field can't be set via constructor
16021710
lc_message.tool_calls = [
16031711
_map_tool_call_to_langchain(c) for c in message.calls

splunklib/ai/hooks.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,5 +199,3 @@ async def model_middleware(
199199
if self._deadline is not None and monotonic() >= self._deadline:
200200
raise TimeoutExceededException(timeout_seconds=self._seconds)
201201
return await handler(request)
202-
203-

0 commit comments

Comments
 (0)