Skip to content

Commit 0676acf

Browse files
Changing AgentState.response -> AgentState.messages
1 parent b1219b2 commit 0676acf

7 files changed

Lines changed: 135 additions & 89 deletions

File tree

splunklib/ai/engines/langchain.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,13 @@ async def awrap_tool_call(
279279
assert resp.artifact is None, "artifact is already populated"
280280

281281
if resp.name.startswith(AGENT_PREFIX):
282-
resp.artifact = SubagentFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType]
282+
resp.artifact = SubagentFailureResult(
283+
str(resp.content)
284+
) # pyright: ignore[reportUnknownArgumentType]
283285
else:
284-
resp.artifact = ToolFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType]
286+
resp.artifact = ToolFailureResult(
287+
str(resp.content)
288+
) # pyright: ignore[reportUnknownArgumentType]
285289

286290
return resp
287291

@@ -967,9 +971,9 @@ async def _sdk_handler(request: ToolRequest) -> ToolResponse:
967971
lc_request = _convert_tool_request_to_lc(request, original_request)
968972
result = await handler(lc_request)
969973
sdk_result = _convert_tool_message_from_lc(result)
970-
assert isinstance(sdk_result, ToolMessage), (
971-
"Expected tool response from tool middleware handler"
972-
)
974+
assert isinstance(
975+
sdk_result, ToolMessage
976+
), "Expected tool response from tool middleware handler"
973977
return ToolResponse(sdk_result.result)
974978

975979
return _sdk_handler
@@ -987,9 +991,9 @@ async def _sdk_handler(
987991
lc_request = _convert_subagent_request_to_lc(request, original_request)
988992
result = await handler(lc_request)
989993
sdk_result = _convert_tool_message_from_lc(result)
990-
assert isinstance(sdk_result, SubagentMessage), (
991-
"Expected subagent response from subagent middleware handler"
992-
)
994+
assert isinstance(
995+
sdk_result, SubagentMessage
996+
), "Expected subagent response from subagent middleware handler"
993997
return SubagentResponse(sdk_result.result)
994998

995999
return _sdk_handler
@@ -1182,16 +1186,18 @@ def _convert_tool_message_from_lc(
11821186
)
11831187
case LC_ToolMessage():
11841188
# If this is reached, we likely passed an invalid tool name to LangChain.
1185-
assert message.name is not None, (
1186-
"LangChain responded with a nameless tool call"
1187-
)
1189+
assert (
1190+
message.name is not None
1191+
), "LangChain responded with a nameless tool call"
11881192

11891193
if message.name.startswith(TOOL_STRATEGY_TOOL_PREFIX):
11901194
return StructuredOutputMessage(
11911195
name=message.name.removeprefix(TOOL_STRATEGY_TOOL_PREFIX),
11921196
call_id=message.tool_call_id,
11931197
status=message.status,
1194-
content=str(message.content), # pyright: ignore[reportUnknownArgumentType]
1198+
content=str(
1199+
message.content
1200+
), # pyright: ignore[reportUnknownArgumentType]
11951201
)
11961202

11971203
assert isinstance(message.artifact, ToolResult) or isinstance(
@@ -1266,7 +1272,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
12661272

12671273

12681274
def _convert_agent_state_to_lc(state: AgentState) -> LC_AgentState[Any]:
1269-
messages = [_map_message_to_langchain(m) for m in state.response.messages]
1275+
messages = [_map_message_to_langchain(m) for m in state.messages]
12701276
return LC_AgentState(messages=messages)
12711277

12721278

@@ -1351,7 +1357,9 @@ async def _tool_call(
13511357
except ToolException as e:
13521358
raise LC_ToolException(*e.args) from e
13531359
except LC_ToolException:
1354-
assert False, ( # noqa: PT015
1360+
assert (
1361+
False
1362+
), ( # noqa: PT015
13551363
"ToolException from LangChain should not be raised in tool.func"
13561364
)
13571365

@@ -1454,6 +1462,7 @@ async def _run( # pyright: ignore[reportRedeclaration]
14541462
content: str, thread_id: str
14551463
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
14561464
return await invoke_agent(HumanMessage(content=content), thread_id)
1465+
14571466
else:
14581467

14591468
async def _run( # pyright: ignore[reportRedeclaration]
@@ -1627,14 +1636,9 @@ def _convert_agent_state_from_langchain(
16271636
messages = state["messages"]
16281637
total_tokens_counter = _get_approximate_token_counter(model)
16291638
total_tokens = total_tokens_counter(messages)
1630-
1631-
response = AgentResponse[Any | None](
1632-
messages=[_map_message_from_langchain(m) for m in state["messages"]],
1633-
structured_output=state.get("structured_response"),
1634-
)
1635-
1639+
messages = [_map_message_from_langchain(m) for m in state["messages"]]
16361640
return AgentState(
1637-
response=response,
1641+
messages=messages,
16381642
total_steps=len(messages),
16391643
token_count=total_tokens,
16401644
)
@@ -1646,7 +1650,9 @@ def _get_approximate_token_counter(model: BaseChatModel) -> LC_TokenCounter:
16461650
# NOTE: This is adapted from the backend provider library
16471651
# 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
16481652
# API: https://platform.claude.com/docs/en/build-with-claude/token-counting
1649-
if model._llm_type == ANTHROPIC_CHAT_MODEL_TYPE: # pyright: ignore[reportPrivateUsage]
1653+
if (
1654+
model._llm_type == ANTHROPIC_CHAT_MODEL_TYPE
1655+
): # pyright: ignore[reportPrivateUsage]
16501656
return partial(count_tokens_approximately, chars_per_token=3.3)
16511657
return count_tokens_approximately
16521658

splunklib/ai/middleware.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15-
from collections.abc import Awaitable, Callable
15+
from collections.abc import Sequence, Awaitable, Callable
1616
from dataclasses import dataclass
1717
from typing import Any, override
1818

@@ -35,7 +35,7 @@ class AgentState:
3535
"""AgentState is available through certain middlewares and contains information about the current state of an agent execution."""
3636

3737
# holds messages exchanged so far in the conversation
38-
response: AgentResponse[Any | None]
38+
messages: Sequence[BaseMessage]
3939
# steps taken so far in the conversation
4040
total_steps: int
4141
# tokens used so far in the conversation
@@ -96,7 +96,7 @@ def __post_init__(self) -> None:
9696

9797
@dataclass(frozen=True)
9898
class AgentRequest:
99-
messages: list[BaseMessage]
99+
messages: Sequence[BaseMessage]
100100

101101

102102
AgentMiddlewareHandler = Callable[[AgentRequest], Awaitable[AgentResponse[Any | None]]]

tests/integration/ai/test_agent.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ async def test_agent_with_openai_round_trip(self):
6565
)
6666

6767
response = result.final_message.content.strip().lower().replace(".", "")
68-
assert result.structured_output is None, (
69-
"The structured output should not be populated"
70-
)
68+
assert (
69+
result.structured_output is None
70+
), "The structured output should not be populated"
7171
assert "stefan" in response
7272

7373
@pytest.mark.asyncio
@@ -160,9 +160,9 @@ class Person(BaseModel):
160160

161161
# check if the last message contains the response in natural language
162162
assert response.name in last_message, "Name field not found in the message"
163-
assert str(response.age) in last_message, (
164-
"Age field not found in the message"
165-
)
163+
assert (
164+
str(response.age) in last_message
165+
), "Age field not found in the message"
166166

167167
async def test_agent_uses_subagent(self):
168168
pytest.importorskip("langchain_openai")
@@ -215,9 +215,9 @@ class NicknameGeneratorInput(BaseModel):
215215
subagent_message = next(
216216
filter(lambda m: m.role == "subagent", result.messages), None
217217
)
218-
assert isinstance(subagent_message, SubagentMessage), (
219-
"Invalid subagent message"
220-
)
218+
assert isinstance(
219+
subagent_message, SubagentMessage
220+
), "Invalid subagent message"
221221
assert subagent_message, "No subagent message found in response"
222222

223223
response = result.final_message.content
@@ -366,12 +366,12 @@ class SupervisorOutput(BaseModel):
366366
)
367367

368368
response = result.structured_output
369-
assert type(response) == SupervisorOutput, (
370-
"Response is not of type Team"
371-
)
372-
assert len(response.member_descriptions) == 3, (
373-
"Team does not have 3 members"
374-
)
369+
assert (
370+
type(response) == SupervisorOutput
371+
), "Response is not of type Team"
372+
assert (
373+
len(response.member_descriptions) == 3
374+
), "Team does not have 3 members"
375375

376376
@pytest.mark.asyncio
377377
async def test_duplicated_subagent_name(self) -> None:
@@ -520,9 +520,9 @@ async def _subagent_call_middleware(
520520

521521
# Override the arguments, such that are invalid.
522522
resp = await handler(replace(request, call=replace(request.call, args={})))
523-
assert isinstance(resp.result, SubagentFailureResult), (
524-
"subagent call did not fail"
525-
)
523+
assert isinstance(
524+
resp.result, SubagentFailureResult
525+
), "subagent call did not fail"
526526

527527
after_subagent_call = True
528528
return resp
@@ -532,7 +532,7 @@ async def _model_call_middleware(
532532
req: ModelRequest, _handler: ModelMiddlewareHandler
533533
) -> ModelResponse:
534534
if after_subagent_call:
535-
msgs = req.state.response.messages
535+
msgs = req.state.messages
536536
assert isinstance(msgs[-1], SubagentMessage)
537537
assert isinstance(msgs[-1].result, SubagentFailureResult)
538538

tests/integration/ai/test_conversation_store.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ async def _model_middleware(
6666

6767
if after_first_call:
6868
# Previous messages included.
69-
assert len(request.state.response.messages) == 3
69+
assert len(request.state.messages) == 3
7070
else:
71-
assert len(request.state.response.messages) == 1
71+
assert len(request.state.messages) == 1
7272
return await handler(request)
7373

7474
@agent_middleware
@@ -166,7 +166,7 @@ async def _model_middleware(
166166
nonlocal model_middleware_called
167167
model_middleware_called = True
168168

169-
assert len(request.state.response.messages) == 1
169+
assert len(request.state.messages) == 1
170170
return await handler(request)
171171

172172
async with Agent(
@@ -186,9 +186,9 @@ async def _model_middleware(
186186
thread_id="2",
187187
)
188188
response = result.final_message.content
189-
assert "Mike" not in response, (
190-
"Agent remembered the name from a different thread_id"
191-
)
189+
assert (
190+
"Mike" not in response
191+
), "Agent remembered the name from a different thread_id"
192192

193193
assert model_middleware_called
194194

@@ -276,9 +276,9 @@ async def _model_middleware(
276276
nonlocal after_first_call
277277

278278
if after_first_call:
279-
assert len(request.state.response.messages) == 3
279+
assert len(request.state.messages) == 3
280280
else:
281-
assert len(request.state.response.messages) == 1
281+
assert len(request.state.messages) == 1
282282

283283
after_first_call = True
284284
return await handler(request)
@@ -347,9 +347,9 @@ async def _model_middleware(
347347
nonlocal after_first_call
348348

349349
if after_first_call:
350-
assert len(request.state.response.messages) == 3
350+
assert len(request.state.messages) == 3
351351
else:
352-
assert len(request.state.response.messages) == 1
352+
assert len(request.state.messages) == 1
353353

354354
after_first_call = True
355355
return await handler(request)

tests/integration/ai/test_hooks.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@
3030
before_model,
3131
)
3232
from splunklib.ai.messages import AIMessage, AgentResponse, HumanMessage
33-
from splunklib.ai.middleware import AgentRequest, ModelMiddlewareHandler, ModelRequest, ModelResponse, model_middleware
33+
from splunklib.ai.middleware import (
34+
AgentRequest,
35+
ModelMiddlewareHandler,
36+
ModelRequest,
37+
ModelResponse,
38+
model_middleware,
39+
)
3440
from tests.ai_testlib import AITestCase
3541

3642

@@ -47,15 +53,15 @@ def test_hook_before(req: ModelRequest) -> None:
4753
hook_calls += 1
4854

4955
assert req.system_message.startswith("Your name is stefan")
50-
assert len(req.state.response.messages) == 1
56+
assert len(req.state.messages) == 1
5157

5258
@before_model
5359
async def test_async_hook_before(req: ModelRequest) -> None:
5460
nonlocal hook_calls
5561
hook_calls += 1
5662

5763
assert req.system_message.startswith("Your name is stefan")
58-
assert len(req.state.response.messages) == 1
64+
assert len(req.state.messages) == 1
5965

6066
@after_model
6167
def test_hook_after(resp: ModelResponse) -> None:
@@ -197,10 +203,12 @@ async def test_agent_loop_stop_conditions_conversation_limit(self) -> None:
197203
with pytest.raises(
198204
StepsLimitExceededException, match="Steps limit of 2 exceeded"
199205
):
200-
_ = await agent.invoke([
201-
HumanMessage(content="hi, my name is Chris"),
202-
HumanMessage(content="What is my name?"),
203-
])
206+
_ = await agent.invoke(
207+
[
208+
HumanMessage(content="hi, my name is Chris"),
209+
HumanMessage(content="What is my name?"),
210+
]
211+
)
204212

205213
@pytest.mark.asyncio
206214
async def test_agent_loop_stop_conditions_conversation_limit_with_checkpointer(
@@ -220,13 +228,17 @@ async def test_agent_loop_stop_conditions_conversation_limit_with_checkpointer(
220228
with pytest.raises(
221229
StepsLimitExceededException, match="Steps limit of 2 exceeded"
222230
):
223-
_ = await agent.invoke([
224-
HumanMessage(content="What is my name?"),
225-
HumanMessage(content="Are you sure?"),
226-
])
231+
_ = await agent.invoke(
232+
[
233+
HumanMessage(content="What is my name?"),
234+
HumanMessage(content="Are you sure?"),
235+
]
236+
)
227237

228238
@pytest.mark.asyncio
229-
async def test_agent_loop_stop_conditions_steps_accumulate_across_invokes(self) -> None:
239+
async def test_agent_loop_stop_conditions_steps_accumulate_across_invokes(
240+
self,
241+
) -> None:
230242
pytest.importorskip("langchain_openai")
231243

232244
step_limit = StepLimitMiddleware(2)

tests/integration/ai/test_middleware.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ async def test_middleware(
7878
assert call.args == {"city": "Krakow"}
7979

8080
state = request.state
81-
assert len(state.response.messages) == 2
81+
assert len(state.messages) == 2
8282

8383
response = await handler(request)
8484
assert isinstance(response.result, ToolResult)
@@ -500,9 +500,9 @@ async def test_middleware(
500500
)
501501
assert subagent_message, "SubagentMessage not found in messages"
502502
assert isinstance(subagent_message.result, SubagentTextResult)
503-
assert subagent_message.result.content == "Chris-superstar", (
504-
"Invalid response from subagent"
505-
)
503+
assert (
504+
subagent_message.result.content == "Chris-superstar"
505+
), "Invalid response from subagent"
506506
assert middleware_called, "Middleware was not called"
507507

508508
@pytest.mark.asyncio
@@ -699,10 +699,7 @@ async def mutating_middleware(
699699
) -> ModelResponse:
700700
new_state = replace(
701701
request.state,
702-
response=replace(
703-
request.state.response,
704-
messages=[HumanMessage(content="What is the capital of France?")],
705-
),
702+
messages=[HumanMessage(content="What is the capital of France?")],
706703
)
707704
return await handler(replace(request, state=new_state))
708705

0 commit comments

Comments
 (0)