Skip to content

Commit 18fae4a

Browse files
committed
Add assistant middleware inheritance option
1 parent e8222b5 commit 18fae4a

5 files changed

Lines changed: 367 additions & 20 deletions

File tree

slack_bolt/app/app.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -683,19 +683,26 @@ def middleware_func(logger, body, next):
683683
self._middleware_list.append(middleware)
684684
if isinstance(middleware, Assistant) and middleware.thread_context_store is not None:
685685
self._assistant_thread_context_store = middleware.thread_context_store
686+
elif not isinstance(middleware, Assistant):
687+
self._inherit_app_middleware_for_assistants(middleware)
686688
elif callable(middleware_or_callable):
687-
self._middleware_list.append(
688-
CustomMiddleware(
689-
app_name=self.name,
690-
func=middleware_or_callable,
691-
base_logger=self._base_logger,
692-
)
689+
middleware = CustomMiddleware(
690+
app_name=self.name,
691+
func=middleware_or_callable,
692+
base_logger=self._base_logger,
693693
)
694+
self._middleware_list.append(middleware)
695+
self._inherit_app_middleware_for_assistants(middleware)
694696
return middleware_or_callable
695697
else:
696698
raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})")
697699
return None
698700

701+
def _inherit_app_middleware_for_assistants(self, middleware: Middleware) -> None:
702+
for registered_middleware in self._middleware_list[:-1]:
703+
if isinstance(registered_middleware, Assistant):
704+
registered_middleware.inherit_app_middleware(middleware)
705+
699706
# -------------------------
700707
# AI Agents & Assistants
701708

slack_bolt/app/async_app.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -710,19 +710,26 @@ async def middleware_func(logger, body, next):
710710
self._async_middleware_list.append(middleware)
711711
if isinstance(middleware, AsyncAssistant) and middleware.thread_context_store is not None:
712712
self._assistant_thread_context_store = middleware.thread_context_store
713+
elif not isinstance(middleware, AsyncAssistant):
714+
self._inherit_app_middleware_for_assistants(middleware)
713715
elif callable(middleware_or_callable):
714-
self._async_middleware_list.append(
715-
AsyncCustomMiddleware(
716-
app_name=self.name,
717-
func=middleware_or_callable,
718-
base_logger=self._base_logger,
719-
)
716+
middleware = AsyncCustomMiddleware(
717+
app_name=self.name,
718+
func=middleware_or_callable,
719+
base_logger=self._base_logger,
720720
)
721+
self._async_middleware_list.append(middleware)
722+
self._inherit_app_middleware_for_assistants(middleware)
721723
return middleware_or_callable
722724
else:
723725
raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})")
724726
return None
725727

728+
def _inherit_app_middleware_for_assistants(self, middleware: AsyncMiddleware) -> None:
729+
for registered_middleware in self._async_middleware_list[:-1]:
730+
if isinstance(registered_middleware, AsyncAssistant):
731+
registered_middleware.inherit_app_middleware(middleware)
732+
726733
def assistant(self, assistant: AsyncAssistant) -> Optional[Callable]:
727734
return self.middleware(assistant)
728735

slack_bolt/middleware/assistant/assistant.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from functools import wraps
33
from logging import Logger
4-
from typing import List, Optional, Union, Callable
4+
from typing import List, Optional, Union, Callable, Tuple
55

66
from slack_bolt.context.save_thread_context import SaveThreadContext
77
from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore
@@ -42,16 +42,25 @@ def __init__(
4242
app_name: str = "assistant",
4343
thread_context_store: Optional[AssistantThreadContextStore] = None,
4444
logger: Optional[logging.Logger] = None,
45+
auto_inherit_app_middleware: bool = False,
4546
):
4647
self.app_name = app_name
4748
self.thread_context_store = thread_context_store
4849
self.base_logger = logger
50+
self.auto_inherit_app_middleware = auto_inherit_app_middleware
51+
self._inherited_app_middleware: List[Middleware] = []
4952

5053
self._thread_started_listeners = None
5154
self._thread_context_changed_listeners = None
5255
self._user_message_listeners = None
5356
self._bot_message_listeners = None
5457

58+
def inherit_app_middleware(self, middleware: Middleware) -> None:
59+
if self.auto_inherit_app_middleware is False:
60+
return
61+
62+
self._inherited_app_middleware.append(middleware)
63+
5564
def thread_started(
5665
self,
5766
*args,
@@ -237,7 +246,11 @@ def process( # type: ignore[return]
237246
if listeners is not None:
238247
for listener in listeners:
239248
if listener.matches(req=req, resp=resp):
240-
middleware_resp, next_was_not_called = listener.run_middleware(req=req, resp=resp)
249+
middleware_resp, next_was_not_called = self._run_middleware(
250+
listener=listener,
251+
req=req,
252+
resp=resp,
253+
)
241254
if next_was_not_called:
242255
if middleware_resp is not None:
243256
return middleware_resp
@@ -258,6 +271,33 @@ def process( # type: ignore[return]
258271

259272
next()
260273

274+
def _run_middleware(
275+
self,
276+
*,
277+
listener: Listener,
278+
req: BoltRequest,
279+
resp: BoltResponse,
280+
) -> Tuple[Optional[BoltResponse], bool]:
281+
middleware = list(listener.middleware)
282+
if len(self._inherited_app_middleware) > 0:
283+
insertion_index = 1 if len(middleware) > 0 and isinstance(middleware[0], AttachingConversationKwargs) else 0
284+
middleware = [
285+
*middleware[:insertion_index],
286+
*self._inherited_app_middleware,
287+
*middleware[insertion_index:],
288+
]
289+
290+
for m in middleware:
291+
middleware_state = {"next_called": False}
292+
293+
def next_():
294+
middleware_state["next_called"] = True
295+
296+
resp = m.process(req=req, resp=resp, next=next_) # type: ignore[assignment]
297+
if not middleware_state["next_called"]:
298+
return resp, True
299+
return resp, False
300+
261301
def build_listener(
262302
self,
263303
listener_or_functions: Union[Listener, Callable, List[Callable]],
@@ -271,8 +311,10 @@ def build_listener(
271311
if isinstance(listener_or_functions, Listener):
272312
return listener_or_functions
273313
elif isinstance(listener_or_functions, list):
274-
middleware = middleware if middleware else []
275-
middleware.insert(0, AttachingConversationKwargs(self.thread_context_store))
314+
middleware = [
315+
AttachingConversationKwargs(self.thread_context_store),
316+
*(middleware if middleware else []),
317+
]
276318
functions = listener_or_functions
277319
ack_function = functions.pop(0)
278320

slack_bolt/middleware/assistant/async_assistant.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from functools import wraps
33
from logging import Logger
4-
from typing import List, Optional, Union, Callable, Awaitable
4+
from typing import List, Optional, Union, Callable, Awaitable, Tuple
55

66
from slack_bolt.context.save_thread_context.async_save_thread_context import AsyncSaveThreadContext
77
from slack_bolt.context.assistant.thread_context_store.async_store import AsyncAssistantThreadContextStore
@@ -42,16 +42,25 @@ def __init__(
4242
app_name: str = "assistant",
4343
thread_context_store: Optional[AsyncAssistantThreadContextStore] = None,
4444
logger: Optional[logging.Logger] = None,
45+
auto_inherit_app_middleware: bool = False,
4546
):
4647
self.app_name = app_name
4748
self.thread_context_store = thread_context_store
4849
self.base_logger = logger
50+
self.auto_inherit_app_middleware = auto_inherit_app_middleware
51+
self._inherited_app_middleware: List[AsyncMiddleware] = []
4952

5053
self._thread_started_listeners = None
5154
self._thread_context_changed_listeners = None
5255
self._user_message_listeners = None
5356
self._bot_message_listeners = None
5457

58+
def inherit_app_middleware(self, middleware: AsyncMiddleware) -> None:
59+
if self.auto_inherit_app_middleware is False:
60+
return
61+
62+
self._inherited_app_middleware.append(middleware)
63+
5564
def thread_started(
5665
self,
5766
*args,
@@ -268,7 +277,11 @@ async def async_process( # type: ignore[return]
268277
if listeners is not None:
269278
for listener in listeners:
270279
if listener is not None and await listener.async_matches(req=req, resp=resp):
271-
middleware_resp, next_was_not_called = await listener.run_async_middleware(req=req, resp=resp)
280+
middleware_resp, next_was_not_called = await self._run_middleware(
281+
listener=listener,
282+
req=req,
283+
resp=resp,
284+
)
272285
if next_was_not_called:
273286
if middleware_resp is not None:
274287
return middleware_resp
@@ -289,6 +302,33 @@ async def async_process( # type: ignore[return]
289302

290303
await next()
291304

305+
async def _run_middleware(
306+
self,
307+
*,
308+
listener: AsyncListener,
309+
req: AsyncBoltRequest,
310+
resp: BoltResponse,
311+
) -> Tuple[Optional[BoltResponse], bool]:
312+
middleware = list(listener.middleware)
313+
if len(self._inherited_app_middleware) > 0:
314+
insertion_index = 1 if len(middleware) > 0 and isinstance(middleware[0], AsyncAttachingConversationKwargs) else 0
315+
middleware = [
316+
*middleware[:insertion_index],
317+
*self._inherited_app_middleware,
318+
*middleware[insertion_index:],
319+
]
320+
321+
for m in middleware:
322+
middleware_state = {"next_called": False}
323+
324+
async def next_():
325+
middleware_state["next_called"] = True
326+
327+
resp = await m.async_process(req=req, resp=resp, next=next_) # type: ignore[assignment]
328+
if not middleware_state["next_called"]:
329+
return resp, True
330+
return resp, False
331+
292332
def build_listener(
293333
self,
294334
listener_or_functions: Union[AsyncListener, Callable, List[Callable]],
@@ -302,8 +342,10 @@ def build_listener(
302342
if isinstance(listener_or_functions, AsyncListener):
303343
return listener_or_functions
304344
elif isinstance(listener_or_functions, list):
305-
middleware = middleware if middleware else []
306-
middleware.insert(0, AsyncAttachingConversationKwargs(self.thread_context_store))
345+
middleware = [
346+
AsyncAttachingConversationKwargs(self.thread_context_store),
347+
*(middleware if middleware else []),
348+
]
307349
functions = listener_or_functions
308350
ack_function = functions.pop(0)
309351

0 commit comments

Comments
 (0)