diff --git a/server.py b/server.py index 9c6e8bc..45859a2 100644 --- a/server.py +++ b/server.py @@ -6,17 +6,18 @@ import time from collections import OrderedDict from logging import getLogger +from typing import Any, TypedDict, cast from OpenSSL import crypto from twisted.internet import reactor, ssl +from twisted.internet.defer import Deferred from twisted.internet.interfaces import ITCPTransport -from twisted.internet.protocol import Factory, defer, connectionDone +from twisted.internet.protocol import Factory, connectionDone, defer from twisted.internet.task import LoopingCall from twisted.protocols.basic import LineReceiver +from twisted.protocols.haproxy._wrapper import HAProxyWrappingFactory from twisted.python import log, usage -from twisted.internet.defer import Deferred from twisted.python.failure import Failure -from typing import Any, TypedDict, cast logger = getLogger("remote-server") @@ -363,27 +364,32 @@ class Options(usage.Options): ["network-interface", "i", "::", "Interface to listen on"], ["port", "p", "6837", "Server port"], ] + optFlags = [ + ["no-ssl", "n", "Disable SSL"], + ] # Exclude from coverage as it's hard to unit test. def main() -> Deferred[None]: # pragma: no cover + sslContext: ssl.CertificateOptions | None = None # Read options from CLI. config = Options() config.parseOptions() - # Open SSL keys. - privkey = open(config["privkey"]).read() - certData = open(config["certificate"], "rb").read() - chain = open(config["chain"], "rb").read() log.startLogging(sys.stdout) - # Initialise encryption - privkey = crypto.load_privatekey(crypto.FILETYPE_PEM, privkey) - certificate = crypto.load_certificate(crypto.FILETYPE_PEM, certData) - chain = crypto.load_certificate(crypto.FILETYPE_PEM, chain) - contextFactory = ssl.CertificateOptions( - privateKey=privkey, - certificate=certificate, - extraCertChain=[chain], - ) + if not config["no-ssl"]: + # Initialise encryption + # Open SSL keys. + privkey = open(config["privkey"]).read() + certData = open(config["certificate"], "rb").read() + chain = open(config["chain"], "rb").read() + privkey = crypto.load_privatekey(crypto.FILETYPE_PEM, privkey) + certificate = crypto.load_certificate(crypto.FILETYPE_PEM, certData) + chain = crypto.load_certificate(crypto.FILETYPE_PEM, chain) + sslContext = ssl.CertificateOptions( + privateKey=privkey, + certificate=certificate, + extraCertChain=[chain], + ) # Initialise the server state machine state = ServerState() if os.path.isfile(config["motd"]): @@ -393,11 +399,15 @@ def main() -> Deferred[None]: # pragma: no cover state.motd = None # Set up the machinery of the server. factory = RemoteServerFactory(state) + wrappedFactory = HAProxyWrappingFactory(factory) looper = LoopingCall(factory.pingConnectedClients) looper.start(PING_INTERVAL) factory.protocol = Handler # Start running the server. - reactor.listenSSL(int(config["port"]), factory, contextFactory, interface=config["network-interface"]) + if config["no-ssl"]: + reactor.listenTCP(int(config["port"]), wrappedFactory, interface=config["network-interface"]) + else: + reactor.listenSSL(int(config["port"]), factory, sslContext, interface=config["network-interface"]) reactor.run() return defer.Deferred()