diff --git a/README.md b/README.md index 5129f1b..5128ab1 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,8 @@ pip install eric-sse *Features* * Send to one listener and broadcast -* SSE format was adopted by design, making the library suitable for such kind of model * Callbacks and threading support +* Support to SSE and concurrency batch process implementation * Sockets server prefab for offline inter process communication *Possible applications* diff --git a/eric_sse/connection.py b/eric_sse/connection.py index 323cf34..d997cb2 100644 --- a/eric_sse/connection.py +++ b/eric_sse/connection.py @@ -2,8 +2,9 @@ import eric_sse from eric_sse.listener import MessageQueueListener +from eric_sse.message import MessageContract from eric_sse.queues import Queue, InMemoryQueue - +from eric_sse.handlers import QueuingErrorHandler class Connection: """ @@ -16,6 +17,8 @@ def __init__(self, listener: MessageQueueListener, queue: Queue, connection_id: self.__listener = listener self.__queue = queue self.__id = connection_id or eric_sse.generate_uuid() + self.__queues_error_handlers: list[QueuingErrorHandler] = [] + @property def listener(self) -> MessageQueueListener: @@ -29,6 +32,26 @@ def queue(self) -> Queue: def id(self) -> str: return self.__id + def send_message(self, msg: MessageContract): + try: + self.__queue.push(msg) + except Exception as e: + for handler in self.__queues_error_handlers: + handler.handle_push_error(msg=msg, exception=e) + raise + + + def fetch_message(self) -> MessageContract: + try: + return self.__queue.pop() + except Exception as e: + for handler in self.__queues_error_handlers: + handler.handle_pop_error(exception=e) + raise e + + def register_queuing_error_handler(self, handler: QueuingErrorHandler): + self.__queues_error_handlers.append(handler) + class ConnectionsFactory(ABC): @abstractmethod diff --git a/eric_sse/entities.py b/eric_sse/entities.py index 64018e5..67c7bd4 100644 --- a/eric_sse/entities.py +++ b/eric_sse/entities.py @@ -8,47 +8,39 @@ from eric_sse.listener import MessageQueueListener from eric_sse.connection import Connection, ConnectionsFactory, InMemoryConnectionsFactory from eric_sse.message import MessageContract, Message -from eric_sse.queues import Queue +from eric_sse.handlers import ListenerErrorHandler logger = eric_sse.get_logger() -MESSAGE_TYPE_CLOSED = '_eric_channel_closed' -MESSAGE_TYPE_END_OF_STREAM = '_eric_channel_eof' -MESSAGE_TYPE_INTERNAL_ERROR = '_eric_error' - - class _ConnectionManager: - """Maintains relationships between listeners and queues""" + """Maintains relationships between listeners and connections.""" def __init__(self, channel_id: str): self.__channel_id = channel_id self.__listeners: dict[str, MessageQueueListener] = {} - self.__queues: dict[str, Queue] = {} self.__connections: dict[str, Connection] = {} def register_connection(self, connection: Connection): self.__connections[connection.listener.id] = connection - self.__queues[connection.listener.id] = connection.queue self.__listeners[connection.listener.id] = connection.listener def remove_listener(self, listener_id: str): try: del self.__connections[listener_id] - del self.__queues[listener_id] del self.__listeners[listener_id] except KeyError: raise InvalidListenerException(listener_id) from None - def get_queue(self, listener_id: str) -> Queue: + def get_listener(self, listener_id: str) -> MessageQueueListener: try: - return self.__queues[listener_id] + return self.__listeners[listener_id] except KeyError: - raise InvalidListenerException(f"Invalid listener {listener_id}") from None + raise InvalidListenerException(listener_id) from None - def get_listener(self, listener_id: str) -> MessageQueueListener: + def get_connection(self, listener_id: str) -> Connection: try: - return self.__listeners[listener_id] + return self.__connections[listener_id] except KeyError: - raise InvalidListenerException + raise InvalidListenerException(listener_id) from None def get_listeners(self) -> dict[str, MessageQueueListener]: """Returns a dict mapping listener ids to listeners""" @@ -82,6 +74,9 @@ def __init__( self.__connection_manager: _ConnectionManager = _ConnectionManager(self.__id) self.__connections_factory = connections_factory if connections_factory else InMemoryConnectionsFactory() + self.__listeners_error_handlers: list[ListenerErrorHandler] = [] + + @property def id(self) -> str: @@ -100,10 +95,9 @@ async def message_stream(self, listener: MessageQueueListener) -> AsyncIterable[ A message with type = 'error' is yield on invalid listener """ - try: - self.__connection_manager.get_listener(listener.id) - except InvalidListenerException: - raise + + # check that listener was registered + _ = self.__connection_manager.get_listener(listener.id) async def new_messages(): try: @@ -149,12 +143,15 @@ def register_listener(self, listener: MessageQueueListener): def register_connection(self, connection: Connection): """ - Register and existing connection. + Register an existing connection. **Warning**: Listener and queue should belong to the same classes returned by connection factory to avoid compatibility issues with persistence layer """ self.__connection_manager.register_connection(connection) + def register_listener_error_handler(self, handler: ListenerErrorHandler): + self.__listeners_error_handlers.append(handler) + def remove_listener(self, listener_id: str): self.__connection_manager.remove_listener(listener_id) @@ -166,21 +163,29 @@ def deliver_next(self, listener_id: str) -> MessageContract: """ listener = self.get_listener(listener_id) if listener.is_running(): - queue = self.__connection_manager.get_queue(listener.id) - msg = queue.pop() - listener.on_message(msg) + msg = self._get_connection(listener.id).fetch_message() + try: + listener.on_message(msg) + except Exception as e: + for handler in self.__listeners_error_handlers: + handler.handle_on_message_error(msg=msg, exception=e) + raise return msg raise NoMessagesException - def _get_queue(self, listener_id: str) -> Queue: - return self.__connection_manager.get_queue(listener_id) + def _get_connection(self, listener_id: str) -> Connection: + return self.__connection_manager.get_connection(listener_id) def dispatch(self, listener_id: str, msg: MessageContract): """Adds a message to listener's queue""" - queue = self._get_queue(listener_id) - queue.push(msg) + try: + self._get_connection(listener_id).send_message(msg) + except Exception: + logger.exception("Failed to dispatch message to listener_id=%s", listener_id) + raise + logger.debug(f"Dispatched {msg} to {listener_id}") def broadcast(self, msg: MessageContract): @@ -194,8 +199,3 @@ def get_listener(self, listener_id: str) -> MessageQueueListener: def get_connections(self) -> Iterable[Connection]: return self.__connection_manager.get_connections() - async def watch(self) -> AsyncIterable[Any]: - listener = self.add_listener() - listener.start() - return self.message_stream(listener) - diff --git a/eric_sse/handlers.py b/eric_sse/handlers.py new file mode 100644 index 0000000..206e3e0 --- /dev/null +++ b/eric_sse/handlers.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from eric_sse.message import MessageContract + +from eric_sse import get_logger +logger = get_logger() + +class QueuingErrorHandler: + + def handle_push_error(self, msg: MessageContract, exception: Exception): + pass + def handle_pop_error(self, exception: Exception): + pass + +class ListenerErrorHandler(ABC): + @abstractmethod + def handle_on_message_error(self, msg: MessageContract, exception: Exception): + pass diff --git a/eric_sse/interfaces.py b/eric_sse/interfaces.py index 93fc855..6267c0a 100644 --- a/eric_sse/interfaces.py +++ b/eric_sse/interfaces.py @@ -42,6 +42,11 @@ def delete(self, connection_id: str): class ConnectionRepositoryInterface(ABC): + @property + @abstractmethod + def connections_factory(self) -> ConnectionsFactory: + pass + @property @abstractmethod def queues_repository(self) -> QueueRepositoryInterface: @@ -58,7 +63,7 @@ def load_all(self, channel_id: str) -> Iterable[Connection]: pass @abstractmethod - def load_one(self, channel_id: str, connection_id: str) -> Connection: + def load_one(self, connection_id: str) -> Connection: """Loads a connection given the connection and channel id it belongs to.""" pass @@ -68,18 +73,13 @@ def persist(self, channel_id: str, connection: Connection): pass @abstractmethod - def delete(self, channel_id: str, connection_id: str): - """Deletes a connection given the connection and channel id it belongs to.""" + def delete(self, connection_id: str): + """Deletes a connection given its id.""" pass class ChannelRepositoryInterface(ABC): - @property - @abstractmethod - def connections_factory(self) -> ConnectionsFactory: - """The connections factory that will be injected into concrete channel instances.""" - pass @property @abstractmethod diff --git a/eric_sse/patterns.py b/eric_sse/patterns.py new file mode 100644 index 0000000..d82e917 --- /dev/null +++ b/eric_sse/patterns.py @@ -0,0 +1,18 @@ +from eric_sse.handlers import QueuingErrorHandler +from eric_sse.message import MessageContract +from eric_sse.queues import Queue +from eric_sse import get_logger + +logger = get_logger() + +class DeadLetterQueueHandler(QueuingErrorHandler): + def __init__(self, queue: Queue): + self.__queue = queue + + def handle_push_error(self, msg: MessageContract, exception: Exception): + try: + self.__queue.push(msg) + except Exception as e: + logger.exception(f"Dead-letter push failed. msg type: {msg.type} payload {msg.payload} {repr(e)}") + + diff --git a/eric_sse/prefabs.py b/eric_sse/prefabs.py index 354ddd2..55ffad9 100644 --- a/eric_sse/prefabs.py +++ b/eric_sse/prefabs.py @@ -80,7 +80,7 @@ async def process_queue(self, listener: MessageQueueListener) -> AsyncIterable[d loop = asyncio.get_running_loop() while there_are_pending_messages: try: - msg = self._get_queue(listener_id=listener.id).pop() + msg = self._get_connection(listener_id=listener.id).fetch_message() tasks.append(loop.run_in_executor(e, DataProcessingChannel._invoke_callback_and_return, listener.on_message, msg)) except NoMessagesException: @@ -177,7 +177,7 @@ def create(self, channel_data: dict) -> SSEChannel: """ :param dict channel_data: Fill it with SSEChannel constructor arguments, except for connections_factory that wil be injected by repository """ - return SSEChannel(**channel_data, connections_factory=self.connections_factory) + return SSEChannel(**channel_data, connections_factory=self.connections_repository.connections_factory) @staticmethod def _channel_to_dict(channel: SSEChannel) -> dict: diff --git a/eric_sse/repository.py b/eric_sse/repository.py index ec26388..9ac5e26 100644 --- a/eric_sse/repository.py +++ b/eric_sse/repository.py @@ -3,7 +3,7 @@ from eric_sse.connection import Connection, ConnectionsFactory from eric_sse.entities import AbstractChannel -from eric_sse.exception import ItemNotFound +from eric_sse.exception import ItemNotFound, InvalidChannelException from eric_sse.interfaces import ChannelRepositoryInterface, ConnectionRepositoryInterface, ListenerRepositoryInterface, \ QueueRepositoryInterface @@ -28,7 +28,7 @@ def upsert(self, key: str, value: Any): @abstractmethod def fetch_one(self, key: str) -> Any: - """Return value correspondant to key""" + """Return value corresponding to key""" pass @abstractmethod @@ -74,16 +74,10 @@ class AbstractChannelRepository(ChannelRepositoryInterface, ABC): def __init__( self, storage: KvStorage, - connections_repository: ConnectionRepositoryInterface, - connections_factory: ConnectionsFactory + connections_repository: ConnectionRepositoryInterface ): self.__storage = storage self.__connections_repository = connections_repository - self.__connections_factory = connections_factory - - @property - def connections_factory(self) -> ConnectionsFactory: - return self.__connections_factory @property def connections_repository(self) -> ConnectionRepositoryInterface: @@ -127,7 +121,7 @@ def persist(self, channel: AbstractChannel): self.__connections_repository.persist(channel_id=channel.id, connection=connection) for connection_id_to_remove in persisted_connections_ids - current_connections_ids: - self.__connections_repository.delete(channel_id=channel.id, connection_id=connection_id_to_remove) + self.__connections_repository.delete(connection_id=connection_id_to_remove) def delete(self, channel_id: str): try: @@ -135,7 +129,7 @@ def delete(self, channel_id: str): except ItemNotFound: return for connection in self.__connections_repository.load_all(channel_id=channel.id): - self.__connections_repository.delete(channel_id=channel_id, connection_id=connection.id) + self.__connections_repository.delete(connection_id=connection.id) self.__storage.delete(channel_id) class ConnectionRepository(ConnectionRepositoryInterface): @@ -143,17 +137,22 @@ class ConnectionRepository(ConnectionRepositoryInterface): Concrete Connection Repository Relies on :class:`~eric_sse.repository.KvStorage` abstraction for final writes of connections data, and on - correspondant repositories for related objects ones. + corresponding repositories for related objects ones. """ def __init__( self, storage: KvStorage, listeners_repository: ListenerRepositoryInterface, - queues_repository: QueueRepositoryInterface + queues_repository: QueueRepositoryInterface, + connections_factory:ConnectionsFactory ): self.__storage = storage self.__listeners_repository = listeners_repository self.__queues_repository = queues_repository + self.__connections_factory = connections_factory + + CONNECTIONS_BY_CHANNEL_PREFIX: str = 'ch_cn' + CONNECTIONS_PREFIX: str = 'cn_ch' @property def queues_repository(self) -> QueueRepositoryInterface: @@ -163,27 +162,47 @@ def queues_repository(self) -> QueueRepositoryInterface: def listeners_repository(self) -> ListenerRepositoryInterface: return self.__listeners_repository + @property + def connections_factory(self) -> ConnectionsFactory: + return self.__connections_factory + def _load_connection(self, connection_id: str) -> Connection: + try: + _ = self.__storage.fetch_one(f'{self.CONNECTIONS_PREFIX}:{connection_id}') + except ItemNotFound as e: + raise e from None + listener = self.__listeners_repository.load(connection_id=connection_id) queue = self.__queues_repository.load(connection_id=connection_id) return Connection(listener=listener, queue=queue, connection_id=connection_id) def load_all(self, channel_id: str) -> Iterable[Connection]: - for connection_data in self.__storage.fetch_by_prefix(channel_id): - yield self._load_connection(connection_data['id']) + for connection_data in self.__storage.fetch_by_prefix(f'{self.CONNECTIONS_BY_CHANNEL_PREFIX}:{channel_id}:'): + yield self._load_connection(connection_id=connection_data['cn_id']) - def load_one(self, channel_id: str, connection_id: str) -> Connection: - return self._load_connection(self.__storage.fetch_one(f'{channel_id}:{connection_id}')['id']) + def load_one(self, connection_id: str) -> Connection: + return self._load_connection(connection_id=connection_id) def persist(self, channel_id: str, connection: Connection): + self.__listeners_repository.persist(connection_id=connection.id, listener=connection.listener) self.__queues_repository.persist(connection_id=connection.id, queue=connection.queue) - self.__storage.upsert(f'{channel_id}:{connection.id}', {'id': connection.id}) + self.__storage.upsert(key=f'{self.CONNECTIONS_PREFIX}:{connection.id}', value={'ch_id': channel_id, 'cn_id': connection.id}) + self.__storage.upsert(key=f'{self.CONNECTIONS_BY_CHANNEL_PREFIX}:{channel_id}:{connection.id}', value={'ch_id': channel_id, 'cn_id': connection.id}) + + + def delete(self, connection_id: str): + try: + connection_data = self.__storage.fetch_one(key=f'{self.CONNECTIONS_PREFIX}:{connection_id}') + except ItemNotFound: + return + + channel_id = connection_data['ch_id'] - def delete(self, channel_id: str, connection_id: str): self.__listeners_repository.delete(connection_id=connection_id) self.__queues_repository.delete(connection_id=connection_id) - self.__storage.delete(key=f'{channel_id}:{connection_id}') + self.__storage.delete(key=f'{self.CONNECTIONS_PREFIX}:{connection_id}') + self.__storage.delete(key=f'{self.CONNECTIONS_BY_CHANNEL_PREFIX}:{channel_id}:{connection_id}') diff --git a/examples/producer_consumer.py b/examples/producer_consumer.py index bd796c7..05151c0 100644 --- a/examples/producer_consumer.py +++ b/examples/producer_consumer.py @@ -7,7 +7,6 @@ from random import uniform from time import sleep from eric_sse import get_logger -from eric_sse.entities import MESSAGE_TYPE_CLOSED from eric_sse.listener import MessageQueueListener from eric_sse.message import SignedMessage from eric_sse.prefabs import DataProcessingChannel @@ -20,7 +19,7 @@ class Producer: async def produce_num(c: DataProcessingChannel, l: MessageQueueListener, num: int): for i in range(0, num): c.dispatch(l.id, SignedMessage(msg_type='counter', msg_payload=i, sender_id='producer')) - c.dispatch(l.id, SignedMessage(msg_type=MESSAGE_TYPE_CLOSED, sender_id='producer')) + c.dispatch(l.id, SignedMessage(msg_type='end_of_stream', sender_id='producer')) class Consumer(MessageQueueListener): @@ -48,3 +47,4 @@ async def main(): print(m) asyncio.run(main()) + diff --git a/pyproject.toml b/pyproject.toml index a7e9c1d..15acf1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "eric-sse" -description = "A lightweight message dispatcher based on SSE protocol data transfer objects format" +description = "A lightweight and extensible asyncronous message dispatcher" requires-python = ">=3.10" version = "2.1.1.2" authors = [ diff --git a/test/mock/connection.py b/test/mock/connection.py new file mode 100644 index 0000000..d8cf74c --- /dev/null +++ b/test/mock/connection.py @@ -0,0 +1,44 @@ +from eric_sse.connection import ConnectionsFactory, Connection +from eric_sse.handlers import QueuingErrorHandler +from eric_sse.listener import MessageQueueListener +from eric_sse.message import MessageContract, Message +from eric_sse.queues import Queue + + +class BrokenListener(MessageQueueListener): + + def on_message(self, msg: MessageContract) -> None: + raise Exception() + + +class BrokenQueue(Queue): + def __init__(self, broken_push: bool = True, broken_pop: bool = True) -> None: + self.broken_push = broken_push + self.broken_pop = broken_pop + + def pop(self) -> MessageContract: + if self.broken_pop: + raise Exception() + return Message(msg_type='test') + + def push(self, message: MessageContract) -> None: + if self.broken_push: + raise Exception() + + +class BrokenConnectionFactory(ConnectionsFactory): + + def __init__( + self, + q_handlers: list[QueuingErrorHandler], + queue: Queue | None = None, + ): + self.q_handlers = q_handlers + self.queue = queue or BrokenQueue() + + + def create(self, listener: MessageQueueListener | None = None) -> Connection: + connection = Connection(listener=listener, queue=self.queue) + for handler in self.q_handlers: + connection.register_queuing_error_handler(handler) + return connection diff --git a/test/test_entities.py b/test/test_entities.py index c7708cc..67b383f 100644 --- a/test/test_entities.py +++ b/test/test_entities.py @@ -142,10 +142,14 @@ def test_error_handling(self): async def test_error_handling_async(self): listener = MessageQueueListener() - + msgs = [] with pytest.raises(InvalidListenerException): + async for msg in self.sut.message_stream(listener): + msgs.append(msg) async for _ in self.sut.message_stream(listener): pass + self.assertEqual(1, len(msgs)) + self.assertEqual('error', msgs[0].type) async def test_stream(self): listener = MessageQueueListenerMock() diff --git a/test/test_error_handlers.py b/test/test_error_handlers.py new file mode 100644 index 0000000..b8ff8f2 --- /dev/null +++ b/test/test_error_handlers.py @@ -0,0 +1,77 @@ +from unittest import IsolatedAsyncioTestCase +from unittest.mock import MagicMock + +from eric_sse.handlers import ListenerErrorHandler, QueuingErrorHandler +from eric_sse.message import Message +from test.mock.channel import FakeChannel +from test.mock.connection import BrokenListener, BrokenQueue, BrokenConnectionFactory + + +class ErrorsHandlingTestCase(IsolatedAsyncioTestCase): + def setUp(self): + self.listeners_handler_mock = MagicMock(ListenerErrorHandler) + self.listeners_handler_mock2 = MagicMock(ListenerErrorHandler) + self.queues_handler_mock = MagicMock(QueuingErrorHandler) + self.queues_handler_mock2 = MagicMock(QueuingErrorHandler) + + + def test_queues_handler(self): + + # Set up broken push + channel = FakeChannel( + connections_factory=BrokenConnectionFactory( + q_handlers=[self.queues_handler_mock, self.queues_handler_mock2], + queue=self.queues_handler_mock + ) + ) + + my_listener = BrokenListener() + channel.register_listener(my_listener) + my_listener.start() + + # act + msg = Message(msg_type='test') + with self.assertRaises(Exception) as context: + channel.dispatch(listener_id=my_listener.id, msg=msg) + self.queues_handler_mock.handle_push_error.assert_called_once_with(msg=msg, exception=context.exception) + + + # Set up broken pop + channel = FakeChannel( + connections_factory=BrokenConnectionFactory( + q_handlers=[ + self.queues_handler_mock, + self.queues_handler_mock2, + ], + queue=BrokenQueue(broken_push=False), + ) + ) + my_listener = BrokenListener() + channel.register_listener(my_listener) + my_listener.start() + + with self.assertRaises(Exception) as context: + channel.dispatch(listener_id=my_listener.id, msg=Message(msg_type='test')) + channel.deliver_next(listener_id=my_listener.id) + self.queues_handler_mock.handle_pop_error.assert_called_once_with(exception=context.exception) + self.queues_handler_mock2.handle_pop_error.assert_called_once_with(exception=context.exception) + + + def test_listeners_handler(self): + channel = FakeChannel() + channel.register_listener_error_handler(self.listeners_handler_mock) + my_listener = BrokenListener() + my_listener.start() + channel.register_listener(my_listener) + msg = Message(msg_type='test') + failed: bool = False + try: + channel.dispatch(listener_id=my_listener.id, msg=msg) + channel.deliver_next(listener_id=my_listener.id) + except Exception as e: + failed = True + self.listeners_handler_mock.handle_on_message_error.assert_called_once_with(msg=msg, exception=e) + + if not failed: + self.fail('Exception was not raised') + diff --git a/test/test_patterns.py b/test/test_patterns.py new file mode 100644 index 0000000..d1eb2fb --- /dev/null +++ b/test/test_patterns.py @@ -0,0 +1,34 @@ +from logging import Logger +from unittest import TestCase +from unittest.mock import MagicMock + +import pytest + +import eric_sse.patterns +from eric_sse.connection import Connection +from eric_sse.listener import MessageQueueListener +from eric_sse.message import Message +from eric_sse.queues import Queue +from eric_sse.patterns import DeadLetterQueueHandler +from test.mock.channel import FakeChannel +from test.mock.connection import BrokenQueue + +class DeadLetterQueueTestCase(TestCase): + def setUp(self): + self.dead_letter_queue = MagicMock(Queue) + self.logger = MagicMock(Logger) + + def test_handle(self): + channel = FakeChannel() + sut = DeadLetterQueueHandler(self.dead_letter_queue) + listener = MessageQueueListener() + + connection = Connection(listener=listener, queue=BrokenQueue()) + connection.register_queuing_error_handler(sut) + channel.register_connection(connection) + + with pytest.raises(Exception) as e: + message = Message("test") + channel.broadcast(message) + + self.dead_letter_queue.push.assert_called_with(message) diff --git a/test/test_repository.py b/test/test_repository.py index 1e69a2a..90fce2d 100644 --- a/test/test_repository.py +++ b/test/test_repository.py @@ -1,13 +1,14 @@ from unittest import TestCase from unittest.mock import MagicMock -from eric_sse.connection import Connection + +from eric_sse.connection import Connection, ConnectionsFactory from eric_sse.listener import MessageQueueListener from eric_sse.queues import InMemoryQueue from eric_sse.repository import ConnectionRepository, KvStorage from eric_sse.interfaces import ListenerRepositoryInterface, QueueRepositoryInterface - from eric_sse.repository import InMemoryStorage -from eric_sse.exception import ItemNotFound +from eric_sse.exception import ItemNotFound, InvalidChannelException + from test.mock.channel import FakeChannelRepository, FakeConnectionsFactory, FakeChannel @@ -45,12 +46,14 @@ class ConnectionsRepositoryTestCase(TestCase): def setUp(self): self.listeners_repository = MagicMock(spec=ListenerRepositoryInterface) self.queues_repository = MagicMock(spec=QueueRepositoryInterface) + self.connections_factory = MagicMock(spec=ConnectionsFactory) self.storage = MagicMock(spec=KvStorage) self.sut = ConnectionRepository( storage=self.storage, listeners_repository=self.listeners_repository, - queues_repository=self.queues_repository + queues_repository=self.queues_repository, + connections_factory=self.connections_factory ) def test_persist_operations_are_delegated_to_composites(self): @@ -62,7 +65,7 @@ def test_persist_operations_are_delegated_to_composites(self): self.sut.persist(channel_id='fake_channel', connection=connection) self.listeners_repository.persist.assert_called_once_with(connection_id=connection.id, listener=connection.listener) self.queues_repository.persist.assert_called_once_with(connection_id=connection.id, queue=connection.queue) - self.storage.upsert.assert_called_once() + self.storage.upsert.assert_called() def test_deletions_are_delegated_to_composites(self): connection = Connection( @@ -71,16 +74,22 @@ def test_deletions_are_delegated_to_composites(self): ) self.sut.persist(channel_id='fake_channel', connection=connection) - self.sut.delete(channel_id='fake_channel', connection_id=connection.id) + self.sut.delete(connection_id=connection.id) self.listeners_repository.delete.assert_called_once_with(connection_id=connection.id) self.queues_repository.delete.assert_called_once_with(connection_id=connection.id) - self.storage.delete.assert_called_once() + self.storage.delete.assert_called() - def error_handling(self): + def test_error_handling(self): + self.sut = ConnectionRepository( + storage=InMemoryStorage(), + listeners_repository=self.listeners_repository, + queues_repository=self.queues_repository, + connections_factory=self.connections_factory + ) with self.assertRaises(ItemNotFound): - self.sut.load_one('nonexistent_channel', 'nonexistent_connection') + self.sut.load_one('nonexistent_channel') @@ -95,12 +104,12 @@ def create_sut(self): connections_repository = ConnectionRepository( listeners_repository=self.listeners_repository, queues_repository=self.queues_repository, + connections_factory=self.connections_factory, storage=InMemoryStorage(), ) return FakeChannelRepository( storage=InMemoryStorage(), - connections_repository=connections_repository, - connections_factory=self.connections_factory + connections_repository=connections_repository ) def test_persistence(self): @@ -133,6 +142,44 @@ def test_missing_connections_are_deleted_on_persist(self): channel = sut.load_one(channel_id=channel.id) self.assertEqual(0, len([c for c in channel.get_connections()])) +""" +MessageQueueListenerMock + + + +class FullPathTestCase(IsolatedAsyncioTestCase): + def setUp(self): + self.listeners_repository = MagicMock(spec=ListenerRepositoryInterface) + self.queues_repository = MagicMock(spec=QueueRepositoryInterface) + self.connections_factory = FakeConnectionsFactory() + + def create_sut(self): + connections_repository = ConnectionRepository( + listeners_repository=self.listeners_repository, + queues_repository=self.queues_repository, + connections_factory=self.connections_factory, + storage=InMemoryStorage(), + ) + return FakeChannelRepository( + storage=InMemoryStorage(), + connections_repository=connections_repository + ) + async def test_one(self): + sut = self.create_sut() + channel = FakeChannel() + + listener = channel.add_listener() + message = Message(msg_type='test') + + channel.dispatch(listener.id, message) + channel.broadcast(message) + + sut.persist(channel=channel) + channel_clone = sut.load_one(channel_id=channel.id) + listener_clone = channel_clone.get_listener(listener.id) + async for received_message in channel_clone.message_stream(listener_clone): + self.assertEqual(received_message.msg_type, 'test') +""" \ No newline at end of file diff --git a/test/test_sse_channel_repository.py b/test/test_sse_channel_repository.py index 38a17f6..6e26a16 100644 --- a/test/test_sse_channel_repository.py +++ b/test/test_sse_channel_repository.py @@ -58,9 +58,9 @@ def _create_sut(self): connections_repository=FakeConnectionRepository( storage=connections_storage, listeners_repository=FakeListenerRepository(listeners_storage), - queues_repository=FakeQueueRepository(queues_storage) - ), - connections_factory=InMemoryConnectionsFactory() + queues_repository=FakeQueueRepository(queues_storage), + connections_factory=InMemoryConnectionsFactory() + ) ) return sut diff --git a/update_docs.sh b/update_docs.sh index 301e058..84450ed 100755 --- a/update_docs.sh +++ b/update_docs.sh @@ -1,4 +1,12 @@ #!/bin/bash +# shellcheck disable=SC2164 + +CURRENT_VERSION="$(poetry version --short)" +rm -rf docs_archive/"$CURRENT_VERSION" +cp -rf docs_markdown/ docs_archive/"$CURRENT_VERSION" + +git add docs_archive/"$CURRENT_VERSION"/* + cd docs rm -rf build/html/ rm -rf build/markdown/