From cd246acf1e285b14ec35c20bb961e4c0fc384054 Mon Sep 17 00:00:00 2001 From: Akihiro Suda Date: Sat, 28 Feb 2026 03:48:44 +0900 Subject: [PATCH] Preserve real client source IP in builtin port driver via IP_TRANSPARENT Use IP_TRANSPARENT socket option in the child process to bind outgoing connections to the real client IP:port, so backend services see the original source address instead of 127.0.0.1. This leverages CAP_NET_ADMIN in the user namespace (Linux 4.18+) and policy routing to complete TCP handshakes without iptables. Falls back gracefully to normal dial on older kernels or when routing setup fails. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Akihiro Suda --- pkg/port/builtin/builtin_test.go | 1 + pkg/port/builtin/child/child.go | 68 +++++++- pkg/port/builtin/msg/msg.go | 18 ++- pkg/port/builtin/parent/tcp/tcp.go | 2 +- pkg/port/builtin/parent/udp/udp.go | 2 +- pkg/port/testsuite/testsuite.go | 249 ++++++++++++++++++++++++++++- 6 files changed, 325 insertions(+), 15 deletions(-) diff --git a/pkg/port/builtin/builtin_test.go b/pkg/port/builtin/builtin_test.go index 24f1fc84..f5394eff 100644 --- a/pkg/port/builtin/builtin_test.go +++ b/pkg/port/builtin/builtin_test.go @@ -29,4 +29,5 @@ func TestBuiltIn(t *testing.T) { return d } testsuite.Run(t, pf) + testsuite.RunTCPTransparent(t, pf) } diff --git a/pkg/port/builtin/child/child.go b/pkg/port/builtin/child/child.go index f75bcc40..fb070d18 100644 --- a/pkg/port/builtin/child/child.go +++ b/pkg/port/builtin/child/child.go @@ -6,8 +6,11 @@ import ( "io" "net" "os" + "os/exec" "strconv" "strings" + "sync" + "syscall" "golang.org/x/sys/unix" @@ -25,7 +28,8 @@ func NewDriver(logWriter io.Writer) port.ChildDriver { } type childDriver struct { - logWriter io.Writer + logWriter io.Writer + routingSetup sync.Once } func (d *childDriver) RunChildDriver(opaque map[string]string, quit <-chan struct{}, detachedNetNSPath string) error { @@ -119,7 +123,6 @@ func (d *childDriver) handleConnectRequest(c *net.UnixConn, req *msg.Request) er } // dialProto does not need "4", "6" suffix dialProto := strings.TrimSuffix(strings.TrimSuffix(req.Proto, "6"), "4") - var dialer net.Dialer ip := req.IP if ip == "" { ip = "127.0.0.1" @@ -135,9 +138,24 @@ func (d *childDriver) handleConnectRequest(c *net.UnixConn, req *msg.Request) er } ip = p.String() } - targetConn, err := dialer.Dial(dialProto, net.JoinHostPort(ip, strconv.Itoa(req.Port))) - if err != nil { - return err + targetAddr := net.JoinHostPort(ip, strconv.Itoa(req.Port)) + + var targetConn net.Conn + var err error + if req.SourceIP != "" && req.SourcePort != 0 && dialProto == "tcp" { + d.routingSetup.Do(func() { d.setupTransparentRouting() }) + targetConn, err = transparentDial(dialProto, targetAddr, req.SourceIP, req.SourcePort) + if err != nil { + fmt.Fprintf(d.logWriter, "transparent dial failed, falling back: %v\n", err) + targetConn, err = nil, nil + } + } + if targetConn == nil { + var dialer net.Dialer + targetConn, err = dialer.Dial(dialProto, targetAddr) + if err != nil { + return err + } } defer targetConn.Close() // no effect on duplicated FD targetConnFiler, ok := targetConn.(filer) @@ -164,6 +182,46 @@ func (d *childDriver) handleConnectRequest(c *net.UnixConn, req *msg.Request) er return err } +// setupTransparentRouting sets up policy routing so that SYN-ACK packets +// from services to transparent-bound source IPs are routed back via loopback. +// This is safe because the "from 127.0.0.0/8" rule only matches loopback-sourced +// packets, leaving TAP traffic unaffected. +func (d *childDriver) setupTransparentRouting() { + cmds := [][]string{ + {"ip", "route", "add", "local", "default", "dev", "lo", "table", "100"}, + {"ip", "rule", "add", "from", "127.0.0.0/8", "lookup", "100", "priority", "100"}, + {"ip", "-6", "route", "add", "local", "default", "dev", "lo", "table", "100"}, + {"ip", "-6", "rule", "add", "from", "::1/128", "lookup", "100", "priority", "100"}, + } + for _, args := range cmds { + if out, err := exec.Command(args[0], args[1:]...).CombinedOutput(); err != nil { + fmt.Fprintf(d.logWriter, "transparent routing setup: %v: %s\n", err, out) + } + } +} + +// transparentDial dials targetAddr using IP_TRANSPARENT, binding to the given +// source IP and port so the backend service sees the real client address. +func transparentDial(dialProto, targetAddr, sourceIP string, sourcePort int) (net.Conn, error) { + dialer := net.Dialer{ + LocalAddr: &net.TCPAddr{IP: net.ParseIP(sourceIP), Port: sourcePort}, + Control: func(network, address string, c syscall.RawConn) error { + var sockErr error + if err := c.Control(func(fd uintptr) { + if strings.Contains(network, "6") { + sockErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_TRANSPARENT, 1) + } else { + sockErr = unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_TRANSPARENT, 1) + } + }); err != nil { + return err + } + return sockErr + }, + } + return dialer.Dial(dialProto, targetAddr) +} + // filer is implemented by *net.TCPConn and *net.UDPConn type filer interface { File() (f *os.File, err error) diff --git a/pkg/port/builtin/msg/msg.go b/pkg/port/builtin/msg/msg.go index 9d6a8a78..aef2437d 100644 --- a/pkg/port/builtin/msg/msg.go +++ b/pkg/port/builtin/msg/msg.go @@ -25,6 +25,8 @@ type Request struct { Port int ParentIP string HostGatewayIP string + SourceIP string `json:",omitempty"` // real client IP for IP_TRANSPARENT + SourcePort int `json:",omitempty"` // real client port for IP_TRANSPARENT } // Reply may contain FD as OOB @@ -69,7 +71,9 @@ func hostGatewayIP() string { // ConnectToChild connects to the child UNIX socket, and obtains TCP or UDP socket FD // that corresponds to the port spec. -func ConnectToChild(c *net.UnixConn, spec port.Spec) (int, error) { +// sourceAddr is the real client address (e.g., from net.Conn.RemoteAddr()) for IP_TRANSPARENT support. +// Pass nil to skip source IP preservation. +func ConnectToChild(c *net.UnixConn, spec port.Spec, sourceAddr net.Addr) (int, error) { req := Request{ Type: RequestTypeConnect, Proto: spec.Proto, @@ -78,6 +82,10 @@ func ConnectToChild(c *net.UnixConn, spec port.Spec) (int, error) { ParentIP: spec.ParentIP, HostGatewayIP: hostGatewayIP(), } + if tcpAddr, ok := sourceAddr.(*net.TCPAddr); ok && tcpAddr != nil { + req.SourceIP = tcpAddr.IP.String() + req.SourcePort = tcpAddr.Port + } if _, err := lowlevelmsgutil.MarshalToWriter(c, &req); err != nil { return 0, err } @@ -114,7 +122,7 @@ func ConnectToChild(c *net.UnixConn, spec port.Spec) (int, error) { } // ConnectToChildWithSocketPath wraps ConnectToChild -func ConnectToChildWithSocketPath(socketPath string, spec port.Spec) (int, error) { +func ConnectToChildWithSocketPath(socketPath string, spec port.Spec, sourceAddr net.Addr) (int, error) { var dialer net.Dialer conn, err := dialer.Dial("unix", socketPath) if err != nil { @@ -122,13 +130,13 @@ func ConnectToChildWithSocketPath(socketPath string, spec port.Spec) (int, error } defer conn.Close() c := conn.(*net.UnixConn) - return ConnectToChild(c, spec) + return ConnectToChild(c, spec, sourceAddr) } // ConnectToChildWithRetry retries ConnectToChild every (i*5) milliseconds. -func ConnectToChildWithRetry(socketPath string, spec port.Spec, retries int) (int, error) { +func ConnectToChildWithRetry(socketPath string, spec port.Spec, retries int, sourceAddr net.Addr) (int, error) { for i := 0; i < retries; i++ { - fd, err := ConnectToChildWithSocketPath(socketPath, spec) + fd, err := ConnectToChildWithSocketPath(socketPath, spec, sourceAddr) if i == retries-1 && err != nil { return 0, err } diff --git a/pkg/port/builtin/parent/tcp/tcp.go b/pkg/port/builtin/parent/tcp/tcp.go index ddf73d8b..e3053ffd 100644 --- a/pkg/port/builtin/parent/tcp/tcp.go +++ b/pkg/port/builtin/parent/tcp/tcp.go @@ -59,7 +59,7 @@ func Run(socketPath string, spec port.Spec, stopCh <-chan struct{}, stoppedCh ch func copyConnToChild(c net.Conn, socketPath string, spec port.Spec, stopCh <-chan struct{}) error { defer c.Close() // get fd from the child as an SCM_RIGHTS cmsg - fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10) + fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10, c.RemoteAddr()) if err != nil { return err } diff --git a/pkg/port/builtin/parent/udp/udp.go b/pkg/port/builtin/parent/udp/udp.go index 85126de7..2bcd0637 100644 --- a/pkg/port/builtin/parent/udp/udp.go +++ b/pkg/port/builtin/parent/udp/udp.go @@ -26,7 +26,7 @@ func Run(socketPath string, spec port.Spec, stopCh <-chan struct{}, stoppedCh ch Listener: c, BackendDial: func() (*net.UDPConn, error) { // get fd from the child as an SCM_RIGHTS cmsg - fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10) + fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10, nil) if err != nil { return nil, err } diff --git a/pkg/port/testsuite/testsuite.go b/pkg/port/testsuite/testsuite.go index 8e4f89c9..f25afca1 100644 --- a/pkg/port/testsuite/testsuite.go +++ b/pkg/port/testsuite/testsuite.go @@ -1,6 +1,7 @@ package testsuite import ( + "bufio" "bytes" "context" "encoding/json" @@ -21,9 +22,10 @@ import ( ) const ( - reexecKeyMode = "rootlesskit-port-testsuite.mode" - reexecKeyOpaque = "rootlesskit-port-testsuite.opaque" - reexecKeyQuitFD = "rootlesskit-port-testsuite.quitfd" + reexecKeyMode = "rootlesskit-port-testsuite.mode" + reexecKeyOpaque = "rootlesskit-port-testsuite.opaque" + reexecKeyQuitFD = "rootlesskit-port-testsuite.quitfd" + reexecKeyEchoPort = "rootlesskit-port-testsuite.echoport" ) func Main(m *testing.M, cf func() port.ChildDriver) { @@ -31,6 +33,9 @@ func Main(m *testing.M, cf func() port.ChildDriver) { case "": os.Exit(m.Run()) case "child": + case "echoserver": + runEchoServer() + os.Exit(0) default: panic(fmt.Errorf("unknown mode: %q", mode)) } @@ -360,3 +365,241 @@ func isAddrInUse(err error) bool { return strings.Contains(msg, "address already in use") || strings.Contains(msg, "port is busy") } + +// runEchoServer is a re-exec mode that runs a minimal TCP server. +// It listens on 127.0.0.1:, signals readiness by closing fd 3, +// accepts one connection, writes the remote address to stdout, and drains input. +func runEchoServer() { + portStr := os.Getenv(reexecKeyEchoPort) + if portStr == "" { + panic("echoserver: missing " + reexecKeyEchoPort) + } + ln, err := net.Listen("tcp", "127.0.0.1:"+portStr) + if err != nil { + panic(fmt.Errorf("echoserver: listen: %w", err)) + } + defer ln.Close() + // Signal readiness by closing fd 3 + readyW := os.NewFile(3, "ready") + readyW.Close() + + conn, err := ln.Accept() + if err != nil { + panic(fmt.Errorf("echoserver: accept: %w", err)) + } + defer conn.Close() + fmt.Fprintln(os.Stdout, conn.RemoteAddr().String()) + io.Copy(io.Discard, conn) +} + +func RunTCPTransparent(t *testing.T, pf func() port.ParentDriver) { + t.Run("TestTCPTransparent", func(t *testing.T) { TestTCPTransparent(t, pf()) }) +} + +func TestTCPTransparent(t *testing.T, d port.ParentDriver) { + ensureDeps(t, "nsenter") + t.Logf("creating USER+NET namespace") + opaque := d.OpaqueForChild() + opaqueJSON, err := json.Marshal(opaque) + if err != nil { + t.Fatal(err) + } + pr, pw, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + cmd := exec.Command("/proc/self/exe") + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + cmd.Env = append([]string{ + reexecKeyMode + "=child", + reexecKeyOpaque + "=" + string(opaqueJSON), + reexecKeyQuitFD + "=3"}, os.Environ()...) + cmd.SysProcAttr = &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGKILL, + Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNET, + UidMappings: []syscall.SysProcIDMap{ + { + ContainerID: 0, + HostID: os.Geteuid(), + Size: 1, + }, + }, + GidMappings: []syscall.SysProcIDMap{ + { + ContainerID: 0, + HostID: os.Getegid(), + Size: 1, + }, + }, + } + cmd.ExtraFiles = []*os.File{pr} + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + defer func() { + pw.Close() + cmd.Wait() + }() + childPID := cmd.Process.Pid + if out, err := nsenterExec(childPID, "ip", "link", "set", "lo", "up"); err != nil { + t.Fatalf("%v, out=%s", err, string(out)) + } + testTCPTransparentWithPID(t, d, childPID) +} + +func testTCPTransparentWithPID(t *testing.T, d port.ParentDriver, childPID int) { + ensureDeps(t, "nsenter") + const childPort = 80 + + // Start parent driver + initComplete := make(chan struct{}) + quit := make(chan struct{}) + driverErr := make(chan error) + go func() { + cctx := &port.ChildContext{ + IP: nil, + } + driverErr <- d.RunParentDriver(initComplete, quit, cctx) + }() + select { + case <-initComplete: + case err := <-driverErr: + t.Fatal(err) + } + + // Start echo server inside the child namespace + exe, err := os.Executable() + if err != nil { + t.Fatal(err) + } + + // Pipe for readiness signaling (fd 3 in the echo server) + readyR, readyW, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + + // Pipe for capturing stdout (the remote address) + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + + echoCmd := exec.Command("nsenter", "-U", "--preserve-credential", "-n", + "-t", strconv.Itoa(childPID), + exe) + echoCmd.Env = append([]string{ + reexecKeyMode + "=echoserver", + reexecKeyEchoPort + "=" + strconv.Itoa(childPort), + }, os.Environ()...) + echoCmd.Stdout = stdoutW + echoCmd.Stderr = os.Stderr + echoCmd.ExtraFiles = []*os.File{readyW} // fd 3 + echoCmd.SysProcAttr = &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGKILL, + } + if err := echoCmd.Start(); err != nil { + t.Fatal(err) + } + defer echoCmd.Process.Kill() + readyW.Close() + + // Wait for echo server readiness + io.ReadAll(readyR) + readyR.Close() + + // Close the write end of stdout pipe in parent so reads see EOF when echo server exits + stdoutW.Close() + + // Allocate a parent port and add port forwarding + parentPort, err := allocateAvailablePort("tcp") + if err != nil { + t.Fatal(err) + } + + var portStatus *port.Status + const maxAttempts = 10 + for attempt := 0; attempt < maxAttempts; attempt++ { + portStatus, err = d.AddPort(context.TODO(), + port.Spec{ + Proto: "tcp", + ParentIP: "127.0.0.1", + ParentPort: parentPort, + ChildPort: childPort, + }) + if err == nil { + break + } + if attempt == maxAttempts-1 || !isAddrInUse(err) { + t.Fatal(err) + } + parentPort, err = allocateAvailablePort("tcp") + if err != nil { + t.Fatal(err) + } + } + t.Logf("opened port: %+v", portStatus) + + // Dial the parent port + var conn net.Conn + for i := 0; i < 5; i++ { + var dialer net.Dialer + conn, err = dialer.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", parentPort)) + if err == nil { + break + } + time.Sleep(time.Duration(i*5) * time.Millisecond) + } + if err != nil { + t.Fatal(err) + } + + clientAddr := conn.LocalAddr().String() + t.Logf("client local address: %s", clientAddr) + + // Send data and close write side + if _, err := conn.Write([]byte("hello")); err != nil { + t.Fatal(err) + } + if err := conn.(*net.TCPConn).CloseWrite(); err != nil { + t.Fatal(err) + } + + // Read the remote address the echo server saw + scanner := bufio.NewScanner(stdoutR) + if !scanner.Scan() { + t.Fatal("failed to read remote address from echo server") + } + serverSawAddr := scanner.Text() + t.Logf("server saw remote address: %s", serverSawAddr) + + conn.Close() + echoCmd.Wait() + + // Parse and verify: the echo server should see the client's IP and port + clientHost, clientPortStr, err := net.SplitHostPort(clientAddr) + if err != nil { + t.Fatalf("failed to parse client address %q: %v", clientAddr, err) + } + serverHost, serverPortStr, err := net.SplitHostPort(serverSawAddr) + if err != nil { + t.Fatalf("failed to parse server-seen address %q: %v", serverSawAddr, err) + } + + if clientHost != serverHost { + t.Errorf("IP mismatch: client=%s, server saw=%s", clientHost, serverHost) + } + if clientPortStr != serverPortStr { + t.Errorf("port mismatch: client=%s, server saw=%s", clientPortStr, serverPortStr) + } + + // Cleanup + if err := d.RemovePort(context.TODO(), portStatus.ID); err != nil { + t.Fatal(err) + } + quit <- struct{}{} + if err := <-driverErr; err != nil { + t.Fatal(err) + } +}