@@ -265,6 +265,7 @@ def __init__(self, agent: BaseAgent[OutputT]) -> None:
265265 )
266266
267267 self ._agent_middleware = []
268+
268269 if agent .limits .max_structured_output_retires is not None :
269270 self ._agent_middleware .append (
270271 _StructuredOutputRetryLimitMiddleware (
@@ -276,20 +277,14 @@ def __init__(self, agent: BaseAgent[OutputT]) -> None:
276277 self ._agent_middleware .extend (agent .middleware or [])
277278 self ._agent_middleware .extend (after_user_middlewares )
278279
279- if agent .limits .max_tokens is not None :
280- self ._agent_middleware .append (
281- _TokenLimitMiddleware (agent .limits .max_tokens )
282- )
283280 if agent .limits .max_steps is not None :
284281 self ._agent_middleware .append (_StepLimitMiddleware (agent .limits .max_steps ))
285282 if agent .limits .timeout is not None :
286283 self ._agent_middleware .append (_TimeoutLimitMiddleware (agent .limits .timeout ))
287284
288285 model_impl = _create_langchain_model (agent .model )
289286
290- lc_middleware : list [LC_AgentMiddleware ] = [
291- _Middleware (self ._agent_middleware , model_impl )
292- ]
287+ lc_middleware : list [LC_AgentMiddleware ] = [_Middleware (self ._agent_middleware )]
293288
294289 # This middleware is executed just after the tool execution and populates
295290 # the artifact field for failed tool calls, since in such cases we can't
@@ -605,6 +600,27 @@ async def awrap_tool_call(
605600 if _DEBUG :
606601 lc_middleware .append (_DEBUGMiddleware ())
607602
603+ if agent .limits .max_tokens is not None :
604+ _max_tokens = agent .limits .max_tokens
605+
606+ class _TokenLimitMiddleware (LC_AgentMiddleware ):
607+ @override
608+ async def awrap_model_call (
609+ self ,
610+ request : LC_ModelRequest ,
611+ handler : Callable [[LC_ModelRequest ], Awaitable [LC_ModelCallResult ]],
612+ ) -> LC_ModelCallResult :
613+ token_count = _get_approximate_token_counter (
614+ request .model , request .tools
615+ )(request .state ["messages" ])
616+
617+ if token_count >= _max_tokens :
618+ raise TokenLimitExceededException (token_limit = _max_tokens )
619+
620+ return await handler (request )
621+
622+ lc_middleware .append (_TokenLimitMiddleware ())
623+
608624 response_format = None
609625 if agent .output_schema is not None :
610626 if _supports_provider_strategy (model_impl ):
@@ -792,11 +808,9 @@ def _prepare_langchain_tools(agent_tools: Sequence[Tool]) -> list[BaseTool]:
792808
793809class _Middleware (LC_AgentMiddleware ):
794810 _middleware : list [AgentMiddleware ]
795- _model : BaseChatModel
796811
797- def __init__ (self , middleware : list [AgentMiddleware ], model : BaseChatModel ) -> None :
812+ def __init__ (self , middleware : list [AgentMiddleware ]) -> None :
798813 self ._middleware = middleware
799- self ._model = model
800814
801815 def _with_model_middleware (
802816 self , model_invoke : ModelMiddlewareHandler
@@ -869,7 +883,7 @@ async def awrap_model_call(
869883 request .state ["messages" ].append (request .runtime .context .retry )
870884 request .runtime .context .retry = False
871885
872- req = _convert_model_request_from_lc (request , self . _model )
886+ req = _convert_model_request_from_lc (request )
873887 final_handler = _convert_model_handler_from_lc (
874888 handler , original_request = request
875889 )
@@ -967,7 +981,7 @@ async def awrap_tool_call(
967981 call = _map_tool_call_from_langchain (request .tool_call )
968982
969983 if isinstance (call , ToolCall ):
970- req = _convert_tool_request_from_lc (request , self . _model )
984+ req = _convert_tool_request_from_lc (request )
971985 final_handler = _convert_tool_handler_from_lc (
972986 handler , original_request = request
973987 )
@@ -995,7 +1009,7 @@ async def awrap_tool_call(
9951009 artifact = sdk_result ,
9961010 )
9971011
998- req = _convert_subagent_request_from_lc (request , self . _model )
1012+ req = _convert_subagent_request_from_lc (request )
9991013 final_handler = _convert_subagent_handler_from_lc (
10001014 handler , original_request = request
10011015 )
@@ -1076,9 +1090,7 @@ async def _sdk_handler(request: ModelRequest) -> ModelResponse:
10761090 return _sdk_handler
10771091
10781092
1079- def _convert_model_request_from_lc (
1080- request : LC_ModelRequest , model : BaseChatModel
1081- ) -> ModelRequest :
1093+ def _convert_model_request_from_lc (request : LC_ModelRequest ) -> ModelRequest :
10821094 thread_id = request .runtime .context .thread_id
10831095
10841096 system_message = (
@@ -1087,12 +1099,12 @@ def _convert_model_request_from_lc(
10871099
10881100 return ModelRequest (
10891101 system_message = system_message ,
1090- state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
1102+ state = _convert_agent_state_from_langchain (request .state , thread_id ),
10911103 )
10921104
10931105
10941106def _convert_tool_request_from_lc (
1095- request : LC_ToolCallRequest , model : BaseChatModel
1107+ request : LC_ToolCallRequest ,
10961108) -> ToolRequest :
10971109 assert isinstance (request .runtime .context , InvokeContext )
10981110 thread_id = request .runtime .context .thread_id
@@ -1101,13 +1113,12 @@ def _convert_tool_request_from_lc(
11011113 assert isinstance (tool_call , ToolCall ), "Expected tool call"
11021114 return ToolRequest (
11031115 call = tool_call ,
1104- state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
1116+ state = _convert_agent_state_from_langchain (request .state , thread_id ),
11051117 )
11061118
11071119
11081120def _convert_subagent_request_from_lc (
11091121 request : LC_ToolCallRequest ,
1110- model : BaseChatModel ,
11111122) -> SubagentRequest :
11121123 assert isinstance (request .runtime .context , InvokeContext )
11131124 thread_id = request .runtime .context .thread_id
@@ -1116,7 +1127,7 @@ def _convert_subagent_request_from_lc(
11161127 assert isinstance (subagent_call , SubagentCall ), "Expected subagent call"
11171128 return SubagentRequest (
11181129 call = subagent_call ,
1119- state = _convert_agent_state_from_langchain (request .state , model , thread_id ),
1130+ state = _convert_agent_state_from_langchain (request .state , thread_id ),
11201131 )
11211132
11221133
@@ -1809,28 +1820,30 @@ def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
18091820
18101821
18111822def _convert_agent_state_from_langchain (
1812- state : LC_AgentState [Any ], model : BaseChatModel , thread_id : str
1823+ state : LC_AgentState [Any ], thread_id : str
18131824) -> AgentState :
18141825 messages = state ["messages" ]
1815- total_tokens_counter = _get_approximate_token_counter (model )
1816- total_tokens = total_tokens_counter (messages )
18171826 messages = [_map_message_from_langchain (m ) for m in state ["messages" ]]
18181827 return AgentState (
18191828 messages = messages ,
1820- token_count = total_tokens ,
18211829 thread_id = thread_id ,
18221830 )
18231831
18241832
1825- def _get_approximate_token_counter (model : BaseChatModel ) -> LC_TokenCounter :
1833+ def _get_approximate_token_counter (
1834+ model : BaseChatModel , tools : list [BaseTool | dict [str , Any ]]
1835+ ) -> LC_TokenCounter :
18261836 """Tune parameters of approximate token counter based on model type."""
18271837
1838+ # TODO: consider using use_usage_metadata_scaling option once
1839+ # we expose token usage details from LLMs.
1840+
18281841 # NOTE: This is adapted from the backend provider library
18291842 # 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
18301843 # API: https://platform.claude.com/docs/en/build-with-claude/token-counting
18311844 if model ._llm_type == ANTHROPIC_CHAT_MODEL_TYPE : # pyright: ignore[reportPrivateUsage]
1832- return partial (count_tokens_approximately , chars_per_token = 3.3 )
1833- return count_tokens_approximately
1845+ return partial (count_tokens_approximately , tools = tools , chars_per_token = 3.3 )
1846+ return partial ( count_tokens_approximately , tools = tools )
18341847
18351848
18361849def _create_langchain_model (model : PredefinedModel ) -> BaseChatModel :
@@ -2016,25 +2029,6 @@ def check_tool_name(type: str, name: str) -> None:
20162029 raise _InvalidMessagesException ("last AIMessage has tool calls" )
20172030
20182031
2019- class _TokenLimitMiddleware (AgentMiddleware ):
2020- """Stops agent execution when the token count of messages passed to the model exceeds the given limit."""
2021-
2022- _limit : int
2023-
2024- def __init__ (self , limit : int ) -> None :
2025- self ._limit = limit
2026-
2027- @override
2028- async def model_middleware (
2029- self ,
2030- request : ModelRequest ,
2031- handler : ModelMiddlewareHandler ,
2032- ) -> ModelResponse :
2033- if request .state .token_count >= self ._limit :
2034- raise TokenLimitExceededException (token_limit = self ._limit )
2035- return await handler (request )
2036-
2037-
20382032class _StepLimitMiddleware (AgentMiddleware ):
20392033 """Stops agent execution when the number of steps taken reaches the given limit."""
20402034
0 commit comments