Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions matrix/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def load_extension(self, extension: Extension) -> None:
for event_type, handlers in extension._event_handlers.items():
self._event_handlers[event_type].extend(handlers)

for hook_name, handlers in extension._hook_handlers.items():
self._hook_handlers[hook_name].extend(handlers)

self._checks.extend(extension._checks)
self._error_handlers.update(extension._error_handlers)
self._command_error_handlers.update(extension._command_error_handlers)
Expand Down Expand Up @@ -122,12 +125,19 @@ def _auto_register_events(self) -> None:
for attr in dir(self):
if not attr.startswith("on_"):
continue

coro = getattr(self, attr, None)
if inspect.iscoroutinefunction(coro):
try:
if not inspect.iscoroutinefunction(coro):
continue

try:
if attr in self.LIFECYCLE_EVENTS:
self.hook(coro)

if attr in self.EVENT_MAP:
self.event(coro)
except ValueError: # ignore unknown name
continue
except ValueError:
continue

async def _on_event(self, room: MatrixRoom, event: Event) -> None:
# ignore bot events
Expand All @@ -139,11 +149,16 @@ async def _on_event(self, room: MatrixRoom, event: Event) -> None:
return

try:
await self._dispatch(room, event)
await self._dispatch_matrix_event(room, event)
except Exception as error:
await self.on_error(error)

async def _dispatch(self, room: MatrixRoom, event: Event) -> None:
async def _dispatch(self, event_name: str, *args, **kwargs) -> None:
"""Fire all listeners registered for a named lifecycle event."""
for handler in self._hook_handlers.get(event_name, []):
await handler(*args, **kwargs)

async def _dispatch_matrix_event(self, room: MatrixRoom, event: Event) -> None:
"""Internal type-based fan-out plus optional command handling."""
for event_type, funcs in self._event_handlers.items():
if isinstance(event, event_type):
Expand Down Expand Up @@ -268,7 +283,7 @@ async def run(self) -> None:

self.scheduler.start()

await self.on_ready()
await self._dispatch("on_ready")
await self.client.sync_forever(timeout=30_000)

def start(self) -> None:
Expand Down
69 changes: 68 additions & 1 deletion matrix/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ class Registry:
"on_member_change": RoomMemberEvent,
}

LIFECYCLE_EVENTS: set[str] = {
"on_ready",
"on_error",
"on_command",
"on_command_error",
"on_command_invoke",
"on_load",
"on_unload",
}

def __init__(self, name: str, prefix: Optional[str] = None):
self.name = name
self.prefix = prefix
Expand All @@ -57,6 +67,7 @@ def __init__(self, name: str, prefix: Optional[str] = None):
self._scheduler: Scheduler = Scheduler()

self._event_handlers: Dict[Type[Event], List[Callback]] = defaultdict(list)
self._hook_handlers: Dict[str, List[Callback]] = defaultdict(list)
self._on_error: Optional[ErrorCallback] = None
self._error_handlers: Dict[type[Exception], ErrorCallback] = {}
self._command_error_handlers: Dict[type[Exception], CommandErrorCallback] = {}
Expand Down Expand Up @@ -217,7 +228,7 @@ def wrapper(f: Callback) -> Callback:
event_type = event_spec
else:
event_type = self.EVENT_MAP.get(f.__name__)
if event_type is None:
if 9 is None:
raise ValueError(f"Unknown event name: {f.__name__}")

return self.register_event(event_type, f)
Expand All @@ -238,6 +249,62 @@ def register_event(self, event_type: Type[Event], callback: Callback) -> Callbac
)
return callback

def hook(self, func: Optional[Callback], *, event_name: Optional[str] = None) -> Union[Callback, Callable[[Callback], Callback]]:
"""Decorator to register a coroutine as a lifecycle event hook.

Lifecycle events include things like ``on_ready``, ``on_command``,
and ``on_error``. If the event name is not provided, it is inferred
from the function name. Multiple handlers for the same lifecycle
event are supported and called in registration order.

## Example

```python
@bot.hook
async def on_ready():
print("Bot is ready!")

@bot.hook(event_name="on_command")
async def log_command(ctx):
print(f"Command invoked: {ctx.command}")
```
"""

def wrapper(f: Callback) -> Callback:
if not inspect.iscoroutinefunction(f):
raise TypeError("Lifecycle hooks must be coroutines")

name = event_name or f.__name__
if name not in self.LIFECYCLE_EVENTS:
raise ValueError(f"Unknown lifecycle event: {name}")

return self.register_hook(name, f)

if func is None:
return wrapper
return wrapper(func)

def register_hook(self, event_name: str, callback: Callback) -> Callback:
"""Register a lifecycle event hook directly for a given event name.

Prefer the :meth:`hook` decorator for typical use. This method
is useful when loading lifecycle hooks from an extension.
"""
if not inspect.iscoroutinefunction(callback):
raise TypeError("Lifecycle hooks must be coroutines")

if event_name not in self.LIFECYCLE_EVENTS:
raise ValueError(f"Unknown lifecycle event: {event_name}")

self._hook_handlers[event_name].append(callback)
logger.debug(
"registered lifecycle hook '%s' for event '%s' on %s",
callback.__name__,
event_name,
type(self).__name__,
)
return callback

def check(self, func: Callback) -> Callback:
"""Register a global check that must pass before any command is invoked.

Expand Down
Loading