diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0cd1081 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python +*.so +*.egg +*.egg-info/ +dist/ +build/ diff --git a/README.md b/README.md index 3cde3dc..623ab4f 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,9 @@ Fast and simple to setup MTProto proxy written in Python. 1. `git clone -b stable https://github.com/alexbers/mtprotoproxy.git; cd mtprotoproxy` 2. *(optional, recommended)* edit *config.py*, set **PORT**, **USERS** and **AD_TAG** -3. `docker-compose up -d` (or just `python3 mtprotoproxy.py` if you don't like Docker) -4. *(optional, get a link to share the proxy)* `docker-compose logs` +3. `docker build -t mtprotoproxy .` +4. `docker-compose up -d` (or just `python3 mtprotoproxy.py` if you don't like Docker) +5. *(optional, get a link to share the proxy)* `docker-compose logs` ![Demo](https://alexbers.com/mtprotoproxy/install_demo_v2.gif) diff --git a/config.py b/config.py index 54b614e..ccc7f3b 100644 --- a/config.py +++ b/config.py @@ -1,27 +1,34 @@ +import os + + +def str_to_bool(value): + """Convert string to boolean.""" + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.lower() in ("true", "1", "yes", "on") + return bool(value) + + PORT = 443 # name -> secret (32 hex chars) USERS = { - "tg": "00000000000000000000000000000001", - # "tg2": "0123456789abcdef0123456789abcdef", + "tg": os.environ.get("TG_KEY", "00000000000000000000000000000001"), + # "tg2": "0123456789abcdef0123456789abcdef", } -MODES = { - # Classic mode, easy to detect - "classic": False, - - # Makes the proxy harder to detect - # Can be incompatible with very old clients - "secure": False, +# Makes the proxy harder to detect +# Can be incompatible with very old clients +SECURE_ONLY = str_to_bool(os.environ.get("SECURE_ONLY", "True")) - # Makes the proxy even more hard to detect - # Can be incompatible with old clients - "tls": True -} +# Makes the proxy even more hard to detect +# Compatible only with the recent clients +TLS_ONLY = str_to_bool(os.environ.get("TLS_ONLY", "True")) -# The domain for TLS mode, bad clients are proxied there +# The domain for TLS, bad clients are proxied there # Use random existing domain, proxy checks it on start -# TLS_DOMAIN = "www.google.com" +TLS_DOMAIN = os.environ.get("TLS_DOMAIN", "www.google.com") # Tag for advertising, obtainable from @MTProxybot -# AD_TAG = "3c09c680b76ee91a4c25ad51f742267d" +AD_TAG = os.environ.get("AD_TAG", "3c09c680b76ee91a4c25ad51f742267d") diff --git a/docker-compose.yml b/docker-compose.yml index f490bf6..0284737 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,16 +1,16 @@ version: '2.0' services: mtprotoproxy: - build: . + image: mtprotoproxy restart: unless-stopped network_mode: "host" + environment: + - TG_KEY=00000000000000000000000000000001 + - SECURE_ONLY=true + - TLS_ONLY=true + - TLS_DOMAIN=www.drive.google.com + - AD_TAG=3c09c680b76ee91a4c25ad51f742267d + volumes: - ./config.py:/home/tgproxy/config.py - ./mtprotoproxy.py:/home/tgproxy/mtprotoproxy.py - - /etc/localtime:/etc/localtime:ro - logging: - driver: "json-file" - options: - max-file: "10" - max-size: "10m" -# mem_limit: 1024m diff --git a/mtprotoproxy.py b/mtprotoproxy.py index fb686ab..57fe0f0 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -24,37 +24,53 @@ TG_DATACENTER_PORT = 443 TG_DATACENTERS_V4 = [ - "149.154.175.50", "149.154.167.51", "149.154.175.100", - "149.154.167.91", "149.154.171.5" + "149.154.175.50", + "149.154.167.51", + "149.154.175.100", + "149.154.167.91", + "149.154.171.5", ] TG_DATACENTERS_V6 = [ - "2001:b28:f23d:f001::a", "2001:67c:04e8:f002::a", "2001:b28:f23d:f003::a", - "2001:67c:04e8:f004::a", "2001:b28:f23f:f005::a" + "2001:b28:f23d:f001::a", + "2001:67c:04e8:f002::a", + "2001:b28:f23d:f003::a", + "2001:67c:04e8:f004::a", + "2001:b28:f23f:f005::a", ] # This list will be updated in the runtime TG_MIDDLE_PROXIES_V4 = { - 1: [("149.154.175.50", 8888)], -1: [("149.154.175.50", 8888)], - 2: [("149.154.161.144", 8888)], -2: [("149.154.161.144", 8888)], - 3: [("149.154.175.100", 8888)], -3: [("149.154.175.100", 8888)], - 4: [("91.108.4.136", 8888)], -4: [("149.154.165.109", 8888)], - 5: [("91.108.56.183", 8888)], -5: [("91.108.56.183", 8888)] + 1: [("149.154.175.50", 8888)], + -1: [("149.154.175.50", 8888)], + 2: [("149.154.161.144", 8888)], + -2: [("149.154.161.144", 8888)], + 3: [("149.154.175.100", 8888)], + -3: [("149.154.175.100", 8888)], + 4: [("91.108.4.136", 8888)], + -4: [("149.154.165.109", 8888)], + 5: [("91.108.56.183", 8888)], + -5: [("91.108.56.183", 8888)], } TG_MIDDLE_PROXIES_V6 = { - 1: [("2001:b28:f23d:f001::d", 8888)], -1: [("2001:b28:f23d:f001::d", 8888)], - 2: [("2001:67c:04e8:f002::d", 80)], -2: [("2001:67c:04e8:f002::d", 80)], - 3: [("2001:b28:f23d:f003::d", 8888)], -3: [("2001:b28:f23d:f003::d", 8888)], - 4: [("2001:67c:04e8:f004::d", 8888)], -4: [("2001:67c:04e8:f004::d", 8888)], - 5: [("2001:b28:f23f:f005::d", 8888)], -5: [("2001:b28:f23f:f005::d", 8888)] + 1: [("2001:b28:f23d:f001::d", 8888)], + -1: [("2001:b28:f23d:f001::d", 8888)], + 2: [("2001:67c:04e8:f002::d", 80)], + -2: [("2001:67c:04e8:f002::d", 80)], + 3: [("2001:b28:f23d:f003::d", 8888)], + -3: [("2001:b28:f23d:f003::d", 8888)], + 4: [("2001:67c:04e8:f004::d", 8888)], + -4: [("2001:67c:04e8:f004::d", 8888)], + 5: [("2001:b28:f23f:f005::d", 8888)], + -5: [("2001:b28:f23f:f005::d", 8888)], } PROXY_SECRET = bytes.fromhex( - "c4f9faca9678e6bb48ad6c7e2ce5c0d24430645d554addeb55419e034da62721" + - "d046eaab6e52ab14a95a443ecfb3463e79a05a66612adf9caeda8be9a80da698" + - "6fb0a6ff387af84d88ef3a6413713e5c3377f6e1a3d47d99f5e0c56eece8f05c" + - "54c490b079e31bef82ff0ee8f2b0a32756d249c5f21269816cb7061b265db212" + "c4f9faca9678e6bb48ad6c7e2ce5c0d24430645d554addeb55419e034da62721" + + "d046eaab6e52ab14a95a443ecfb3463e79a05a66612adf9caeda8be9a80da698" + + "6fb0a6ff387af84d88ef3a6413713e5c3377f6e1a3d47d99f5e0c56eece8f05c" + + "54c490b079e31bef82ff0ee8f2b0a32756d249c5f21269816cb7061b265db212" ) SKIP_LEN = 8 @@ -75,7 +91,7 @@ PADDING_FILLER = b"\x04\x00\x00\x00" MIN_MSG_LEN = 12 -MAX_MSG_LEN = 2 ** 24 +MAX_MSG_LEN = 2**24 STAT_DURATION_BUCKETS = [0.1, 0.5, 1, 2, 5, 15, 60, 300, 600, 1800, 2**31 - 1] @@ -111,7 +127,9 @@ def init_config(): conf_dict = {} conf_dict["PORT"] = int(sys.argv[1]) secrets = sys.argv[2].split(",") - conf_dict["USERS"] = {"user%d" % i: secrets[i].zfill(32) for i in range(len(secrets))} + conf_dict["USERS"] = { + "user%d" % i: secrets[i].zfill(32) for i in range(len(secrets)) + } conf_dict["MODES"] = {"classic": False, "secure": True, "tls": True} if len(sys.argv) > 3: conf_dict["AD_TAG"] = sys.argv[3] @@ -122,14 +140,17 @@ def init_config(): conf_dict = {k: v for k, v in conf_dict.items() if k.isupper()} conf_dict.setdefault("PORT", 3256) - conf_dict.setdefault("USERS", {"tg": "00000000000000000000000000000000"}) + conf_dict.setdefault("USERS", {"tg": "00000000000000000000000000000000"}) conf_dict["AD_TAG"] = bytes.fromhex(conf_dict.get("AD_TAG", "")) for user, secret in conf_dict["USERS"].items(): if not re.fullmatch("[0-9a-fA-F]{32}", secret): fixed_secret = re.sub(r"[^0-9a-fA-F]", "", secret).zfill(32)[:32] - print_err("Bad secret for user %s, should be 32 hex chars, got %s. " % (user, secret)) + print_err( + "Bad secret for user %s, should be 32 hex chars, got %s. " + % (user, secret) + ) print_err("Changing it to %s" % fixed_secret) conf_dict["USERS"][user] = fixed_secret @@ -218,7 +239,9 @@ def init_config(): # expiration date for users in format of day/month/year conf_dict.setdefault("USER_EXPIRATIONS", {}) for user in conf_dict["USER_EXPIRATIONS"]: - expiration = datetime.datetime.strptime(conf_dict["USER_EXPIRATIONS"][user], "%d/%m/%Y") + expiration = datetime.datetime.strptime( + conf_dict["USER_EXPIRATIONS"][user], "%d/%m/%Y" + ) conf_dict["USER_EXPIRATIONS"][user] = expiration # the data quota for user @@ -237,13 +260,15 @@ def init_config(): conf_dict.setdefault("STATS_PRINT_PERIOD", 600) # delay in seconds between middle proxy info updates - conf_dict.setdefault("PROXY_INFO_UPDATE_PERIOD", 24*60*60) + conf_dict.setdefault("PROXY_INFO_UPDATE_PERIOD", 24 * 60 * 60) # delay in seconds between time getting, zero means disabled - conf_dict.setdefault("GET_TIME_PERIOD", 10*60) + conf_dict.setdefault("GET_TIME_PERIOD", 10 * 60) # delay in seconds between getting the length of certificate on the mask host - conf_dict.setdefault("GET_CERT_LEN_PERIOD", random.randrange(4*60*60, 6*60*60)) + conf_dict.setdefault( + "GET_CERT_LEN_PERIOD", random.randrange(4 * 60 * 60, 6 * 60 * 60) + ) # max socket buffer size to the client direction, the more the faster, but more RAM hungry # can be the tuple (low, users_margin, high) for the adaptive case. If no much users, use high @@ -253,13 +278,13 @@ def init_config(): conf_dict.setdefault("TO_TG_BUFSIZE", 65536) # keepalive period for clients in secs - conf_dict.setdefault("CLIENT_KEEPALIVE", 10*60) + conf_dict.setdefault("CLIENT_KEEPALIVE", 10 * 60) # drop client after this timeout if the handshake fail conf_dict.setdefault("CLIENT_HANDSHAKE_TIMEOUT", random.randrange(5, 15)) # if client doesn't confirm data for this number of seconds, it is dropped - conf_dict.setdefault("CLIENT_ACK_TIMEOUT", 5*60) + conf_dict.setdefault("CLIENT_ACK_TIMEOUT", 5 * 60) # telegram servers connect timeout in seconds conf_dict.setdefault("TG_CONNECT_TIMEOUT", 10) @@ -299,9 +324,17 @@ def apply_upstream_proxy_settings(): # apply socks settings in place if config.SOCKS5_HOST and config.SOCKS5_PORT: import socks - print_err("Socket-proxy mode activated, it is incompatible with advertising and uvloop") - socks.set_default_proxy(socks.PROXY_TYPE_SOCKS5, config.SOCKS5_HOST, config.SOCKS5_PORT, - username=config.SOCKS5_USER, password=config.SOCKS5_PASS) + + print_err( + "Socket-proxy mode activated, it is incompatible with advertising and uvloop" + ) + socks.set_default_proxy( + socks.PROXY_TYPE_SOCKS5, + config.SOCKS5_HOST, + config.SOCKS5_PORT, + username=config.SOCKS5_USER, + password=config.SOCKS5_PASS, + ) if not hasattr(socket, "origsocket"): socket.origsocket = socket.socket socket.socket = socks.socksocket @@ -315,7 +348,7 @@ def try_use_cryptography_module(): from cryptography.hazmat.backends import default_backend class CryptographyEncryptorAdapter: - __slots__ = ('encryptor', 'decryptor') + __slots__ = ("encryptor", "decryptor") def __init__(self, cipher): self.encryptor = cipher.encryptor() @@ -361,7 +394,7 @@ def use_slow_bundled_cryptography_module(): print(msg, flush=True, file=sys.stderr) class BundledEncryptorAdapter: - __slots__ = ('mode', ) + __slots__ = ("mode",) def __init__(self, mode): self.mode = mode @@ -381,6 +414,7 @@ def create_aes_ctr(key, iv): def create_aes_cbc(key, iv): mode = pyaes.AESModeOfOperationCBC(key, iv) return BundledEncryptorAdapter(mode) + return create_aes_ctr, create_aes_cbc @@ -465,13 +499,13 @@ def __init__(self): def getrandbits(self, k): numbytes = (k + 7) // 8 - return int.from_bytes(self.getrandbytes(numbytes), 'big') >> (numbytes * 8 - k) + return int.from_bytes(self.getrandbytes(numbytes), "big") >> (numbytes * 8 - k) def getrandbytes(self, n): CHUNK_SIZE = 512 while n > len(self.buffer): - data = int.to_bytes(super().getrandbits(CHUNK_SIZE*8), CHUNK_SIZE, "big") + data = int.to_bytes(super().getrandbits(CHUNK_SIZE * 8), CHUNK_SIZE, "big") self.buffer += self.encryptor.encrypt(data) result = self.buffer[:n] @@ -490,46 +524,63 @@ def __init__(self): async def open_tg_connection(self, host, port, init_func=None): task = asyncio.open_connection(host, port, limit=get_to_clt_bufsize()) - reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=config.TG_CONNECT_TIMEOUT) + reader_tgt, writer_tgt = await asyncio.wait_for( + task, timeout=config.TG_CONNECT_TIMEOUT + ) set_keepalive(writer_tgt.get_extra_info("socket")) - set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize()) + set_bufsizes( + writer_tgt.get_extra_info("socket"), + get_to_clt_bufsize(), + get_to_tg_bufsize(), + ) if init_func: - return await asyncio.wait_for(init_func(host, port, reader_tgt, writer_tgt), - timeout=config.TG_CONNECT_TIMEOUT) + return await asyncio.wait_for( + init_func(host, port, reader_tgt, writer_tgt), + timeout=config.TG_CONNECT_TIMEOUT, + ) return reader_tgt, writer_tgt def register_host_port(self, host, port, init_func): if (host, port, init_func) not in self.pools: self.pools[(host, port, init_func)] = [] - while len(self.pools[(host, port, init_func)]) < TgConnectionPool.MAX_CONNS_IN_POOL: - connect_task = asyncio.ensure_future(self.open_tg_connection(host, port, init_func)) + while ( + len(self.pools[(host, port, init_func)]) + < TgConnectionPool.MAX_CONNS_IN_POOL + ): + connect_task = asyncio.ensure_future( + self.open_tg_connection(host, port, init_func) + ) self.pools[(host, port, init_func)].append(connect_task) async def get_connection(self, host, port, init_func=None): - self.register_host_port(host, port, init_func) - - ret = None - for task in self.pools[(host, port, init_func)][::]: - if task.done(): - if task.exception(): - self.pools[(host, port, init_func)].remove(task) - continue - - reader, writer, *other = task.result() - if writer.transport.is_closing(): - self.pools[(host, port, init_func)].remove(task) - continue - - if not ret: - self.pools[(host, port, init_func)].remove(task) - ret = (reader, writer, *other) - - self.register_host_port(host, port, init_func) - if ret: - return ret + # Connection pool disabled to prevent Telegram showing "updating" status + # for users with fewer connections. This improves user experience by + # establishing fresh connections instead of reusing pooled ones. + # Original pooling logic commented out below: + # self.register_host_port(host, port, init_func) + + # ret = None + # for task in self.pools[(host, port, init_func)][::]: + # if task.done(): + # if task.exception(): + # self.pools[(host, port, init_func)].remove(task) + # continue + + # reader, writer, *other = task.result() + # if writer.transport.is_closing(): + # self.pools[(host, port, init_func)].remove(task) + # continue + + # if not ret: + # self.pools[(host, port, init_func)].remove(task) + # ret = (reader, writer, *other) + + # self.register_host_port(host, port, init_func) + # if ret: + # return ret return await self.open_tg_connection(host, port, init_func) @@ -537,7 +588,7 @@ async def get_connection(self, host, port, init_func=None): class LayeredStreamReaderBase: - __slots__ = ("upstream", ) + __slots__ = ("upstream",) def __init__(self, upstream): self.upstream = upstream @@ -550,7 +601,7 @@ async def readexactly(self, n): class LayeredStreamWriterBase: - __slots__ = ("upstream", ) + __slots__ = ("upstream",) def __init__(self, upstream): self.upstream = upstream @@ -579,7 +630,7 @@ def transport(self): class FakeTLSStreamReader(LayeredStreamReaderBase): - __slots__ = ('buf', ) + __slots__ = ("buf",) def __init__(self, upstream): self.upstream = upstream @@ -630,14 +681,14 @@ def __init__(self, upstream): def write(self, data, extra={}): MAX_CHUNK_SIZE = 16384 + 24 for start in range(0, len(data), MAX_CHUNK_SIZE): - end = min(start+MAX_CHUNK_SIZE, len(data)) - self.upstream.write(b"\x17\x03\x03" + int.to_bytes(end-start, 2, "big")) - self.upstream.write(data[start: end]) + end = min(start + MAX_CHUNK_SIZE, len(data)) + self.upstream.write(b"\x17\x03\x03" + int.to_bytes(end - start, 2, "big")) + self.upstream.write(data[start:end]) return len(data) class CryptoWrappedStreamReader(LayeredStreamReaderBase): - __slots__ = ('decryptor', 'block_size', 'buf') + __slots__ = ("decryptor", "block_size", "buf") def __init__(self, upstream, decryptor, block_size=1): self.upstream = upstream @@ -675,7 +726,7 @@ async def readexactly(self, n): class CryptoWrappedStreamWriter(LayeredStreamWriterBase): - __slots__ = ('encryptor', 'block_size') + __slots__ = ("encryptor", "block_size") def __init__(self, upstream, encryptor, block_size=1): self.upstream = upstream @@ -684,15 +735,17 @@ def __init__(self, upstream, encryptor, block_size=1): def write(self, data, extra={}): if len(data) % self.block_size != 0: - print_err("BUG: writing %d bytes not aligned to block size %d" % ( - len(data), self.block_size)) + print_err( + "BUG: writing %d bytes not aligned to block size %d" + % (len(data), self.block_size) + ) return 0 q = self.encryptor.encrypt(data) return self.upstream.write(q) class MTProtoFrameStreamReader(LayeredStreamReaderBase): - __slots__ = ('seq_no', ) + __slots__ = ("seq_no",) def __init__(self, upstream, seq_no=0): self.upstream = upstream @@ -706,7 +759,7 @@ async def read(self, buf_size): msg_len_bytes = await self.upstream.readexactly(4) msg_len = int.from_bytes(msg_len_bytes, "little") - len_is_bad = (msg_len % len(PADDING_FILLER) != 0) + len_is_bad = msg_len % len(PADDING_FILLER) != 0 if not MIN_MSG_LEN <= msg_len <= MAX_MSG_LEN or len_is_bad: print_err("msg_len is bad, closing connection", msg_len) return b"" @@ -730,7 +783,7 @@ async def read(self, buf_size): class MTProtoFrameStreamWriter(LayeredStreamWriterBase): - __slots__ = ('seq_no', ) + __slots__ = ("seq_no",) def __init__(self, upstream, seq_no=0): self.upstream = upstream @@ -745,7 +798,9 @@ def write(self, msg, extra={}): checksum = int.to_bytes(binascii.crc32(msg_without_checksum), 4, "little") full_msg = msg_without_checksum + checksum - padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER)) + padding = PADDING_FILLER * ( + (-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER) + ) return self.upstream.write(full_msg + padding) @@ -762,7 +817,7 @@ async def read(self, buf_size): extra["QUICKACK_FLAG"] = True msg_len -= 0x80 - if msg_len == 0x7f: + if msg_len == 0x7F: msg_len_bytes = await self.upstream.readexactly(3) msg_len = int.from_bytes(msg_len_bytes, "little") @@ -777,11 +832,14 @@ class MTProtoCompactFrameStreamWriter(LayeredStreamWriterBase): __slots__ = () def write(self, data, extra={}): - SMALL_PKT_BORDER = 0x7f - LARGE_PKT_BORGER = 256 ** 3 + SMALL_PKT_BORDER = 0x7F + LARGE_PKT_BORGER = 256**3 if len(data) % 4 != 0: - print_err("BUG: MTProtoFrameStreamWriter attempted to send msg with len", len(data)) + print_err( + "BUG: MTProtoFrameStreamWriter attempted to send msg with len", + len(data), + ) return 0 if extra.get("SIMPLE_ACK"): @@ -792,7 +850,9 @@ def write(self, data, extra={}): if len_div_four < SMALL_PKT_BORDER: return self.upstream.write(bytes([len_div_four]) + data) elif len_div_four < LARGE_PKT_BORGER: - return self.upstream.write(b'\x7f' + int.to_bytes(len_div_four, 3, 'little') + data) + return self.upstream.write( + b"\x7f" + int.to_bytes(len_div_four, 3, "little") + data + ) else: print_err("Attempted to send too large pkt len =", len(data)) return 0 @@ -821,7 +881,7 @@ def write(self, data, extra={}): if extra.get("SIMPLE_ACK"): return self.upstream.write(data) else: - return self.upstream.write(int.to_bytes(len(data), 4, 'little') + data) + return self.upstream.write(int.to_bytes(len(data), 4, "little") + data) class MTProtoSecureIntermediateFrameStreamReader(LayeredStreamReaderBase): @@ -856,7 +916,7 @@ def write(self, data, extra={}): else: padding_len = myrandom.randrange(MAX_PADDING_LEN) padding = myrandom.getrandbytes(padding_len) - padded_data_len_bytes = int.to_bytes(len(data) + padding_len, 4, 'little') + padded_data_len_bytes = int.to_bytes(len(data) + padding_len, 4, "little") return self.upstream.write(padded_data_len_bytes + data + padding) @@ -867,7 +927,7 @@ async def read(self, msg): RPC_PROXY_ANS = b"\x0d\xda\x03\x44" RPC_CLOSE_EXT = b"\xa2\x34\xb6\x5e" RPC_SIMPLE_ACK = b"\x9b\x40\xac\x3b" - RPC_UNKNOWN = b'\xdf\xa2\x30\x57' + RPC_UNKNOWN = b"\xdf\xa2\x30\x57" data = await self.upstream.read(1) @@ -894,7 +954,7 @@ async def read(self, msg): class ProxyReqStreamWriter(LayeredStreamWriterBase): - __slots__ = ('remote_ip_port', 'our_ip_port', 'out_conn_id', 'proto_tag') + __slots__ = ("remote_ip_port", "our_ip_port", "out_conn_id", "proto_tag") def __init__(self, upstream, cl_ip, cl_port, my_ip, my_port, proto_tag): self.upstream = upstream @@ -978,7 +1038,9 @@ def set_keepalive(sock, interval=40, attempts=5): def set_ack_timeout(sock, timeout): if hasattr(socket, "TCP_USER_TIMEOUT"): - try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, timeout*1000) + try_setsockopt( + sock, socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, timeout * 1000 + ) def set_bufsizes(sock, recv_buf, send_buf): @@ -996,7 +1058,7 @@ def gen_x25519_public_key(): # generates some number which has square root by modulo P P = 2**255 - 19 n = myrandom.randrange(P) - return int.to_bytes((n*n) % P, length=32, byteorder="little") + return int.to_bytes((n * n) % P, length=32, byteorder="little") async def connect_reader_to_writer(reader, writer): @@ -1051,7 +1113,9 @@ async def handle_bad_client(reader_clt, writer_clt, handshake): task_srv_to_clt = asyncio.ensure_future(srv_to_clt) task_clt_to_srv = asyncio.ensure_future(clt_to_srv) - await asyncio.wait([task_srv_to_clt, task_clt_to_srv], return_when=asyncio.FIRST_COMPLETED) + await asyncio.wait( + [task_srv_to_clt, task_clt_to_srv], return_when=asyncio.FIRST_COMPLETED + ) task_srv_to_clt.cancel() task_clt_to_srv.cancel() @@ -1062,7 +1126,7 @@ async def handle_bad_client(reader_clt, writer_clt, handshake): # if the server closed the connection with RST or FIN-RST, copy them to the client if not writer_srv.transport.is_closing(): # workaround for uvloop, it doesn't fire exceptions on write_eof - sock = writer_srv.get_extra_info('socket') + sock = writer_srv.get_extra_info("socket") raw_sock = socket.socket(sock.family, sock.type, sock.proto, sock.fileno()) try: raw_sock.shutdown(socket.SHUT_WR) @@ -1107,23 +1171,27 @@ async def handle_fake_tls_handshake(handshake, reader, writer, peer): tls_extensions = b"\x00\x2e" + b"\x00\x33\x00\x24" + b"\x00\x1d\x00\x20" tls_extensions += gen_x25519_public_key() + b"\x00\x2b\x00\x02\x03\x04" - digest = handshake[DIGEST_POS:DIGEST_POS+DIGEST_LEN] + digest = handshake[DIGEST_POS : DIGEST_POS + DIGEST_LEN] if digest[:DIGEST_HALFLEN] in used_handshakes: last_clients_with_same_handshake[peer[0]] += 1 return False sess_id_len = handshake[SESSION_ID_LEN_POS] - sess_id = handshake[SESSION_ID_POS:SESSION_ID_POS+sess_id_len] + sess_id = handshake[SESSION_ID_POS : SESSION_ID_POS + sess_id_len] for user in config.USERS: secret = bytes.fromhex(config.USERS[user]) - msg = handshake[:DIGEST_POS] + b"\x00"*DIGEST_LEN + handshake[DIGEST_POS+DIGEST_LEN:] + msg = ( + handshake[:DIGEST_POS] + + b"\x00" * DIGEST_LEN + + handshake[DIGEST_POS + DIGEST_LEN :] + ) computed_digest = hmac.new(secret, msg, digestmod=hashlib.sha256).digest() xored_digest = bytes(digest[i] ^ computed_digest[i] for i in range(DIGEST_LEN)) - digest_good = xored_digest.startswith(b"\x00" * (DIGEST_LEN-4)) + digest_good = xored_digest.startswith(b"\x00" * (DIGEST_LEN - 4)) if not digest_good: continue @@ -1132,8 +1200,10 @@ async def handle_fake_tls_handshake(handshake, reader, writer, peer): client_time_is_ok = TIME_SKEW_MIN < time.time() - timestamp < TIME_SKEW_MAX # some clients fail to read unix time and send the time since boot instead - client_time_is_small = timestamp < 60*60*24*1000 - accept_bad_time = config.IGNORE_TIME_SKEW or is_time_skewed or client_time_is_small + client_time_is_small = timestamp < 60 * 60 * 24 * 1000 + accept_bad_time = ( + config.IGNORE_TIME_SKEW or is_time_skewed or client_time_is_small + ) if not client_time_is_ok and not accept_bad_time: last_clients_with_time_skew[peer[0]] = (time.time() - timestamp) // 60 @@ -1141,7 +1211,7 @@ async def handle_fake_tls_handshake(handshake, reader, writer, peer): http_data = myrandom.getrandbytes(fake_cert_len) - srv_hello = TLS_VERS + b"\x00"*DIGEST_LEN + bytes([sess_id_len]) + sess_id + srv_hello = TLS_VERS + b"\x00" * DIGEST_LEN + bytes([sess_id_len]) + sess_id srv_hello += TLS_CIPHERSUITE + b"\x00" + tls_extensions hello_pkt = b"\x16" + TLS_VERS + int.to_bytes(len(srv_hello) + 4, 2, "big") @@ -1149,8 +1219,14 @@ async def handle_fake_tls_handshake(handshake, reader, writer, peer): hello_pkt += TLS_CHANGE_CIPHER + TLS_APP_HTTP2_HDR hello_pkt += int.to_bytes(len(http_data), 2, "big") + http_data - computed_digest = hmac.new(secret, msg=digest+hello_pkt, digestmod=hashlib.sha256).digest() - hello_pkt = hello_pkt[:DIGEST_POS] + computed_digest + hello_pkt[DIGEST_POS+DIGEST_LEN:] + computed_digest = hmac.new( + secret, msg=digest + hello_pkt, digestmod=hashlib.sha256 + ).digest() + hello_pkt = ( + hello_pkt[:DIGEST_POS] + + computed_digest + + hello_pkt[DIGEST_POS + DIGEST_LEN :] + ) writer.write(hello_pkt) await writer.drain() @@ -1194,8 +1270,8 @@ async def handle_proxy_protocol(reader, peer=None): _, proxy_fam, *proxy_addr = header[:-2].split(b" ") if proxy_fam in (PROXY_TCP4, PROXY_TCP6): if len(proxy_addr) == 4: - src_addr = proxy_addr[0].decode('ascii') - src_port = int(proxy_addr[2].decode('ascii')) + src_addr = proxy_addr[0].decode("ascii") + src_port = int(proxy_addr[2].decode("ascii")) return (src_addr, src_port) elif proxy_fam == PROXY_UNKNOWN: return peer @@ -1205,19 +1281,19 @@ async def handle_proxy_protocol(reader, peer=None): if header.startswith(PROXY2_SIGNATURE): # proxy header v2 proxy_ver = header[12] - if proxy_ver & 0xf0 != 0x20: + if proxy_ver & 0xF0 != 0x20: return False proxy_len = int.from_bytes(header[14:16], "big") proxy_addr = await reader.readexactly(proxy_len) if proxy_ver == 0x21: proxy_fam = header[13] >> 4 if proxy_fam == PROXY2_AF_INET: - if proxy_len >= (4 + 2)*2: + if proxy_len >= (4 + 2) * 2: src_addr = socket.inet_ntop(socket.AF_INET, proxy_addr[:4]) src_port = int.from_bytes(proxy_addr[8:10], "big") return (src_addr, src_port) elif proxy_fam == PROXY2_AF_INET6: - if proxy_len >= (16 + 2)*2: + if proxy_len >= (16 + 2) * 2: src_addr = socket.inet_ntop(socket.AF_INET6, proxy_addr[:16]) src_port = int.from_bytes(proxy_addr[32:34], "big") return (src_addr, src_port) @@ -1268,7 +1344,9 @@ async def handle_handshake(reader, writer): if is_tls_handshake: handshake += await reader.readexactly(tls_handshake_len) - tls_handshake_result = await handle_fake_tls_handshake(handshake, reader, writer, peer) + tls_handshake_result = await handle_fake_tls_handshake( + handshake, reader, writer, peer + ) if not tls_handshake_result: await handle_bad_client(reader, writer, handshake) @@ -1281,9 +1359,9 @@ async def handle_handshake(reader, writer): return False handshake += await reader.readexactly(HANDSHAKE_LEN - len(handshake)) - dec_prekey_and_iv = handshake[SKIP_LEN:SKIP_LEN+PREKEY_LEN+IV_LEN] + dec_prekey_and_iv = handshake[SKIP_LEN : SKIP_LEN + PREKEY_LEN + IV_LEN] dec_prekey, dec_iv = dec_prekey_and_iv[:PREKEY_LEN], dec_prekey_and_iv[PREKEY_LEN:] - enc_prekey_and_iv = handshake[SKIP_LEN:SKIP_LEN+PREKEY_LEN+IV_LEN][::-1] + enc_prekey_and_iv = handshake[SKIP_LEN : SKIP_LEN + PREKEY_LEN + IV_LEN][::-1] enc_prekey, enc_iv = enc_prekey_and_iv[:PREKEY_LEN], enc_prekey_and_iv[PREKEY_LEN:] if dec_prekey_and_iv in used_handshakes: @@ -1302,8 +1380,12 @@ async def handle_handshake(reader, writer): decrypted = decryptor.decrypt(handshake) - proto_tag = decrypted[PROTO_TAG_POS:PROTO_TAG_POS+4] - if proto_tag not in (PROTO_TAG_ABRIDGED, PROTO_TAG_INTERMEDIATE, PROTO_TAG_SECURE): + proto_tag = decrypted[PROTO_TAG_POS : PROTO_TAG_POS + 4] + if proto_tag not in ( + PROTO_TAG_ABRIDGED, + PROTO_TAG_INTERMEDIATE, + PROTO_TAG_SECURE, + ): continue if proto_tag == PROTO_TAG_SECURE: @@ -1315,7 +1397,9 @@ async def handle_handshake(reader, writer): if not config.MODES["classic"]: continue - dc_idx = int.from_bytes(decrypted[DC_IDX_POS:DC_IDX_POS+2], "little", signed=True) + dc_idx = int.from_bytes( + decrypted[DC_IDX_POS : DC_IDX_POS + 2], "little", signed=True + ) if config.REPLAY_CHECK_LEN > 0: while len(used_handshakes) >= config.REPLAY_CHECK_LEN: @@ -1339,9 +1423,14 @@ async def handle_handshake(reader, writer): async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): RESERVED_NONCE_FIRST_CHARS = [b"\xef"] - RESERVED_NONCE_BEGININGS = [b"\x48\x45\x41\x44", b"\x50\x4F\x53\x54", - b"\x47\x45\x54\x20", b"\xee\xee\xee\xee", - b"\xdd\xdd\xdd\xdd", b"\x16\x03\x01\x02"] + RESERVED_NONCE_BEGININGS = [ + b"\x48\x45\x41\x44", + b"\x50\x4f\x53\x54", + b"\x47\x45\x54\x20", + b"\xee\xee\xee\xee", + b"\xdd\xdd\xdd\xdd", + b"\x16\x03\x01\x02", + ] RESERVED_NONCE_CONTINUES = [b"\x00\x00\x00\x00"] global my_ip_info @@ -1359,12 +1448,19 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): dc = TG_DATACENTERS_V4[dc_idx] try: - reader_tgt, writer_tgt = await tg_connection_pool.get_connection(dc, TG_DATACENTER_PORT) + reader_tgt, writer_tgt = await tg_connection_pool.get_connection( + dc, TG_DATACENTER_PORT + ) except ConnectionRefusedError as E: - print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT) + print_err( + "Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT + ) return False except ConnectionAbortedError as E: - print_err("The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, addr, port, E)) + print_err( + "The Telegram server connection is bad: %d (%s %s) %s" + % (dc_idx, addr, port, E) + ) return False except (OSError, asyncio.TimeoutError) as E: print_err("Unable to connect to", dc, TG_DATACENTER_PORT) @@ -1380,18 +1476,18 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): continue break - rnd[PROTO_TAG_POS:PROTO_TAG_POS+4] = proto_tag + rnd[PROTO_TAG_POS : PROTO_TAG_POS + 4] = proto_tag if dec_key_and_iv: - rnd[SKIP_LEN:SKIP_LEN+KEY_LEN+IV_LEN] = dec_key_and_iv[::-1] + rnd[SKIP_LEN : SKIP_LEN + KEY_LEN + IV_LEN] = dec_key_and_iv[::-1] rnd = bytes(rnd) - dec_key_and_iv = rnd[SKIP_LEN:SKIP_LEN+KEY_LEN+IV_LEN][::-1] + dec_key_and_iv = rnd[SKIP_LEN : SKIP_LEN + KEY_LEN + IV_LEN][::-1] dec_key, dec_iv = dec_key_and_iv[:KEY_LEN], dec_key_and_iv[KEY_LEN:] decryptor = create_aes_ctr(key=dec_key, iv=int.from_bytes(dec_iv, "big")) - enc_key_and_iv = rnd[SKIP_LEN:SKIP_LEN+KEY_LEN+IV_LEN] + enc_key_and_iv = rnd[SKIP_LEN : SKIP_LEN + KEY_LEN + IV_LEN] enc_key, enc_iv = enc_key_and_iv[:KEY_LEN], enc_key_and_iv[KEY_LEN:] encryptor = create_aes_ctr(key=enc_key, iv=int.from_bytes(enc_iv, "big")) @@ -1406,9 +1502,19 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): return reader_tgt, writer_tgt -def get_middleproxy_aes_key_and_iv(nonce_srv, nonce_clt, clt_ts, srv_ip, clt_port, purpose, - clt_ip, srv_port, middleproxy_secret, clt_ipv6=None, - srv_ipv6=None): +def get_middleproxy_aes_key_and_iv( + nonce_srv, + nonce_clt, + clt_ts, + srv_ip, + clt_port, + purpose, + clt_ip, + srv_port, + middleproxy_secret, + clt_ipv6=None, + srv_ipv6=None, +): EMPTY_IP = b"\x00\x00\x00\x00" if not clt_ip or not srv_ip: @@ -1416,7 +1522,9 @@ def get_middleproxy_aes_key_and_iv(nonce_srv, nonce_clt, clt_ts, srv_ip, clt_por srv_ip = EMPTY_IP s = bytearray() - s += nonce_srv + nonce_clt + clt_ts + srv_ip + clt_port + purpose + clt_ip + srv_port + s += ( + nonce_srv + nonce_clt + clt_ts + srv_ip + clt_port + purpose + clt_ip + srv_port + ) s += middleproxy_secret + nonce_srv if clt_ipv6 and srv_ipv6: @@ -1433,7 +1541,7 @@ def get_middleproxy_aes_key_and_iv(nonce_srv, nonce_clt, clt_ts, srv_ip, clt_por async def middleproxy_handshake(host, port, reader_tgt, writer_tgt): - """ The most logic of middleproxy handshake, launched in pool """ + """The most logic of middleproxy handshake, launched in pool""" START_SEQ_NO = -2 NONCE_LEN = 16 @@ -1464,17 +1572,25 @@ async def middleproxy_handshake(host, port, reader_tgt, writer_tgt): raise ConnectionAbortedError("bad rpc answer length") rpc_type, rpc_key_selector, rpc_schema, rpc_crypto_ts, rpc_nonce = ( - ans[:4], ans[4:8], ans[8:12], ans[12:16], ans[16:32] + ans[:4], + ans[4:8], + ans[8:12], + ans[12:16], + ans[16:32], ) - if rpc_type != RPC_NONCE or rpc_key_selector != key_selector or rpc_schema != CRYPTO_AES: + if ( + rpc_type != RPC_NONCE + or rpc_key_selector != key_selector + or rpc_schema != CRYPTO_AES + ): raise ConnectionAbortedError("bad rpc answer") # get keys - tg_ip, tg_port = writer_tgt.upstream.get_extra_info('peername')[:2] - my_ip, my_port = writer_tgt.upstream.get_extra_info('sockname')[:2] + tg_ip, tg_port = writer_tgt.upstream.get_extra_info("peername")[:2] + my_ip, my_port = writer_tgt.upstream.get_extra_info("sockname")[:2] - use_ipv6_tg = (":" in tg_ip) + use_ipv6_tg = ":" in tg_ip if not use_ipv6_tg: if my_ip_info["ipv4"]: @@ -1500,14 +1616,32 @@ async def middleproxy_handshake(host, port, reader_tgt, writer_tgt): my_port_bytes = int.to_bytes(my_port, 2, "little") enc_key, enc_iv = get_middleproxy_aes_key_and_iv( - nonce_srv=rpc_nonce, nonce_clt=nonce, clt_ts=crypto_ts, srv_ip=tg_ip_bytes, - clt_port=my_port_bytes, purpose=b"CLIENT", clt_ip=my_ip_bytes, srv_port=tg_port_bytes, - middleproxy_secret=PROXY_SECRET, clt_ipv6=my_ipv6_bytes, srv_ipv6=tg_ipv6_bytes) + nonce_srv=rpc_nonce, + nonce_clt=nonce, + clt_ts=crypto_ts, + srv_ip=tg_ip_bytes, + clt_port=my_port_bytes, + purpose=b"CLIENT", + clt_ip=my_ip_bytes, + srv_port=tg_port_bytes, + middleproxy_secret=PROXY_SECRET, + clt_ipv6=my_ipv6_bytes, + srv_ipv6=tg_ipv6_bytes, + ) dec_key, dec_iv = get_middleproxy_aes_key_and_iv( - nonce_srv=rpc_nonce, nonce_clt=nonce, clt_ts=crypto_ts, srv_ip=tg_ip_bytes, - clt_port=my_port_bytes, purpose=b"SERVER", clt_ip=my_ip_bytes, srv_port=tg_port_bytes, - middleproxy_secret=PROXY_SECRET, clt_ipv6=my_ipv6_bytes, srv_ipv6=tg_ipv6_bytes) + nonce_srv=rpc_nonce, + nonce_clt=nonce, + clt_ts=crypto_ts, + srv_ip=tg_ip_bytes, + clt_port=my_port_bytes, + purpose=b"SERVER", + clt_ip=my_ip_bytes, + srv_port=tg_port_bytes, + middleproxy_secret=PROXY_SECRET, + clt_ipv6=my_ipv6_bytes, + srv_ipv6=tg_ipv6_bytes, + ) encryptor = create_aes_cbc(key=enc_key, iv=enc_iv) decryptor = create_aes_cbc(key=dec_key, iv=dec_iv) @@ -1518,18 +1652,26 @@ async def middleproxy_handshake(host, port, reader_tgt, writer_tgt): # TODO: pass client ip and port here for statistics handshake = RPC_HANDSHAKE + RPC_FLAGS + SENDER_PID + PEER_PID - writer_tgt.upstream = CryptoWrappedStreamWriter(writer_tgt.upstream, encryptor, block_size=16) + writer_tgt.upstream = CryptoWrappedStreamWriter( + writer_tgt.upstream, encryptor, block_size=16 + ) writer_tgt.write(handshake) await writer_tgt.drain() - reader_tgt.upstream = CryptoWrappedStreamReader(reader_tgt.upstream, decryptor, block_size=16) + reader_tgt.upstream = CryptoWrappedStreamReader( + reader_tgt.upstream, decryptor, block_size=16 + ) handshake_ans = await reader_tgt.read(1) if len(handshake_ans) != RPC_HANDSHAKE_ANS_LEN: raise ConnectionAbortedError("bad rpc handshake answer length") handshake_type, handshake_flags, handshake_sender_pid, handshake_peer_pid = ( - handshake_ans[:4], handshake_ans[4:8], handshake_ans[8:20], handshake_ans[20:32]) + handshake_ans[:4], + handshake_ans[4:8], + handshake_ans[8:20], + handshake_ans[20:32], + ) if handshake_type != RPC_HANDSHAKE or handshake_peer_pid != SENDER_PID: raise ConnectionAbortedError("bad rpc handshake answer") @@ -1540,7 +1682,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): global my_ip_info global tg_connection_pool - use_ipv6_tg = (my_ip_info["ipv6"] and (config.PREFER_IPV6 or not my_ip_info["ipv4"])) + use_ipv6_tg = my_ip_info["ipv6"] and (config.PREFER_IPV6 or not my_ip_info["ipv4"]) if use_ipv6_tg: if dc_idx not in TG_MIDDLE_PROXIES_V6: @@ -1555,16 +1697,26 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): ret = await tg_connection_pool.get_connection(addr, port, middleproxy_handshake) reader_tgt, writer_tgt, my_ip, my_port = ret except ConnectionRefusedError as E: - print_err("The Telegram server %d (%s %s) is refusing connections" % (dc_idx, addr, port)) + print_err( + "The Telegram server %d (%s %s) is refusing connections" + % (dc_idx, addr, port) + ) return False except ConnectionAbortedError as E: - print_err("The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, addr, port, E)) + print_err( + "The Telegram server connection is bad: %d (%s %s) %s" + % (dc_idx, addr, port, E) + ) return False except (OSError, asyncio.TimeoutError) as E: - print_err("Unable to connect to the Telegram server %d (%s %s)" % (dc_idx, addr, port)) + print_err( + "Unable to connect to the Telegram server %d (%s %s)" % (dc_idx, addr, port) + ) return False - writer_tgt = ProxyReqStreamWriter(writer_tgt, cl_ip, cl_port, my_ip, my_port, proto_tag) + writer_tgt = ProxyReqStreamWriter( + writer_tgt, cl_ip, cl_port, my_ip, my_port, proto_tag + ) reader_tgt = ProxyReqStreamReader(reader_tgt) return reader_tgt, writer_tgt @@ -1588,9 +1740,13 @@ async def tg_connect_reader_to_writer(rd, wr, user, rd_buf_size, is_upstream): return else: if is_upstream: - update_user_stats(user, octets_from_client=len(data), msgs_from_client=1) + update_user_stats( + user, octets_from_client=len(data), msgs_from_client=1 + ) else: - update_user_stats(user, octets_to_client=len(data), msgs_to_client=1) + update_user_stats( + user, octets_to_client=len(data), msgs_to_client=1 + ) wr.write(data, extra) await wr.drain() @@ -1600,15 +1756,21 @@ async def tg_connect_reader_to_writer(rd, wr, user, rd_buf_size, is_upstream): async def handle_client(reader_clt, writer_clt): - set_keepalive(writer_clt.get_extra_info("socket"), config.CLIENT_KEEPALIVE, attempts=3) + set_keepalive( + writer_clt.get_extra_info("socket"), config.CLIENT_KEEPALIVE, attempts=3 + ) set_ack_timeout(writer_clt.get_extra_info("socket"), config.CLIENT_ACK_TIMEOUT) - set_bufsizes(writer_clt.get_extra_info("socket"), get_to_tg_bufsize(), get_to_clt_bufsize()) + set_bufsizes( + writer_clt.get_extra_info("socket"), get_to_tg_bufsize(), get_to_clt_bufsize() + ) update_stats(connects_all=1) try: - clt_data = await asyncio.wait_for(handle_handshake(reader_clt, writer_clt), - timeout=config.CLIENT_HANDSHAKE_TIMEOUT) + clt_data = await asyncio.wait_for( + handle_handshake(reader_clt, writer_clt), + timeout=config.CLIENT_HANDSHAKE_TIMEOUT, + ) except asyncio.TimeoutError: update_stats(handshake_timeouts=1) return @@ -1621,11 +1783,13 @@ async def handle_client(reader_clt, writer_clt): update_user_stats(user, connects=1) - connect_directly = (not config.USE_MIDDLE_PROXY or disable_middle_proxy) + connect_directly = not config.USE_MIDDLE_PROXY or disable_middle_proxy if connect_directly: if config.FAST_MODE: - tg_data = await do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=enc_key_and_iv) + tg_data = await do_direct_handshake( + proto_tag, dc_idx, dec_key_and_iv=enc_key_and_iv + ) else: tg_data = await do_direct_handshake(proto_tag, dc_idx) else: @@ -1637,6 +1801,7 @@ async def handle_client(reader_clt, writer_clt): reader_tg, writer_tg = tg_data if connect_directly and config.FAST_MODE: + class FakeEncryptor: def encrypt(self, data): return data @@ -1661,34 +1826,37 @@ def decrypt(self, data): else: return - tg_to_clt = tg_connect_reader_to_writer(reader_tg, writer_clt, user, - get_to_clt_bufsize(), False) - clt_to_tg = tg_connect_reader_to_writer(reader_clt, writer_tg, - user, get_to_tg_bufsize(), True) + tg_to_clt = tg_connect_reader_to_writer( + reader_tg, writer_clt, user, get_to_clt_bufsize(), False + ) + clt_to_tg = tg_connect_reader_to_writer( + reader_clt, writer_tg, user, get_to_tg_bufsize(), True + ) task_tg_to_clt = asyncio.ensure_future(tg_to_clt) task_clt_to_tg = asyncio.ensure_future(clt_to_tg) update_user_stats(user, curr_connects=1) tcp_limit_hit = ( - user in config.USER_MAX_TCP_CONNS and - user_stats[user]["curr_connects"] > config.USER_MAX_TCP_CONNS[user] + user in config.USER_MAX_TCP_CONNS + and user_stats[user]["curr_connects"] > config.USER_MAX_TCP_CONNS[user] ) user_expired = ( - user in config.USER_EXPIRATIONS and - datetime.datetime.now() > config.USER_EXPIRATIONS[user] + user in config.USER_EXPIRATIONS + and datetime.datetime.now() > config.USER_EXPIRATIONS[user] ) - user_data_quota_hit = ( - user in config.USER_DATA_QUOTA and - (user_stats[user]["octets_to_client"] + - user_stats[user]["octets_from_client"] > config.USER_DATA_QUOTA[user]) + user_data_quota_hit = user in config.USER_DATA_QUOTA and ( + user_stats[user]["octets_to_client"] + user_stats[user]["octets_from_client"] + > config.USER_DATA_QUOTA[user] ) if (not tcp_limit_hit) and (not user_expired) and (not user_data_quota_hit): start = time.time() - await asyncio.wait([task_tg_to_clt, task_clt_to_tg], return_when=asyncio.FIRST_COMPLETED) + await asyncio.wait( + [task_tg_to_clt, task_clt_to_tg], return_when=asyncio.FIRST_COMPLETED + ) update_durations(time.time() - start) update_user_stats(user, curr_connects=-1) @@ -1728,7 +1896,7 @@ def make_metrics_pkt(metrics): for tag, tag_val in val.items(): if tag == "val": continue - tag_val = tag_val.replace('"', r'\"') + tag_val = tag_val.replace('"', r"\"") tags.append('%s="%s"' % (tag, tag_val)) pkt_body_list.append("%s{%s} %s" % (name, ",".join(tags), val["val"])) else: @@ -1740,7 +1908,9 @@ def make_metrics_pkt(metrics): pkt_header_list.append("Connection: close") pkt_header_list.append("Content-Length: %d" % len(pkt_body)) pkt_header_list.append("Content-Type: text/plain; version=0.0.4; charset=utf-8") - pkt_header_list.append("Date: %s" % time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime())) + pkt_header_list.append( + "Date: %s" % time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()) + ) pkt_header = "\r\n".join(pkt_header_list) @@ -1764,38 +1934,75 @@ async def handle_metrics(reader, writer): try: metrics = [] - metrics.append(["uptime", "counter", "proxy uptime", time.time() - proxy_start_time]) - metrics.append(["connects_bad", "counter", "connects with bad secret", - stats["connects_bad"]]) - metrics.append(["connects_all", "counter", "incoming connects", stats["connects_all"]]) - metrics.append(["handshake_timeouts", "counter", "number of timed out handshakes", - stats["handshake_timeouts"]]) + metrics.append( + ["uptime", "counter", "proxy uptime", time.time() - proxy_start_time] + ) + metrics.append( + [ + "connects_bad", + "counter", + "connects with bad secret", + stats["connects_bad"], + ] + ) + metrics.append( + ["connects_all", "counter", "incoming connects", stats["connects_all"]] + ) + metrics.append( + [ + "handshake_timeouts", + "counter", + "number of timed out handshakes", + stats["handshake_timeouts"], + ] + ) if config.METRICS_EXPORT_LINKS: for link in proxy_links: link_as_metric = link.copy() link_as_metric["val"] = 1 - metrics.append(["proxy_link_info", "counter", - "the proxy link info", link_as_metric]) + metrics.append( + [ + "proxy_link_info", + "counter", + "the proxy link info", + link_as_metric, + ] + ) bucket_start = 0 for bucket in STAT_DURATION_BUCKETS: bucket_end = bucket if bucket != STAT_DURATION_BUCKETS[-1] else "+Inf" metric = { "bucket": "%s-%s" % (bucket_start, bucket_end), - "val": stats["connects_with_duration_le_%s" % str(bucket)] + "val": stats["connects_with_duration_le_%s" % str(bucket)], } - metrics.append(["connects_by_duration", "counter", "connects by duration", metric]) + metrics.append( + ["connects_by_duration", "counter", "connects by duration", metric] + ) bucket_start = bucket_end user_metrics_desc = [ ["user_connects", "counter", "user connects", "connects"], ["user_connects_curr", "gauge", "current user connects", "curr_connects"], - ["user_octets", "counter", "octets proxied for user", - "octets_from_client+octets_to_client"], - ["user_msgs", "counter", "msgs proxied for user", - "msgs_from_client+msgs_to_client"], - ["user_octets_from", "counter", "octets proxied from user", "octets_from_client"], + [ + "user_octets", + "counter", + "octets proxied for user", + "octets_from_client+octets_to_client", + ], + [ + "user_msgs", + "counter", + "msgs proxied for user", + "msgs_from_client+msgs_to_client", + ], + [ + "user_octets_from", + "counter", + "octets proxied from user", + "octets_from_client", + ], ["user_octets_to", "counter", "octets proxied to user", "octets_to_client"], ["user_msgs_from", "counter", "msgs proxied from user", "msgs_from_client"], ["user_msgs_to", "counter", "msgs proxied to user", "msgs_to_client"], @@ -1833,10 +2040,16 @@ async def stats_printer(): print("Stats for", time.strftime("%d.%m.%Y %H:%M:%S")) for user, stat in user_stats.items(): - print("%s: %d connects (%d current), %.2f MB, %d msgs" % ( - user, stat["connects"], stat["curr_connects"], - (stat["octets_from_client"] + stat["octets_to_client"]) / 1000000, - stat["msgs_from_client"] + stat["msgs_to_client"])) + print( + "%s: %d connects (%d current), %.2f MB, %d msgs" + % ( + user, + stat["connects"], + stat["curr_connects"], + (stat["octets_from_client"] + stat["octets_to_client"]) / 1000000, + stat["msgs_from_client"] + stat["msgs_to_client"], + ) + ) print(flush=True) if last_client_ips: @@ -1861,12 +2074,13 @@ async def stats_printer(): async def make_https_req(url, host="core.telegram.org"): - """ Make request, return resp body and headers. """ + """Make request, return resp body and headers.""" SSL_PORT = 443 url_data = urllib.parse.urlparse(url) - HTTP_REQ_TEMPLATE = "\r\n".join(["GET %s HTTP/1.1", "Host: %s", - "Connection: close"]) + "\r\n\r\n" + HTTP_REQ_TEMPLATE = ( + "\r\n".join(["GET %s HTTP/1.1", "Host: %s", "Connection: close"]) + "\r\n\r\n" + ) reader, writer = await asyncio.open_connection(url_data.netloc, SSL_PORT, ssl=True) req = HTTP_REQ_TEMPLATE % (urllib.parse.quote(url_data.path), host) writer.write(req.encode("utf8")) @@ -1934,8 +2148,10 @@ async def get_tls_record(reader): record4_type, record4 = await get_tls_record(reader) if record4_type != 23: return b"" - msg = ("The MASK_HOST %s sent some TLS record before certificate record, this makes the " + - "proxy more detectable") % config.MASK_HOST + msg = ( + "The MASK_HOST %s sent some TLS record before certificate record, this makes the " + + "proxy more detectable" + ) % config.MASK_HOST print_err(msg) return record4 @@ -1956,29 +2172,40 @@ async def get_mask_host_cert_len(): await asyncio.sleep(MASK_ENABLING_CHECK_PERIOD) continue - task = get_encrypted_cert(config.MASK_HOST, config.MASK_PORT, config.TLS_DOMAIN) + task = get_encrypted_cert( + config.MASK_HOST, config.MASK_PORT, config.TLS_DOMAIN + ) cert = await asyncio.wait_for(task, timeout=GET_CERT_TIMEOUT) if cert: if len(cert) < MIN_CERT_LEN: - msg = ("The MASK_HOST %s returned several TLS records, this is not supported" % - config.MASK_HOST) + msg = ( + "The MASK_HOST %s returned several TLS records, this is not supported" + % config.MASK_HOST + ) print_err(msg) elif len(cert) != fake_cert_len: fake_cert_len = len(cert) - print_err("Got cert from the MASK_HOST %s, its length is %d" % - (config.MASK_HOST, fake_cert_len)) + print_err( + "Got cert from the MASK_HOST %s, its length is %d" + % (config.MASK_HOST, fake_cert_len) + ) else: - print_err("The MASK_HOST %s is not TLS 1.3 host, this is not recommended" % - config.MASK_HOST) + print_err( + "The MASK_HOST %s is not TLS 1.3 host, this is not recommended" + % config.MASK_HOST + ) except ConnectionRefusedError: - print_err("The MASK_HOST %s is refusing connections, this is not recommended" % - config.MASK_HOST) + print_err( + "The MASK_HOST %s is refusing connections, this is not recommended" + % config.MASK_HOST + ) except (TimeoutError, asyncio.TimeoutError): - print_err("Got timeout while getting TLS handshake from MASK_HOST %s" % - config.MASK_HOST) + print_err( + "Got timeout while getting TLS handshake from MASK_HOST %s" + % config.MASK_HOST + ) except Exception as E: - print_err("Failed to connect to MASK_HOST %s: %s" % ( - config.MASK_HOST, E)) + print_err("Failed to connect to MASK_HOST %s: %s" % (config.MASK_HOST, E)) await asyncio.sleep(config.GET_CERT_LEN_PERIOD) @@ -1998,11 +2225,15 @@ async def get_srv_time(): for line in headers.split(b"\r\n"): if not line.startswith(b"Date: "): continue - line = line[len("Date: "):].decode() + line = line[len("Date: ") :].decode() srv_time = datetime.datetime.strptime(line, "%a, %d %b %Y %H:%M:%S %Z") now_time = datetime.datetime.utcnow() - is_time_skewed = (now_time-srv_time).total_seconds() > MAX_TIME_SKEW - if is_time_skewed and config.USE_MIDDLE_PROXY and not disable_middle_proxy: + is_time_skewed = (now_time - srv_time).total_seconds() > MAX_TIME_SKEW + if ( + is_time_skewed + and config.USE_MIDDLE_PROXY + and not disable_middle_proxy + ): print_err("Time skew detected, please set the clock") print_err("Server time:", srv_time, "your time:", now_time) print_err("Disabling advertising to continue serving") @@ -2130,7 +2361,9 @@ def print_tg_info(): if config.PORT == 3256: print("The default port 3256 is used, this is not recommended", flush=True) if not config.MODES["classic"] and not config.MODES["secure"]: - print("Since you have TLS only mode enabled the best port is 443", flush=True) + print( + "Since you have TLS only mode enabled the best port is 443", flush=True + ) print_default_warning = True if not config.MY_DOMAIN: @@ -2146,14 +2379,14 @@ def print_tg_info(): for ip in ip_addrs: if config.MODES["classic"]: params = {"server": ip, "port": config.PORT, "secret": secret} - params_encodeded = urllib.parse.urlencode(params, safe=':') + params_encodeded = urllib.parse.urlencode(params, safe=":") classic_link = "tg://proxy?{}".format(params_encodeded) proxy_links.append({"user": user, "link": classic_link}) print("{}: {}".format(user, classic_link), flush=True) if config.MODES["secure"]: params = {"server": ip, "port": config.PORT, "secret": "dd" + secret} - params_encodeded = urllib.parse.urlencode(params, safe=':') + params_encodeded = urllib.parse.urlencode(params, safe=":") dd_link = "tg://proxy?{}".format(params_encodeded) proxy_links.append({"user": user, "link": dd_link}) print("{}: {}".format(user, dd_link), flush=True) @@ -2164,21 +2397,31 @@ def print_tg_info(): # tls_secret = bytes.fromhex("ee" + secret) + config.TLS_DOMAIN.encode() # tls_secret_base64 = base64.urlsafe_b64encode(tls_secret) params = {"server": ip, "port": config.PORT, "secret": tls_secret} - params_encodeded = urllib.parse.urlencode(params, safe=':') + params_encodeded = urllib.parse.urlencode(params, safe=":") tls_link = "tg://proxy?{}".format(params_encodeded) proxy_links.append({"user": user, "link": tls_link}) print("{}: {}".format(user, tls_link), flush=True) - if secret in ["00000000000000000000000000000000", "0123456789abcdef0123456789abcdef", - "00000000000000000000000000000001"]: - msg = "The default secret {} is used, this is not recommended".format(secret) + if secret in [ + "00000000000000000000000000000000", + "0123456789abcdef0123456789abcdef", + "00000000000000000000000000000001", + ]: + msg = "The default secret {} is used, this is not recommended".format( + secret + ) print(msg, flush=True) - random_secret = "".join(myrandom.choice("0123456789abcdef") for i in range(32)) + random_secret = "".join( + myrandom.choice("0123456789abcdef") for i in range(32) + ) print("You can change it to this random secret:", random_secret, flush=True) print_default_warning = True if config.TLS_DOMAIN == "www.google.com": - print("The default TLS_DOMAIN www.google.com is used, this is not recommended", flush=True) + print( + "The default TLS_DOMAIN www.google.com is used, this is not recommended", + flush=True, + ) msg = "You should use random existing domain instead, bad clients are proxied there" print(msg, flush=True) print_default_warning = True @@ -2190,10 +2433,13 @@ def print_tg_info(): def setup_files_limit(): try: import resource + soft_fd_limit, hard_fd_limit = resource.getrlimit(resource.RLIMIT_NOFILE) resource.setrlimit(resource.RLIMIT_NOFILE, (hard_fd_limit, hard_fd_limit)) except (ValueError, OSError): - print("Failed to increase the limit of opened files", flush=True, file=sys.stderr) + print( + "Failed to increase the limit of opened files", flush=True, file=sys.stderr + ) except ImportError: pass @@ -2204,14 +2450,17 @@ def setup_asyncio(): def setup_signals(): - if hasattr(signal, 'SIGUSR1'): + if hasattr(signal, "SIGUSR1"): + def debug_signal(signum, frame): import pdb + pdb.set_trace() signal.signal(signal.SIGUSR1, debug_signal) - if hasattr(signal, 'SIGUSR2'): + if hasattr(signal, "SIGUSR2"): + def reload_signal(signum, frame): init_config() ensure_users_in_user_stats() @@ -2228,6 +2477,7 @@ def try_setup_uvloop(): return try: import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) print_err("Found uvloop, using it for optimal performance") except ImportError: @@ -2253,12 +2503,11 @@ def loop_exception_handler(loop, context): if isinstance(exception, OSError): IGNORE_ERRNO = { 10038, # operation on non-socket on Windows, likely because fd == -1 - 121, # the semaphore timeout period has expired on Windows + 121, # the semaphore timeout period has expired on Windows } FORCE_CLOSE_ERRNO = { - 113, # no route to host - + 113, # no route to host } if exception.errno in IGNORE_ERRNO: return @@ -2277,30 +2526,43 @@ def create_servers(loop): has_unix = hasattr(socket, "AF_UNIX") if config.LISTEN_ADDR_IPV4: - task = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV4, config.PORT, - limit=get_to_tg_bufsize(), reuse_port=reuse_port) + task = asyncio.start_server( + handle_client_wrapper, + config.LISTEN_ADDR_IPV4, + config.PORT, + limit=get_to_tg_bufsize(), + reuse_port=reuse_port, + ) servers.append(loop.run_until_complete(task)) if config.LISTEN_ADDR_IPV6 and socket.has_ipv6: - task = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV6, config.PORT, - limit=get_to_tg_bufsize(), reuse_port=reuse_port) + task = asyncio.start_server( + handle_client_wrapper, + config.LISTEN_ADDR_IPV6, + config.PORT, + limit=get_to_tg_bufsize(), + reuse_port=reuse_port, + ) servers.append(loop.run_until_complete(task)) if config.LISTEN_UNIX_SOCK and has_unix: remove_unix_socket(config.LISTEN_UNIX_SOCK) - task = asyncio.start_unix_server(handle_client_wrapper, config.LISTEN_UNIX_SOCK, - limit=get_to_tg_bufsize()) + task = asyncio.start_unix_server( + handle_client_wrapper, config.LISTEN_UNIX_SOCK, limit=get_to_tg_bufsize() + ) servers.append(loop.run_until_complete(task)) os.chmod(config.LISTEN_UNIX_SOCK, 0o666) if config.METRICS_PORT is not None: if config.METRICS_LISTEN_ADDR_IPV4: - task = asyncio.start_server(handle_metrics, config.METRICS_LISTEN_ADDR_IPV4, - config.METRICS_PORT) + task = asyncio.start_server( + handle_metrics, config.METRICS_LISTEN_ADDR_IPV4, config.METRICS_PORT + ) servers.append(loop.run_until_complete(task)) if config.METRICS_LISTEN_ADDR_IPV6 and socket.has_ipv6: - task = asyncio.start_server(handle_metrics, config.METRICS_LISTEN_ADDR_IPV6, - config.METRICS_PORT) + task = asyncio.start_server( + handle_metrics, config.METRICS_LISTEN_ADDR_IPV6, config.METRICS_PORT + ) servers.append(loop.run_until_complete(task)) return servers