Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ with a discount using [this referral link](https://iproyal.com/?r=795836)! 🚀
```

```bash
usage: wireproxy [-h|--help] [-c|--config "<value>"] [-s|--silent]
[-d|--daemon] [-i|--info "<value>"] [-v|--version]
[-n|--configtest]
usage: wireproxy [-h|--help] [-c|--config "<value>"] [-s|--silent]
[-v|--verbose] [-d|--daemon] [-i|--info "<value>"]
[-V|--version] [-n|--configtest]

Userspace wireguard client for proxying

Expand All @@ -61,10 +61,11 @@ Arguments:
-h --help Print help information
-c --config Path of configuration file
Default paths: /etc/wireproxy/wireproxy.conf, $HOME/.config/wireproxy.conf
-s --silent Silent mode
-s --silent Logging: Set silent mode
-v --verbose Logging: Set verbose mode
-d --daemon Make wireproxy run in background
-i --info Specify the address and port for exposing health status
-v --version Print version
-V --version Print version
-n --configtest Configtest mode. Only check the configuration file for
validity.
```
Expand Down
14 changes: 10 additions & 4 deletions cmd/wireproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,12 @@ func main() {
parser := argparse.NewParser("wireproxy", "Userspace wireguard client for proxying")

config := parser.String("c", "config", &argparse.Options{Help: "Path of configuration file"})
silent := parser.Flag("s", "silent", &argparse.Options{Help: "Silent mode"})
silent := parser.Flag("s", "silent", &argparse.Options{Help: "Logging: Set silent mode"})
verbose := parser.Flag("v", "verbose", &argparse.Options{Help: "Logging: Set verbose mode"})

daemon := parser.Flag("d", "daemon", &argparse.Options{Help: "Make wireproxy run in background"})
info := parser.String("i", "info", &argparse.Options{Help: "Specify the address and port for exposing health status"})
printVerison := parser.Flag("v", "version", &argparse.Options{Help: "Print version"})
printVerison := parser.Flag("V", "version", &argparse.Options{Help: "Print version"})
configTest := parser.Flag("n", "configtest", &argparse.Options{Help: "Configtest mode. Only check the configuration file for validity."})

err := parser.Parse(args)
Expand Down Expand Up @@ -238,14 +240,18 @@ func main() {
// https://github.com/WireGuard/wireguard-go/blob/master/device/logger.go#L39
// so redirect STDOUT to STDERR, we don't want to print anything to STDOUT anyways
os.Stdout = os.NewFile(uintptr(syscall.Stderr), "/dev/stderr")
logLevel := device.LogLevelVerbose
logLevel := device.LogLevelError
if *silent {
logLevel = device.LogLevelSilent

} else if *verbose {
logLevel = device.LogLevelVerbose
}

lock("ready")
var errorLogger = device.NewLogger(logLevel, "wireproxy - ")

tun, err := wireproxy.StartWireguard(conf.Device, logLevel)
tun, err := wireproxy.StartWireguard(conf.Device, errorLogger)
if err != nil {
log.Fatal(err)
}
Expand Down
22 changes: 14 additions & 8 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import (
"encoding/base64"
"fmt"
"io"
"log"

//"log"
"net"
"net/http"
"strings"
Expand All @@ -21,6 +22,7 @@ type HTTPServer struct {
dial func(network, address string) (net.Conn, error)

authRequired bool
vtun *VirtualTun
}

func (s *HTTPServer) authenticate(req *http.Request) (int, error) {
Expand Down Expand Up @@ -55,6 +57,7 @@ func (s *HTTPServer) handleConn(req *http.Request, conn net.Conn) (peer net.Conn
addr = net.JoinHostPort(addr, port)
}

s.vtun.logger.Verbosef("Got HTTP Connect to %s", addr)
peer, err = s.dial("tcp", addr)
if err != nil {
return peer, fmt.Errorf("tun tcp dial failed: %w", err)
Expand All @@ -76,6 +79,7 @@ func (s *HTTPServer) handle(req *http.Request) (peer net.Conn, err error) {
addr = net.JoinHostPort(addr, port)
}

s.vtun.logger.Verbosef("Got HTTP GET to %s", addr)
peer, err = s.dial("tcp", addr)
if err != nil {
return peer, fmt.Errorf("tun tcp dial failed: %w", err)
Expand All @@ -95,7 +99,7 @@ func (s *HTTPServer) serve(conn net.Conn) {
var rd = bufio.NewReader(conn)
req, err := http.ReadRequest(rd)
if err != nil {
log.Printf("read request failed: %s\n", err)
s.vtun.logger.Errorf("read request failed: %s", err)
return
}

Expand All @@ -106,7 +110,7 @@ func (s *HTTPServer) serve(conn net.Conn) {
resp.Header.Set("Proxy-Authenticate", "Basic realm=\"Proxy\"")
}
_ = resp.Write(conn)
log.Println(err)
s.vtun.logger.Errorf("authenticate failed: %s", err)
return
}

Expand All @@ -118,15 +122,15 @@ func (s *HTTPServer) serve(conn net.Conn) {
peer, err = s.handle(req)
default:
_ = responseWith(req, http.StatusMethodNotAllowed).Write(conn)
log.Printf("unsupported protocol: %s\n", req.Method)
s.vtun.logger.Errorf("unsupported protocol: %s", req.Method)
return
}
if err != nil {
log.Printf("dial proxy failed: %s\n", err)
s.vtun.logger.Errorf("dial proxy failed: %s", err)
return
}
if peer == nil {
log.Println("dial proxy failed: peer nil")
s.vtun.logger.Errorf("dial proxy failed: peer nil")
return
}

Expand All @@ -149,15 +153,17 @@ func (s *HTTPServer) serve(conn net.Conn) {
func (s *HTTPServer) ListenAndServe(network, addr string) error {
server, err := net.Listen(network, addr)
if err != nil {
return fmt.Errorf("listen tcp failed: %w", err)
s.vtun.logger.Errorf("listen tcp failed: %w", err)
return err
}
defer func(server net.Listener) {
_ = server.Close()
}(server)
for {
conn, err := server.Accept()
if err != nil {
return fmt.Errorf("accept request failed: %w", err)
s.vtun.logger.Errorf("accept request failed: %w", err)
return err
}
go func(conn net.Conn) {
s.serve(conn)
Expand Down
60 changes: 31 additions & 29 deletions routine.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
)

// errorLogger is the logger to print error message
var errorLogger = log.New(os.Stderr, "ERROR: ", log.LstdFlags)
//var errorLogger = log.New(os.Stderr, "ERROR: ", log.LstdFlags)

// CredentialValidator stores the authentication data of a socks5 proxy
type CredentialValidator struct {
Expand All @@ -51,6 +51,7 @@ type VirtualTun struct {
// PingRecord stores the last time an IP was pinged
PingRecord map[string]uint64
PingRecordLock *sync.Mutex
logger *device.Logger
}

// RoutineSpawner spawns a routine (e.g. socks5, tcp static routes) after the configuration is parsed
Expand Down Expand Up @@ -111,7 +112,7 @@ func (d VirtualTun) Resolve(ctx context.Context, name string) (context.Context,
if err != nil {
return nil, nil, err
}

d.logger.Verbosef("Resolved %s to %s\n", name, addr)
return ctx, addr.AsSlice(), nil
}

Expand Down Expand Up @@ -170,6 +171,7 @@ func (config *HTTPConfig) SpawnRoutine(vt *VirtualTun) {
config: config,
dial: vt.Tnet.Dial,
auth: CredentialValidator{config.Username, config.Password},
vtun: vt,
}
if config.Username != "" || config.Password != "" {
server.authRequired = true
Expand All @@ -189,60 +191,60 @@ func (c CredentialValidator) Valid(username, password string) bool {
}

// connForward copy data from `from` to `to`
func connForward(from io.ReadWriteCloser, to io.ReadWriteCloser) {
func connForward(from io.ReadWriteCloser, to io.ReadWriteCloser, logger *device.Logger) {
defer from.Close()
defer to.Close()

_, err := io.Copy(to, from)
if err != nil {
errorLogger.Printf("Cannot forward traffic: %s\n", err.Error())
logger.Errorf("Cannot forward traffic: %s\n", err.Error())
}
}

// tcpClientForward starts a new connection via wireguard and forward traffic from `conn`
func tcpClientForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
target, err := vt.resolveToAddrPort(raddr)
if err != nil {
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
vt.logger.Errorf("TCP Server Tunnel to %s: %s\n", target, err.Error())
return
}

tcpAddr := TCPAddrFromAddrPort(*target)

sconn, err := vt.Tnet.DialTCP(tcpAddr)
if err != nil {
errorLogger.Printf("TCP Client Tunnel to %s: %s\n", target, err.Error())
vt.logger.Errorf("TCP Client Tunnel to %s: %s\n", target, err.Error())
return
}

go connForward(sconn, conn)
go connForward(conn, sconn)
go connForward(sconn, conn, vt.logger)
go connForward(conn, sconn, vt.logger)
}

// STDIOTcpForward starts a new connection via wireguard and forward traffic from `conn`
func STDIOTcpForward(vt *VirtualTun, raddr *addressPort) {
target, err := vt.resolveToAddrPort(raddr)
if err != nil {
errorLogger.Printf("Name resolution error for %s: %s\n", raddr.address, err.Error())
vt.logger.Errorf("Name resolution error for %s: %s\n", raddr.address, err.Error())
return
}

// os.Stdout has previously been remapped to stderr, se we can't use it
stdout, err := os.OpenFile("/dev/stdout", os.O_WRONLY, 0)
if err != nil {
errorLogger.Printf("Failed to open /dev/stdout: %s\n", err.Error())
vt.logger.Errorf("Failed to open /dev/stdout: %s\n", err.Error())
return
}

tcpAddr := TCPAddrFromAddrPort(*target)
sconn, err := vt.Tnet.DialTCP(tcpAddr)
if err != nil {
errorLogger.Printf("TCP Client Tunnel to %s (%s): %s\n", target, tcpAddr, err.Error())
vt.logger.Errorf("TCP Client Tunnel to %s (%s): %s\n", target, tcpAddr, err.Error())
return
}

go connForward(os.Stdin, sconn)
go connForward(sconn, stdout)
go connForward(os.Stdin, sconn, vt.logger)
go connForward(sconn, stdout, vt.logger)
}

// SpawnRoutine spawns a local TCP server which acts as a proxy to the specified target
Expand Down Expand Up @@ -280,20 +282,20 @@ func (conf *STDIOTunnelConfig) SpawnRoutine(vt *VirtualTun) {
func tcpServerForward(vt *VirtualTun, raddr *addressPort, conn net.Conn) {
target, err := vt.resolveToAddrPort(raddr)
if err != nil {
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
vt.logger.Errorf("TCP Server Tunnel to %s: %s\n", target, err.Error())
return
}

tcpAddr := TCPAddrFromAddrPort(*target)

sconn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
errorLogger.Printf("TCP Server Tunnel to %s: %s\n", target, err.Error())
vt.logger.Errorf("TCP Server Tunnel to %s: %s\n", target, err.Error())
return
}

go connForward(sconn, conn)
go connForward(conn, sconn)
go connForward(sconn, conn, vt.logger)
go connForward(conn, sconn, vt.logger)

}

Expand All @@ -320,12 +322,12 @@ func (conf *TCPServerTunnelConfig) SpawnRoutine(vt *VirtualTun) {
}

func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Printf("Health metric request: %s\n", r.URL.Path)
d.logger.Verbosef("Health metric request: %s\n", r.URL.Path)
switch path.Clean(r.URL.Path) {
case "/readyz":
body, err := json.Marshal(d.PingRecord)
if err != nil {
errorLogger.Printf("Failed to get device metrics: %s\n", err.Error())
d.logger.Errorf("Failed to get device metrics: %s\n", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
Expand All @@ -346,7 +348,7 @@ func (d VirtualTun) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "/metrics":
get, err := d.Dev.IpcGet()
if err != nil {
errorLogger.Printf("Failed to get device metrics: %s\n", err.Error())
d.logger.Errorf("Failed to get device metrics: %s\n", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -377,7 +379,7 @@ func (d VirtualTun) pingIPs() {
for _, addr := range d.Conf.CheckAlive {
socket, err := d.Tnet.Dial("ping", addr.String())
if err != nil {
errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error())
d.logger.Errorf("Failed to ping %s: %s\n", addr, err.Error())
continue
}

Expand All @@ -395,54 +397,54 @@ func (d VirtualTun) pingIPs() {
} else if addr.Is6() {
icmpBytes, _ = (&icmp.Message{Type: ipv6.ICMPTypeEchoRequest, Code: 0, Body: &requestPing}).Marshal(nil)
} else {
errorLogger.Printf("Failed to ping %s: invalid address: %s\n", addr, addr.String())
d.logger.Errorf("Failed to ping %s: invalid address: %s\n", addr, addr.String())
continue
}

_ = socket.SetReadDeadline(time.Now().Add(time.Duration(d.Conf.CheckAliveInterval) * time.Second))
_, err = socket.Write(icmpBytes)
if err != nil {
errorLogger.Printf("Failed to ping %s: %s\n", addr, err.Error())
d.logger.Errorf("Failed to ping %s: %s\n", addr, err.Error())
continue
}

addr := addr
go func() {
n, err := socket.Read(icmpBytes[:])
if err != nil {
errorLogger.Printf("Failed to read ping response from %s: %s\n", addr, err.Error())
d.logger.Errorf("Failed to read ping response from %s: %s\n", addr, err.Error())
return
}

replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
if err != nil {
errorLogger.Printf("Failed to parse ping response from %s: %s\n", addr, err.Error())
d.logger.Errorf("Failed to parse ping response from %s: %s\n", addr, err.Error())
return
}

if addr.Is4() {
replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok {
errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type)
d.logger.Errorf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type)
return
}
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing)
d.logger.Errorf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing)
return
}
}

if addr.Is6() {
replyPing, ok := replyPacket.Body.(*icmp.RawBody)
if !ok {
errorLogger.Printf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type)
d.logger.Errorf("Failed to parse ping response from %s: invalid reply type: %s\n", addr, replyPacket.Type)
return
}

seq := binary.BigEndian.Uint16(replyPing.Data[2:4])
pongBody := replyPing.Data[4:]
if !bytes.Equal(pongBody, requestPing.Data) || int(seq) != requestPing.Seq {
errorLogger.Printf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing)
d.logger.Errorf("Failed to parse ping response from %s: invalid ping reply: %v\n", addr, replyPing)
return
}
}
Expand Down
Loading
Loading