forked from NVDARemote/remote-server
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathserver.py
More file actions
431 lines (361 loc) · 14 KB
/
server.py
File metadata and controls
431 lines (361 loc) · 14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
# Copyright 2020-2025 NV Access Limited, Christopher Toth, Tyler Spivey
#
# This file is part of the NVDA Remote Access Relay Server.
#
# NVDA Remote Access Relay Server is free software: you can redistribute it and/or modify it under the terms
# of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of
# the License, or (at your option) any later version.
#
# NVDA Remote Access Relay Server is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero
# General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License along with NVDA Remote Access Relay
# Server. If not, see <https://www.gnu.org/licenses/>.
import json
import os
import random
import string
import sys
import time
from collections import OrderedDict
from logging import getLogger
from typing import Any, TypedDict, cast
from OpenSSL import crypto
from twisted.internet import reactor, ssl
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import ITCPTransport
from twisted.internet.protocol import Factory, connectionDone, defer
from twisted.internet.task import LoopingCall
from twisted.protocols.basic import LineReceiver
from twisted.protocols.haproxy._wrapper import HAProxyWrappingFactory
from twisted.python import log, usage
from twisted.python.failure import Failure
logger = getLogger("remote-server")
PING_INTERVAL: int = 300
INITIAL_TIMEOUT: int = 30
# Expiration time for generated keys, in seconds
GENERATED_KEY_EXPIRATION_TIME: int = 60 * 60 * 24 # One day
class UserDict(TypedDict):
"""Typed dictionary representing a user.
Keys in this dictionary cannot be renamed, as clients rely on them.
"""
id: int
connection_type: str | None
class Message(TypedDict):
"""Type hints for protocol messages.
Keys in this dictionary cannot be renamed, as clients rely on them.
"""
type: str
class Channel:
"""Collection of connected users in the one "session"."""
def __init__(self, key: str, serverState: "ServerState | None" = None) -> None:
"""Constructor
:param key: Unique identifier of this channel.
:param serverState: Server state, defaults to None
"""
self.clients: OrderedDict[int, User] = OrderedDict()
self.key = key
self.serverState = serverState
def addClient(self, client: "User") -> None:
"""Joined when a new user wants to join the channel.
:param client: The new channel member.
"""
if client.protocol.protocolVersion == 1: # pragma: no cover - protocol v1 is not tested
ids = [c.userId for c in self.clients.values()]
msg = dict(type="channel_joined", channel=self.key, user_ids=ids, origin=client.userId)
else:
clients = [i.asDict() for i in self.clients.values()]
msg = dict(type="channel_joined", channel=self.key, origin=client.userId, clients=clients)
client.send(**msg)
for existingClient in self.clients.values():
if existingClient.protocol.protocolVersion == 1: # pragma: no cover - protocol v1 is not tested
existingClient.send(type="client_joined", user_id=client.userId)
else:
existingClient.send(type="client_joined", client=client.asDict())
self.clients[client.userId] = client
def removeConnection(self, con: "User") -> None:
"""Called when a user leaves the channel.
:param con: The leaving channel member.
"""
if con.userId in self.clients:
del self.clients[con.userId]
for client in self.clients.values():
if client.protocol.protocolVersion == 1: # pragma: no cover - protocol v1 is not tested
client.send(type="client_left", user_id=con.userId)
else:
client.send(type="client_left", client=con.asDict())
if not self.clients:
self.serverState.removeChannel(self.key)
def pingClients(self) -> None:
"""Ping clients to ensure they're still connected."""
self.sendToClients({"type": "ping"})
def sendToClients(
self,
obj: dict[str, Any],
exclude: "User | None" = None,
origin: int | None = None,
) -> None:
"""Broadcast a message to all users in this channel.
:param obj: Message to send.
:param exclude: User to exclude from the broadcast, defaults to None
:param origin: Originating user, defaults to None
"""
for client in self.clients.values():
if client is exclude:
continue
client.send(origin=origin, **obj)
class Handler(LineReceiver):
"""Handle sending and receiving messages."""
delimiter = b"\n"
connectionId = 0
MAX_LENGTH = 20 * 1048576
def __init__(self) -> None:
self.connectionId = Handler.connectionId + 1
Handler.connectionId += 1
self.protocolVersion = 1
def connectionMade(self) -> None:
"""Called when a user first connects."""
logger.info("Connection %d from %s", self.connectionId, self.transport.getPeer())
# We use a non-tcp transport for unit testing,
# which doesn't support setTcpNoDelay.
if isinstance(self.transport, ITCPTransport): # pragma: no cover
# Methods of Zope interfaces don't take self, so pyright thinks this call has too many arguments
self.transport.setTcpNoDelay(True) # pyright: ignore [reportCallIssue]
self.bytesSent = 0
self.bytesReceived = 0
self.user = User(protocol=self)
self.cleanupTimer = reactor.callLater(INITIAL_TIMEOUT, self.cleanup)
self.user.sendMotd()
def connectionLost(self, reason: Failure = connectionDone) -> None:
"""Called when the connection is dropped."""
logger.info(
"Connection %d lost, bytes sent: %d received: %d",
self.connectionId,
self.bytesSent,
self.bytesReceived,
)
self.user.connectionLost()
if (
self.cleanupTimer is not None and not self.cleanupTimer.cancelled
): # pragma: no cover - not sure how to trigger this
self.cleanupTimer.cancel()
def lineReceived(self, line: bytes) -> None:
"""Called when a new line (a command) has been received.
:param line: The incoming line.
"""
self.bytesReceived += len(line)
try:
parsed = json.loads(line)
if not isinstance(parsed, dict):
raise ValueError
except ValueError:
logger.warning("Unable to parse %r", line)
self.transport.loseConnection()
return
cast(dict[str, Any], parsed)
if "type" not in parsed:
logger.warning("Invalid object received: %r", parsed)
return
parsed.pop("origin", None) # Remove an existing origin, we know where the message comes from.
if self.user.channel is not None:
self.user.channel.sendToClients(parsed, exclude=self.user, origin=self.user.userId)
return
elif not hasattr(self, "do_" + parsed["type"]):
logger.warning("No function for type %s", parsed["type"])
return
getattr(self, "do_" + parsed["type"])(parsed)
def do_join(self, obj: dict[str, str]) -> None:
"""Called when receiving a "join" message."""
if (
"channel" not in obj
or not obj["channel"]
or "connection_type" not in obj
or not obj["connection_type"]
):
self.send(type="error", error="invalid_parameters")
return
self.user.join(obj["channel"], connectionType=obj["connection_type"])
self.cleanupTimer.cancel()
def do_protocol_version(self, obj: dict[str, int | str]) -> None:
"""Called when a "protocol_version" message is received."""
# TODO: Why don't we send an error message back?
if "version" not in obj:
return
try:
self.protocolVersion = int(obj["version"])
except ValueError:
return
def do_generate_key(self, obj: dict[str, str]) -> None:
"""Called when a "generate_key" message is received."""
self.user.generateKey()
def send(self, origin: int | None = None, **msg: Any) -> None:
"""Send a message.
:param origin: Originating user of the message, defaults to None
"""
if self.protocolVersion > 1 and origin:
msg["origin"] = origin
obj = json.dumps(msg).encode("ascii")
self.bytesSent += len(obj)
self.sendLine(obj)
def cleanup(self) -> None:
"""Clean up this connection."""
logger.info("Connection %d timed out", self.connectionId)
self.transport.abortConnection()
self.cleanupTimer = None
class User:
"""A single connected user."""
userId = 0
def __init__(self, protocol: Handler) -> None:
"""Initializer.
:param protocol: The Handler through which this user connected.
"""
self.protocol = protocol
self.channel: Channel | None = None
self.serverState: ServerState = self.protocol.factory.serverState
self.connectionType = None
self.userId = User.userId + 1
User.userId += 1
def asDict(self) -> UserDict:
"""Get a representation of this user suitable for sending over the wire."""
return UserDict(id=self.userId, connection_type=self.connectionType)
def generateKey(self) -> str | None:
"""Generate a key for the user.
:return: A channel key, or None if too many keys have been requested.
:postcondition: The key will be temporarily persisted so that future key generation requests don't result in duplicate keys.
"""
ip: str = self.protocol.transport.getPeer().host # type: ignore
if ip in self.serverState.generatedIps and time.time() - self.serverState.generatedIps[ip] < 1:
self.send(type="error", message="too many keys")
self.protocol.transport.loseConnection()
return
key = "".join([random.choice(string.digits) for _ in range(7)])
while key in self.serverState.generatedKeys or key in self.serverState.channels.keys():
key = "".join([random.choice(string.digits) for _ in range(7)])
self.serverState.generatedKeys.add(key)
self.serverState.generatedIps[ip] = time.time()
reactor.callLater(GENERATED_KEY_EXPIRATION_TIME, lambda: self.serverState.generatedKeys.remove(key))
if key: # pragma: no cover - I can't work out why this branch is here. When would this be False?
self.send(type="generate_key", key=key)
return key
def connectionLost(self) -> None:
"""Remove this user when they disconnect."""
if (
self.channel is not None
): # pragma: no branch - we don't care about the alternative, as it's a no-op
self.channel.removeConnection(self)
def join(self, channel: str, connectionType: str) -> None:
"""Add this user to a channel.
:param channel: Key of the channel to join. If no channel with this key exists, a new channel will be created.
:param connectionType: Leader ("master") or follower ("slave").
"""
if self.channel:
self.send(type="error", error="already_joined")
return
self.connectionType = connectionType
self.channel = self.serverState.findOrCreateChannel(channel)
self.channel.addClient(self)
# TODO: Work out if this is ever called.
def do_generate_key(self) -> None: # pragma: no cover
"""Not sure what calls this?"""
key = self.generateKey()
if key:
self.send(type="generate_key", key=key)
def send(self, **obj: Any) -> None:
"""Send a message to this user."""
self.protocol.send(**obj)
def sendMotd(self) -> None:
"""Send the message of the day to this user."""
if self.serverState.motd is not None:
self.send(type="motd", motd=self.serverState.motd)
class RemoteServerFactory(Factory):
"""Factory to add common functionality to connections."""
def __init__(self, serverState: "ServerState") -> None:
"""Initializer.
:param serverState: Status tracking object.
"""
self.serverState = serverState
def pingConnectedClients(self) -> None:
"""Ping all users in all channels to determine if they're still connected."""
for channel in self.serverState.channels.values():
channel.pingClients()
class ServerState:
"""Object that tracks the status of the server."""
def __init__(self) -> None:
self.channels: dict[str, Channel] = {}
# Set of already generated keys
self.generatedKeys: set[str] = set()
# Mapping of IPs to generated time for people who have generated keys.
self.generatedIps: dict[str, float] = {}
self.motd: str | None = None
def removeChannel(self, channel: str) -> None:
"""Close a channel.
:param channel: Key of the channel to remove.
"""
del self.channels[channel]
def findOrCreateChannel(self, name: str) -> Channel:
"""Find an existing channel, or create one if one doesn't already exist.
:param name: Key of the channel to find/create.
:return: The found or created channel.
"""
if name in self.channels:
channel = self.channels[name]
else:
channel = Channel(name, self)
self.channels[name] = channel
return channel
class Options(usage.Options):
optParameters = [
["certificate", "c", "cert", "SSL certificate"],
["privkey", "k", "privkey", "SSL private key"],
["chain", "C", "chain", "SSL chain"],
["motd", "m", "motd", "MOTD"],
["network-interface", "i", "::", "Interface to listen on"],
["port", "p", "6837", "Server port"],
]
optFlags = [
["no-ssl", "n", "Disable SSL"],
]
# Exclude from coverage as it's hard to unit test.
def main() -> Deferred[None]: # pragma: no cover
sslContext: ssl.CertificateOptions | None = None
# Read options from CLI.
config = Options()
config.parseOptions()
log.startLogging(sys.stdout)
if not config["no-ssl"]:
# Initialise encryption
# Open SSL keys.
privkey = open(config["privkey"]).read()
certData = open(config["certificate"], "rb").read()
chain = open(config["chain"], "rb").read()
privkey = crypto.load_privatekey(crypto.FILETYPE_PEM, privkey)
certificate = crypto.load_certificate(crypto.FILETYPE_PEM, certData)
chain = crypto.load_certificate(crypto.FILETYPE_PEM, chain)
sslContext = ssl.CertificateOptions(
privateKey=privkey,
certificate=certificate,
extraCertChain=[chain],
)
# Initialise the server state machine
state = ServerState()
if os.path.isfile(config["motd"]):
with open(config["motd"], "r", encoding="utf-8") as fp:
state.motd = fp.read().strip()
else:
state.motd = None
# Set up the machinery of the server.
factory = RemoteServerFactory(state)
wrappedFactory = HAProxyWrappingFactory(factory)
looper = LoopingCall(factory.pingConnectedClients)
looper.start(PING_INTERVAL)
factory.protocol = Handler
# Start running the server.
if config["no-ssl"]:
reactor.listenTCP(int(config["port"]), wrappedFactory, interface=config["network-interface"])
else:
reactor.listenSSL(int(config["port"]), factory, sslContext, interface=config["network-interface"])
reactor.run()
return defer.Deferred()
if __name__ == "__main__":
res = main()