diff --git a/README.md b/README.md index 5c774e0e..03c8d524 100644 --- a/README.md +++ b/README.md @@ -50,9 +50,9 @@ with a discount using [this referral link](https://iproyal.com/?r=795836)! 🚀 ``` ```bash -usage: wireproxy [-h|--help] [-c|--config ""] [-s|--silent] - [-d|--daemon] [-i|--info ""] [-v|--version] - [-n|--configtest] +usage: wireproxy [-h|--help] [-c|--config ""] [-s|--silent] + [-v|--verbose] [-d|--daemon] [-i|--info ""] + [-V|--version] [-n|--configtest] Userspace wireguard client for proxying @@ -61,10 +61,11 @@ Arguments: -h --help Print help information -c --config Path of configuration file Default paths: /etc/wireproxy/wireproxy.conf, $HOME/.config/wireproxy.conf - -s --silent Silent mode + -s --silent Logging: Set silent mode + -v --verbose Logging: Set verbose mode -d --daemon Make wireproxy run in background -i --info Specify the address and port for exposing health status - -v --version Print version + -V --version Print version -n --configtest Configtest mode. Only check the configuration file for validity. ``` diff --git a/cmd/wireproxy/main.go b/cmd/wireproxy/main.go index 1a383a72..fbf9f96e 100644 --- a/cmd/wireproxy/main.go +++ b/cmd/wireproxy/main.go @@ -176,10 +176,12 @@ func main() { parser := argparse.NewParser("wireproxy", "Userspace wireguard client for proxying") config := parser.String("c", "config", &argparse.Options{Help: "Path of configuration file"}) - silent := parser.Flag("s", "silent", &argparse.Options{Help: "Silent mode"}) + silent := parser.Flag("s", "silent", &argparse.Options{Help: "Logging: Set silent mode"}) + verbose := parser.Flag("v", "verbose", &argparse.Options{Help: "Logging: Set verbose mode"}) + daemon := parser.Flag("d", "daemon", &argparse.Options{Help: "Make wireproxy run in background"}) info := parser.String("i", "info", &argparse.Options{Help: "Specify the address and port for exposing health status"}) - printVerison := parser.Flag("v", "version", &argparse.Options{Help: "Print version"}) + printVerison := parser.Flag("V", "version", &argparse.Options{Help: "Print version"}) configTest := parser.Flag("n", "configtest", &argparse.Options{Help: "Configtest mode. Only check the configuration file for validity."}) err := parser.Parse(args) @@ -238,14 +240,18 @@ func main() { // https://github.com/WireGuard/wireguard-go/blob/master/device/logger.go#L39 // so redirect STDOUT to STDERR, we don't want to print anything to STDOUT anyways os.Stdout = os.NewFile(uintptr(syscall.Stderr), "/dev/stderr") - logLevel := device.LogLevelVerbose + logLevel := device.LogLevelError if *silent { logLevel = device.LogLevelSilent + + } else if *verbose { + logLevel = device.LogLevelVerbose } lock("ready") + var errorLogger = device.NewLogger(logLevel, "wireproxy - ") - tun, err := wireproxy.StartWireguard(conf.Device, logLevel) + tun, err := wireproxy.StartWireguard(conf.Device, errorLogger) if err != nil { log.Fatal(err) } diff --git a/http.go b/http.go index 88a7ef48..1636adbc 100644 --- a/http.go +++ b/http.go @@ -6,7 +6,8 @@ import ( "encoding/base64" "fmt" "io" - "log" + + //"log" "net" "net/http" "strings" @@ -21,6 +22,7 @@ type HTTPServer struct { dial func(network, address string) (net.Conn, error) authRequired bool + vtun *VirtualTun } func (s *HTTPServer) authenticate(req *http.Request) (int, error) { @@ -55,6 +57,7 @@ func (s *HTTPServer) handleConn(req *http.Request, conn net.Conn) (peer net.Conn addr = net.JoinHostPort(addr, port) } + s.vtun.logger.Verbosef("Got HTTP Connect to %s", addr) peer, err = s.dial("tcp", addr) if err != nil { return peer, fmt.Errorf("tun tcp dial failed: %w", err) @@ -76,6 +79,7 @@ func (s *HTTPServer) handle(req *http.Request) (peer net.Conn, err error) { addr = net.JoinHostPort(addr, port) } + s.vtun.logger.Verbosef("Got HTTP GET to %s", addr) peer, err = s.dial("tcp", addr) if err != nil { return peer, fmt.Errorf("tun tcp dial failed: %w", err) @@ -95,7 +99,7 @@ func (s *HTTPServer) serve(conn net.Conn) { var rd = bufio.NewReader(conn) req, err := http.ReadRequest(rd) if err != nil { - log.Printf("read request failed: %s\n", err) + s.vtun.logger.Errorf("read request failed: %s", err) return } @@ -106,7 +110,7 @@ func (s *HTTPServer) serve(conn net.Conn) { resp.Header.Set("Proxy-Authenticate", "Basic realm=\"Proxy\"") } _ = resp.Write(conn) - log.Println(err) + s.vtun.logger.Errorf("authenticate failed: %s", err) return } @@ -118,15 +122,15 @@ func (s *HTTPServer) serve(conn net.Conn) { peer, err = s.handle(req) default: _ = responseWith(req, http.StatusMethodNotAllowed).Write(conn) - log.Printf("unsupported protocol: %s\n", req.Method) + s.vtun.logger.Errorf("unsupported protocol: %s", req.Method) return } if err != nil { - log.Printf("dial proxy failed: %s\n", err) + s.vtun.logger.Errorf("dial proxy failed: %s", err) return } if peer == nil { - log.Println("dial proxy failed: peer nil") + s.vtun.logger.Errorf("dial proxy failed: peer nil") return } @@ -149,7 +153,8 @@ func (s *HTTPServer) serve(conn net.Conn) { func (s *HTTPServer) ListenAndServe(network, addr string) error { server, err := net.Listen(network, addr) if err != nil { - return fmt.Errorf("listen tcp failed: %w", err) + s.vtun.logger.Errorf("listen tcp failed: %w", err) + return err } defer func(server net.Listener) { _ = server.Close() @@ -157,7 +162,8 @@ func (s *HTTPServer) ListenAndServe(network, addr string) error { for { conn, err := server.Accept() if err != nil { - return fmt.Errorf("accept request failed: %w", err) + s.vtun.logger.Errorf("accept request failed: %w", err) + return err } go func(conn net.Conn) { s.serve(conn) diff --git a/routine.go b/routine.go index 855758a6..99536b12 100644 --- a/routine.go +++ b/routine.go @@ -34,7 +34,7 @@ import ( ) // errorLogger is the logger to print error message -var errorLogger = log.New(os.Stderr, "ERROR: ", log.LstdFlags) +//var errorLogger = log.New(os.Stderr, "ERROR: ", log.LstdFlags) // CredentialValidator stores the authentication data of a socks5 proxy type CredentialValidator struct { @@ -51,6 +51,7 @@ type VirtualTun struct { // PingRecord stores the last time an IP was pinged PingRecord map[string]uint64 PingRecordLock *sync.Mutex + logger *device.Logger } // RoutineSpawner spawns a routine (e.g. socks5, tcp static routes) after the configuration is parsed @@ -111,7 +112,7 @@ func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context, if err != nil { return nil, nil, err } - + d.logger.Verbosef("Resolved %s to %s\n", name, addr) return ctx, addr.AsSlice(), nil } @@ -170,6 +171,7 @@ func (config *HTTPConfig) SpawnRoutine(vt *VirtualTun) { config: config, dial: vt.Tnet.Dial, auth: CredentialValidator{config.Username, config.Password}, + vtun: vt, } if config.Username != "" || config.Password != "" { server.authRequired = true @@ -189,13 +191,13 @@ func (c CredentialValidator) Valid(username, password string) bool { } // connForward copy data from `from` to `to` -func connForward(from io.ReadWriteCloser, to io.ReadWriteCloser) { +func connForward(from io.ReadWriteCloser, to io.ReadWriteCloser, logger *device.Logger) { defer from.Close() defer to.Close() _, err := io.Copy(to, from) if err != nil { - errorLogger.Printf("Cannot forward traffic: %s\n", err.Error()) + logger.Errorf("Cannot forward traffic: %s\n", err.Error()) } } @@ -203,7 +205,7 @@ func connForward(from io.ReadWriteCloser, to io.ReadWriteCloser) { func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { target, err := vt.resolveToAddrPort(raddr) if err != nil { - errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error()) + vt.logger.Errorf("TCP Server Tunnel to %s: %s\n", target, err.Error()) return } @@ -211,38 +213,38 @@ func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { sconn, err := vt.Tnet.DialTCP(tcpAddr) if err != nil { - errorLogger.Printf("TCP Client Tunnel to %s: %s\n", target, err.Error()) + vt.logger.Errorf("TCP Client Tunnel to %s: %s\n", target, err.Error()) return } - go connForward(sconn, conn) - go connForward(conn, sconn) + go connForward(sconn, conn, vt.logger) + go connForward(conn, sconn, vt.logger) } // STDIOTcpForward starts a new connection via wireguard and forward traffic from `conn` func STDIOTcpForward(vt *VirtualTun, raddr *addressPort) { target, err := vt.resolveToAddrPort(raddr) if err != nil { - errorLogger.Printf("Name resolution error for %s: %s\n", raddr.address, err.Error()) + vt.logger.Errorf("Name resolution error for %s: %s\n", raddr.address, err.Error()) return } // os.Stdout has previously been remapped to stderr, se we can't use it stdout, err := os.OpenFile("/dev/stdout", os.O_WRONLY, 0) if err != nil { - errorLogger.Printf("Failed to open /dev/stdout: %s\n", err.Error()) + vt.logger.Errorf("Failed to open /dev/stdout: %s\n", err.Error()) return } tcpAddr := TCPAddrFromAddrPort(*target) sconn, err := vt.Tnet.DialTCP(tcpAddr) if err != nil { - errorLogger.Printf("TCP Client Tunnel to %s (%s): %s\n", target, tcpAddr, err.Error()) + vt.logger.Errorf("TCP Client Tunnel to %s (%s): %s\n", target, tcpAddr, err.Error()) return } - go connForward(os.Stdin, sconn) - go connForward(sconn, stdout) + go connForward(os.Stdin, sconn, vt.logger) + go connForward(sconn, stdout, vt.logger) } // SpawnRoutine spawns a local TCP server which acts as a proxy to the specified target @@ -280,7 +282,7 @@ func (conf *STDIOTunnelConfig) SpawnRoutine(vt *VirtualTun) { func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { target, err := vt.resolveToAddrPort(raddr) if err != nil { - errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error()) + vt.logger.Errorf("TCP Server Tunnel to %s: %s\n", target, err.Error()) return } @@ -288,12 +290,12 @@ func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) { sconn, err := net.DialTCP("tcp", nil, tcpAddr) if err != nil { - errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error()) + vt.logger.Errorf("TCP Server Tunnel to %s: %s\n", target, err.Error()) return } - go connForward(sconn, conn) - go connForward(conn, sconn) + go connForward(sconn, conn, vt.logger) + go connForward(conn, sconn, vt.logger) } @@ -320,12 +322,12 @@ func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) { } func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) { - log.Printf("Health metric request: %s\n", r.URL.Path) + d.logger.Verbosef("Health metric request: %s\n", r.URL.Path) switch path.Clean(r.URL.Path) { case "/readyz": body, err := json.Marshal(d.PingRecord) if err != nil { - errorLogger.Printf("Failed to get device metrics: %s\n", err.Error()) + d.logger.Errorf("Failed to get device metrics: %s\n", err.Error()) w.WriteHeader(http.StatusInternalServerError) return } @@ -346,7 +348,7 @@ func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "/metrics": get, err := d.Dev.IpcGet() if err != nil { - errorLogger.Printf("Failed to get device metrics: %s\n", err.Error()) + d.logger.Errorf("Failed to get device metrics: %s\n", err.Error()) w.WriteHeader(http.StatusInternalServerError) return } @@ -377,7 +379,7 @@ func (d VirtualTun) pingIPs() { for _, addr := range d.Conf.CheckAlive { socket, err := d.Tnet.Dial("ping", addr.String()) if err != nil { - errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error()) + d.logger.Errorf("Failed to ping %s: %s\n", addr, err.Error()) continue } @@ -395,14 +397,14 @@ func (d VirtualTun) pingIPs() { } else if addr.Is6() { icmpBytes, _ = (&icmp.Message{Type: ipv6.ICMPTypeEchoRequest, Code: 0, Body: &requestPing}).Marshal(nil) } else { - errorLogger.Printf("Failed to ping %s: invalid address: %s\n", addr, addr.String()) + d.logger.Errorf("Failed to ping %s: invalid address: %s\n", addr, addr.String()) continue } _ = socket.SetReadDeadline(time.Now().Add(time.Duration(d.Conf.CheckAliveInterval) * time.Second)) _, err = socket.Write(icmpBytes) if err != nil { - errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error()) + d.logger.Errorf("Failed to ping %s: %s\n", addr, err.Error()) continue } @@ -410,24 +412,24 @@ func (d VirtualTun) pingIPs() { go func() { n, err := socket.Read(icmpBytes[:]) if err != nil { - errorLogger.Printf("Failed to read ping response from %s: %s\n", addr, err.Error()) + d.logger.Errorf("Failed to read ping response from %s: %s\n", addr, err.Error()) return } replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) if err != nil { - errorLogger.Printf("Failed to parse ping response from %s: %s\n", addr, err.Error()) + d.logger.Errorf("Failed to parse ping response from %s: %s\n", addr, err.Error()) return } if addr.Is4() { replyPing, ok := replyPacket.Body.(*icmp.Echo) if !ok { - errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type) + d.logger.Errorf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type) return } if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { - errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing) + d.logger.Errorf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing) return } } @@ -435,14 +437,14 @@ func (d VirtualTun) pingIPs() { if addr.Is6() { replyPing, ok := replyPacket.Body.(*icmp.RawBody) if !ok { - errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type) + d.logger.Errorf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type) return } seq := binary.BigEndian.Uint16(replyPing.Data[2:4]) pongBody := replyPing.Data[4:] if !bytes.Equal(pongBody, requestPing.Data) || int(seq) != requestPing.Seq { - errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing) + d.logger.Errorf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing) return } } diff --git a/wireguard.go b/wireguard.go index 71a2960e..de5443bb 100644 --- a/wireguard.go +++ b/wireguard.go @@ -60,7 +60,7 @@ func CreateIPCRequest(conf *DeviceConfig) (*DeviceSetting, error) { } // StartWireguard creates a tun interface on netstack given a configuration -func StartWireguard(conf *DeviceConfig, logLevel int) (*VirtualTun, error) { +func StartWireguard(conf *DeviceConfig, logger *device.Logger) (*VirtualTun, error) { setting, err := CreateIPCRequest(conf) if err != nil { return nil, err @@ -70,7 +70,7 @@ func StartWireguard(conf *DeviceConfig, logLevel int) (*VirtualTun, error) { if err != nil { return nil, err } - dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(logLevel, "")) + dev := device.NewDevice(tun, conn.NewDefaultBind(), logger) err = dev.IpcSet(setting.IpcRequest) if err != nil { return nil, err @@ -88,5 +88,6 @@ func StartWireguard(conf *DeviceConfig, logLevel int) (*VirtualTun, error) { SystemDNS: len(setting.DNS) == 0, PingRecord: make(map[string]uint64), PingRecordLock: new(sync.Mutex), + logger: logger, }, nil }