@@ -268,16 +268,14 @@ def __init__(self, agent: BaseAgent[OutputT]) -> None:
268268 self ._agent_middleware .extend (agent .middleware or [])
269269 self ._agent_middleware .extend (after_user_middlewares )
270270
271- if agent .limits .max_tokens is not None :
272- self ._agent_middleware .append (_TokenLimitMiddleware (agent .limits .max_tokens ))
273271 if agent .limits .max_steps is not None :
274272 self ._agent_middleware .append (_StepLimitMiddleware (agent .limits .max_steps ))
275273 if agent .limits .timeout is not None :
276274 self ._agent_middleware .append (_TimeoutLimitMiddleware (agent .limits .timeout ))
277275
278276 model_impl = _create_langchain_model (agent .model )
279277
280- lc_middleware : list [LC_AgentMiddleware ] = [_Middleware (self ._agent_middleware , model_impl )]
278+ lc_middleware : list [LC_AgentMiddleware ] = [_Middleware (self ._agent_middleware )]
281279
282280 # This middleware is executed just after the tool execution and populates
283281 # the artifact field for failed tool calls, since in such cases we can't
@@ -587,6 +585,27 @@ async def awrap_tool_call(
587585 if _DEBUG :
588586 lc_middleware .append (_DEBUGMiddleware ())
589587
588+ if agent .limits .max_tokens is not None :
589+ _max_tokens = agent .limits .max_tokens
590+
591+ class _TokenLimitMiddleware (LC_AgentMiddleware ):
592+ @override
593+ async def awrap_model_call (
594+ self ,
595+ request : LC_ModelRequest ,
596+ handler : Callable [[LC_ModelRequest ], Awaitable [LC_ModelCallResult ]],
597+ ) -> LC_ModelCallResult :
598+ token_count = _get_approximate_token_counter (request .model , request .tools )(
599+ request .state ["messages" ]
600+ )
601+
602+ if token_count >= _max_tokens :
603+ raise TokenLimitExceededException (token_limit = _max_tokens )
604+
605+ return await handler (request )
606+
607+ lc_middleware .append (_TokenLimitMiddleware ())
608+
590609 response_format = None
591610 if agent .output_schema is not None :
592611 if _supports_provider_strategy (model_impl ):
@@ -764,11 +783,9 @@ def _prepare_langchain_tools(agent_tools: Sequence[Tool]) -> list[BaseTool]:
764783
765784class _Middleware (LC_AgentMiddleware ):
766785 _middleware : list [AgentMiddleware ]
767- _model : BaseChatModel
768786
769- def __init__ (self , middleware : list [AgentMiddleware ], model : BaseChatModel ) -> None :
787+ def __init__ (self , middleware : list [AgentMiddleware ]) -> None :
770788 self ._middleware = middleware
771- self ._model = model
772789
773790 def _with_model_middleware (
774791 self , model_invoke : ModelMiddlewareHandler
@@ -837,7 +854,7 @@ async def awrap_model_call(
837854 request .state ["messages" ].append (request .runtime .context .retry )
838855 request .runtime .context .retry = False
839856
840- req = _convert_model_request_from_lc (request , self . _model )
857+ req = _convert_model_request_from_lc (request )
841858 final_handler = _convert_model_handler_from_lc (handler , original_request = request )
842859
843860 async def llm_handler (req : ModelRequest ) -> ModelResponse :
@@ -929,7 +946,7 @@ async def awrap_tool_call(
929946 call = _map_tool_call_from_langchain (request .tool_call )
930947
931948 if isinstance (call , ToolCall ):
932- req = _convert_tool_request_from_lc (request , self . _model )
949+ req = _convert_tool_request_from_lc (request )
933950 final_handler = _convert_tool_handler_from_lc (handler , original_request = request )
934951 sdk_response = await self ._with_tool_call_middleware (final_handler )(req )
935952
@@ -955,7 +972,7 @@ async def awrap_tool_call(
955972 artifact = sdk_result ,
956973 )
957974
958- req = _convert_subagent_request_from_lc (request , self . _model )
975+ req = _convert_subagent_request_from_lc (request )
959976 final_handler = _convert_subagent_handler_from_lc (handler , original_request = request )
960977 sdk_response = await self ._with_subagent_call_middleware (final_handler )(req )
961978
@@ -1030,32 +1047,31 @@ async def _sdk_handler(request: ModelRequest) -> ModelResponse:
10301047 return _sdk_handler
10311048
10321049
1033- def _convert_model_request_from_lc (request : LC_ModelRequest , model : BaseChatModel ) -> ModelRequest :
1050+ def _convert_model_request_from_lc (request : LC_ModelRequest ) -> ModelRequest :
10341051 thread_id = request .runtime .context .thread_id
10351052
10361053 system_message = request .system_message .content .__str__ () if request .system_message else ""
10371054
10381055 return ModelRequest (
10391056 system_message = system_message ,
1040- state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
1057+ state = _convert_agent_state_from_langchain (request .state , thread_id ),
10411058 )
10421059
10431060
1044- def _convert_tool_request_from_lc (request : LC_ToolCallRequest , model : BaseChatModel ) -> ToolRequest :
1061+ def _convert_tool_request_from_lc (request : LC_ToolCallRequest ) -> ToolRequest :
10451062 assert isinstance (request .runtime .context , InvokeContext )
10461063 thread_id = request .runtime .context .thread_id
10471064
10481065 tool_call = _map_tool_call_from_langchain (request .tool_call )
10491066 assert isinstance (tool_call , ToolCall ), "Expected tool call"
10501067 return ToolRequest (
10511068 call = tool_call ,
1052- state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
1069+ state = _convert_agent_state_from_langchain (request .state , thread_id ),
10531070 )
10541071
10551072
10561073def _convert_subagent_request_from_lc (
10571074 request : LC_ToolCallRequest ,
1058- model : BaseChatModel ,
10591075) -> SubagentRequest :
10601076 assert isinstance (request .runtime .context , InvokeContext )
10611077 thread_id = request .runtime .context .thread_id
@@ -1064,7 +1080,7 @@ def _convert_subagent_request_from_lc(
10641080 assert isinstance (subagent_call , SubagentCall ), "Expected subagent call"
10651081 return SubagentRequest (
10661082 call = subagent_call ,
1067- state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
1083+ state = _convert_agent_state_from_langchain (request .state , thread_id ),
10681084 )
10691085
10701086
@@ -1732,30 +1748,29 @@ def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
17321748 raise InvalidMessageTypeError ("Invalid SDK message type" )
17331749
17341750
1735- def _convert_agent_state_from_langchain (
1736- state : LC_AgentState [Any ], model : BaseChatModel , thread_id : str
1737- ) -> AgentState :
1751+ def _convert_agent_state_from_langchain (state : LC_AgentState [Any ], thread_id : str ) -> AgentState :
17381752 messages = state ["messages" ]
1739- total_tokens_counter = _get_approximate_token_counter (model )
1740- total_tokens = total_tokens_counter (messages )
17411753 messages = [_map_message_from_langchain (m ) for m in state ["messages" ]]
17421754 return AgentState (
17431755 messages = messages ,
1744- total_steps = len (messages ),
1745- token_count = total_tokens ,
17461756 thread_id = thread_id ,
17471757 )
17481758
17491759
1750- def _get_approximate_token_counter (model : BaseChatModel ) -> LC_TokenCounter :
1760+ def _get_approximate_token_counter (
1761+ model : BaseChatModel , tools : list [BaseTool | dict [str , Any ]]
1762+ ) -> LC_TokenCounter :
17511763 """Tune parameters of approximate token counter based on model type."""
17521764
1765+ # TODO: consider using use_usage_metadata_scaling option once
1766+ # we expose token usage details from LLMs.
1767+
17531768 # NOTE: This is adapted from the backend provider library
17541769 # 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
17551770 # API: https://platform.claude.com/docs/en/build-with-claude/token-counting
17561771 if model ._llm_type == ANTHROPIC_CHAT_MODEL_TYPE : # pyright: ignore[reportPrivateUsage]
1757- return partial (count_tokens_approximately , chars_per_token = 3.3 )
1758- return count_tokens_approximately
1772+ return partial (count_tokens_approximately , tools = tools , chars_per_token = 3.3 )
1773+ return partial ( count_tokens_approximately , tools = tools )
17591774
17601775
17611776def _create_langchain_model (model : PredefinedModel ) -> BaseChatModel :
@@ -1964,25 +1979,6 @@ def check_tool_name(type: str, name: str) -> None:
19641979 raise _InvalidMessagesException ("last AIMessage has tool calls" )
19651980
19661981
1967- class _TokenLimitMiddleware (AgentMiddleware ):
1968- """Stops agent execution when the token count of messages passed to the model exceeds the given limit."""
1969-
1970- _limit : int
1971-
1972- def __init__ (self , limit : int ) -> None :
1973- self ._limit = limit
1974-
1975- @override
1976- async def model_middleware (
1977- self ,
1978- request : ModelRequest ,
1979- handler : ModelMiddlewareHandler ,
1980- ) -> ModelResponse :
1981- if request .state .token_count >= self ._limit :
1982- raise TokenLimitExceededException (token_limit = self ._limit )
1983- return await handler (request )
1984-
1985-
19861982class _StepLimitMiddleware (AgentMiddleware ):
19871983 """Stops agent execution when the number of steps taken reaches the given limit."""
19881984
@@ -1997,7 +1993,7 @@ async def model_middleware(
19971993 request : ModelRequest ,
19981994 handler : ModelMiddlewareHandler ,
19991995 ) -> ModelResponse :
2000- if request .state .total_steps >= self ._limit :
1996+ if len ( request .state .messages ) >= self ._limit :
20011997 raise StepsLimitExceededException (steps_limit = self ._limit )
20021998 return await handler (request )
20031999
0 commit comments