diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..b9127bc --- /dev/null +++ b/.editorconfig @@ -0,0 +1,18 @@ +root = true + +[*] +end_of_line = lf +insert_final_newline = true + +[*.py] +charset = utf-8 +indent_style = tab +tab_width = 4 +trim_trailing_whitespace = true +max_line_length = 110 + +[*.{yml,yaml}] +indent_style = space +charset = utf-8 +indent_size = 2 +trim_trailing_whitespace = true diff --git a/.github/actions/setup-environment/action.yml b/.github/actions/setup-environment/action.yml new file mode 100644 index 0000000..d04462e --- /dev/null +++ b/.github/actions/setup-environment/action.yml @@ -0,0 +1,11 @@ +name: 'Setup CI Environment' +description: 'Setup the environment in which CI scripts for the NVDA Remote Access server can be run' + +runs: + using: "composite" + steps: + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v6.1.0 + - name: Setup environment + shell: bash + run: uv sync --dev diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 61b7086..0c2f0d7 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -6,14 +6,10 @@ jobs: coverage: name: Check coverage with coverage.py runs-on: ubuntu-latest - permissions: - pull-requests: write + permissions: {} steps: - uses: actions/checkout@v4.2.2 - - name: Install the latest version of uv - uses: astral-sh/setup-uv@v6.1.0 - - name: Setup environment - run: uv sync --dev + - uses: ./.github/actions/setup-environment - name: Run unit tests run: uv run coverage run - name: Report coverage diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml new file mode 100644 index 0000000..2480733 --- /dev/null +++ b/.github/workflows/pyright.yml @@ -0,0 +1,21 @@ +name: Check types with Pyright + +on: + push: + branches: + - main + + pull_request: + branches: + - main + +jobs: + pyright: + name: Check types with pyright + runs-on: ubuntu-latest + permissions: {} + steps: + - uses: actions/checkout@v4.2.2 + - uses: ./.github/actions/setup-environment + - name: Run pyright + run: uv run pyright >> $GITHUB_STEP_SUMMARY diff --git a/pyproject.toml b/pyproject.toml index 2aad0a5..72883e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ reportOptionalMemberAccess = false # The following option causes problems due to dynamic member access reportUnknownArgumentType = false + [tool.ruff] line-length = 110 diff --git a/server.py b/server.py index 3159508..9c6e8bc 100644 --- a/server.py +++ b/server.py @@ -1,4 +1,3 @@ -import io import json import os import random @@ -11,55 +10,104 @@ from OpenSSL import crypto from twisted.internet import reactor, ssl from twisted.internet.interfaces import ITCPTransport -from twisted.internet.protocol import Factory, defer +from twisted.internet.protocol import Factory, defer, connectionDone from twisted.internet.task import LoopingCall from twisted.protocols.basic import LineReceiver from twisted.python import log, usage +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure +from typing import Any, TypedDict, cast logger = getLogger("remote-server") -PING_INTERVAL = 300 -INITIAL_TIMEOUT = 30 +PING_INTERVAL: int = 300 +INITIAL_TIMEOUT: int = 30 # Expiration time for generated keys, in seconds -GENERATED_KEY_EXPIRATION_TIME = 60 * 60 * 24 # One day +GENERATED_KEY_EXPIRATION_TIME: int = 60 * 60 * 24 # One day -class Channel(object): - def __init__(self, key, server_state=None): - self.clients = OrderedDict() +class UserDict(TypedDict): + """Typed dictionary representing a user. + + Keys in this dictionary cannot be renamed, as clients rely on them. + """ + + id: int + connection_type: str | None + + +class Message(TypedDict): + """Type hints for protocol messages. + + Keys in this dictionary cannot be renamed, as clients rely on them. + """ + + type: str + + +class Channel: + """Collection of connected users in the one "session".""" + + def __init__(self, key: str, serverState: "ServerState | None" = None) -> None: + """Constructor + + :param key: Unique identifier of this channel. + :param serverState: Server state, defaults to None + """ + self.clients: OrderedDict[int, User] = OrderedDict() self.key = key - self.server_state = server_state + self.serverState = serverState + + def addClient(self, client: "User") -> None: + """Joined when a new user wants to join the channel. - def add_client(self, client): - if client.protocol.protocol_version == 1: # pragma: no cover - protocol v1 is not tested - ids = [c.user_id for c in self.clients.values()] - msg = dict(type="channel_joined", channel=self.key, user_ids=ids, origin=client.user_id) + :param client: The new channel member. + """ + if client.protocol.protocolVersion == 1: # pragma: no cover - protocol v1 is not tested + ids = [c.userId for c in self.clients.values()] + msg = dict(type="channel_joined", channel=self.key, user_ids=ids, origin=client.userId) else: - clients = [i.as_dict() for i in self.clients.values()] - msg = dict(type="channel_joined", channel=self.key, origin=client.user_id, clients=clients) + clients = [i.asDict() for i in self.clients.values()] + msg = dict(type="channel_joined", channel=self.key, origin=client.userId, clients=clients) client.send(**msg) - for existing_client in self.clients.values(): - if existing_client.protocol.protocol_version == 1: # pragma: no cover - protocol v1 is not tested - existing_client.send(type="client_joined", user_id=client.user_id) + for existingClient in self.clients.values(): + if existingClient.protocol.protocolVersion == 1: # pragma: no cover - protocol v1 is not tested + existingClient.send(type="client_joined", user_id=client.userId) else: - existing_client.send(type="client_joined", client=client.as_dict()) - self.clients[client.user_id] = client + existingClient.send(type="client_joined", client=client.asDict()) + self.clients[client.userId] = client - def remove_connection(self, con): - if con.user_id in self.clients: - del self.clients[con.user_id] + def removeConnection(self, con: "User") -> None: + """Called when a user leaves the channel. + + :param con: The leaving channel member. + """ + if con.userId in self.clients: + del self.clients[con.userId] for client in self.clients.values(): - if client.protocol.protocol_version == 1: # pragma: no cover - protocol v1 is not tested - client.send(type="client_left", user_id=con.user_id) + if client.protocol.protocolVersion == 1: # pragma: no cover - protocol v1 is not tested + client.send(type="client_left", user_id=con.userId) else: - client.send(type="client_left", client=con.as_dict()) + client.send(type="client_left", client=con.asDict()) if not self.clients: - self.server_state.remove_channel(self.key) - - def ping_clients(self): - self.send_to_clients({"type": "ping"}) - - def send_to_clients(self, obj, exclude=None, origin=None): + self.serverState.removeChannel(self.key) + + def pingClients(self) -> None: + """Ping clients to ensure they're still connected.""" + self.sendToClients({"type": "ping"}) + + def sendToClients( + self, + obj: dict[str, Any], + exclude: "User | None" = None, + origin: int | None = None, + ) -> None: + """Broadcast a message to all users in this channel. + + :param obj: Message to send. + :param exclude: User to exclude from the broadcast, defaults to None + :param origin: Originating user, defaults to None + """ for client in self.clients.values(): if client is exclude: continue @@ -67,168 +115,237 @@ def send_to_clients(self, obj, exclude=None, origin=None): class Handler(LineReceiver): + """Handle sending and receiving messages.""" + delimiter = b"\n" - connection_id = 0 + connectionId = 0 MAX_LENGTH = 20 * 1048576 - def __init__(self): - self.connection_id = Handler.connection_id + 1 - Handler.connection_id += 1 - self.protocol_version = 1 + def __init__(self) -> None: + self.connectionId = Handler.connectionId + 1 + Handler.connectionId += 1 + self.protocolVersion = 1 - def connectionMade(self): - logger.info("Connection %d from %s" % (self.connection_id, self.transport.getPeer())) + def connectionMade(self) -> None: + """Called when a user first connects.""" + logger.info("Connection %d from %s", self.connectionId, self.transport.getPeer()) # We use a non-tcp transport for unit testing, # which doesn't support setTcpNoDelay. if isinstance(self.transport, ITCPTransport): # pragma: no cover - self.transport.setTcpNoDelay(True) - self.bytes_sent = 0 - self.bytes_received = 0 + # Methods of Zope interfaces don't take self, so pyright thinks this call has too many arguments + self.transport.setTcpNoDelay(True) # pyright: ignore [reportCallIssue] + self.bytesSent = 0 + self.bytesReceived = 0 self.user = User(protocol=self) - self.cleanup_timer = reactor.callLater(INITIAL_TIMEOUT, self.cleanup) - self.user.send_motd() + self.cleanupTimer = reactor.callLater(INITIAL_TIMEOUT, self.cleanup) + self.user.sendMotd() - def connectionLost(self, reason): + def connectionLost(self, reason: Failure = connectionDone) -> None: + """Called when the connection is dropped.""" logger.info( - "Connection %d lost, bytes sent: %d received: %d" - % (self.connection_id, self.bytes_sent, self.bytes_received), + "Connection %d lost, bytes sent: %d received: %d", + self.connectionId, + self.bytesSent, + self.bytesReceived, ) - self.user.connection_lost() + self.user.connectionLost() if ( - self.cleanup_timer is not None and not self.cleanup_timer.cancelled + self.cleanupTimer is not None and not self.cleanupTimer.cancelled ): # pragma: no cover - not sure how to trigger this - self.cleanup_timer.cancel() + self.cleanupTimer.cancel() - def lineReceived(self, line): - self.bytes_received += len(line) + def lineReceived(self, line: bytes) -> None: + """Called when a new line (a command) has been received. + + :param line: The incoming line. + """ + self.bytesReceived += len(line) try: parsed = json.loads(line) if not isinstance(parsed, dict): raise ValueError except ValueError: - logger.warn("Unable to parse %r" % line) + logger.warning("Unable to parse %r", line) self.transport.loseConnection() return + cast(dict[str, Any], parsed) if "type" not in parsed: - logger.warning("Invalid object received: %r" % parsed) + logger.warning("Invalid object received: %r", parsed) return parsed.pop("origin", None) # Remove an existing origin, we know where the message comes from. if self.user.channel is not None: - self.user.channel.send_to_clients(parsed, exclude=self.user, origin=self.user.user_id) + self.user.channel.sendToClients(parsed, exclude=self.user, origin=self.user.userId) return elif not hasattr(self, "do_" + parsed["type"]): - logger.warning("No function for type %s" % parsed["type"]) + logger.warning("No function for type %s", parsed["type"]) return getattr(self, "do_" + parsed["type"])(parsed) - def do_join(self, obj): - if "channel" not in obj or not obj["channel"]: + def do_join(self, obj: dict[str, str]) -> None: + """Called when receiving a "join" message.""" + if ( + "channel" not in obj + or not obj["channel"] + or "connection_type" not in obj + or not obj["connection_type"] + ): self.send(type="error", error="invalid_parameters") return - self.user.join(obj["channel"], connection_type=obj.get("connection_type")) - self.cleanup_timer.cancel() + self.user.join(obj["channel"], connectionType=obj["connection_type"]) + self.cleanupTimer.cancel() - def do_protocol_version(self, obj): + def do_protocol_version(self, obj: dict[str, int | str]) -> None: + """Called when a "protocol_version" message is received.""" + # TODO: Why don't we send an error message back? if "version" not in obj: return - self.protocol_version = obj["version"] + try: + self.protocolVersion = int(obj["version"]) + except ValueError: + return + + def do_generate_key(self, obj: dict[str, str]) -> None: + """Called when a "generate_key" message is received.""" + self.user.generateKey() - def do_generate_key(self, obj): - self.user.generate_key() + def send(self, origin: int | None = None, **msg: Any) -> None: + """Send a message. - def send(self, origin=None, **msg): - if self.protocol_version > 1 and origin: + :param origin: Originating user of the message, defaults to None + """ + if self.protocolVersion > 1 and origin: msg["origin"] = origin obj = json.dumps(msg).encode("ascii") - self.bytes_sent += len(obj) + self.bytesSent += len(obj) self.sendLine(obj) - def cleanup(self): - logger.info("Connection %d timed out" % self.connection_id) + def cleanup(self) -> None: + """Clean up this connection.""" + logger.info("Connection %d timed out", self.connectionId) self.transport.abortConnection() - self.cleanup_timer = None + self.cleanupTimer = None -class User(object): - user_id = 0 +class User: + """A single connected user.""" - def __init__(self, protocol): + userId = 0 + + def __init__(self, protocol: Handler) -> None: + """Initializer. + + :param protocol: The Handler through which this user connected. + """ self.protocol = protocol - self.channel = None - self.server_state = self.protocol.factory.server_state - self.connection_type = None - self.user_id = User.user_id + 1 - User.user_id += 1 - - def as_dict(self): - return dict(id=self.user_id, connection_type=self.connection_type) - - def generate_key(self): - ip = self.protocol.transport.getPeer().host - if ip in self.server_state.generated_ips and time.time() - self.server_state.generated_ips[ip] < 1: + self.channel: Channel | None = None + self.serverState: ServerState = self.protocol.factory.serverState + self.connectionType = None + self.userId = User.userId + 1 + User.userId += 1 + + def asDict(self) -> UserDict: + """Get a representation of this user suitable for sending over the wire.""" + return UserDict(id=self.userId, connection_type=self.connectionType) + + def generateKey(self) -> str | None: + """Generate a key for the user. + + :return: A channel key, or None if too many keys have been requested. + + :postcondition: The key will be temporarily persisted so that future key generation requests don't result in duplicate keys. + """ + ip: str = self.protocol.transport.getPeer().host # type: ignore + if ip in self.serverState.generatedIps and time.time() - self.serverState.generatedIps[ip] < 1: self.send(type="error", message="too many keys") self.protocol.transport.loseConnection() return - key = "".join([random.choice(string.digits) for i in range(7)]) - while key in self.server_state.generated_keys or key in self.server_state.channels.keys(): - key = "".join([random.choice(string.digits) for i in range(7)]) - self.server_state.generated_keys.add(key) - self.server_state.generated_ips[ip] = time.time() - reactor.callLater(GENERATED_KEY_EXPIRATION_TIME, lambda: self.server_state.generated_keys.remove(key)) + key = "".join([random.choice(string.digits) for _ in range(7)]) + while key in self.serverState.generatedKeys or key in self.serverState.channels.keys(): + key = "".join([random.choice(string.digits) for _ in range(7)]) + self.serverState.generatedKeys.add(key) + self.serverState.generatedIps[ip] = time.time() + reactor.callLater(GENERATED_KEY_EXPIRATION_TIME, lambda: self.serverState.generatedKeys.remove(key)) if key: # pragma: no cover - I can't work out why this branch is here. When would this be False? self.send(type="generate_key", key=key) return key - def connection_lost(self): + def connectionLost(self) -> None: + """Remove this user when they disconnect.""" if ( self.channel is not None ): # pragma: no branch - we don't care about the alternative, as it's a no-op - self.channel.remove_connection(self) + self.channel.removeConnection(self) - def join(self, channel, connection_type): + def join(self, channel: str, connectionType: str) -> None: + """Add this user to a channel. + + :param channel: Key of the channel to join. If no channel with this key exists, a new channel will be created. + :param connectionType: Leader ("master") or follower ("slave"). + """ if self.channel: self.send(type="error", error="already_joined") return - self.connection_type = connection_type - self.channel = self.server_state.find_or_create_channel(channel) - self.channel.add_client(self) + self.connectionType = connectionType + self.channel = self.serverState.findOrCreateChannel(channel) + self.channel.addClient(self) # TODO: Work out if this is ever called. - def do_generate_key(self): # pragma: no cover - key = self.generate_key() + def do_generate_key(self) -> None: # pragma: no cover + """Not sure what calls this?""" + key = self.generateKey() if key: self.send(type="generate_key", key=key) - def send(self, **obj): + def send(self, **obj: Any) -> None: + """Send a message to this user.""" self.protocol.send(**obj) - def send_motd(self): - if self.server_state.motd is not None: - self.send(type="motd", motd=self.server_state.motd) + def sendMotd(self) -> None: + """Send the message of the day to this user.""" + if self.serverState.motd is not None: + self.send(type="motd", motd=self.serverState.motd) class RemoteServerFactory(Factory): - def __init__(self, server_state): - self.server_state = server_state + """Factory to add common functionality to connections.""" + + def __init__(self, serverState: "ServerState") -> None: + """Initializer. - def ping_connected_clients(self): - for channel in self.server_state.channels.values(): - channel.ping_clients() + :param serverState: Status tracking object. + """ + self.serverState = serverState + def pingConnectedClients(self) -> None: + """Ping all users in all channels to determine if they're still connected.""" + for channel in self.serverState.channels.values(): + channel.pingClients() -class ServerState(object): - def __init__(self): - self.channels = {} + +class ServerState: + """Object that tracks the status of the server.""" + + def __init__(self) -> None: + self.channels: dict[str, Channel] = {} # Set of already generated keys - self.generated_keys = set() - # Dictionary of ips to generated time for people who have generated keys. - self.generated_ips = {} + self.generatedKeys: set[str] = set() + # Mapping of IPs to generated time for people who have generated keys. + self.generatedIps: dict[str, float] = {} self.motd: str | None = None - def remove_channel(self, channel): + def removeChannel(self, channel: str) -> None: + """Close a channel. + + :param channel: Key of the channel to remove. + """ del self.channels[channel] - def find_or_create_channel(self, name): + def findOrCreateChannel(self, name: str) -> Channel: + """Find an existing channel, or create one if one doesn't already exist. + + :param name: Key of the channel to find/create. + :return: The found or created channel. + """ if name in self.channels: channel = self.channels[name] else: @@ -249,32 +366,38 @@ class Options(usage.Options): # Exclude from coverage as it's hard to unit test. -def main(): # pragma: no cover +def main() -> Deferred[None]: # pragma: no cover + # Read options from CLI. config = Options() config.parseOptions() + # Open SSL keys. privkey = open(config["privkey"]).read() - certData = open(config["certificate"]).read() - chain = open(config["chain"]).read() + certData = open(config["certificate"], "rb").read() + chain = open(config["chain"], "rb").read() log.startLogging(sys.stdout) + # Initialise encryption privkey = crypto.load_privatekey(crypto.FILETYPE_PEM, privkey) certificate = crypto.load_certificate(crypto.FILETYPE_PEM, certData) chain = crypto.load_certificate(crypto.FILETYPE_PEM, chain) - context_factory = ssl.CertificateOptions( + contextFactory = ssl.CertificateOptions( privateKey=privkey, certificate=certificate, extraCertChain=[chain], ) + # Initialise the server state machine state = ServerState() - if os.path.exists(config["motd"]): - with io.open(config["motd"], encoding="utf-8") as fp: + if os.path.isfile(config["motd"]): + with open(config["motd"], "r", encoding="utf-8") as fp: state.motd = fp.read().strip() else: state.motd = None - f = RemoteServerFactory(state) - l = LoopingCall(f.ping_connected_clients) - l.start(PING_INTERVAL) - f.protocol = Handler - reactor.listenSSL(int(config["port"]), f, context_factory, interface=config["network-interface"]) + # Set up the machinery of the server. + factory = RemoteServerFactory(state) + looper = LoopingCall(factory.pingConnectedClients) + looper.start(PING_INTERVAL) + factory.protocol = Handler + # Start running the server. + reactor.listenSSL(int(config["port"]), factory, contextFactory, interface=config["network-interface"]) reactor.run() return defer.Deferred() diff --git a/test.py b/test.py index 8eca766..55bdf87 100644 --- a/test.py +++ b/test.py @@ -2,11 +2,11 @@ from itertools import islice import json import random -from typing import Any, Final, NamedTuple +from typing import Any, Final, NamedTuple, cast from unittest import mock from twisted.internet import reactor -from twisted.internet.protocol import Protocol, connectionDone +from twisted.internet.protocol import connectionDone from twisted.internet.task import Clock from twisted.internet.testing import StringTransport from twisted.trial import unittest @@ -25,7 +25,7 @@ class Client(NamedTuple): """Structure representing a client connection to the server.""" - protocol: Protocol + protocol: Handler """Serverside protocol. Write to this to represent the client sending to the server.""" transport: StringTransport @@ -36,17 +36,17 @@ def mockUser(id: int) -> mock.MagicMock: """Create a MagicMock representing a user.""" return mock.MagicMock( spec=User, - user_id=id, + userId=id, protocol=MockHandler(), as_dict=lambda: dict(id=id, connection_type="dummy"), ) -def MockHandler(protocol_version: int = 2, serverState: ServerState | None = None) -> mock.MagicMock: +def MockHandler(protocolVersion: int = 2, serverState: ServerState | None = None) -> mock.MagicMock: """Return a MagicMock representing a Handler.""" return mock.MagicMock( spec=Handler, - protocol_version=protocol_version, + protocolVersion=protocolVersion, factory=mockRemoteServerFactory(serverState=serverState or ServerState()), ) @@ -56,7 +56,7 @@ def mockChannel(key: str, clients: Iterable[User]) -> mock.MagicMock: return mock.MagicMock( speck=Channel, key=key, - clients={client.user_id: client for client in clients}, + clients={client.userId: client for client in clients}, ) @@ -64,7 +64,7 @@ def mockRemoteServerFactory(serverState: ServerState) -> mock.MagicMock: """Return a MagicMock representing a RemoteServerFactory.""" return mock.MagicMock( spec=RemoteServerFactory, - server_state=serverState, + serverState=serverState, ) @@ -79,9 +79,9 @@ def setUp(self) -> None: def test_addClient(self): """Test adding a client to a channel.""" oldUsers = [mockUser(id=id) for id in range(3)] - self.channel.clients.update({user.user_id: user for user in oldUsers}) + self.channel.clients.update({user.userId: user for user in oldUsers}) newUser = mockUser(id=4) - self.channel.add_client(newUser) + self.channel.addClient(newUser) self.assertEqual(newUser, self.channel.clients[4]) newUser.send.assert_called_once() self.assertEqual(newUser.send.call_args.kwargs["type"], "channel_joined") @@ -94,10 +94,10 @@ def test_removeConnection(self): allUsers = [mockUser(id=id) for id in range(4)] leavingUser = allUsers[1] leftUsers = [user for user in allUsers if user is not leavingUser] - self.channel.clients.update({user.user_id: user for user in allUsers}) - self.assertIs(self.channel.clients[leavingUser.user_id], leavingUser) - self.channel.remove_connection(leavingUser) - self.assertNotIn(leavingUser.user_id, self.channel.clients) + self.channel.clients.update({user.userId: user for user in allUsers}) + self.assertIs(self.channel.clients[leavingUser.userId], leavingUser) + self.channel.removeConnection(leavingUser) + self.assertNotIn(leavingUser.userId, self.channel.clients) self.assertNotIn(leavingUser, self.channel.clients.values()) for leftUser in leftUsers: leftUser.send.assert_called_once() @@ -107,9 +107,9 @@ def test_removeConnection_notJoined(self): """Test removing a client from a channel of which it isn't a member does nothing.""" memberUsers = [mockUser(id=id) for id in range(4)] nonmemberUser = memberUsers.pop(2) - oldChannelClients = {user.user_id: user for user in memberUsers} + oldChannelClients = {user.userId: user for user in memberUsers} self.channel.clients.update(oldChannelClients) - self.channel.remove_connection(nonmemberUser) + self.channel.removeConnection(nonmemberUser) # NOTE: The current implementation sends client_left messages to the remaining clients, # even if the client wasn't in the channel to begin with. # Sending these messages is already covered in another test. @@ -118,23 +118,23 @@ def test_removeConnection_notJoined(self): def test_cleanup(self): """Test removing the last client removes the channel from the server state.""" user = mockUser(id=1) - self.channel.add_client(user) - self.channel.remove_connection(user) + self.channel.addClient(user) + self.channel.removeConnection(user) self.assertNotIn("channel", self.state.channels) def test_sendToClients_all(self): """Test sending to all clients in the channel.""" users = [mockUser(id) for id in range(4)] - self.channel.clients.update({user.user_id: user for user in users}) - self.channel.send_to_clients({"this": "is a message"}, origin=99) + self.channel.clients.update({user.userId: user for user in users}) + self.channel.sendToClients({"this": "is a message"}, origin=99) for user in users: user.send.assert_called_once_with(this="is a message", origin=99) def test_sendToClients_except(self): """Test sending to all clients but one in the channel.""" users = [mockUser(id) for id in range(4)] - self.channel.clients.update({user.user_id: user for user in users}) - self.channel.send_to_clients({"this": "is a message"}, origin=99, exclude=users[2]) + self.channel.clients.update({user.userId: user for user in users}) + self.channel.sendToClients({"this": "is a message"}, origin=99, exclude=users[2]) for user in users: if user is users[2]: user.send.assert_not_called() @@ -144,8 +144,8 @@ def test_sendToClients_except(self): def test_ping(self): """Test pinging the clients in the channel.""" users = [mockUser(id) for id in range(4)] - self.channel.clients.update({user.user_id: user for user in users}) - self.channel.ping_clients() + self.channel.clients.update({user.userId: user for user in users}) + self.channel.pingClients() for user in users: user.send.assert_called_once_with(type="ping", origin=None) @@ -154,15 +154,15 @@ class TestUser(unittest.TestCase): """Test the User class.""" def setUp(self) -> None: - User.user_id = 0 + User.userId = 0 def tearDown(self) -> None: - User.user_id = 0 + User.userId = 0 def test_consecutiveUserCreation(self): """Test that creating several users sequentially creates them with sequential user IDs.""" users = (User(mock.Mock(Handler)) for _ in range(10)) - self.assertSequenceEqual(list(map(lambda user: user.user_id, users)), range(1, 11)) + self.assertSequenceEqual(list(map(lambda user: user.userId, users)), range(1, 11)) def test_join(self): """Test that adding a user to a channel works as expected.""" @@ -171,7 +171,7 @@ def test_join(self): user = User(MockHandler(serverState=serverState)) user.join(CHANNEL_ID, "master") self.assertIs(user.channel, serverState.channels[CHANNEL_ID]) - self.assertIs(user, serverState.channels[CHANNEL_ID].clients[user.user_id]) + self.assertIs(user, serverState.channels[CHANNEL_ID].clients[user.userId]) def test_join_alreadyJoined(self): """Test that adding a user who is already in a channel to a new channel fails.""" @@ -201,7 +201,7 @@ def test_findOrCreateChannel_create(self): extantChannels = self._addChannels() oldChannels = self.serverState.channels.copy() self.assertNotIn("newChannel", self.serverState.channels) - newChannel = self.serverState.find_or_create_channel("newChannel") + newChannel = self.serverState.findOrCreateChannel("newChannel") self.assertIn("newChannel", self.serverState.channels) self.assertIs(self.serverState.channels["newChannel"], newChannel) self.assertNotIn(newChannel, extantChannels) @@ -213,7 +213,7 @@ def test_findOrCreateChannel_find(self): oldChannels = self.serverState.channels.copy() self.assertIn("c", self.serverState.channels) expectedChannel = extantChannels[2] - foundChannel = self.serverState.find_or_create_channel("c") + foundChannel = self.serverState.findOrCreateChannel("c") self.assertIs(expectedChannel, foundChannel) self.assertEqual(oldChannels, self.serverState.channels) @@ -222,15 +222,15 @@ class TestRemoteServerFactory(unittest.TestCase): """Test the RemoteServerFactory class.""" def test_pingClients(self): - """Test that calling ping_connected_clients calls ping_clients on all channels, regardless of size.""" + """Test that calling ping_connected_clients calls pingClients on all channels, regardless of size.""" serverState = ServerState() factory = RemoteServerFactory(serverState) userIterator = (mockUser(id) for id in range(10)) channels = tuple(mockChannel(key=chr(n + 65), clients=islice(userIterator, n)) for n in range(5)) serverState.channels.update({channel.key: channel for channel in channels}) - factory.ping_connected_clients() + factory.pingConnectedClients() for channel in channels: - channel.ping_clients.assert_called_once() + channel.pingClients.assert_called_once() class BaseServerTestCase(unittest.TestCase): @@ -241,8 +241,8 @@ class BaseServerTestCase(unittest.TestCase): def setUp(self) -> None: # Ensure we're starting from a common baseline - self._oldUserId = User.user_id - User.user_id = 0 + self._oldUserId = User.userId + User.userId = 0 self.state = ServerState() self.factory = RemoteServerFactory(self.state) self.factory.protocol = Handler @@ -251,14 +251,18 @@ def setUp(self) -> None: def tearDown(self) -> None: # Put things back how they were when we found them - User.user_id = self._oldUserId + User.userId = self._oldUserId def _createClient(self) -> Client: """Create a client-server connection.""" - protocol = self.factory.buildProtocol(("127.0.0.1", 0)) + # A (host, port) tuple works fine here. + # Even using twisted.internet.address.IPv4Address` here doesn't work, + # as pyright doesn't understand Zope interfaces. + protocol = self.factory.buildProtocol(("127.0.0.1", 0)) # pyright: ignore [reportArgumentType] transport = StringTransport() protocol.makeConnection(transport) - return Client(protocol=protocol, transport=transport) + assert protocol is not None # Needed to shut pyright up + return Client(protocol=cast(Handler, protocol), transport=transport) def _connectClient(self, protocolVersion: int = 2) -> Client: """Create and initialize a new connection.""" @@ -300,7 +304,7 @@ def setUp(self) -> None: self.protocol, self.transport = self._connectClient() random.seed(self.RANDOM_SEED) - def _test(self, serverReceived: bytes, clientReceived: bytes) -> None: + def _test(self, serverReceived: dict[str, Any], clientReceived: dict[str, Any]) -> None: self.protocol.dataReceived(json.dumps(serverReceived).encode() + b"\n") self.assertEqual(json.loads(self.transport.value().decode()), clientReceived) self.transport.clear() @@ -309,9 +313,9 @@ def test_generateKey(self): """Test that requesting the server to generate a key returns the expected result, and temporarily persists the key to avoid collisions.""" key = self.EXPECTED_KEYS[0] self._test({"type": "generate_key"}, {"type": "generate_key", "key": key}) - self.assertIn(key, self.state.generated_keys, "Key was not persisted where expected.") + self.assertIn(key, self.state.generatedKeys, "Key was not persisted where expected.") self.clock.advance(GENERATED_KEY_EXPIRATION_TIME) - self.assertNotIn(key, self.state.generated_keys, "Key was not removed after expiration.") + self.assertNotIn(key, self.state.generatedKeys, "Key was not removed after expiration.") @mock.patch("time.time", return_value=12345) def test_repeated_generateKey_ok(self, mock_time: mock.MagicMock): @@ -444,10 +448,19 @@ def test_join_withoutChannel(self): def test_protocol_version_withoutVersion(self): """Test that sending a 'protocol_version' message without a 'version' returns nothing.""" - client = self._connectClient() - # TODO: Work out how to find the associated Handler and check that its protocol version doesn't change. + client = self._createClient() + oldProtocolVersion = client.protocol.protocolVersion self._send(client, {"type": "protocol_version"}) self.assertIsNone(self._receive(client)) + self.assertEqual(client.protocol.protocolVersion, oldProtocolVersion) + + def test_protocol_version_withInvalidVersion(self): + """Test that sending a 'protocol_version' message with a non-integer 'version' returns nothing.""" + client = self._createClient() + oldProtocolVersion = client.protocol.protocolVersion + self._send(client, {"type": "protocol_version", "version": "NaN"}) + self.assertIsNone(self._receive(client)) + self.assertEqual(client.protocol.protocolVersion, oldProtocolVersion) def test_inactivityCausesDisconnection(self): """Test that connecting without joining a channel causes disconnection."""