diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 2c6201e..14d4da5 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -28,6 +28,8 @@ jobs: sudo update-alternatives --set llvm-strip /usr/lib/llvm-21/bin/llvm-strip - name: Install musl target run: rustup target add x86_64-unknown-linux-musl + - name: Run tests + run: cargo test --verbose - name: Build run: | cargo build --release --target x86_64-unknown-linux-musl @@ -42,4 +44,4 @@ jobs: with: archive: false name: minecraft_filter.o - path: ./c/minecraft_filter.o + path: ./xdp/minecraft_filter.o diff --git a/.gitignore b/.gitignore index e22daa4..35739ec 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,9 @@ target/ *.tmp *.swp +# Generated at runtime on first launch +config.toml + # System files .DS_Store Thumbs.db diff --git a/Cargo.lock b/Cargo.lock index 56239f3..b573a5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,56 +23,6 @@ dependencies = [ "libc", ] -[[package]] -name = "anstream" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" -dependencies = [ - "anstyle", - "anstyle-parse", - "anstyle-query", - "anstyle-wincon", - "colorchoice", - "is_terminal_polyfill", - "utf8parse", -] - -[[package]] -name = "anstyle" -version = "1.0.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" - -[[package]] -name = "anstyle-parse" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" -dependencies = [ - "utf8parse", -] - -[[package]] -name = "anstyle-query" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" -dependencies = [ - "windows-sys 0.61.2", -] - -[[package]] -name = "anstyle-wincon" -version = "3.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" -dependencies = [ - "anstyle", - "once_cell_polyfill", - "windows-sys 0.61.2", -] - [[package]] name = "anyhow" version = "1.0.102" @@ -181,52 +131,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901" -[[package]] -name = "clap" -version = "4.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ddb117e43bbf7dacf0a4190fef4d345b9bad68dfc649cb349e7d17d28428e51" -dependencies = [ - "clap_builder", - "clap_derive", -] - -[[package]] -name = "clap_builder" -version = "4.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" -dependencies = [ - "anstream", - "anstyle", - "clap_lex", - "strsim", -] - -[[package]] -name = "clap_derive" -version = "4.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2ce8604710f6733aa641a2b3731eaa1e8b3d9973d5e3565da11800813f997a9" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "clap_lex" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" - -[[package]] -name = "colorchoice" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" - [[package]] name = "colored" version = "2.2.0" @@ -337,15 +241,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" - -[[package]] -name = "heck" -version = "0.5.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" [[package]] name = "httpdate" @@ -379,20 +277,14 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.1" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown 0.16.1", + "hashbrown 0.17.1", ] -[[package]] -name = "is_terminal_polyfill" -version = "1.70.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" - [[package]] name = "js-sys" version = "0.3.83" @@ -473,12 +365,6 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" -[[package]] -name = "once_cell_polyfill" -version = "1.70.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" - [[package]] name = "parking_lot" version = "0.12.5" @@ -522,30 +408,9 @@ dependencies = [ "lazy_static", "memchr", "parking_lot", - "protobuf", "thiserror 2.0.17", ] -[[package]] -name = "protobuf" -version = "3.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d65a1d4ddae7d8b5de68153b48f6aa3bba8cb002b243dbdbc55a5afbc98f99f4" -dependencies = [ - "once_cell", - "protobuf-support", - "thiserror 1.0.69", -] - -[[package]] -name = "protobuf-support" -version = "3.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e36c2f31e0a47f9280fb347ef5e461ffcd2c52dd520d8e216b52f93b0b0d7d6" -dependencies = [ - "thiserror 1.0.69", -] - [[package]] name = "quote" version = "1.0.45" @@ -576,6 +441,45 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_spanned" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6662b5879511e06e8999a8a235d848113e942c9124f211511b16466ee2995f26" +dependencies = [ + "serde_core", +] + [[package]] name = "shlex" version = "1.3.0" @@ -613,12 +517,6 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" -[[package]] -name = "strsim" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" - [[package]] name = "syn" version = "2.0.117" @@ -683,16 +581,49 @@ dependencies = [ ] [[package]] -name = "unicode-ident" -version = "1.0.22" +name = "toml" +version = "1.1.2+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +checksum = "81f3d15e84cbcd896376e6730314d59fb5a87f31e4b038454184435cd57defee" +dependencies = [ + "indexmap", + "serde_core", + "serde_spanned", + "toml_datetime", + "toml_parser", + "toml_writer", + "winnow", +] [[package]] -name = "utf8parse" -version = "0.2.2" +name = "toml_datetime" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_parser" +version = "1.1.2+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" +dependencies = [ + "winnow", +] + +[[package]] +name = "toml_writer" +version = "1.1.1+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "756daf9b1013ebe47a8776667b466417e2d4c5679d441c26230efd9ef78692db" + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" [[package]] name = "version_check" @@ -886,6 +817,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0592e1c9d151f854e6fd382574c3a0855250e1d9b2f99d9281c6e6391af352f1" + [[package]] name = "xdp-loader" version = "0.1.0" @@ -893,13 +830,13 @@ dependencies = [ "anyhow", "aya", "chrono", - "clap", "colored 3.1.1", "fern", "file-rotate", - "lazy_static", "log", "prometheus", + "serde", "signal-hook", "tiny_http", + "toml", ] diff --git a/Cargo.toml b/Cargo.toml index dbdb059..cb7d31b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,22 +3,30 @@ name = "xdp-loader" version = "0.1.0" edition = "2024" +# userspace loader lives in loader/ instead of the default src/ +[[bin]] +name = "xdp-loader" +path = "loader/main.rs" + [dependencies] aya = "0.13.1" log = "0.4.32" signal-hook = "0.4.4" anyhow = "1.0.102" -prometheus = "0.14" -lazy_static = "1.4" +# only the text encoder is used, default features would pull in protobuf +prometheus = { version = "0.14", default-features = false } tiny_http = "0.12" fern = { version = "0.7", features = ["colored"] } colored = "3.1" chrono = "0.4" file-rotate = "0.8" -clap = { version = "4.6", features = ["derive"] } +serde = { version = "1.0.228", features = ["derive"] } +toml = "1.1.2" [profile.release] -opt-level = 3 +# the packet hot path lives in the kernel; the loader itself is not +# performance sensitive, so optimize it for size +opt-level = "z" lto = "fat" codegen-units = 1 panic = "abort" diff --git a/LICENSE b/LICENSE index 73825a0..965ec51 100644 --- a/LICENSE +++ b/LICENSE @@ -1,3 +1,12 @@ +This project is licensed under the BSD-3-Clause license below. + +Exception: the BPF program sources in the xdp/ directory are dual-licensed +under the BSD-3-Clause license below OR the GNU General Public License +version 2 (GPL-2.0), at your option (SPDX: GPL-2.0 OR BSD-3-Clause). The +GPL option exists because the program uses GPL-only BPF kernel helpers +(bpf_timer_*) and the kernel only loads it with a GPL-compatible license +declaration ("Dual BSD/GPL"). + Copyright (c) 2025, Outfluencer Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index ffefda1..859b05f 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ The default filtered port is 25565. ### Generate your filter binary Generate here: https://xdp.outfluencer.dev/ -And than just run the executable. +And then just run the executable. ### Prerequisites - Rust toolchain (stable) @@ -41,26 +41,93 @@ And than just run the executable. ``` The compiled binary will be at `target/release/xdp-loader`. -2. **Run the firewall**: +2. **Run the tests** (optional): + ```bash + cargo test + ``` + Besides the Rust unit tests this compiles the eBPF parsing code (VarInt + reader, packet inspectors) natively with ASan/UBSan and runs its C unit + tests (see `xdp/tests/protocol_test.c`). + +3. **Run the firewall**: ```bash sudo ./target/release/xdp-loader # Example: sudo ./target/release/xdp-loader eth0 ``` - To enable Prometheus metrics export: - ```bash - sudo ./target/release/xdp-loader eth0 --metrics-addr 0.0.0.0:1999 - ``` - Metrics available at: `http://host:1999/metrics` + To enable Prometheus metrics export, set `prometheus = true` and a + `metrics_addr` in `config.toml` (see [Configuration](#configuration)), then + run the loader normally. Metrics are then available at: `http://host:1999/metrics` -**Note:** This project uses a persistent XDP loader. Usage of `XDP` programs requires the userspace program to stay running to manage maps. Stopping the loader will unload the firewall. +**Note:** This project uses a persistent XDP loader. The userspace program must stay running to keep the filter attached; all map state (throttle windows, verified connections) is managed in-kernel via `bpf_timer`. Stopping the loader will unload the firewall. Requires Linux kernel 5.15 or newer. ## Configuration -You can configure ports, features, and throttling behavior in the `build.rs` file. - -**After changing `build.rs`, you must recompile the project.** +Runtime behavior is controlled by a `config.toml` file next to the binary. +On first run it is created automatically with documented defaults; edit it and +restart the loader. Use `--config ` to point at a different file. + +**`[filter]`** — what traffic is filtered and how strictly: + +| Option | Type | Default | Description | +|----------------|--------|---------|-------------| +| `start_port` | int | 25565 | First port of the inclusive filtered range. | +| `end_port` | int | 25565 | Last port of the inclusive filtered range. | +| `hit_count` | int | 10 | Max SYNs per source IP per throttle window (`0` disables throttling). | +| `hit_count_reset_secs` | int | 3 | Throttle window length in seconds; each IP's SYN counter resets in-kernel once its window expires. | +| `player_idle_timeout_secs` | int | 60 | Idle timeout for verified connections; an idle entry is removed in-kernel after one to two intervals, requiring a new handshake. | +| `online_names` | bool | true | Enforce online-mode usernames (≤16 chars). | + +**`[xdp]`** — how the program attaches and the capacity of its in-kernel tables +(the maps are preallocated, so higher limits cost kernel memory up front): + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `mode` | string | `"auto"` | XDP attach mode: `"auto"` (native if the NIC supports it, automatic fallback to generic), `"driver"` (force native, fails when unsupported) or `"skb"` (force generic; works everywhere, slower). | +| `max_pending_connections` | int | 16384 | Concurrent connections mid-handshake; the least-recently-used entry is evicted when full. | +| `max_player_connections` | int | 65535 | Concurrent verified player connections; size well above the expected player count. | +| `max_throttled_ips` | int | 65535 | Source IPs tracked by the SYN throttle; new SYNs are dropped while full (fail closed). | + +**`[metrics]`** — statistics collection and the Prometheus endpoint: + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `enabled` | bool | false | Collect packet statistics inside the eBPF program. | +| `addr` | string | (unset) | Address for the Prometheus HTTP endpoint (requires `enabled = true`). | +| `poll_secs` | int | 10 | How often the in-kernel statistics are read and published. | + +**`[logging]`** — console and file logging: + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `level` | string | `"info"` | Verbosity: `"off"`, `"error"`, `"warn"`, `"info"`, `"debug"` or `"trace"`. Overridden by the `RUST_LOG` env variable. | +| `file_max_mb` | int | 100 | Rotate `xdp-loader.log` once it grows past this size; the 5 newest rotated files are kept. | + +The `[filter]` values and map capacities are pushed into the eBPF program at +load time (via `.rodata` globals and map definitions), so changing them only +requires a restart — **not a rebuild**. + +## Project Layout + +| Path | Purpose | +|------|---------| +| `xdp/minecraft_filter.c` | XDP entry point, BPF maps, conntrack state machine | +| `xdp/protocol.h` | Minecraft packet inspection (handshake, status, ping, login) | +| `xdp/varint.h` | Bounded VarInt reader | +| `xdp/common.h` | Bounds-check macros, flow key, connection states | +| `xdp/config.h` | Runtime configuration globals (patched by the loader) | +| `xdp/stats.h` | Statistics counters | +| `xdp/tests/` | Native unit tests for the parsing code (run via `cargo test`) | +| `loader/main.rs` | CLI entry point and process lifecycle | +| `loader/ebpf.rs` | Loads, configures and attaches the eBPF program | +| `loader/config.rs` | TOML configuration | +| `loader/metrics.rs` | Statistics polling and Prometheus endpoint | +| `loader/logging.rs` | Console + rotating file logging | +| `loader/shutdown.rs` | Signal handling and shutdown coordination | + +The eBPF program is compiled by `build.rs` and embedded into the loader +binary, so the released executable is fully self-contained. ## Troubleshooting diff --git a/build.rs b/build.rs index e4c2bbc..4bdbe7e 100644 --- a/build.rs +++ b/build.rs @@ -1,84 +1,45 @@ -use std::env; use std::process::Command; -/// Reads an environment variable or returns the default value -fn env_or(key: &str, default: &str) -> String { - env::var(key).unwrap_or_else(|_| default.to_string()) -} - +/// Compiles the eBPF program to `xdp/minecraft_filter.o`, which `loader/ebpf.rs` +/// embeds into the loader binary at compile time. fn main() { - // Configuration via environment variables at build time - // Example: ONLY_ASCII_NAMES=1 START_PORT=25565 cargo build - let config = [ - ("ONLY_ASCII_NAMES", env_or("ONLY_ASCII_NAMES", "1")), - ("CONNECTION_THROTTLE", env_or("CONNECTION_THROTTLE", "1")), - ("START_PORT", env_or("START_PORT", "25565")), - ("END_PORT", env_or("END_PORT", "25565")), - ("PROMETHEUS_METRICS", env_or("PROMETHEUS_METRICS", "1")), - ("IP_AND_PORT_PER_CPU", env_or("IP_AND_PORT_PER_CPU", "0")), - ("IP_PER_CPU", env_or("IP_PER_CPU", "0")), - ("HIT_COUNT", env_or("HIT_COUNT", "10")), - ]; - - // Register custom cfg options to avoid warnings - for (key, _) in &config { - println!("cargo:rustc-check-cfg=cfg({})", key.to_lowercase()); + // explicit file list on purpose: watching the whole xdp/ directory would + // also watch the generated .o and recompile on every build + for source in [ + "xdp/minecraft_filter.c", + "xdp/common.h", + "xdp/config.h", + "xdp/protocol.h", + "xdp/stats.h", + "xdp/varint.h", + ] { + println!("cargo:rerun-if-changed={source}"); } - // 1. Export variables for Rust (access via env! macro, e.g. env!("START_PORT")) - for (key, value) in &config { - println!("cargo:rustc-env={}={}", key, value); - // Enable #[cfg(key)] if value is "1" - if value == "1" { - println!("cargo:rustc-cfg={}", key.to_lowercase()); - } - // Rerun if this env var changes - println!("cargo:rerun-if-env-changed={}", key); - } - - // clang -Wall -Wextra -Wno-language-extension-token -O2 -g -target bpf -mcpu=v3 -c minecraft_filter.c -o minecraft_filter.o - // 2. Compile the C code using clang directly - let mut command = Command::new("clang"); - - // Add config as -D define flags - for (key, value) in &config { - command.arg(format!("-D{}={}", key, value)); + let output = Command::new("clang") + .args([ + "-Wall", + "-Wextra", + "-Wno-language-extension-token", + "-O2", + "-g", + "-target", + "bpf", + "-mcpu=v3", + "-c", + "minecraft_filter.c", + "-o", + "minecraft_filter.o", + ]) + .current_dir("xdp") + .output() + .expect("failed to run clang, is LLVM/clang installed?"); + + if !output.status.success() { + panic!( + "clang compilation failed:\nstdout: {}\nstderr: {}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); } - - // Add compilation flags - command.args([ - "-Wall", - "-Wextra", - "-Wno-language-extension-token", - "-O2", - "-g", - "-target", "bpf", - "-mcpu=v3", - "-c", "minecraft_filter.c", - "-o", "minecraft_filter.o", - ]); - - // execute command in "c" directory - command.current_dir("c"); - - println!("cargo:warning=Compiling eBPF program..."); - match command.output() { - Ok(output) => { - if !output.status.success() { - panic!( - "clang compilation failed:\nstdout: {}\nstderr: {}", - String::from_utf8_lossy(&output.stdout), - String::from_utf8_lossy(&output.stderr) - ); - } - } - Err(e) => panic!("Failed to execute clang: {}", e), - } - - // 3. Re-run if relevant files change - println!("cargo:rerun-if-changed=c/minecraft_filter.c"); - println!("cargo:rerun-if-changed=c/common.h"); - println!("cargo:rerun-if-changed=c/minecraft_networking.h"); - println!("cargo:rerun-if-changed=c/stats.h"); - println!("cargo:rerun-if-changed=build.rs"); } diff --git a/c/common.h b/c/common.h deleted file mode 100644 index dfc44f6..0000000 --- a/c/common.h +++ /dev/null @@ -1,151 +0,0 @@ -#ifndef COMMON_H -#define COMMON_H - -#include - -#ifndef barrier_var -#define barrier_var(var) asm volatile("" : "+r"(var)) -#endif - -// maximum amount of retransmission packets before blocking -#define MAX_OUT_OF_ORDER 4 - -// STATE TRACKING -#define AWAIT_ACK 1 -#define AWAIT_MC_HANDSHAKE 2 -#define RECEIVED_LEGACY_PING 3 // this connection will be fully dropped -#define AWAIT_STATUS_REQUEST 4 -#define AWAIT_LOGIN 5 -#define AWAIT_PING 6 -#define PING_COMPLETE 7 -#define DIRECT_READ_STATUS_REQUEST 8 -#define DIRECT_READ_LOGIN 9 - -#define SECOND_TO_NANOS 1000000000ULL - - -// Checks bounds and returns 0 if out of bounds (does NOT increment ptr) -#define CHECK_BOUNDS_OR_RETURN(ptr, n, pend, dend) \ - do \ - { \ - if ((void *)(ptr) + (n) > (const void *)(dend)) \ - return 0; \ - barrier_var(ptr); \ - if ((void *)(ptr) + (n) > (const void *)(pend)) \ - return 0; \ - barrier_var(ptr); \ - } while (0) - -// checks bounds. if bad, returns 0. if good, increments ptr. -// usage: READ_OR_RETURN(reader_index, 2, payload_end, data_end); -#define READ_OR_RETURN(ptr, n, pend, dend) \ - do \ - { \ - if ((void *)(ptr) + (n) > (const void *)(dend)) \ - return 0; \ - barrier_var(ptr); \ - if ((void *)(ptr) + (n) > (const void *)(pend)) \ - return 0; \ - barrier_var(ptr); \ - ptr += (n); \ - } while (0) - -// Returns how many bytes a value occupies when encoded as a varint (compile-time). -// 7 bits per byte: 0-127 → 1, 128-16383 → 2, ... up to 5 bytes max. -#define VARINT_SIZE(n) \ - (((__u32)(n) <= 0x7F) ? 1 : \ - ((__u32)(n) <= 0x3FFF) ? 2 : \ - ((__u32)(n) <= 0x1FFFFF) ? 3 : \ - ((__u32)(n) <= 0xFFFFFFF) ? 4 : 5) - -// reads a value into 'dest' and increments 'ptr', or returns 0 if OOB -#define READ_VAL_OR_RETURN(dest, ptr, pend, dend) \ - do \ - { \ - if ((void *)(ptr) + sizeof(dest) > (const void *)(dend)) \ - return 0; \ - barrier_var(ptr); \ - if ((void *)(ptr) + sizeof(dest) > (const void *)(pend)) \ - return 0; \ - barrier_var(ptr); \ - dest = *(__typeof__(dest) *)(ptr); \ - ptr += sizeof(dest); \ - } while (0) - -// if condition is false, returns 0 immediately. -#define ASSERT_OR_RETURN(cond) \ - do \ - { \ - if (!(cond)) \ - return 0; \ - } while (0) - -// if val is not in [min, max], returns 0 immediately. -#define ASSERT_IN_RANGE(val, min, max) \ - do \ - { \ - if ((val) < (min) || (val) > (max)) \ - return 0; \ - } while (0) - -// reads a varint into 'dest_struct', increments 'ptr', or returns 0 on failure. -#define VARINT_OR_DIE(dest_struct, ptr, pend, dend) \ - do \ - { \ - dest_struct = read_varint_sized(ptr, pend, 5, dend); \ - if (!(dest_struct).bytes) \ - return 0; \ - (ptr) += (dest_struct).bytes; \ - barrier_var(ptr); \ - } while (0) - -#define MAX_VARINT_OR_DIE(dest_struct, ptr, pend, dend, max) \ - do \ - { \ - dest_struct = read_varint_sized(ptr, pend, max, dend); \ - if (!(dest_struct).bytes) \ - return 0; \ - (ptr) += (dest_struct).bytes; \ - barrier_var(ptr); \ - } while (0) - - -struct ipv4_flow_key -{ - const __u32 src_ip; - const __u32 dst_ip; - const __u16 src_port; - const __u16 dst_port; -}; -_Static_assert(sizeof(struct ipv4_flow_key) == 12, "ipv4_flow_key size mismatch!"); - -struct initial_state -{ - __u16 state; // we only need u8, but padding.... - __u16 fails; // we only need u8, but padding.... - __s32 protocol; // minecraft protocol versions are signed - __u32 expected_sequence; -}; -_Static_assert(sizeof(struct initial_state) == 12, "initial_state size mismatch!"); - -static __always_inline struct ipv4_flow_key gen_ipv4_flow_key(const __u32 src_ip, const __u32 dst_ip, const __u16 src_port, const __u16 dst_port) -{ - struct ipv4_flow_key key = { - .src_ip = src_ip, - .dst_ip = dst_ip, - .src_port = src_port, - .dst_port = dst_port}; - return key; -} - -static __always_inline struct initial_state gen_initial_state(const __u16 state, const __s32 protocol, const __u32 expected_sequence) -{ - struct initial_state new_state = { - .state = state, - .fails = 0, - .protocol = protocol, - .expected_sequence = expected_sequence, - }; - return new_state; -} -#endif diff --git a/c/minecraft_filter.c b/c/minecraft_filter.c deleted file mode 100644 index ad88849..0000000 --- a/c/minecraft_filter.c +++ /dev/null @@ -1,403 +0,0 @@ -#ifndef HIT_COUNT -#define HIT_COUNT 10 -#endif - -#ifndef START_PORT -#define START_PORT 25565 -#endif -#ifndef END_PORT -#define END_PORT 25565 -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include "common.h" -#include "minecraft_networking.h" -#include "stats.h" - -struct -{ - __uint(type, -#if IP_AND_PORT_PER_CPU - BPF_MAP_TYPE_LRU_PERCPU_HASH -#else - BPF_MAP_TYPE_LRU_HASH -#endif - ); - __uint(max_entries, 4096); // max amount of 4096 concurrent initial connections - __type(key, struct ipv4_flow_key); // flow key - __type(value, struct initial_state); // initial state -} conntrack_map SEC(".maps"); - -struct -{ - __uint(type, -#if IP_AND_PORT_PER_CPU - BPF_MAP_TYPE_PERCPU_HASH -#else - BPF_MAP_TYPE_HASH -#endif - ); - __uint(max_entries, 65535); - __type(key, struct ipv4_flow_key); // flow key - __type(value, __u64); // last seen timestamp - __uint(pinning, LIBBPF_PIN_BY_NAME); -} player_connection_map SEC(".maps"); - -struct -{ - __uint(type, -#if IP_PER_CPU - BPF_MAP_TYPE_PERCPU_HASH -#else - BPF_MAP_TYPE_HASH -#endif - ); - __uint(max_entries, 65535); - __type(key, __u32); // ipv4 address - __type(value, __u32); // throttle hit counter - __uint(pinning, LIBBPF_PIN_BY_NAME); -} connection_throttle SEC(".maps"); - -#if PROMETHEUS_METRICS -struct -{ - __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY); - __uint(max_entries, 1); - __type(key, __u32); - __type(value, struct statistics); - __uint(pinning, LIBBPF_PIN_BY_NAME); -} stats_map SEC(".maps"); -#endif - -static __always_inline __u8 detect_tcp_bypass(const struct tcphdr *tcp) -{ - if ((!tcp->syn && !tcp->ack && !tcp->fin && !tcp->rst) || // no SYN/ACK/FIN/RST flag - (tcp->syn && tcp->ack) || // SYN+ACK from external (unexpected) - tcp->urg) - { // drop if URG flag is set - return 1; - } - return 0; -} - -/* - * removes the connection from the conntrack_map - */ -static __always_inline void remove_connection(const struct statistics *stats_ptr, const struct ipv4_flow_key *flow_key) -{ - count_stats(stats_ptr, DROP_CONNECTION, 1); - bpf_map_delete_elem(&conntrack_map, flow_key); - (void)stats_ptr; // for compiler -} - -/* - * removes connection from conntrack map and puts it into the player map - * no more packets of this connection will be checked now - */ -static __always_inline __u32 switch_to_verified(const __u64 raw_packet_len, const struct statistics *stats_ptr, const struct ipv4_flow_key *flow_key) -{ - bpf_map_delete_elem(&conntrack_map, flow_key); - __u64 count = 1; - // timeout after 60 to 120 seconds. - if (bpf_map_update_elem(&player_connection_map, flow_key, &count, BPF_NOEXIST) < 0) - { - count_stats(stats_ptr, DROPPED_BYTES, raw_packet_len); - count_stats(stats_ptr, DROP_CONNECTION | DROPPED_PACKET, 1); - return XDP_DROP; - } - count_stats(stats_ptr, VERIFIED, 1); - // for compiler - (void)raw_packet_len; - (void)stats_ptr; - - return XDP_PASS; -} - -SEC("xdp") -__s32 minecraft_filter(struct xdp_md *ctx) -{ - const void *data = (const void *)(long)ctx->data; - const void *data_end = (const void *)(long)ctx->data_end; - - const struct ethhdr *eth = data; - if ((const void *)(eth + 1) > data_end) - { - return XDP_DROP; - } - - if (eth->h_proto != bpf_htons(ETH_P_IP)) - { - return XDP_PASS; - } - - const struct iphdr *ip = data + sizeof(struct ethhdr); - if ((const void *)(ip + 1) > data_end || ip->ihl < 5) - { - return XDP_DROP; - } - - if (ip->protocol != IPPROTO_TCP) - { - return XDP_PASS; - } - - const struct tcphdr *tcp = (const void *)ip + (ip->ihl * 4); - if ((const void *)(tcp + 1) > data_end) - { - return XDP_DROP; - } - - // check if TCP destination port matches mc server port - const __u16 dest_port = bpf_ntohs(tcp->dest); - -#if START_PORT == END_PORT - if (dest_port != START_PORT) - { - return XDP_PASS; // not for our service - } -#else - if (dest_port < START_PORT || dest_port > END_PORT) - { - return XDP_PASS; // not for our service - } -#endif - if (tcp->doff < 5) - { - return XDP_DROP; - } - - const __u32 tcp_hdr_len = tcp->doff * 4; - if ((const void *)tcp + tcp_hdr_len > data_end) - { - return XDP_DROP; - } - -#if PROMETHEUS_METRICS - __u32 key = 0; - struct statistics *stats_ptr = bpf_map_lookup_elem(&stats_map, &key); - if (!stats_ptr) - { - // this should be impossible - return XDP_DROP; - } -#else - struct statistics *stats_ptr = 0; -#endif - - const __u64 raw_packet_len = (__u64)(data_end - data); - count_stats(stats_ptr, INCOMING_BYTES, raw_packet_len); - - // additional TCP bypass checks for abnormal flags - if (detect_tcp_bypass(tcp)) - { - count_stats(stats_ptr, TCP_BYPASS, 1); - goto drop; - } - - const __u32 src_ip = ip->saddr; - - // stateless new connection checks - if (tcp->syn) - { - count_stats(stats_ptr, SYN_RECEIVE, 1); - -#if CONNECTION_THROTTLE - // connection throttle - // 10 connection per ip per 3 seconds, otherwise drop - __u32 *hit_counter = bpf_map_lookup_elem(&connection_throttle, &src_ip); - if (hit_counter) - { - if (*hit_counter > HIT_COUNT) - { - goto drop; - } - (*hit_counter)++; - } - else - { - __u32 new_counter = 1; - if (bpf_map_update_elem(&connection_throttle, &src_ip, &new_counter, BPF_NOEXIST) < 0) - { - goto drop; - } - } -#endif - // compute flow key - const struct ipv4_flow_key flow_key = gen_ipv4_flow_key(src_ip, ip->daddr, tcp->source, tcp->dest); - const struct initial_state new_state = gen_initial_state(AWAIT_ACK, 0, bpf_ntohl(tcp->seq) + 1); - if (bpf_map_update_elem(&conntrack_map, &flow_key, &new_state, BPF_ANY) < 0) - { - goto drop; - } - - return XDP_PASS; - } - - // compute flow key - const struct ipv4_flow_key flow_key = gen_ipv4_flow_key(src_ip, ip->daddr, tcp->source, tcp->dest); - __u64 *p_counter = bpf_map_lookup_elem(&player_connection_map, &flow_key); - if (p_counter) - { - (*p_counter)++; - return XDP_PASS; - } - - struct initial_state *initial_state = bpf_map_lookup_elem(&conntrack_map, &flow_key); - if (!initial_state) - { - goto drop; // no connection tracked, drop - } - - __u8 *tcp_payload = (__u8 *)((__u8 *)tcp + tcp_hdr_len); - - // total length of ip packet - const __u16 ip_tot_len = bpf_ntohs(ip->tot_len); - // total ip - ip header - tcp header = length of tcp payload - const __u16 tcp_payload_len = ip_tot_len - (ip->ihl * 4) - tcp_hdr_len; - // tcp payload end = start + length - const __u8 *tcp_payload_end = tcp_payload + tcp_payload_len; - - // tcp packet is split in multiple ethernet frames, we don't support that - if (tcp_payload_end > (__u8 *)data_end) - { - goto drop; - } - - __u32 state = initial_state->state; - if (state == AWAIT_ACK) - { - // not an ack or invalid ack number - if (!tcp->ack || initial_state->expected_sequence != bpf_ntohl(tcp->seq)) - { - goto drop; - } - - // set state here even tho we may retrun as we need the state for the next packet - initial_state->state = state = AWAIT_MC_HANDSHAKE; - - // we can drop original pure ack from the tcp 3 way handshake - // the backend will accept the first minecraft data packet as the ack of the 3 way handshake - // that's an elegant way to only let the backend accept connections that have a mc handshake in it. - // Only drop if there is no TCP payload; if there is payload, continue into payload inspection. - if (tcp_payload >= tcp_payload_end) - { - goto drop; - } - - // do not return here, the ack of the tcp handshake can contain application data - // return XDP_PASS; - } - - if (tcp_payload < tcp_payload_end) - { - - if (!tcp->ack) - { - goto drop_connection; - } - - // we fully track the tcp packet order with this check, - // this mean we can hard punish invalid packets below, as they are not out of order - // but invalid data - if (initial_state->expected_sequence != bpf_ntohl(tcp->seq)) - { - if (++initial_state->fails > MAX_OUT_OF_ORDER) - { - goto drop_connection; - } - goto drop; - } - - if (state == AWAIT_MC_HANDSHAKE) - { - // returns the next state - // if the login data or motd request is included in the same tcp data as the handshake - // the tcp_payload reader index will be updated to the next position - __s32 next_state = inspect_handshake(tcp_payload, tcp_payload_end, &initial_state->protocol, data_end, &tcp_payload); - // if the first packet has invalid length, we can block it - // even with retransmission this len should always be valid‚ - if (!next_state) - { - goto drop; - } - - // fully drop legacy ping - if (next_state == RECEIVED_LEGACY_PING) - { - goto drop_connection; - } - if (next_state == DIRECT_READ_STATUS_REQUEST) - { - goto read_status; - } - if (next_state == DIRECT_READ_LOGIN) - { - goto read_login; - } - initial_state->state = next_state; - goto update_state; - } - if (state == AWAIT_STATUS_REQUEST) - read_status: { - if (!inspect_status_request(tcp_payload, tcp_payload_end, data_end)) - { - goto drop; - } - initial_state->state = AWAIT_PING; - goto update_state; - } - if (state == AWAIT_PING) - { - if (!inspect_ping_request(tcp_payload, tcp_payload_end, data_end)) - { - goto drop; - } - initial_state->state = PING_COMPLETE; - goto update_state; - } - if (state == AWAIT_LOGIN) - read_login: { - - if (!inspect_login_packet(tcp_payload, tcp_payload_end, initial_state->protocol, data_end)) - { - goto drop; - } - // as tracking ends here we do not need to update the sequence - // initial_state->expected_sequence += tcp_payload_len; - goto switch_to_verified; - } - if (state == PING_COMPLETE) - { - goto drop_connection; - } - } else if (state == AWAIT_MC_HANDSHAKE) { - // no ack's are allowed, we are waiting for the handshake - // otherwise an attacker could bypass the 3 way handshake hack - goto drop; - } - return XDP_PASS; - -// Using this labels drastically reduce the file size -drop_connection: - remove_connection(stats_ptr, &flow_key); - goto drop; -drop: - count_stats(stats_ptr, DROPPED_PACKET, 1); - count_stats(stats_ptr, DROPPED_BYTES, raw_packet_len); - return XDP_DROP; -update_state: - initial_state->expected_sequence += tcp_payload_len; - count_stats(stats_ptr, STATE_SWITCH, 1); - return XDP_PASS; -switch_to_verified: - return switch_to_verified(raw_packet_len, stats_ptr, &flow_key); -} - -char _license[] SEC("license") = "Proprietary"; diff --git a/c/minecraft_helper.h b/c/minecraft_helper.h deleted file mode 100644 index 1173c2d..0000000 --- a/c/minecraft_helper.h +++ /dev/null @@ -1,103 +0,0 @@ -#ifndef MINECRAFT_HELPER_H -#define MINECRAFT_HELPER_H - -#include -#include "common.h" - -// if you are running a premium server, you can enable this, it drops weird usernames -#ifndef ONLY_ASCII_NAMES -#define ONLY_ASCII_NAMES 0 -#endif - -// general varint limits -#define UTF8_MAX_BYTES 3 -#define UUID_LEN 16 -#define MIN_VARINT_BYTES 1 -#define MAX_VARINT_BYTES 5 - -#define PACKET_ID_MIN MIN_VARINT_BYTES -#define PACKET_ID_MAX MAX_VARINT_BYTES - -// handshake packet -#define HANDSHAKE_VERSION_MIN MIN_VARINT_BYTES -#define HANDSHAKE_VERSION_MAX MAX_VARINT_BYTES - -#define HANDSHAKE_HOSTLEN_MIN MIN_VARINT_BYTES -#define HANDSHAKE_HOSTLEN_MAX MAX_VARINT_BYTES - -#define HANDSHAKE_HOST_DATA_MIN (0) -#define HANDSHAKE_HOST_DATA_MAX (255 * 3) - -#define HANDSHAKE_PORT_LEN (2) - -#define HANDSHAKE_INTENTION_MIN MIN_VARINT_BYTES -#define HANDSHAKE_INTENTION_MAX MAX_VARINT_BYTES - -#define HANDSHAKE_DATA_MIN (HANDSHAKE_VERSION_MIN + HANDSHAKE_HOSTLEN_MIN + HANDSHAKE_HOST_DATA_MIN + HANDSHAKE_PORT_LEN + HANDSHAKE_INTENTION_MIN) -#define HANDSHAKE_DATA_MAX (HANDSHAKE_VERSION_MAX + HANDSHAKE_HOSTLEN_MAX + HANDSHAKE_HOST_DATA_MAX + HANDSHAKE_PORT_LEN + HANDSHAKE_INTENTION_MAX) - -// login request packet -#define LOGIN_NAME_LEN_MIN MIN_VARINT_BYTES -#define LOGIN_NAME_LEN_MAX MAX_VARINT_BYTES - -#define LOGIN_NAME_DATA_MIN (1) // empty names are not possible -#define LOGIN_NAME_DATA_MAX (16 * (ONLY_ASCII_NAMES ? 1 : UTF8_MAX_BYTES)) - -#define LOGIN_KEY_MIN 0 -#define LOGIN_KEY_MAX 512 - -#define LOGIN_SIGNATURE_MIN 0 -#define LOGIN_SIGNATURE_MAX 4096 - -#define LOGIN_PUBLIC_KEY_MIN (/*has key*/ 1) -#define LOGIN_PUBLIC_KEY_MAX (/*has key*/ 1 + /*expiry*/ 8 + /*length*/ MAX_VARINT_BYTES + LOGIN_KEY_MAX + /*length*/ MAX_VARINT_BYTES + LOGIN_SIGNATURE_MAX) - -#define LOGIN_HAS_UUID_LEN 1 -#define LOGIN_DATA_MIN (LOGIN_NAME_LEN_MIN + LOGIN_NAME_DATA_MIN) -#define LOGIN_DATA_MAX (LOGIN_NAME_LEN_MAX + LOGIN_NAME_DATA_MAX + LOGIN_PUBLIC_KEY_MAX + LOGIN_HAS_UUID_LEN + UUID_LEN) - -struct varint_value -{ - __s32 value; - __u32 bytes; // 1 to 5 bytes -}; - -static __always_inline struct varint_value varint(__s32 value, __u32 bytes) -{ - return (struct varint_value){value, bytes}; -} - -_Static_assert(sizeof(struct varint_value) == 8, "varint_value size mismatch!"); - -// Reads one varint byte, checks bounds, returns result if done, or continues -#define VARINT_BYTE(ptr, pend, dend, max, idx, shift, result) \ - do { \ - if ((max) < (idx)) \ - goto error; \ - if ((const void *)(ptr) + 1 > (const void *)(dend)) \ - goto error; \ - barrier_var(ptr); \ - if ((const void *)(ptr) + 1 > (const void *)(pend)) \ - goto error; \ - barrier_var(ptr); \ - __u8 _b = *(ptr)++; \ - (result) |= ((__s32)(_b & 0x7F) << (shift)); \ - if (!(_b & 0x80)) \ - return varint((result), (idx)); \ - } while (0) - -static __always_inline struct varint_value read_varint_sized(__u8 *start, const __u8 *payload_end, const __u8 max_size, const void *data_end) -{ - __s32 result = 0; - - VARINT_BYTE(start, payload_end, data_end, max_size, 1, 0, result); - VARINT_BYTE(start, payload_end, data_end, max_size, 2, 7, result); - VARINT_BYTE(start, payload_end, data_end, max_size, 3, 14, result); - VARINT_BYTE(start, payload_end, data_end, max_size, 4, 21, result); - VARINT_BYTE(start, payload_end, data_end, max_size, 5, 28, result); - -error: - return varint(0, 0); -} - -#endif \ No newline at end of file diff --git a/c/minecraft_networking.h b/c/minecraft_networking.h deleted file mode 100644 index 9fc3fe5..0000000 --- a/c/minecraft_networking.h +++ /dev/null @@ -1,161 +0,0 @@ -#ifndef MINECRAFT_NETWORKING_H -#define MINECRAFT_NETWORKING_H - -#include - -#include "minecraft_helper.h" -#include "common.h" - -// checks if the packet contains a valid ping request -static __always_inline __u8 inspect_ping_request(__u8 *start, const __u8 *payload_end, const void *data_end) -{ - struct varint_value varint; - - // max 9 bytes - MAX_VARINT_OR_DIE(varint, start, payload_end, data_end, VARINT_SIZE(0x09)); - ASSERT_OR_RETURN(varint.value == 0x09); - - // packet id - MAX_VARINT_OR_DIE(varint, start, payload_end, data_end, VARINT_SIZE(0x01)); - ASSERT_OR_RETURN(varint.value == 0x01); - - __u64 timestamp; - READ_VAL_OR_RETURN(timestamp, start, payload_end, data_end); - return start == payload_end; -} - -// checks if the packet contains a valid status request -static __always_inline __u8 inspect_status_request(__u8 *start, const __u8 *payload_end, const void *data_end) -{ - struct varint_value varint; - - // max 1 byte - MAX_VARINT_OR_DIE(varint, start, payload_end, data_end, VARINT_SIZE(0x01)); - ASSERT_OR_RETURN(varint.value == 0x01); - - // packet id - MAX_VARINT_OR_DIE(varint, start, payload_end, data_end, VARINT_SIZE(0x00)); - ASSERT_OR_RETURN(varint.value == 0x00); - - return start == payload_end; -} - -// checks if the packet contains a valid login request -// see https://github.com/SpigotMC/BungeeCord/blob/master/protocol/src/main/java/net/md_5/bungee/protocol/packet/LoginRequest.java -static __always_inline __u8 inspect_login_packet(__u8 *reader_index, const __u8 *payload_end, __s32 protocol_version, const void *data_end) -{ - // length of the packet - struct varint_value varint; - - // len 3 bytes varint max - MAX_VARINT_OR_DIE(varint, reader_index, payload_end, data_end, VARINT_SIZE((PACKET_ID_MAX + LOGIN_DATA_MAX))); - ASSERT_IN_RANGE(varint.value, PACKET_ID_MIN + LOGIN_DATA_MIN, PACKET_ID_MAX + LOGIN_DATA_MAX); - - // packet id - MAX_VARINT_OR_DIE(varint, reader_index, payload_end, data_end, VARINT_SIZE(0x00)); - ASSERT_OR_RETURN(varint.value == 0x00); - - // username length - MAX_VARINT_OR_DIE(varint, reader_index, payload_end, data_end, VARINT_SIZE(LOGIN_NAME_DATA_MAX)); - // bounce check, invalid username - ASSERT_IN_RANGE(varint.value, LOGIN_NAME_DATA_MIN, LOGIN_NAME_DATA_MAX); - // skip the username data - READ_OR_RETURN(reader_index, varint.value, payload_end, data_end); - - // 1_19 1_19_3 - if (protocol_version >= 759 && protocol_version < 761) - { - __u8 has_public_key; - READ_VAL_OR_RETURN(has_public_key, reader_index, payload_end, data_end); - if (has_public_key) - { - // public key length - READ_OR_RETURN(reader_index, 8, payload_end, data_end); - - // login key - MAX_VARINT_OR_DIE(varint, reader_index, payload_end, data_end, VARINT_SIZE(LOGIN_KEY_MAX)); - // assert reasonable size - ASSERT_IN_RANGE(varint.value, LOGIN_KEY_MIN, LOGIN_KEY_MAX); - // skip login key - READ_OR_RETURN(reader_index, varint.value, payload_end, data_end); - - // signaturey length - MAX_VARINT_OR_DIE(varint, reader_index, payload_end, data_end, VARINT_SIZE(LOGIN_SIGNATURE_MAX)); - // assert reasonable size - ASSERT_IN_RANGE(varint.value, LOGIN_SIGNATURE_MIN, LOGIN_SIGNATURE_MAX); - // skip signature - READ_OR_RETURN(reader_index, varint.value, payload_end, data_end); - } - } - // 1_19_1 - if (protocol_version >= 760) - { - // 1_20_2 - if (protocol_version >= 764) - { - // check space for uuid - READ_OR_RETURN(reader_index, 16, payload_end, data_end); - } - else - { - // check space for uuid and boolean - __u8 has_uuid; - READ_VAL_OR_RETURN(has_uuid, reader_index, payload_end, data_end); - if (has_uuid) - { - READ_OR_RETURN(reader_index, 16, payload_end, data_end); - } - } - } - // no data left to read, this is a valid login packet - return reader_index == payload_end; -} - -// check for valid handshake packet -// note: it happens that the handshake and login or status request are in the same packet, -// so we have to check for both cases here. this can also happen after retransmission. -static __always_inline __s32 inspect_handshake(__u8 *reader_index, const __u8 *payload_end, __s32 *protocol_version, const void *data_end, __u8 **current_reader_index) -{ - CHECK_BOUNDS_OR_RETURN(reader_index, 1, payload_end, data_end); - // check for legacy ping - if (reader_index[0] == (__u8)0xFE) - { - return RECEIVED_LEGACY_PING; - } - - struct varint_value varint; - // len 3 bytes varint max - MAX_VARINT_OR_DIE(varint, reader_index, payload_end, data_end, VARINT_SIZE((PACKET_ID_MAX + HANDSHAKE_DATA_MAX))); - ASSERT_IN_RANGE(varint.value, (PACKET_ID_MIN + HANDSHAKE_DATA_MIN), (PACKET_ID_MAX + HANDSHAKE_DATA_MAX)); - // packet id - MAX_VARINT_OR_DIE(varint, reader_index, payload_end, data_end, VARINT_SIZE(0x00)); - ASSERT_OR_RETURN(varint.value == 0x00); // packet id needs to be 0 - // protocol version - VARINT_OR_DIE(varint, reader_index, payload_end, data_end); - *protocol_version = varint.value; - // host len - MAX_VARINT_OR_DIE(varint, reader_index, payload_end, data_end, VARINT_SIZE(HANDSHAKE_HOST_DATA_MAX)); - ASSERT_IN_RANGE(varint.value, HANDSHAKE_HOST_DATA_MIN, HANDSHAKE_HOST_DATA_MAX); - // read host - READ_OR_RETURN(reader_index, varint.value, payload_end, data_end); - // read port - READ_OR_RETURN(reader_index, 2, payload_end, data_end); - // intention - MAX_VARINT_OR_DIE(varint, reader_index, payload_end, data_end, VARINT_SIZE(3)); - __s32 intention = varint.value; - __u8 support_transfer = *protocol_version >= 766; - - // valid intentions: 1 (status), 2 (login), 3 (login with transfer request) since 766 - ASSERT_OR_RETURN((intention == 1 || intention == 2 || (support_transfer && intention == 3))); - - // this packet contained exactly the handshake - if (reader_index == payload_end) - { - return intention == 1 ? AWAIT_STATUS_REQUEST : AWAIT_LOGIN; - } - - *current_reader_index = reader_index; - return intention == 1 ? DIRECT_READ_STATUS_REQUEST : DIRECT_READ_LOGIN; -} - -#endif diff --git a/loader/config.rs b/loader/config.rs new file mode 100644 index 0000000..c245b8b --- /dev/null +++ b/loader/config.rs @@ -0,0 +1,459 @@ +use anyhow::{Context, Result, bail}; +use log::LevelFilter; +use serde::{Deserialize, Serialize}; +use std::fmt; +use std::path::Path; + +/// Default configuration file written on first run, with documentation for +/// every option. Mirrors [`Config::default`]. +pub const DEFAULT_CONFIG_TOML: &str = r#"# Minecraft XDP filter configuration. +# This file is applied at load time; restart the loader after editing. +# Every option is listed with its default value. + +[filter] +# Inclusive TCP destination port range to protect. +# Use the same value for both to protect a single port. +start_port = 25565 +end_port = 25565 + +# SYN connection throttle: max new connections (SYNs) per source IP within +# each throttle window (see hit_count_reset_secs). Set to 0 to disable +# throttling. +hit_count = 10 + +# Length of the throttle window in seconds, enforced inside the eBPF program: +# each source IP gets its own window starting at its first SYN, and the counter +# resets in-kernel once the window expires. e.g. hit_count = 10 with +# hit_count_reset_secs = 3 allows 10 new connections per source IP every +# 3 seconds. Must be between 1 and 86400 (one day). +hit_count_reset_secs = 3 + +# Idle timeout for verified player connections in seconds, enforced inside the +# eBPF program: a connection's entry is removed after one to two timeout +# intervals without packets, so a returning client has to redo the handshake. +# Must be between 1 and 86400 (one day). +player_idle_timeout_secs = 60 + +# Enforce online-mode username rules during login inspection. +# true -> usernames are limited to 16 characters (Mojang online mode). +# false -> allow the protocol maximum (offline / cracked servers). +online_names = true + +[xdp] +# How the filter attaches to the network interface. +# "auto" -> native driver mode if the NIC supports XDP, with automatic +# fallback to generic mode otherwise. Recommended. +# "driver" -> force native driver mode: fastest, but fails on NICs without +# XDP support (including most virtual machines). +# "skb" -> force generic (skb) mode: works on any interface but is +# slower; use it when the driver misbehaves in native mode. +mode = "auto" + +# Sizes of the in-kernel connection tables below. The maps are preallocated, +# so higher limits cost kernel memory up front, not per connection. + +# Connections that have not finished the Minecraft handshake yet. The +# least-recently-used entry is evicted when the table is full, bounding how +# many half-open connections an attacker can keep alive at once. +max_pending_connections = 16384 + +# Verified player connections. New connections cannot be verified while this +# table is full, so keep it well above the expected player count. +max_player_connections = 65535 + +# Source IPs tracked by the SYN throttle. When the table is full, new SYNs +# are dropped (fail closed) until expired windows are reclaimed in-kernel. +max_throttled_ips = 65535 + +[metrics] +# Collect packet statistics inside the eBPF program. Required for any metrics +# output. Adds a small per-packet cost, so it is disabled by default. +enabled = false + +# Address to expose Prometheus metrics on (only used when enabled = true). +# Leave commented out to collect stats without starting the HTTP server. +# addr = "0.0.0.0:1999" + +# How often the in-kernel statistics are read and published, in seconds. +poll_secs = 10 + +[logging] +# Verbosity of console and file logging: +# "off", "error", "warn", "info", "debug" or "trace". +# The RUST_LOG environment variable overrides this setting at runtime. +level = "info" + +# The log file (xdp-loader.log) is rotated once it grows past this many +# megabytes; the 5 most recent rotated files are kept. +file_max_mb = 100 +"#; + +/// Runtime configuration for the XDP filter, grouped like the TOML file. +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +#[serde(default, deny_unknown_fields)] +pub struct Config { + pub filter: FilterConfig, + pub xdp: XdpConfig, + pub metrics: MetricsConfig, + pub logging: LoggingConfig, +} + +/// `[filter]` — which traffic is filtered and how strictly. +/// +/// These fields are pushed into the eBPF program's `volatile const` globals at +/// load time; see `load_and_attach` in `loader/ebpf.rs`. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(default, deny_unknown_fields)] +pub struct FilterConfig { + /// First port of the inclusive filtered range. Maps to `START_PORT`. + pub start_port: u16, + /// Last port of the inclusive filtered range. Maps to `END_PORT`. + pub end_port: u16, + /// Max SYNs per source IP per throttle window (0 = disabled). Maps to `HIT_COUNT`. + pub hit_count: u32, + /// Throttle window length in seconds; each IP's SYN counter resets in-kernel + /// once its window expires. Maps to `HIT_COUNT_RESET_NS` (converted to ns). + pub hit_count_reset_secs: u64, + /// Idle timeout for verified connections in seconds; an entry is removed + /// in-kernel after one to two intervals without packets. Maps to + /// `PLAYER_IDLE_NS` (converted to ns). + pub player_idle_timeout_secs: u64, + /// Enforce online-mode (<= 16 char) usernames. Maps to `ONLINE_NAMES`. + pub online_names: bool, +} + +impl Default for FilterConfig { + fn default() -> Self { + Self { + start_port: 25565, + end_port: 25565, + hit_count: 10, + hit_count_reset_secs: 3, + player_idle_timeout_secs: 60, + online_names: true, + } + } +} + +/// `[xdp]` — how the program attaches and the capacity of its maps. +/// +/// The capacities override the placeholder `max_entries` of the corresponding +/// map in `xdp/minecraft_filter.c` at load time. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(default, deny_unknown_fields)] +pub struct XdpConfig { + /// XDP attach mode. + pub mode: XdpMode, + /// Capacity of `conntrack_map`: concurrent unverified (mid-handshake) + /// connections, oldest evicted when full. + pub max_pending_connections: u32, + /// Capacity of `player_connection_map`: concurrent verified connections. + pub max_player_connections: u32, + /// Capacity of `connection_throttle`: source IPs with an active throttle + /// window. + pub max_throttled_ips: u32, +} + +impl Default for XdpConfig { + fn default() -> Self { + Self { + mode: XdpMode::Auto, + max_pending_connections: 16384, + max_player_connections: 65535, + max_throttled_ips: 65535, + } + } +} + +/// XDP attach mode, the `mode` option in `[xdp]`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum XdpMode { + /// Native driver mode when the NIC supports it, automatic fallback to + /// generic mode otherwise. + Auto, + /// Force native driver mode, fail when unsupported. + Driver, + /// Force generic (skb) mode. + Skb, +} + +impl fmt::Display for XdpMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + XdpMode::Auto => "auto", + XdpMode::Driver => "driver", + XdpMode::Skb => "skb", + }) + } +} + +/// `[metrics]` — statistics collection and the Prometheus endpoint. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(default, deny_unknown_fields)] +pub struct MetricsConfig { + /// Collect statistics inside the eBPF program. Maps to `PROMETHEUS`. + pub enabled: bool, + /// Optional address for the Prometheus HTTP endpoint (e.g. `0.0.0.0:1999`). + pub addr: Option, + /// Seconds between reads of the in-kernel statistics map. + pub poll_secs: u64, +} + +impl Default for MetricsConfig { + fn default() -> Self { + Self { + enabled: false, + addr: None, + poll_secs: 10, + } + } +} + +/// `[logging]` — console/file log verbosity and rotation. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(default, deny_unknown_fields)] +pub struct LoggingConfig { + /// Log verbosity, overridable at runtime via the `RUST_LOG` env variable. + pub level: LogLevel, + /// Rotate the log file once it grows past this size in megabytes. + pub file_max_mb: u64, +} + +impl Default for LoggingConfig { + fn default() -> Self { + Self { + level: LogLevel::Info, + file_max_mb: 100, + } + } +} + +/// Log verbosity, the `level` option in `[logging]`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + Off, + Error, + Warn, + Info, + Debug, + Trace, +} + +impl From for LevelFilter { + fn from(level: LogLevel) -> Self { + match level { + LogLevel::Off => LevelFilter::Off, + LogLevel::Error => LevelFilter::Error, + LogLevel::Warn => LevelFilter::Warn, + LogLevel::Info => LevelFilter::Info, + LogLevel::Debug => LevelFilter::Debug, + LogLevel::Trace => LevelFilter::Trace, + } + } +} + +impl Config { + /// Load the configuration from `path`, creating a documented default file if + /// it does not exist yet. + /// + /// Runs before logging is initialized (the log level lives in here), so the + /// "wrote defaults" notice goes to stderr directly. + pub fn load(path: &Path) -> Result { + if !path.exists() { + eprintln!( + "Config file '{}' not found, writing documented defaults", + path.display() + ); + std::fs::write(path, DEFAULT_CONFIG_TOML).with_context(|| { + format!("failed to write default config to '{}'", path.display()) + })?; + } + + let contents = std::fs::read_to_string(path) + .with_context(|| format!("failed to read config file '{}'", path.display()))?; + let config: Config = toml::from_str(&contents) + .with_context(|| format!("failed to parse config file '{}'", path.display()))?; + + config.validate()?; + Ok(config) + } + + /// Reject values the eBPF program or loader cannot represent sensibly. + fn validate(&self) -> Result<()> { + self.filter.validate()?; + self.xdp.validate()?; + self.metrics.validate()?; + self.logging.validate() + } +} + +/// Bail with a uniform message unless `min <= value <= max`. +fn check_range(option: &str, value: u64, min: u64, max: u64) -> Result<()> { + if value < min || value > max { + bail!("{option} must be between {min} and {max} (got {value})"); + } + Ok(()) +} + +impl FilterConfig { + fn validate(&self) -> Result<()> { + if self.start_port == 0 { + bail!("[filter] start_port must be >= 1"); + } + if self.start_port > self.end_port { + bail!( + "[filter] start_port ({}) must be <= end_port ({})", + self.start_port, + self.end_port + ); + } + check_range( + "[filter] hit_count_reset_secs", + self.hit_count_reset_secs, + 1, + 86_400, + )?; + check_range( + "[filter] player_idle_timeout_secs", + self.player_idle_timeout_secs, + 1, + 86_400, + ) + } +} + +impl XdpConfig { + fn validate(&self) -> Result<()> { + const MAX_ENTRIES: u64 = 1 << 20; + check_range( + "[xdp] max_pending_connections", + self.max_pending_connections as u64, + 1, + MAX_ENTRIES, + )?; + check_range( + "[xdp] max_player_connections", + self.max_player_connections as u64, + 1, + MAX_ENTRIES, + )?; + check_range( + "[xdp] max_throttled_ips", + self.max_throttled_ips as u64, + 1, + MAX_ENTRIES, + ) + } +} + +impl MetricsConfig { + fn validate(&self) -> Result<()> { + check_range("[metrics] poll_secs", self.poll_secs, 1, 3_600) + } +} + +impl LoggingConfig { + fn validate(&self) -> Result<()> { + check_range("[logging] file_max_mb", self.file_max_mb, 1, 10_240) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn embedded_default_matches_struct_default() { + // The documented template shipped to users must parse and agree with + // Config::default (metrics addr is commented out -> None). + let parsed: Config = toml::from_str(DEFAULT_CONFIG_TOML).expect("default toml parses"); + assert_eq!(parsed, Config::default()); + parsed.validate().expect("default config is valid"); + } + + #[test] + fn unknown_keys_are_rejected() { + let err = toml::from_str::("nonsense_key = 1").unwrap_err(); + assert!(err.to_string().contains("nonsense_key")); + // also inside a section + assert!(toml::from_str::("[filter]\nnonsense_key = 1").is_err()); + } + + #[test] + fn partial_config_falls_back_to_defaults() { + let cfg: Config = + toml::from_str("[filter]\nhit_count = 0\n\n[metrics]\nenabled = true").unwrap(); + assert_eq!(cfg.filter.hit_count, 0); // throttle disabled + assert!(cfg.metrics.enabled); + assert_eq!(cfg.filter.start_port, 25565); // default preserved + assert_eq!(cfg.xdp.mode, XdpMode::Auto); // untouched section defaulted + } + + #[test] + fn rejects_inverted_port_range() { + let cfg: Config = toml::from_str("[filter]\nstart_port = 30000\nend_port = 25565").unwrap(); + assert!(cfg.validate().is_err()); + } + + #[test] + fn rejects_zero_start_port() { + let cfg: Config = toml::from_str("[filter]\nstart_port = 0").unwrap(); + assert!(cfg.validate().is_err()); + } + + #[test] + fn rejects_zero_reset_window() { + let cfg: Config = toml::from_str("[filter]\nhit_count_reset_secs = 0").unwrap(); + assert!(cfg.validate().is_err()); + } + + #[test] + fn rejects_oversized_reset_window() { + let cfg: Config = toml::from_str("[filter]\nhit_count_reset_secs = 86401").unwrap(); + assert!(cfg.validate().is_err()); + } + + #[test] + fn rejects_zero_idle_timeout() { + let cfg: Config = toml::from_str("[filter]\nplayer_idle_timeout_secs = 0").unwrap(); + assert!(cfg.validate().is_err()); + } + + #[test] + fn rejects_oversized_idle_timeout() { + let cfg: Config = toml::from_str("[filter]\nplayer_idle_timeout_secs = 86401").unwrap(); + assert!(cfg.validate().is_err()); + } + + #[test] + fn rejects_unknown_xdp_mode() { + let err = toml::from_str::("[xdp]\nmode = \"hardware\"").unwrap_err(); + assert!(err.to_string().contains("hardware")); + } + + #[test] + fn rejects_zero_map_capacity() { + for option in [ + "max_pending_connections", + "max_player_connections", + "max_throttled_ips", + ] { + let cfg: Config = toml::from_str(&format!("[xdp]\n{option} = 0")).unwrap(); + assert!(cfg.validate().is_err(), "{option} = 0 must be rejected"); + } + } + + #[test] + fn rejects_zero_poll_interval() { + let cfg: Config = toml::from_str("[metrics]\npoll_secs = 0").unwrap(); + assert!(cfg.validate().is_err()); + } + + #[test] + fn rejects_unknown_log_level() { + let err = toml::from_str::("[logging]\nlevel = \"verbose\"").unwrap_err(); + assert!(err.to_string().contains("verbose")); + } +} diff --git a/loader/ebpf.rs b/loader/ebpf.rs new file mode 100644 index 0000000..adaa0e8 --- /dev/null +++ b/loader/ebpf.rs @@ -0,0 +1,95 @@ +use anyhow::{Context, Result}; +use aya::programs::{Xdp, XdpFlags}; +use aya::{Ebpf, EbpfLoader, include_bytes_aligned}; +use log::info; + +use crate::config::{Config, XdpMode}; + +/// Loads the embedded eBPF object, applies the runtime configuration and +/// attaches the XDP program to `interface`. +/// +/// The returned handle owns the attachment: dropping it detaches the filter. +/// No userspace involvement is needed while it runs, all map cleanup +/// (throttle windows, idle player connections) happens in-kernel via +/// `bpf_timer`. +pub fn load_and_attach(interface: &str, config: &Config) -> Result { + let object = + include_bytes_aligned!(concat!(env!("CARGO_MANIFEST_DIR"), "/xdp/minecraft_filter.o")); + info!("Loaded BPF object ({} bytes)", object.len()); + + // Push the runtime configuration into the program's `volatile const` + // globals (BPF .rodata). Each Rust type MUST match its C declaration in + // xdp/config.h exactly, since set_global() patches size_of::() bytes at + // the symbol's offset. `must_exist = true` fails loudly if a symbol is + // missing (e.g. renamed on the C side) instead of silently ignoring the + // configured value. + let filter = &config.filter; + let prometheus: u8 = config.metrics.enabled as u8; + let online_names: u8 = filter.online_names as u8; + let start_port: u32 = filter.start_port as u32; + let end_port: u32 = filter.end_port as u32; + let hit_count: u32 = filter.hit_count; + let hit_count_reset_ns: u64 = filter.hit_count_reset_secs * 1_000_000_000; + let player_idle_ns: u64 = filter.player_idle_timeout_secs * 1_000_000_000; + + let mut ebpf = EbpfLoader::new() + .set_global("PROMETHEUS", &prometheus, true) + .set_global("ONLINE_NAMES", &online_names, true) + .set_global("START_PORT", &start_port, true) + .set_global("END_PORT", &end_port, true) + .set_global("HIT_COUNT", &hit_count, true) + .set_global("HIT_COUNT_RESET_NS", &hit_count_reset_ns, true) + .set_global("PLAYER_IDLE_NS", &player_idle_ns, true) + // Replace the placeholder map capacities baked into the object with + // the configured ones. The names must match the map definitions in + // xdp/minecraft_filter.c; unlike set_global() there is no must_exist + // flag, a mismatched name would be silently ignored. + .set_max_entries("conntrack_map", config.xdp.max_pending_connections) + .set_max_entries("player_connection_map", config.xdp.max_player_connections) + .set_max_entries("connection_throttle", config.xdp.max_throttled_ips) + .load(object) + .context("failed to load BPF program")?; + + let program: &mut Xdp = ebpf + .program_mut("minecraft_filter") + .context("program 'minecraft_filter' not found")? + .try_into()?; + program.load()?; + + // Auto does its own driver -> skb fallback instead of leaving the choice + // to the kernel, so the mode that is actually active can be logged. + let (link, mode) = match config.xdp.mode { + XdpMode::Auto => match program.attach(interface, XdpFlags::DRV_MODE) { + Ok(link) => (link, XdpMode::Driver), + Err(err) => { + info!("'{interface}' does not support native XDP ({err}), using generic skb mode"); + let link = program + .attach(interface, XdpFlags::SKB_MODE) + .with_context(|| format!("failed to attach to interface '{interface}'"))?; + (link, XdpMode::Skb) + } + }, + XdpMode::Driver => { + let link = program.attach(interface, XdpFlags::DRV_MODE).with_context(|| { + format!( + "failed to attach to interface '{interface}' in native driver mode \ + (the NIC driver may not support XDP; try mode = \"skb\")" + ) + })?; + (link, XdpMode::Driver) + } + XdpMode::Skb => { + let link = program + .attach(interface, XdpFlags::SKB_MODE) + .with_context(|| format!("failed to attach to interface '{interface}' in skb mode"))?; + (link, XdpMode::Skb) + } + }; + info!("BPF program attached to interface {interface} in {mode} mode ({link:?})"); + + for (name, _) in ebpf.maps() { + info!("Found map: {name}"); + } + + Ok(ebpf) +} diff --git a/loader/logging.rs b/loader/logging.rs new file mode 100644 index 0000000..cfb24dc --- /dev/null +++ b/loader/logging.rs @@ -0,0 +1,81 @@ +use anyhow::Result; +use colored::Colorize; +use fern::colors::{Color, ColoredLevelConfig}; +use file_rotate::compression::Compression; +use file_rotate::suffix::AppendCount; +use file_rotate::{ContentLimit, FileRotate}; +use log::LevelFilter; + +use crate::config::LoggingConfig; + +const TIMESTAMP_FORMAT: &str = "%Y-%m-%d %H:%M:%S"; +const LOG_FILE: &str = "xdp-loader.log"; +const LOG_FILES_KEPT: usize = 5; + +/// The configured level, overridable at runtime via `RUST_LOG`. +fn level_filter(config: &LoggingConfig) -> LevelFilter { + match std::env::var("RUST_LOG") { + Ok(var) => match var.to_lowercase().as_str() { + "off" => LevelFilter::Off, + "error" => LevelFilter::Error, + "warn" => LevelFilter::Warn, + "info" => LevelFilter::Info, + "debug" => LevelFilter::Debug, + "trace" => LevelFilter::Trace, + _ => config.level.into(), + }, + Err(_) => config.level.into(), + } +} + +/// Initializes logging to stdout (colored) and to a rotating log file, with +/// level and rotation size taken from `[logging]` in the config. +pub fn init(config: &LoggingConfig) -> Result<()> { + let colors = ColoredLevelConfig::new() + .debug(Color::Magenta) + .info(Color::Green) + .warn(Color::Yellow) + .error(Color::Red); + + let console = fern::Dispatch::new() + .format(move |out, message, record| { + out.finish(format_args!( + "{} {}{}{} {}", + chrono::Local::now() + .format(TIMESTAMP_FORMAT) + .to_string() + .white(), + "[".bright_black(), + colors.color(record.level()), + "]".bright_black(), + message + )) + }) + .chain(std::io::stdout()); + + let file = fern::Dispatch::new() + .format(|out, message, record| { + out.finish(format_args!( + "{} [{}] {}", + chrono::Local::now().format(TIMESTAMP_FORMAT), + record.level(), + message + )) + }) + .chain(Box::new(FileRotate::new( + LOG_FILE, + AppendCount::new(LOG_FILES_KEPT), + ContentLimit::Bytes(config.file_max_mb as usize * 1024 * 1024), + Compression::None, + #[cfg(unix)] + None, + )) as Box); + + fern::Dispatch::new() + .level(level_filter(config)) + .chain(console) + .chain(file) + .apply()?; + + Ok(()) +} diff --git a/loader/main.rs b/loader/main.rs new file mode 100644 index 0000000..10217d7 --- /dev/null +++ b/loader/main.rs @@ -0,0 +1,137 @@ +mod config; +mod ebpf; +mod logging; +mod metrics; +mod shutdown; + +use std::path::Path; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use log::{error, info}; + +use config::Config; +use shutdown::Shutdown; + +const USAGE: &str = "\ +Usage: xdp-loader [OPTIONS] + +Arguments: + Network interface to attach to + +Options: + -c, --config Path to the TOML configuration file, created with + defaults if missing [default: config.toml] + --license Print license information + -h, --help Print help + -V, --version Print version"; + +#[derive(Debug)] +struct Args { + /// Network interface to attach to, required unless --license is given. + interface: Option, + /// Path to the TOML configuration file. + config: String, + /// Print license information instead of running. + license: bool, +} + +/// Hand-rolled argument parsing (a CLI library would double the binary size +/// for three options). Prints help/version/errors and exits where appropriate. +fn parse_args() -> Args { + let mut args = Args { + interface: None, + config: "config.toml".into(), + license: false, + }; + + let mut argv = std::env::args().skip(1); + let invalid = |message: String| -> ! { + eprintln!("error: {message}\n\n{USAGE}"); + std::process::exit(2); + }; + while let Some(arg) = argv.next() { + match arg.as_str() { + "-h" | "--help" => { + println!("{USAGE}"); + std::process::exit(0); + } + "-V" | "--version" => { + println!("xdp-loader {}", env!("CARGO_PKG_VERSION")); + std::process::exit(0); + } + "--license" => args.license = true, + "-c" | "--config" => match argv.next() { + Some(value) => args.config = value, + None => invalid(format!("missing value for '{arg} '")), + }, + _ => match arg.strip_prefix("--config=") { + Some(value) => args.config = value.into(), + None if arg.starts_with('-') => invalid(format!("unexpected option '{arg}'")), + None if args.interface.is_none() => args.interface = Some(arg), + None => invalid(format!("unexpected argument '{arg}'")), + }, + } + } + + if args.interface.is_none() && !args.license { + invalid("missing required argument ".into()); + } + args +} + +fn main() { + let args = parse_args(); + + if args.license { + println!(include_str!("../LICENSE")); + return; + } + + // the log level lives in the config, so it must be loaded before logging + // is up; until then errors can only go to stderr directly + let config = match Config::load(Path::new(&args.config)) { + Ok(config) => config, + Err(e) => { + eprintln!("failed to load config '{}': {e:#}", args.config); + std::process::exit(1); + } + }; + + logging::init(&config.logging).expect("Failed to setup logger"); + info!("Loading minecraft xdp filter v3 by Outfluencer..."); + info!("Loaded configuration: {config:?}"); + + let shutdown = Arc::new(Shutdown::new()); + shutdown::trigger_on_termination_signal(shutdown.clone()); + + if let Err(e) = run(&args, &config, &shutdown) { + error!("{e:#}"); + } + + shutdown.trigger(); + info!("Good bye!"); +} + +/// Attaches the XDP filter and keeps it alive until shutdown is triggered. +/// Dropping the `Ebpf` handle on return detaches the filter again. +fn run(args: &Args, config: &Config, shutdown: &Arc) -> Result<()> { + let interface = args + .interface + .as_deref() + .context("interface is required unless --license is specified")?; + + // keep the handle alive until the end of this function: dropping it + // detaches the XDP program + let mut ebpf = ebpf::load_and_attach(interface, config)?; + let stats_thread = metrics::start(&mut ebpf, config, shutdown)?; + + shutdown.wait(); + + if let Some(handle) = stats_thread { + handle + .join() + .map_err(|e| anyhow::anyhow!("track-stats thread panicked: {e:?}"))?; + } + Ok(()) +} diff --git a/loader/metrics.rs b/loader/metrics.rs new file mode 100644 index 0000000..57a3ec6 --- /dev/null +++ b/loader/metrics.rs @@ -0,0 +1,166 @@ +use std::sync::{Arc, LazyLock}; +use std::thread::{self, JoinHandle}; +use std::time::Duration; + +use anyhow::{Context, Result}; +use aya::maps::{MapData, PerCpuArray}; +use aya::{Ebpf, Pod}; +use log::{debug, error, info}; +use prometheus::{Encoder, IntCounter, TextEncoder, register_int_counter}; + +use crate::config::Config; +use crate::shutdown::Shutdown; + +/// Defines the userspace mirror of `struct statistics` (see xdp/stats.h) +/// together with one Prometheus counter per field, so the struct layout, the +/// per-cpu summing and the published metrics all stay in sync from a single +/// field list. Field order and types must match the C struct exactly. +macro_rules! statistics { + ($($field:ident => $metric:literal, $help:literal;)+) => { + #[repr(C)] + #[derive(Copy, Clone, Debug, Default)] + pub struct Statistics { + $(pub $field: u64,)+ + } + + // SAFETY: repr(C) struct of plain u64 fields without padding, valid + // for any bit pattern read from the map. + unsafe impl Pod for Statistics {} + + impl Statistics { + fn add(&mut self, other: &Statistics) { + $(self.$field += other.$field;)+ + } + } + + struct Counters { + $($field: IntCounter,)+ + } + + impl Counters { + fn new() -> Self { + Self { + $($field: register_int_counter!($metric, $help).unwrap(),)+ + } + } + + /// Publishes cumulative map totals as counter increments. The + /// unpinned stats map and this process always die together, so + /// publishing the delta since the last poll keeps proper counter + /// semantics; on restart both reset together, which Prometheus + /// handles as a regular counter reset. + fn publish(&self, total: &Statistics) { + $(self.$field.inc_by(total.$field.saturating_sub(self.$field.get()));)+ + } + } + }; +} + +statistics! { + verified => "minecraft_verified_connections", "Total verified connections"; + dropped_packets => "minecraft_dropped_packets", "Total dropped packets"; + state_switches => "minecraft_state_switches", "Total state switches"; + drop_connection => "minecraft_dropped_connections", "Total dropped connections"; + syn => "minecraft_syn_packets", "Total SYN packets"; + tcp_bypass => "minecraft_tcp_bypass", "Total TCP bypass attempts"; + incoming_bytes => "minecraft_incoming_bytes", "Total incoming bytes"; + dropped_bytes => "minecraft_dropped_bytes", "Total dropped bytes"; +} + +// compile-time layout check, mirrors the _Static_assert in xdp/stats.h +const _: () = assert!(std::mem::size_of::() == 64); + +static COUNTERS: LazyLock = LazyLock::new(Counters::new); + +/// Starts the metrics machinery if enabled in the config: takes ownership of +/// the stats map, spawns the polling thread and (when an address is +/// configured) the HTTP endpoint. Returns the polling thread's handle. +pub fn start( + ebpf: &mut Ebpf, + config: &Config, + shutdown: &Arc, +) -> Result>> { + if !config.metrics.enabled { + return Ok(None); + } + + let map = ebpf + .take_map("stats_map") + .context("can't take map 'stats_map'")?; + let stats = PerCpuArray::try_from(map)?; + + match &config.metrics.addr { + Some(addr) => serve_http(addr.clone()), + None => info!("Metrics collection enabled but no addr set; HTTP endpoint disabled"), + } + + let poll_interval = Duration::from_secs(config.metrics.poll_secs); + let shutdown = shutdown.clone(); + let handle = thread::Builder::new() + .name("track-stats".into()) + .spawn(move || poll_loop(stats, shutdown, poll_interval))?; + Ok(Some(handle)) +} + +/// Sums the per-cpu slices of the stats map and publishes the totals every +/// `poll_interval` until shutdown (or a map read error) ends the loop. +fn poll_loop( + stats: PerCpuArray, + shutdown: Arc, + poll_interval: Duration, +) { + loop { + match stats.get(&0, 0) { + Ok(per_cpu) => { + let mut total = Statistics::default(); + for cpu_stats in per_cpu.iter() { + total.add(cpu_stats); + } + debug!("Stats: {total:?}"); + COUNTERS.publish(&total); + } + Err(e) => { + error!("Failed to read stats map: {e}"); + shutdown.trigger(); + return; + } + } + if !shutdown.sleep(poll_interval) { + return; + } + } +} + +/// Serves the Prometheus text endpoint on `addr` from its own thread. +fn serve_http(addr: String) { + thread::spawn(move || { + let server = match tiny_http::Server::http(&addr) { + Ok(server) => server, + Err(e) => { + error!("Failed to start metrics server: {e}"); + return; + } + }; + info!("Prometheus metrics server running on {addr}/metrics"); + for request in server.incoming_requests() { + if request.url() != "/metrics" { + let _ = request.respond(tiny_http::Response::empty(404)); + continue; + } + debug!("Received metrics request from {:?}", request.remote_addr()); + let mut buffer = vec![]; + if let Err(e) = TextEncoder::new().encode(&prometheus::gather(), &mut buffer) { + error!("Failed to encode metrics: {e}"); + continue; + } + let response = tiny_http::Response::from_data(buffer).with_header( + tiny_http::Header::from_bytes( + &b"Content-Type"[..], + &b"text/plain; version=0.0.4; charset=utf-8"[..], + ) + .unwrap(), + ); + let _ = request.respond(response); + } + }); +} diff --git a/loader/shutdown.rs b/loader/shutdown.rs new file mode 100644 index 0000000..a792db7 --- /dev/null +++ b/loader/shutdown.rs @@ -0,0 +1,85 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Condvar, Mutex, PoisonError}; +use std::thread; +use std::time::Duration; + +use log::{info, warn}; +use signal_hook::consts::TERM_SIGNALS; +use signal_hook::iterator::Signals; + +/// Coordinates shutdown between the main thread, the signal handler and the +/// stats thread: a flag that can only flip to "stopped", plus a condvar so +/// blocked threads wake up immediately when that happens. +pub struct Shutdown { + running: AtomicBool, + lock: Mutex<()>, + wakeup: Condvar, +} + +impl Shutdown { + pub fn new() -> Self { + Self { + running: AtomicBool::new(true), + lock: Mutex::new(()), + wakeup: Condvar::new(), + } + } + + pub fn is_running(&self) -> bool { + self.running.load(Ordering::SeqCst) + } + + /// Flips to "stopped" (idempotent) and wakes every waiting thread. + pub fn trigger(&self) { + if self.running.swap(false, Ordering::SeqCst) { + info!("Shutting down..."); + } + // taking the lock orders this notify after any in-progress + // is_running check, so no waiter can miss the wakeup + let _guard = self.lock.lock().unwrap_or_else(PoisonError::into_inner); + self.wakeup.notify_all(); + } + + /// Blocks until [`Shutdown::trigger`] is called. + pub fn wait(&self) { + let mut guard = self.lock.lock().unwrap_or_else(PoisonError::into_inner); + while self.is_running() { + guard = self + .wakeup + .wait(guard) + .unwrap_or_else(PoisonError::into_inner); + } + } + + /// Sleeps for at most `timeout` (woken early by [`Shutdown::trigger`]) + /// and returns whether the process is still running afterwards. + pub fn sleep(&self, timeout: Duration) -> bool { + let guard = self.lock.lock().unwrap_or_else(PoisonError::into_inner); + if !self.is_running() { + return false; + } + drop( + self.wakeup + .wait_timeout(guard, timeout) + .unwrap_or_else(PoisonError::into_inner), + ); + self.is_running() + } +} + +impl Default for Shutdown { + fn default() -> Self { + Self::new() + } +} + +/// Spawns a thread that triggers `shutdown` on the first termination signal. +pub fn trigger_on_termination_signal(shutdown: Arc) { + let mut signals = Signals::new(TERM_SIGNALS).expect("Couldn't register signals"); + thread::spawn(move || { + if let Some(signal) = signals.forever().next() { + warn!("Received termination signal: {signal}"); + shutdown.trigger(); + } + }); +} diff --git a/src/common.rs b/src/common.rs deleted file mode 100644 index 3bd4e37..0000000 --- a/src/common.rs +++ /dev/null @@ -1,71 +0,0 @@ -use aya::Pod; -use std::hash::Hash; - -#[repr(C)] -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Ipv4AddrImpl { - pub data: u32, -} -unsafe impl Pod for Ipv4AddrImpl {} -const _: () = assert!(std::mem::size_of::() == 4); - -impl std::fmt::Display for Ipv4AddrImpl { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "[{}]", network_address_to_string(self.data)) - } -} - -/// Equivalent to `struct ipv4_flow_key` -#[repr(C)] -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] -pub struct Ipv4FlowKey { - pub src_ip: u32, - pub dst_ip: u32, - pub src_port: u16, - pub dst_port: u16, -} - -unsafe impl Pod for Ipv4FlowKey {} - -// Compile-time check: size == 12 bytes -const _: () = assert!(std::mem::size_of::() == 12); - -/// Equivalent to `struct statistics` -#[repr(C)] -#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Statistics { - pub verified: u64, - pub dropped_packets: u64, - pub state_switches: u64, - pub drop_connection: u64, - pub syn: u64, - pub tcp_bypass: u64, - pub incoming_bytes: u64, - pub dropped_bytes: u64, -} - -unsafe impl Pod for Statistics {} - -// Compile-time check: size == 64 bytes -const _: () = assert!(std::mem::size_of::() == 64); - -pub fn network_address_to_string(ip: u32) -> String { - std::net::Ipv4Addr::from(ip.swap_bytes()).to_string() -} - -pub fn network_port_to_normal(port: u16) -> u16 { - port.swap_bytes() -} - -impl std::fmt::Display for Ipv4FlowKey { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "[{}:{} -> {}:{}]", - network_address_to_string(self.src_ip), - network_port_to_normal(self.src_port), - network_address_to_string(self.dst_ip), - network_port_to_normal(self.dst_port) - ) - } -} diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 34f14d8..0000000 --- a/src/main.rs +++ /dev/null @@ -1,543 +0,0 @@ -use crate::common::Ipv4AddrImpl; -use crate::mapimpl::XdpMapAbstraction; -use anyhow::Context; -use anyhow::Result; -use aya::{ - Ebpf, include_bytes_aligned, - maps::{HashMap, MapData, PerCpuArray, PerCpuHashMap, PerCpuValues}, - programs::{Xdp, XdpFlags}, -}; -use clap::Parser; -use colored::Colorize; -use common::{Ipv4FlowKey, Statistics}; -use fern::colors::Color; -use file_rotate::{ContentLimit, FileRotate, compression::Compression, suffix::AppendCount}; -use lazy_static::lazy_static; -use log::LevelFilter; -use log::debug; -use log::warn; -use log::{error, info}; -#[cfg(prometheus_metrics)] -use prometheus::{IntGauge, register_int_gauge}; -use signal_hook::consts::TERM_SIGNALS; -use signal_hook::iterator::Signals; -use std::fmt::Display; -use std::{ - env, - sync::{ - Arc, Condvar, Mutex, - atomic::{AtomicBool, Ordering}, - }, - thread, - time::Duration, -}; - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Network interface to attach to - #[arg(required_unless_present = "license")] - interface: Option, - - /// Print license information - #[arg(long, action)] - license: bool, - - /// Address to bind the metrics server to - #[cfg(prometheus_metrics)] - #[arg(long)] - metrics_addr: Option, -} - -mod common; -mod mapimpl; - -const OLD_CONNECTION_TIMEOUT: u64 = 60; // every 60 seconds -const THROTTLE_CLEAR_CYCLE: u64 = 3; // every 3 seconds -#[cfg(prometheus_metrics)] -const STATS_TRACKING_CYCLE: u64 = 10; // every 10 seconds - -#[cfg(prometheus_metrics)] -lazy_static! { - static ref INCOMING_BYTES: IntGauge = - register_int_gauge!("minecraft_incoming_bytes", "Total incoming bytes").unwrap(); - static ref DROPPED_BYTES: IntGauge = - register_int_gauge!("minecraft_dropped_bytes", "Total dropped bytes").unwrap(); - static ref VERIFIED: IntGauge = register_int_gauge!( - "minecraft_verified_connections", - "Total verified connections" - ) - .unwrap(); - static ref DROPPED_PACKETS: IntGauge = - register_int_gauge!("minecraft_dropped_packets", "Total dropped packets").unwrap(); - static ref STATE_SWITCHES: IntGauge = - register_int_gauge!("minecraft_state_switches", "Total state switches").unwrap(); - static ref DROP_CONNECTION: IntGauge = - register_int_gauge!("minecraft_dropped_connections", "Total dropped connections").unwrap(); - static ref SYN: IntGauge = - register_int_gauge!("minecraft_syn_packets", "Total SYN packets").unwrap(); - static ref TCP_BYPASS: IntGauge = - register_int_gauge!("minecraft_tcp_bypass", "Total TCP bypass attempts").unwrap(); -} - -fn setup_logger() -> Result<(), anyhow::Error> { - let colors = fern::colors::ColoredLevelConfig::new() - .debug(Color::Magenta) - .info(Color::Green) - .warn(Color::Yellow) - .error(Color::Red); - - let level_filter = match std::env::var("RUST_LOG") { - Ok(var) => match var.to_lowercase().as_str() { - "off" => LevelFilter::Off, - "error" => LevelFilter::Error, - "warn" => LevelFilter::Warn, - "info" => LevelFilter::Info, - "debug" => LevelFilter::Debug, - "trace" => LevelFilter::Trace, - _ => LevelFilter::Info, - }, - Err(_) => { - #[cfg(debug_assertions)] - { - LevelFilter::Debug - } - #[cfg(not(debug_assertions))] - { - LevelFilter::Info - } - } - }; - - let console_dispatch = fern::Dispatch::new() - .format(move |out, message, record| { - out.finish(format_args!( - "{} {}{}{} {}", - chrono::Local::now() - .format("%Y-%m-%d %H:%M:%S") - .to_string() - .white(), - "[".bright_black(), - colors.color(record.level()), - "]".bright_black(), - message - )) - }) - .chain(std::io::stdout()); - - let file_dispatch = fern::Dispatch::new() - .format(|out, message, record| { - out.finish(format_args!( - "{} [{}] {}", - chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), - record.level(), - message - )) - }) - .chain(Box::new(FileRotate::new( - "xdp-loader.log", - AppendCount::new(5), - ContentLimit::Bytes(100 * 1024 * 1024), // 100 MB - Compression::None, - #[cfg(unix)] - None, - )) as Box); - - fern::Dispatch::new() - .level(level_filter) - .chain(console_dispatch) - .chain(file_dispatch) - .apply()?; - - Ok(()) -} - -fn shutdown(running: Arc, condvar: Arc) { - if running.load(Ordering::SeqCst) { - info!("Shutting down..."); - running.store(false, Ordering::SeqCst); - condvar.notify_all(); - } -} - -fn main() { - let args = Args::parse(); - - if std::env::var("RUST_LOG").is_err() { - unsafe { - #[cfg(debug_assertions)] - std::env::set_var("RUST_LOG", "debug"); - #[cfg(not(debug_assertions))] - std::env::set_var("RUST_LOG", "info"); - } - } - - if args.license { - println!(include_str!("../LICENSE")); - return; - } - setup_logger().expect("Failed to setup logger"); - - info!("Loading minecraft xdp filter v2.1 by Outfluencer..."); - - let running = Arc::new(AtomicBool::new(true)); - let condvar = Arc::new(Condvar::new()); - - start_shutdown_hook(running.clone(), condvar.clone()); - - #[cfg(prometheus_metrics)] - if let Some(addr) = args.metrics_addr { - start_metrics_server(addr); - } - - let mut epbf: Option = None; - if let Some(interface) = args.interface { - match load(&interface, running.clone(), condvar.clone()) { - Err(e) => { - error!("Failed to load BPF program: {}", e); - } - Ok(value) => { - epbf = Some(value); - } - } - } else { - error!("Interface is required unless --license is specified"); - } - - shutdown(running, condvar); - drop(epbf); - - info!("Good bye!"); -} - -fn start_shutdown_hook(arc: Arc, condvar: Arc) { - let mut signals = Signals::new(TERM_SIGNALS).expect("Couldn't register signals"); - thread::spawn(move || { - for signal in signals.forever() { - warn!("Received termination signal: {signal}"); - shutdown(arc, condvar); - break; // Stop on first termination signal - } - }); -} -#[cfg(prometheus_metrics)] -fn start_metrics_server(addr: String) { - thread::spawn(move || { - let server = match tiny_http::Server::http(&addr) { - Ok(s) => s, - Err(e) => { - error!("Failed to start metrics server: {}", e); - return; - } - }; - info!("Prometheus metrics server running on {}/metrics", addr); - for request in server.incoming_requests() { - if request.url() == "/metrics" { - use prometheus::Encoder; - - debug!("Received metrics request from {:?}", request.remote_addr()); - let encoder = prometheus::TextEncoder::new(); - let metric_families = prometheus::gather(); - let mut buffer = vec![]; - if let Err(e) = encoder.encode(&metric_families, &mut buffer) { - error!("Failed to encode metrics: {}", e); - continue; - } - let response = tiny_http::Response::from_data(buffer) - .with_header(tiny_http::Header::from_bytes( - &b"Content-Type"[..], - &b"text/plain; version=0.0.4; charset=utf-8"[..], - ).unwrap()); - let _ = request.respond(response); - } else { - let _ = request.respond(tiny_http::Response::empty(404)); - } - } - }); -} - -fn load( - interface: &str, - running: Arc, - condvar: Arc, -) -> Result { - let data = include_bytes_aligned!(concat!(env!("CARGO_MANIFEST_DIR"), "/c/minecraft_filter.o")); - info!("Loaded BPF program (size: {})", data.len()); - - let mut ebpf = Ebpf::load(data)?; - - let programm: &mut Xdp = ebpf - .program_mut("minecraft_filter") - .ok_or_else(|| anyhow::anyhow!("Program 'minecraft_filter' not found"))? - .try_into()?; - programm.load()?; - - let result = programm.attach(interface, XdpFlags::empty())?; - info!( - "BPF program attached to interface: {} ({:?})", - interface, result - ); - - for (name, _) in ebpf.maps() { - info!("Found map: {}", name); - } - - let player_connection_map = { - let map = ebpf - .take_map("player_connection_map") - .ok_or_else(|| anyhow::anyhow!("Can't take map 'player_connection_map'"))?; - #[cfg(ip_and_port_per_cpu)] - { - info!("Using PerCpuHashMap for player_connection_map"); - PerCpuHashMap::::try_from(map) - .context("try to get player_connection_map PerCpuHashMap")? - } - #[cfg(not(ip_and_port_per_cpu))] - { - info!("Using HashMap for player_connection_map"); - HashMap::::try_from(map) - .context("try to get player_connection_map HashMap")? - } - }; - let player_connection_map_ref = Arc::new(Mutex::new(player_connection_map)); - - let connection_throttle = { - let map = ebpf - .take_map("connection_throttle") - .ok_or_else(|| anyhow::anyhow!("Can't take map 'connection_throttle'"))?; - - #[cfg(ip_per_cpu)] - { - info!("Using PerCpuHashMap for connection_throttle"); - PerCpuHashMap::::try_from(map) - .context("try to get connection_throttle PerCpuHashMap")? - } - #[cfg(not(ip_per_cpu))] - { - info!("Using HashMap for connection_throttle"); - HashMap::::try_from(map) - .context("try to get connection_throttle HashMap")? - } - }; - let connection_throttle_ref = Arc::new(Mutex::new(connection_throttle)); - - #[cfg(prometheus_metrics)] - let stats = { - let map = ebpf - .take_map("stats_map") - .ok_or_else(|| anyhow::anyhow!("Can't take map 'stats_map'"))?; - PerCpuArray::::try_from(map)? - }; - - #[cfg(prometheus_metrics)] - let stats_ref: Arc>> = Arc::new(Mutex::new(stats)); - - let handle1 = spawn_old_connection_clear( - "clear-old", - running.clone(), - condvar.clone(), - player_connection_map_ref, - )?; - let handle2 = spawn_connection_throttle_clear( - "clear-throttle", - running.clone(), - condvar.clone(), - connection_throttle_ref, - )?; - - #[cfg(prometheus_metrics)] - let handle4 = spawn_stats_thread( - "track-stats", - running.clone(), - condvar.clone(), - stats_ref.clone(), - )?; - - let _ = handle1 - .join() - .map_err(|e| anyhow::anyhow!("clear-old thread panicked: {:?}", e))?; - let _ = handle2 - .join() - .map_err(|e| anyhow::anyhow!("clear-throttle thread panicked: {:?}", e))?; - #[cfg(prometheus_metrics)] - let _ = handle4 - .join() - .map_err(|e| anyhow::anyhow!("track-stats thread panicked: {:?}", e))?; - - Ok(ebpf) -} - -#[cfg(prometheus_metrics)] -fn spawn_stats_thread( - name: &'static str, - running: Arc, - condvar: Arc, - stats_ref: Arc>>, -) -> Result, anyhow::Error> { - thread::Builder::new() - .name(name.into()) - .spawn(move || { - if let Err(e) = track_stats(running.clone(), condvar.clone(), stats_ref) { - error!("Failed to track stats: {:?}", e); - shutdown(running, condvar); - } - }) - .map_err(|e| e.into()) -} - -#[cfg(prometheus_metrics)] -fn track_stats( - running: Arc, - condvar: Arc, - stats_ref: Arc>>, -) -> Result<(), anyhow::Error> { - let dummy_mutex = Mutex::new(()); - while running.load(Ordering::SeqCst) { - let stats = stats_ref - .lock() - .map_err(|e| anyhow::anyhow!("Mutex poisoned: {}", e))?; - - let values = stats.get(&0, 0)?; - let mut total = Statistics::default(); - for cpu_stat in values.iter() { - total.incoming_bytes += cpu_stat.incoming_bytes; - total.dropped_bytes += cpu_stat.dropped_bytes; - total.verified += cpu_stat.verified; - total.dropped_packets += cpu_stat.dropped_packets; - total.state_switches += cpu_stat.state_switches; - total.drop_connection += cpu_stat.drop_connection; - total.syn += cpu_stat.syn; - total.tcp_bypass += cpu_stat.tcp_bypass; - } - debug!( - "Stats: Incoming: {} bytes, Dropped: {} bytes, Packets Dropped: {}, Verified: {}, Syn: {}, Bypass: {}, State Switches: {}, Drop Conn: {}", - total.incoming_bytes, - total.dropped_bytes, - total.dropped_packets, - total.verified, - total.syn, - total.tcp_bypass, - total.state_switches, - total.drop_connection, - ); - - // Update Prometheus metrics - INCOMING_BYTES.set(total.incoming_bytes as i64); - DROPPED_BYTES.set(total.dropped_bytes as i64); - VERIFIED.set(total.verified as i64); - DROPPED_PACKETS.set(total.dropped_packets as i64); - STATE_SWITCHES.set(total.state_switches as i64); - DROP_CONNECTION.set(total.drop_connection as i64); - SYN.set(total.syn as i64); - TCP_BYPASS.set(total.tcp_bypass as i64); - drop(stats); // release lock before waiting - - let guard = dummy_mutex - .lock() - .map_err(|e| anyhow::anyhow!("Dummy Mutex poisoned: {}", e))?; - let _ = condvar - .wait_timeout(guard, Duration::from_secs(STATS_TRACKING_CYCLE)) - .map_err(|e| anyhow::anyhow!("condvar wait_timeout poisoned: {}", e))?; - } - Ok(()) -} - -fn spawn_connection_throttle_clear( - name: &'static str, - running: Arc, - condvar: Arc, - connection_throttle_ref: Arc>, -) -> Result, anyhow::Error> -where - M: mapimpl::XdpMapAbstraction + Send + 'static, -{ - thread::Builder::new() - .name(name.into()) - .spawn(move || { - if let Err(e) = - connection_throttle_clear(running.clone(), condvar.clone(), connection_throttle_ref) - { - error!("Failed to clear connection throttles: {:?}", e); - shutdown(running, condvar); - } - }) - .map_err(|e| e.into()) -} - -fn connection_throttle_clear( - running: Arc, - condvar: Arc, - connection_throttle_ref: Arc>, -) -> Result<(), anyhow::Error> -where - M: XdpMapAbstraction + Send + 'static, -{ - let dummy_mutex = Mutex::new(()); - while running.load(Ordering::SeqCst) { - connection_throttle_ref - .lock() - .map_err(|e| anyhow::anyhow!("Mutex poisoned: {}", e))? - .clear()?; - let guard = dummy_mutex - .lock() - .map_err(|e| anyhow::anyhow!("Dummy Mutex poisoned: {}", e))?; - let _ = condvar - .wait_timeout(guard, Duration::from_secs(THROTTLE_CLEAR_CYCLE)) - .map_err(|e| anyhow::anyhow!("condvar wait_timeout poisoned: {}", e))?; - } - Ok(()) -} - -fn spawn_old_connection_clear( - name: &'static str, - running: Arc, - condvar: Arc, - player_connection_map_ref: Arc>, -) -> Result, anyhow::Error> -where - M: XdpMapAbstraction + Send + 'static, -{ - thread::Builder::new() - .name(name.into()) - .spawn(move || { - if let Err(e) = - clear_old_connections(running.clone(), condvar.clone(), player_connection_map_ref) - { - error!("Failed to clear old connections: {:?}", e); - shutdown(running, condvar); - } - }) - .map_err(|e| e.into()) -} - -fn clear_old_connections( - running: Arc, - condvar: Arc, - player_connection_map_ref: Arc>, -) -> Result<(), anyhow::Error> -where - M: XdpMapAbstraction + Send + 'static, -{ - let dummy_mutex = Mutex::new(()); - let mut last_seen: std::collections::HashMap = std::collections::HashMap::new(); - while running.load(Ordering::SeqCst) { - let mut current_snapshot = std::collections::HashMap::new(); - player_connection_map_ref - .lock() - .map_err(|e| anyhow::anyhow!("Mutex poisoned: {}", e))? - .remove_if(|key, counter| { - let stale = last_seen.get(key).is_some_and(|prev| *prev == *counter); - current_snapshot.insert(*key, *counter); - stale - })?; - last_seen = current_snapshot; - let guard = dummy_mutex - .lock() - .map_err(|e| anyhow::anyhow!("Mutex poisoned: {}", e))?; - let _ = condvar - .wait_timeout(guard, Duration::from_secs(OLD_CONNECTION_TIMEOUT)) - .map_err(|e| anyhow::anyhow!("condvar wait_timeout poisoned: {}", e))?; - } - Ok(()) -} - diff --git a/src/mapimpl.rs b/src/mapimpl.rs deleted file mode 100644 index 3256b32..0000000 --- a/src/mapimpl.rs +++ /dev/null @@ -1,85 +0,0 @@ -use aya::{ - Pod, - maps::{HashMap, MapData, PerCpuHashMap}, -}; -use log::debug; -use std::{fmt::Display, result::Result}; - -pub trait XdpMapAbstraction { - - fn clear(&mut self) -> Result<(), anyhow::Error>; - - fn remove_if bool>(&mut self, predicate: F) -> Result<(), anyhow::Error>; -} - -impl XdpMapAbstraction for HashMap { - - fn clear(&mut self) -> Result<(), anyhow::Error> { - self.remove_if(|_, _| true) - } - - fn remove_if(&mut self, mut predicate: F) -> Result<(), anyhow::Error> - where - F: FnMut(&K, &V) -> bool, - { - let mut keys = Vec::new(); - for item in self.iter() { - let (k, v) = item?; - if predicate(&k, &v) { - debug!("Removing {}: {} from map..", k, v); - keys.push(k); - } - } - - for k in keys { - self.remove(&k)?; - } - Ok(()) - } -} - -impl XdpMapAbstraction - for PerCpuHashMap -{ - fn clear(&mut self) -> Result<(), anyhow::Error> { - self.remove_if(|_, _| true) - } - - fn remove_if(&mut self, mut predicate: F) -> Result<(), anyhow::Error> - where - F: FnMut(&K, &V) -> bool, - { - let mut keys_to_remove = Vec::new(); - - for item in self.iter() { - let (k, values) = item?; - // With 4-tuple RSS, only ONE CPU should have a non-zero value - if let Some(val) = find_active_value(&values) { - if predicate(&k, &val) { - debug!("Removing {}: {} from per-cpu map..", k, val); - keys_to_remove.push(k); - } - } - } - - for k in keys_to_remove { - self.remove(&k)?; - } - Ok(()) - } -} - - -// Helper to check if a value is all zeros (empty/unused slot) -#[inline] -fn is_zero(v: &V) -> bool { - let bytes = unsafe { - std::slice::from_raw_parts(v as *const V as *const u8, std::mem::size_of::()) - }; - bytes.iter().all(|&b| b == 0) -} - -#[inline] -fn find_active_value(values: &aya::maps::PerCpuValues) -> Option { - values.iter().find(|v| !is_zero(*v)).copied() -} diff --git a/tests/c_unit_tests.rs b/tests/c_unit_tests.rs new file mode 100644 index 0000000..e24cb93 --- /dev/null +++ b/tests/c_unit_tests.rs @@ -0,0 +1,63 @@ +use std::path::Path; +use std::process::Command; + +/// Compiles `xdp/tests/protocol_test.c` natively and runs it. +/// +/// This exercises the exact parsing code the eBPF program is built from +/// (varint reader, packet inspectors, bounds-check macros) in userspace, +/// where it can be sanitized. clang is already required to build the +/// project at all, so this adds no new dependency. +#[test] +fn c_parser_unit_tests() { + let manifest_dir = Path::new(env!("CARGO_MANIFEST_DIR")); + let binary = Path::new(env!("CARGO_TARGET_TMPDIR")).join("protocol_test"); + let binary = binary.to_str().expect("tmpdir path is valid utf-8"); + + // ASan catches reads past the exact-size test buffers; the alignment + // check is disabled because the parser does unaligned reads on purpose + // (network data), like the kernel does. + let sanitizer_flags = [ + "-fsanitize=address,undefined", + "-fno-sanitize=alignment", + "-fno-sanitize-recover=all", + ]; + let base_flags = [ + "-Wall", + "-Wextra", + "-O2", + "-g", + "-fno-strict-aliasing", + "xdp/tests/protocol_test.c", + "-o", + binary, + ]; + + let compile = |with_sanitizers: bool| { + let mut cmd = Command::new("clang"); + cmd.current_dir(manifest_dir); + if with_sanitizers { + cmd.args(sanitizer_flags); + } + cmd.args(base_flags); + cmd.output().expect("failed to run clang") + }; + + // fall back to a plain build where the sanitizer runtime is unavailable + let mut output = compile(true); + if !output.status.success() { + println!("note: sanitizer build failed, retrying without sanitizers"); + output = compile(false); + } + assert!( + output.status.success(), + "clang failed to compile the C unit tests:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + + let run = Command::new(binary) + .output() + .expect("failed to run the C unit test binary"); + println!("{}", String::from_utf8_lossy(&run.stdout)); + eprintln!("{}", String::from_utf8_lossy(&run.stderr)); + assert!(run.status.success(), "C unit tests reported failures"); +} diff --git a/xdp/common.h b/xdp/common.h new file mode 100644 index 0000000..258f3f5 --- /dev/null +++ b/xdp/common.h @@ -0,0 +1,133 @@ +#ifndef COMMON_H +#define COMMON_H + +#include + +// compiler barrier: prevents clang from merging or reordering the pointer +// arithmetic around bounds checks, which would make the verifier lose track +// of the checked range +#ifndef barrier_var +#define barrier_var(var) asm volatile("" : "+r"(var)) +#endif + +// maximum amount of out-of-order/retransmitted packets before a tracked +// connection is dropped entirely +#define MAX_OUT_OF_ORDER 4 + +/* + * Connection states stored in initial_state.state, plus pseudo states only + * returned by inspect_handshake() (RECEIVED_LEGACY_PING and the DIRECT_READ_* + * values, which signal that more protocol data follows in the same packet). + * STATE_INVALID doubles as the generic "parse failed" return value of all + * inspectors. + */ +enum connection_state +{ + STATE_INVALID = 0, + AWAIT_ACK = 1, + AWAIT_MC_HANDSHAKE = 2, + RECEIVED_LEGACY_PING = 3, // this connection will be fully dropped + AWAIT_STATUS_REQUEST = 4, + AWAIT_LOGIN = 5, + AWAIT_PING = 6, + PING_COMPLETE = 7, + DIRECT_READ_STATUS_REQUEST = 8, + DIRECT_READ_LOGIN = 9, +}; + +/* + * Bounds-check macros. + * + * All of them bail out of the CALLING function with `return 0` when the + * requested bytes are not fully inside both the TCP payload (pend) and the + * packet (dend). Checking against both bounds with a barrier in between is + * what convinces the verifier that every later access is safe. + */ + +// checks that [ptr, ptr + n) is in bounds; does NOT advance ptr +#define CHECK_BOUNDS_OR_RETURN(ptr, n, pend, dend) \ + do \ + { \ + if ((void *)(ptr) + (n) > (const void *)(dend)) \ + return 0; \ + barrier_var(ptr); \ + if ((void *)(ptr) + (n) > (const void *)(pend)) \ + return 0; \ + barrier_var(ptr); \ + } while (0) + +// checks that [ptr, ptr + n) is in bounds and advances ptr past those bytes +#define SKIP_OR_RETURN(ptr, n, pend, dend) \ + do \ + { \ + CHECK_BOUNDS_OR_RETURN(ptr, n, pend, dend); \ + (ptr) += (n); \ + } while (0) + +// reads a fixed-size value into dest and advances ptr past it +#define READ_VAL_OR_RETURN(dest, ptr, pend, dend) \ + do \ + { \ + CHECK_BOUNDS_OR_RETURN(ptr, sizeof(dest), pend, dend); \ + (dest) = *(const __typeof__(dest) *)(ptr); \ + (ptr) += sizeof(dest); \ + } while (0) + +// returns 0 from the calling function if the condition does not hold +#define ASSERT_OR_RETURN(cond) \ + do \ + { \ + if (!(cond)) \ + return 0; \ + } while (0) + +// returns 0 from the calling function if val is not in [min, max] +#define ASSERT_IN_RANGE_OR_RETURN(val, min, max) \ + do \ + { \ + if ((val) < (min) || (val) > (max)) \ + return 0; \ + } while (0) + +// key identifying one TCP flow (all fields in network byte order) +struct ipv4_flow_key +{ + const __u32 src_ip; + const __u32 dst_ip; + const __u16 src_port; + const __u16 dst_port; +}; +_Static_assert(sizeof(struct ipv4_flow_key) == 12, "ipv4_flow_key size mismatch!"); + +// per-connection tracking data while the handshake sequence is inspected +struct initial_state +{ + __u16 state; // enum connection_state; u16 to keep the struct padding-free + __u16 fails; // out-of-order packets seen so far (see MAX_OUT_OF_ORDER) + __s32 protocol; // minecraft protocol version (signed by protocol definition) + __u32 expected_sequence; +}; +_Static_assert(sizeof(struct initial_state) == 12, "initial_state size mismatch!"); + +static __always_inline struct ipv4_flow_key gen_ipv4_flow_key(const __u32 src_ip, const __u32 dst_ip, const __u16 src_port, const __u16 dst_port) +{ + struct ipv4_flow_key key = { + .src_ip = src_ip, + .dst_ip = dst_ip, + .src_port = src_port, + .dst_port = dst_port}; + return key; +} + +static __always_inline struct initial_state gen_initial_state(const __u16 state, const __s32 protocol, const __u32 expected_sequence) +{ + struct initial_state new_state = { + .state = state, + .fails = 0, + .protocol = protocol, + .expected_sequence = expected_sequence, + }; + return new_state; +} + +#endif diff --git a/xdp/config.h b/xdp/config.h new file mode 100644 index 0000000..2a9e097 --- /dev/null +++ b/xdp/config.h @@ -0,0 +1,28 @@ +#ifndef CONFIG_H +#define CONFIG_H + +#include + +/* + * Runtime configuration. + * + * The Rust loader patches these values into the BPF .rodata section via + * aya's set_global() (with must_exist) before the program is loaded, so all + * of them are always overridden; the zeros are only placeholders. The types + * must match the set_global() calls in loader/ebpf.rs exactly. Loaded + * standalone (without the loader), the all-zero config filters nothing. + * + * Defining (not just declaring) these in a header is safe because the BPF + * program is a single translation unit; the test build replaces this header + * entirely via its include guard. + */ +volatile const __u8 PROMETHEUS = 0; // collect statistics in stats_map +volatile const __u8 ONLINE_NAMES = 0; // enforce online-mode usernames (max 16 chars) +volatile const __u32 START_PORT = 0; // first TCP port of the filtered range (inclusive) +volatile const __u32 END_PORT = 0; // last TCP port of the filtered range (inclusive) +volatile const __u32 HIT_COUNT = 0; // max SYNs per source ip per window, 0 disables the throttle +volatile const __u64 HIT_COUNT_RESET_NS = 0; // throttle window length in nanoseconds +volatile const __u64 PLAYER_IDLE_NS = 0; // idle check interval for verified connections in nanoseconds; + // entries are removed after one to two intervals without packets + +#endif diff --git a/xdp/minecraft_filter.c b/xdp/minecraft_filter.c new file mode 100644 index 0000000..6a5b10b --- /dev/null +++ b/xdp/minecraft_filter.c @@ -0,0 +1,557 @@ +/* + * minecraft_filter - XDP program protecting Minecraft Java Edition servers + * against L7 (D)DoS attacks. + * + * Every TCP packet for the filtered port range runs through a small state + * machine that validates the TCP handshake and the first Minecraft packets + * of the connection (handshake, then status+ping or login). Connections that + * complete the sequence are promoted to a verified fast path and are no + * longer inspected; everything else is dropped at the driver level. New + * connections are additionally rate limited per source ip. All map cleanup + * happens in-kernel via bpf_timer, userspace never has to touch the maps. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "config.h" +#include "protocol.h" +#include "stats.h" + +// fragment bits of iphdr->frag_off (kernel-internal net/ip.h, not uapi) +#ifndef IP_MF +#define IP_MF 0x2000 // flag: "more fragments" +#endif +#ifndef IP_OFFSET +#define IP_OFFSET 0x1FFF // "fragment offset" part +#endif + +/* ------------------------------------------------------------------------ + * Connection tracking of unverified connections + * --------------------------------------------------------------------- */ + +struct +{ + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, 16384); // placeholder, set by the loader ([xdp] max_pending_connections) + __type(key, struct ipv4_flow_key); // flow key + __type(value, struct initial_state); // inspection state machine data +} conntrack_map SEC(".maps"); + +/* ------------------------------------------------------------------------ + * Verified connections (players) + * --------------------------------------------------------------------- */ + +struct player_entry +{ + struct bpf_timer timer; // deletes the entry when the connection goes idle + __u64 packets; // incremented for every packet of this flow + __u64 last_packets; // snapshot taken by the idle check timer +}; +_Static_assert(sizeof(struct player_entry) == 32, "player_entry size mismatch!"); + +struct +{ + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 65535); // placeholder, set by the loader ([xdp] max_player_connections) + __type(key, struct ipv4_flow_key); // flow key + __type(value, struct player_entry); // idle timer + packet counter +} player_connection_map SEC(".maps"); + +/* + * bpf_timer callback: delete the verified connection if it was idle for a + * full interval, otherwise snapshot the counter and check again next interval + */ +static __s32 player_connection_idle_check(void *map, struct ipv4_flow_key *key, struct player_entry *entry) +{ + const __u64 packets = entry->packets; + if (packets == entry->last_packets) + { + bpf_map_delete_elem(map, key); + return 0; + } + entry->last_packets = packets; + bpf_timer_start(&entry->timer, PLAYER_IDLE_NS, 0); + return 0; +} + +/* ------------------------------------------------------------------------ + * Per-ip connection throttle + * --------------------------------------------------------------------- */ + +struct throttle_entry +{ + struct bpf_timer timer; // deletes the entry when the window expires + __u32 hits; // SYNs counted within the current window + __u32 pad; +}; +_Static_assert(sizeof(struct throttle_entry) == 24, "throttle_entry size mismatch!"); + +struct +{ + // plain HASH on purpose (no LRU): during a big attack the map fills up, + // inserts fail and ALL unverified traffic is dropped; only verified + // connections keep passing. Capacity recovers in-kernel as the per-entry + // timers fire and delete the expired windows. + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 65535); // placeholder, set by the loader ([xdp] max_throttled_ips) + __type(key, __u32); // ipv4 source address + __type(value, struct throttle_entry); // window timer + hit counter +} connection_throttle SEC(".maps"); + +// while connection_throttle is full, only retry inserting after this long +// (per core): a failed insert on a full preallocated map scans every cpu's +// freelist under spinlocks, so during that time we drop without even trying +#define THROTTLE_BACKOFF_NS (100ULL * 1000000ULL) // 100ms + +struct +{ + __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, __u64); // per-cpu: no insert retry before this ktime +} throttle_insert_backoff SEC(".maps"); + +/* + * bpf_timer callback: the throttle window of this ip is over. Entries that + * saw SYNs during the window are recycled (counter reset, timer re-armed) so + * repeat senders cause no map/timer churn; entries that were idle for the + * whole window are deleted. + */ +static __s32 throttle_window_expired(void *map, __u32 *key, struct throttle_entry *entry) +{ + if (__sync_lock_test_and_set(&entry->hits, 0) == 0) + { + bpf_map_delete_elem(map, key); + return 0; + } + bpf_timer_start(&entry->timer, HIT_COUNT_RESET_NS, 0); + return 0; +} + +/* ------------------------------------------------------------------------ + * Statistics (only used when PROMETHEUS is enabled, see stats.h) + * --------------------------------------------------------------------- */ + +struct +{ + __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY); + __uint(max_entries, 1); + __type(key, __u32); + __type(value, struct statistics); +} stats_map SEC(".maps"); + +/* ------------------------------------------------------------------------ + * Helpers + * --------------------------------------------------------------------- */ + +// flag combinations that never occur on legitimate client traffic +static __always_inline __u8 detect_tcp_bypass(const struct tcphdr *tcp) +{ + if ((!tcp->syn && !tcp->ack && !tcp->fin && !tcp->rst) || // none of SYN/ACK/FIN/RST set + (tcp->syn && tcp->ack) || // SYN+ACK from outside is never a client + (tcp->syn && (tcp->fin || tcp->rst)) || // SYN+FIN / SYN+RST are always forged + tcp->urg) // URG is unused by the protocol + { + return 1; + } + return 0; +} + +/* + * The connection passed the full inspection sequence: move it from the + * conntrack map into the player map so its packets skip inspection from now + * on, and arm the idle timer that will eventually clean the entry up. + */ +static __always_inline __u32 switch_to_verified(const __u64 raw_packet_len, struct statistics *stats_ptr, const struct ipv4_flow_key *flow_key) +{ + bpf_map_delete_elem(&conntrack_map, flow_key); + const struct player_entry fresh = {.packets = 1, .last_packets = 0}; + if (bpf_map_update_elem(&player_connection_map, flow_key, &fresh, BPF_NOEXIST) < 0) + { + goto drop; + } + struct player_entry *entry = bpf_map_lookup_elem(&player_connection_map, flow_key); + if (!entry) + { + goto drop; + } + if (bpf_timer_init(&entry->timer, &player_connection_map, CLOCK_MONOTONIC) < 0 || + bpf_timer_set_callback(&entry->timer, player_connection_idle_check) < 0 || + bpf_timer_start(&entry->timer, PLAYER_IDLE_NS, 0) < 0) + { + // never leak an entry that has no idle timer armed + bpf_map_delete_elem(&player_connection_map, flow_key); + goto drop; + } + count_stats(stats_ptr, VERIFIED, 1); + return XDP_PASS; +drop: + count_stats(stats_ptr, DROPPED_BYTES, raw_packet_len); + count_stats(stats_ptr, DROP_CONNECTION | DROPPED_PACKET, 1); + return XDP_DROP; +} + +/* ------------------------------------------------------------------------ + * XDP entry point + * --------------------------------------------------------------------- */ + +SEC("xdp") +__s32 minecraft_filter(struct xdp_md *ctx) +{ + const void *data = (const void *)(long)ctx->data; + const void *data_end = (const void *)(long)ctx->data_end; + + const struct ethhdr *eth = data; + if ((const void *)(eth + 1) > data_end) + { + return XDP_DROP; + } + + if (eth->h_proto != bpf_htons(ETH_P_IP)) + { + return XDP_PASS; + } + + const struct iphdr *ip = data + sizeof(struct ethhdr); + if ((const void *)(ip + 1) > data_end || ip->ihl < 5) + { + return XDP_DROP; + } + + if (ip->protocol != IPPROTO_TCP) + { + return XDP_PASS; + } + + // non-first fragments (fragment offset != 0) carry no tcp header, so the + // port check below cannot run on them: pass them up the stack so other + // services keep receiving their fragmented traffic. Safe for the filtered + // range because the matching first fragment is dropped after the port + // check, so reassembly never completes and the kernel discards the rest + // after the frag timeout. The ports can not be forged via fragment + // overlap either: they live in bytes 0-3 of the tcp header while the + // smallest non-first offset is 8 bytes, and linux >= 4.19 drops + // overlapping fragments outright + if (ip->frag_off & bpf_htons(IP_OFFSET)) + { + return XDP_PASS; + } + + const struct tcphdr *tcp = (const void *)ip + (ip->ihl * 4); + if ((const void *)(tcp + 1) > data_end) + { + return XDP_DROP; + } + + // everything outside the filtered port range is not our business + const __u16 dest_port = bpf_ntohs(tcp->dest); + if (dest_port < START_PORT || dest_port > END_PORT) + { + return XDP_PASS; + } + + // first fragment of a fragmented packet (MF set, offset 0) aimed at our + // range: the remaining payload is in fragments the state machine never + // sees, so after kernel reassembly the backend would receive uninspected + // data. Legitimate tcp does not fragment, the MSS keeps segments below + // the MTU. Dropping the first fragment makes reassembly impossible, the + // kernel discards the passed non-first fragments after the frag timeout + if (ip->frag_off & bpf_htons(IP_MF)) + { + return XDP_DROP; + } + + if (tcp->doff < 5) + { + return XDP_DROP; + } + + const __u32 tcp_hdr_len = tcp->doff * 4; + if ((const void *)tcp + tcp_hdr_len > data_end) + { + return XDP_DROP; + } + + struct statistics *stats_ptr = NULL; + if (PROMETHEUS) + { + __u32 key = 0; + stats_ptr = bpf_map_lookup_elem(&stats_map, &key); + if (!stats_ptr) + { + // per-cpu array index 0 always exists, this is unreachable + return XDP_DROP; + } + } + + const __u64 raw_packet_len = (__u64)(data_end - data); + count_stats(stats_ptr, INCOMING_BYTES, raw_packet_len); + + if (detect_tcp_bypass(tcp)) + { + count_stats(stats_ptr, TCP_BYPASS, 1); + goto drop; + } + + const __u32 src_ip = ip->saddr; + const struct ipv4_flow_key flow_key = gen_ipv4_flow_key(src_ip, ip->daddr, tcp->source, tcp->dest); + + // new connection: throttle it, then start tracking it + if (tcp->syn) + { + count_stats(stats_ptr, SYN_RECEIVE, 1); + + // drop SYNs carrying payload (e.g. TCP fast open): the data would + // reach the backend without ever passing the inspection state machine + if (bpf_ntohs(ip->tot_len) > (ip->ihl * 4) + tcp_hdr_len) + { + count_stats(stats_ptr, TCP_BYPASS, 1); + goto drop; + } + + if (HIT_COUNT) + { + // connection throttle, fully in kernel: every source ip gets its + // own window of HIT_COUNT_RESET_NS, opened by its first SYN and + // closed by the bpf_timer that deletes the entry again + struct throttle_entry *entry = bpf_map_lookup_elem(&connection_throttle, &src_ip); + if (entry) + { + if (entry->hits >= HIT_COUNT) + { + goto drop; + } + __sync_fetch_and_add(&entry->hits, 1); + } + else + { + __u32 zero = 0; + __u64 *backoff = bpf_map_lookup_elem(&throttle_insert_backoff, &zero); + if (!backoff) + { + // per-cpu array index 0 always exists, this is unreachable + goto drop; + } + const __u64 now = bpf_ktime_get_ns(); + if (now < *backoff) + { + // the map was full just before: fail closed without + // paying for another doomed insert attempt + goto drop; + } + const struct throttle_entry fresh = {.hits = 1, .pad = 0}; + const long err = bpf_map_update_elem(&connection_throttle, &src_ip, &fresh, BPF_NOEXIST); + if (err < 0) + { + if (err != -EEXIST) + { + // map full (attack): fail closed and back off + *backoff = now + THROTTLE_BACKOFF_NS; + } + goto drop; + } + entry = bpf_map_lookup_elem(&connection_throttle, &src_ip); + if (!entry) + { + goto drop; + } + if (bpf_timer_init(&entry->timer, &connection_throttle, CLOCK_MONOTONIC) < 0 || + bpf_timer_set_callback(&entry->timer, throttle_window_expired) < 0 || + bpf_timer_start(&entry->timer, HIT_COUNT_RESET_NS, 0) < 0) + { + // never leak an entry that has no expiry timer armed + bpf_map_delete_elem(&connection_throttle, &src_ip); + goto drop; + } + } + } + + // track the connection: the next packet has to be the ACK finishing + // the TCP handshake, with the sequence number following this SYN + const struct initial_state new_state = gen_initial_state(AWAIT_ACK, 0, bpf_ntohl(tcp->seq) + 1); + if (bpf_map_update_elem(&conntrack_map, &flow_key, &new_state, BPF_ANY) < 0) + { + goto drop; + } + + return XDP_PASS; + } + + // verified connections skip all further inspection + struct player_entry *player = bpf_map_lookup_elem(&player_connection_map, &flow_key); + if (player) + { + // non-atomic on purpose: racing increments (flow migrating cpus) can + // only lose single steps, never regress the counter across a window. + // The idle check thus only false-matches if the connection sent + // nothing for ~a full window, and minecraft clients keepalive every + // few seconds, so such a connection is dead anyway + player->packets++; + return XDP_PASS; + } + + struct initial_state *initial_state = bpf_map_lookup_elem(&conntrack_map, &flow_key); + if (!initial_state) + { + goto drop; // neither tracked nor verified + } + + const __u8 *tcp_payload = (const __u8 *)tcp + tcp_hdr_len; + + // ip total length - ip header - tcp header = length of the tcp payload + const __u16 ip_tot_len = bpf_ntohs(ip->tot_len); + const __u16 tcp_payload_len = ip_tot_len - (ip->ihl * 4) - tcp_hdr_len; + const __u8 *tcp_payload_end = tcp_payload + tcp_payload_len; + + // tcp packet split over multiple ethernet frames, we don't support that + if (tcp_payload_end > (const __u8 *)data_end) + { + goto drop; + } + + __u32 state = initial_state->state; + if (state == AWAIT_ACK) + { + // not an ack, or not the ack matching our SYN + if (!tcp->ack || initial_state->expected_sequence != bpf_ntohl(tcp->seq)) + { + goto drop; + } + + // advance the state machine before the early drop below, the next + // packet has to be matched against AWAIT_MC_HANDSHAKE + initial_state->state = state = AWAIT_MC_HANDSHAKE; + + // the empty ack finishing the TCP handshake is dropped on purpose: + // the backend will accept the first minecraft data packet as that + // ack, which elegantly limits backend connections to clients whose + // handshake passed inspection. If the ack already carries payload, + // fall through into payload inspection instead + if (tcp_payload >= tcp_payload_end) + { + goto drop; + } + } + + if (tcp_payload < tcp_payload_end) + { + // payload without an ack flag is never legitimate mid-handshake + if (!tcp->ack) + { + goto drop_connection; + } + + // we fully track the tcp sequence, so a mismatch here is either a + // retransmission or an out-of-order packet: drop the packet, and + // drop the whole connection once that happens too often. Everything + // that survives this check is exactly the in-order byte stream the + // backend would see, which is what allows the hard punishments below + if (initial_state->expected_sequence != bpf_ntohl(tcp->seq)) + { + if (++initial_state->fails > MAX_OUT_OF_ORDER) + { + goto drop_connection; + } + goto drop; + } + + if (state == AWAIT_MC_HANDSHAKE) + { + // if the status request or login packet is in the same tcp + // segment as the handshake, inspect_handshake returns a + // DIRECT_READ_* state and advances tcp_payload to the rest + const __s32 next_state = inspect_handshake(tcp_payload, tcp_payload_end, data_end, &initial_state->protocol, &tcp_payload); + if (!next_state) + { + // even with retransmissions the handshake of a legitimate + // client is always parseable, this connection is bogus + goto drop; + } + if (next_state == RECEIVED_LEGACY_PING) + { + goto drop_connection; + } + if (next_state == DIRECT_READ_STATUS_REQUEST) + { + goto read_status; + } + if (next_state == DIRECT_READ_LOGIN) + { + goto read_login; + } + initial_state->state = next_state; + goto update_state; + } + if (state == AWAIT_STATUS_REQUEST) + read_status: + { + if (!inspect_status_request(tcp_payload, tcp_payload_end, data_end)) + { + goto drop; + } + initial_state->state = AWAIT_PING; + goto update_state; + } + if (state == AWAIT_PING) + { + if (!inspect_ping_request(tcp_payload, tcp_payload_end, data_end)) + { + goto drop; + } + initial_state->state = PING_COMPLETE; + goto update_state; + } + if (state == AWAIT_LOGIN) + read_login: + { + if (!inspect_login_packet(tcp_payload, tcp_payload_end, data_end, initial_state->protocol)) + { + goto drop; + } + // tracking ends here, no need to update the expected sequence + return switch_to_verified(raw_packet_len, stats_ptr, &flow_key); + } + if (state == PING_COMPLETE) + { + // a finished ping flow has nothing more to say + goto drop_connection; + } + } + else if (state == AWAIT_MC_HANDSHAKE) + { + // empty acks are not allowed while the handshake is pending, + // otherwise an attacker could sit on a half-inspected connection + goto drop; + } + + // empty segments in the remaining states (pure acks, FIN/RST teardown) + return XDP_PASS; + +// shared exit paths: jumping here instead of duplicating these blocks keeps +// the generated program drastically smaller +drop_connection: + count_stats(stats_ptr, DROP_CONNECTION, 1); + bpf_map_delete_elem(&conntrack_map, &flow_key); + // fall through +drop: + count_stats(stats_ptr, DROPPED_PACKET, 1); + count_stats(stats_ptr, DROPPED_BYTES, raw_packet_len); + return XDP_DROP; +update_state: + initial_state->expected_sequence += tcp_payload_len; + count_stats(stats_ptr, STATE_SWITCH, 1); + return XDP_PASS; +} + +// must be GPL-compatible: the bpf_timer_* helpers used by the connection +// throttle are gpl_only, the kernel refuses to load them otherwise +char _license[] SEC("license") = "Dual BSD/GPL"; diff --git a/xdp/protocol.h b/xdp/protocol.h new file mode 100644 index 0000000..bddea11 --- /dev/null +++ b/xdp/protocol.h @@ -0,0 +1,215 @@ +#ifndef PROTOCOL_H +#define PROTOCOL_H + +#include + +#include "common.h" +#include "config.h" +#include "varint.h" + +/* + * Inspection of the first Minecraft protocol packets of a connection. + * + * Every inspector walks the TCP payload with a bounds-checked cursor and + * returns 0 (STATE_INVALID) as soon as anything does not look like a valid + * packet of the expected type. Wire format reference: + * https://github.com/SpigotMC/BungeeCord/tree/master/protocol + */ + +// size limits, derived from the protocol definitions +#define UTF8_MAX_BYTES 3 +#define UUID_LEN 16 + +#define PACKET_ID_MIN MIN_VARINT_BYTES +#define PACKET_ID_MAX MAX_VARINT_BYTES + +// handshake packet: protocol version, server host, server port, intention +#define HANDSHAKE_VERSION_MIN MIN_VARINT_BYTES +#define HANDSHAKE_VERSION_MAX MAX_VARINT_BYTES + +#define HANDSHAKE_HOSTLEN_MIN MIN_VARINT_BYTES +#define HANDSHAKE_HOSTLEN_MAX MAX_VARINT_BYTES + +#define HANDSHAKE_HOST_DATA_MIN (0) +#define HANDSHAKE_HOST_DATA_MAX (255 * UTF8_MAX_BYTES) + +#define HANDSHAKE_PORT_LEN (2) + +#define HANDSHAKE_INTENTION_MIN MIN_VARINT_BYTES +#define HANDSHAKE_INTENTION_MAX MAX_VARINT_BYTES + +#define HANDSHAKE_DATA_MIN (HANDSHAKE_VERSION_MIN + HANDSHAKE_HOSTLEN_MIN + HANDSHAKE_HOST_DATA_MIN + HANDSHAKE_PORT_LEN + HANDSHAKE_INTENTION_MIN) +#define HANDSHAKE_DATA_MAX (HANDSHAKE_VERSION_MAX + HANDSHAKE_HOSTLEN_MAX + HANDSHAKE_HOST_DATA_MAX + HANDSHAKE_PORT_LEN + HANDSHAKE_INTENTION_MAX) + +// login request packet: username, optional public key (1.19 to 1.19.2), uuid +#define LOGIN_NAME_LEN_MIN MIN_VARINT_BYTES +#define LOGIN_NAME_LEN_MAX MAX_VARINT_BYTES + +#define LOGIN_NAME_DATA_MIN (1) // empty names are not possible +#define LOGIN_NAME_DATA_MAX (16 * UTF8_MAX_BYTES) + +#define LOGIN_KEY_MIN 0 +#define LOGIN_KEY_MAX 512 + +#define LOGIN_SIGNATURE_MIN 0 +#define LOGIN_SIGNATURE_MAX 4096 + +#define LOGIN_PUBLIC_KEY_MIN (/*has key*/ 1) +#define LOGIN_PUBLIC_KEY_MAX (/*has key*/ 1 + /*expiry*/ 8 + /*length*/ MAX_VARINT_BYTES + LOGIN_KEY_MAX + /*length*/ MAX_VARINT_BYTES + LOGIN_SIGNATURE_MAX) + +#define LOGIN_HAS_UUID_LEN 1 +#define LOGIN_DATA_MIN (LOGIN_NAME_LEN_MIN + LOGIN_NAME_DATA_MIN) +#define LOGIN_DATA_MAX (LOGIN_NAME_LEN_MAX + LOGIN_NAME_DATA_MAX + LOGIN_PUBLIC_KEY_MAX + LOGIN_HAS_UUID_LEN + UUID_LEN) + +/* + * Validates the handshake packet and returns the resulting connection state, + * or 0 if the packet is invalid. A client may append the status request or + * login packet to the same TCP segment (also seen after retransmissions); in + * that case a DIRECT_READ_* state is returned and *resume_cursor points at + * the remaining payload so the caller can continue inspecting it. + */ +static __always_inline __s32 inspect_handshake(const __u8 *cursor, const __u8 *payload_end, const void *data_end, __s32 *protocol_version, const __u8 **resume_cursor) +{ + CHECK_BOUNDS_OR_RETURN(cursor, 1, payload_end, data_end); + // pre-1.7 clients open with 0xFE instead of a length prefix + if (cursor[0] == (__u8)0xFE) + { + return RECEIVED_LEGACY_PING; + } + + struct varint_value varint; + + // packet length + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(PACKET_ID_MAX + HANDSHAKE_DATA_MAX)); + ASSERT_IN_RANGE_OR_RETURN(varint.value, PACKET_ID_MIN + HANDSHAKE_DATA_MIN, PACKET_ID_MAX + HANDSHAKE_DATA_MAX); + + // packet id, must be 0 + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(0x00)); + ASSERT_OR_RETURN(varint.value == 0x00); + + // protocol version + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, MAX_VARINT_BYTES); + *protocol_version = varint.value; + + // host length, then skip the host data + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(HANDSHAKE_HOST_DATA_MAX)); + ASSERT_IN_RANGE_OR_RETURN(varint.value, HANDSHAKE_HOST_DATA_MIN, HANDSHAKE_HOST_DATA_MAX); + SKIP_OR_RETURN(cursor, varint.value, payload_end, data_end); + + // server port + SKIP_OR_RETURN(cursor, HANDSHAKE_PORT_LEN, payload_end, data_end); + + // intention: 1 (status), 2 (login), 3 (login via transfer, since 1.20.5) + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(3)); + const __s32 intention = varint.value; + const __u8 supports_transfer = *protocol_version >= 766; + ASSERT_OR_RETURN(intention == 1 || intention == 2 || (supports_transfer && intention == 3)); + + // packet contained exactly the handshake + if (cursor == payload_end) + { + return intention == 1 ? AWAIT_STATUS_REQUEST : AWAIT_LOGIN; + } + + // more protocol data follows in the same packet + *resume_cursor = cursor; + return intention == 1 ? DIRECT_READ_STATUS_REQUEST : DIRECT_READ_LOGIN; +} + +// returns 1 if the payload is exactly one valid status request packet +static __always_inline __u8 inspect_status_request(const __u8 *cursor, const __u8 *payload_end, const void *data_end) +{ + struct varint_value varint; + + // packet length, must be 1 + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(0x01)); + ASSERT_OR_RETURN(varint.value == 0x01); + + // packet id, must be 0 + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(0x00)); + ASSERT_OR_RETURN(varint.value == 0x00); + + return cursor == payload_end; +} + +// returns 1 if the payload is exactly one valid ping request packet +static __always_inline __u8 inspect_ping_request(const __u8 *cursor, const __u8 *payload_end, const void *data_end) +{ + struct varint_value varint; + + // packet length, must be 9 (packet id + 8 byte timestamp) + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(0x09)); + ASSERT_OR_RETURN(varint.value == 0x09); + + // packet id, must be 1 + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(0x01)); + ASSERT_OR_RETURN(varint.value == 0x01); + + __u64 timestamp; + READ_VAL_OR_RETURN(timestamp, cursor, payload_end, data_end); + return cursor == payload_end; +} + +// returns 1 if the payload is exactly one valid login request packet +static __always_inline __u8 inspect_login_packet(const __u8 *cursor, const __u8 *payload_end, const void *data_end, const __s32 protocol_version) +{ + struct varint_value varint; + + // packet length + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(PACKET_ID_MAX + LOGIN_DATA_MAX)); + ASSERT_IN_RANGE_OR_RETURN(varint.value, PACKET_ID_MIN + LOGIN_DATA_MIN, PACKET_ID_MAX + LOGIN_DATA_MAX); + + // packet id, must be 0 + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(0x00)); + ASSERT_OR_RETURN(varint.value == 0x00); + + // username length, then skip the username data + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(ONLINE_NAMES ? 16 : LOGIN_NAME_DATA_MAX)); + ASSERT_IN_RANGE_OR_RETURN(varint.value, LOGIN_NAME_DATA_MIN, ONLINE_NAMES ? 16 : LOGIN_NAME_DATA_MAX); + SKIP_OR_RETURN(cursor, varint.value, payload_end, data_end); + + // optional chat signing key, 1.19 (759) up to 1.19.3 (761) + if (protocol_version >= 759 && protocol_version < 761) + { + __u8 has_public_key; + READ_VAL_OR_RETURN(has_public_key, cursor, payload_end, data_end); + if (has_public_key) + { + // expiry timestamp + SKIP_OR_RETURN(cursor, 8, payload_end, data_end); + + // public key length, then skip the key + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(LOGIN_KEY_MAX)); + ASSERT_IN_RANGE_OR_RETURN(varint.value, LOGIN_KEY_MIN, LOGIN_KEY_MAX); + SKIP_OR_RETURN(cursor, varint.value, payload_end, data_end); + + // signature length, then skip the signature + READ_VARINT_OR_RETURN(varint, cursor, payload_end, data_end, VARINT_SIZE(LOGIN_SIGNATURE_MAX)); + ASSERT_IN_RANGE_OR_RETURN(varint.value, LOGIN_SIGNATURE_MIN, LOGIN_SIGNATURE_MAX); + SKIP_OR_RETURN(cursor, varint.value, payload_end, data_end); + } + } + + // uuid, optional from 1.19.1 (760), always present since 1.20.2 (764) + if (protocol_version >= 760) + { + if (protocol_version >= 764) + { + SKIP_OR_RETURN(cursor, UUID_LEN, payload_end, data_end); + } + else + { + __u8 has_uuid; + READ_VAL_OR_RETURN(has_uuid, cursor, payload_end, data_end); + if (has_uuid) + { + SKIP_OR_RETURN(cursor, UUID_LEN, payload_end, data_end); + } + } + } + + // valid only if the packet ends exactly here + return cursor == payload_end; +} + +#endif diff --git a/c/stats.h b/xdp/stats.h similarity index 60% rename from c/stats.h rename to xdp/stats.h index 24f22f2..6e29ca1 100644 --- a/c/stats.h +++ b/xdp/stats.h @@ -3,7 +3,10 @@ #include -// bitmask for statistics types +#include "config.h" + +// selects which statistics count_stats() increments; a bitmask so related +// counters can be bumped in one call (e.g. DROP_CONNECTION | DROPPED_PACKET) enum stats_mask { VERIFIED = 1u << 0, DROPPED_PACKET = 1u << 1, @@ -15,6 +18,7 @@ enum stats_mask { DROPPED_BYTES = 1u << 7, }; +// one per-cpu slot of stats_map; must match `Statistics` in loader/metrics.rs struct statistics { __u64 verified; @@ -26,15 +30,25 @@ struct statistics __u64 incoming_bytes; __u64 dropped_bytes; }; - _Static_assert(sizeof(struct statistics) == 64, "statistics size mismatch!"); /* - * the compiler will optimize this function well + * Adds `amount` to every counter selected by `bitmask`. + * + * stats_ptr is NULL whenever PROMETHEUS is 0 (the filter only looks it up + * when enabled), so the PROMETHEUS check below also guards the dereference. + * Since PROMETHEUS lives in .rodata, the verifier knows its value at load + * time and removes either the early return or the entire body as dead code; + * with constant bitmasks the compiler reduces each call to the few + * increments that are actually selected. */ -#if PROMETHEUS_METRICS -static __always_inline void count_stats_impl(struct statistics *stats_ptr, const __u32 bitmask, const __u64 amount) +static __always_inline void count_stats(struct statistics *stats_ptr, const __u32 bitmask, const __u64 amount) { + if (!PROMETHEUS) + { + return; + } + if (bitmask & INCOMING_BYTES) { stats_ptr->incoming_bytes += amount; @@ -76,9 +90,4 @@ static __always_inline void count_stats_impl(struct statistics *stats_ptr, const } } -#define count_stats(stats_ptr, bitmask, amount) count_stats_impl(stats_ptr, bitmask, amount) -#else -#define count_stats(stats_ptr, bitmask, amount) #endif - -#endif \ No newline at end of file diff --git a/xdp/tests/protocol_test.c b/xdp/tests/protocol_test.c new file mode 100644 index 0000000..0358086 --- /dev/null +++ b/xdp/tests/protocol_test.c @@ -0,0 +1,704 @@ +/* + * Native unit tests for the eBPF parsing code (varint.h, protocol.h and the + * bounds-check macros in common.h). + * + * Compiled for the host (not for BPF) and executed by tests/c_unit_tests.rs + * as part of `cargo test`, with ASan/UBSan enabled when available. Every + * inspector call runs on an exact-size heap copy of the packet, so any read + * past data_end trips the address sanitizer instead of going unnoticed. + */ +#include +#include +#include +#include + +// provided by bpf/bpf_helpers.h in the BPF build +#ifndef __always_inline +#define __always_inline inline __attribute__((always_inline)) +#endif + +/* + * Stand-in for xdp/config.h, suppressed via its include guard: the real header + * defines the knobs `volatile const`, but the login tests need to flip + * ONLINE_NAMES at runtime. + */ +#define CONFIG_H +static volatile __u8 ONLINE_NAMES = 1; + +#include "../common.h" +#include "../varint.h" +#include "../protocol.h" + +/* ------------------------------------------------------------------------ + * Tiny test framework + * --------------------------------------------------------------------- */ + +static unsigned checks_run = 0; +static unsigned checks_failed = 0; + +#define CHECK(cond) \ + do \ + { \ + checks_run++; \ + if (!(cond)) \ + { \ + checks_failed++; \ + printf("FAIL %s:%d in %s: %s\n", __FILE__, __LINE__, __func__, \ + #cond); \ + } \ + } while (0) + +/* ------------------------------------------------------------------------ + * Packet building helpers + * --------------------------------------------------------------------- */ + +// reference varint encoder, validated against the wiki.vg test vectors below +static __u32 write_varint(__u8 *out, __s32 value) +{ + __u32 v = (__u32)value; + __u32 n = 0; + do + { + __u8 byte = v & 0x7F; + v >>= 7; + if (v) + { + byte |= 0x80; + } + out[n++] = byte; + } while (v); + return n; +} + +struct buf +{ + __u8 b[2048]; + __u32 n; +}; + +static void put_varint(struct buf *p, __s32 value) +{ + p->n += write_varint(p->b + p->n, value); +} + +static void put_u8(struct buf *p, __u8 value) +{ + p->b[p->n++] = value; +} + +static void put_fill(struct buf *p, __u8 fill, __u32 count) +{ + memset(p->b + p->n, fill, count); + p->n += count; +} + +// prefixes a packet body with its length varint, like the protocol does +static struct buf packetize(const struct buf *body) +{ + struct buf pkt = {{0}, 0}; + put_varint(&pkt, (__s32)body->n); + memcpy(pkt.b + pkt.n, body->b, body->n); + pkt.n += body->n; + return pkt; +} + +/* ------------------------------------------------------------------------ + * Runners: every parse happens on an exact-size heap copy so ASan catches + * any access past data_end. `slack` adds extra bytes between payload_end + * and data_end to exercise the dual-bounds checks. + * --------------------------------------------------------------------- */ + +static __u8 *heap_copy(const __u8 *bytes, __u32 len, __u32 slack) +{ + const __u32 size = len + slack; + __u8 *heap = malloc(size ? size : 1); // malloc(0) may return NULL + if (len) + { + memcpy(heap, bytes, len); + } + if (slack) + { + // 0x01 terminates a varint, so a parser that wrongly runs past + // payload_end produces a successful-looking parse the assertions + // can catch (instead of failing by coincidence) + memset(heap + len, 0x01, slack); + } + return heap; +} + +static struct varint_value run_varint_slack(const __u8 *bytes, __u32 payload_len, __u32 slack, __u8 max_size) +{ + __u8 *heap = heap_copy(bytes, payload_len, slack); + const struct varint_value v = + read_varint_sized(heap, heap + payload_len, max_size, heap + payload_len + slack); + free(heap); + return v; +} + +static struct varint_value run_varint(const __u8 *bytes, __u32 payload_len, __u8 max_size) +{ + return run_varint_slack(bytes, payload_len, 0, max_size); +} + +static __u8 run_status_slack(const __u8 *pkt, __u32 len, __u32 slack) +{ + __u8 *heap = heap_copy(pkt, len, slack); + const __u8 ok = inspect_status_request(heap, heap + len, heap + len + slack); + free(heap); + return ok; +} + +static __u8 run_status(const __u8 *pkt, __u32 len) +{ + return run_status_slack(pkt, len, 0); +} + +static __u8 run_ping(const __u8 *pkt, __u32 len) +{ + __u8 *heap = heap_copy(pkt, len, 0); + const __u8 ok = inspect_ping_request(heap, heap + len, heap + len); + free(heap); + return ok; +} + +static __u8 run_login(const __u8 *pkt, __u32 len, __s32 protocol) +{ + __u8 *heap = heap_copy(pkt, len, 0); + const __u8 ok = inspect_login_packet(heap, heap + len, heap + len, protocol); + free(heap); + return ok; +} + +static __s32 run_handshake(const __u8 *pkt, __u32 len, __s32 *proto_out, __u32 *resume_off) +{ + __u8 *heap = heap_copy(pkt, len, 0); + const __u8 *resume = NULL; + __s32 proto = 0; + const __s32 state = inspect_handshake(heap, heap + len, heap + len, &proto, &resume); + if (proto_out) + { + *proto_out = proto; + } + if (resume_off) + { + *resume_off = resume ? (__u32)(resume - heap) : 0; + } + free(heap); + return state; +} + +/* ------------------------------------------------------------------------ + * VARINT_SIZE (compile-time) + * --------------------------------------------------------------------- */ + +_Static_assert(VARINT_SIZE(0x00) == 1, "VARINT_SIZE(0)"); +_Static_assert(VARINT_SIZE(0x7F) == 1, "VARINT_SIZE(127)"); +_Static_assert(VARINT_SIZE(0x80) == 2, "VARINT_SIZE(128)"); +_Static_assert(VARINT_SIZE(0x3FFF) == 2, "VARINT_SIZE(16383)"); +_Static_assert(VARINT_SIZE(0x4000) == 3, "VARINT_SIZE(16384)"); +_Static_assert(VARINT_SIZE(0x1FFFFF) == 3, "VARINT_SIZE(2097151)"); +_Static_assert(VARINT_SIZE(0x200000) == 4, "VARINT_SIZE(2097152)"); +_Static_assert(VARINT_SIZE(0xFFFFFFF) == 4, "VARINT_SIZE(268435455)"); +_Static_assert(VARINT_SIZE(0x10000000) == 5, "VARINT_SIZE(268435456)"); + +/* ------------------------------------------------------------------------ + * Varint reader + * --------------------------------------------------------------------- */ + +// test vectors from the protocol documentation (wiki.vg) +static const struct +{ + __u8 bytes[5]; + __u32 len; + __s32 value; +} VARINT_VECTORS[] = { + {{0x00}, 1, 0}, + {{0x01}, 1, 1}, + {{0x02}, 1, 2}, + {{0x7F}, 1, 127}, + {{0x80, 0x01}, 2, 128}, + {{0xFF, 0x01}, 2, 255}, + {{0xDD, 0xC7, 0x01}, 3, 25565}, + {{0xFF, 0xFF, 0x7F}, 3, 2097151}, + {{0xFF, 0xFF, 0xFF, 0xFF, 0x07}, 5, 2147483647}, + {{0xFF, 0xFF, 0xFF, 0xFF, 0x0F}, 5, -1}, + {{0x80, 0x80, 0x80, 0x80, 0x08}, 5, -2147483647 - 1}, +}; + +static void test_varint_decodes_known_vectors(void) +{ + for (__u32 i = 0; i < sizeof(VARINT_VECTORS) / sizeof(VARINT_VECTORS[0]); i++) + { + const struct varint_value v = + run_varint(VARINT_VECTORS[i].bytes, VARINT_VECTORS[i].len, MAX_VARINT_BYTES); + CHECK(v.value == VARINT_VECTORS[i].value); + CHECK(v.bytes == VARINT_VECTORS[i].len); + + // the encoder used by the packet builders must produce these + // exact bytes, otherwise all later tests would test nothing + __u8 encoded[5] = {0}; + const __u32 n = write_varint(encoded, VARINT_VECTORS[i].value); + CHECK(n == VARINT_VECTORS[i].len); + CHECK(memcmp(encoded, VARINT_VECTORS[i].bytes, n) == 0); + } +} + +static void test_varint_roundtrip(void) +{ + static const __s32 VALUES[] = {0, 1, 2, 127, 128, 255, + 300, 16383, 16384, 25565, 2097151, 2097152, + -1, -25565, 268435455, 268435456, 2147483647, -2147483647 - 1}; + for (__u32 i = 0; i < sizeof(VALUES) / sizeof(VALUES[0]); i++) + { + __u8 encoded[5]; + const __u32 n = write_varint(encoded, VALUES[i]); + const struct varint_value v = run_varint(encoded, n, MAX_VARINT_BYTES); + CHECK(v.value == VALUES[i]); + CHECK(v.bytes == n); + } +} + +static void test_varint_rejects_truncated_input(void) +{ + const __u8 two_byte[] = {0x80, 0x01}; + CHECK(run_varint(two_byte, 1, MAX_VARINT_BYTES).bytes == 0); // continuation cut off + CHECK(run_varint(two_byte, 0, MAX_VARINT_BYTES).bytes == 0); // empty payload + + const __u8 three_byte[] = {0xDD, 0xC7, 0x01}; + CHECK(run_varint(three_byte, 2, MAX_VARINT_BYTES).bytes == 0); +} + +static void test_varint_respects_max_size(void) +{ + const __u8 two_byte[] = {0x80, 0x01}; + CHECK(run_varint(two_byte, 2, 1).bytes == 0); + CHECK(run_varint(two_byte, 2, 2).bytes == 2); + + const __u8 three_byte[] = {0xDD, 0xC7, 0x01}; + CHECK(run_varint(three_byte, 3, 2).bytes == 0); + CHECK(run_varint(three_byte, 3, 3).bytes == 3); +} + +static void test_varint_never_reads_past_payload_end(void) +{ + // a continuation byte at the end of the payload, with valid-looking + // bytes behind payload_end: must fail instead of reading the slack + const __u8 bytes[] = {0x80}; + CHECK(run_varint_slack(bytes, 1, 4, MAX_VARINT_BYTES).bytes == 0); +} + +static void test_varint_stops_at_terminator(void) +{ + // trailing bytes after a complete varint are someone else's business + const __u8 bytes[] = {0x01, 0xFF, 0xFF}; + const struct varint_value v = run_varint(bytes, 3, MAX_VARINT_BYTES); + CHECK(v.value == 1); + CHECK(v.bytes == 1); +} + +static void test_varint_rejects_overlong_encoding(void) +{ + // continuation bit still set on the fifth byte: a sixth byte would be + // required, which no 32-bit varint may have + const __u8 bytes[] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}; + CHECK(run_varint(bytes, 6, MAX_VARINT_BYTES).bytes == 0); +} + +/* ------------------------------------------------------------------------ + * Bounds-check macros + * --------------------------------------------------------------------- */ + +static __u8 wrap_check_bounds(const __u8 *cursor, __u32 n, const __u8 *pend, const void *dend) +{ + CHECK_BOUNDS_OR_RETURN(cursor, n, pend, dend); + return 1; +} + +static __u8 wrap_skip_twice(const __u8 *cursor, __u32 a, __u32 b, const __u8 *pend, const void *dend) +{ + SKIP_OR_RETURN(cursor, a, pend, dend); + SKIP_OR_RETURN(cursor, b, pend, dend); + return 1; +} + +static __u8 wrap_read_u64(const __u8 *cursor, const __u8 *pend, const void *dend, __u64 *out) +{ + READ_VAL_OR_RETURN(*out, cursor, pend, dend); + return 1; +} + +static void test_bounds_macros(void) +{ + __u8 *heap = malloc(16); + for (__u8 i = 0; i < 16; i++) + { + heap[i] = i; + } + const __u8 *pend = heap + 16; + + CHECK(wrap_check_bounds(heap, 16, pend, pend) == 1); + CHECK(wrap_check_bounds(heap, 17, pend, pend) == 0); + CHECK(wrap_check_bounds(heap + 16, 0, pend, pend) == 1); + // payload_end binds even when data_end leaves room + CHECK(wrap_check_bounds(heap, 10, heap + 8, pend) == 0); + + CHECK(wrap_skip_twice(heap, 8, 8, pend, pend) == 1); + CHECK(wrap_skip_twice(heap, 8, 9, pend, pend) == 0); + + __u64 value = 0; + CHECK(wrap_read_u64(heap, pend, pend, &value) == 1); + CHECK(value == 0x0706050403020100ULL); // little endian + // unaligned read is fine by design (network data) + CHECK(wrap_read_u64(heap + 1, pend, pend, &value) == 1); + CHECK(value == 0x0807060504030201ULL); + CHECK(wrap_read_u64(heap + 9, pend, pend, &value) == 0); + + free(heap); +} + +/* ------------------------------------------------------------------------ + * Status request: [len=0x01][id=0x00] + * --------------------------------------------------------------------- */ + +static void test_status_request(void) +{ + const __u8 valid[] = {0x01, 0x00}; + CHECK(run_status(valid, 2) == 1); + // bytes beyond payload_end must not change the verdict + CHECK(run_status_slack(valid, 2, 8) == 1); + + const __u8 wrong_len[] = {0x02, 0x00}; + CHECK(run_status(wrong_len, 2) == 0); + + const __u8 wrong_id[] = {0x01, 0x01}; + CHECK(run_status(wrong_id, 2) == 0); + + const __u8 trailing[] = {0x01, 0x00, 0x00}; + CHECK(run_status(trailing, 3) == 0); + + const __u8 non_canonical_len[] = {0x81, 0x00, 0x00}; + CHECK(run_status(non_canonical_len, 3) == 0); + + CHECK(run_status(valid, 1) == 0); + CHECK(run_status(valid, 0) == 0); +} + +/* ------------------------------------------------------------------------ + * Ping request: [len=0x09][id=0x01][8 byte timestamp] + * --------------------------------------------------------------------- */ + +static void test_ping_request(void) +{ + const __u8 valid[] = {0x09, 0x01, 1, 2, 3, 4, 5, 6, 7, 8}; + CHECK(run_ping(valid, 10) == 1); + + CHECK(run_ping(valid, 9) == 0); // truncated timestamp + + const __u8 wrong_id[] = {0x09, 0x00, 1, 2, 3, 4, 5, 6, 7, 8}; + CHECK(run_ping(wrong_id, 10) == 0); + + const __u8 wrong_len[] = {0x08, 0x01, 1, 2, 3, 4, 5, 6, 7, 8}; + CHECK(run_ping(wrong_len, 10) == 0); + + const __u8 trailing[] = {0x09, 0x01, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + CHECK(run_ping(trailing, 11) == 0); +} + +/* ------------------------------------------------------------------------ + * Handshake: [len][id=0x00][protocol][host len][host][port][intention] + * --------------------------------------------------------------------- */ + +static struct buf build_handshake(__s32 protocol, __u32 host_len, __s32 intention) +{ + struct buf body = {{0}, 0}; + put_varint(&body, 0x00); // packet id + put_varint(&body, protocol); + put_varint(&body, (__s32)host_len); + put_fill(&body, 'h', host_len); + put_u8(&body, 0x63); // port 25565 + put_u8(&body, 0xDD); + put_varint(&body, intention); + return packetize(&body); +} + +static void test_handshake_intentions(void) +{ + __s32 proto = 0; + + struct buf status = build_handshake(763, 14, 1); + CHECK(run_handshake(status.b, status.n, &proto, NULL) == AWAIT_STATUS_REQUEST); + CHECK(proto == 763); + + struct buf login = build_handshake(763, 14, 2); + CHECK(run_handshake(login.b, login.n, &proto, NULL) == AWAIT_LOGIN); + + // intention 3 (transfer) exists since 1.20.5 (766) + struct buf transfer_new = build_handshake(766, 14, 3); + CHECK(run_handshake(transfer_new.b, transfer_new.n, &proto, NULL) == AWAIT_LOGIN); + struct buf transfer_old = build_handshake(765, 14, 3); + CHECK(run_handshake(transfer_old.b, transfer_old.n, &proto, NULL) == STATE_INVALID); + + struct buf intention_zero = build_handshake(763, 14, 0); + CHECK(run_handshake(intention_zero.b, intention_zero.n, &proto, NULL) == STATE_INVALID); + struct buf intention_four = build_handshake(763, 14, 4); + CHECK(run_handshake(intention_four.b, intention_four.n, &proto, NULL) == STATE_INVALID); +} + +static void test_handshake_legacy_ping(void) +{ + const __u8 legacy[] = {0xFE, 0x01}; + CHECK(run_handshake(legacy, 2, NULL, NULL) == RECEIVED_LEGACY_PING); + CHECK(run_handshake(legacy, 1, NULL, NULL) == RECEIVED_LEGACY_PING); + CHECK(run_handshake(legacy, 0, NULL, NULL) == STATE_INVALID); +} + +static void test_handshake_rejects_malformed(void) +{ + // wrong packet id + struct buf body = {{0}, 0}; + put_varint(&body, 0x01); // packet id 1 instead of 0 + put_varint(&body, 763); + put_varint(&body, 0); + put_u8(&body, 0x63); + put_u8(&body, 0xDD); + put_varint(&body, 1); + struct buf wrong_id = packetize(&body); + CHECK(run_handshake(wrong_id.b, wrong_id.n, NULL, NULL) == STATE_INVALID); + + // truncated: every prefix of a valid handshake must be rejected + struct buf valid = build_handshake(763, 14, 1); + for (__u32 len = 0; len < valid.n; len++) + { + CHECK(run_handshake(valid.b, len, NULL, NULL) == STATE_INVALID); + } + + // length field below the minimum (smallest possible body is 6 bytes) + const __u8 len_too_small[] = {0x03, 0x00, 0x00}; + CHECK(run_handshake(len_too_small, 3, NULL, NULL) == STATE_INVALID); + + // length field above the maximum (787) + struct buf len_too_big = {{0}, 0}; + put_varint(&len_too_big, 788); + put_fill(&len_too_big, 0x00, 8); + CHECK(run_handshake(len_too_big.b, len_too_big.n, NULL, NULL) == STATE_INVALID); + + // host longer than the protocol allows (255 * 3 = 765) + struct buf host_too_long = build_handshake(763, 766, 1); + CHECK(run_handshake(host_too_long.b, host_too_long.n, NULL, NULL) == STATE_INVALID); + + // the longest legal host must still parse + struct buf host_max = build_handshake(763, 765, 1); + CHECK(run_handshake(host_max.b, host_max.n, NULL, NULL) == AWAIT_STATUS_REQUEST); + + // empty host is allowed by the protocol bounds + struct buf host_empty = build_handshake(763, 0, 1); + CHECK(run_handshake(host_empty.b, host_empty.n, NULL, NULL) == AWAIT_STATUS_REQUEST); +} + +static void test_handshake_combined_with_status_request(void) +{ + struct buf pkt = build_handshake(763, 14, 1); + const __u32 handshake_len = pkt.n; + put_u8(&pkt, 0x01); // status request appended in the same segment + put_u8(&pkt, 0x00); + + __s32 proto = 0; + __u32 resume = 0; + CHECK(run_handshake(pkt.b, pkt.n, &proto, &resume) == DIRECT_READ_STATUS_REQUEST); + CHECK(resume == handshake_len); + CHECK(run_status(pkt.b + resume, pkt.n - resume) == 1); +} + +static void test_handshake_combined_with_login(void) +{ + struct buf pkt = build_handshake(765, 14, 2); + const __u32 handshake_len = pkt.n; + + struct buf login_body = {{0}, 0}; + put_varint(&login_body, 0x00); // packet id + put_varint(&login_body, 5); // username length + put_fill(&login_body, 'a', 5); + put_fill(&login_body, 0xAB, 16); // uuid (1.20.2+) + struct buf login = packetize(&login_body); + memcpy(pkt.b + pkt.n, login.b, login.n); + pkt.n += login.n; + + __s32 proto = 0; + __u32 resume = 0; + CHECK(run_handshake(pkt.b, pkt.n, &proto, &resume) == DIRECT_READ_LOGIN); + CHECK(resume == handshake_len); + CHECK(proto == 765); + CHECK(run_login(pkt.b + resume, pkt.n - resume, proto) == 1); +} + +/* ------------------------------------------------------------------------ + * Login request, all protocol eras: + * < 759 [id][name] + * 759 (1.19) [id][name][key block] + * 760 (1.19.1/2) [id][name][key block][has uuid + uuid] + * 761+ (1.19.3) [id][name][has uuid + uuid] + * 764+ (1.20.2) [id][name][uuid] + * --------------------------------------------------------------------- */ + +static struct buf build_simple_login(__u32 name_len) +{ + struct buf body = {{0}, 0}; + put_varint(&body, 0x00); // packet id + put_varint(&body, (__s32)name_len); + put_fill(&body, 'a', name_len); + return body; // callers append era-specific fields, then packetize +} + +static void test_login_pre_1_19(void) +{ + struct buf body = build_simple_login(5); + struct buf pkt = packetize(&body); + CHECK(run_login(pkt.b, pkt.n, 47) == 1); // 1.8 + CHECK(run_login(pkt.b, pkt.n, 758) == 1); // 1.18.2 + + // 1.20.2+ requires a uuid after the name, so the same bytes must fail + CHECK(run_login(pkt.b, pkt.n, 764) == 0); + + CHECK(run_login(pkt.b, pkt.n - 1, 47) == 0); // truncated + put_u8(&pkt, 0x00); + CHECK(run_login(pkt.b, pkt.n, 47) == 0); // trailing byte + CHECK(run_login(pkt.b, 0, 47) == 0); // empty payload +} + +static void test_login_username_rules(void) +{ + // 16 chars is the online-mode maximum + struct buf name16 = build_simple_login(16); + struct buf pkt16 = packetize(&name16); + CHECK(run_login(pkt16.b, pkt16.n, 758) == 1); + + struct buf name17 = build_simple_login(17); + struct buf pkt17 = packetize(&name17); + CHECK(run_login(pkt17.b, pkt17.n, 758) == 0); + + // offline mode allows up to 48 bytes (16 chars * 3 utf-8 bytes) + ONLINE_NAMES = 0; + CHECK(run_login(pkt17.b, pkt17.n, 758) == 1); + struct buf name48 = build_simple_login(48); + struct buf pkt48 = packetize(&name48); + CHECK(run_login(pkt48.b, pkt48.n, 758) == 1); + struct buf name49 = build_simple_login(49); + struct buf pkt49 = packetize(&name49); + CHECK(run_login(pkt49.b, pkt49.n, 758) == 0); + ONLINE_NAMES = 1; + + // empty names are impossible + struct buf name0 = build_simple_login(0); + struct buf pkt0 = packetize(&name0); + CHECK(run_login(pkt0.b, pkt0.n, 758) == 0); + + // a non-canonical two-byte encoding of name length 16 must be rejected + // (the reader is limited to the canonical varint width) + struct buf body = {{0}, 0}; + put_varint(&body, 0x00); + put_u8(&body, 0x90); // 16 with a needless continuation byte + put_u8(&body, 0x00); + put_fill(&body, 'a', 16); + struct buf pkt = packetize(&body); + CHECK(run_login(pkt.b, pkt.n, 758) == 0); +} + +static void test_login_1_19_key_block(void) +{ + // 1.19 - 1.19.2 (759/760): optional chat signing key after the name + struct buf no_key = build_simple_login(5); + put_u8(&no_key, 0x00); // has_public_key = false + struct buf no_key_pkt = packetize(&no_key); + CHECK(run_login(no_key_pkt.b, no_key_pkt.n, 759) == 1); + + struct buf with_key = build_simple_login(5); + put_u8(&with_key, 0x01); // has_public_key = true + put_fill(&with_key, 0xEE, 8); // expiry timestamp + put_varint(&with_key, 16); // key length + put_fill(&with_key, 0xBB, 16); // key + put_varint(&with_key, 32); // signature length + put_fill(&with_key, 0xCC, 32); // signature + struct buf with_key_pkt = packetize(&with_key); + CHECK(run_login(with_key_pkt.b, with_key_pkt.n, 759) == 1); + + // key larger than the protocol maximum (512) + struct buf big_key = build_simple_login(5); + put_u8(&big_key, 0x01); + put_fill(&big_key, 0xEE, 8); + put_varint(&big_key, 513); + put_fill(&big_key, 0xBB, 513); + put_varint(&big_key, 32); + put_fill(&big_key, 0xCC, 32); + struct buf big_key_pkt = packetize(&big_key); + CHECK(run_login(big_key_pkt.b, big_key_pkt.n, 759) == 0); + + // 1.19.3 (761) dropped the key block again, so a packet carrying one + // must be rejected there. (The no-key packet is coincidentally also a + // valid 761 packet: its 0x00 then reads as has_uuid = false.) + CHECK(run_login(with_key_pkt.b, with_key_pkt.n, 761) == 0); + CHECK(run_login(no_key_pkt.b, no_key_pkt.n, 761) == 1); +} + +static void test_login_uuid_eras(void) +{ + // 760 (1.19.1): key block, then optional uuid + struct buf v760 = build_simple_login(5); + put_u8(&v760, 0x00); // has_public_key = false + put_u8(&v760, 0x01); // has_uuid = true + put_fill(&v760, 0xAB, 16); // uuid + struct buf v760_pkt = packetize(&v760); + CHECK(run_login(v760_pkt.b, v760_pkt.n, 760) == 1); + + struct buf v760_no_uuid = build_simple_login(5); + put_u8(&v760_no_uuid, 0x00); // has_public_key = false + put_u8(&v760_no_uuid, 0x00); // has_uuid = false + struct buf v760_no_uuid_pkt = packetize(&v760_no_uuid); + CHECK(run_login(v760_no_uuid_pkt.b, v760_no_uuid_pkt.n, 760) == 1); + + // 761/762 (1.19.3/4): no key block, optional uuid + struct buf v762 = build_simple_login(5); + put_u8(&v762, 0x01); + put_fill(&v762, 0xAB, 16); + struct buf v762_pkt = packetize(&v762); + CHECK(run_login(v762_pkt.b, v762_pkt.n, 762) == 1); + + // 764+ (1.20.2): uuid always present, no flag + struct buf v765 = build_simple_login(5); + put_fill(&v765, 0xAB, 16); + struct buf v765_pkt = packetize(&v765); + CHECK(run_login(v765_pkt.b, v765_pkt.n, 765) == 1); + + // truncated uuid + struct buf v765_short = build_simple_login(5); + put_fill(&v765_short, 0xAB, 8); + struct buf v765_short_pkt = packetize(&v765_short); + CHECK(run_login(v765_short_pkt.b, v765_short_pkt.n, 765) == 0); +} + +/* --------------------------------------------------------------------- */ + +int main(void) +{ + test_varint_decodes_known_vectors(); + test_varint_roundtrip(); + test_varint_rejects_truncated_input(); + test_varint_respects_max_size(); + test_varint_never_reads_past_payload_end(); + test_varint_stops_at_terminator(); + test_varint_rejects_overlong_encoding(); + test_bounds_macros(); + test_status_request(); + test_ping_request(); + test_handshake_intentions(); + test_handshake_legacy_ping(); + test_handshake_rejects_malformed(); + test_handshake_combined_with_status_request(); + test_handshake_combined_with_login(); + test_login_pre_1_19(); + test_login_username_rules(); + test_login_1_19_key_block(); + test_login_uuid_eras(); + + printf("%u checks, %u failures\n", checks_run, checks_failed); + return checks_failed ? 1 : 0; +} diff --git a/xdp/varint.h b/xdp/varint.h new file mode 100644 index 0000000..40d0911 --- /dev/null +++ b/xdp/varint.h @@ -0,0 +1,88 @@ +#ifndef VARINT_H +#define VARINT_H + +#include + +#include "common.h" + +/* + * Bounded reader for Minecraft VarInts. + * + * VarInts encode 7 bits per byte, least significant group first; the high + * bit of each byte marks continuation. A 32-bit value therefore occupies at + * most 5 bytes. + */ + +#define MIN_VARINT_BYTES 1 +#define MAX_VARINT_BYTES 5 + +// number of bytes a compile-time constant occupies when varint-encoded +#define VARINT_SIZE(n) \ + (((__u32)(n) <= 0x7F) ? 1 : \ + ((__u32)(n) <= 0x3FFF) ? 2 : \ + ((__u32)(n) <= 0x1FFFFF) ? 3 : \ + ((__u32)(n) <= 0xFFFFFFF) ? 4 : 5) + +struct varint_value +{ + __s32 value; + __u32 bytes; // bytes consumed (1 to 5), 0 on parse failure +}; +_Static_assert(sizeof(struct varint_value) == 8, "varint_value size mismatch!"); + +static __always_inline struct varint_value varint(__s32 value, __u32 bytes) +{ + return (struct varint_value){value, bytes}; +} + +// One unrolled step of read_varint_sized(): bounds-check one byte, fold it +// into result, and return from the enclosing function once the continuation +// bit ends. Must be a macro because it returns/jumps on behalf of its caller; +// the manual unrolling keeps the parse loop verifier-friendly. +#define VARINT_BYTE(ptr, pend, dend, max, idx, shift, result) \ + do \ + { \ + if ((max) < (idx)) \ + goto error; \ + if ((const void *)(ptr) + 1 > (const void *)(dend)) \ + goto error; \ + barrier_var(ptr); \ + if ((const void *)(ptr) + 1 > (const void *)(pend)) \ + goto error; \ + barrier_var(ptr); \ + __u8 _b = *(ptr)++; \ + /* shift in unsigned: 0x0F << 28 would overflow __s32 */ \ + (result) |= (__s32)((__u32)(_b & 0x7F) << (shift)); \ + if (!(_b & 0x80)) \ + return varint((result), (idx)); \ + } while (0) + +// reads a varint of at most max_size bytes, never touching memory beyond +// payload_end/data_end; returns {0, 0} on any violation +static __always_inline struct varint_value read_varint_sized(const __u8 *cursor, const __u8 *payload_end, const __u8 max_size, const void *data_end) +{ + __s32 result = 0; + + VARINT_BYTE(cursor, payload_end, data_end, max_size, 1, 0, result); + VARINT_BYTE(cursor, payload_end, data_end, max_size, 2, 7, result); + VARINT_BYTE(cursor, payload_end, data_end, max_size, 3, 14, result); + VARINT_BYTE(cursor, payload_end, data_end, max_size, 4, 21, result); + VARINT_BYTE(cursor, payload_end, data_end, max_size, 5, 28, result); + +error: + return varint(0, 0); +} + +// reads a varint of at most max bytes into dest and advances ptr past it, +// or returns 0 from the calling function on failure +#define READ_VARINT_OR_RETURN(dest, ptr, pend, dend, max) \ + do \ + { \ + dest = read_varint_sized(ptr, pend, max, dend); \ + if (!(dest).bytes) \ + return 0; \ + (ptr) += (dest).bytes; \ + barrier_var(ptr); \ + } while (0) + +#endif