11import logging
22from functools import wraps
33from logging import Logger
4- from typing import List , Optional , Union , Callable , Awaitable
4+ from typing import List , Optional , Union , Callable , Awaitable , Tuple
55
66from slack_bolt .context .save_thread_context .async_save_thread_context import AsyncSaveThreadContext
77from slack_bolt .context .assistant .thread_context_store .async_store import AsyncAssistantThreadContextStore
@@ -42,16 +42,25 @@ def __init__(
4242 app_name : str = "assistant" ,
4343 thread_context_store : Optional [AsyncAssistantThreadContextStore ] = None ,
4444 logger : Optional [logging .Logger ] = None ,
45+ auto_inherit_app_middleware : bool = False ,
4546 ):
4647 self .app_name = app_name
4748 self .thread_context_store = thread_context_store
4849 self .base_logger = logger
50+ self .auto_inherit_app_middleware = auto_inherit_app_middleware
51+ self ._inherited_app_middleware : List [AsyncMiddleware ] = []
4952
5053 self ._thread_started_listeners = None
5154 self ._thread_context_changed_listeners = None
5255 self ._user_message_listeners = None
5356 self ._bot_message_listeners = None
5457
58+ def inherit_app_middleware (self , middleware : AsyncMiddleware ) -> None :
59+ if self .auto_inherit_app_middleware is False :
60+ return
61+
62+ self ._inherited_app_middleware .append (middleware )
63+
5564 def thread_started (
5665 self ,
5766 * args ,
@@ -268,7 +277,11 @@ async def async_process( # type: ignore[return]
268277 if listeners is not None :
269278 for listener in listeners :
270279 if listener is not None and await listener .async_matches (req = req , resp = resp ):
271- middleware_resp , next_was_not_called = await listener .run_async_middleware (req = req , resp = resp )
280+ middleware_resp , next_was_not_called = await self ._run_middleware (
281+ listener = listener ,
282+ req = req ,
283+ resp = resp ,
284+ )
272285 if next_was_not_called :
273286 if middleware_resp is not None :
274287 return middleware_resp
@@ -289,6 +302,33 @@ async def async_process( # type: ignore[return]
289302
290303 await next ()
291304
305+ async def _run_middleware (
306+ self ,
307+ * ,
308+ listener : AsyncListener ,
309+ req : AsyncBoltRequest ,
310+ resp : BoltResponse ,
311+ ) -> Tuple [Optional [BoltResponse ], bool ]:
312+ middleware = list (listener .middleware )
313+ if len (self ._inherited_app_middleware ) > 0 :
314+ insertion_index = 1 if len (middleware ) > 0 and isinstance (middleware [0 ], AsyncAttachingConversationKwargs ) else 0
315+ middleware = [
316+ * middleware [:insertion_index ],
317+ * self ._inherited_app_middleware ,
318+ * middleware [insertion_index :],
319+ ]
320+
321+ for m in middleware :
322+ middleware_state = {"next_called" : False }
323+
324+ async def next_ ():
325+ middleware_state ["next_called" ] = True
326+
327+ resp = await m .async_process (req = req , resp = resp , next = next_ ) # type: ignore[assignment]
328+ if not middleware_state ["next_called" ]:
329+ return resp , True
330+ return resp , False
331+
292332 def build_listener (
293333 self ,
294334 listener_or_functions : Union [AsyncListener , Callable , List [Callable ]],
@@ -302,8 +342,10 @@ def build_listener(
302342 if isinstance (listener_or_functions , AsyncListener ):
303343 return listener_or_functions
304344 elif isinstance (listener_or_functions , list ):
305- middleware = middleware if middleware else []
306- middleware .insert (0 , AsyncAttachingConversationKwargs (self .thread_context_store ))
345+ middleware = [
346+ AsyncAttachingConversationKwargs (self .thread_context_store ),
347+ * (middleware if middleware else []),
348+ ]
307349 functions = listener_or_functions
308350 ack_function = functions .pop (0 )
309351
0 commit comments