Skip to content

Commit ee9fb84

Browse files
committed
Add assistant middleware inheritance option
1 parent e8222b5 commit ee9fb84

5 files changed

Lines changed: 339 additions & 16 deletions

File tree

slack_bolt/app/app.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -683,19 +683,26 @@ def middleware_func(logger, body, next):
683683
self._middleware_list.append(middleware)
684684
if isinstance(middleware, Assistant) and middleware.thread_context_store is not None:
685685
self._assistant_thread_context_store = middleware.thread_context_store
686+
elif not isinstance(middleware, Assistant):
687+
self._inherit_app_middleware_for_assistants(middleware)
686688
elif callable(middleware_or_callable):
687-
self._middleware_list.append(
688-
CustomMiddleware(
689-
app_name=self.name,
690-
func=middleware_or_callable,
691-
base_logger=self._base_logger,
692-
)
689+
middleware = CustomMiddleware(
690+
app_name=self.name,
691+
func=middleware_or_callable,
692+
base_logger=self._base_logger,
693693
)
694+
self._middleware_list.append(middleware)
695+
self._inherit_app_middleware_for_assistants(middleware)
694696
return middleware_or_callable
695697
else:
696698
raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})")
697699
return None
698700

701+
def _inherit_app_middleware_for_assistants(self, middleware: Middleware) -> None:
702+
for registered_middleware in self._middleware_list[:-1]:
703+
if isinstance(registered_middleware, Assistant):
704+
registered_middleware.inherit_app_middleware(middleware)
705+
699706
# -------------------------
700707
# AI Agents & Assistants
701708

slack_bolt/app/async_app.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -710,19 +710,26 @@ async def middleware_func(logger, body, next):
710710
self._async_middleware_list.append(middleware)
711711
if isinstance(middleware, AsyncAssistant) and middleware.thread_context_store is not None:
712712
self._assistant_thread_context_store = middleware.thread_context_store
713+
elif not isinstance(middleware, AsyncAssistant):
714+
self._inherit_app_middleware_for_assistants(middleware)
713715
elif callable(middleware_or_callable):
714-
self._async_middleware_list.append(
715-
AsyncCustomMiddleware(
716-
app_name=self.name,
717-
func=middleware_or_callable,
718-
base_logger=self._base_logger,
719-
)
716+
middleware = AsyncCustomMiddleware(
717+
app_name=self.name,
718+
func=middleware_or_callable,
719+
base_logger=self._base_logger,
720720
)
721+
self._async_middleware_list.append(middleware)
722+
self._inherit_app_middleware_for_assistants(middleware)
721723
return middleware_or_callable
722724
else:
723725
raise BoltError(f"Unexpected type for a middleware ({type(middleware_or_callable)})")
724726
return None
725727

728+
def _inherit_app_middleware_for_assistants(self, middleware: AsyncMiddleware) -> None:
729+
for registered_middleware in self._async_middleware_list[:-1]:
730+
if isinstance(registered_middleware, AsyncAssistant):
731+
registered_middleware.inherit_app_middleware(middleware)
732+
726733
def assistant(self, assistant: AsyncAssistant) -> Optional[Callable]:
727734
return self.middleware(assistant)
728735

slack_bolt/middleware/assistant/assistant.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,43 @@ def __init__(
4242
app_name: str = "assistant",
4343
thread_context_store: Optional[AssistantThreadContextStore] = None,
4444
logger: Optional[logging.Logger] = None,
45+
auto_inherit_app_middleware: bool = False,
4546
):
4647
self.app_name = app_name
4748
self.thread_context_store = thread_context_store
4849
self.base_logger = logger
50+
self.auto_inherit_app_middleware = auto_inherit_app_middleware
51+
self._inherited_app_middleware: List[Middleware] = []
4952

5053
self._thread_started_listeners = None
5154
self._thread_context_changed_listeners = None
5255
self._user_message_listeners = None
5356
self._bot_message_listeners = None
5457

58+
def inherit_app_middleware(self, middleware: Middleware) -> None:
59+
if self.auto_inherit_app_middleware is False:
60+
return
61+
62+
insertion_index = len(self._inherited_app_middleware) + 1
63+
self._inherited_app_middleware.append(middleware)
64+
for listener in self._listeners:
65+
listener_middleware = list(listener.middleware)
66+
listener_middleware.insert(insertion_index, middleware)
67+
listener.middleware = listener_middleware
68+
69+
@property
70+
def _listeners(self) -> List[Listener]:
71+
listeners: List[Listener] = []
72+
for listener_list in [
73+
self._thread_started_listeners,
74+
self._thread_context_changed_listeners,
75+
self._user_message_listeners,
76+
self._bot_message_listeners,
77+
]:
78+
if listener_list is not None:
79+
listeners.extend(listener_list)
80+
return listeners
81+
5582
def thread_started(
5683
self,
5784
*args,
@@ -271,8 +298,11 @@ def build_listener(
271298
if isinstance(listener_or_functions, Listener):
272299
return listener_or_functions
273300
elif isinstance(listener_or_functions, list):
274-
middleware = middleware if middleware else []
275-
middleware.insert(0, AttachingConversationKwargs(self.thread_context_store))
301+
middleware = [
302+
AttachingConversationKwargs(self.thread_context_store),
303+
*self._inherited_app_middleware,
304+
*(middleware if middleware else []),
305+
]
276306
functions = listener_or_functions
277307
ack_function = functions.pop(0)
278308

slack_bolt/middleware/assistant/async_assistant.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,43 @@ def __init__(
4242
app_name: str = "assistant",
4343
thread_context_store: Optional[AsyncAssistantThreadContextStore] = None,
4444
logger: Optional[logging.Logger] = None,
45+
auto_inherit_app_middleware: bool = False,
4546
):
4647
self.app_name = app_name
4748
self.thread_context_store = thread_context_store
4849
self.base_logger = logger
50+
self.auto_inherit_app_middleware = auto_inherit_app_middleware
51+
self._inherited_app_middleware: List[AsyncMiddleware] = []
4952

5053
self._thread_started_listeners = None
5154
self._thread_context_changed_listeners = None
5255
self._user_message_listeners = None
5356
self._bot_message_listeners = None
5457

58+
def inherit_app_middleware(self, middleware: AsyncMiddleware) -> None:
59+
if self.auto_inherit_app_middleware is False:
60+
return
61+
62+
insertion_index = len(self._inherited_app_middleware) + 1
63+
self._inherited_app_middleware.append(middleware)
64+
for listener in self._listeners:
65+
listener_middleware = list(listener.middleware)
66+
listener_middleware.insert(insertion_index, middleware)
67+
listener.middleware = listener_middleware
68+
69+
@property
70+
def _listeners(self) -> List[AsyncListener]:
71+
listeners: List[AsyncListener] = []
72+
for listener_list in [
73+
self._thread_started_listeners,
74+
self._thread_context_changed_listeners,
75+
self._user_message_listeners,
76+
self._bot_message_listeners,
77+
]:
78+
if listener_list is not None:
79+
listeners.extend(listener_list)
80+
return listeners
81+
5582
def thread_started(
5683
self,
5784
*args,
@@ -302,8 +329,11 @@ def build_listener(
302329
if isinstance(listener_or_functions, AsyncListener):
303330
return listener_or_functions
304331
elif isinstance(listener_or_functions, list):
305-
middleware = middleware if middleware else []
306-
middleware.insert(0, AsyncAttachingConversationKwargs(self.thread_context_store))
332+
middleware = [
333+
AsyncAttachingConversationKwargs(self.thread_context_store),
334+
*self._inherited_app_middleware,
335+
*(middleware if middleware else []),
336+
]
307337
functions = listener_or_functions
308338
ack_function = functions.pop(0)
309339

0 commit comments

Comments
 (0)