@@ -276,20 +276,14 @@ def __init__(self, agent: BaseAgent[OutputT]) -> None:
276276 self ._agent_middleware .extend (agent .middleware or [])
277277 self ._agent_middleware .extend (after_user_middlewares )
278278
279- if agent .limits .max_tokens is not None :
280- self ._agent_middleware .append (
281- _TokenLimitMiddleware (agent .limits .max_tokens )
282- )
283279 if agent .limits .max_steps is not None :
284280 self ._agent_middleware .append (_StepLimitMiddleware (agent .limits .max_steps ))
285281 if agent .limits .timeout is not None :
286282 self ._agent_middleware .append (_TimeoutLimitMiddleware (agent .limits .timeout ))
287283
288284 model_impl = _create_langchain_model (agent .model )
289285
290- lc_middleware : list [LC_AgentMiddleware ] = [
291- _Middleware (self ._agent_middleware , model_impl )
292- ]
286+ lc_middleware : list [LC_AgentMiddleware ] = [_Middleware (self ._agent_middleware )]
293287
294288 # This middleware is executed just after the tool execution and populates
295289 # the artifact field for failed tool calls, since in such cases we can't
@@ -605,6 +599,27 @@ async def awrap_tool_call(
605599 if _DEBUG :
606600 lc_middleware .append (_DEBUGMiddleware ())
607601
602+ if agent .limits .max_tokens is not None :
603+ _max_tokens = agent .limits .max_tokens
604+
605+ class _TokenLimitMiddleware (LC_AgentMiddleware ):
606+ @override
607+ async def awrap_model_call (
608+ self ,
609+ request : LC_ModelRequest ,
610+ handler : Callable [[LC_ModelRequest ], Awaitable [LC_ModelCallResult ]],
611+ ) -> LC_ModelCallResult :
612+ token_count = _get_approximate_token_counter (
613+ request .model , request .tools
614+ )(request .state ["messages" ])
615+
616+ if token_count >= _max_tokens :
617+ raise TokenLimitExceededException (token_limit = _max_tokens )
618+
619+ return await handler (request )
620+
621+ lc_middleware .append (_TokenLimitMiddleware ())
622+
608623 response_format = None
609624 if agent .output_schema is not None :
610625 if _supports_provider_strategy (model_impl ):
@@ -792,11 +807,9 @@ def _prepare_langchain_tools(agent_tools: Sequence[Tool]) -> list[BaseTool]:
792807
793808class _Middleware (LC_AgentMiddleware ):
794809 _middleware : list [AgentMiddleware ]
795- _model : BaseChatModel
796810
797- def __init__ (self , middleware : list [AgentMiddleware ], model : BaseChatModel ) -> None :
811+ def __init__ (self , middleware : list [AgentMiddleware ]) -> None :
798812 self ._middleware = middleware
799- self ._model = model
800813
801814 def _with_model_middleware (
802815 self , model_invoke : ModelMiddlewareHandler
@@ -869,7 +882,7 @@ async def awrap_model_call(
869882 request .state ["messages" ].append (request .runtime .context .retry )
870883 request .runtime .context .retry = False
871884
872- req = _convert_model_request_from_lc (request , self . _model )
885+ req = _convert_model_request_from_lc (request )
873886 final_handler = _convert_model_handler_from_lc (
874887 handler , original_request = request
875888 )
@@ -967,7 +980,7 @@ async def awrap_tool_call(
967980 call = _map_tool_call_from_langchain (request .tool_call )
968981
969982 if isinstance (call , ToolCall ):
970- req = _convert_tool_request_from_lc (request , self . _model )
983+ req = _convert_tool_request_from_lc (request )
971984 final_handler = _convert_tool_handler_from_lc (
972985 handler , original_request = request
973986 )
@@ -995,7 +1008,7 @@ async def awrap_tool_call(
9951008 artifact = sdk_result ,
9961009 )
9971010
998- req = _convert_subagent_request_from_lc (request , self . _model )
1011+ req = _convert_subagent_request_from_lc (request )
9991012 final_handler = _convert_subagent_handler_from_lc (
10001013 handler , original_request = request
10011014 )
@@ -1076,9 +1089,7 @@ async def _sdk_handler(request: ModelRequest) -> ModelResponse:
10761089 return _sdk_handler
10771090
10781091
1079- def _convert_model_request_from_lc (
1080- request : LC_ModelRequest , model : BaseChatModel
1081- ) -> ModelRequest :
1092+ def _convert_model_request_from_lc (request : LC_ModelRequest ) -> ModelRequest :
10821093 thread_id = request .runtime .context .thread_id
10831094
10841095 system_message = (
@@ -1087,12 +1098,12 @@ def _convert_model_request_from_lc(
10871098
10881099 return ModelRequest (
10891100 system_message = system_message ,
1090- state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
1101+ state = _convert_agent_state_from_langchain (request .state , thread_id ),
10911102 )
10921103
10931104
10941105def _convert_tool_request_from_lc (
1095- request : LC_ToolCallRequest , model : BaseChatModel
1106+ request : LC_ToolCallRequest ,
10961107) -> ToolRequest :
10971108 assert isinstance (request .runtime .context , InvokeContext )
10981109 thread_id = request .runtime .context .thread_id
@@ -1101,13 +1112,12 @@ def _convert_tool_request_from_lc(
11011112 assert isinstance (tool_call , ToolCall ), "Expected tool call"
11021113 return ToolRequest (
11031114 call = tool_call ,
1104- state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
1115+ state = _convert_agent_state_from_langchain (request .state , thread_id ),
11051116 )
11061117
11071118
11081119def _convert_subagent_request_from_lc (
11091120 request : LC_ToolCallRequest ,
1110- model : BaseChatModel ,
11111121) -> SubagentRequest :
11121122 assert isinstance (request .runtime .context , InvokeContext )
11131123 thread_id = request .runtime .context .thread_id
@@ -1116,7 +1126,7 @@ def _convert_subagent_request_from_lc(
11161126 assert isinstance (subagent_call , SubagentCall ), "Expected subagent call"
11171127 return SubagentRequest (
11181128 call = subagent_call ,
1119- state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
1129+ state = _convert_agent_state_from_langchain (request .state , thread_id ),
11201130 )
11211131
11221132
@@ -1809,29 +1819,30 @@ def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
18091819
18101820
18111821def _convert_agent_state_from_langchain (
1812- state : LC_AgentState [Any ], model : BaseChatModel , thread_id : str
1822+ state : LC_AgentState [Any ], thread_id : str
18131823) -> AgentState :
18141824 messages = state ["messages" ]
1815- total_tokens_counter = _get_approximate_token_counter (model )
1816- total_tokens = total_tokens_counter (messages )
18171825 messages = [_map_message_from_langchain (m ) for m in state ["messages" ]]
18181826 return AgentState (
18191827 messages = messages ,
1820- total_steps = len (messages ),
1821- token_count = total_tokens ,
18221828 thread_id = thread_id ,
18231829 )
18241830
18251831
1826- def _get_approximate_token_counter (model : BaseChatModel ) -> LC_TokenCounter :
1832+ def _get_approximate_token_counter (
1833+ model : BaseChatModel , tools : list [BaseTool | dict [str , Any ]]
1834+ ) -> LC_TokenCounter :
18271835 """Tune parameters of approximate token counter based on model type."""
18281836
1837+ # TODO: consider using use_usage_metadata_scaling option once
1838+ # we expose token usage details from LLMs.
1839+
18291840 # NOTE: This is adapted from the backend provider library
18301841 # 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
18311842 # API: https://platform.claude.com/docs/en/build-with-claude/token-counting
18321843 if model ._llm_type == ANTHROPIC_CHAT_MODEL_TYPE : # pyright: ignore[reportPrivateUsage]
1833- return partial (count_tokens_approximately , chars_per_token = 3.3 )
1834- return count_tokens_approximately
1844+ return partial (count_tokens_approximately , tools = tools , chars_per_token = 3.3 )
1845+ return partial ( count_tokens_approximately , tools = tools )
18351846
18361847
18371848def _create_langchain_model (model : PredefinedModel ) -> BaseChatModel :
@@ -2017,25 +2028,6 @@ def check_tool_name(type: str, name: str) -> None:
20172028 raise _InvalidMessagesException ("last AIMessage has tool calls" )
20182029
20192030
2020- class _TokenLimitMiddleware (AgentMiddleware ):
2021- """Stops agent execution when the token count of messages passed to the model exceeds the given limit."""
2022-
2023- _limit : int
2024-
2025- def __init__ (self , limit : int ) -> None :
2026- self ._limit = limit
2027-
2028- @override
2029- async def model_middleware (
2030- self ,
2031- request : ModelRequest ,
2032- handler : ModelMiddlewareHandler ,
2033- ) -> ModelResponse :
2034- if request .state .token_count >= self ._limit :
2035- raise TokenLimitExceededException (token_limit = self ._limit )
2036- return await handler (request )
2037-
2038-
20392031class _StepLimitMiddleware (AgentMiddleware ):
20402032 """Stops agent execution when the number of steps taken reaches the given limit."""
20412033
@@ -2050,7 +2042,7 @@ async def model_middleware(
20502042 request : ModelRequest ,
20512043 handler : ModelMiddlewareHandler ,
20522044 ) -> ModelResponse :
2053- if request .state .total_steps >= self ._limit :
2045+ if len ( request .state .messages ) >= self ._limit :
20542046 raise StepsLimitExceededException (steps_limit = self ._limit )
20552047 return await handler (request )
20562048
0 commit comments