diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9857f9c..374a0a0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -12,6 +12,7 @@ jobs: name: Build Linux binaries runs-on: ubuntu-latest strategy: + fail-fast: false matrix: arch: [amd64, arm64] steps: @@ -21,7 +22,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version: '1.24.2' - name: Install cross-compilation tools if: matrix.arch == 'arm64' @@ -29,27 +30,45 @@ jobs: sudo apt-get update sudo apt-get install -y gcc-aarch64-linux-gnu - - name: Build binaries + - name: Build package run: | - # Build Go binary - GOOS=linux GOARCH=${{ matrix.arch }} CGO_ENABLED=0 \ - go build -ldflags="-s -w -X main.version=${{ github.event.release.tag_name }}" \ - -o wrapguard . - - # Build C library if [ "${{ matrix.arch }}" = "arm64" ]; then - aarch64-linux-gnu-gcc -fPIC -shared -Wall -O2 \ - -o libwrapguard.so lib/intercept.c -ldl -lpthread + export C_COMPILER=aarch64-linux-gnu-gcc else - gcc -fPIC -shared -Wall -O2 \ - -o libwrapguard.so lib/intercept.c -ldl -lpthread + export C_COMPILER=gcc fi + make build \ + TARGET_GOOS=linux \ + TARGET_GOARCH=${{ matrix.arch }} \ + TARGET_DIR=dist/linux-${{ matrix.arch }} \ + C_COMPILER="$C_COMPILER" + - name: Create release archive + id: package + run: | + archive="wrapguard-${{ github.event.release.tag_name }}-linux-${{ matrix.arch }}.tar.gz" + tar -C "dist/linux-${{ matrix.arch }}" -czf "$archive" \ + wrapguard libwrapguard.so \ + -C "$GITHUB_WORKSPACE" README.md example-wg0.conf + echo "archive=$archive" >> "$GITHUB_OUTPUT" + + - name: Validate release archive run: | - chmod +x wrapguard - tar -czf wrapguard-${{ github.event.release.tag_name }}-linux-${{ matrix.arch }}.tar.gz \ - wrapguard libwrapguard.so README.md example-wg0.conf + archive="${{ steps.package.outputs.archive }}" + verify_dir="$(mktemp -d)" + tar -xzf "$archive" -C "$verify_dir" + test -x "$verify_dir/wrapguard" + test -f "$verify_dir/libwrapguard.so" + test -f "$verify_dir/README.md" + test -f "$verify_dir/example-wg0.conf" + "$verify_dir/wrapguard" --version + "$verify_dir/wrapguard" --help + + - name: Generate checksum + run: | + archive="${{ steps.package.outputs.archive }}" + sha256sum "$archive" > "$archive.sha256" - name: Upload Release Asset uses: actions/upload-release-asset@v1 @@ -57,14 +76,25 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: upload_url: ${{ github.event.release.upload_url }} - asset_path: ./wrapguard-${{ github.event.release.tag_name }}-linux-${{ matrix.arch }}.tar.gz - asset_name: wrapguard-${{ github.event.release.tag_name }}-linux-${{ matrix.arch }}.tar.gz + asset_path: ./${{ steps.package.outputs.archive }} + asset_name: ${{ steps.package.outputs.archive }} asset_content_type: application/gzip + - name: Upload Checksum Asset + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ github.event.release.upload_url }} + asset_path: ./${{ steps.package.outputs.archive }}.sha256 + asset_name: ${{ steps.package.outputs.archive }}.sha256 + asset_content_type: text/plain + build-macos: name: Build macOS binaries runs-on: macos-latest strategy: + fail-fast: false matrix: arch: [amd64, arm64] steps: @@ -74,24 +104,104 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version: '1.24.2' - - name: Build binaries + - name: Build package run: | - # Build Go binary - GOOS=darwin GOARCH=${{ matrix.arch }} CGO_ENABLED=0 \ - go build -ldflags="-s -w -X main.version=${{ github.event.release.tag_name }}" \ - -o wrapguard . - - # Build C library (dylib for macOS) - clang -fPIC -shared -Wall -O2 \ - -o libwrapguard.dylib lib/intercept.c -ldl -lpthread + make build \ + TARGET_GOOS=darwin \ + TARGET_GOARCH=${{ matrix.arch }} \ + TARGET_DIR=dist/darwin-${{ matrix.arch }} \ + C_COMPILER=clang - name: Create release archive + id: package run: | - chmod +x wrapguard - tar -czf wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz \ - wrapguard libwrapguard.dylib README.md example-wg0.conf + archive="wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz" + tar -C "dist/darwin-${{ matrix.arch }}" -czf "$archive" \ + wrapguard libwrapguard.dylib \ + -C "$GITHUB_WORKSPACE" README.md example-wg0.conf + echo "archive=$archive" >> "$GITHUB_OUTPUT" + + - name: Validate release archive + run: | + archive="${{ steps.package.outputs.archive }}" + verify_dir="$(mktemp -d)" + tar -xzf "$archive" -C "$verify_dir" + test -x "$verify_dir/wrapguard" + test -f "$verify_dir/libwrapguard.dylib" + test -f "$verify_dir/README.md" + test -f "$verify_dir/example-wg0.conf" + "$verify_dir/wrapguard" --version + "$verify_dir/wrapguard" --help + + - name: Generate checksum + run: | + archive="${{ steps.package.outputs.archive }}" + shasum -a 256 "$archive" > "$archive.sha256" + + - name: Upload workflow artifact + uses: actions/upload-artifact@v4 + with: + name: wrapguard-macos-${{ matrix.arch }} + path: | + ${{ steps.package.outputs.archive }} + ${{ steps.package.outputs.archive }}.sha256 + if-no-files-found: error + + verify-macos-release-archives: + name: Verify macOS release archives + needs: build-macos + runs-on: macos-latest + strategy: + fail-fast: false + matrix: + arch: [amd64, arm64] + steps: + - name: Download packaged archive + uses: actions/download-artifact@v4 + with: + name: wrapguard-macos-${{ matrix.arch }} + path: ${{ runner.temp }}/wrapguard-macos-${{ matrix.arch }} + + - name: Validate archive contents + run: | + artifact_dir="${{ runner.temp }}/wrapguard-macos-${{ matrix.arch }}" + archive="wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz" + archive_path="$artifact_dir/$archive" + checksum_path="$archive_path.sha256" + verify_dir="$(mktemp -d)" + + test -f "$archive_path" + test -f "$checksum_path" + + expected_sum="$(awk '{print $1}' "$checksum_path")" + actual_sum="$(shasum -a 256 "$archive_path" | awk '{print $1}')" + test "$actual_sum" = "$expected_sum" + + tar -xzf "$archive_path" -C "$verify_dir" + test -x "$verify_dir/wrapguard" + test -f "$verify_dir/libwrapguard.dylib" + test -f "$verify_dir/README.md" + test -f "$verify_dir/example-wg0.conf" + chmod +x "$verify_dir/wrapguard" + "$verify_dir/wrapguard" --version + "$verify_dir/wrapguard" --help + + publish-macos-release-assets: + name: Publish macOS release assets + needs: verify-macos-release-archives + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + arch: [amd64, arm64] + steps: + - name: Download packaged archive + uses: actions/download-artifact@v4 + with: + name: wrapguard-macos-${{ matrix.arch }} + path: ${{ runner.temp }}/wrapguard-macos-${{ matrix.arch }} - name: Upload Release Asset uses: actions/upload-release-asset@v1 @@ -99,7 +209,16 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: upload_url: ${{ github.event.release.upload_url }} - asset_path: ./wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz + asset_path: ${{ runner.temp }}/wrapguard-macos-${{ matrix.arch }}/wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz asset_name: wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz asset_content_type: application/gzip + - name: Upload Checksum Asset + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ github.event.release.upload_url }} + asset_path: ${{ runner.temp }}/wrapguard-macos-${{ matrix.arch }}/wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz.sha256 + asset_name: wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz.sha256 + asset_content_type: text/plain diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3f8b8c3..fdef14b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,8 +8,12 @@ on: jobs: test: - name: Test - runs-on: ubuntu-latest + name: Test (${{ matrix.os }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] steps: - name: Check out code @@ -18,7 +22,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24.2' - name: Cache Go modules uses: actions/cache@v4 @@ -26,9 +30,9 @@ jobs: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go-1.23-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go-1.24.2-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go-1.23- + ${{ runner.os }}-go-1.24.2- - name: Download dependencies run: go mod download @@ -37,12 +41,33 @@ jobs: run: go mod verify - name: Run tests - run: go test -v -race -coverprofile=coverage.out ./... + run: | + if [ "${{ matrix.os }}" = "ubuntu-latest" ]; then + go test -v -race -coverprofile=coverage.out ./... + else + go test -v ./... + fi - - name: Run tests with coverage - run: go test -cover ./... + - name: Build package + run: make build + + - name: Verify build outputs + run: | + test -f wrapguard + if [ "${{ matrix.os }}" = "macos-latest" ]; then + test -f libwrapguard.dylib + else + test -f libwrapguard.so + fi + ./wrapguard --version + ./wrapguard --help + + - name: Smoke test packaged macOS archive + if: matrix.os == 'macos-latest' + run: make smoke-macos - name: Upload coverage reports to Codecov + if: matrix.os == 'ubuntu-latest' uses: codecov/codecov-action@v4 with: file: ./coverage.out @@ -53,7 +78,7 @@ jobs: lint: name: Lint runs-on: ubuntu-latest - + steps: - name: Check out code uses: actions/checkout@v4 @@ -61,7 +86,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.23' + go-version: '1.24.2' - name: Run go vet run: go vet ./... @@ -73,34 +98,3 @@ jobs: gofmt -d . exit 1 fi - - build: - name: Build - runs-on: ubuntu-latest - - steps: - - name: Check out code - uses: actions/checkout@v4 - - - name: Set up Go - uses: actions/setup-go@v5 - with: - go-version: '1.23' - - - name: Install build dependencies - run: sudo apt-get update && sudo apt-get install -y gcc - - - name: Build binary - run: make build - - - name: Verify binary exists - run: | - ls -la wrapguard - ls -la libwrapguard.so - file wrapguard - file libwrapguard.so - - - name: Test binary runs - run: | - ./wrapguard --version - ./wrapguard --help diff --git a/Makefile b/Makefile index b6f443a..39de9ff 100644 --- a/Makefile +++ b/Makefile @@ -1,36 +1,103 @@ -.PHONY: all build clean test +SHELL := /bin/bash + +.PHONY: all build build-target build-linux build-linux-amd64 build-linux-arm64 build-macos build-macos-amd64 build-macos-arm64 build-macos-universal build-all clean test test-coverage deps fmt lint smoke-macos smoke-macos-browser help # Build variables GO_MODULE = github.com/puzed/wrapguard BINARY_NAME = wrapguard -LIBRARY_NAME = libwrapguard.so -VERSION = 1.0.0-dev - -# Build flags +VERSION ?= 1.0.0-dev +DIST_DIR ?= dist +TARGET_GOOS ?= $(shell go env GOOS) +TARGET_GOARCH ?= $(shell go env GOARCH) +TARGET_DIR ?= . GO_BUILD_FLAGS = -ldflags="-s -w -X main.version=$(VERSION)" -C_BUILD_FLAGS = -shared -fPIC -ldl +LIBRARY_NAME = $(if $(filter darwin,$(TARGET_GOOS)),libwrapguard.dylib,libwrapguard.so) +BROWSER_APP ?= +BROWSER_ARGS_TEMPLATE ?= --no-remote -profile __PROFILE__ +BROWSER_PROFILE_DIR ?= +SMOKE_URL ?= http://icanhazip.com +WG_CONFIG ?= +WG_LOG_FILE ?= /tmp/wrapguard-browser-smoke.log + +ifeq ($(TARGET_GOOS),darwin) + ifeq ($(TARGET_GOARCH),amd64) + DARWIN_ARCH = x86_64 + else ifeq ($(TARGET_GOARCH),arm64) + DARWIN_ARCH = arm64 + else + DARWIN_ARCH = $(TARGET_GOARCH) + endif + C_COMPILER ?= clang + GO_CGO_ENV = CGO_ENABLED=1 CC="$(C_COMPILER)" CGO_CFLAGS="-arch $(DARWIN_ARCH)" CGO_LDFLAGS="-arch $(DARWIN_ARCH)" + C_ARCH_FLAGS = -arch $(DARWIN_ARCH) + C_SHARED_FLAGS = -dynamiclib + C_WARNING_FLAGS = -Wall -Wextra -Wpedantic -O2 + C_LINK_FLAGS = -Wl,-undefined,dynamic_lookup +else + C_COMPILER ?= gcc + GO_CGO_ENV = CGO_ENABLED=1 CC="$(C_COMPILER)" + C_ARCH_FLAGS = + C_SHARED_FLAGS = -shared -fPIC + C_WARNING_FLAGS = -Wall -Wextra -Wpedantic -O2 + C_LINK_FLAGS = -ldl +endif # Default target all: build -# Build both Go binary and C library -build: $(BINARY_NAME) $(LIBRARY_NAME) - -# Build Go binary -$(BINARY_NAME): *.go go.mod go.sum - @echo "Building Go binary..." - go mod tidy - go build $(GO_BUILD_FLAGS) -o $(BINARY_NAME) . - -# Build C library -$(LIBRARY_NAME): lib/intercept.c - @echo "Building C library..." - gcc $(C_BUILD_FLAGS) -o $(LIBRARY_NAME) lib/intercept.c +# Build the current host platform into the requested output directory +build: build-target + +build-target: + @mkdir -p "$(TARGET_DIR)" + @echo "Building $(TARGET_GOOS)/$(TARGET_GOARCH) into $(TARGET_DIR)..." + @GOOS=$(TARGET_GOOS) GOARCH=$(TARGET_GOARCH) $(GO_CGO_ENV) go build $(GO_BUILD_FLAGS) -o "$(TARGET_DIR)/$(BINARY_NAME)" . + @$(C_COMPILER) $(C_ARCH_FLAGS) $(C_SHARED_FLAGS) $(C_WARNING_FLAGS) lib/intercept.c $(C_LINK_FLAGS) -o "$(TARGET_DIR)/$(LIBRARY_NAME)" + +build-linux: TARGET_GOOS = linux +build-linux: TARGET_DIR = $(DIST_DIR)/linux-$(TARGET_GOARCH) +build-linux: build-target + +build-linux-amd64: + @$(MAKE) build-linux TARGET_GOARCH=amd64 + +build-linux-arm64: + @$(MAKE) build-linux TARGET_GOARCH=arm64 C_COMPILER=aarch64-linux-gnu-gcc + +build-macos: TARGET_GOOS = darwin +build-macos: TARGET_DIR = $(DIST_DIR)/darwin-$(TARGET_GOARCH) +build-macos: build-target + +build-macos-amd64: + @$(MAKE) build-macos TARGET_GOARCH=amd64 + +build-macos-arm64: + @$(MAKE) build-macos TARGET_GOARCH=arm64 + +build-macos-universal: TARGET_GOOS = darwin +build-macos-universal: + @if [ "$$(uname -s)" != "Darwin" ]; then \ + echo "build-macos-universal must be run on macOS"; \ + exit 1; \ + fi + @set -euo pipefail; \ + stage_dir="$$(mktemp -d)"; \ + final_dir="$(DIST_DIR)/darwin-universal"; \ + trap 'rm -rf "$$stage_dir"' EXIT; \ + $(MAKE) build-target TARGET_GOOS=darwin TARGET_GOARCH=amd64 TARGET_DIR="$$stage_dir/amd64" C_COMPILER=clang; \ + $(MAKE) build-target TARGET_GOOS=darwin TARGET_GOARCH=arm64 TARGET_DIR="$$stage_dir/arm64" C_COMPILER=clang; \ + mkdir -p "$$final_dir"; \ + lipo -create "$$stage_dir/amd64/$(BINARY_NAME)" "$$stage_dir/arm64/$(BINARY_NAME)" -output "$$final_dir/$(BINARY_NAME)"; \ + lipo -create "$$stage_dir/amd64/$(LIBRARY_NAME)" "$$stage_dir/arm64/$(LIBRARY_NAME)" -output "$$final_dir/$(LIBRARY_NAME)"; \ + chmod +x "$$final_dir/$(BINARY_NAME)"; \ + echo "Built universal macOS binaries in $$final_dir" + +build-all: build-linux-amd64 build-linux-arm64 build-macos-amd64 build-macos-arm64 # Clean build artifacts clean: @echo "Cleaning build artifacts..." - rm -f $(BINARY_NAME) $(LIBRARY_NAME) + rm -rf "$(DIST_DIR)" "$(BINARY_NAME)" "$(LIBRARY_NAME)" go clean # Run tests @@ -45,7 +112,7 @@ test-coverage: # Build debug version debug: GO_BUILD_FLAGS = -ldflags="-X main.version=$(VERSION)-debug" -debug: C_BUILD_FLAGS += -g -O0 +debug: C_WARNING_FLAGS += -g -O0 debug: build # Install dependencies @@ -63,19 +130,60 @@ lint: @echo "Running linter..." go vet ./... -# Build for multiple platforms -build-all: build-linux build-darwin - -build-linux: - @echo "Building for Linux..." - GOOS=linux GOARCH=amd64 go build $(GO_BUILD_FLAGS) -o $(BINARY_NAME)-linux-amd64 . - gcc $(C_BUILD_FLAGS) -o libwrapguard-linux-amd64.so lib/intercept.c - -build-darwin: - @echo "Building for macOS..." - GOOS=darwin GOARCH=amd64 go build $(GO_BUILD_FLAGS) -o $(BINARY_NAME)-darwin-amd64 . - GOOS=darwin GOARCH=arm64 go build $(GO_BUILD_FLAGS) -o $(BINARY_NAME)-darwin-arm64 . - gcc $(C_BUILD_FLAGS) -o libwrapguard-darwin.dylib lib/intercept.c +# Validate a local macOS package end to end +smoke-macos: + @if [ "$$(uname -s)" != "Darwin" ]; then \ + echo "smoke-macos must be run on macOS"; \ + exit 1; \ + fi + @set -euo pipefail; \ + $(MAKE) build-macos TARGET_GOARCH=$(TARGET_GOARCH); \ + staging="$$(mktemp -d)"; \ + package_dir="$$staging/package"; \ + verify_dir="$$staging/verify"; \ + mkdir -p "$$package_dir" "$$verify_dir"; \ + cp "$(DIST_DIR)/darwin-$(TARGET_GOARCH)/$(BINARY_NAME)" "$$package_dir/"; \ + cp "$(DIST_DIR)/darwin-$(TARGET_GOARCH)/$(LIBRARY_NAME)" "$$package_dir/"; \ + cp README.md example-wg0.conf "$$package_dir/"; \ + tar -C "$$package_dir" -czf "$$staging/$(BINARY_NAME)-macos-smoke.tar.gz" $(BINARY_NAME) $(LIBRARY_NAME) README.md example-wg0.conf; \ + tar -xzf "$$staging/$(BINARY_NAME)-macos-smoke.tar.gz" -C "$$verify_dir"; \ + test -x "$$verify_dir/$(BINARY_NAME)"; \ + test -f "$$verify_dir/$(LIBRARY_NAME)"; \ + chmod +x "$$verify_dir/$(BINARY_NAME)"; \ + "$$verify_dir/$(BINARY_NAME)" --version; \ + "$$verify_dir/$(BINARY_NAME)" --help; \ + rm -rf "$$staging" + +# Launch an experimental macOS browser target through WrapGuard with a fresh profile +smoke-macos-browser: + @if [ "$$(uname -s)" != "Darwin" ]; then \ + echo "smoke-macos-browser must be run on macOS"; \ + exit 1; \ + fi + @if [ -z "$(WG_CONFIG)" ]; then \ + echo "WG_CONFIG=/path/to/config.conf is required"; \ + exit 1; \ + fi + @if [ -z "$(BROWSER_APP)" ]; then \ + echo "BROWSER_APP=/Applications/LibreWolf.app/Contents/MacOS/librewolf is required"; \ + exit 1; \ + fi + @set -euo pipefail; \ + $(MAKE) build; \ + profile_dir="$(BROWSER_PROFILE_DIR)"; \ + if [ -z "$$profile_dir" ]; then \ + profile_dir="$$(mktemp -d /tmp/wrapguard-browser-profile.XXXXXX)"; \ + echo "Using temporary browser profile: $$profile_dir"; \ + else \ + mkdir -p "$$profile_dir"; \ + echo "Using browser profile: $$profile_dir"; \ + fi; \ + args_template='$(BROWSER_ARGS_TEMPLATE)'; \ + browser_args="$${args_template//__PROFILE__/$$profile_dir}"; \ + echo "Logging to $(WG_LOG_FILE)"; \ + echo "Suggested validation URL: $(SMOKE_URL)"; \ + echo "Launching $(BROWSER_APP) $$browser_args"; \ + eval "./$(BINARY_NAME) --config=\"$(WG_CONFIG)\" --log-level=debug --log-file=\"$(WG_LOG_FILE)\" -- \"$(BROWSER_APP)\" $$browser_args" # Run demo demo: build @@ -85,15 +193,24 @@ demo: build # Help help: @echo "Available targets:" - @echo " all - Build both binary and library (default)" - @echo " build - Build both binary and library" - @echo " clean - Clean build artifacts" - @echo " test - Run tests" - @echo " test-coverage- Run tests with coverage" - @echo " debug - Build debug version" - @echo " deps - Install dependencies" - @echo " fmt - Format Go code" - @echo " lint - Run linter" - @echo " build-all - Build for multiple platforms" - @echo " demo - Run demo" - @echo " help - Show this help" \ No newline at end of file + @echo " all - Build the current host platform (default)" + @echo " build - Build the current host platform" + @echo " build-linux - Build a Linux package into dist/" + @echo " build-linux-amd64 - Build a Linux amd64 package into dist/" + @echo " build-linux-arm64 - Build a Linux arm64 package into dist/" + @echo " build-macos - Build a macOS package into dist/" + @echo " build-macos-amd64 - Build a macOS amd64 package into dist/" + @echo " build-macos-arm64 - Build a macOS arm64 package into dist/" + @echo " build-macos-universal - Build a universal macOS package into dist/" + @echo " build-all - Build all packaged Linux and macOS variants" + @echo " clean - Clean build artifacts" + @echo " test - Run tests" + @echo " test-coverage - Run tests with coverage" + @echo " debug - Build debug version" + @echo " deps - Install dependencies" + @echo " fmt - Format Go code" + @echo " lint - Run go vet" + @echo " smoke-macos - Validate a local macOS package end to end" + @echo " smoke-macos-browser - Launch a macOS browser target via WrapGuard with a fresh profile" + @echo " demo - Run demo" + @echo " help - Show this help" diff --git a/README.md b/README.md index 4a5e497..7bf67d3 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,14 @@ # WrapGuard - Userspace WireGuard Proxy -WrapGuard enables any application to transparently route ALL network traffic through a WireGuard VPN without requiring container privileges or kernel modules. +WrapGuard enables applications to route network traffic through a WireGuard VPN from userspace without requiring container privileges or kernel modules. + +Linux is the primary production target today. macOS build, packaging, and regression checks run in CI, and the macOS runtime path has proven direct-launch TCP routing for CLI targets, but it is still experimental and limited to targets that can accept injection. ## Features - **Pure Userspace**: No TUN interface creation, no NET_ADMIN capability needed -- **Transparent Interception**: Uses LD_PRELOAD to intercept all network calls -- **Bidirectional Support**: Both incoming and outgoing connections work +- **Transparent Interception**: Uses platform-specific dynamic library injection on supported hosts +- **Bidirectional Support**: Both incoming and outgoing connections work on supported Linux builds - **Standard Config**: Uses standard WireGuard configuration files ## Installation @@ -15,6 +17,8 @@ WrapGuard enables any application to transparently route ALL network traffic thr Download pre-compiled binaries for Linux and macOS from the [releases page](https://github.com/puzed/wrapguard/releases). +Linux releases are production-targeted. macOS archives are packaged and validated for layout consistency plus release-archive smoke checks, but macOS support should still be treated as experimental until the runtime launcher work is complete. + **No additional dependencies required** - WrapGuard is a single binary that includes everything needed to create WireGuard connections. You don't need WireGuard installed on your host machine, kernel modules, or any other VPN software. ### Building from Source @@ -24,8 +28,22 @@ make build ``` This will create: -- `wrapguard` - The main executable (single binary with embedded library) -- `libwrapguard.so` - The LD_PRELOAD library +- `wrapguard` - The main executable +- `libwrapguard.so` on Linux or `libwrapguard.dylib` on macOS - the injected library that lives next to the binary + +## Support Matrix + +- Linux `amd64` and `arm64`: supported for production use. +- macOS 14 Sonoma and macOS 15 Sequoia on `amd64` and `arm64`: experimental direct-launch support for targets that can be launched as an executable path. +- If you are experimenting with a simple `.app` bundle, WrapGuard can resolve the inner executable when `Contents/MacOS` contains a single clear candidate; `open -a` remains unsupported. +- Direct CLI launches through a non-SIP shell binary are supported in the same experimental sense as other direct macOS CLI targets. +- `open -a` and launch targets that depend on Apple-managed app launchers: unsupported. +- macOS releases newer than Sequoia may work, but they are not yet part of the documented QA matrix. +- macOS SIP-protected binaries in locations such as `/usr/bin`, `/bin`, `/System`, `/sbin`, and `/usr/libexec`: rejected before launch and unsupported. +- Third-party signed GUI apps, hardened-runtime apps, sandboxed helpers, and browser-style multi-process apps may still launch but are not supported and may be unstable under injection. +- Routed outbound TCP is the documented and tested macOS path today. UDP is not tunneled on macOS, and WrapGuard may intentionally suppress likely QUIC `UDP/443` connect attempts so browsers fall back to the TCP path instead of leaking through UDP. +- IPv6 remains outside the production support statement on macOS. +- Windows: unsupported. ## Usage @@ -52,6 +70,62 @@ wrapguard --config=~/wg0.conf --log-level=debug -- curl https://icanhazip.com wrapguard --config=~/wg0.conf --log-level=info --log-file=/tmp/wrapguard.log -- curl https://icanhazip.com ``` +## macOS Guide + +- Launch the target executable directly with WrapGuard. +- For a simple `.app` bundle with a single clear executable in `Contents/MacOS`, you can also pass the bundle path and WrapGuard will resolve the inner executable for you. +- If a bundle contains multiple executable candidates in `Contents/MacOS`, WrapGuard will fail closed and ask you to launch the inner executable explicitly. +- If you want a shell session under WrapGuard on macOS, use a non-SIP shell binary launched directly by WrapGuard. Apple-protected shells in `/bin` are not a supported injection target. +- `open -a AppName` is not equivalent to launching the inner executable directly and is not a supported wrapping path. +- Apple-protected binaries in locations such as `/usr/bin`, `/bin`, `/System`, `/sbin`, and `/usr/libexec` are blocked by SIP and are unsupported. +- Browser-style GUI apps such as Firefox/LibreWolf-class multi-process browsers are still experimental even when TCP interception works. Helper, GPU, compositor, and sandboxed subprocesses may become unstable under DYLD injection. +- For repeatable experimental browser checks, use the documented harness in [docs/macos-browser-validation.md](docs/macos-browser-validation.md) instead of ad hoc launch commands. +- If a browser-style app shows different results on soft refresh versus hard refresh, treat that as a sign that the app may be using cache, service-worker, or alternate transport paths rather than assuming the tunnel path itself is broken. +- Routed outbound TCP is the documented and tested macOS path today. Wrapped UDP and wrapped IPv6 traffic are not yet production-ready on macOS, and broader non-blocking/browser socket compatibility is still under active validation. +- DNS lookups are still resolved by the host network stack. WrapGuard currently routes post-resolution IP-literal TCP destinations through the tunnel, but it does not intercept resolver APIs or tunnel DNS itself. +- Localhost and loopback traffic are intentionally left on the host stack and are not routed through the injected SOCKS path. + +### Experimental GUI Behavior + +Current expected behavior for experimental macOS GUI launches: + +- launch the real executable path directly through WrapGuard +- for `.app` bundles, WrapGuard may resolve `Contents/MacOS/...` automatically only when there is a single clear executable candidate +- if a browser or GUI app needs an already-running app instance, `open -a`, an app launcher service, or a handoff into another unwrapped session, that path is outside the supported model +- the most reliable validation flow is a fresh profile plus a direct inner-executable launch +- if the app stays in the directly launched process tree and accepts DYLD injection, routed outbound TCP can work + +Current unsupported or risky app classes on macOS: + +- Apple-protected or SIP-protected binaries +- hardened-runtime apps that reject injected libraries +- app launchers that immediately hand off to another already-running process or daemon +- apps whose critical helper processes cannot tolerate DYLD injection +- sandboxed GUI apps whose networking or compositor helpers break under interposition + +### Current Browser Transport Decision + +The current experimental browser stance on macOS is: + +- routed TCP is the supported browser transport path +- UDP and native QUIC tunneling are not currently supported on macOS +- WrapGuard may suppress likely browser QUIC / `HTTP/3` `UDP/443` traffic so the browser falls back toward the proven TCP path +- browser support should therefore be treated as experimental direct-launch TCP support, not as full browser-transport equivalence yet + +### macOS Troubleshooting + +- Run `wrapguard --doctor [target]` to check the local runtime layout, selected injection mode, and macOS preflight restrictions before you try a real launch. +- If you pass a target, `--doctor` also validates that launch target before WrapGuard starts the tunnel stack. If you omit the target, it only checks the local runtime artifacts. +- `--doctor` accepts resolvable `.app` bundles, prints the inner executable it selected, and still rejects SIP-protected launch paths on macOS before launch. +- Run `wrapguard --config= --self-test` to validate that WrapGuard can start its IPC/SOCKS stack, inject the library, and observe an intercepted outbound `connect()`. +- If the child starts but WrapGuard reports that the interceptor never announced readiness, confirm that the target is not SIP-protected and that `libwrapguard.dylib` is present next to `wrapguard`. +- If traffic is unchanged, rerun with `--log-level=debug` and check for the startup lines showing `DYLD_INSERT_LIBRARIES`, the resolved dylib path, and the interceptor load message. +- If a hostname-based connection still resolves through the host instead of the VPN, that is expected today. WrapGuard only routes the resulting IP-literal TCP destination through SOCKS and the WireGuard userspace netstack. +- If you see the interceptor load and outbound `connect()` calls being intercepted but the request still exits with the host IP, treat that as a tunnel-path failure. Routed outbound TCP now fails through the WireGuard userspace netstack instead of silently falling back to a direct host connection, so this points to peer reachability, exit-node routing, or upstream WireGuard configuration rather than dylib injection. +- If you are validating a browser-like app and it still reports the host IP after a soft refresh, compare it against a hard refresh and a direct CLI probe such as `curl`; browser caching, service workers, QUIC/HTTP3, or helper-process instability can all change the observed path. +- If codesigning or hardened runtime prevents injection, use an unsigned or developer-controlled binary for now. Notarized and hardened-runtime compatibility is not part of the current experimental support statement. +- If a browser or GUI app loads, intercepts some traffic, and then crashes later, treat that as a helper-process compatibility issue rather than proof that the tunnel path is broken. + ## Routing WrapGuard supports policy-based routing to direct traffic through specific WireGuard peers. @@ -140,14 +214,16 @@ PersistentKeepalive = 25 ## How It Works 1. **Main Process**: Parses config, initializes WireGuard userspace implementation -2. **LD_PRELOAD Library**: Intercepts network system calls (socket, connect, send, recv, etc.) +2. **Injected Library**: Intercepts network system calls using the host-specific dynamic loading path 3. **Virtual Network Stack**: Routes packets between intercepted connections and WireGuard tunnel -4. **Memory-based TUN**: No kernel interface needed, packets processed entirely in memory +4. **Userspace TUN/Netstack**: No kernel interface needed, packets are handled entirely in memory by the WireGuard userspace netstack ## Limitations - Linux and macOS only (Windows is not supported) -- TCP and UDP protocols only +- macOS runtime support is experimental and should not be assumed for SIP-protected binaries, hardened-runtime apps, `open -a` launcher paths, or launch targets that cannot accept injected libraries +- Routed TCP is the primary documented path today +- Wrapped UDP and wrapped IPv6 are not yet documented as supported on macOS - Performance overhead due to userspace packet processing ## Development @@ -170,6 +246,14 @@ go test -cover ./... # Build the main binary make build +# Validate a local macOS package layout end to end +make smoke-macos + +# Launch an experimental macOS browser target with a fresh profile +make smoke-macos-browser \ + WG_CONFIG=../NL-US-PA-16.conf \ + BROWSER_APP="/Applications/LibreWolf.app/Contents/MacOS/librewolf" + # Build with debug information make debug diff --git a/config.go b/config.go index 9272ab5..2c72d0e 100644 --- a/config.go +++ b/config.go @@ -345,7 +345,7 @@ func ApplyCLIRoutes(config *WireGuardConfig, exitNode string, routes []string) e peer.RoutingPolicies = append(peer.RoutingPolicies, policy) peerFound = true - if logger != nil { + if CurrentLogger() != nil { logger.Infof("Added route %s via peer %s", cidr, peerIP) } break diff --git a/docs/macos-browser-validation.md b/docs/macos-browser-validation.md new file mode 100644 index 0000000..94fc6a6 --- /dev/null +++ b/docs/macos-browser-validation.md @@ -0,0 +1,101 @@ +# macOS Browser Validation + +This guide gives us one repeatable way to validate experimental browser support on macOS without drifting between ad hoc commands. + +## Goals + +- launch the browser by its real inner executable, not `open -a` +- use a fresh profile for each validation run +- keep WrapGuard logs in a known location +- distinguish "browser never starts", "browser starts but does not browse", and "browser browses through the VPN" + +## Preconditions + +- run on macOS +- build the repo-root artifacts with `make build` +- use a non-SIP target application +- use a real WireGuard config file that already works with a CLI probe + +Validate the tunnel path first: + +```bash +./wrapguard --config=../NL-US-PA-16.conf -- /opt/homebrew/opt/curl/bin/curl https://icanhazip.com +``` + +If that does not return the VPN IP, stop there. Browser validation is not the next thing to debug. + +## Repeatable Harness + +Use the new `make smoke-macos-browser` target. + +LibreWolf / Firefox-style example: + +```bash +make smoke-macos-browser \ + WG_CONFIG=../NL-US-PA-16.conf \ + BROWSER_APP="/Applications/LibreWolf.app/Contents/MacOS/librewolf" \ + BROWSER_ARGS_TEMPLATE="--no-remote -profile __PROFILE__" \ + WG_LOG_FILE=/tmp/wrapguard-librewolf.log +``` + +Brave / Chromium-style example: + +```bash +make smoke-macos-browser \ + WG_CONFIG=../NL-US-PA-16.conf \ + BROWSER_APP="/Applications/Brave Browser.app/Contents/MacOS/Brave Browser" \ + BROWSER_ARGS_TEMPLATE="--user-data-dir=__PROFILE__ --no-first-run --no-default-browser-check --new-window" \ + WG_LOG_FILE=/tmp/wrapguard-brave.log +``` + +How it works: + +- `__PROFILE__` is replaced with a fresh temporary profile directory unless `BROWSER_PROFILE_DIR` is provided +- WrapGuard is rebuilt first so the binary and dylib stay in sync +- WrapGuard logs go to `WG_LOG_FILE` +- the target browser is launched directly through WrapGuard with debug logging enabled + +## Manual Validation Flow + +After launch: + +1. open `http://icanhazip.com` +2. confirm the page shows the VPN IP, not the host IP +3. refresh several times +4. open DevTools and confirm the browser remains stable +5. compare the browser result against a direct CLI probe if anything looks wrong + +## Result Buckets + +Record each run in one of these buckets: + +- `startup-failed` + - browser never reaches a usable window + - log focus: handshake, helper startup, GPU/compositor failures +- `startup-only` + - browser window opens, but pages do not load + - log focus: intercepted `CONNECT` traffic versus local IPC-only noise +- `tunneled` + - browser reaches a page and the public IP is the VPN IP + - keep notes on refresh behavior and DevTools stability +- `regressed` + - browser used to work in the same setup but no longer does + - capture the exact command, app version, and log path + +## What To Watch For + +- repeated `AF_UNIX` / local IPC logs are expected and are not proof of a network leak +- GPU/helper warnings may still appear even when browsing works +- browser-visible host IP after a soft refresh can still indicate cache, service-worker, or QUIC / HTTP3 behavior rather than a total loss of TCP interception +- on macOS today, UDP is not a supported tunneled transport; WrapGuard only tries to suppress likely browser QUIC traffic enough to push browsers back toward the proven TCP path + +## Current Known-Good Result + +Current experimental known-good validation on Apple Silicon: + +- LibreWolf launches via its inner executable +- `http://icanhazip.com` shows VPN IP `146.70.156.18` +- repeated refreshes keep showing the VPN IP +- DevTools opens without the earlier recursion crash + +That is a real breakthrough, but it is still not the final product bar for all GUI apps. Broader app coverage and automated regression checks are still open. diff --git a/docs/release-notes-macos.md b/docs/release-notes-macos.md new file mode 100644 index 0000000..c14bc3b --- /dev/null +++ b/docs/release-notes-macos.md @@ -0,0 +1,57 @@ +# WrapGuard macOS Release Notes + +Use this template when cutting a macOS release. Fill in the version-specific details before publishing. + +## Release Summary + +- Version: `vX.Y.Z` +- Release date: `YYYY-MM-DD` +- Supported architectures: `arm64`, `amd64` +- Packaging: `wrapguard--darwin-arm64.tar.gz` and `wrapguard--darwin-amd64.tar.gz` + +## Support Matrix + +- macOS 14 Sonoma: experimental direct-launch support for targets that can be launched as an executable path. +- macOS 15 Sequoia: experimental direct-launch support for targets that can be launched as an executable path. +- `.app` bundle launching: supported only when WrapGuard can resolve a single clear executable in `Contents/MacOS`; otherwise launch the inner executable directly. +- `open -a` launch paths: unsupported. +- System binaries under `/bin`, `/sbin`, `/System`, `/usr/bin`, and `/usr/libexec`: unsupported. +- Browser-style GUI apps: experimental and not considered production-supported. +- Direct inner-executable launches remain the only supported experimental GUI model; launcher handoff into an already-running app session is outside the support statement. + +## Example Commands + +```bash +# Preflight a launch target +wrapguard --doctor /usr/local/bin/curl + +# Run a direct CLI command through WrapGuard +wrapguard --config=wg0.conf -- curl https://icanhazip.com + +# Run the shared experimental browser harness +make smoke-macos-browser \ + WG_CONFIG=wg0.conf \ + BROWSER_APP="/Applications/LibreWolf.app/Contents/MacOS/librewolf" + +# Inspect the packaged build locally +tar -tzf wrapguard-vX.Y.Z-darwin-arm64.tar.gz +``` + +## Known Limitations + +- macOS support is CLI-oriented and relies on direct launching of the target executable. +- SIP-protected system binaries are rejected before launch. +- GUI applications may load when launched via their inner executable, but they can still become unstable if helper processes are not compatible with DYLD injection. +- GUI validation should be treated as direct-launch validation only; if an app hands work off to another already-running session, WrapGuard no longer controls the real process tree. +- TCP routing is the documented macOS path; UDP and IPv6 remain outside the production support statement unless explicitly validated for a release. +- On current macOS builds, WrapGuard may deliberately suppress likely QUIC `UDP/443` connect attempts to encourage TCP fallback rather than claim full UDP tunneling support. +- Non-blocking socket behavior is improved but still under active validation; broader browser/socket-state compatibility still needs more regression coverage. +- The current browser transport decision is explicit: experimental macOS browser validation is a TCP-path claim, not a full `HTTP/3` / QUIC support claim. + +## Validation Notes + +- Confirm the packaged archive contains `wrapguard`, `libwrapguard.dylib`, `README.md`, and `example-wg0.conf`. +- Confirm `wrapguard --version` and `wrapguard --help` succeed from the unpacked archive on a clean macOS runner. +- If a GUI app was used for validation, record the inner executable path that was launched and note any helper-process instability. +- Record any manual CLI validation performed against a real WireGuard configuration. +- Record whether browser-style validation was performed with hard refresh and soft refresh comparisons, and note any QUIC, cache, or helper-process behavior that affected the result. diff --git a/go.mod b/go.mod index 28ae09e..0c02ec8 100644 --- a/go.mod +++ b/go.mod @@ -10,8 +10,11 @@ require ( ) require ( + github.com/google/btree v1.1.2 // indirect golang.org/x/crypto v0.46.0 // indirect golang.org/x/net v0.48.0 // indirect golang.org/x/sys v0.39.0 // indirect + golang.org/x/time v0.7.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect + gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c // indirect ) diff --git a/index.html b/index.html new file mode 100644 index 0000000..0cfd5d9 --- /dev/null +++ b/index.html @@ -0,0 +1 @@ +146.70.156.18 diff --git a/interceptor_smoke_test.go b/interceptor_smoke_test.go new file mode 100644 index 0000000..b806aef --- /dev/null +++ b/interceptor_smoke_test.go @@ -0,0 +1,1853 @@ +package main + +import ( + "context" + "fmt" + "io" + "net" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "syscall" + "testing" + "time" +) + +func interceptSourcePath(t *testing.T) string { + t.Helper() + + _, file, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("failed to resolve test source path") + } + return filepath.Join(filepath.Dir(file), "lib", "intercept.c") +} + +func TestInjectedLibraryHandshakeAndConnectSmoke(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skipf("unsupported runtime platform for smoke test: %s", runtime.GOOS) + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping smoke test: %v", err) + } + + helperDir := t.TempDir() + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + libraryPath := filepath.Join(helperDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + helperBinary := filepath.Join(helperDir, "connect-probe") + if err := buildConnectProbeForTest(t, cc, helperBinary); err != nil { + t.Fatalf("failed to build connect probe: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + subID, ch := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(subID) + + socksPort := reserveUnusedPort(t) + cmd := exec.CommandContext(ctx, helperBinary, "203.0.113.1:443") + cmd.Env = buildChildEnv(os.Environ(), cfg, libraryPath, ipcServer.SocketPath(), socksPort, false, false) + + if output, err := cmd.CombinedOutput(); err != nil { + t.Logf("probe exited with %v: %s", err, strings.TrimSpace(string(output))) + } + + deadline := time.After(5 * time.Second) + var sawReady bool + var sawConnect bool + for !(sawReady && sawConnect) { + select { + case msg, ok := <-ch: + if !ok { + t.Fatal("ipc subscription closed before expected messages arrived") + } + switch msg.Type { + case "READY": + sawReady = true + case "CONNECT": + sawConnect = true + if msg.Addr == "" { + t.Fatal("CONNECT message did not include address") + } + } + case <-deadline: + t.Fatalf("timed out waiting for READY and CONNECT messages (ready=%v connect=%v)", sawReady, sawConnect) + } + } +} + +func TestInjectedLibraryBypassesLocalhostConnects(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skipf("unsupported runtime platform for smoke test: %s", runtime.GOOS) + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping localhost bypass test: %v", err) + } + + helperDir := t.TempDir() + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + libraryPath := filepath.Join(helperDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + helperBinary := filepath.Join(helperDir, "connect-probe") + if err := buildConnectProbeForTest(t, cc, helperBinary); err != nil { + t.Fatalf("failed to build connect probe: %v", err) + } + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + subID, ch := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(subID) + + socksPort := reserveUnusedPort(t) + cmd := exec.Command(helperBinary, "127.0.0.1:9") + cmd.Env = buildChildEnv(os.Environ(), cfg, libraryPath, ipcServer.SocketPath(), socksPort, false, false) + + if output, err := cmd.CombinedOutput(); err != nil { + t.Logf("probe exited with %v: %s", err, strings.TrimSpace(string(output))) + } + + deadline := time.After(2 * time.Second) + sawReady := false + for { + select { + case msg, ok := <-ch: + if !ok { + t.Fatal("ipc subscription closed unexpectedly") + } + switch msg.Type { + case "READY": + sawReady = true + case "CONNECT": + t.Fatalf("localhost connect should not be intercepted, saw CONNECT for %q", msg.Addr) + } + case <-deadline: + if !sawReady { + t.Fatal("timed out waiting for READY message") + } + return + } + } +} + +func TestInterceptorSourceKeepsUnixDomainConnectsBypassed(t *testing.T) { + data, err := os.ReadFile(interceptSourcePath(t)) + if err != nil { + t.Fatalf("failed to read interceptor source: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + "if (addr->sa_family != AF_INET && addr->sa_family != AF_INET6) {", + "return 0; // Only intercept IP connections", + "case AF_UNIX:", + } + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("interceptor source missing expected unix-domain bypass snippet: %q", snippet) + } + } +} + +func TestInjectedLibraryReportsBindSmoke(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skipf("unsupported runtime platform for smoke test: %s", runtime.GOOS) + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping bind smoke test: %v", err) + } + + helperDir := t.TempDir() + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + libraryPath := filepath.Join(helperDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + helperBinary := filepath.Join(helperDir, "bind-probe") + if err := buildBindProbeForTest(t, cc, helperBinary); err != nil { + t.Fatalf("failed to build bind probe: %v", err) + } + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + subID, ch := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(subID) + + socksPort := reserveUnusedPort(t) + cmd := exec.Command(helperBinary) + cmd.Env = buildChildEnv(os.Environ(), cfg, libraryPath, ipcServer.SocketPath(), socksPort, false, false) + + if output, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("bind probe failed: %v: %s", err, strings.TrimSpace(string(output))) + } + + deadline := time.After(5 * time.Second) + var sawReady bool + var sawBind bool + for !(sawReady && sawBind) { + select { + case msg, ok := <-ch: + if !ok { + t.Fatal("ipc subscription closed before expected messages arrived") + } + switch msg.Type { + case "READY": + sawReady = true + case "BIND": + sawBind = true + if msg.Port == 0 { + t.Fatal("BIND message did not include a concrete port") + } + } + case <-deadline: + t.Fatalf("timed out waiting for READY and BIND messages (ready=%v bind=%v)", sawReady, sawBind) + } + } +} + +func TestInjectedLibraryHandlesNonBlockingConnectSmoke(t *testing.T) { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + t.Skipf("unsupported runtime platform for smoke test: %s", runtime.GOOS) + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping non-blocking smoke test: %v", err) + } + + helperDir := t.TempDir() + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + libraryPath := filepath.Join(helperDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + helperBinary := filepath.Join(helperDir, "nonblocking-connect-probe") + if err := buildNonBlockingConnectProbeForTest(t, cc, helperBinary); err != nil { + t.Fatalf("failed to build non-blocking connect probe: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + subID, ch := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(subID) + + socksPort := startSOCKSSuccessServer(t) + cmd := exec.CommandContext(ctx, helperBinary, "203.0.113.1:443") + cmd.Env = buildChildEnv(os.Environ(), cfg, libraryPath, ipcServer.SocketPath(), socksPort, false, false) + + output, err := cmd.CombinedOutput() + if ctx.Err() == context.DeadlineExceeded { + t.Fatalf("non-blocking connect probe timed out, output: %s", strings.TrimSpace(string(output))) + } + if err != nil { + t.Fatalf("non-blocking probe failed: %v: %s", err, strings.TrimSpace(string(output))) + } + + deadline := time.After(5 * time.Second) + var sawReady bool + var sawConnect bool + for !(sawReady && sawConnect) { + select { + case msg, ok := <-ch: + if !ok { + t.Fatal("ipc subscription closed before expected messages arrived") + } + switch msg.Type { + case "READY": + sawReady = true + case "CONNECT": + sawConnect = true + if msg.Addr == "" { + t.Fatal("CONNECT message did not include address") + } + } + case <-deadline: + t.Fatalf("timed out waiting for READY and CONNECT messages (ready=%v connect=%v)", sawReady, sawConnect) + } + } +} + +func TestInjectedLibraryVirtualizesGetpeernameAfterSOCKSConnect(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skipf("getpeername virtualization is only exercised on Linux now, runtime=%s", runtime.GOOS) + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping getpeername virtualization test: %v", err) + } + + helperDir := t.TempDir() + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + libraryPath := filepath.Join(helperDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + helperBinary := filepath.Join(helperDir, "getpeername-probe") + if err := buildGetPeerNameProbeForTest(t, cc, helperBinary); err != nil { + t.Fatalf("failed to build getpeername probe: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + subID, ch := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(subID) + + socksPort := startSOCKSSuccessServer(t) + cmd := exec.CommandContext(ctx, helperBinary, "203.0.113.1:443") + cmd.Env = buildChildEnv(os.Environ(), cfg, libraryPath, ipcServer.SocketPath(), socksPort, false, false) + + output, err := cmd.CombinedOutput() + if ctx.Err() == context.DeadlineExceeded { + t.Fatalf("getpeername probe timed out, output: %s", strings.TrimSpace(string(output))) + } + if err != nil { + t.Fatalf("getpeername probe failed: %v: %s", err, strings.TrimSpace(string(output))) + } + + deadline := time.After(5 * time.Second) + var sawReady bool + var sawConnect bool + for !(sawReady && sawConnect) { + select { + case msg, ok := <-ch: + if !ok { + t.Fatal("ipc subscription closed before expected messages arrived") + } + switch msg.Type { + case "READY": + sawReady = true + case "CONNECT": + sawConnect = true + } + case <-deadline: + t.Fatalf("timed out waiting for READY and CONNECT messages (ready=%v connect=%v)", sawReady, sawConnect) + } + } +} + +func TestInterceptorSourceClearsVirtualPeerStateBeforeFallbackConnects(t *testing.T) { + data, err := os.ReadFile(interceptSourcePath(t)) + if err != nil { + t.Fatalf("failed to read interceptor source: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + "static void forget_virtual_peer(int sockfd)", + "forget_virtual_peer(sockfd);\n errno = EHOSTUNREACH;", + } + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("interceptor source missing expected virtual-peer cleanup snippet: %q", snippet) + } + } + fallbacks := []string{ + "forget_virtual_peer(sockfd);\n#ifdef __APPLE__\n return raw_connect_call(sockfd, addr, addrlen);", + "forget_virtual_peer(sockfd);\n#ifdef __APPLE__\n return raw_connect_call(sockfd, addr, addrlen);\n#else\n return call_real_connect(sockfd, addr, addrlen);", + } + foundFallback := false + for _, snippet := range fallbacks { + if strings.Contains(content, snippet) { + foundFallback = true + break + } + } + if !foundFallback { + t.Fatalf("interceptor source missing expected virtual-peer fallback connect cleanup") + } +} + +func TestInterceptorSourceDeclaresMacOSInterpositionEntryPoints(t *testing.T) { + data, err := os.ReadFile(interceptSourcePath(t)) + if err != nil { + t.Fatalf("failed to read interceptor source: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + "#ifdef __APPLE__", + "DYLD_INTERPOSE(wrapguard_connect, connect)", + "DYLD_INTERPOSE(wrapguard_bind, bind)", + "DYLD_INTERPOSE(wrapguard_connectx, connectx)", + "DYLD_INTERPOSE(wrapguard_sendto, sendto)", + "DYLD_INTERPOSE(wrapguard_sendmsg, sendmsg)", + "return raw_bind_call(sockfd, addr, addrlen);", + "dlsym(RTLD_NEXT, \"connect\")", + "dlsym(RTLD_NEXT, \"bind\")", + "dlsym(RTLD_NEXT, \"getpeername\")", + "dlsym(RTLD_NEXT, \"connectx\")", + } + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("interceptor source missing expected macOS interposition snippet: %q", snippet) + } + } +} + +func TestInterceptorSourceUsesMacOSSafeDebugIPCForQUICSuppression(t *testing.T) { + data, err := os.ReadFile(interceptSourcePath(t)) + if err != nil { + t.Fatalf("failed to read interceptor source: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + "send_ipc_message(\"DEBUG\"", + "log_debugf(\"Blocking likely QUIC UDP sendto() to %s\", addr_str);", + "log_debugf(\"Blocking likely QUIC UDP sendmsg() to %s\", addr_str);", + "DYLD_INTERPOSE(wrapguard_sendto, sendto)", + "DYLD_INTERPOSE(wrapguard_sendmsg, sendmsg)", + } + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("interceptor source missing expected macOS QUIC observability snippet: %q", snippet) + } + } +} + +func TestInterceptorSourceBlocksConnectedDarwinUDPSendPathsViaPeerLookup(t *testing.T) { + data, err := os.ReadFile(interceptSourcePath(t)) + if err != nil { + t.Fatalf("failed to read interceptor source: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + "if (target == NULL || target_len < (socklen_t)sizeof(sa_family_t)) {", + "if (call_real_getpeername(sockfd, (struct sockaddr *)&target_storage, &target_len) != 0) {", + "if (!should_block_udp_target(target)) {", + } + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("interceptor source missing expected connected-UDP suppression snippet: %q", snippet) + } + } +} + +func TestInterceptorSourceKeepsDarwinGetpeernameOutOfTheInterposeTable(t *testing.T) { + data, err := os.ReadFile(interceptSourcePath(t)) + if err != nil { + t.Fatalf("failed to read interceptor source: %v", err) + } + + content := string(data) + if strings.Contains(content, "DYLD_INTERPOSE(wrapguard_getpeername, getpeername)") { + t.Fatal("Darwin should not interpose getpeername anymore; that regression breaks browser socket-thread behavior") + } + requiredSnippets := []string{ + "#ifndef __APPLE__", + "int getpeername(int sockfd, struct sockaddr *addr, socklen_t *addrlen) {", + } + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("interceptor source missing expected Darwin getpeername guard snippet: %q", snippet) + } + } +} + +func TestInterceptorSourceKeepsExpectReadyOneShotForDarwinChildren(t *testing.T) { + data, err := os.ReadFile(interceptSourcePath(t)) + if err != nil { + t.Fatalf("failed to read interceptor source: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + "expect_ready_cached = expect_ready_enabled();", + "if (expect_ready_cached) {\n unsetenv(\"WRAPGUARD_EXPECT_READY\");\n }", + "if (expect_ready_cached && ipc_path != NULL && socks_port != 0) {\n send_ipc_message(\"READY\", -1, socks_port, NULL);\n }", + } + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("interceptor source missing expected one-shot READY snippet: %q", snippet) + } + } +} + +func TestInterceptorSourcePreservesErrnoAcrossMacOSDebugIPC(t *testing.T) { + data, err := os.ReadFile(interceptSourcePath(t)) + if err != nil { + t.Fatalf("failed to read interceptor source: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + "int saved_errno = errno;", + "send_ipc_message(\"DEBUG\", -1, 0, message);", + "send_ipc_message(\"ERROR\", -1, 0, message);", + "errno = saved_errno;", + } + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("interceptor source missing expected errno-preservation snippet: %q", snippet) + } + } +} + +func TestInterceptorSourceAvoidsPreMutatingConnectxOutputsBeforeFallback(t *testing.T) { + data, err := os.ReadFile(interceptSourcePath(t)) + if err != nil { + t.Fatalf("failed to read interceptor source: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + "if (endpoints == NULL || endpoints->sae_dstaddr == NULL || endpoints->sae_dstaddrlen < (socklen_t)sizeof(sa_family_t)) {\n return call_real_connectx(sockfd, endpoints, associd, flags, iov, iovcnt, len, connid);\n }", + "if (len != NULL) {\n *len = 0;\n }", + "if (connid != NULL) {\n *connid = SAE_CONNID_ANY;\n }", + } + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("interceptor source missing expected connectx compatibility snippet: %q", snippet) + } + } +} + +func TestInterceptorSourcePinsMozillaHelperPassthroughPolicyOnMacOS(t *testing.T) { + data, err := os.ReadFile(interceptSourcePath(t)) + if err != nil { + t.Fatalf("failed to read interceptor source: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + "#include ", + "static int should_passthrough_mozilla_process(char *const *argv)", + "if (str_equals(base, \"plugin-container\")) {", + "if (argc > 1 && str_equals(argv[argc - 1], \"socket\")) {", + "passthrough_mode_cached = should_passthrough_current_process();", + "if (passthrough_mode_cached) {", + "return raw_connect_call(sockfd, addr, addrlen);", + "return raw_connectx_call(sockfd, endpoints, associd, flags, iov, iovcnt, len, connid);", + "int suppress_debug_log = addr->sa_family == AF_UNIX && sock_type == SOCK_DGRAM;", + } + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("interceptor source missing expected Mozilla helper passthrough snippet: %q", snippet) + } + } +} + +func TestInterceptorSourcePinsVirtualPeerBookkeepingForGetpeername(t *testing.T) { + data, err := os.ReadFile(interceptSourcePath(t)) + if err != nil { + t.Fatalf("failed to read interceptor source: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + "static int wrapguard_getpeername_impl(int sockfd, struct sockaddr *addr, socklen_t *addrlen)", + "if (lookup_virtual_peer(sockfd, addr, addrlen)) {", + "remember_virtual_peer(sockfd, addr, addrlen);", + "#ifndef __APPLE__", + "int getpeername(int sockfd, struct sockaddr *addr, socklen_t *addrlen) {", + } + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("interceptor source missing expected virtual-peer snippet: %q", snippet) + } + } +} + +func TestInjectedLibraryBlocksLikelyQUICUDPConnectOnDarwin(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS-only QUIC suppression smoke test") + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping UDP suppression smoke test: %v", err) + } + + helperDir := t.TempDir() + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + libraryPath := filepath.Join(helperDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + helperBinary := filepath.Join(helperDir, "udp-connect-probe") + if err := buildUDPConnectProbeForTest(t, cc, helperBinary); err != nil { + t.Fatalf("failed to build udp connect probe: %v", err) + } + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + subID, ch := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(subID) + + socksPort := reserveUnusedPort(t) + cmd := exec.Command(helperBinary, "203.0.113.1:443") + cmd.Env = buildChildEnv(os.Environ(), cfg, libraryPath, ipcServer.SocketPath(), socksPort, false, false) + + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("udp connect probe failed: %v: %s", err, strings.TrimSpace(string(output))) + } + + deadline := time.After(2 * time.Second) + sawReady := false + for { + select { + case msg, ok := <-ch: + if !ok { + t.Fatal("ipc subscription closed unexpectedly") + } + switch msg.Type { + case "READY": + sawReady = true + case "CONNECT": + t.Fatalf("UDP/443 connect should not have been tunneled through SOCKS, saw CONNECT for %q", msg.Addr) + } + case <-deadline: + if !sawReady { + t.Fatal("timed out waiting for READY message") + } + return + } + } +} + +func TestInjectedLibraryInterceptsConnectxOnDarwin(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS-only connectx smoke test") + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping connectx smoke test: %v", err) + } + + helperDir := t.TempDir() + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + libraryPath := filepath.Join(helperDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + helperBinary := filepath.Join(helperDir, "connectx-probe") + if err := buildConnectxProbeForTest(t, cc, helperBinary); err != nil { + t.Fatalf("failed to build connectx probe: %v", err) + } + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + subID, ch := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(subID) + + socksPort := reserveUnusedPort(t) + cmd := exec.Command(helperBinary, "203.0.113.1:443") + cmd.Env = buildChildEnv(os.Environ(), cfg, libraryPath, ipcServer.SocketPath(), socksPort, true, false) + + if output, err := cmd.CombinedOutput(); err != nil { + t.Logf("connectx probe exited with %v: %s", err, strings.TrimSpace(string(output))) + } + + deadline := time.After(5 * time.Second) + var sawReady bool + var sawConnect bool + for !(sawReady && sawConnect) { + select { + case msg, ok := <-ch: + if !ok { + t.Fatal("ipc subscription closed before expected messages arrived") + } + switch msg.Type { + case "READY": + sawReady = true + case "CONNECT": + sawConnect = true + if msg.Addr == "" { + t.Fatal("CONNECT message did not include address") + } + } + case <-deadline: + t.Fatalf("timed out waiting for READY and CONNECT messages from connectx probe (ready=%v connect=%v)", sawReady, sawConnect) + } + } +} + +func TestInjectedLibraryAppliesMozillaRolePolicyOnDarwin(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS-only Mozilla role policy smoke test") + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping Mozilla role policy smoke test: %v", err) + } + + helperDir := t.TempDir() + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + libraryPath := filepath.Join(helperDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + helperBinary := filepath.Join(helperDir, "nonblocking-role-probe") + if err := buildNonBlockingRoleProbeForTest(t, cc, helperBinary); err != nil { + t.Fatalf("failed to build non-blocking role probe: %v", err) + } + + tests := []struct { + name string + linkName string + args []string + wantConnect bool + }{ + { + name: "socket-process-stays-intercepted", + linkName: "plugin-container", + args: []string{"203.0.113.1:443", "socket"}, + wantConnect: true, + }, + { + name: "gpu-helper-stays-passthrough", + linkName: "gpu-helper", + args: []string{"203.0.113.1:443"}, + wantConnect: false, + }, + { + name: "librewolf-main-process-stays-intercepted", + linkName: "librewolf", + args: []string{"203.0.113.1:443"}, + wantConnect: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + roleBinary := filepath.Join(helperDir, tt.linkName) + if err := os.Link(helperBinary, roleBinary); err != nil { + if err := os.Symlink(helperBinary, roleBinary); err != nil { + t.Fatalf("failed to create role probe alias: %v / %v", err, err) + } + } + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + subID, ch := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(subID) + + socksPort := startSOCKSSuccessServer(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, roleBinary, tt.args...) + cmd.Env = buildChildEnv(os.Environ(), cfg, libraryPath, ipcServer.SocketPath(), socksPort, true, false) + + output, err := cmd.CombinedOutput() + if ctx.Err() == context.DeadlineExceeded { + t.Fatalf("role probe timed out, output: %s", strings.TrimSpace(string(output))) + } + if err != nil { + t.Fatalf("role probe failed: %v: %s", err, strings.TrimSpace(string(output))) + } + + deadline := time.After(5 * time.Second) + sawReady := false + sawConnect := false + for { + select { + case msg, ok := <-ch: + if !ok { + t.Fatal("ipc subscription closed unexpectedly") + } + switch msg.Type { + case "READY": + sawReady = true + case "CONNECT": + sawConnect = true + } + case <-deadline: + if !sawReady { + t.Fatal("timed out waiting for READY from role probe") + } + if sawConnect != tt.wantConnect { + t.Fatalf("CONNECT visibility mismatch for %s: got %v want %v", tt.linkName, sawConnect, tt.wantConnect) + } + return + } + } + }) + } +} + +func TestInjectedLibraryStripsMacOSInjectionEnvForDescendantsInCompatMode(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS-only GUI compatibility smoke test") + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping GUI compatibility smoke test: %v", err) + } + + helperDir := t.TempDir() + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + libraryPath := filepath.Join(helperDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + envDumpBinary := filepath.Join(helperDir, "env-dump-probe") + if err := buildEnvDumpProbeForTest(t, cc, envDumpBinary); err != nil { + t.Fatalf("failed to build env dump probe: %v", err) + } + + spawnBinary := filepath.Join(helperDir, "spawn-child-probe") + if err := buildSpawnChildProbeForTest(t, cc, spawnBinary); err != nil { + t.Fatalf("failed to build spawn child probe: %v", err) + } + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + subID, ch := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(subID) + + outputPath := filepath.Join(t.TempDir(), "child-env.txt") + socksPort := reserveUnusedPort(t) + cmd := exec.Command(spawnBinary, envDumpBinary, outputPath) + cmd.Env = buildChildEnv(os.Environ(), cfg, libraryPath, ipcServer.SocketPath(), socksPort, true, true) + + if output, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("spawn child probe failed: %v: %s", err, strings.TrimSpace(string(output))) + } + + sawReady := false + deadline := time.After(5 * time.Second) + for !sawReady { + select { + case msg, ok := <-ch: + if !ok { + t.Fatal("ipc subscription closed before READY arrived") + } + if msg.Type == "READY" { + sawReady = true + } + case <-deadline: + t.Fatal("timed out waiting for READY from injected parent") + } + } + + data, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("failed to read child env output: %v", err) + } + + got := string(data) + for _, key := range []string{ + "DYLD_INSERT_LIBRARIES", + "DYLD_FORCE_FLAT_NAMESPACE", + envWrapGuardExpectRDY, + envWrapGuardIPCPath, + envWrapGuardSOCKSPort, + envWrapGuardDebug, + envWrapGuardDebugIPC, + envWrapGuardBlockUDP, + envWrapGuardNoInherit, + } { + if strings.Contains(got, key+"=") && !strings.Contains(got, key+"=\n") { + t.Fatalf("child unexpectedly inherited %s: %q", key, got) + } + } +} + +func TestInjectedLibrarySuppressesMacOSUDP443SendtoAndSendmsgSmoke(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS-only QUIC send-path suppression smoke test") + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping UDP send-path suppression smoke test: %v", err) + } + + helperDir := t.TempDir() + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + libraryPath := filepath.Join(helperDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + helperBinary := filepath.Join(helperDir, "udp-send-probe") + if err := buildUDPSendProbeForTest(t, cc, helperBinary); err != nil { + t.Fatalf("failed to build udp send probe: %v", err) + } + + for _, mode := range []string{"sendto", "sendmsg", "connected-sendto", "connected-sendmsg"} { + t.Run(mode, func(t *testing.T) { + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + subID, ch := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(subID) + + socksPort := reserveUnusedPort(t) + cmd := exec.Command(helperBinary, "203.0.113.1:443", mode) + cmd.Env = buildChildEnv(os.Environ(), cfg, libraryPath, ipcServer.SocketPath(), socksPort, true, false) + + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("udp %s probe failed: %v: %s", mode, err, strings.TrimSpace(string(output))) + } + if strings.Contains(string(output), "WrapGuard DYLD:") || strings.Contains(string(output), "WrapGuard LD_PRELOAD:") { + t.Fatalf("udp %s probe should not emit recursive interceptor stderr logging: %s", mode, strings.TrimSpace(string(output))) + } + + deadline := time.After(5 * time.Second) + sawReady := false + sawDebug := false + for !(sawReady && sawDebug) { + select { + case msg, ok := <-ch: + if !ok { + t.Fatal("ipc subscription closed before expected messages arrived") + } + switch msg.Type { + case "READY": + sawReady = true + case "DEBUG": + if strings.Contains(msg.Addr, "Blocking likely QUIC UDP") { + sawDebug = true + } + case "CONNECT": + t.Fatalf("udp %s send path should not emit CONNECT, saw %q", mode, msg.Addr) + } + case <-deadline: + t.Fatalf("timed out waiting for READY and QUIC debug messages for %s (ready=%v debug=%v)", mode, sawReady, sawDebug) + } + } + }) + } +} + +func findCCompiler() (string, error) { + candidates := []string{"cc", "clang", "gcc"} + for _, candidate := range candidates { + if path, err := exec.LookPath(candidate); err == nil { + return path, nil + } + } + return "", fmt.Errorf("no C compiler found in PATH") +} + +func buildInterceptLibraryForTest(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := interceptSourcePath(t) + args := []string{"-Wall", "-Wextra", "-Werror"} + if runtime.GOOS == "darwin" { + args = append(args, "-dynamiclib", "-o", outputPath, sourcePath) + } else { + args = append(args, "-shared", "-fPIC", "-o", outputPath, sourcePath, "-ldl") + } + + cmd := exec.Command(cc, args...) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func buildConnectProbeForTest(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "connect_probe.c") + source := `#include +#include +#include +#include +#include +#include +#include + +int main(int argc, char **argv) { + if (argc != 2) { + return 2; + } + + char input[256]; + memset(input, 0, sizeof(input)); + strncpy(input, argv[1], sizeof(input) - 1); + + char *sep = strrchr(input, ':'); + if (sep == NULL) { + return 3; + } + + *sep = '\0'; + const char *host = input; + int port = atoi(sep + 1); + + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + return 4; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + if (inet_pton(AF_INET, host, &addr.sin_addr) != 1) { + close(fd); + return 5; + } + + (void)connect(fd, (struct sockaddr *)&addr, sizeof(addr)); + close(fd); + return 0; +}` + if err := os.WriteFile(sourcePath, []byte(source), 0644); err != nil { + return err + } + + args := []string{"-Wall", "-Wextra", "-Werror"} + if runtime.GOOS == "darwin" { + args = append(args, "-Wno-deprecated-declarations") + } + args = append(args, "-o", outputPath, sourcePath) + cmd := exec.Command(cc, args...) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func buildNonBlockingConnectProbeForTest(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "nonblocking_connect_probe.c") + source := `#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +int main(int argc, char **argv) { + if (argc != 2) { + return 2; + } + + char input[256]; + memset(input, 0, sizeof(input)); + strncpy(input, argv[1], sizeof(input) - 1); + + char *sep = strrchr(input, ':'); + if (sep == NULL) { + return 3; + } + + *sep = '\0'; + const char *host = input; + int port = atoi(sep + 1); + + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + return 4; + } + + int flags = fcntl(fd, F_GETFL, 0); + if (flags < 0) { + close(fd); + return 5; + } + if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) != 0) { + close(fd); + return 6; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + if (inet_pton(AF_INET, host, &addr.sin_addr) != 1) { + close(fd); + return 7; + } + + if (connect(fd, (struct sockaddr *)&addr, sizeof(addr)) != -1 || errno != EINPROGRESS) { + close(fd); + return 8; + } + + fd_set writefds; + FD_ZERO(&writefds); + FD_SET(fd, &writefds); + + struct timeval timeout; + timeout.tv_sec = 5; + timeout.tv_usec = 0; + + int ready = select(fd + 1, NULL, &writefds, NULL, &timeout); + if (ready != 1) { + close(fd); + return 9; + } + + int so_error = -1; + socklen_t so_error_len = sizeof(so_error); + if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &so_error, &so_error_len) != 0) { + close(fd); + return 10; + } + if (so_error != 0) { + close(fd); + return 11; + } + + close(fd); + return 0; +}` + if err := os.WriteFile(sourcePath, []byte(source), 0o644); err != nil { + return err + } + + args := []string{"-Wall", "-Wextra", "-Werror"} + if runtime.GOOS == "darwin" { + args = append(args, "-Wno-deprecated-declarations") + } + args = append(args, "-o", outputPath, sourcePath) + cmd := exec.Command(cc, args...) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func buildConnectxProbeForTest(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "connectx_probe.c") + source := `#include +#include +#include +#include +#include +#include +#include +#include + +int main(int argc, char **argv) { + if (argc != 2) { + return 2; + } + + char input[256]; + memset(input, 0, sizeof(input)); + strncpy(input, argv[1], sizeof(input) - 1); + + char *sep = strrchr(input, ':'); + if (sep == NULL) { + return 3; + } + + *sep = '\0'; + const char *host = input; + int port = atoi(sep + 1); + + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + return 4; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + if (inet_pton(AF_INET, host, &addr.sin_addr) != 1) { + close(fd); + return 5; + } + + sa_endpoints_t endpoints; + memset(&endpoints, 0, sizeof(endpoints)); + endpoints.sae_dstaddr = (const struct sockaddr *)&addr; + endpoints.sae_dstaddrlen = sizeof(addr); + + (void)connectx(fd, &endpoints, SAE_ASSOCID_ANY, 0, NULL, 0, NULL, NULL); + close(fd); + return 0; +}` + if err := os.WriteFile(sourcePath, []byte(source), 0o644); err != nil { + return err + } + + args := []string{"-Wall", "-Wextra", "-Werror"} + if runtime.GOOS == "darwin" { + args = append(args, "-Wno-deprecated-declarations") + } + args = append(args, "-o", outputPath, sourcePath) + cmd := exec.Command(cc, args...) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func buildEnvDumpProbeForTest(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "env_dump_probe.c") + source := `#include +#include + +int main(int argc, char **argv) { + if (argc != 2) { + return 2; + } + + FILE *fp = fopen(argv[1], "w"); + if (fp == NULL) { + return 3; + } + + const char *keys[] = { + "DYLD_INSERT_LIBRARIES", + "DYLD_FORCE_FLAT_NAMESPACE", + "WRAPGUARD_EXPECT_READY", + "WRAPGUARD_IPC_PATH", + "WRAPGUARD_SOCKS_PORT", + "WRAPGUARD_DEBUG", + "WRAPGUARD_DEBUG_IPC", + "WRAPGUARD_BLOCK_UDP_443", + "WRAPGUARD_MACOS_NO_INHERIT", + }; + + for (size_t i = 0; i < sizeof(keys) / sizeof(keys[0]); ++i) { + const char *value = getenv(keys[i]); + fprintf(fp, "%s=%s\n", keys[i], value ? value : ""); + } + + fclose(fp); + return 0; +}` + if err := os.WriteFile(sourcePath, []byte(source), 0o644); err != nil { + return err + } + + cmd := exec.Command(cc, "-Wall", "-Wextra", "-Werror", "-o", outputPath, sourcePath) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func buildNonBlockingRoleProbeForTest(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "nonblocking_role_probe.c") + source := `#include +#include +#include +#include +#include +#include +#include +#include + +int main(int argc, char **argv) { + if (argc < 2) { + return 2; + } + + char input[256]; + memset(input, 0, sizeof(input)); + strncpy(input, argv[1], sizeof(input) - 1); + + char *sep = strrchr(input, ':'); + if (sep == NULL) { + return 3; + } + + *sep = '\0'; + const char *host = input; + int port = atoi(sep + 1); + + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + return 4; + } + + int flags = fcntl(fd, F_GETFL, 0); + if (flags < 0 || fcntl(fd, F_SETFL, flags | O_NONBLOCK) != 0) { + close(fd); + return 5; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + if (inet_pton(AF_INET, host, &addr.sin_addr) != 1) { + close(fd); + return 6; + } + + if (connect(fd, (struct sockaddr *)&addr, sizeof(addr)) != 0 && errno != EINPROGRESS) { + close(fd); + return 7; + } + + close(fd); + return 0; +}` + if err := os.WriteFile(sourcePath, []byte(source), 0o644); err != nil { + return err + } + + cmd := exec.Command(cc, "-Wall", "-Wextra", "-Werror", "-o", outputPath, sourcePath) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func buildSpawnChildProbeForTest(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "spawn_child_probe.c") + source := `#include + +int main(int argc, char **argv) { + if (argc != 3) { + return 2; + } + + char *child_argv[] = {argv[1], argv[2], NULL}; + execv(argv[1], child_argv); + return 3; +}` + if err := os.WriteFile(sourcePath, []byte(source), 0o644); err != nil { + return err + } + + cmd := exec.Command(cc, "-Wall", "-Wextra", "-Werror", "-o", outputPath, sourcePath) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func buildGetPeerNameProbeForTest(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "getpeername_probe.c") + source := `#include +#include +#include +#include +#include +#include +#include + +int main(int argc, char **argv) { + if (argc != 2) { + return 2; + } + + char input[256]; + memset(input, 0, sizeof(input)); + strncpy(input, argv[1], sizeof(input) - 1); + + char *sep = strrchr(input, ':'); + if (sep == NULL) { + return 3; + } + + *sep = '\0'; + const char *host = input; + int port = atoi(sep + 1); + + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + return 4; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + if (inet_pton(AF_INET, host, &addr.sin_addr) != 1) { + close(fd); + return 5; + } + + if (connect(fd, (struct sockaddr *)&addr, sizeof(addr)) != 0) { + close(fd); + return 6; + } + + struct sockaddr_in peer; + memset(&peer, 0, sizeof(peer)); + socklen_t peer_len = sizeof(peer); + if (getpeername(fd, (struct sockaddr *)&peer, &peer_len) != 0) { + close(fd); + return 7; + } + + char peer_ip[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &peer.sin_addr, peer_ip, sizeof(peer_ip)) == NULL) { + close(fd); + return 8; + } + + if (strcmp(peer_ip, host) != 0 || ntohs(peer.sin_port) != port) { + close(fd); + return 9; + } + + close(fd); + return 0; +}` + if err := os.WriteFile(sourcePath, []byte(source), 0o644); err != nil { + return err + } + + cmd := exec.Command(cc, "-Wall", "-Wextra", "-Werror", "-o", outputPath, sourcePath) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func buildUDPSendProbeForTest(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "udp_send_probe.c") + source := `#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int raw_udp_connect(int fd, const struct sockaddr *addr, socklen_t addrlen) { + return (int)syscall(SYS_connect, fd, addr, addrlen); +} + +static int parse_target(const char *input, struct sockaddr_in *addr) { + char copy[256]; + memset(copy, 0, sizeof(copy)); + strncpy(copy, input, sizeof(copy) - 1); + + char *sep = strrchr(copy, ':'); + if (sep == NULL) { + return 1; + } + + *sep = '\0'; + addr->sin_family = AF_INET; + addr->sin_port = htons(atoi(sep + 1)); + return inet_pton(AF_INET, copy, &addr->sin_addr) == 1 ? 0 : 2; +} + +int main(int argc, char **argv) { + if (argc != 3) { + return 2; + } + + int fd = socket(AF_INET, SOCK_DGRAM, 0); + if (fd < 0) { + return 4; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + if (parse_target(argv[1], &addr) != 0) { + close(fd); + return 5; + } + + const char *mode = argv[2]; + const char payload[] = "quic"; + + if (strcmp(mode, "sendto") == 0) { + ssize_t sent = sendto(fd, payload, sizeof(payload) - 1, 0, (struct sockaddr *)&addr, sizeof(addr)); + if (sent != -1 || errno != EHOSTUNREACH) { + close(fd); + return 6; + } + } else if (strcmp(mode, "sendmsg") == 0) { + struct iovec iov; + memset(&iov, 0, sizeof(iov)); + iov.iov_base = (void *)payload; + iov.iov_len = sizeof(payload) - 1; + + struct msghdr msg; + memset(&msg, 0, sizeof(msg)); + msg.msg_name = &addr; + msg.msg_namelen = sizeof(addr); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ssize_t sent = sendmsg(fd, &msg, 0); + if (sent != -1 || errno != EHOSTUNREACH) { + close(fd); + return 7; + } + } else if (strcmp(mode, "connected-sendto") == 0) { + if (raw_udp_connect(fd, (struct sockaddr *)&addr, sizeof(addr)) != 0) { + close(fd); + return 9; + } + + errno = 0; + ssize_t sent = sendto(fd, payload, sizeof(payload) - 1, 0, NULL, 0); + if (sent != -1 || errno != EHOSTUNREACH) { + close(fd); + return 10; + } + } else if (strcmp(mode, "connected-sendmsg") == 0) { + if (raw_udp_connect(fd, (struct sockaddr *)&addr, sizeof(addr)) != 0) { + close(fd); + return 11; + } + + struct iovec iov; + memset(&iov, 0, sizeof(iov)); + iov.iov_base = (void *)payload; + iov.iov_len = sizeof(payload) - 1; + + struct msghdr msg; + memset(&msg, 0, sizeof(msg)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + errno = 0; + ssize_t sent = sendmsg(fd, &msg, 0); + if (sent != -1 || errno != EHOSTUNREACH) { + close(fd); + return 12; + } + } else { + close(fd); + return 8; + } + + close(fd); + return 0; +}` + if err := os.WriteFile(sourcePath, []byte(source), 0o644); err != nil { + return err + } + + args := []string{"-Wall", "-Wextra", "-Werror"} + if runtime.GOOS == "darwin" { + args = append(args, "-Wno-deprecated-declarations") + } + args = append(args, "-o", outputPath, sourcePath) + cmd := exec.Command(cc, args...) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func buildBindProbeForTest(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "bind_probe.c") + source := `#include +#include +#include +#include +#include +#include + +int main(void) { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + return 2; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + + if (bind(fd, (struct sockaddr *)&addr, sizeof(addr)) != 0) { + close(fd); + return 3; + } + + if (listen(fd, 1) != 0) { + close(fd); + return 4; + } + + close(fd); + return 0; +}` + if err := os.WriteFile(sourcePath, []byte(source), 0o644); err != nil { + return err + } + + cmd := exec.Command(cc, "-Wall", "-Wextra", "-Werror", "-o", outputPath, sourcePath) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func buildUDPConnectProbeForTest(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "udp_connect_probe.c") + source := `#include +#include +#include +#include +#include +#include +#include +#include + +int main(int argc, char **argv) { + if (argc != 2) { + return 2; + } + + char input[256]; + memset(input, 0, sizeof(input)); + strncpy(input, argv[1], sizeof(input) - 1); + + char *sep = strrchr(input, ':'); + if (sep == NULL) { + return 3; + } + + *sep = '\0'; + const char *host = input; + int port = atoi(sep + 1); + + int fd = socket(AF_INET, SOCK_DGRAM, 0); + if (fd < 0) { + return 4; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + if (inet_pton(AF_INET, host, &addr.sin_addr) != 1) { + close(fd); + return 5; + } + + if (connect(fd, (struct sockaddr *)&addr, sizeof(addr)) != -1 || errno != EHOSTUNREACH) { + close(fd); + return 6; + } + + close(fd); + return 0; +}` + if err := os.WriteFile(sourcePath, []byte(source), 0o644); err != nil { + return err + } + + cmd := exec.Command(cc, "-Wall", "-Wextra", "-Werror", "-o", outputPath, sourcePath) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%w: %s", err, strings.TrimSpace(string(output))) + } + return nil +} + +func startSOCKSSuccessServer(t *testing.T) int { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to start SOCKS success listener: %v", err) + } + t.Cleanup(func() { + _ = listener.Close() + }) + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + if errorsIsClosedConn(err) { + return + } + continue + } + go func(conn net.Conn) { + defer conn.Close() + + header := make([]byte, 3) + if _, err := io.ReadFull(conn, header); err != nil { + return + } + if _, err := conn.Write([]byte{0x05, 0x00}); err != nil { + return + } + + reqHeader := make([]byte, 4) + if _, err := io.ReadFull(conn, reqHeader); err != nil { + return + } + + var rest int + switch reqHeader[3] { + case 0x01: + rest = 4 + 2 + case 0x03: + domainLen := make([]byte, 1) + if _, err := io.ReadFull(conn, domainLen); err != nil { + return + } + rest = int(domainLen[0]) + 2 + case 0x04: + rest = 16 + 2 + default: + return + } + + if rest > 0 { + payload := make([]byte, rest) + if _, err := io.ReadFull(conn, payload); err != nil { + return + } + } + + _, _ = conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x12, 0x34}) + }(conn) + } + }() + + return listener.Addr().(*net.TCPAddr).Port +} + +func errorsIsClosedConn(err error) bool { + return err != nil && (err == net.ErrClosed || err == syscall.EINVAL || strings.Contains(err.Error(), "use of closed network connection")) +} + +func reserveUnusedPort(t *testing.T) int { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to reserve port: %v", err) + } + defer listener.Close() + + return listener.Addr().(*net.TCPAddr).Port +} diff --git a/ipc.go b/ipc.go index 058efee..62e0171 100644 --- a/ipc.go +++ b/ipc.go @@ -7,19 +7,26 @@ import ( "net" "os" "path/filepath" + "sync" + "time" ) type IPCMessage struct { - Type string `json:"type"` // "CONNECT" or "BIND" - FD int `json:"fd"` - Port int `json:"port"` - Addr string `json:"addr"` + Type string `json:"type"` // "READY", "CONNECT", "BIND", and debug/transport events + FD int `json:"fd"` + Port int `json:"port"` + Addr string `json:"addr"` + PID int `json:"pid"` + Detail string `json:"detail"` } type IPCServer struct { listener net.Listener socketPath string msgChan chan IPCMessage + mu sync.Mutex + nextSubID int + subs map[int]chan IPCMessage } func NewIPCServer() (*IPCServer, error) { @@ -38,6 +45,7 @@ func NewIPCServer() (*IPCServer, error) { listener: listener, socketPath: socketPath, msgChan: make(chan IPCMessage, 100), + subs: make(map[int]chan IPCMessage), } // Start accepting connections @@ -68,15 +76,29 @@ func (s *IPCServer) handleConnection(conn net.Conn) { var msg IPCMessage if err := json.Unmarshal([]byte(line), &msg); err != nil { - fmt.Printf("IPC: Failed to parse message: %v\n", err) + logger.Warnf("IPC failed to parse message: %v", err) continue } - // Send message to channel (non-blocking) + s.dispatchMessage(msg) + } +} + +func (s *IPCServer) dispatchMessage(msg IPCMessage) { + select { + case s.msgChan <- msg: + default: + logger.Warnf("IPC message channel full, dropping %s from pid %d", msg.Type, msg.PID) + } + + s.mu.Lock() + defer s.mu.Unlock() + + for id, ch := range s.subs { select { - case s.msgChan <- msg: + case ch <- msg: default: - fmt.Printf("IPC: Message channel full, dropping message\n") + logger.Warnf("IPC subscriber %d channel full, dropping %s from pid %d", id, msg.Type, msg.PID) } } } @@ -89,11 +111,64 @@ func (s *IPCServer) MessageChan() <-chan IPCMessage { return s.msgChan } +func (s *IPCServer) Subscribe() (int, <-chan IPCMessage) { + s.mu.Lock() + defer s.mu.Unlock() + + id := s.nextSubID + s.nextSubID++ + ch := make(chan IPCMessage, 32) + s.subs[id] = ch + + return id, ch +} + +func (s *IPCServer) Unsubscribe(id int) { + s.mu.Lock() + defer s.mu.Unlock() + + ch, ok := s.subs[id] + if !ok { + return + } + delete(s.subs, id) + close(ch) +} + +func (s *IPCServer) WaitForMessageType(msgType string, timeout time.Duration) (IPCMessage, error) { + subID, ch := s.Subscribe() + defer s.Unsubscribe(subID) + + timer := time.NewTimer(timeout) + defer timer.Stop() + + for { + select { + case msg, ok := <-ch: + if !ok { + return IPCMessage{}, fmt.Errorf("ipc subscriber closed while waiting for %s", msgType) + } + if msg.Type == msgType { + return msg, nil + } + case <-timer.C: + return IPCMessage{}, fmt.Errorf("timed out waiting for IPC message type %s", msgType) + } + } +} + func (s *IPCServer) Close() error { if s.listener != nil { s.listener.Close() } + s.mu.Lock() + for id, ch := range s.subs { + delete(s.subs, id) + close(ch) + } + s.mu.Unlock() + // Clean up socket file if s.socketPath != "" { os.Remove(s.socketPath) diff --git a/launcher_contract_test.go b/launcher_contract_test.go new file mode 100644 index 0000000..18d1a2a --- /dev/null +++ b/launcher_contract_test.go @@ -0,0 +1,504 @@ +package main + +import ( + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestInjectionConfigForGOOS(t *testing.T) { + tests := []struct { + goos string + want injectionConfig + }{ + { + goos: "linux", + want: injectionConfig{ + LibraryName: "libwrapguard.so", + LibraryEnvVar: "LD_PRELOAD", + }, + }, + { + goos: "darwin", + want: injectionConfig{ + LibraryName: "libwrapguard.dylib", + LibraryEnvVar: "DYLD_INSERT_LIBRARIES", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.goos, func(t *testing.T) { + got, err := injectionConfigForGOOS(tt.goos) + if err != nil { + t.Fatalf("injectionConfigForGOOS(%q) returned error: %v", tt.goos, err) + } + if got != tt.want { + t.Fatalf("injectionConfigForGOOS(%q) = %+v, want %+v", tt.goos, got, tt.want) + } + }) + } +} + +func TestResolveInjectedLibraryPath(t *testing.T) { + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + tmpDir := t.TempDir() + execPath := filepath.Join(tmpDir, "wrapguard") + libPath := filepath.Join(tmpDir, cfg.LibraryName) + + if err := os.WriteFile(libPath, []byte("test"), 0o644); err != nil { + t.Fatalf("failed to create dummy library: %v", err) + } + + gotPath, gotCfg, err := resolveInjectedLibraryPath(execPath) + if err != nil { + t.Fatalf("resolveInjectedLibraryPath failed: %v", err) + } + + if gotPath != libPath { + t.Fatalf("resolveInjectedLibraryPath() path = %q, want %q", gotPath, libPath) + } + if gotCfg != cfg { + t.Fatalf("resolveInjectedLibraryPath() config = %+v, want %+v", gotCfg, cfg) + } +} + +func TestResolveInjectedLibraryPathMissing(t *testing.T) { + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + tmpDir := t.TempDir() + execPath := filepath.Join(tmpDir, "wrapguard") + _ = filepath.Join(tmpDir, cfg.LibraryName) + + _, _, err = resolveInjectedLibraryPath(execPath) + if err == nil { + t.Fatal("resolveInjectedLibraryPath should fail when the platform library is missing") + } + if !strings.Contains(err.Error(), cfg.LibraryName) { + t.Fatalf("expected missing-library error to mention %q, got %v", cfg.LibraryName, err) + } +} + +func TestResolveInjectedLibraryPathFallsBackToInvokedPathDirectory(t *testing.T) { + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + execDir := t.TempDir() + execPath := filepath.Join(execDir, "wrapguard") + + pathDir := t.TempDir() + invokedPath := filepath.Join(pathDir, "wrapguard-on-path") + libPath := filepath.Join(pathDir, cfg.LibraryName) + + if err := os.WriteFile(invokedPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("failed to create invoked-path fixture: %v", err) + } + if err := os.WriteFile(libPath, []byte("test"), 0o644); err != nil { + t.Fatalf("failed to create dummy library: %v", err) + } + + oldArgs0 := os.Args[0] + oldPath := os.Getenv("PATH") + defer func() { + os.Args[0] = oldArgs0 + _ = os.Setenv("PATH", oldPath) + }() + + os.Args[0] = "wrapguard-on-path" + if err := os.Setenv("PATH", pathDir); err != nil { + t.Fatalf("failed to update PATH: %v", err) + } + + gotPath, gotCfg, err := resolveInjectedLibraryPath(execPath) + if err != nil { + t.Fatalf("resolveInjectedLibraryPath failed: %v", err) + } + if gotPath != libPath { + t.Fatalf("resolveInjectedLibraryPath() path = %q, want %q", gotPath, libPath) + } + if gotCfg != cfg { + t.Fatalf("resolveInjectedLibraryPath() config = %+v, want %+v", gotCfg, cfg) + } +} + +func TestBuildChildEnvUsesPlatformInjectionVariable(t *testing.T) { + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + got := buildChildEnv( + []string{ + "PATH=/usr/bin", + "DYLD_FORCE_FLAT_NAMESPACE=0", + "UNRELATED=value", + }, + cfg, + "/tmp/"+cfg.LibraryName, + "/tmp/wrapguard.sock", + 4242, + true, + false, + ) + + if _, ok := envValue(got, "PATH"); !ok { + t.Fatal("PATH should be preserved") + } + if value, ok := envValue(got, "UNRELATED"); !ok || value != "value" { + t.Fatalf("unrelated environment should be preserved, got %q, present=%v", value, ok) + } + if value, ok := envValue(got, cfg.LibraryEnvVar); !ok || value != "/tmp/"+cfg.LibraryName { + t.Fatalf("%s not injected correctly: got %q, present=%v", cfg.LibraryEnvVar, value, ok) + } + if value, ok := envValue(got, "WRAPGUARD_IPC_PATH"); !ok || value != "/tmp/wrapguard.sock" { + t.Fatalf("WRAPGUARD_IPC_PATH not injected correctly: got %q, present=%v", value, ok) + } + if value, ok := envValue(got, "WRAPGUARD_SOCKS_PORT"); !ok || value != "4242" { + t.Fatalf("WRAPGUARD_SOCKS_PORT not injected correctly: got %q, present=%v", value, ok) + } + if value, ok := envValue(got, envWrapGuardExpectRDY); !ok || value != "1" { + t.Fatalf("%s not injected correctly: got %q, present=%v", envWrapGuardExpectRDY, value, ok) + } + if cfg.LibraryEnvVar == "DYLD_INSERT_LIBRARIES" { + if value, ok := envValue(got, envWrapGuardBlockUDP); !ok || value != "1" { + t.Fatalf("%s should be enabled on macOS, got %q, present=%v", envWrapGuardBlockUDP, value, ok) + } + if value, ok := envValue(got, envWrapGuardDebugIPC); !ok || value != "1" { + t.Fatalf("%s should be enabled for macOS debug launches, got %q, present=%v", envWrapGuardDebugIPC, value, ok) + } + } else if _, ok := envValue(got, envWrapGuardBlockUDP); ok { + t.Fatalf("%s should not be injected on Linux", envWrapGuardBlockUDP) + } else if _, ok := envValue(got, envWrapGuardDebugIPC); ok { + t.Fatalf("%s should not be injected on Linux", envWrapGuardDebugIPC) + } else if _, ok := envValue(got, envWrapGuardNoInherit); ok { + t.Fatalf("%s should not be injected unless GUI compatibility mode is enabled", envWrapGuardNoInherit) + } + if value, ok := envValue(got, "WRAPGUARD_DEBUG"); !ok || value != "1" { + t.Fatalf("WRAPGUARD_DEBUG should be enabled in debug mode, got %q, present=%v", value, ok) + } + + if currentPlatformName() == "darwin" { + if _, ok := envValue(got, "DYLD_FORCE_FLAT_NAMESPACE"); ok { + t.Fatalf("DYLD_FORCE_FLAT_NAMESPACE should not be injected for Darwin DYLD_INTERPOSE launches") + } + } else if value, ok := envValue(got, "DYLD_FORCE_FLAT_NAMESPACE"); !ok || value != "0" { + t.Fatalf("DYLD_FORCE_FLAT_NAMESPACE should remain unchanged on Linux, got %q, present=%v", value, ok) + } +} + +func TestValidateLaunchTarget(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS-specific launch target validation only applies on Darwin") + } + + bundlePath, innerExecutable := writeAppBundleFixture(t, "Example") + + details, err := validateLaunchTargetWithLibrary(bundlePath, "") + if err != nil { + t.Fatalf("validateLaunchTargetWithLibrary rejected app bundle: %v", err) + } + if details == nil { + t.Fatal("validateLaunchTargetWithLibrary returned nil details for app bundle") + } + if details.ResolvedPath != innerExecutable { + t.Fatalf("validateLaunchTargetWithLibrary resolved %q, want %q", details.ResolvedPath, innerExecutable) + } +} + +func TestResolveAppBundleExecutablePathRejectsMultipleCandidates(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS-specific app bundle resolution only applies on Darwin") + } + + bundlePath := filepath.Join(t.TempDir(), "Example.app") + macOSDir := filepath.Join(bundlePath, "Contents", "MacOS") + if err := os.MkdirAll(macOSDir, 0o755); err != nil { + t.Fatalf("failed to create app bundle directory: %v", err) + } + + for _, name := range []string{"First", "Second"} { + path := filepath.Join(macOSDir, name) + if err := os.WriteFile(path, []byte("#!/bin/sh\n"), 0o755); err != nil { + t.Fatalf("failed to create executable candidate %s: %v", name, err) + } + } + + _, err := resolveAppBundleExecutablePath(bundlePath) + if err == nil || !strings.Contains(err.Error(), "multiple executable candidates in Contents/MacOS") { + t.Fatalf("expected multiple-candidate failure, got %v", err) + } +} + +func TestResolveAppBundleExecutablePathRejectsMissingExecutables(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS-specific app bundle resolution only applies on Darwin") + } + + bundlePath := filepath.Join(t.TempDir(), "Example.app") + macOSDir := filepath.Join(bundlePath, "Contents", "MacOS") + if err := os.MkdirAll(macOSDir, 0o755); err != nil { + t.Fatalf("failed to create app bundle directory: %v", err) + } + + readmePath := filepath.Join(macOSDir, "README.txt") + if err := os.WriteFile(readmePath, []byte("not executable"), 0o644); err != nil { + t.Fatalf("failed to create non-executable file: %v", err) + } + + _, err := resolveAppBundleExecutablePath(bundlePath) + if err == nil || !strings.Contains(err.Error(), "does not contain an executable in Contents/MacOS") { + t.Fatalf("expected missing-executable failure, got %v", err) + } +} + +func TestResolveScriptInterpreter(t *testing.T) { + scriptPath := filepath.Join(t.TempDir(), "script.sh") + if err := os.WriteFile(scriptPath, []byte("#!/usr/bin/env sh\necho ok\n"), 0o755); err != nil { + t.Fatalf("failed to create script: %v", err) + } + + got, ok, err := resolveScriptInterpreter(scriptPath) + if err != nil { + t.Fatalf("resolveScriptInterpreter returned error: %v", err) + } + if !ok { + t.Fatal("resolveScriptInterpreter should detect script shebang") + } + + want, err := exec.LookPath("sh") + if err != nil { + t.Fatalf("failed to resolve sh for test: %v", err) + } + if got != want { + t.Fatalf("resolveScriptInterpreter() = %q, want %q", got, want) + } +} + +func TestResolveScriptInterpreterForBinary(t *testing.T) { + binaryPath := filepath.Join(t.TempDir(), "binary") + if err := os.WriteFile(binaryPath, []byte("not a script"), 0o755); err != nil { + t.Fatalf("failed to create dummy file: %v", err) + } + + got, ok, err := resolveScriptInterpreter(binaryPath) + if err != nil { + t.Fatalf("resolveScriptInterpreter returned error: %v", err) + } + if ok { + t.Fatalf("resolveScriptInterpreter unexpectedly detected interpreter %q", got) + } +} + +func TestValidateLaunchTargetRejectsProtectedInterpreter(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS-specific launch target validation only applies on Darwin") + } + + scriptPath := filepath.Join(t.TempDir(), "script.sh") + if err := os.WriteFile(scriptPath, []byte("#!/bin/sh\necho ok\n"), 0o755); err != nil { + t.Fatalf("failed to create script: %v", err) + } + + if _, err := validateLaunchTargetWithLibrary(scriptPath, ""); err == nil || !strings.Contains(err.Error(), "SIP-protected interpreter") { + t.Fatalf("expected SIP-protected interpreter rejection, got %v", err) + } +} + +func TestBuildChildEnvMergesExistingInjectionLibraryValue(t *testing.T) { + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + existingValue := "/opt/existing/preload" + if cfg.LibraryEnvVar == "DYLD_INSERT_LIBRARIES" { + existingValue = "/opt/existing/a.dylib:/opt/existing/b.dylib" + } + + got := buildChildEnv( + []string{cfg.LibraryEnvVar + "=" + existingValue}, + cfg, + "/tmp/"+cfg.LibraryName, + "/tmp/wrapguard.sock", + 4242, + false, + false, + ) + + value, ok := envValue(got, cfg.LibraryEnvVar) + if !ok { + t.Fatalf("%s not found in child environment", cfg.LibraryEnvVar) + } + if !strings.HasPrefix(value, "/tmp/"+cfg.LibraryName) { + t.Fatalf("%s should prepend wrapguard library, got %q", cfg.LibraryEnvVar, value) + } + if !strings.Contains(value, existingValue) { + t.Fatalf("%s should preserve existing value %q, got %q", cfg.LibraryEnvVar, existingValue, value) + } +} + +func TestBuildChildEnvEnablesMacOSNoInheritWhenRequested(t *testing.T) { + cfg, err := injectionConfigForGOOS("darwin") + if err != nil { + t.Fatalf("injectionConfigForGOOS(darwin) failed: %v", err) + } + + got := buildChildEnv( + nil, + cfg, + "/tmp/"+cfg.LibraryName, + "/tmp/wrapguard.sock", + 4242, + false, + true, + ) + + if value, ok := envValue(got, envWrapGuardNoInherit); !ok || value != "1" { + t.Fatalf("%s should be enabled in macOS GUI compatibility mode, got %q, present=%v", envWrapGuardNoInherit, value, ok) + } +} + +func TestBuildChildEnvOverridesInheritedWrapGuardStateOnDarwin(t *testing.T) { + cfg, err := injectionConfigForGOOS("darwin") + if err != nil { + t.Fatalf("injectionConfigForGOOS(darwin) failed: %v", err) + } + + got := buildChildEnv( + []string{ + "DYLD_INSERT_LIBRARIES=/tmp/old-a.dylib:/tmp/old-b.dylib", + "DYLD_FORCE_FLAT_NAMESPACE=1", + envWrapGuardIPCPath + "=/tmp/old.sock", + envWrapGuardSOCKSPort + "=9999", + envWrapGuardExpectRDY + "=0", + envWrapGuardDebug + "=0", + envWrapGuardDebugIPC + "=0", + envWrapGuardBlockUDP + "=0", + envWrapGuardNoInherit + "=0", + }, + cfg, + "/tmp/"+cfg.LibraryName, + "/tmp/new.sock", + 4242, + true, + true, + ) + + if value, ok := envValue(got, cfg.LibraryEnvVar); !ok || !strings.HasPrefix(value, "/tmp/"+cfg.LibraryName) { + t.Fatalf("%s should be reinjected with the current dylib first, got %q present=%v", cfg.LibraryEnvVar, value, ok) + } + if _, ok := envValue(got, "DYLD_FORCE_FLAT_NAMESPACE"); ok { + t.Fatal("DYLD_FORCE_FLAT_NAMESPACE should be stripped for Darwin DYLD_INTERPOSE launches") + } + if value, ok := envValue(got, envWrapGuardIPCPath); !ok || value != "/tmp/new.sock" { + t.Fatalf("%s = %q, present=%v, want /tmp/new.sock", envWrapGuardIPCPath, value, ok) + } + if value, ok := envValue(got, envWrapGuardSOCKSPort); !ok || value != "4242" { + t.Fatalf("%s = %q, present=%v, want 4242", envWrapGuardSOCKSPort, value, ok) + } + if value, ok := envValue(got, envWrapGuardExpectRDY); !ok || value != "1" { + t.Fatalf("%s = %q, present=%v, want 1", envWrapGuardExpectRDY, value, ok) + } + if value, ok := envValue(got, envWrapGuardDebug); !ok || value != "1" { + t.Fatalf("%s = %q, present=%v, want 1", envWrapGuardDebug, value, ok) + } + if value, ok := envValue(got, envWrapGuardDebugIPC); !ok || value != "1" { + t.Fatalf("%s = %q, present=%v, want 1", envWrapGuardDebugIPC, value, ok) + } + if value, ok := envValue(got, envWrapGuardBlockUDP); !ok || value != "1" { + t.Fatalf("%s = %q, present=%v, want 1", envWrapGuardBlockUDP, value, ok) + } + if value, ok := envValue(got, envWrapGuardNoInherit); !ok || value != "1" { + t.Fatalf("%s = %q, present=%v, want 1", envWrapGuardNoInherit, value, ok) + } +} + +func TestInitialHandshakeTimeout(t *testing.T) { + tests := []struct { + name string + goos string + target string + want string + }{ + { + name: "linux-default", + goos: "linux", + target: "/usr/bin/curl", + want: "3s", + }, + { + name: "darwin-cli-default", + goos: "darwin", + target: "/opt/homebrew/bin/curl", + want: "3s", + }, + { + name: "darwin-app-bundle", + goos: "darwin", + target: "/Applications/LibreWolf.app", + want: "15s", + }, + { + name: "darwin-inner-app-executable", + goos: "darwin", + target: "/Applications/LibreWolf.app/Contents/MacOS/librewolf", + want: "15s", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := initialHandshakeTimeout(tt.goos, tt.target).String(); got != tt.want { + t.Fatalf("initialHandshakeTimeout(%q, %q) = %s, want %s", tt.goos, tt.target, got, tt.want) + } + }) + } +} + +func writeAppBundleFixture(t *testing.T, bundleName string) (string, string) { + t.Helper() + + sourceExecutable, err := filepath.Abs(os.Args[0]) + if err != nil { + t.Fatalf("failed to resolve test executable path: %v", err) + } + + bundlePath := filepath.Join(t.TempDir(), bundleName+".app") + macOSDir := filepath.Join(bundlePath, "Contents", "MacOS") + if err := os.MkdirAll(macOSDir, 0o755); err != nil { + t.Fatalf("failed to create app bundle directory: %v", err) + } + + innerExecutable := filepath.Join(macOSDir, bundleName) + sourceData, err := os.ReadFile(sourceExecutable) + if err != nil { + t.Fatalf("failed to read source executable: %v", err) + } + if err := os.WriteFile(innerExecutable, sourceData, 0o755); err != nil { + t.Fatalf("failed to write bundle executable: %v", err) + } + + return bundlePath, innerExecutable +} + +func envValue(env []string, key string) (string, bool) { + prefix := key + "=" + for _, kv := range env { + if strings.HasPrefix(kv, prefix) { + return strings.TrimPrefix(kv, prefix), true + } + } + return "", false +} diff --git a/lib/intercept.c b/lib/intercept.c index a637606..57da97e 100644 --- a/lib/intercept.c +++ b/lib/intercept.c @@ -1,4 +1,6 @@ +#ifdef __linux__ #define _GNU_SOURCE +#endif #include #include #include @@ -11,293 +13,1349 @@ #include #include #include +#include #include +#include +#include +#include +#ifdef __APPLE__ +#include +#endif // Function pointers for original functions static int (*real_connect)(int sockfd, const struct sockaddr *addr, socklen_t addrlen) = NULL; static int (*real_bind)(int sockfd, const struct sockaddr *addr, socklen_t addrlen) = NULL; +static int (*real_getpeername)(int sockfd, struct sockaddr *addr, socklen_t *addrlen) = NULL; +static int (*real_close)(int fd) = NULL; +#ifdef __APPLE__ +static int (*real_connectx)(int sockfd, const sa_endpoints_t *endpoints, sae_associd_t associd, unsigned int flags, const struct iovec *iov, unsigned int iovcnt, size_t *len, sae_connid_t *connid) = NULL; +#endif // Global variables for configuration static char *ipc_path = NULL; static int socks_port = 0; static int initialized = 0; +static int debug_mode_cached = 0; +static int debug_ipc_cached = 0; +static int block_udp_443_cached = 0; +static int macos_no_inherit_cached = 0; +static int passthrough_mode_cached = 0; +static int expect_ready_cached = 0; +static __thread int internal_connect_guard = 0; + +static int wrapguard_connect_impl(int sockfd, const struct sockaddr *addr, socklen_t addrlen); +static int wrapguard_bind_impl(int sockfd, const struct sockaddr *addr, socklen_t addrlen); +#ifdef __APPLE__ +static int wrapguard_connectx_impl(int sockfd, const sa_endpoints_t *endpoints, sae_associd_t associd, unsigned int flags, const struct iovec *iov, unsigned int iovcnt, size_t *len, sae_connid_t *connid); +#endif +#ifdef __APPLE__ +static ssize_t wrapguard_sendto_impl(int sockfd, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen); +static ssize_t wrapguard_sendmsg_impl(int sockfd, const struct msghdr *msg, int flags); +#endif +static int raw_connect_call(int sockfd, const struct sockaddr *addr, socklen_t addrlen); +#ifdef __APPLE__ +static int raw_bind_call(int sockfd, const struct sockaddr *addr, socklen_t addrlen); +#endif +static int raw_close_call(int fd); +static int wait_for_socket(int sockfd, int for_write, int timeout_seconds); +static int recv_exact_with_timeout(int sockfd, unsigned char *buf, size_t len, int timeout_seconds); +static int send_all_with_timeout(int sockfd, const unsigned char *buf, size_t len, int timeout_seconds); +static const char *family_name(sa_family_t family); +static const char *socket_type_name(int sock_type); +#ifndef __APPLE__ +static int call_real_connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen); +#endif +static int call_real_bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen); +static int call_real_getpeername(int sockfd, struct sockaddr *addr, socklen_t *addrlen); +static int call_real_close(int fd); +#ifdef __APPLE__ +static int call_real_connectx(int sockfd, const sa_endpoints_t *endpoints, sae_associd_t associd, unsigned int flags, const struct iovec *iov, unsigned int iovcnt, size_t *len, sae_connid_t *connid); +#endif +static int is_loopback_connect(const struct sockaddr *addr); +static int is_nonblocking_socket(int sockfd); +static int block_udp_443_enabled(void); +static int should_block_udp_target(const struct sockaddr *addr); +#ifdef __APPLE__ +static int should_block_udp_send_target(int sockfd, const struct sockaddr *addr, socklen_t addrlen, char *buf, size_t buf_len); +#endif +static int sockaddr_port(const struct sockaddr *addr); +static void format_sockaddr(const struct sockaddr *addr, char *buf, size_t buf_len); +static void remember_virtual_peer(int sockfd, const struct sockaddr *addr, socklen_t addrlen); +#ifndef __APPLE__ +static int lookup_virtual_peer(int sockfd, struct sockaddr *addr, socklen_t *addrlen); +#endif +static void forget_virtual_peer(int sockfd); +static void send_ipc_message(const char *type, int fd, int port, const char *addr); +static int should_passthrough_current_process(void); +static int expect_ready_enabled(void); + +#ifdef __APPLE__ +static const char *process_basename(const char *path) { + if (path == NULL) { + return ""; + } + const char *base = strrchr(path, '/'); + return base ? base + 1 : path; +} + +static int str_equals(const char *value, const char *expected) { + return value != NULL && expected != NULL && strcmp(value, expected) == 0; +} + +static int should_passthrough_mozilla_process(char *const *argv) { + if (argv == NULL || argv[0] == NULL) { + return 0; + } + + const char *base = process_basename(argv[0]); + if (str_equals(base, "plugin-container")) { + size_t argc = 0; + while (argv[argc] != NULL) { + argc++; + } + if (argc > 1 && str_equals(argv[argc - 1], "socket")) { + return 0; + } + return 1; + } + + if (strstr(base, "GPU Helper") != NULL || str_equals(base, "gpu-helper")) { + return 1; + } + + return 0; +} +#endif + + +struct virtual_peer_entry { + int fd; + struct sockaddr_storage addr; + socklen_t addrlen; + struct virtual_peer_entry *next; +}; + +static pthread_mutex_t virtual_peer_mutex = PTHREAD_MUTEX_INITIALIZER; +static struct virtual_peer_entry *virtual_peers = NULL; + +#ifdef __APPLE__ +#define DYLD_INTERPOSE(_replacement, _replacee) \ + __attribute__((used)) static struct { \ + const void *replacement; \ + const void *replacee; \ + } _interpose_##_replacee \ + __attribute__((section("__DATA,__interpose"))) = { \ + (const void *)(unsigned long)&_replacement, \ + (const void *)(unsigned long)&_replacee \ + }; +#endif + +static const char *debug_prefix(void) { +#ifdef __APPLE__ + return "WrapGuard DYLD: "; +#else + return "WrapGuard LD_PRELOAD: "; +#endif +} + +static int debug_enabled(void) { + if (initialized) { + return debug_mode_cached && !passthrough_mode_cached; + } + char *debug_mode = getenv("WRAPGUARD_DEBUG"); + return debug_mode != NULL && strcmp(debug_mode, "1") == 0; +} + +static int debug_ipc_enabled(void) { +#ifdef __APPLE__ + if (initialized) { + return debug_ipc_cached; + } + char *value = getenv("WRAPGUARD_DEBUG_IPC"); + return value != NULL && strcmp(value, "1") == 0; +#else + return 0; +#endif +} + +#ifdef __APPLE__ +static void write_stderr_line(const char *prefix, const char *message) { + char buffer[768]; + int written = snprintf(buffer, sizeof(buffer), "%s%s\n", prefix ? prefix : "", message ? message : ""); + if (written <= 0) { + return; + } + + size_t len = (size_t)written; + if (len >= sizeof(buffer)) { + len = sizeof(buffer) - 1; + } + + size_t offset = 0; + while (offset < len) { + ssize_t chunk = write(STDERR_FILENO, buffer + offset, len - offset); + if (chunk < 0) { + if (errno == EINTR) { + continue; + } + return; + } + offset += (size_t)chunk; + } +} +#endif + +static void log_debugf(const char *fmt, ...) { + if (!debug_enabled()) { + return; + } + int saved_errno = errno; + + char message[512]; + va_list ap; + va_start(ap, fmt); + vsnprintf(message, sizeof(message), fmt, ap); + va_end(ap); + +#ifdef __APPLE__ + if (debug_ipc_enabled() && ipc_path != NULL) { + send_ipc_message("DEBUG", -1, 0, message); + errno = saved_errno; + return; + } + write_stderr_line(debug_prefix(), message); +#else + fprintf(stderr, "%s%s\n", debug_prefix(), message); +#endif + errno = saved_errno; +} + +static void log_errorf(const char *fmt, ...) { + int saved_errno = errno; + + char message[512]; + va_list ap; + va_start(ap, fmt); + vsnprintf(message, sizeof(message), fmt, ap); + va_end(ap); + +#ifdef __APPLE__ + if (debug_ipc_enabled() && ipc_path != NULL) { + send_ipc_message("ERROR", -1, 0, message); + errno = saved_errno; + return; + } + write_stderr_line(debug_prefix(), message); +#else + fprintf(stderr, "%s%s\n", debug_prefix(), message); +#endif + errno = saved_errno; +} + +static int block_udp_443_enabled(void) { + if (initialized) { + return block_udp_443_cached; + } + char *value = getenv("WRAPGUARD_BLOCK_UDP_443"); + return value != NULL && strcmp(value, "1") == 0; +} + +static int macos_no_inherit_enabled(void) { +#ifdef __APPLE__ + if (initialized) { + return macos_no_inherit_cached; + } + char *value = getenv("WRAPGUARD_MACOS_NO_INHERIT"); + return value != NULL && strcmp(value, "1") == 0; +#else + return 0; +#endif +} + +static int expect_ready_enabled(void) { + if (initialized) { + return expect_ready_cached; + } + char *value = getenv("WRAPGUARD_EXPECT_READY"); + return value != NULL && strcmp(value, "1") == 0; +} + +static const char *family_name(sa_family_t family) { + switch (family) { + case AF_UNIX: + return "AF_UNIX"; + case AF_INET: + return "AF_INET"; + case AF_INET6: + return "AF_INET6"; + default: + return "AF_UNKNOWN"; + } +} + +static const char *socket_type_name(int sock_type) { + switch (sock_type) { + case SOCK_STREAM: + return "SOCK_STREAM"; + case SOCK_DGRAM: + return "SOCK_DGRAM"; +#ifdef SOCK_SEQPACKET + case SOCK_SEQPACKET: + return "SOCK_SEQPACKET"; +#endif + default: + return "SOCK_OTHER"; + } +} + +#ifndef __APPLE__ +static int call_real_connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { + if (real_connect == NULL) { + errno = ENOSYS; + return -1; + } + return real_connect(sockfd, addr, addrlen); +} +#endif + +static int call_real_bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { +#ifdef __APPLE__ + return raw_bind_call(sockfd, addr, addrlen); +#else + if (real_bind == NULL) { + errno = ENOSYS; + return -1; + } + return real_bind(sockfd, addr, addrlen); +#endif +} + +static int call_real_getpeername(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { + if (real_getpeername == NULL) { + errno = ENOSYS; + return -1; + } + return real_getpeername(sockfd, addr, addrlen); +} + +static int call_real_close(int fd) { +#ifdef __APPLE__ + return raw_close_call(fd); +#else + if (real_close == NULL) { + return raw_close_call(fd); + } + return real_close(fd); +#endif +} + +#ifdef __APPLE__ +static int call_real_connectx(int sockfd, const sa_endpoints_t *endpoints, sae_associd_t associd, unsigned int flags, const struct iovec *iov, unsigned int iovcnt, size_t *len, sae_connid_t *connid) { + if (real_connectx == NULL) { + errno = ENOSYS; + return -1; + } + return real_connectx(sockfd, endpoints, associd, flags, iov, iovcnt, len, connid); +} +#endif + +#ifdef __APPLE__ +static int raw_connectx_call(int sockfd, const sa_endpoints_t *endpoints, sae_associd_t associd, unsigned int flags, const struct iovec *iov, unsigned int iovcnt, size_t *len, sae_connid_t *connid) { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + return (int)syscall(SYS_connectx, sockfd, endpoints, associd, flags, iov, iovcnt, len, connid); +#pragma clang diagnostic pop +} +#endif + +static int raw_connect_call(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { +#ifdef __APPLE__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif + return (int)syscall(SYS_connect, sockfd, addr, addrlen); +#ifdef __APPLE__ +#pragma clang diagnostic pop +#endif +} + +#ifdef __APPLE__ +static int raw_bind_call(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { +#ifdef __APPLE__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif + return (int)syscall(SYS_bind, sockfd, addr, addrlen); +#ifdef __APPLE__ +#pragma clang diagnostic pop +#endif +} +#endif + +static int raw_close_call(int fd) { +#ifdef __APPLE__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif + return (int)syscall(SYS_close, fd); +#ifdef __APPLE__ +#pragma clang diagnostic pop +#endif +} + +static int wait_for_socket(int sockfd, int for_write, int timeout_seconds) { + for (;;) { + fd_set fds; + FD_ZERO(&fds); + FD_SET(sockfd, &fds); + + struct timeval timeout; + timeout.tv_sec = timeout_seconds; + timeout.tv_usec = 0; + + int result = select(sockfd + 1, for_write ? NULL : &fds, for_write ? &fds : NULL, NULL, &timeout); + if (result >= 0) { + return result; + } + if (errno != EINTR) { + return -1; + } + } +} + +static int recv_exact_with_timeout(int sockfd, unsigned char *buf, size_t len, int timeout_seconds) { + size_t offset = 0; + + while (offset < len) { + int select_result = wait_for_socket(sockfd, 0, timeout_seconds); + if (select_result <= 0) { + return -1; + } + + ssize_t chunk = recv(sockfd, buf + offset, len - offset, 0); + if (chunk <= 0) { + if (chunk < 0 && (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK)) { + continue; + } + return -1; + } + + offset += (size_t)chunk; + } + + return 0; +} + +static int send_all_with_timeout(int sockfd, const unsigned char *buf, size_t len, int timeout_seconds) { + size_t offset = 0; + + while (offset < len) { + int select_result = wait_for_socket(sockfd, 1, timeout_seconds); + if (select_result <= 0) { + return -1; + } + + ssize_t chunk = send(sockfd, buf + offset, len - offset, 0); + if (chunk < 0) { + if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) { + continue; + } + return -1; + } + + offset += (size_t)chunk; + } + + return 0; +} + +static int is_loopback_connect(const struct sockaddr *addr) { + if (addr == NULL) { + return 0; + } + + if (addr->sa_family == AF_INET) { + const struct sockaddr_in *in_addr = (const struct sockaddr_in *)addr; + uint32_t ip = ntohl(in_addr->sin_addr.s_addr); + return (ip & 0xFF000000) == 0x7F000000; + } + + if (addr->sa_family == AF_INET6) { + const struct sockaddr_in6 *in6_addr = (const struct sockaddr_in6 *)addr; + return IN6_IS_ADDR_LOOPBACK(&in6_addr->sin6_addr); + } + + return 0; +} + +static int is_nonblocking_socket(int sockfd) { + int flags = fcntl(sockfd, F_GETFL); + if (flags < 0) { + return 0; + } + return (flags & O_NONBLOCK) != 0; +} + +static int sockaddr_port(const struct sockaddr *addr) { + if (addr == NULL) { + return 0; + } + if (addr->sa_family == AF_INET) { + return ntohs(((const struct sockaddr_in *)addr)->sin_port); + } + if (addr->sa_family == AF_INET6) { + return ntohs(((const struct sockaddr_in6 *)addr)->sin6_port); + } + return 0; +} + +static void format_sockaddr(const struct sockaddr *addr, char *buf, size_t buf_len) { + if (buf_len == 0) { + return; + } + if (addr == NULL) { + snprintf(buf, buf_len, "NULL"); + return; + } + + if (addr->sa_family == AF_INET) { + const struct sockaddr_in *in_addr = (const struct sockaddr_in *)addr; + char ip_str[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &in_addr->sin_addr, ip_str, sizeof(ip_str)); + snprintf(buf, buf_len, "%s:%d", ip_str, ntohs(in_addr->sin_port)); + return; + } + + if (addr->sa_family == AF_INET6) { + const struct sockaddr_in6 *in6_addr = (const struct sockaddr_in6 *)addr; + char ip_str[INET6_ADDRSTRLEN]; + inet_ntop(AF_INET6, &in6_addr->sin6_addr, ip_str, sizeof(ip_str)); + snprintf(buf, buf_len, "[%s]:%d", ip_str, ntohs(in6_addr->sin6_port)); + return; + } + + snprintf(buf, buf_len, "%s(%d)", family_name(addr->sa_family), addr->sa_family); +} + +static void remember_virtual_peer(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { + if (addr == NULL || addrlen == 0) { + return; + } + + if ((size_t)addrlen > sizeof(struct sockaddr_storage)) { + addrlen = sizeof(struct sockaddr_storage); + } + + pthread_mutex_lock(&virtual_peer_mutex); + + struct virtual_peer_entry *entry = virtual_peers; + while (entry != NULL) { + if (entry->fd == sockfd) { + memset(&entry->addr, 0, sizeof(entry->addr)); + memcpy(&entry->addr, addr, (size_t)addrlen); + entry->addrlen = addrlen; + pthread_mutex_unlock(&virtual_peer_mutex); + return; + } + entry = entry->next; + } + + entry = (struct virtual_peer_entry *)calloc(1, sizeof(*entry)); + if (entry != NULL) { + entry->fd = sockfd; + memcpy(&entry->addr, addr, (size_t)addrlen); + entry->addrlen = addrlen; + entry->next = virtual_peers; + virtual_peers = entry; + } + + pthread_mutex_unlock(&virtual_peer_mutex); +} + +#ifndef __APPLE__ +static int lookup_virtual_peer(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { + int found = 0; + + if (addr == NULL || addrlen == NULL) { + return 0; + } + + pthread_mutex_lock(&virtual_peer_mutex); + + struct virtual_peer_entry *entry = virtual_peers; + while (entry != NULL) { + if (entry->fd == sockfd) { + socklen_t copy_len = entry->addrlen; + if (*addrlen < copy_len) { + copy_len = *addrlen; + } + memcpy(addr, &entry->addr, (size_t)copy_len); + *addrlen = entry->addrlen; + found = 1; + break; + } + entry = entry->next; + } + + pthread_mutex_unlock(&virtual_peer_mutex); + return found; +} +#endif + +static void forget_virtual_peer(int sockfd) { + pthread_mutex_lock(&virtual_peer_mutex); + + struct virtual_peer_entry *entry = virtual_peers; + struct virtual_peer_entry *prev = NULL; + while (entry != NULL) { + if (entry->fd == sockfd) { + if (prev == NULL) { + virtual_peers = entry->next; + } else { + prev->next = entry->next; + } + free(entry); + break; + } + prev = entry; + entry = entry->next; + } + + pthread_mutex_unlock(&virtual_peer_mutex); +} + +static int should_block_udp_target(const struct sockaddr *addr) { + if (!block_udp_443_enabled()) { + return 0; + } + if (addr == NULL) { + return 0; + } + if (addr->sa_family != AF_INET && addr->sa_family != AF_INET6) { + return 0; + } + if (is_loopback_connect(addr)) { + return 0; + } + return sockaddr_port(addr) == 443; +} + +#ifdef __APPLE__ +static int should_block_udp_send_target(int sockfd, const struct sockaddr *addr, socklen_t addrlen, char *buf, size_t buf_len) { + struct sockaddr_storage target_storage; + const struct sockaddr *target = addr; + socklen_t target_len = addrlen; + + if (target == NULL || target_len < (socklen_t)sizeof(sa_family_t)) { + memset(&target_storage, 0, sizeof(target_storage)); + target_len = sizeof(target_storage); + if (call_real_getpeername(sockfd, (struct sockaddr *)&target_storage, &target_len) != 0) { + return 0; + } + target = (const struct sockaddr *)&target_storage; + } + + if (!should_block_udp_target(target)) { + return 0; + } + + if (buf != NULL && buf_len > 0) { + format_sockaddr(target, buf, buf_len); + } + + return 1; +} +#endif // Initialize the library -static void init_library() { +static void init_library(void) { if (initialized) return; + debug_mode_cached = debug_enabled(); + debug_ipc_cached = debug_ipc_enabled(); + block_udp_443_cached = block_udp_443_enabled(); + macos_no_inherit_cached = macos_no_inherit_enabled(); + passthrough_mode_cached = should_passthrough_current_process(); + expect_ready_cached = expect_ready_enabled(); initialized = 1; + +#ifdef __APPLE__ + if (expect_ready_cached) { + unsetenv("WRAPGUARD_EXPECT_READY"); + } + + if (passthrough_mode_cached) { + char *ipc_path_env = getenv("WRAPGUARD_IPC_PATH"); + if (ipc_path_env != NULL) { + ipc_path = strdup(ipc_path_env); + } + char *socks_port_str = getenv("WRAPGUARD_SOCKS_PORT"); + if (socks_port_str != NULL) { + socks_port = atoi(socks_port_str); + } + + if (expect_ready_cached && ipc_path != NULL && socks_port != 0) { + send_ipc_message("READY", -1, socks_port, NULL); + } + return; + } +#endif // Load original functions +#ifdef __APPLE__ + dlerror(); + real_connect = (int (*)(int, const struct sockaddr *, socklen_t))dlsym(RTLD_NEXT, "connect"); + const char *connect_err = dlerror(); + dlerror(); + real_bind = (int (*)(int, const struct sockaddr *, socklen_t))dlsym(RTLD_NEXT, "bind"); + const char *bind_err = dlerror(); + dlerror(); + real_getpeername = (int (*)(int, struct sockaddr *, socklen_t *))dlsym(RTLD_NEXT, "getpeername"); + const char *getpeername_err = dlerror(); + dlerror(); + real_connectx = (int (*)(int, const sa_endpoints_t *, sae_associd_t, unsigned int, const struct iovec *, unsigned int, size_t *, sae_connid_t *))dlsym(RTLD_NEXT, "connectx"); + const char *connectx_err = dlerror(); + const char *close_err = NULL; +#else + dlerror(); real_connect = dlsym(RTLD_NEXT, "connect"); + const char *connect_err = dlerror(); + dlerror(); real_bind = dlsym(RTLD_NEXT, "bind"); + const char *bind_err = dlerror(); + dlerror(); + real_getpeername = dlsym(RTLD_NEXT, "getpeername"); + const char *getpeername_err = dlerror(); + dlerror(); + real_close = dlsym(RTLD_NEXT, "close"); + const char *close_err = dlerror(); + const char *connectx_err = "unsupported"; +#endif + + if (real_connect == NULL || real_bind == NULL || real_getpeername == NULL +#ifndef __APPLE__ + || real_close == NULL +#else + || real_connectx == NULL +#endif + ) { + log_errorf("Failed to resolve original socket symbols (connect=%s bind=%s getpeername=%s connectx=%s close=%s)", + connect_err ? connect_err : "unknown", + bind_err ? bind_err : "unknown", + getpeername_err ? getpeername_err : "unknown", + connectx_err ? connectx_err : "unknown", + close_err ? close_err : "syscall"); + return; + } // Get configuration from environment - ipc_path = getenv("WRAPGUARD_IPC_PATH"); + char *ipc_path_env = getenv("WRAPGUARD_IPC_PATH"); + if (ipc_path_env != NULL) { + ipc_path = strdup(ipc_path_env); + } char *socks_port_str = getenv("WRAPGUARD_SOCKS_PORT"); if (socks_port_str) { socks_port = atoi(socks_port_str); } + +#ifdef __APPLE__ + if (macos_no_inherit_cached) { + unsetenv("DYLD_INSERT_LIBRARIES"); + unsetenv("DYLD_FORCE_FLAT_NAMESPACE"); + unsetenv("WRAPGUARD_IPC_PATH"); + unsetenv("WRAPGUARD_SOCKS_PORT"); + unsetenv("WRAPGUARD_DEBUG"); + unsetenv("WRAPGUARD_DEBUG_IPC"); + unsetenv("WRAPGUARD_BLOCK_UDP_443"); + unsetenv("WRAPGUARD_MACOS_NO_INHERIT"); + } +#endif - // Debug output (only in debug mode) - char *debug_mode = getenv("WRAPGUARD_DEBUG"); - if (debug_mode && strcmp(debug_mode, "1") == 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Initialized\n"); - fprintf(stderr, "WrapGuard LD_PRELOAD: IPC path: %s\n", ipc_path ? ipc_path : "NULL"); - fprintf(stderr, "WrapGuard LD_PRELOAD: SOCKS port: %d\n", socks_port); + if (debug_enabled()) { + log_debugf("Initialized"); + log_debugf("IPC path: %s", ipc_path ? ipc_path : "NULL"); + log_debugf("SOCKS port: %d", socks_port); +#ifdef __APPLE__ + log_debugf("Resolved real symbols connect=%p bind=%p getpeername=%p connectx=%p close=%p", (void *)real_connect, (void *)real_bind, (void *)real_getpeername, (void *)real_connectx, (void *)real_close); +#else + log_debugf("Resolved real symbols connect=%p bind=%p getpeername=%p close=%p", (void *)real_connect, (void *)real_bind, (void *)real_getpeername, (void *)real_close); +#endif + if (block_udp_443_enabled()) { + log_debugf("Likely QUIC UDP/443 suppression is enabled"); + } } if (!ipc_path || socks_port == 0) { - fprintf(stderr, "WrapGuard: Missing environment variables\n"); + log_errorf("WrapGuard: Missing environment variables"); + return; + } + + if (expect_ready_cached) { + send_ipc_message("READY", -1, socks_port, NULL); + if (debug_enabled()) { + log_debugf("Interceptor loaded and announced readiness"); + } } } +static int should_passthrough_current_process(void) { +#ifdef __APPLE__ + char ***argv_ptr = _NSGetArgv(); + if (argv_ptr == NULL) { + return 0; + } + return should_passthrough_mozilla_process(*argv_ptr); +#else + return 0; +#endif +} + +__attribute__((constructor)) +static void wrapguard_constructor(void) { + init_library(); +} + // Check if an address should be intercepted static int should_intercept_connect(const struct sockaddr *addr) { if (addr->sa_family != AF_INET && addr->sa_family != AF_INET6) { return 0; // Only intercept IP connections } - - if (addr->sa_family == AF_INET) { - struct sockaddr_in *in_addr = (struct sockaddr_in *)addr; - - // Don't intercept localhost connections (except when connecting to our SOCKS proxy) - uint32_t ip = ntohl(in_addr->sin_addr.s_addr); - if ((ip & 0xFF000000) == 0x7F000000) { // 127.x.x.x - int port = ntohs(in_addr->sin_port); - if (port == socks_port) { - return 0; // Don't intercept connections to our own SOCKS proxy - } - } - - return 1; // Intercept all other connections + + if (is_loopback_connect(addr)) { + return 0; } - - // TODO: Handle IPv6 if needed - return 0; + + return 1; } // Send IPC message static void send_ipc_message(const char *type, int fd, int port, const char *addr) { + int saved_errno = errno; if (!ipc_path) return; + if (internal_connect_guard > 0) { + errno = saved_errno; + return; + } + + internal_connect_guard++; int sock = socket(AF_UNIX, SOCK_STREAM, 0); - if (sock < 0) return; + if (sock < 0) { + internal_connect_guard--; + return; + } struct sockaddr_un sun; memset(&sun, 0, sizeof(sun)); sun.sun_family = AF_UNIX; strncpy(sun.sun_path, ipc_path, sizeof(sun.sun_path) - 1); - if (connect(sock, (struct sockaddr *)&sun, sizeof(sun)) == 0) { + int ipc_connect_result = raw_connect_call(sock, (struct sockaddr *)&sun, sizeof(sun)); + if (ipc_connect_result == 0) { char message[512]; snprintf(message, sizeof(message), - "{\"type\":\"%s\",\"fd\":%d,\"port\":%d,\"addr\":\"%s\"}\n", - type, fd, port, addr ? addr : ""); - - write(sock, message, strlen(message)); + "{\"type\":\"%s\",\"fd\":%d,\"port\":%d,\"addr\":\"%s\",\"pid\":%d}\n", + type, fd, port, addr ? addr : "", (int)getpid()); + + size_t message_len = strlen(message); + size_t offset = 0; + while (offset < message_len) { + ssize_t written = write(sock, message + offset, message_len - offset); + if (written < 0) { + if (errno == EINTR) { + continue; + } + break; + } + offset += (size_t)written; + } } - close(sock); + raw_close_call(sock); + internal_connect_guard--; + errno = saved_errno; +} + +#ifdef __APPLE__ +static ssize_t raw_sendto_call(int sockfd, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen) { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + return (ssize_t)syscall(SYS_sendto, sockfd, buf, len, flags, dest_addr, addrlen); +#pragma clang diagnostic pop +} + +static ssize_t raw_sendmsg_call(int sockfd, const struct msghdr *msg, int flags) { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + return (ssize_t)syscall(SYS_sendmsg, sockfd, msg, flags); +#pragma clang diagnostic pop +} + +static ssize_t wrapguard_sendto_impl(int sockfd, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen) { + init_library(); + if (passthrough_mode_cached) { + return raw_sendto_call(sockfd, buf, len, flags, dest_addr, addrlen); + } + + int sock_type = 0; + socklen_t sock_type_len = sizeof(sock_type); + if (getsockopt(sockfd, SOL_SOCKET, SO_TYPE, &sock_type, &sock_type_len) != 0 || sock_type != SOCK_DGRAM) { + return raw_sendto_call(sockfd, buf, len, flags, dest_addr, addrlen); + } + + char addr_str[INET6_ADDRSTRLEN + 32]; + if (should_block_udp_send_target(sockfd, dest_addr, addrlen, addr_str, sizeof(addr_str))) { + log_debugf("Blocking likely QUIC UDP sendto() to %s", addr_str); + errno = EHOSTUNREACH; + return -1; + } + + return raw_sendto_call(sockfd, buf, len, flags, dest_addr, addrlen); +} + +static ssize_t wrapguard_sendmsg_impl(int sockfd, const struct msghdr *msg, int flags) { + init_library(); + if (passthrough_mode_cached) { + return raw_sendmsg_call(sockfd, msg, flags); + } + + if (msg == NULL) { + return raw_sendmsg_call(sockfd, msg, flags); + } + + int sock_type = 0; + socklen_t sock_type_len = sizeof(sock_type); + if (getsockopt(sockfd, SOL_SOCKET, SO_TYPE, &sock_type, &sock_type_len) != 0 || sock_type != SOCK_DGRAM) { + return raw_sendmsg_call(sockfd, msg, flags); + } + + char addr_str[INET6_ADDRSTRLEN + 32]; + if (should_block_udp_send_target(sockfd, (const struct sockaddr *)msg->msg_name, msg->msg_namelen, addr_str, sizeof(addr_str))) { + log_debugf("Blocking likely QUIC UDP sendmsg() to %s", addr_str); + errno = EHOSTUNREACH; + return -1; + } + + return raw_sendmsg_call(sockfd, msg, flags); } +#endif // SOCKS5 connection helper static int socks5_connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { - char *debug_mode = getenv("WRAPGUARD_DEBUG"); - - if (addr->sa_family != AF_INET) { + (void)addrlen; + int was_nonblocking = is_nonblocking_socket(sockfd); + + if (addr->sa_family != AF_INET && addr->sa_family != AF_INET6) { errno = EAFNOSUPPORT; return -1; } - - struct sockaddr_in *target = (struct sockaddr_in *)addr; + struct sockaddr_in socks_addr; memset(&socks_addr, 0, sizeof(socks_addr)); socks_addr.sin_family = AF_INET; socks_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); socks_addr.sin_port = htons(socks_port); - // Connect to SOCKS5 proxy - if (debug_mode && strcmp(debug_mode, "1") == 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Connecting to SOCKS5 proxy at 127.0.0.1:%d\n", socks_port); + if (debug_enabled()) { + log_debugf("Connecting to SOCKS5 proxy at 127.0.0.1:%d", socks_port); } - int connect_result = real_connect(sockfd, (struct sockaddr *)&socks_addr, sizeof(socks_addr)); + int connect_result = raw_connect_call(sockfd, (struct sockaddr *)&socks_addr, sizeof(socks_addr)); if (connect_result != 0 && errno != EINPROGRESS) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Failed to connect to SOCKS5 proxy: %s\n", strerror(errno)); + forget_virtual_peer(sockfd); + log_errorf("Failed to connect to SOCKS5 proxy: %s", strerror(errno)); return -1; } - // For non-blocking sockets, we need to wait for connection to complete - if (errno == EINPROGRESS) { - if (debug_mode && strcmp(debug_mode, "1") == 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Non-blocking connect in progress, waiting...\n"); - } - fd_set write_fds; - FD_ZERO(&write_fds); - FD_SET(sockfd, &write_fds); - - struct timeval timeout = {5, 0}; // 5 second timeout - int select_result = select(sockfd + 1, NULL, &write_fds, NULL, &timeout); + if (connect_result != 0 && errno == EINPROGRESS) { + if (debug_enabled()) { + log_debugf("Non-blocking connect in progress, waiting..."); + } + int select_result = wait_for_socket(sockfd, 1, 5); if (select_result <= 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Timeout waiting for SOCKS5 connection\n"); + forget_virtual_peer(sockfd); + log_errorf("Timeout waiting for SOCKS5 connection"); return -1; } - // Check if connection actually succeeded - int so_error; + int so_error = 0; socklen_t len = sizeof(so_error); - if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0 || so_error != 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: SOCKS5 connection failed: %s\n", strerror(so_error)); + if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &so_error, &len) != 0) { + forget_virtual_peer(sockfd); + log_errorf("SOCKS5 connection failed while reading SO_ERROR: %s", strerror(errno)); + return -1; + } + if (so_error != 0) { + forget_virtual_peer(sockfd); + log_errorf("SOCKS5 connection failed: %s", strerror(so_error)); + errno = so_error; return -1; } } - if (debug_mode && strcmp(debug_mode, "1") == 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Connected to SOCKS5 proxy, starting handshake\n"); + if (debug_enabled()) { + log_debugf("Connected to SOCKS5 proxy, starting handshake"); } - // SOCKS5 handshake - unsigned char handshake[] = {0x05, 0x01, 0x00}; // Version 5, 1 method, no auth - if (debug_mode && strcmp(debug_mode, "1") == 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Sending SOCKS5 handshake\n"); + unsigned char handshake[] = {0x05, 0x01, 0x00}; + if (debug_enabled()) { + log_debugf("Sending SOCKS5 handshake"); } - if (send(sockfd, handshake, 3, 0) != 3) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Failed to send SOCKS5 handshake\n"); + if (send_all_with_timeout(sockfd, handshake, sizeof(handshake), 5) != 0) { + forget_virtual_peer(sockfd); + log_errorf("Failed to send SOCKS5 handshake: %s", strerror(errno)); return -1; } unsigned char response[2]; - if (debug_mode && strcmp(debug_mode, "1") == 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Waiting for SOCKS5 handshake response\n"); + if (debug_enabled()) { + log_debugf("Waiting for SOCKS5 handshake response"); } - - // Wait for response with timeout (non-blocking socket issue) - fd_set read_fds; - FD_ZERO(&read_fds); - FD_SET(sockfd, &read_fds); - - struct timeval timeout = {5, 0}; // 5 second timeout - int select_result = select(sockfd + 1, &read_fds, NULL, NULL, &timeout); - if (select_result <= 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Timeout waiting for SOCKS5 handshake response\n"); - return -1; - } - - int recv_bytes = recv(sockfd, response, 2, 0); - if (recv_bytes != 2) { - fprintf(stderr, "WrapGuard LD_PRELOAD: SOCKS5 handshake response failed, got %d bytes, errno: %s\n", recv_bytes, strerror(errno)); + + if (recv_exact_with_timeout(sockfd, response, sizeof(response), 5) != 0) { + forget_virtual_peer(sockfd); + log_errorf("Timeout waiting for SOCKS5 handshake response"); return -1; } if (response[0] != 0x05 || response[1] != 0x00) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Invalid SOCKS5 handshake response: %02x %02x\n", response[0], response[1]); + forget_virtual_peer(sockfd); + log_errorf("Invalid SOCKS5 handshake response: %02x %02x", response[0], response[1]); return -1; } - if (debug_mode && strcmp(debug_mode, "1") == 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: SOCKS5 handshake successful\n"); + if (debug_enabled()) { + log_debugf("SOCKS5 handshake successful"); } - // SOCKS5 connect request - unsigned char connect_req[10]; - connect_req[0] = 0x05; // Version - connect_req[1] = 0x01; // Connect command - connect_req[2] = 0x00; // Reserved - connect_req[3] = 0x01; // IPv4 address type - memcpy(&connect_req[4], &target->sin_addr, 4); // IP address - memcpy(&connect_req[8], &target->sin_port, 2); // Port - - if (send(sockfd, connect_req, 10, 0) != 10) { + unsigned char connect_req[22]; + size_t connect_req_len = 0; + connect_req[connect_req_len++] = 0x05; + connect_req[connect_req_len++] = 0x01; + connect_req[connect_req_len++] = 0x00; + + if (addr->sa_family == AF_INET) { + const struct sockaddr_in *target = (const struct sockaddr_in *)addr; + connect_req[connect_req_len++] = 0x01; + memcpy(&connect_req[connect_req_len], &target->sin_addr, 4); + connect_req_len += 4; + memcpy(&connect_req[connect_req_len], &target->sin_port, 2); + connect_req_len += 2; + } else { + const struct sockaddr_in6 *target6 = (const struct sockaddr_in6 *)addr; + connect_req[connect_req_len++] = 0x04; + memcpy(&connect_req[connect_req_len], &target6->sin6_addr, 16); + connect_req_len += 16; + memcpy(&connect_req[connect_req_len], &target6->sin6_port, 2); + connect_req_len += 2; + } + + if (send_all_with_timeout(sockfd, connect_req, connect_req_len, 15) != 0) { + forget_virtual_peer(sockfd); + log_errorf("Failed to send SOCKS5 connect request: %s", strerror(errno)); return -1; } - - // Read SOCKS5 response with timeout - unsigned char connect_resp[10]; - - FD_ZERO(&read_fds); - FD_SET(sockfd, &read_fds); - timeout.tv_sec = 5; - timeout.tv_usec = 0; - - select_result = select(sockfd + 1, &read_fds, NULL, NULL, &timeout); - if (select_result <= 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: Timeout waiting for SOCKS5 connect response\n"); + + unsigned char connect_resp_header[4]; + if (recv_exact_with_timeout(sockfd, connect_resp_header, sizeof(connect_resp_header), 15) != 0) { + forget_virtual_peer(sockfd); + log_errorf("Timeout waiting for SOCKS5 connect response"); return -1; } - - if (recv(sockfd, connect_resp, 10, 0) != 10 || connect_resp[0] != 0x05 || connect_resp[1] != 0x00) { - fprintf(stderr, "WrapGuard LD_PRELOAD: SOCKS5 connect failed\n"); + + if (connect_resp_header[0] != 0x05 || connect_resp_header[1] != 0x00) { + forget_virtual_peer(sockfd); + log_errorf("SOCKS5 connect failed: version=%02x status=%02x", connect_resp_header[0], connect_resp_header[1]); errno = ECONNREFUSED; return -1; } - - return 0; // Success + + size_t addr_bytes = 0; + switch (connect_resp_header[3]) { + case 0x01: + addr_bytes = 4 + 2; + break; + case 0x03: { + unsigned char domain_len = 0; + if (recv_exact_with_timeout(sockfd, &domain_len, 1, 15) != 0) { + forget_virtual_peer(sockfd); + log_errorf("Timed out reading SOCKS5 domain length"); + errno = ECONNREFUSED; + return -1; + } + addr_bytes = (size_t)domain_len + 2; + break; + } + case 0x04: + addr_bytes = 16 + 2; + break; + default: + forget_virtual_peer(sockfd); + log_errorf("SOCKS5 connect failed: unsupported atyp=%02x", connect_resp_header[3]); + errno = ECONNREFUSED; + return -1; + } + + if (addr_bytes > 0) { + unsigned char addr_buf[258]; + if (recv_exact_with_timeout(sockfd, addr_buf, addr_bytes, 15) != 0) { + forget_virtual_peer(sockfd); + log_errorf("Timed out reading SOCKS5 connect address payload"); + errno = ECONNREFUSED; + return -1; + } + } + + if (debug_enabled()) { + log_debugf("SOCKS5 connect successful"); + } + + remember_virtual_peer(sockfd, addr, addrlen); + + if (was_nonblocking) { + if (debug_enabled()) { + log_debugf("Preserving non-blocking connect semantics after SOCKS5 handshake"); + } + errno = EINPROGRESS; + return -1; + } + + return 0; } -// Intercepted connect function -int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { +static int wrapguard_connect_impl(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { init_library(); - - // Convert address to string for logging - char addr_str[INET_ADDRSTRLEN + 16]; - if (addr->sa_family == AF_INET) { - struct sockaddr_in *in_addr = (struct sockaddr_in *)addr; - char ip_str[INET_ADDRSTRLEN]; - inet_ntop(AF_INET, &in_addr->sin_addr, ip_str, INET_ADDRSTRLEN); - snprintf(addr_str, sizeof(addr_str), "%s:%d", ip_str, ntohs(in_addr->sin_port)); - } else { - strcpy(addr_str, "unknown"); + if (passthrough_mode_cached) { +#ifdef __APPLE__ + return raw_connect_call(sockfd, addr, addrlen); +#else + return call_real_connect(sockfd, addr, addrlen); +#endif + } + + if (addr == NULL || addrlen < (socklen_t)sizeof(addr->sa_family)) { +#ifdef __APPLE__ + return raw_connect_call(sockfd, addr, addrlen); +#else + return call_real_connect(sockfd, addr, addrlen); +#endif + } + + if (internal_connect_guard > 0) { +#ifdef __APPLE__ + return raw_connect_call(sockfd, addr, addrlen); +#else + return call_real_connect(sockfd, addr, addrlen); +#endif + } + + int sock_type = 0; + socklen_t sock_type_len = sizeof(sock_type); + if (getsockopt(sockfd, SOL_SOCKET, SO_TYPE, &sock_type, &sock_type_len) != 0) { + if (debug_enabled()) { + log_debugf("Failed to read socket type for fd=%d: %s", sockfd, strerror(errno)); + } +#ifdef __APPLE__ + return raw_connect_call(sockfd, addr, addrlen); +#else + return call_real_connect(sockfd, addr, addrlen); +#endif } - char *debug_mode = getenv("WRAPGUARD_DEBUG"); - if (debug_mode && strcmp(debug_mode, "1") == 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: connect() called for %s\n", addr_str); + char addr_str[INET6_ADDRSTRLEN + 32]; + format_sockaddr(addr, addr_str, sizeof(addr_str)); + int suppress_debug_log = addr->sa_family == AF_UNIX && sock_type == SOCK_DGRAM; + + if (debug_enabled() && !suppress_debug_log) { + log_debugf("connect() called for %s family=%s type=%s", addr_str, family_name(addr->sa_family), socket_type_name(sock_type)); + } + + if (sock_type == SOCK_DGRAM && should_block_udp_target(addr)) { + if (debug_enabled()) { + log_debugf("Blocking likely QUIC UDP flow to %s", addr_str); + } + forget_virtual_peer(sockfd); + errno = EHOSTUNREACH; + return -1; + } + + if (sock_type != SOCK_STREAM) { + if (debug_enabled() && !suppress_debug_log) { + log_debugf("NOT intercepting %s because socket type is %s", addr_str, socket_type_name(sock_type)); + } + forget_virtual_peer(sockfd); +#ifdef __APPLE__ + return raw_connect_call(sockfd, addr, addrlen); +#else + return call_real_connect(sockfd, addr, addrlen); +#endif } if (!should_intercept_connect(addr)) { - if (debug_mode && strcmp(debug_mode, "1") == 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: NOT intercepting %s\n", addr_str); + if (debug_enabled() && !suppress_debug_log) { + log_debugf("NOT intercepting %s family=%s", addr_str, family_name(addr->sa_family)); } - return real_connect(sockfd, addr, addrlen); + forget_virtual_peer(sockfd); +#ifdef __APPLE__ + return raw_connect_call(sockfd, addr, addrlen); +#else + return call_real_connect(sockfd, addr, addrlen); +#endif } - if (debug_mode && strcmp(debug_mode, "1") == 0) { - fprintf(stderr, "WrapGuard LD_PRELOAD: INTERCEPTING %s, routing through SOCKS5\n", addr_str); + if (debug_enabled()) { + log_debugf("INTERCEPTING %s, routing through SOCKS5", addr_str); } - // Send IPC message send_ipc_message("CONNECT", sockfd, 0, addr_str); - - // Route through SOCKS5 return socks5_connect(sockfd, addr, addrlen); } -// Intercepted bind function -int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { +static int wrapguard_bind_impl(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { init_library(); + if (passthrough_mode_cached) { +#ifdef __APPLE__ + return raw_bind_call(sockfd, addr, addrlen); +#else + return call_real_bind(sockfd, addr, addrlen); +#endif + } + + if (addr == NULL || addrlen < (socklen_t)sizeof(addr->sa_family)) { + return call_real_bind(sockfd, addr, addrlen); + } - // Call original bind first - int result = real_bind(sockfd, addr, addrlen); + int result = call_real_bind(sockfd, addr, addrlen); - // If bind succeeded and it's a TCP socket, notify the main process - if (result == 0 && addr->sa_family == AF_INET) { - struct sockaddr_in *in_addr = (struct sockaddr_in *)addr; - int port = ntohs(in_addr->sin_port); + if (result == 0 && (addr->sa_family == AF_INET || addr->sa_family == AF_INET6)) { + int port = 0; + if (addr->sa_family == AF_INET) { + struct sockaddr_in *in_addr = (struct sockaddr_in *)addr; + port = ntohs(in_addr->sin_port); + } else { + struct sockaddr_in6 *in6_addr = (struct sockaddr_in6 *)addr; + port = ntohs(in6_addr->sin6_port); + } - // Get the actual port if it was auto-assigned (port 0) if (port == 0) { - struct sockaddr_in actual_addr; + struct sockaddr_storage actual_addr; socklen_t actual_len = sizeof(actual_addr); if (getsockname(sockfd, (struct sockaddr *)&actual_addr, &actual_len) == 0) { - port = ntohs(actual_addr.sin_port); + if (actual_addr.ss_family == AF_INET) { + port = ntohs(((struct sockaddr_in *)&actual_addr)->sin_port); + } else if (actual_addr.ss_family == AF_INET6) { + port = ntohs(((struct sockaddr_in6 *)&actual_addr)->sin6_port); + } } } - // Check if it's a TCP socket - int sock_type; + int sock_type = 0; socklen_t opt_len = sizeof(sock_type); if (getsockopt(sockfd, SOL_SOCKET, SO_TYPE, &sock_type, &opt_len) == 0 && sock_type == SOCK_STREAM) { - // Send IPC message to set up port forwarding send_ipc_message("BIND", sockfd, port, NULL); } } return result; -} \ No newline at end of file +} + +#ifdef __APPLE__ +static int wrapguard_connectx_impl(int sockfd, const sa_endpoints_t *endpoints, sae_associd_t associd, unsigned int flags, const struct iovec *iov, unsigned int iovcnt, size_t *len, sae_connid_t *connid) { + init_library(); + if (passthrough_mode_cached) { + return raw_connectx_call(sockfd, endpoints, associd, flags, iov, iovcnt, len, connid); + } + + if (endpoints == NULL || endpoints->sae_dstaddr == NULL || endpoints->sae_dstaddrlen < (socklen_t)sizeof(sa_family_t)) { + return call_real_connectx(sockfd, endpoints, associd, flags, iov, iovcnt, len, connid); + } + + if (internal_connect_guard > 0) { + return call_real_connectx(sockfd, endpoints, associd, flags, iov, iovcnt, len, connid); + } + + if (endpoints->sae_srcif != 0 || endpoints->sae_srcaddr != NULL || endpoints->sae_srcaddrlen != 0 || iov != NULL || iovcnt != 0 || associd != SAE_ASSOCID_ANY || flags != 0) { + if (debug_enabled()) { + log_debugf("Falling back to real connectx() because advanced endpoints/options are in use"); + } + return call_real_connectx(sockfd, endpoints, associd, flags, iov, iovcnt, len, connid); + } + + if (len != NULL) { + *len = 0; + } + if (connid != NULL) { + *connid = SAE_CONNID_ANY; + } + + return wrapguard_connect_impl(sockfd, endpoints->sae_dstaddr, endpoints->sae_dstaddrlen); +} +#endif + + #ifndef __APPLE__ +static int wrapguard_getpeername_impl(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { + init_library(); + if (passthrough_mode_cached) { + return call_real_getpeername(sockfd, addr, addrlen); + } + + if (lookup_virtual_peer(sockfd, addr, addrlen)) { + return 0; + } + + return call_real_getpeername(sockfd, addr, addrlen); +} +#endif + +static int wrapguard_close_impl(int fd) { + if (initialized) { + forget_virtual_peer(fd); + } + return call_real_close(fd); +} + +#ifdef __APPLE__ +static int wrapguard_connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { + return wrapguard_connect_impl(sockfd, addr, addrlen); +} + +static int wrapguard_bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { + return wrapguard_bind_impl(sockfd, addr, addrlen); +} + +static int wrapguard_connectx(int sockfd, const sa_endpoints_t *endpoints, sae_associd_t associd, unsigned int flags, const struct iovec *iov, unsigned int iovcnt, size_t *len, sae_connid_t *connid) { + return wrapguard_connectx_impl(sockfd, endpoints, associd, flags, iov, iovcnt, len, connid); +} + +static ssize_t wrapguard_sendto(int sockfd, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen) { + return wrapguard_sendto_impl(sockfd, buf, len, flags, dest_addr, addrlen); +} + +static ssize_t wrapguard_sendmsg(int sockfd, const struct msghdr *msg, int flags) { + return wrapguard_sendmsg_impl(sockfd, msg, flags); +} + +static int wrapguard_close(int fd) { + return wrapguard_close_impl(fd); +} + +DYLD_INTERPOSE(wrapguard_connect, connect) +DYLD_INTERPOSE(wrapguard_bind, bind) +DYLD_INTERPOSE(wrapguard_connectx, connectx) +DYLD_INTERPOSE(wrapguard_sendto, sendto) +DYLD_INTERPOSE(wrapguard_sendmsg, sendmsg) +DYLD_INTERPOSE(wrapguard_close, close) +#else +int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { + return wrapguard_connect_impl(sockfd, addr, addrlen); +} + +int bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { + return wrapguard_bind_impl(sockfd, addr, addrlen); +} + +int getpeername(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { + return wrapguard_getpeername_impl(sockfd, addr, addrlen); +} + +int close(int fd) { + return wrapguard_close_impl(fd); +} + +#endif diff --git a/logger.go b/logger.go index cc19c7a..6c58438 100644 --- a/logger.go +++ b/logger.go @@ -7,6 +7,7 @@ import ( "os" "strings" "sync" + "sync/atomic" "time" ) @@ -102,14 +103,54 @@ func (l *Logger) Debugf(format string, args ...interface{}) { l.log(LogLevelDebug, format, args...) } +type globalLogger struct { + ptr atomic.Pointer[Logger] +} + +func (g *globalLogger) Load() *Logger { + return g.ptr.Load() +} + +func (g *globalLogger) Store(l *Logger) { + g.ptr.Store(l) +} + +func (g *globalLogger) Errorf(format string, args ...interface{}) { + if l := g.Load(); l != nil { + l.Errorf(format, args...) + } +} + +func (g *globalLogger) Warnf(format string, args ...interface{}) { + if l := g.Load(); l != nil { + l.Warnf(format, args...) + } +} + +func (g *globalLogger) Infof(format string, args ...interface{}) { + if l := g.Load(); l != nil { + l.Infof(format, args...) + } +} + +func (g *globalLogger) Debugf(format string, args ...interface{}) { + if l := g.Load(); l != nil { + l.Debugf(format, args...) + } +} + // Global logger instance -var logger *Logger +var logger globalLogger func init() { // Default logger to stderr with info level - logger = NewLogger(LogLevelInfo, os.Stderr) + logger.Store(NewLogger(LogLevelInfo, os.Stderr)) } func SetGlobalLogger(l *Logger) { - logger = l + logger.Store(l) +} + +func CurrentLogger() *Logger { + return logger.Load() } diff --git a/logger_test.go b/logger_test.go index c74ae88..f987e7f 100644 --- a/logger_test.go +++ b/logger_test.go @@ -304,7 +304,7 @@ func TestLogger_ConcurrentAccess(t *testing.T) { func TestSetGlobalLogger(t *testing.T) { // Save original logger - originalLogger := logger + originalLogger := CurrentLogger() // Create a new logger var buf bytes.Buffer @@ -314,7 +314,7 @@ func TestSetGlobalLogger(t *testing.T) { SetGlobalLogger(testLogger) // Verify it was set - if logger != testLogger { + if CurrentLogger() != testLogger { t.Error("global logger not set correctly") } @@ -324,15 +324,16 @@ func TestSetGlobalLogger(t *testing.T) { func TestGlobalLoggerInitialization(t *testing.T) { // The global logger should be initialized in init() - if logger == nil { + current := CurrentLogger() + if current == nil { t.Error("global logger not initialized") } - if logger.level != LogLevelInfo { - t.Errorf("expected default log level %v, got %v", LogLevelInfo, logger.level) + if current.level != LogLevelInfo { + t.Errorf("expected default log level %v, got %v", LogLevelInfo, current.level) } - if logger.output != os.Stderr { + if current.output != os.Stderr { t.Error("expected default output to be os.Stderr") } } diff --git a/main.go b/main.go index 98bb546..c41b994 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,6 @@ import ( "os" "os/exec" "os/signal" - "path/filepath" "strings" "syscall" "time" @@ -48,6 +47,9 @@ func printUsage() { help += " --route= Add routing policy (CIDR:peerIP)\n" help += " --log-level= Set log level (error, warn, info, debug)\n" help += " --log-file= Set file to write logs to (default: terminal)\n" + help += " --doctor [target] Run local runtime preflight checks\n" + help += " --self-test Launch a built-in injection self-test\n" + help += " --macos-gui-compat macOS only: do not let helper subprocesses inherit DYLD injection\n" help += " --help Show this help message\n" help += " --version Show version information\n\n" @@ -77,6 +79,10 @@ func main() { var configPath string var showHelp bool var showVersion bool + var doctorMode bool + var selfTestMode bool + var macOSGUICompat bool + var internalSelfTestProbe string var logLevelStr string var logFile string var exitNode string @@ -84,9 +90,13 @@ func main() { flag.StringVar(&configPath, "config", "", "Path to WireGuard configuration file") flag.BoolVar(&showHelp, "help", false, "Show help message") flag.BoolVar(&showVersion, "version", false, "Show version information") + flag.BoolVar(&doctorMode, "doctor", false, "Run local runtime preflight checks") + flag.BoolVar(&selfTestMode, "self-test", false, "Launch a built-in injection self-test") + flag.BoolVar(&macOSGUICompat, "macos-gui-compat", false, "macOS only: stop DYLD injection from being inherited by helper subprocesses") flag.StringVar(&logLevelStr, "log-level", "info", "Set log level (error, warn, info, debug)") flag.StringVar(&logFile, "log-file", "", "Set file to write logs to (default: terminal)") flag.StringVar(&exitNode, "exit-node", "", "Route all traffic through specified peer IP (e.g., 10.0.0.3)") + flag.StringVar(&internalSelfTestProbe, "internal-self-test-probe", "", "internal self-test probe") flag.Func("route", "Add routing policy (format: CIDR:peerIP, e.g., 192.168.1.0/24:10.0.0.3)", func(value string) error { routes = append(routes, value) return nil @@ -94,6 +104,10 @@ func main() { flag.Usage = printUsage flag.Parse() + if internalSelfTestProbe != "" { + os.Exit(runInternalSelfTestProbe(internalSelfTestProbe)) + } + if showVersion { fmt.Printf("wrapguard version %s\n", version) os.Exit(0) @@ -104,7 +118,7 @@ func main() { os.Exit(0) } - if configPath == "" { + if configPath == "" && !doctorMode { printUsage() os.Exit(1) } @@ -133,12 +147,47 @@ func main() { SetGlobalLogger(logger) args := flag.Args() - if len(args) == 0 { + if doctorMode { + execPath, err := os.Executable() + if err != nil { + logger.Errorf("Failed to get executable path: %v", err) + os.Exit(1) + } + + target := "" + if len(args) > 0 { + target = args[0] + } + os.Exit(runDoctor(execPath, target, os.Stdout)) + } + + if len(args) == 0 && !selfTestMode { fmt.Fprintf(os.Stderr, "\n\033[31m✗ Error:\033[0m No command specified\n") printUsage() os.Exit(1) } + execPath, err := os.Executable() + if err != nil { + fmt.Fprintf(os.Stderr, "\n\033[31m✗ Error:\033[0m Failed to get executable path: %v\n", err) + os.Exit(1) + } + + libPath, injectCfg, err := resolveInjectedLibraryPath(execPath) + if err != nil { + fmt.Fprintf(os.Stderr, "\n\033[31m✗ Error:\033[0m Failed to resolve injection library: %v\n", err) + os.Exit(1) + } + + var launchDetails *launchTargetDetails + if !selfTestMode && len(args) > 0 { + launchDetails, err = validateLaunchTargetWithLibrary(args[0], libPath) + if err != nil { + fmt.Fprintf(os.Stderr, "\n\033[31m✗ Error:\033[0m Launch target is not supported on this platform: %v\n", err) + os.Exit(1) + } + } + // Parse WireGuard configuration config, err := ParseConfig(configPath) if err != nil { @@ -154,7 +203,7 @@ func main() { } } - // Create IPC server for communication with LD_PRELOAD library + // Create IPC server for communication with the injected library. ipcServer, err := NewIPCServer() if err != nil { logger.Errorf("Failed to start IPC server: %v", err) @@ -165,6 +214,8 @@ func main() { // Create context for cancellation ctx, cancel := context.WithCancel(context.Background()) defer cancel() + stopIPCLogger := startIPCEventLogger(ctx, ipcServer, logLevel == LogLevelDebug) + defer stopIPCLogger() // Start WireGuard tunnel logger.Infof("Creating WireGuard tunnel...") @@ -197,28 +248,39 @@ func main() { if len(config.Peers) > 0 { logger.Infof("Peer endpoint: %s", config.Peers[0].Endpoint) } - logger.Infof("Launching: [%s]", strings.Join(args, " ")) + if !selfTestMode { + logger.Infof("Launching: [%s]", strings.Join(args, " ")) + } - // Get path to our LD_PRELOAD library - execPath, err := os.Executable() - if err != nil { - logger.Errorf("Failed to get executable path: %v", err) - os.Exit(1) + logger.Infof("Injection mode: %s", injectCfg.LibraryEnvVar) + logger.Infof("Injection library: %s", libPath) + + if selfTestMode { + os.Exit(runSelfTest(ctx, ipcServer, socksServer, execPath, libPath, injectCfg, logLevel == LogLevelDebug)) } - libPath := filepath.Join(filepath.Dir(execPath), "libwrapguard.so") // Prepare child process - cmd := exec.Command(args[0], args[1:]...) + launchTarget := args[0] + if launchDetails != nil && launchDetails.ResolvedPath != "" { + launchTarget = launchDetails.ResolvedPath + } + cmd := exec.Command(launchTarget, args[1:]...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + cmd.SysProcAttr = childSysProcAttr() + + cmd.Env = buildChildEnv(os.Environ(), injectCfg, libPath, ipcServer.SocketPath(), socksServer.Port(), logLevel == LogLevelDebug, macOSGUICompat) + logger.Debugf("Child environment prepared with %s=%s", injectCfg.LibraryEnvVar, libPath) + logger.Debugf("Child environment prepared with %s=%s", envWrapGuardIPCPath, ipcServer.SocketPath()) + logger.Debugf("Child environment prepared with %s=%d", envWrapGuardSOCKSPort, socksServer.Port()) + if macOSGUICompat && currentPlatformName() == "darwin" { + logger.Infof("macOS GUI compatibility mode enabled: helper subprocesses will not inherit WrapGuard DYLD injection") + logger.Debugf("Child environment prepared with %s=1", envWrapGuardNoInherit) + } - // Set LD_PRELOAD and IPC socket path - cmd.Env = append(os.Environ(), - fmt.Sprintf("LD_PRELOAD=%s", libPath), - fmt.Sprintf("WRAPGUARD_IPC_PATH=%s", ipcServer.SocketPath()), - fmt.Sprintf("WRAPGUARD_SOCKS_PORT=%d", socksServer.Port()), - ) + readySubID, readyCh := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(readySubID) // Start the child process if err := cmd.Start(); err != nil { @@ -226,40 +288,32 @@ func main() { os.Exit(1) } - // Handle signals - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - // Wait for child process or signal done := make(chan error, 1) go func() { done <- cmd.Wait() }() - select { - case err := <-done: - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - os.Exit(exitErr.ExitCode()) - } - logger.Errorf("Child process error: %v", err) - os.Exit(1) - } - // Exit cleanly when child process completes successfully - os.Exit(0) - case sig := <-sigChan: - logger.Infof("Received signal %v, shutting down...", sig) - // Forward signal to child process - if cmd.Process != nil { - cmd.Process.Signal(sig) - } - // Wait for child to exit - select { - case <-done: - case <-time.After(5 * time.Second): - logger.Warnf("Child process did not exit gracefully, killing...") - cmd.Process.Kill() - } + readyTimeout := initialHandshakeTimeout(currentPlatformName(), launchTarget) + readyMsg, err := waitForIPCMessage(readyCh, done, readyTimeout, "READY") + if err != nil { + cancel() + _ = socksServer.Close() + _ = ipcServer.Close() + _ = signalWrappedProcess(cmd, syscall.SIGKILL) + logger.Errorf("Injected library handshake failed: %v", err) os.Exit(1) } + + logger.Infof("Interceptor handshake completed from pid %d", readyMsg.PID) + + // Handle signals + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + exitCode := waitForWrappedCommand(cmd, done, sigChan, func() { + cancel() + _ = socksServer.Close() + _ = ipcServer.Close() + }, 5*time.Second) + os.Exit(exitCode) } diff --git a/packaging_regression_test.go b/packaging_regression_test.go new file mode 100644 index 0000000..7e81caf --- /dev/null +++ b/packaging_regression_test.go @@ -0,0 +1,77 @@ +package main + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestReleaseWorkflowPackagesExpectedMacOSArtifacts(t *testing.T) { + data, err := os.ReadFile(filepath.Join(".github", "workflows", "release.yml")) + if err != nil { + t.Fatalf("failed to read release workflow: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + `archive="wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz"`, + `wrapguard libwrapguard.dylib`, + `test -f "$verify_dir/libwrapguard.dylib"`, + `name: Verify macOS release archives`, + `uses: actions/download-artifact@v4`, + `needs: verify-macos-release-archives`, + `asset_name: wrapguard-${{ github.event.release.tag_name }}-darwin-${{ matrix.arch }}.tar.gz`, + } + + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("release workflow missing required macOS archive snippet: %q", snippet) + } + } +} + +func TestSmokeMacOSTargetValidatesExpectedRuntimeArtifacts(t *testing.T) { + data, err := os.ReadFile("Makefile") + if err != nil { + t.Fatalf("failed to read Makefile: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + `cp "$(DIST_DIR)/darwin-$(TARGET_GOARCH)/$(BINARY_NAME)" "$$package_dir/";`, + `cp "$(DIST_DIR)/darwin-$(TARGET_GOARCH)/$(LIBRARY_NAME)" "$$package_dir/";`, + `tar -C "$$package_dir" -czf "$$staging/$(BINARY_NAME)-macos-smoke.tar.gz" $(BINARY_NAME) $(LIBRARY_NAME) README.md example-wg0.conf;`, + `test -f "$$verify_dir/$(LIBRARY_NAME)";`, + `build-macos-universal`, + `lipo -create "$$stage_dir/amd64/$(BINARY_NAME)" "$$stage_dir/arm64/$(BINARY_NAME)" -output "$$final_dir/$(BINARY_NAME)";`, + } + + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("Makefile missing required macOS smoke packaging snippet: %q", snippet) + } + } +} + +func TestMacOSReleaseNotesTemplateDocumentsSupportMatrix(t *testing.T) { + data, err := os.ReadFile(filepath.Join("docs", "release-notes-macos.md")) + if err != nil { + t.Fatalf("failed to read macOS release notes template: %v", err) + } + + content := string(data) + requiredSnippets := []string{ + `## Support Matrix`, + `macOS 14 Sonoma`, + `macOS 15 Sequoia`, + `## Known Limitations`, + `## Example Commands`, + } + + for _, snippet := range requiredSnippets { + if !strings.Contains(content, snippet) { + t.Fatalf("release notes template missing required snippet: %q", snippet) + } + } +} diff --git a/platform.go b/platform.go new file mode 100644 index 0000000..b692c3a --- /dev/null +++ b/platform.go @@ -0,0 +1,561 @@ +package main + +import ( + "bufio" + "debug/macho" + "errors" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + "slices" + "strings" + "time" +) + +const ( + envWrapGuardIPCPath = "WRAPGUARD_IPC_PATH" + envWrapGuardSOCKSPort = "WRAPGUARD_SOCKS_PORT" + envWrapGuardDebug = "WRAPGUARD_DEBUG" + envWrapGuardDebugIPC = "WRAPGUARD_DEBUG_IPC" + envWrapGuardBlockUDP = "WRAPGUARD_BLOCK_UDP_443" + envWrapGuardNoInherit = "WRAPGUARD_MACOS_NO_INHERIT" + envWrapGuardExpectRDY = "WRAPGUARD_EXPECT_READY" +) + +type injectionConfig struct { + LibraryName string + LibraryEnvVar string + RequiresFlatNamespace bool +} + +func currentInjectionConfig() (injectionConfig, error) { + return injectionConfigForGOOS(runtime.GOOS) +} + +func currentPlatformName() string { + return runtime.GOOS +} + +func injectionConfigForGOOS(goos string) (injectionConfig, error) { + switch goos { + case "linux": + return injectionConfig{ + LibraryName: "libwrapguard.so", + LibraryEnvVar: "LD_PRELOAD", + }, nil + case "darwin": + return injectionConfig{ + LibraryName: "libwrapguard.dylib", + LibraryEnvVar: "DYLD_INSERT_LIBRARIES", + }, nil + default: + return injectionConfig{}, fmt.Errorf("unsupported platform: %s", goos) + } +} + +func resolveInjectedLibraryPath(execPath string) (string, injectionConfig, error) { + cfg, err := currentInjectionConfig() + if err != nil { + return "", injectionConfig{}, err + } + + candidateDirs := []string{filepath.Dir(execPath)} + if invokedPath, err := exec.LookPath(os.Args[0]); err == nil { + candidateDirs = append(candidateDirs, filepath.Dir(invokedPath)) + } + + var statErrs []string + seen := make(map[string]struct{}, len(candidateDirs)) + for _, dir := range candidateDirs { + if dir == "" { + continue + } + dir = filepath.Clean(dir) + if _, ok := seen[dir]; ok { + continue + } + seen[dir] = struct{}{} + + libPath := filepath.Join(dir, cfg.LibraryName) + if _, err := os.Stat(libPath); err == nil { + return libPath, cfg, nil + } else if os.IsNotExist(err) { + statErrs = append(statErrs, libPath) + continue + } else { + return "", injectionConfig{}, fmt.Errorf("failed to stat injection library %s: %w", libPath, err) + } + } + + return "", injectionConfig{}, fmt.Errorf("required injection library not found; searched: %s", strings.Join(statErrs, ", ")) +} + +func buildChildEnv(baseEnv []string, cfg injectionConfig, libraryPath, ipcPath string, socksPort int, debug bool, macOSNoInherit bool) []string { + envMap := make(map[string]string, len(baseEnv)+6) + envOrder := make([]string, 0, len(baseEnv)+6) + + for _, entry := range baseEnv { + parts := strings.SplitN(entry, "=", 2) + key := parts[0] + value := "" + if len(parts) == 2 { + value = parts[1] + } + if _, exists := envMap[key]; !exists { + envOrder = append(envOrder, key) + } + envMap[key] = value + } + + setEnv := func(key, value string) { + if _, exists := envMap[key]; !exists { + envOrder = append(envOrder, key) + } + envMap[key] = value + } + unsetEnv := func(key string) { + delete(envMap, key) + } + + setEnv(cfg.LibraryEnvVar, mergeInjectionLibraryValue(cfg, envMap[cfg.LibraryEnvVar], libraryPath)) + if cfg.RequiresFlatNamespace { + setEnv("DYLD_FORCE_FLAT_NAMESPACE", "1") + } else if cfg.LibraryEnvVar == "DYLD_INSERT_LIBRARIES" { + unsetEnv("DYLD_FORCE_FLAT_NAMESPACE") + } + setEnv(envWrapGuardExpectRDY, "1") + setEnv(envWrapGuardIPCPath, ipcPath) + setEnv(envWrapGuardSOCKSPort, fmt.Sprintf("%d", socksPort)) + if cfg.LibraryEnvVar == "DYLD_INSERT_LIBRARIES" { + setEnv(envWrapGuardBlockUDP, "1") + if macOSNoInherit { + setEnv(envWrapGuardNoInherit, "1") + } + } + if debug { + setEnv(envWrapGuardDebug, "1") + if cfg.LibraryEnvVar == "DYLD_INSERT_LIBRARIES" { + setEnv(envWrapGuardDebugIPC, "1") + } + } + + result := make([]string, 0, len(envOrder)) + for _, key := range envOrder { + value, ok := envMap[key] + if !ok { + continue + } + result = append(result, fmt.Sprintf("%s=%s", key, value)) + } + + return result +} + +func initialHandshakeTimeout(goos, requestedTarget string) time.Duration { + if goos != "darwin" { + return 3 * time.Second + } + + target := strings.TrimSpace(requestedTarget) + if strings.Contains(target, ".app/") || strings.HasSuffix(target, ".app") { + return 15 * time.Second + } + + return 3 * time.Second +} + +func mergeInjectionLibraryValue(cfg injectionConfig, existingValue, libraryPath string) string { + if strings.TrimSpace(existingValue) == "" { + return libraryPath + } + + separator := ":" + if cfg.LibraryEnvVar == "LD_PRELOAD" { + separator = " " + } + + for _, entry := range strings.FieldsFunc(existingValue, func(r rune) bool { + return r == ':' || r == ' ' || r == '\t' + }) { + if entry == libraryPath { + return existingValue + } + } + + return libraryPath + separator + existingValue +} + +type launchTargetDetails struct { + RequestedPath string + ResolvedPath string + InjectionTargetPath string + UsedInterpreter bool + InterpreterPath string +} + +type launchTargetSecurityInfo struct { + SigningStatus string + HardenedRuntime string + InspectionNotice string +} + +func validateLaunchTargetWithLibrary(command, libraryPath string) (*launchTargetDetails, error) { + if runtime.GOOS != "darwin" { + return &launchTargetDetails{RequestedPath: command}, nil + } + + details := &launchTargetDetails{ + RequestedPath: command, + } + + resolvedPath := command + var err error + if strings.HasSuffix(command, ".app") { + resolvedPath, err = resolveAppBundleExecutablePath(command) + if err != nil { + return nil, err + } + details.ResolvedPath = resolvedPath + } else { + resolvedPath, err = exec.LookPath(command) + if err != nil { + return nil, fmt.Errorf("failed to resolve launch target %q: %w", command, err) + } + } + + if !filepath.IsAbs(resolvedPath) { + resolvedPath, err = filepath.Abs(resolvedPath) + if err != nil { + return nil, fmt.Errorf("failed to resolve launch target path: %w", err) + } + } + + details.InjectionTargetPath = resolvedPath + + if interpreterPath, ok, err := resolveScriptInterpreter(resolvedPath); err != nil { + return nil, err + } else if ok { + details.UsedInterpreter = true + details.InterpreterPath = interpreterPath + details.InjectionTargetPath = interpreterPath + } + + protectedPrefixes := []string{ + "/System/", + "/bin/", + "/sbin/", + "/usr/bin/", + "/usr/libexec/", + } + for _, prefix := range protectedPrefixes { + if strings.HasPrefix(details.InjectionTargetPath, prefix) { + if details.UsedInterpreter { + return nil, fmt.Errorf("launch target %s uses SIP-protected interpreter %s and cannot be wrapped via DYLD injection", resolvedPath, details.InjectionTargetPath) + } + return nil, fmt.Errorf("launch target %s is protected by macOS SIP and cannot be wrapped via DYLD injection", details.InjectionTargetPath) + } + } + + if libraryPath != "" { + targetArchs, err := machOArchitectures(details.InjectionTargetPath) + if err != nil { + return nil, fmt.Errorf("failed to inspect launch target architecture for %s: %w", details.InjectionTargetPath, err) + } + + libraryArchs, err := machOArchitectures(libraryPath) + if err != nil { + return nil, fmt.Errorf("failed to inspect injection library architecture for %s: %w", libraryPath, err) + } + + if !archSetsOverlap(targetArchs, libraryArchs) { + return nil, fmt.Errorf( + "launch target architecture %s is incompatible with injection library architecture %s", + strings.Join(targetArchs, ", "), + strings.Join(libraryArchs, ", "), + ) + } + } + + return details, nil +} + +func validateLaunchTarget(command string) error { + _, err := validateLaunchTargetWithLibrary(command, "") + return err +} + +func inspectLaunchTargetSecurityInfo(targetPath, codesignPath string) (launchTargetSecurityInfo, error) { + if codesignPath == "" { + var err error + codesignPath, err = exec.LookPath("codesign") + if err != nil { + return launchTargetSecurityInfo{ + SigningStatus: "unknown", + HardenedRuntime: "unknown", + }, fmt.Errorf("codesign tool not found: %w", err) + } + } + + cmd := exec.Command(codesignPath, "-dv", "--verbose=4", targetPath) + output, err := cmd.CombinedOutput() + info := parseLaunchTargetSecurityInfo(string(output)) + if info.SigningStatus == "" { + info.SigningStatus = "unknown" + } + if info.HardenedRuntime == "" { + info.HardenedRuntime = "unknown" + } + + lowerOutput := strings.ToLower(string(output)) + if err == nil { + return info, nil + } + + if strings.Contains(lowerOutput, "code object is not signed at all") { + if info.SigningStatus == "unknown" { + info.SigningStatus = "unsigned" + } + if info.HardenedRuntime == "unknown" { + info.HardenedRuntime = "disabled" + } + return info, nil + } + + if info.SigningStatus != "unknown" || info.HardenedRuntime != "unknown" { + return info, nil + } + + return info, fmt.Errorf("failed to inspect code signature metadata: %w", err) +} + +func parseLaunchTargetSecurityInfo(output string) launchTargetSecurityInfo { + lowerOutput := strings.ToLower(output) + info := launchTargetSecurityInfo{ + SigningStatus: "unknown", + HardenedRuntime: "unknown", + } + + switch { + case strings.Contains(lowerOutput, "code object is not signed at all"): + info.SigningStatus = "unsigned" + info.HardenedRuntime = "disabled" + case strings.Contains(lowerOutput, "signature=adhoc"): + info.SigningStatus = "ad-hoc" + case strings.Contains(lowerOutput, "authority="): + info.SigningStatus = "signed" + } + + if strings.Contains(lowerOutput, "flags=") { + if strings.Contains(lowerOutput, "runtime") { + info.HardenedRuntime = "enabled" + if info.SigningStatus == "signed" { + info.InspectionNotice = "DYLD injection may still be rejected at runtime by the target's hardened runtime policy" + } + } else if info.HardenedRuntime == "unknown" { + info.HardenedRuntime = "disabled" + } + } + + if info.SigningStatus == "ad-hoc" && info.HardenedRuntime == "unknown" { + info.HardenedRuntime = "disabled" + } + + return info +} + +func reportLaunchTargetSecurityInfo(output io.Writer, targetPath, codesignPath string) error { + info, err := inspectLaunchTargetSecurityInfo(targetPath, codesignPath) + fmt.Fprintf(output, "doctor: target-signing=%s\n", info.SigningStatus) + fmt.Fprintf(output, "doctor: target-hardened-runtime=%s\n", info.HardenedRuntime) + if info.InspectionNotice != "" { + fmt.Fprintf(output, "doctor: advisory: %s\n", info.InspectionNotice) + } + return err +} + +func resolveAppBundleExecutablePath(bundlePath string) (string, error) { + absBundlePath, err := filepath.Abs(bundlePath) + if err != nil { + return "", fmt.Errorf("failed to resolve app bundle path %s: %w", bundlePath, err) + } + + info, err := os.Stat(absBundlePath) + if err != nil { + return "", fmt.Errorf("failed to inspect app bundle %s: %w", absBundlePath, err) + } + if !info.IsDir() { + return "", fmt.Errorf("%s is not a macOS app bundle directory", absBundlePath) + } + + macOSDir := filepath.Join(absBundlePath, "Contents", "MacOS") + entries, err := os.ReadDir(macOSDir) + if err != nil { + return "", fmt.Errorf("failed to inspect app bundle executable directory %s: %w", macOSDir, err) + } + + baseName := strings.TrimSuffix(filepath.Base(absBundlePath), ".app") + var candidatePath string + candidateNames := make([]string, 0, len(entries)) + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + entryInfo, err := entry.Info() + if err != nil { + return "", fmt.Errorf("failed to inspect app bundle executable %s: %w", filepath.Join(macOSDir, entry.Name()), err) + } + if entryInfo.Mode()&0o111 == 0 { + continue + } + + candidateNames = append(candidateNames, entry.Name()) + fullPath := filepath.Join(macOSDir, entry.Name()) + if entry.Name() == baseName { + return fullPath, nil + } + if candidatePath == "" { + candidatePath = fullPath + } + } + + if len(candidateNames) == 1 { + return candidatePath, nil + } + if len(candidateNames) > 1 { + slices.Sort(candidateNames) + return "", fmt.Errorf( + "app bundle %s has multiple executable candidates in Contents/MacOS: %s", + absBundlePath, + strings.Join(candidateNames, ", "), + ) + } + + return "", fmt.Errorf("app bundle %s does not contain an executable in Contents/MacOS", absBundlePath) +} + +func resolveScriptInterpreter(path string) (string, bool, error) { + file, err := os.Open(path) + if err != nil { + return "", false, fmt.Errorf("failed to inspect launch target %s: %w", path, err) + } + defer file.Close() + + reader := bufio.NewReader(file) + line, err := reader.ReadString('\n') + if err != nil && !errors.Is(err, io.EOF) { + return "", false, fmt.Errorf("failed to read launch target %s: %w", path, err) + } + if !strings.HasPrefix(line, "#!") { + return "", false, nil + } + + fields := strings.Fields(strings.TrimSpace(strings.TrimPrefix(line, "#!"))) + if len(fields) == 0 { + return "", false, nil + } + + interpreter := fields[0] + if !filepath.IsAbs(interpreter) { + resolved, err := exec.LookPath(interpreter) + if err != nil { + return "", false, fmt.Errorf("failed to resolve script interpreter %q for %s: %w", interpreter, path, err) + } + interpreter = resolved + } + + if filepath.Base(interpreter) == "env" { + for _, arg := range fields[1:] { + if strings.HasPrefix(arg, "-") { + continue + } + resolved, err := exec.LookPath(arg) + if err == nil { + interpreter = resolved + } + break + } + } + + interpreter, err = filepath.Abs(interpreter) + if err != nil { + return "", false, fmt.Errorf("failed to resolve interpreter path for %s: %w", path, err) + } + + return interpreter, true, nil +} + +func machOArchitectures(path string) ([]string, error) { + if fat, err := macho.OpenFat(path); err == nil { + defer fat.Close() + + archs := make([]string, 0, len(fat.Arches)) + for _, arch := range fat.Arches { + archName := machoCPUArchName(arch.Cpu) + if archName == "" { + archName = fmt.Sprintf("cpu-%d", arch.Cpu) + } + archs = append(archs, archName) + } + return compactArchitectures(archs), nil + } + + file, err := macho.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + archName := machoCPUArchName(file.Cpu) + if archName == "" { + archName = fmt.Sprintf("cpu-%d", file.Cpu) + } + return []string{archName}, nil +} + +func machoCPUArchName(cpu macho.Cpu) string { + switch cpu { + case macho.CpuAmd64: + return "amd64" + case macho.CpuArm64: + return "arm64" + default: + return "" + } +} + +func compactArchitectures(archs []string) []string { + if len(archs) == 0 { + return nil + } + + seen := make(map[string]struct{}, len(archs)) + result := make([]string, 0, len(archs)) + for _, arch := range archs { + if arch == "" { + continue + } + if _, ok := seen[arch]; ok { + continue + } + seen[arch] = struct{}{} + result = append(result, arch) + } + slices.Sort(result) + return result +} + +func archSetsOverlap(left, right []string) bool { + for _, lhs := range left { + for _, rhs := range right { + if lhs == rhs { + return true + } + } + } + return false +} diff --git a/process_group_other.go b/process_group_other.go new file mode 100644 index 0000000..e47a714 --- /dev/null +++ b/process_group_other.go @@ -0,0 +1,19 @@ +//go:build !linux && !darwin + +package main + +import ( + "os/exec" + "syscall" +) + +func childSysProcAttr() *syscall.SysProcAttr { + return nil +} + +func signalWrappedProcess(cmd *exec.Cmd, sig syscall.Signal) error { + if cmd == nil || cmd.Process == nil { + return nil + } + return cmd.Process.Signal(sig) +} diff --git a/process_group_unix.go b/process_group_unix.go new file mode 100644 index 0000000..b2b9934 --- /dev/null +++ b/process_group_unix.go @@ -0,0 +1,25 @@ +//go:build linux || darwin + +package main + +import ( + "os/exec" + "syscall" +) + +func childSysProcAttr() *syscall.SysProcAttr { + return &syscall.SysProcAttr{Setpgid: true} +} + +func signalWrappedProcess(cmd *exec.Cmd, sig syscall.Signal) error { + if cmd == nil || cmd.Process == nil { + return nil + } + + pgid, err := syscall.Getpgid(cmd.Process.Pid) + if err == nil && pgid > 0 { + return syscall.Kill(-pgid, sig) + } + + return cmd.Process.Signal(sig) +} diff --git a/routing.go b/routing.go index 04805ad..9a52e1d 100644 --- a/routing.go +++ b/routing.go @@ -42,7 +42,7 @@ func NewRoutingEngine(config *WireGuardConfig) *RoutingEngine { for _, allowedIP := range peer.AllowedIPs { prefix, err := netip.ParsePrefix(allowedIP) if err != nil { - if logger != nil { + if CurrentLogger() != nil { logger.Warnf("Invalid AllowedIP %s for peer %d: %v", allowedIP, peerIdx, err) } continue diff --git a/runtime_helpers.go b/runtime_helpers.go new file mode 100644 index 0000000..910a79b --- /dev/null +++ b/runtime_helpers.go @@ -0,0 +1,336 @@ +package main + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "os/exec" + "strings" + "syscall" + "time" +) + +const selfTestProbeTarget = "203.0.113.1:443" + +func startIPCEventLogger(ctx context.Context, server *IPCServer, enabled bool) func() { + if !enabled || server == nil { + return func() {} + } + + subID, ch := server.Subscribe() + done := make(chan struct{}) + + go func() { + defer close(done) + for { + select { + case <-ctx.Done(): + return + case msg, ok := <-ch: + if !ok { + return + } + switch msg.Type { + case "READY": + logger.Debugf("Interceptor READY from pid %d", msg.PID) + case "CONNECT": + logger.Debugf("Interceptor CONNECT from pid %d to %s", msg.PID, msg.Addr) + case "BIND": + logger.Debugf("Interceptor BIND from pid %d on port %d", msg.PID, msg.Port) + case "DEBUG": + if msg.Detail != "" { + logger.Debugf("Interceptor DEBUG from pid %d: %s", msg.PID, msg.Detail) + } else { + logger.Debugf("Interceptor DEBUG from pid %d addr=%s port=%d", msg.PID, msg.Addr, msg.Port) + } + case "UDP_BLOCK": + logger.Debugf("Interceptor UDP_BLOCK from pid %d to %s (%s)", msg.PID, msg.Addr, msg.Detail) + case "UDP_SEND": + logger.Debugf("Interceptor UDP_SEND from pid %d to %s (%s)", msg.PID, msg.Addr, msg.Detail) + case "ERROR": + if msg.Detail != "" { + logger.Warnf("Interceptor ERROR from pid %d: %s", msg.PID, msg.Detail) + } else { + logger.Warnf("Interceptor ERROR from pid %d: %s", msg.PID, msg.Addr) + } + default: + logger.Debugf("IPC event %s from pid %d addr=%s port=%d detail=%s", msg.Type, msg.PID, msg.Addr, msg.Port, msg.Detail) + } + } + } + }() + + return func() { + server.Unsubscribe(subID) + <-done + } +} + +func waitForIPCMessage(msgCh <-chan IPCMessage, done <-chan error, timeout time.Duration, wantType string) (IPCMessage, error) { + timer := time.NewTimer(timeout) + defer timer.Stop() + + for { + select { + case msg, ok := <-msgCh: + if !ok { + return IPCMessage{}, fmt.Errorf("ipc subscriber closed while waiting for %s", wantType) + } + if msg.Type == wantType { + return msg, nil + } + case err := <-done: + for { + select { + case msg, ok := <-msgCh: + if !ok { + goto childExit + } + if msg.Type == wantType { + return msg, nil + } + default: + goto childExit + } + } + childExit: + if err != nil { + return IPCMessage{}, fmt.Errorf("child exited before %s: %w", wantType, err) + } + return IPCMessage{}, fmt.Errorf("child exited before %s", wantType) + case <-timer.C: + return IPCMessage{}, fmt.Errorf("timed out waiting for %s", wantType) + } + } +} + +func waitForWrappedCommand(cmd *exec.Cmd, done <-chan error, sigCh <-chan os.Signal, onTerminate func(), gracePeriod time.Duration) int { + select { + case err := <-done: + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + return exitErr.ExitCode() + } + logger.Errorf("Child process error: %v", err) + return 1 + } + return 0 + case sig := <-sigCh: + logger.Infof("Received signal %v, shutting down...", sig) + if onTerminate != nil { + onTerminate() + } + + sysSig, ok := sig.(syscall.Signal) + if !ok { + sysSig = syscall.SIGTERM + } + if err := signalWrappedProcess(cmd, sysSig); err != nil && !errors.Is(err, os.ErrProcessDone) { + logger.Warnf("Failed to forward signal %v to child: %v", sig, err) + } + + select { + case <-done: + case <-time.After(gracePeriod): + logger.Warnf("Child process did not exit gracefully, killing...") + _ = signalWrappedProcess(cmd, syscall.SIGKILL) + <-done + } + + return 1 + } +} + +func probeIPCReachability(socketPath string) error { + conn, err := net.DialTimeout("unix", socketPath, time.Second) + if err != nil { + return err + } + return conn.Close() +} + +func probeSOCKSReachability(port int) error { + conn, err := net.DialTimeout("tcp", net.JoinHostPort("127.0.0.1", fmt.Sprintf("%d", port)), time.Second) + if err != nil { + return err + } + defer conn.Close() + + if _, err := conn.Write([]byte{0x05, 0x01, 0x00}); err != nil { + return err + } + + reply := make([]byte, 2) + if _, err := io.ReadFull(conn, reply); err != nil { + return err + } + if reply[0] != 0x05 { + return fmt.Errorf("unexpected SOCKS version byte %d", reply[0]) + } + + return nil +} + +func runDoctor(execPath, launchTarget string, output io.Writer) int { + if output == nil { + output = os.Stdout + } + + libPath, injectCfg, err := resolveInjectedLibraryPath(execPath) + if err != nil { + fmt.Fprintf(output, "doctor: runtime library check failed: %v\n", err) + return 1 + } + + fmt.Fprintf(output, "doctor: platform=%s injection=%s library=%s\n", currentPlatformName(), injectCfg.LibraryEnvVar, libPath) + + if launchTarget == "" { + fmt.Fprintln(output, "doctor: no launch target supplied; preflight completed for local runtime artifacts only") + return 0 + } + + resolvedTarget := launchTarget + var details *launchTargetDetails + if currentPlatformName() == "darwin" && strings.HasSuffix(launchTarget, ".app") { + details, err = validateLaunchTargetWithLibrary(launchTarget, libPath) + if err != nil { + fmt.Fprintf(output, "doctor: launch target unsupported: %v\n", err) + return 1 + } + if details != nil && details.ResolvedPath != "" { + resolvedTarget = details.ResolvedPath + } + } else { + if lookupPath, err := exec.LookPath(launchTarget); err == nil { + resolvedTarget = lookupPath + } else { + fmt.Fprintf(output, "doctor: target lookup failed: %v\n", err) + return 1 + } + + details, err = validateLaunchTargetWithLibrary(launchTarget, libPath) + if err != nil { + fmt.Fprintf(output, "doctor: launch target unsupported: %v\n", err) + return 1 + } + } + + fmt.Fprintf(output, "doctor: target=%s\n", resolvedTarget) + if details != nil && details.UsedInterpreter { + fmt.Fprintf(output, "doctor: script interpreter=%s\n", details.InterpreterPath) + } + if currentPlatformName() == "darwin" { + reportDarwinLaunchTargetAdvisories(output, details) + } + if currentPlatformName() == "darwin" && details != nil && details.InjectionTargetPath != "" { + if targetArchs, err := machOArchitectures(details.InjectionTargetPath); err == nil && len(targetArchs) > 0 { + fmt.Fprintf(output, "doctor: target-arch=%s\n", strings.Join(targetArchs, ",")) + } + if libraryArchs, err := machOArchitectures(libPath); err == nil && len(libraryArchs) > 0 { + fmt.Fprintf(output, "doctor: library-arch=%s\n", strings.Join(libraryArchs, ",")) + } + } + + if currentPlatformName() == "darwin" { + if err := reportLaunchTargetSecurityInfo(output, resolvedTarget, ""); err != nil { + fmt.Fprintf(output, "doctor: advisory: failed to inspect code signature metadata: %v\n", err) + } + } + + fmt.Fprintln(output, "doctor: launch target passed preflight") + return 0 +} + +func reportDarwinLaunchTargetAdvisories(output io.Writer, details *launchTargetDetails) { + if output == nil || details == nil { + return + } + + if strings.HasSuffix(details.RequestedPath, ".app") && details.ResolvedPath != "" { + fmt.Fprintf(output, "doctor: app-bundle-resolved=%s\n", details.ResolvedPath) + } + + injectionTarget := details.InjectionTargetPath + if injectionTarget == "" { + injectionTarget = details.ResolvedPath + } + if strings.Contains(injectionTarget, ".app/Contents/MacOS/") { + fmt.Fprintln(output, "doctor: advisory: macOS GUI launches are experimental and only supported through the directly launched inner executable path") + fmt.Fprintln(output, "doctor: advisory: if this app hands work off to an already-running session or external launcher, WrapGuard will not control the real process tree") + } +} + +func runSelfTest(ctx context.Context, ipcServer *IPCServer, socksServer *SOCKS5Server, execPath, libPath string, injectCfg injectionConfig, debug bool) int { + if err := probeIPCReachability(ipcServer.SocketPath()); err != nil { + logger.Errorf("Self-test failed: IPC socket is not reachable: %v", err) + return 1 + } + logger.Infof("Self-test check passed: IPC socket is reachable") + + if err := probeSOCKSReachability(socksServer.Port()); err != nil { + logger.Errorf("Self-test failed: SOCKS listener is not reachable: %v", err) + return 1 + } + logger.Infof("Self-test check passed: SOCKS listener is reachable") + + subID, events := ipcServer.Subscribe() + defer ipcServer.Unsubscribe(subID) + + cmd := exec.CommandContext(ctx, execPath, "--internal-self-test-probe="+selfTestProbeTarget) + cmd.Env = buildChildEnv(os.Environ(), injectCfg, libPath, ipcServer.SocketPath(), socksServer.Port(), debug, false) + cmd.SysProcAttr = childSysProcAttr() + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + logger.Errorf("Self-test failed to start probe child: %v", err) + return 1 + } + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + readyMsg, err := waitForIPCMessage(events, done, 3*time.Second, "READY") + if err != nil { + logger.Errorf("Self-test failed: %v", err) + _ = signalWrappedProcess(cmd, syscall.SIGKILL) + return 1 + } + logger.Infof("Self-test check passed: interceptor READY from pid %d", readyMsg.PID) + + connectMsg, err := waitForIPCMessage(events, done, 5*time.Second, "CONNECT") + if err == nil { + logger.Infof("Self-test check passed: intercepted outbound connect from pid %d to %s", connectMsg.PID, connectMsg.Addr) + } else if dialErr := socksServer.WaitForDial(selfTestProbeTarget, 5*time.Second); dialErr == nil { + logger.Infof("Self-test check passed: SOCKS server observed intercepted outbound connect to %s", selfTestProbeTarget) + } else { + logger.Errorf("Self-test failed: %v", err) + _ = signalWrappedProcess(cmd, syscall.SIGKILL) + return 1 + } + + _ = signalWrappedProcess(cmd, syscall.SIGTERM) + select { + case <-done: + case <-time.After(2 * time.Second): + _ = signalWrappedProcess(cmd, syscall.SIGKILL) + <-done + } + + logger.Infof("Self-test completed successfully") + return 0 +} + +func runInternalSelfTestProbe(target string) int { + dialer := net.Dialer{Timeout: 2 * time.Second} + conn, err := dialer.Dial("tcp", target) + if err == nil { + _ = conn.Close() + } + return 0 +} diff --git a/runtime_helpers_test.go b/runtime_helpers_test.go new file mode 100644 index 0000000..c04e367 --- /dev/null +++ b/runtime_helpers_test.go @@ -0,0 +1,1299 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "os" + "os/exec" + "os/signal" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "syscall" + "testing" + "time" +) + +type synchronizedBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (b *synchronizedBuffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.Write(p) +} + +func (b *synchronizedBuffer) String() string { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.String() +} + +func TestBuildChildEnvPropagatesAcrossReexec(t *testing.T) { + if os.Getenv("TEST_WRAPGUARD_REEXEC_HELPER") == "1" { + outputPath := os.Getenv("TEST_WRAPGUARD_REEXEC_OUTPUT") + env := filterEnv(os.Environ(), "TEST_WRAPGUARD_REEXEC_HELPER") + env = append(env, + "TEST_WRAPGUARD_REEXEC_GRANDCHILD=1", + "TEST_WRAPGUARD_REEXEC_OUTPUT="+outputPath, + ) + if err := syscall.Exec(os.Args[0], []string{os.Args[0], "-test.run=TestBuildChildEnvPropagatesAcrossReexec"}, env); err != nil { + os.Exit(3) + } + return + } + + if os.Getenv("TEST_WRAPGUARD_REEXEC_GRANDCHILD") == "1" { + outputPath := os.Getenv("TEST_WRAPGUARD_REEXEC_OUTPUT") + payload := map[string]string{ + "library": os.Getenv(currentTestInjectionVar()), + "ipc": os.Getenv(envWrapGuardIPCPath), + "socks": os.Getenv(envWrapGuardSOCKSPort), + "debug": os.Getenv(envWrapGuardDebug), + "debugIP": os.Getenv(envWrapGuardDebugIPC), + "custom": os.Getenv("WRAPGUARD_CUSTOM_SENTINEL"), + } + data, _ := json.Marshal(payload) + _ = os.WriteFile(outputPath, data, 0o644) + return + } + + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping reexec propagation test: %v", err) + } + + workDir := t.TempDir() + libraryPath := filepath.Join(workDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + outputPath := filepath.Join(t.TempDir(), "env.json") + env := buildChildEnv( + append(os.Environ(), "WRAPGUARD_CUSTOM_SENTINEL=kept"), + cfg, + libraryPath, + filepath.Join(workDir, "wrapguard.sock"), + 45678, + true, + false, + ) + + cmd := exec.Command(os.Args[0], "-test.run=TestBuildChildEnvPropagatesAcrossReexec") + cmd.Env = append(env, + "TEST_WRAPGUARD_REEXEC_HELPER=1", + "TEST_WRAPGUARD_REEXEC_OUTPUT="+outputPath, + ) + if output, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("reexec helper failed: %v: %s", err, string(output)) + } + + data, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("failed to read grandchild env output: %v", err) + } + + var payload map[string]string + if err := json.Unmarshal(data, &payload); err != nil { + t.Fatalf("failed to decode grandchild env output: %v", err) + } + + if got := payload["library"]; got != libraryPath { + t.Fatalf("grandchild %s = %q", cfg.LibraryEnvVar, got) + } + if got := payload["ipc"]; got == "" { + t.Fatal("grandchild missing WRAPGUARD_IPC_PATH") + } + if got := payload["socks"]; got != "45678" { + t.Fatalf("grandchild WRAPGUARD_SOCKS_PORT = %q", got) + } + if got := payload["debug"]; got != "1" { + t.Fatalf("grandchild WRAPGUARD_DEBUG = %q", got) + } + if currentPlatformName() == "darwin" { + if got := payload["debugIP"]; got != "1" { + t.Fatalf("grandchild %s = %q", envWrapGuardDebugIPC, got) + } + } else if got := payload["debugIP"]; got != "" { + t.Fatalf("grandchild %s should be empty on Linux, got %q", envWrapGuardDebugIPC, got) + } + if got := payload["custom"]; got != "kept" { + t.Fatalf("grandchild custom sentinel = %q", got) + } +} + +func TestBuildChildEnvPropagatesThroughShellChild(t *testing.T) { + if os.Getenv("TEST_WRAPGUARD_SHELL_ENV_HELPER") == "1" { + outputPath := os.Getenv("TEST_WRAPGUARD_SHELL_ENV_OUTPUT") + payload := map[string]string{ + "library": os.Getenv(currentTestInjectionVar()), + "ipc": os.Getenv(envWrapGuardIPCPath), + "socks": os.Getenv(envWrapGuardSOCKSPort), + "debug": os.Getenv(envWrapGuardDebug), + "debugIP": os.Getenv(envWrapGuardDebugIPC), + "custom": os.Getenv("WRAPGUARD_CUSTOM_SENTINEL"), + } + data, _ := json.Marshal(payload) + _ = os.WriteFile(outputPath, data, 0o644) + return + } + + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping shell propagation test: %v", err) + } + + workDir := t.TempDir() + libraryPath := filepath.Join(workDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libraryPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + outputPath := filepath.Join(t.TempDir(), "shell-env.json") + env := buildChildEnv( + append(os.Environ(), "WRAPGUARD_CUSTOM_SENTINEL=kept"), + cfg, + libraryPath, + filepath.Join(workDir, "wrapguard.sock"), + 34567, + true, + false, + ) + + shellPath, err := exec.LookPath("sh") + if err != nil { + t.Skip("shell not available") + } + if runtime.GOOS == "darwin" { + if err := validateLaunchTarget(shellPath); err != nil { + t.Skipf("skipping shell propagation test for protected shell %s: %v", shellPath, err) + } + } + + cmd := exec.Command(shellPath, "-c", `exec "$1" -test.run=TestBuildChildEnvPropagatesThroughShellChild`, "wrapguard-shell-test", os.Args[0]) + cmd.Env = append(env, + "TEST_WRAPGUARD_SHELL_ENV_HELPER=1", + "TEST_WRAPGUARD_SHELL_ENV_OUTPUT="+outputPath, + ) + + if output, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("shell helper failed: %v: %s", err, string(output)) + } + + data, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("failed to read shell env output: %v", err) + } + + var payload map[string]string + if err := json.Unmarshal(data, &payload); err != nil { + t.Fatalf("failed to decode shell env output: %v", err) + } + + if got := payload["library"]; got != libraryPath { + t.Fatalf("shell child %s = %q", cfg.LibraryEnvVar, got) + } + if got := payload["ipc"]; got == "" { + t.Fatal("shell child missing WRAPGUARD_IPC_PATH") + } + if got := payload["socks"]; got != "34567" { + t.Fatalf("shell child WRAPGUARD_SOCKS_PORT = %q", got) + } + if got := payload["debug"]; got != "1" { + t.Fatalf("shell child WRAPGUARD_DEBUG = %q", got) + } + if currentPlatformName() == "darwin" { + if got := payload["debugIP"]; got != "1" { + t.Fatalf("shell child %s = %q", envWrapGuardDebugIPC, got) + } + } else if got := payload["debugIP"]; got != "" { + t.Fatalf("shell child %s should be empty on Linux, got %q", envWrapGuardDebugIPC, got) + } + if got := payload["custom"]; got != "kept" { + t.Fatalf("shell child custom sentinel = %q", got) + } +} + +func TestRunDoctorReportsLocalPreflightWithoutLaunchTarget(t *testing.T) { + execPath := writeDoctorRuntimeFixture(t, true) + + var output bytes.Buffer + if exitCode := runDoctor(execPath, "", &output); exitCode != 0 { + t.Fatalf("runDoctor exit code = %d, want 0", exitCode) + } + + got := output.String() + if !strings.Contains(got, "doctor: platform=") { + t.Fatalf("runDoctor output missing platform summary: %q", got) + } + if !strings.Contains(got, "doctor: no launch target supplied; preflight completed for local runtime artifacts only") { + t.Fatalf("runDoctor output missing no-target preflight message: %q", got) + } +} + +func TestRunDoctorReportsMissingRuntimeLibrary(t *testing.T) { + execPath := writeDoctorRuntimeFixture(t, false) + + var output bytes.Buffer + if exitCode := runDoctor(execPath, "", &output); exitCode != 1 { + t.Fatalf("runDoctor exit code = %d, want 1", exitCode) + } + + got := output.String() + if !strings.Contains(got, "doctor: runtime library check failed:") { + t.Fatalf("runDoctor output missing runtime library failure: %q", got) + } +} + +func TestRunDoctorReportsMissingLaunchTarget(t *testing.T) { + execPath := writeDoctorRuntimeFixture(t, true) + + var output bytes.Buffer + if exitCode := runDoctor(execPath, "definitely-not-a-real-wrapguard-target", &output); exitCode != 1 { + t.Fatalf("runDoctor exit code = %d, want 1", exitCode) + } + + got := output.String() + if !strings.Contains(got, "doctor: target lookup failed:") && !strings.Contains(got, "doctor: launch target unsupported: failed to resolve launch target") { + t.Fatalf("runDoctor output missing launch-target lookup failure: %q", got) + } +} + +func TestRunDoctorAcceptsDirectExecutableLaunchTarget(t *testing.T) { + var execPath string + if runtime.GOOS == "darwin" { + binaryPath, err := filepath.Abs("wrapguard") + if err != nil { + t.Fatalf("failed to resolve bundled wrapguard path: %v", err) + } + if _, err := os.Stat(binaryPath); err != nil { + t.Skipf("skipping bundled-runtime doctor test: %v", err) + } + if _, err := os.Stat(filepath.Join(filepath.Dir(binaryPath), "libwrapguard.dylib")); err != nil { + t.Skipf("skipping bundled-runtime doctor test: %v", err) + } + execPath = binaryPath + } else { + execPath = writeDoctorRuntimeFixture(t, true) + } + + target, err := filepath.Abs(os.Args[0]) + if err != nil { + t.Fatalf("failed to resolve test binary path: %v", err) + } + + var output bytes.Buffer + if exitCode := runDoctor(execPath, target, &output); exitCode != 0 { + t.Fatalf("runDoctor exit code = %d, want 0; output=%q", exitCode, output.String()) + } + + got := output.String() + if !strings.Contains(got, "doctor: target=") { + t.Fatalf("runDoctor output missing target summary: %q", got) + } + if !strings.Contains(got, "doctor: launch target passed preflight") { + t.Fatalf("runDoctor output missing success message: %q", got) + } +} + +func TestRunDoctorLaunchTargetsOnDarwin(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS-specific launch target validation only applies on Darwin") + } + + execPath := writeDoctorRuntimeFixture(t, true) + appTarget, innerExecutable := writeAppBundleFixture(t, "Example") + + tests := []struct { + name string + target string + want []string + exitCode int + }{ + { + name: "sip-protected-shell", + target: "/bin/sh", + want: []string{"doctor: launch target unsupported:"}, + exitCode: 1, + }, + { + name: "app-bundle", + target: appTarget, + want: []string{ + "doctor: app-bundle-resolved=" + innerExecutable, + "doctor: advisory: macOS GUI launches are experimental and only supported through the directly launched inner executable path", + "doctor: advisory: if this app hands work off to an already-running session or external launcher, WrapGuard will not control the real process tree", + "doctor: launch target passed preflight", + }, + exitCode: 0, + }, + { + name: "inner-executable", + target: innerExecutable, + want: []string{ + "doctor: target=" + innerExecutable, + "doctor: advisory: macOS GUI launches are experimental and only supported through the directly launched inner executable path", + "doctor: advisory: if this app hands work off to an already-running session or external launcher, WrapGuard will not control the real process tree", + "doctor: launch target passed preflight", + }, + exitCode: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var output bytes.Buffer + if exitCode := runDoctor(execPath, tt.target, &output); exitCode != tt.exitCode { + t.Fatalf("runDoctor exit code = %d, want %d", exitCode, tt.exitCode) + } + + got := output.String() + for _, want := range tt.want { + if !strings.Contains(got, want) { + t.Fatalf("runDoctor output missing expected message %q: %q", want, got) + } + } + }) + } +} + +func TestReportLaunchTargetSecurityInfoFormatsSigningStates(t *testing.T) { + targetPath := filepath.Join(t.TempDir(), "target") + if err := os.WriteFile(targetPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("failed to create target fixture: %v", err) + } + + tests := []struct { + name string + codesignOutput string + exitCode int + wantLines []string + }{ + { + name: "unsigned", + codesignOutput: "code object is not signed at all\n", + exitCode: 1, + wantLines: []string{ + "doctor: target-signing=unsigned", + "doctor: target-hardened-runtime=disabled", + }, + }, + { + name: "ad-hoc", + codesignOutput: "Executable=/tmp/target\nIdentifier=wrapguard.test\nSignature=adhoc\n", + exitCode: 0, + wantLines: []string{ + "doctor: target-signing=ad-hoc", + "doctor: target-hardened-runtime=disabled", + }, + }, + { + name: "signed-hardened-runtime", + codesignOutput: "Executable=/tmp/target\nAuthority=Developer ID Application: Example (ABCDE12345)\nflags=0x10000(runtime)\n", + exitCode: 0, + wantLines: []string{ + "doctor: target-signing=signed", + "doctor: target-hardened-runtime=enabled", + "doctor: advisory: DYLD injection may still be rejected at runtime by the target's hardened runtime policy", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + codesignPath := writeCodesignFixture(t, tt.codesignOutput, tt.exitCode) + + var output bytes.Buffer + if err := reportLaunchTargetSecurityInfo(&output, targetPath, codesignPath); err != nil { + t.Fatalf("reportLaunchTargetSecurityInfo returned error: %v", err) + } + + got := output.String() + for _, wantLine := range tt.wantLines { + if !strings.Contains(got, wantLine) { + t.Fatalf("reportLaunchTargetSecurityInfo output missing %q: %q", wantLine, got) + } + } + }) + } +} + +func TestReportLaunchTargetSecurityInfoFallsBackWhenCodesignMissing(t *testing.T) { + targetPath := filepath.Join(t.TempDir(), "target") + if err := os.WriteFile(targetPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("failed to create target fixture: %v", err) + } + + var output bytes.Buffer + missingCodesign := filepath.Join(t.TempDir(), "missing-codesign") + if err := reportLaunchTargetSecurityInfo(&output, targetPath, missingCodesign); err == nil { + t.Fatal("expected reportLaunchTargetSecurityInfo to fail when codesign is missing") + } + + got := output.String() + if !strings.Contains(got, "doctor: target-signing=unknown") { + t.Fatalf("fallback output missing signing status: %q", got) + } + if !strings.Contains(got, "doctor: target-hardened-runtime=unknown") { + t.Fatalf("fallback output missing hardened runtime status: %q", got) + } +} + +func TestParseLaunchTargetSecurityInfoDistinguishesCommonCodesignStates(t *testing.T) { + tests := []struct { + name string + text string + want launchTargetSecurityInfo + }{ + { + name: "signed-without-runtime", + text: "Executable=/tmp/target\nAuthority=Developer ID Application: Example (ABCDE12345)\nflags=0x0(none)\n", + want: launchTargetSecurityInfo{ + SigningStatus: "signed", + HardenedRuntime: "disabled", + }, + }, + { + name: "adhoc-with-runtime-flag", + text: "Executable=/tmp/target\nSignature=adhoc\nflags=0x10000(runtime)\n", + want: launchTargetSecurityInfo{ + SigningStatus: "ad-hoc", + HardenedRuntime: "enabled", + }, + }, + { + name: "unparsed-output-stays-unknown", + text: "some unexpected codesign output\n", + want: launchTargetSecurityInfo{ + SigningStatus: "unknown", + HardenedRuntime: "unknown", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseLaunchTargetSecurityInfo(tt.text) + if got.SigningStatus != tt.want.SigningStatus { + t.Fatalf("SigningStatus = %q, want %q", got.SigningStatus, tt.want.SigningStatus) + } + if got.HardenedRuntime != tt.want.HardenedRuntime { + t.Fatalf("HardenedRuntime = %q, want %q", got.HardenedRuntime, tt.want.HardenedRuntime) + } + }) + } +} + +func TestInspectLaunchTargetSecurityInfoUsesParsedMetadataEvenWhenCodesignExitsNonZero(t *testing.T) { + targetPath := filepath.Join(t.TempDir(), "target") + if err := os.WriteFile(targetPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("failed to create target fixture: %v", err) + } + + codesignPath := writeCodesignFixture(t, "Executable=/tmp/target\nAuthority=Developer ID Application: Example (ABCDE12345)\nflags=0x10000(runtime)\n", 1) + + info, err := inspectLaunchTargetSecurityInfo(targetPath, codesignPath) + if err != nil { + t.Fatalf("inspectLaunchTargetSecurityInfo returned error: %v", err) + } + if info.SigningStatus != "signed" { + t.Fatalf("SigningStatus = %q, want signed", info.SigningStatus) + } + if info.HardenedRuntime != "enabled" { + t.Fatalf("HardenedRuntime = %q, want enabled", info.HardenedRuntime) + } + if !strings.Contains(info.InspectionNotice, "DYLD injection may still be rejected") { + t.Fatalf("InspectionNotice = %q, want hardened-runtime advisory", info.InspectionNotice) + } +} + +func TestInspectLaunchTargetSecurityInfoReturnsErrorWhenCodesignOutputIsUnusable(t *testing.T) { + targetPath := filepath.Join(t.TempDir(), "target") + if err := os.WriteFile(targetPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("failed to create target fixture: %v", err) + } + + codesignPath := writeCodesignFixture(t, "codesign blew up\n", 1) + + info, err := inspectLaunchTargetSecurityInfo(targetPath, codesignPath) + if err == nil { + t.Fatal("expected inspectLaunchTargetSecurityInfo to fail for unusable codesign output") + } + if info.SigningStatus != "unknown" { + t.Fatalf("SigningStatus = %q, want unknown", info.SigningStatus) + } + if info.HardenedRuntime != "unknown" { + t.Fatalf("HardenedRuntime = %q, want unknown", info.HardenedRuntime) + } +} + +func TestWaitForWrappedCommandForwardsSignal(t *testing.T) { + if os.Getenv("TEST_WRAPGUARD_SIGNAL_HELPER") == "1" { + runSignalHelper(os.Getenv("TEST_WRAPGUARD_SIGNAL_FILE"), os.Getenv("TEST_WRAPGUARD_IGNORE_SIGNAL") == "1") + return + } + + oldLogger := CurrentLogger() + SetGlobalLogger(NewLogger(LogLevelDebug, io.Discard)) + defer SetGlobalLogger(oldLogger) + + signalFile := filepath.Join(t.TempDir(), "signal.txt") + cmd := exec.Command(os.Args[0], "-test.run=TestWaitForWrappedCommandForwardsSignal") + cmd.Env = append(os.Environ(), + "TEST_WRAPGUARD_SIGNAL_HELPER=1", + "TEST_WRAPGUARD_SIGNAL_FILE="+signalFile, + ) + cmd.SysProcAttr = childSysProcAttr() + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start signal helper: %v", err) + } + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + waitForFile(t, signalFile, false) + sigCh := make(chan os.Signal, 1) + sigCh <- syscall.SIGTERM + + if exitCode := waitForWrappedCommand(cmd, done, sigCh, nil, time.Second); exitCode != 1 { + t.Fatalf("waitForWrappedCommand exit code = %d, want 1", exitCode) + } + + waitForFile(t, signalFile, true) + data, err := os.ReadFile(signalFile) + if err != nil { + t.Fatalf("failed to read signal file: %v", err) + } + if string(data) != "terminated:terminated\n" { + t.Fatalf("signal helper output = %q", string(data)) + } +} + +func TestWaitForWrappedCommandReturnsChildExitCode(t *testing.T) { + if os.Getenv("TEST_WRAPGUARD_EXIT_HELPER") == "1" { + os.Exit(7) + } + + oldLogger := CurrentLogger() + SetGlobalLogger(NewLogger(LogLevelDebug, io.Discard)) + defer SetGlobalLogger(oldLogger) + + cmd := exec.Command(os.Args[0], "-test.run=TestWaitForWrappedCommandReturnsChildExitCode") + cmd.Env = append(os.Environ(), "TEST_WRAPGUARD_EXIT_HELPER=1") + cmd.SysProcAttr = childSysProcAttr() + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start exit helper: %v", err) + } + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + sigCh := make(chan os.Signal, 1) + if exitCode := waitForWrappedCommand(cmd, done, sigCh, nil, time.Second); exitCode != 7 { + t.Fatalf("waitForWrappedCommand exit code = %d, want 7", exitCode) + } +} + +func TestWaitForWrappedCommandKillsHungChildAfterGracePeriod(t *testing.T) { + if os.Getenv("TEST_WRAPGUARD_SIGNAL_HELPER") == "1" { + runSignalHelper(os.Getenv("TEST_WRAPGUARD_SIGNAL_FILE"), os.Getenv("TEST_WRAPGUARD_IGNORE_SIGNAL") == "1") + return + } + + oldLogger := CurrentLogger() + SetGlobalLogger(NewLogger(LogLevelDebug, io.Discard)) + defer SetGlobalLogger(oldLogger) + + signalFile := filepath.Join(t.TempDir(), "signal.txt") + cmd := exec.Command(os.Args[0], "-test.run=TestWaitForWrappedCommandKillsHungChildAfterGracePeriod") + cmd.Env = append(os.Environ(), + "TEST_WRAPGUARD_SIGNAL_HELPER=1", + "TEST_WRAPGUARD_SIGNAL_FILE="+signalFile, + "TEST_WRAPGUARD_IGNORE_SIGNAL=1", + ) + cmd.SysProcAttr = childSysProcAttr() + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start hanging signal helper: %v", err) + } + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + waitForFile(t, signalFile, false) + sigCh := make(chan os.Signal, 1) + sigCh <- syscall.SIGTERM + + if exitCode := waitForWrappedCommand(cmd, done, sigCh, nil, 200*time.Millisecond); exitCode != 1 { + t.Fatalf("waitForWrappedCommand exit code = %d, want 1", exitCode) + } + + waitForFile(t, signalFile, true) +} + +func TestWaitForWrappedCommandRunsTerminateHookOnSignal(t *testing.T) { + if os.Getenv("TEST_WRAPGUARD_SIGNAL_HELPER") == "1" { + runSignalHelper(os.Getenv("TEST_WRAPGUARD_SIGNAL_FILE"), false) + return + } + + oldLogger := CurrentLogger() + SetGlobalLogger(NewLogger(LogLevelDebug, io.Discard)) + defer SetGlobalLogger(oldLogger) + + signalFile := filepath.Join(t.TempDir(), "signal.txt") + cmd := exec.Command(os.Args[0], "-test.run=TestWaitForWrappedCommandRunsTerminateHookOnSignal") + cmd.Env = append(os.Environ(), + "TEST_WRAPGUARD_SIGNAL_HELPER=1", + "TEST_WRAPGUARD_SIGNAL_FILE="+signalFile, + ) + cmd.SysProcAttr = childSysProcAttr() + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start signal helper: %v", err) + } + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + waitForFile(t, signalFile, false) + sigCh := make(chan os.Signal, 1) + sigCh <- syscall.SIGTERM + + called := false + if exitCode := waitForWrappedCommand(cmd, done, sigCh, func() { called = true }, time.Second); exitCode != 1 { + t.Fatalf("waitForWrappedCommand exit code = %d, want 1", exitCode) + } + if !called { + t.Fatal("terminate hook was not invoked") + } +} + +func currentTestInjectionVar() string { + cfg, err := currentInjectionConfig() + if err != nil { + return "" + } + return cfg.LibraryEnvVar +} + +func writeCodesignFixture(t *testing.T, stdout string, exitCode int) string { + t.Helper() + + scriptPath := filepath.Join(t.TempDir(), "codesign") + script := "#!/bin/sh\ncat <<'EOF'\n" + stdout + "EOF\nexit " + strconv.Itoa(exitCode) + "\n" + if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil { + t.Fatalf("failed to create codesign fixture: %v", err) + } + return scriptPath +} + +func runSignalHelper(signalFile string, ignore bool) { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(sigs) + + _ = os.WriteFile(signalFile, []byte("ready\n"), 0o644) + sig := <-sigs + _ = os.WriteFile(signalFile, []byte("terminated:"+sig.String()+"\n"), 0o644) + if ignore { + select {} + } + os.Exit(0) +} + +func waitForFile(t *testing.T, path string, wantTermination bool) { + t.Helper() + + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + data, err := os.ReadFile(path) + if err == nil { + content := string(data) + if wantTermination { + if content != "" && content != "ready\n" { + return + } + } else if content == "ready\n" { + return + } + } + time.Sleep(10 * time.Millisecond) + } + + t.Fatalf("timed out waiting for signal helper file state (termination=%v)", wantTermination) +} + +func waitForOutputContains(t *testing.T, output interface{ String() string }, want ...string) { + t.Helper() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + got := output.String() + allPresent := true + for _, needle := range want { + if !strings.Contains(got, needle) { + allPresent = false + break + } + } + if allPresent { + return + } + time.Sleep(10 * time.Millisecond) + } + + t.Fatalf("timed out waiting for log output to contain %q; got %q", want, output.String()) +} + +func filterEnv(env []string, dropKey string) []string { + filtered := make([]string, 0, len(env)) + prefix := dropKey + "=" + for _, entry := range env { + if len(entry) >= len(prefix) && entry[:len(prefix)] == prefix { + continue + } + filtered = append(filtered, entry) + } + return filtered +} + +func TestProbeSOCKSReachabilityFailsWhenPortClosed(t *testing.T) { + port := reserveUnusedPort(t) + + if err := probeSOCKSReachability(port); err == nil { + t.Fatal("expected probeSOCKSReachability to fail for a closed port") + } +} + +func TestProbeSOCKSReachabilitySucceedsAgainstSOCKSListener(t *testing.T) { + tunnel := &Tunnel{ourIP: mustParseIPAddr("10.150.0.2")} + server, err := NewSOCKS5Server(tunnel) + if err != nil { + t.Fatalf("NewSOCKS5Server failed: %v", err) + } + defer server.Close() + + if err := probeSOCKSReachability(server.Port()); err != nil { + t.Fatalf("probeSOCKSReachability failed: %v", err) + } +} + +func TestProbeIPCReachabilityFailsWhenSocketMissing(t *testing.T) { + socketPath := filepath.Join(t.TempDir(), "missing.sock") + if err := probeIPCReachability(socketPath); err == nil { + t.Fatal("expected probeIPCReachability to fail for a missing socket") + } +} + +func TestProbeIPCReachabilitySucceedsForLiveServer(t *testing.T) { + server, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer server.Close() + + if err := probeIPCReachability(server.SocketPath()); err != nil { + t.Fatalf("probeIPCReachability failed: %v", err) + } +} + +func TestWaitForIPCMessageReturnsChildExitError(t *testing.T) { + msgCh := make(chan IPCMessage) + done := make(chan error, 1) + done <- errors.New("boom") + + _, err := waitForIPCMessage(msgCh, done, 100*time.Millisecond, "READY") + if err == nil { + t.Fatal("expected waitForIPCMessage to return an error") + } + if err.Error() != "child exited before READY: boom" { + t.Fatalf("unexpected waitForIPCMessage error: %v", err) + } +} + +func TestWaitForIPCMessageReturnsMessageBeforeTimeout(t *testing.T) { + msgCh := make(chan IPCMessage, 1) + done := make(chan error, 1) + msgCh <- IPCMessage{Type: "READY", PID: 42} + + msg, err := waitForIPCMessage(msgCh, done, time.Second, "READY") + if err != nil { + t.Fatalf("waitForIPCMessage failed: %v", err) + } + if msg.PID != 42 { + t.Fatalf("unexpected PID %d", msg.PID) + } +} + +func TestWaitForIPCMessageTimesOutWhileIgnoringUnrelatedMessages(t *testing.T) { + msgCh := make(chan IPCMessage, 2) + done := make(chan error, 1) + msgCh <- IPCMessage{Type: "DEBUG", PID: 7} + msgCh <- IPCMessage{Type: "UDP_SEND", PID: 8} + + _, err := waitForIPCMessage(msgCh, done, 100*time.Millisecond, "READY") + if err == nil { + t.Fatal("expected waitForIPCMessage to time out") + } + if err.Error() != "timed out waiting for READY" { + t.Fatalf("unexpected waitForIPCMessage error: %v", err) + } +} + +func TestProbeSOCKSReachabilityRejectsNonSOCKSServer(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + defer listener.Close() + + errCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + errCh <- err + return + } + defer conn.Close() + _, _ = conn.Write([]byte("nope")) + errCh <- nil + }() + + if err := probeSOCKSReachability(listener.Addr().(*net.TCPAddr).Port); err == nil { + t.Fatal("expected probeSOCKSReachability to reject a non-SOCKS server") + } + if err := <-errCh; err != nil { + t.Fatalf("helper server failed: %v", err) + } +} + +func TestProbeSOCKSReachabilityRejectsTruncatedHandshake(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + defer listener.Close() + + errCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + errCh <- err + return + } + defer conn.Close() + _, _ = conn.Write([]byte{0x05}) + errCh <- nil + }() + + if err := probeSOCKSReachability(listener.Addr().(*net.TCPAddr).Port); err == nil { + t.Fatal("expected probeSOCKSReachability to reject a truncated SOCKS handshake") + } + if err := <-errCh; err != nil { + t.Fatalf("helper server failed: %v", err) + } +} + +func TestStartIPCEventLoggerLogsTransportEvents(t *testing.T) { + oldLogger := CurrentLogger() + var output synchronizedBuffer + SetGlobalLogger(NewLogger(LogLevelDebug, &output)) + defer SetGlobalLogger(oldLogger) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + server, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer server.Close() + + stop := startIPCEventLogger(ctx, server, true) + defer stop() + + server.dispatchMessage(IPCMessage{Type: "READY", PID: 101}) + server.dispatchMessage(IPCMessage{Type: "CONNECT", PID: 101, Addr: "203.0.113.10:443"}) + server.dispatchMessage(IPCMessage{Type: "DEBUG", PID: 101, Detail: "browser debug detail"}) + server.dispatchMessage(IPCMessage{Type: "UDP_BLOCK", PID: 101, Addr: "203.0.113.11:443", Detail: "sendmsg"}) + server.dispatchMessage(IPCMessage{Type: "UDP_SEND", PID: 101, Addr: "203.0.113.12:443", Detail: "connected-sendto"}) + server.dispatchMessage(IPCMessage{Type: "ERROR", PID: 101, Detail: "simulated failure"}) + + waitForOutputContains(t, &output, + "Interceptor READY from pid 101", + "Interceptor CONNECT from pid 101 to 203.0.113.10:443", + "Interceptor DEBUG from pid 101: browser debug detail", + "Interceptor UDP_BLOCK from pid 101 to 203.0.113.11:443 (sendmsg)", + "Interceptor UDP_SEND from pid 101 to 203.0.113.12:443 (connected-sendto)", + "Interceptor ERROR from pid 101: simulated failure", + ) +} + +func TestRunSelfTestReportsClosedSOCKSListener(t *testing.T) { + oldLogger := CurrentLogger() + var output bytes.Buffer + SetGlobalLogger(NewLogger(LogLevelDebug, &output)) + defer SetGlobalLogger(oldLogger) + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + tunnel := &Tunnel{ourIP: mustParseIPAddr("10.150.0.2")} + socksServer, err := NewSOCKS5Server(tunnel) + if err != nil { + t.Fatalf("NewSOCKS5Server failed: %v", err) + } + socksServer.Close() + + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + exitCode := runSelfTest(context.Background(), ipcServer, socksServer, os.Args[0], "", cfg, true) + if exitCode != 1 { + t.Fatalf("runSelfTest exit code = %d, want 1", exitCode) + } + + got := output.String() + if !strings.Contains(got, "Self-test failed: SOCKS listener is not reachable") { + t.Fatalf("runSelfTest output missing SOCKS failure diagnostic: %q", got) + } +} + +func TestRunSelfTestFailsWhenChildExitsBeforeReady(t *testing.T) { + oldLogger := CurrentLogger() + var output bytes.Buffer + SetGlobalLogger(NewLogger(LogLevelDebug, &output)) + defer SetGlobalLogger(oldLogger) + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + tunnel := &Tunnel{ourIP: mustParseIPAddr("10.150.0.2")} + socksServer, err := NewSOCKS5Server(tunnel) + if err != nil { + t.Fatalf("NewSOCKS5Server failed: %v", err) + } + defer socksServer.Close() + + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + helperPath := filepath.Join(t.TempDir(), "self-test-no-ready.sh") + if err := os.WriteFile(helperPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatalf("failed to write self-test helper: %v", err) + } + + exitCode := runSelfTest(context.Background(), ipcServer, socksServer, helperPath, "", cfg, true) + if exitCode != 1 { + t.Fatalf("runSelfTest exit code = %d, want 1", exitCode) + } + + got := output.String() + if !strings.Contains(got, "Self-test failed: child exited before READY") { + t.Fatalf("runSelfTest output missing READY failure diagnostic: %q", got) + } +} + +func TestRunSelfTestFailsWhenConnectNeverArrives(t *testing.T) { + oldLogger := CurrentLogger() + var output bytes.Buffer + SetGlobalLogger(NewLogger(LogLevelDebug, &output)) + defer SetGlobalLogger(oldLogger) + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + tunnel := &Tunnel{ourIP: mustParseIPAddr("10.150.0.2")} + socksServer, err := NewSOCKS5Server(tunnel) + if err != nil { + t.Fatalf("NewSOCKS5Server failed: %v", err) + } + defer socksServer.Close() + + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping self-test connect timeout fixture: %v", err) + } + + fixtureDir := t.TempDir() + libPath := filepath.Join(fixtureDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + helperPath := filepath.Join(fixtureDir, "self-test-ready-only") + if err := buildIPCReadyOnlyHelper(t, helperPath); err != nil { + t.Fatalf("failed to build self-test helper: %v", err) + } + + exitCode := runSelfTest(context.Background(), ipcServer, socksServer, helperPath, libPath, cfg, true) + if exitCode != 1 { + t.Fatalf("runSelfTest exit code = %d, want 1", exitCode) + } + + got := output.String() + if !strings.Contains(got, "Self-test check passed: interceptor READY") { + t.Fatalf("runSelfTest output missing READY success diagnostic: %q", got) + } + if !strings.Contains(got, "Self-test failed: child exited before CONNECT") { + t.Fatalf("runSelfTest output missing CONNECT failure diagnostic: %q", got) + } +} + +func TestRunSelfTestSucceedsWithInjectedProbe(t *testing.T) { + oldLogger := CurrentLogger() + var output bytes.Buffer + SetGlobalLogger(NewLogger(LogLevelDebug, &output)) + defer SetGlobalLogger(oldLogger) + + ipcServer, err := NewIPCServer() + if err != nil { + t.Fatalf("NewIPCServer failed: %v", err) + } + defer ipcServer.Close() + + tunnel := &Tunnel{ourIP: mustParseIPAddr("10.150.0.2")} + socksServer, err := NewSOCKS5Server(tunnel) + if err != nil { + t.Fatalf("NewSOCKS5Server failed: %v", err) + } + defer socksServer.Close() + + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping self-test success fixture: %v", err) + } + + fixtureDir := t.TempDir() + libPath := filepath.Join(fixtureDir, cfg.LibraryName) + if err := buildInterceptLibraryForTest(t, cc, libPath); err != nil { + t.Fatalf("failed to build intercept library: %v", err) + } + + helperPath := filepath.Join(fixtureDir, "self-test-connect-probe") + if err := buildSelfTestConnectProbe(t, cc, helperPath); err != nil { + t.Fatalf("failed to build self-test probe helper: %v", err) + } + + exitCode := runSelfTest(context.Background(), ipcServer, socksServer, helperPath, libPath, cfg, true) + if exitCode != 0 { + t.Fatalf("runSelfTest exit code = %d, want 0; output=%q", exitCode, output.String()) + } + + got := output.String() + for _, want := range []string{ + "Self-test check passed: IPC socket is reachable", + "Self-test check passed: SOCKS listener is reachable", + "Self-test check passed: interceptor READY", + "Self-test check passed: intercepted outbound connect", + "Self-test completed successfully", + } { + if !strings.Contains(got, want) { + t.Fatalf("runSelfTest output missing %q: %q", want, got) + } + } +} + +func writeDoctorRuntimeFixture(t *testing.T, includeLibrary bool) string { + t.Helper() + + cfg, err := currentInjectionConfig() + if err != nil { + t.Fatalf("currentInjectionConfig failed: %v", err) + } + + workDir := t.TempDir() + execPath := filepath.Join(workDir, "wrapguard") + if err := os.WriteFile(execPath, []byte("test"), 0o755); err != nil { + t.Fatalf("failed to create doctor exec fixture: %v", err) + } + + if includeLibrary { + libPath := filepath.Join(workDir, cfg.LibraryName) + if runtime.GOOS == "darwin" { + cc, err := findCCompiler() + if err != nil { + t.Skipf("skipping doctor fixture test: %v", err) + } + if err := buildInterceptLibraryForTest(t, cc, libPath); err != nil { + t.Fatalf("failed to build doctor library fixture: %v", err) + } + } else if err := os.WriteFile(libPath, []byte("test"), 0o644); err != nil { + t.Fatalf("failed to create doctor library fixture: %v", err) + } + } + + return execPath +} + +func buildIPCReadyOnlyHelper(t *testing.T, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "main.go") + source := `package main + +import ( + "encoding/json" + "net" + "os" +) + +func main() { + socketPath := os.Getenv("WRAPGUARD_IPC_PATH") + if socketPath == "" { + os.Exit(2) + } + + conn, err := net.Dial("unix", socketPath) + if err != nil { + os.Exit(3) + } + defer conn.Close() + + msg := map[string]any{ + "type": "READY", + "pid": os.Getpid(), + } + if err := json.NewEncoder(conn).Encode(msg); err != nil { + os.Exit(4) + } +} +` + if err := os.WriteFile(sourcePath, []byte(source), 0o644); err != nil { + return err + } + + cmd := exec.Command("go", "build", "-o", outputPath, sourcePath) + if output, err := cmd.CombinedOutput(); err != nil { + return errors.New(strings.TrimSpace(string(output))) + } + return nil +} + +func buildSelfTestConnectProbe(t *testing.T, cc, outputPath string) error { + t.Helper() + + sourcePath := filepath.Join(t.TempDir(), "self_test_connect_probe.c") + source := `#include +#include +#include +#include +#include +#include +#include +#include + +int main(int argc, char **argv) { + const char *prefix = "--internal-self-test-probe="; + const size_t prefix_len = strlen(prefix); + const char *target = NULL; + + for (int i = 1; i < argc; i++) { + if (strncmp(argv[i], prefix, prefix_len) == 0) { + target = argv[i] + prefix_len; + break; + } + } + + if (target == NULL || *target == '\0') { + return 2; + } + + char input[256]; + memset(input, 0, sizeof(input)); + strncpy(input, target, sizeof(input) - 1); + + char *sep = strrchr(input, ':'); + if (sep == NULL) { + return 3; + } + + *sep = '\0'; + const char *host = input; + int port = atoi(sep + 1); + if (port <= 0 || port > 65535) { + return 4; + } + + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + return 5; + } + + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons((unsigned short)port); + if (inet_pton(AF_INET, host, &addr.sin_addr) != 1) { + close(fd); + return 6; + } + + (void)connect(fd, (struct sockaddr *)&addr, sizeof(addr)); + close(fd); + return 0; +} +` + if err := os.WriteFile(sourcePath, []byte(source), 0o644); err != nil { + return err + } + + args := []string{"-Wall", "-Wextra", "-Werror", "-o", outputPath, sourcePath} + if runtime.GOOS == "darwin" { + args = []string{"-Wall", "-Wextra", "-Werror", "-Wno-deprecated-declarations", "-o", outputPath, sourcePath} + } + cmd := exec.Command(cc, args...) + if output, err := cmd.CombinedOutput(); err != nil { + return errors.New(strings.TrimSpace(string(output))) + } + return nil +} diff --git a/socks.go b/socks.go index 1111410..ebe9f6b 100644 --- a/socks.go +++ b/socks.go @@ -5,6 +5,8 @@ import ( "fmt" "net" "strconv" + "sync" + "time" "github.com/armon/go-socks5" ) @@ -14,67 +16,88 @@ type SOCKS5Server struct { listener net.Listener port int tunnel *Tunnel + dials chan string + wg sync.WaitGroup } -func NewSOCKS5Server(tunnel *Tunnel) (*SOCKS5Server, error) { - // Create SOCKS5 server with custom dialer that routes WireGuard IPs through the tunnel - socksConfig := &socks5.Config{ - Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { - logger.Debugf("SOCKS5 dial request: %s %s", network, addr) - - // Parse the address to check if it's a WireGuard IP - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, fmt.Errorf("invalid address format: %w", err) - } +func buildSOCKS5Dial(tunnel *Tunnel, socksPort int, baseDial func(context.Context, string, string) (net.Conn, error), onDial func(string, string)) func(context.Context, string, string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + if onDial != nil { + onDial(network, addr) + } + logger.Debugf("SOCKS5 dial request: %s %s", network, addr) - // Check if this is a WireGuard IP that should be routed through the tunnel - ip := net.ParseIP(host) - if ip != nil { - // Use routing engine to find appropriate peer - portNum, _ := strconv.Atoi(port) - peer, peerIdx := tunnel.router.FindPeerForDestination(ip, portNum, "tcp") - if peer != nil { - logger.Debugf("Routing %s through WireGuard tunnel via peer %d (endpoint: %s)", addr, peerIdx, peer.Endpoint) - return tunnel.DialWireGuard(ctx, network, host, port) - } + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("invalid address format: %w", err) + } + + ip := net.ParseIP(host) + if ip != nil && ip.IsLoopback() { + portNum, _ := strconv.Atoi(port) + if portNum == socksPort { + return nil, fmt.Errorf("refusing recursive SOCKS dial to localhost:%d", socksPort) } + return baseDial(ctx, network, addr) + } - // For non-WireGuard IPs, use normal dialing - logger.Debugf("Using normal dial for %s", addr) - dialer := &net.Dialer{} - conn, err := dialer.DialContext(ctx, network, addr) - if err != nil { - logger.Debugf("SOCKS5 dial failed for %s: %v", addr, err) - } else { - logger.Debugf("SOCKS5 dial succeeded for %s", addr) + if tunnel != nil && tunnel.router != nil && ip != nil { + portNum, _ := strconv.Atoi(port) + peer, peerIdx := tunnel.router.FindPeerForDestination(ip, portNum, normalizeNetworkProtocol(network)) + if peer != nil { + logger.Debugf("Routing %s through WireGuard tunnel via peer %d (endpoint: %s)", addr, peerIdx, peer.Endpoint) + return tunnel.DialWireGuard(ctx, network, host, port) } - return conn, err - }, + } + + logger.Debugf("Using normal dial for %s", addr) + conn, err := baseDial(ctx, network, addr) + if err != nil { + logger.Debugf("SOCKS5 dial failed for %s: %v", addr, err) + } else { + logger.Debugf("SOCKS5 dial succeeded for %s", addr) + } + return conn, err } +} - server, err := socks5.New(socksConfig) - if err != nil { - return nil, fmt.Errorf("failed to create SOCKS5 server: %w", err) +func NewSOCKS5Server(tunnel *Tunnel) (*SOCKS5Server, error) { + if tunnel == nil { + return nil, fmt.Errorf("tunnel is required") } - // Listen on localhost for SOCKS5 connections listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, fmt.Errorf("failed to listen for SOCKS5 connections: %w", err) } port := listener.Addr().(*net.TCPAddr).Port + baseDialer := (&net.Dialer{}).DialContext s := &SOCKS5Server{ - server: server, listener: listener, port: port, tunnel: tunnel, + dials: make(chan string, 32), + } + socksConfig := &socks5.Config{} + socksConfig.Dial = buildSOCKS5Dial(tunnel, port, baseDialer, func(_ string, addr string) { + select { + case s.dials <- addr: + default: + } + }) + server, err := socks5.New(socksConfig) + if err != nil { + _ = listener.Close() + return nil, fmt.Errorf("failed to create SOCKS5 server: %w", err) } + s.server = server // Start serving in background + s.wg.Add(1) go func() { + defer s.wg.Done() if err := server.Serve(listener); err != nil { // Log error but don't crash - server might be shutting down logger.Debugf("SOCKS5 server stopped: %v", err) @@ -88,9 +111,27 @@ func (s *SOCKS5Server) Port() int { return s.port } +func (s *SOCKS5Server) WaitForDial(addr string, timeout time.Duration) error { + timer := time.NewTimer(timeout) + defer timer.Stop() + + for { + select { + case got := <-s.dials: + if got == addr { + return nil + } + case <-timer.C: + return fmt.Errorf("timed out waiting for SOCKS dial to %s", addr) + } + } +} + func (s *SOCKS5Server) Close() error { if s.listener != nil { - return s.listener.Close() + err := s.listener.Close() + s.wg.Wait() + return err } return nil } diff --git a/socks_test.go b/socks_test.go index a9b1ad9..637da0f 100644 --- a/socks_test.go +++ b/socks_test.go @@ -1,6 +1,8 @@ package main import ( + "context" + "fmt" "net" "net/netip" "testing" @@ -161,13 +163,10 @@ func TestSOCKS5Server_ListenerAddress(t *testing.T) { } func TestSOCKS5Server_NilTunnel(t *testing.T) { - // Test behavior with nil tunnel (should not panic but may fail) + // Test behavior with nil tunnel is a clean error. _, err := NewSOCKS5Server(nil) - - // This will likely panic or fail, which is acceptable behavior - // We just want to ensure it doesn't crash the test suite - if err != nil { - t.Logf("NewSOCKS5Server with nil tunnel failed as expected: %v", err) + if err == nil { + t.Fatal("expected error for nil tunnel") } } @@ -245,6 +244,175 @@ func TestSOCKS5Server_ServerRunning(t *testing.T) { } } +func TestBuildSOCKS5DialBypassesLoopback(t *testing.T) { + tunnel := &Tunnel{ + ourIP: mustParseIPAddr("10.150.0.2"), + router: NewRoutingEngine(&WireGuardConfig{ + Peers: []PeerConfig{ + { + AllowedIPs: []string{"0.0.0.0/0"}, + }, + }, + }), + } + + var dialed []string + dial := buildSOCKS5Dial(tunnel, 1080, func(ctx context.Context, network, addr string) (net.Conn, error) { + dialed = append(dialed, network+" "+addr) + return nil, fmt.Errorf("base dial invoked") + }, nil) + + _, err := dial(context.Background(), "tcp", "127.0.0.1:8080") + if err == nil || err.Error() != "base dial invoked" { + t.Fatalf("expected base dialer error, got %v", err) + } + if len(dialed) != 1 || dialed[0] != "tcp 127.0.0.1:8080" { + t.Fatalf("unexpected base dial invocations: %v", dialed) + } +} + +func TestBuildSOCKS5DialRejectsRecursiveLoopbackPort(t *testing.T) { + dial := buildSOCKS5Dial(&Tunnel{ourIP: mustParseIPAddr("10.150.0.2")}, 1080, func(ctx context.Context, network, addr string) (net.Conn, error) { + t.Fatalf("base dialer should not be used for recursive SOCKS target") + return nil, nil + }, nil) + + if _, err := dial(context.Background(), "tcp", "127.0.0.1:1080"); err == nil || err.Error() != "refusing recursive SOCKS dial to localhost:1080" { + t.Fatalf("unexpected recursive dial result: %v", err) + } +} + +func TestBuildSOCKS5DialLeavesHostnamesOnBaseDialer(t *testing.T) { + tunnel := &Tunnel{ + ourIP: mustParseIPAddr("10.150.0.2"), + router: NewRoutingEngine(&WireGuardConfig{ + Peers: []PeerConfig{ + { + AllowedIPs: []string{"0.0.0.0/0"}, + }, + }, + }), + } + + var dialed []string + dial := buildSOCKS5Dial(tunnel, 1080, func(ctx context.Context, network, addr string) (net.Conn, error) { + dialed = append(dialed, network+" "+addr) + return nil, fmt.Errorf("base dial invoked") + }, nil) + + _, err := dial(context.Background(), "tcp", "example.com:443") + if err == nil || err.Error() != "base dial invoked" { + t.Fatalf("expected base dialer error, got %v", err) + } + if len(dialed) != 1 || dialed[0] != "tcp example.com:443" { + t.Fatalf("unexpected base dial invocations: %v", dialed) + } +} + +func TestBuildSOCKS5DialRoutesMatchedDestinationsThroughTunnel(t *testing.T) { + config := &WireGuardConfig{ + Interface: InterfaceConfig{ + Address: "10.150.0.2/24", + }, + Peers: []PeerConfig{ + { + PublicKey: "route-peer", + Endpoint: "route.example.com:51820", + AllowedIPs: []string{"10.200.0.0/16"}, + RoutingPolicies: []RoutingPolicy{ + { + DestinationCIDR: "198.51.100.0/24", + Protocol: "tcp", + PortRange: PortRange{Start: 443, End: 443}, + Priority: 10, + }, + }, + }, + }, + } + + var ( + gotNetwork string + gotAddress string + ) + tunnel := &Tunnel{ + ourIP: mustParseIPAddr("10.150.0.2"), + config: config, + router: NewRoutingEngine(config), + dialFn: func(ctx context.Context, network, address string) (net.Conn, error) { + gotNetwork = network + gotAddress = address + return nil, fmt.Errorf("tunnel dial invoked") + }, + } + + dial := buildSOCKS5Dial(tunnel, 1080, func(ctx context.Context, network, addr string) (net.Conn, error) { + t.Fatalf("base dialer should not be used for routed destination") + return nil, nil + }, nil) + + _, err := dial(context.Background(), "tcp4", "198.51.100.25:443") + if err == nil || err.Error() != "tunnel dial invoked" { + t.Fatalf("expected tunnel dialer error, got %v", err) + } + if gotNetwork != "tcp4" { + t.Fatalf("tunnel dial network = %q, want tcp4", gotNetwork) + } + if gotAddress != "198.51.100.25:443" { + t.Fatalf("tunnel dial address = %q, want 198.51.100.25:443", gotAddress) + } +} + +func TestBuildSOCKS5DialFallsBackToBaseDialerForUnroutedIP(t *testing.T) { + tunnel := &Tunnel{ + ourIP: mustParseIPAddr("10.150.0.2"), + router: NewRoutingEngine(&WireGuardConfig{ + Peers: []PeerConfig{ + { + AllowedIPs: []string{"10.200.0.0/16"}, + }, + }, + }), + } + + var dialed []string + dial := buildSOCKS5Dial(tunnel, 1080, func(ctx context.Context, network, addr string) (net.Conn, error) { + dialed = append(dialed, network+" "+addr) + return nil, fmt.Errorf("base dial invoked") + }, nil) + + _, err := dial(context.Background(), "tcp4", "198.51.100.25:443") + if err == nil || err.Error() != "base dial invoked" { + t.Fatalf("expected base dialer error, got %v", err) + } + if len(dialed) != 1 || dialed[0] != "tcp4 198.51.100.25:443" { + t.Fatalf("unexpected base dial invocations: %v", dialed) + } +} + +func TestBuildSOCKS5DialPropagatesBaseDialFailure(t *testing.T) { + tunnel := &Tunnel{ + ourIP: mustParseIPAddr("10.150.0.2"), + } + + called := false + dial := buildSOCKS5Dial(tunnel, 1080, func(ctx context.Context, network, addr string) (net.Conn, error) { + called = true + return nil, fmt.Errorf("proxy unreachable") + }, nil) + + conn, err := dial(context.Background(), "tcp", "example.com:443") + if err == nil || err.Error() != "proxy unreachable" { + t.Fatalf("expected base dial error, got conn=%v err=%v", conn, err) + } + if conn != nil { + t.Fatalf("expected nil conn on dial failure, got %v", conn) + } + if !called { + t.Fatal("base dialer was not invoked") + } +} + // Helper function to parse IP addresses for testing func mustParseIPAddr(s string) netip.Addr { ip, err := netip.ParseAddr(s) diff --git a/tunnel.go b/tunnel.go index 0cf6408..7e85359 100644 --- a/tunnel.go +++ b/tunnel.go @@ -8,22 +8,26 @@ import ( "net/netip" "os" "strconv" + "strings" "sync" "time" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" ) type Tunnel struct { - device *device.Device - tun *MemoryTUN - ourIP netip.Addr - connMap map[string]*TunnelConn - mutex sync.RWMutex - router *RoutingEngine // Add routing engine - config *WireGuardConfig // Keep config reference + device *device.Device + tun tun.Device + ourIP netip.Addr + connMap map[string]*TunnelConn + mutex sync.RWMutex + router *RoutingEngine // Add routing engine + config *WireGuardConfig // Keep config reference + dialFn func(ctx context.Context, network, address string) (net.Conn, error) + listenFn func(*net.TCPAddr) (net.Listener, error) } type TunnelConn struct { @@ -125,33 +129,43 @@ func (m *MemoryTUN) Close() error { } func NewTunnel(ctx context.Context, config *WireGuardConfig) (*Tunnel, error) { + _ = ctx + // Get our WireGuard IP ourIP, err := config.GetInterfaceIP() if err != nil { return nil, fmt.Errorf("failed to parse interface IP: %w", err) } - // Create memory TUN - memTun := NewMemoryTUN("wg0", 1420) + dnsServers, err := parseDNSAddrs(config.Interface.DNS) + if err != nil { + return nil, fmt.Errorf("failed to parse interface DNS servers: %w", err) + } + + tunDevice, tnet, err := netstack.CreateNetTUN([]netip.Addr{ourIP}, dnsServers, 1420) + if err != nil { + return nil, fmt.Errorf("failed to create userspace netstack tunnel: %w", err) + } tunnel := &Tunnel{ - tun: memTun, + tun: tunDevice, ourIP: ourIP, connMap: make(map[string]*TunnelConn), config: config, router: NewRoutingEngine(config), + dialFn: tnet.DialContext, + listenFn: func(addr *net.TCPAddr) (net.Listener, error) { + return tnet.ListenTCP(addr) + }, } - // Set tunnel reference in TUN for packet handling - memTun.tunnel = tunnel - // Create WireGuard device logger := device.NewLogger( device.LogLevelSilent, fmt.Sprintf("[%s] ", "wg"), ) - dev := device.NewDevice(memTun, conn.NewDefaultBind(), logger) + dev := device.NewDevice(tunDevice, conn.NewDefaultBind(), logger) // Configure device if err := configureDevice(dev, config); err != nil { @@ -244,9 +258,10 @@ func (t *Tunnel) handleIncomingPacket(packet []byte) { // DialContext creates a connection through WireGuard func (t *Tunnel) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - // For now, return an error since we need the WireGuard interface to be configured - // In a full implementation, this would send packets through the WireGuard tunnel - return nil, fmt.Errorf("WireGuard tunnel dial not implemented - requires system WireGuard interface or full TCP/IP stack") + if t.dialFn == nil { + return nil, fmt.Errorf("WireGuard tunnel dialer is not initialized") + } + return t.dialFn(ctx, network, address) } func (t *Tunnel) createTCPSyn(dstIP net.IP, dstPort int) []byte { @@ -278,9 +293,22 @@ func (t *Tunnel) createTCPSyn(dstIP net.IP, dstPort int) []byte { } func (t *Tunnel) Listen(network, address string) (net.Listener, error) { - // For incoming connections, we need to listen on our WireGuard IP - // This is a placeholder - real implementation would handle TCP listening - return net.Listen("tcp", fmt.Sprintf("%s%s", t.ourIP.String(), address)) + if t.listenFn == nil { + return nil, fmt.Errorf("WireGuard tunnel listener is not initialized") + } + + switch normalizeNetworkProtocol(network) { + case "tcp": + default: + return nil, fmt.Errorf("unsupported listen network %q", network) + } + + tcpAddr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + return nil, fmt.Errorf("failed to resolve listen address %q: %w", address, err) + } + + return t.listenFn(tcpAddr) } // IsWireGuardIP checks if an IP is in the WireGuard network @@ -307,39 +335,29 @@ func (t *Tunnel) DialWireGuard(ctx context.Context, network, host, port string) } // Find the appropriate peer using routing engine - peer, peerIdx := t.router.FindPeerForDestination(ip, portNum, network) + peer, peerIdx := t.router.FindPeerForDestination(ip, portNum, normalizeNetworkProtocol(network)) if peer == nil { return nil, fmt.Errorf("no route to %s:%s", host, port) } logger.Debugf("WireGuard tunnel: routing %s:%s through peer %d (endpoint: %s)", host, port, peerIdx, peer.Endpoint) - - // For now, fall back to hostname translation for testing - // In a production system, this would send packets through the WireGuard tunnel - // to the selected peer - var realHost string - switch host { - case "10.150.0.2": - realHost = "node-server-1" - case "10.150.0.3": - realHost = "node-server-2" - default: - // In a real implementation, we would encapsulate and send through the tunnel - // For now, try direct connection as fallback - logger.Warnf("No hostname mapping for %s, attempting direct connection", host) - realHost = host + if t.dialFn == nil { + return nil, fmt.Errorf("WireGuard tunnel dialer is not initialized") } - dialer := &net.Dialer{} - return dialer.DialContext(ctx, network, realHost+":"+port) + address := net.JoinHostPort(host, port) + return t.dialFn(ctx, network, address) } func (t *Tunnel) Close() error { if t.device != nil { t.device.Close() + t.device = nil + t.tun = nil } if t.tun != nil { t.tun.Close() + t.tun = nil } return nil } @@ -385,3 +403,31 @@ func mustParsePort(s string) int { p, _ := strconv.Atoi(s) return p } + +func normalizeNetworkProtocol(network string) string { + switch { + case strings.HasPrefix(network, "tcp"): + return "tcp" + case strings.HasPrefix(network, "udp"): + return "udp" + default: + return network + } +} + +func parseDNSAddrs(entries []string) ([]netip.Addr, error) { + if len(entries) == 0 { + return nil, nil + } + + addrs := make([]netip.Addr, 0, len(entries)) + for _, entry := range entries { + addr, err := netip.ParseAddr(strings.TrimSpace(entry)) + if err != nil { + return nil, fmt.Errorf("invalid DNS address %q: %w", entry, err) + } + addrs = append(addrs, addr) + } + + return addrs, nil +} diff --git a/tunnel_test.go b/tunnel_test.go index 7097d27..233e604 100644 --- a/tunnel_test.go +++ b/tunnel_test.go @@ -2,9 +2,16 @@ package main import ( "context" + "crypto/rand" + "encoding/hex" + "fmt" + "io" "net" + "strings" "testing" "time" + + "golang.org/x/crypto/curve25519" ) func TestNewMemoryTUN(t *testing.T) { @@ -251,49 +258,158 @@ func TestTunnel_DialWireGuard(t *testing.T) { } ourIP, _ := config.GetInterfaceIP() + var ( + gotNetwork string + gotAddress string + ) tunnel := &Tunnel{ ourIP: ourIP, config: config, router: NewRoutingEngine(config), + dialFn: func(ctx context.Context, network, address string) (net.Conn, error) { + gotNetwork = network + gotAddress = address + return nil, fmt.Errorf("dial blocked in test") + }, } - // Use a timeout context to prevent hanging on connection attempts - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - // Test dialing known WireGuard IPs (fallback mode) tests := []struct { name string host string port string + network string expectError bool + wantAddress string }{ - {"node-server-1", "10.150.0.2", "8080", false}, - {"node-server-2", "10.150.0.3", "8080", false}, - {"unknown WireGuard IP", "10.150.0.99", "8080", true}, + {"default-route target", "104.16.185.241", "443", "tcp", false, "104.16.185.241:443"}, + {"overlay target", "10.150.0.3", "8080", "tcp4", false, "10.150.0.3:8080"}, + {"no route", "2001:4860:4860::8888", "53", "udp6", true, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - conn, err := tunnel.DialWireGuard(ctx, "tcp", tt.host, tt.port) + gotNetwork = "" + gotAddress = "" + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + conn, err := tunnel.DialWireGuard(ctx, tt.network, tt.host, tt.port) if tt.expectError { - if err == nil { - t.Error("expected error but got none") - if conn != nil { - conn.Close() - } + if err == nil || err.Error() != fmt.Sprintf("no route to %s:%s", tt.host, tt.port) { + t.Fatalf("expected no-route error, got conn=%v err=%v", conn, err) } - } else { - // Note: This will likely fail in test environment since - // node-server-1 and node-server-2 don't exist, but we test - // that the function doesn't panic and handles the mapping - if err != nil { - // Expected in test environment - t.Logf("DialWireGuard failed as expected in test environment: %v", err) - } else if conn != nil { - conn.Close() + if gotAddress != "" || gotNetwork != "" { + t.Fatalf("dialer should not have been invoked on no-route, got network=%q address=%q", gotNetwork, gotAddress) } + return + } + + if err == nil || err.Error() != "dial blocked in test" { + t.Fatalf("expected test dialer error, got conn=%v err=%v", conn, err) + } + if gotNetwork != tt.network { + t.Fatalf("dial network = %q, want %q", gotNetwork, tt.network) + } + if gotAddress != tt.wantAddress { + t.Fatalf("dial address = %q, want %q", gotAddress, tt.wantAddress) + } + if conn != nil { + conn.Close() + } + }) + } +} + +func TestTunnel_DialContext(t *testing.T) { + tunnel := &Tunnel{ + dialFn: func(ctx context.Context, network, address string) (net.Conn, error) { + if network != "tcp" || address != "203.0.113.10:443" { + t.Fatalf("unexpected dial args: network=%q address=%q", network, address) + } + return nil, fmt.Errorf("dial blocked in test") + }, + } + + _, err := tunnel.DialContext(context.Background(), "tcp", "203.0.113.10:443") + if err == nil || err.Error() != "dial blocked in test" { + t.Fatalf("expected injected dialer error, got %v", err) + } +} + +func TestTunnel_DialContextRequiresDialer(t *testing.T) { + _, err := (&Tunnel{}).DialContext(context.Background(), "tcp", "203.0.113.10:443") + if err == nil || err.Error() != "WireGuard tunnel dialer is not initialized" { + t.Fatalf("expected missing dialer error, got %v", err) + } +} + +func TestTunnel_Listen(t *testing.T) { + tunnel := &Tunnel{ + listenFn: func(addr *net.TCPAddr) (net.Listener, error) { + if addr.Port != 8080 { + t.Fatalf("unexpected listen port: %d", addr.Port) + } + return nil, fmt.Errorf("listen blocked in test") + }, + } + + _, err := tunnel.Listen("tcp", ":8080") + if err == nil || err.Error() != "listen blocked in test" { + t.Fatalf("expected injected listen error, got %v", err) + } +} + +func TestTunnel_ListenRejectsInvalidAddress(t *testing.T) { + tunnel := &Tunnel{ + listenFn: func(addr *net.TCPAddr) (net.Listener, error) { + t.Fatal("listenFn should not be called when address resolution fails") + return nil, nil + }, + } + + _, err := tunnel.Listen("tcp", "not-a-valid-listen-address") + if err == nil || !strings.Contains(err.Error(), "failed to resolve listen address") { + t.Fatalf("expected listen address resolution error, got %v", err) + } +} + +func TestTunnel_ListenRejectsUnsupportedNetwork(t *testing.T) { + tunnel := &Tunnel{} + + _, err := tunnel.Listen("udp", ":8080") + if err == nil || err.Error() != `WireGuard tunnel listener is not initialized` { + t.Fatalf("expected uninitialized listener error, got %v", err) + } +} + +func TestTunnel_DialWireGuardRejectsInvalidInputs(t *testing.T) { + tunnel := &Tunnel{ + router: NewRoutingEngine(&WireGuardConfig{ + Peers: []PeerConfig{ + { + AllowedIPs: []string{"0.0.0.0/0"}, + }, + }, + }), + } + + tests := []struct { + name string + host string + port string + wantErr string + }{ + {name: "invalid-host", host: "example.com", port: "443", wantErr: "invalid IP address: example.com"}, + {name: "invalid-port", host: "203.0.113.10", port: "not-a-port", wantErr: "invalid port: not-a-port"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := tunnel.DialWireGuard(context.Background(), "tcp", tt.host, tt.port) + if err == nil || err.Error() != tt.wantErr { + t.Fatalf("expected %q, got %v", tt.wantErr, err) } }) } @@ -591,3 +707,264 @@ func TestTunnel_Close(t *testing.T) { t.Error("TUN should be closed after tunnel close") } } + +func TestTunnel_EndToEndTCPAcrossWireGuard(t *testing.T) { + serverPriv, serverPub := mustGenerateWireGuardKeyPair(t) + clientPriv, clientPub := mustGenerateWireGuardKeyPair(t) + + serverUDP, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to reserve server UDP port: %v", err) + } + serverPort := serverUDP.LocalAddr().(*net.UDPAddr).Port + _ = serverUDP.Close() + + serverConfig := &WireGuardConfig{ + Interface: InterfaceConfig{ + PrivateKey: serverPriv, + Address: "10.150.0.1/24", + ListenPort: serverPort, + }, + Peers: []PeerConfig{ + { + PublicKey: clientPub, + AllowedIPs: []string{"10.150.0.2/32"}, + }, + }, + } + + clientConfig := &WireGuardConfig{ + Interface: InterfaceConfig{ + PrivateKey: clientPriv, + Address: "10.150.0.2/24", + }, + Peers: []PeerConfig{ + { + PublicKey: serverPub, + Endpoint: fmt.Sprintf("127.0.0.1:%d", serverPort), + AllowedIPs: []string{"10.150.0.0/24"}, + PersistentKeepalive: 1, + }, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + serverTunnel, err := NewTunnel(ctx, serverConfig) + if err != nil { + t.Fatalf("failed to create server tunnel: %v", err) + } + defer serverTunnel.Close() + + clientTunnel, err := NewTunnel(ctx, clientConfig) + if err != nil { + t.Fatalf("failed to create client tunnel: %v", err) + } + defer clientTunnel.Close() + + listener, err := serverTunnel.Listen("tcp", ":8080") + if err != nil { + t.Fatalf("failed to listen over server tunnel: %v", err) + } + defer listener.Close() + + serverErrCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + serverErrCh <- fmt.Errorf("accept failed: %w", err) + return + } + defer conn.Close() + + buf := make([]byte, 32) + n, err := conn.Read(buf) + if err != nil { + serverErrCh <- fmt.Errorf("server read failed: %w", err) + return + } + if string(buf[:n]) != "ping-over-wrapguard" { + serverErrCh <- fmt.Errorf("unexpected payload %q", string(buf[:n])) + return + } + + if _, err := io.WriteString(conn, "pong-from-peer"); err != nil { + serverErrCh <- fmt.Errorf("server write failed: %w", err) + return + } + + serverErrCh <- nil + }() + + var clientConn net.Conn + deadline := time.Now().Add(6 * time.Second) + for { + clientConn, err = clientTunnel.DialWireGuard(ctx, "tcp", "10.150.0.1", "8080") + if err == nil { + break + } + if time.Now().After(deadline) { + t.Fatalf("failed to dial peer over WireGuard: %v", err) + } + time.Sleep(150 * time.Millisecond) + } + defer clientConn.Close() + + if _, err := io.WriteString(clientConn, "ping-over-wrapguard"); err != nil { + t.Fatalf("client write failed: %v", err) + } + + reply := make([]byte, 32) + n, err := clientConn.Read(reply) + if err != nil { + t.Fatalf("client read failed: %v", err) + } + if string(reply[:n]) != "pong-from-peer" { + t.Fatalf("unexpected reply %q", string(reply[:n])) + } + + select { + case err := <-serverErrCh: + if err != nil { + t.Fatal(err) + } + case <-time.After(2 * time.Second): + t.Fatal("server handler did not complete") + } +} + +func TestNormalizeNetworkProtocol(t *testing.T) { + tests := []struct { + input string + want string + }{ + {input: "tcp", want: "tcp"}, + {input: "tcp4", want: "tcp"}, + {input: "tcp6", want: "tcp"}, + {input: "udp", want: "udp"}, + {input: "udp6", want: "udp"}, + {input: "ping", want: "ping"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + if got := normalizeNetworkProtocol(tt.input); got != tt.want { + t.Fatalf("normalizeNetworkProtocol(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestTunnelDialWireGuardNormalizesPolicyProtocol(t *testing.T) { + config := &WireGuardConfig{ + Interface: InterfaceConfig{ + Address: "10.150.0.2/24", + }, + Peers: []PeerConfig{ + { + PublicKey: "policy-peer", + Endpoint: "policy.example.com:51820", + AllowedIPs: []string{"10.200.0.0/16"}, + RoutingPolicies: []RoutingPolicy{ + { + DestinationCIDR: "203.0.113.0/24", + Protocol: "tcp", + PortRange: PortRange{Start: 443, End: 443}, + Priority: 100, + }, + }, + }, + }, + } + + var ( + gotNetwork string + gotAddress string + ) + tunnel := &Tunnel{ + config: config, + router: NewRoutingEngine(config), + dialFn: func(ctx context.Context, network, address string) (net.Conn, error) { + gotNetwork = network + gotAddress = address + return nil, fmt.Errorf("dial blocked in test") + }, + } + + _, err := tunnel.DialWireGuard(context.Background(), "tcp4", "203.0.113.10", "443") + if err == nil || err.Error() != "dial blocked in test" { + t.Fatalf("expected test dialer error, got %v", err) + } + if gotNetwork != "tcp4" { + t.Fatalf("dial network = %q, want tcp4", gotNetwork) + } + if gotAddress != "203.0.113.10:443" { + t.Fatalf("dial address = %q, want 203.0.113.10:443", gotAddress) + } +} + +func TestTunnel_DialWireGuardRejectsHostnames(t *testing.T) { + tunnel := &Tunnel{ + router: NewRoutingEngine(&WireGuardConfig{ + Peers: []PeerConfig{ + { + AllowedIPs: []string{"0.0.0.0/0"}, + }, + }, + }), + } + + _, err := tunnel.DialWireGuard(context.Background(), "tcp", "example.com", "443") + if err == nil || err.Error() != "invalid IP address: example.com" { + t.Fatalf("expected hostname rejection, got %v", err) + } +} + +func TestParseDNSAddrs(t *testing.T) { + got, err := parseDNSAddrs([]string{"8.8.8.8", " 2001:4860:4860::8888 "}) + if err != nil { + t.Fatalf("parseDNSAddrs returned error: %v", err) + } + if len(got) != 2 { + t.Fatalf("expected 2 DNS addresses, got %d", len(got)) + } + if got[0].String() != "8.8.8.8" || got[1].String() != "2001:4860:4860::8888" { + t.Fatalf("unexpected DNS addresses: %v", got) + } +} + +func TestParseDNSAddrsEmptyInput(t *testing.T) { + got, err := parseDNSAddrs(nil) + if err != nil { + t.Fatalf("parseDNSAddrs returned error for empty input: %v", err) + } + if got != nil { + t.Fatalf("expected nil DNS addrs for empty input, got %v", got) + } +} + +func TestParseDNSAddrsInvalid(t *testing.T) { + if _, err := parseDNSAddrs([]string{"not-an-ip"}); err == nil { + t.Fatal("expected invalid DNS address error") + } +} + +func mustGenerateWireGuardKeyPair(t *testing.T) (privateHex, publicHex string) { + t.Helper() + + var privateKey [32]byte + if _, err := rand.Read(privateKey[:]); err != nil { + t.Fatalf("failed to generate private key: %v", err) + } + + privateKey[0] &= 248 + privateKey[31] = (privateKey[31] & 127) | 64 + + publicKey, err := curve25519.X25519(privateKey[:], curve25519.Basepoint) + if err != nil { + t.Fatalf("failed to derive public key: %v", err) + } + + return hex.EncodeToString(privateKey[:]), hex.EncodeToString(publicKey) +} diff --git a/wget-log b/wget-log new file mode 100644 index 0000000..1846b4e --- /dev/null +++ b/wget-log @@ -0,0 +1,11 @@ +--2026-03-23 22:23:25-- https://icanhazip.com/ +Resolving icanhazip.com (icanhazip.com)... 104.16.185.241, 104.16.184.241, 2606:4700::6810:b8f1, ... +Connecting to icanhazip.com (icanhazip.com)|104.16.185.241|:443... connected. +HTTP request sent, awaiting response... 200 OK +Length: 14 [text/plain] +Saving to: ‘index.html’ + + index.html 0%[ ] 0 --.-KB/s index.html 100%[==========================================================>] 14 --.-KB/s in 0s + +2026-03-23 22:23:26 (1.67 MB/s) - ‘index.html’ saved [14/14] +