diff --git a/cmd/cmd.go b/cmd/cmd.go index dbb55a3..4e46d87 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,33 +1,62 @@ package cmd import ( + "context" "encoding/json" "fmt" "io" "log" + "net" "net/http" "net/url" "runtime" + "strconv" "strings" "time" + "github.com/olekukonko/tablewriter" + "github.com/olekukonko/tablewriter/tw" "github.com/spf13/cobra" "github.com/ustclug/rsync-proxy/pkg/server" ) +const DefaultUnixSocketPath = "/run/rsync-proxy/rsync-proxy.sock" + var ( Version = "0.0.0" GitCommit = "$Format:%H$" // sha1 from git, output of $(git rev-parse HEAD) BuildDate = "1970-01-01T00:00:00Z" // build date in ISO8601 format, output of $(date -u +'%Y-%m-%dT%H:%M:%SZ') + + daemonSocket = DefaultUnixSocketPath + dialer = &net.Dialer{} ) -func SendReloadRequest(addr string, stdout, stderr io.Writer) error { - client := http.Client{ +func makeHttpClient(addr string) *http.Client { + addrFamily := "tcp" + if strings.HasPrefix(addr, "/") { + addrFamily = "unix" + } + return &http.Client{ Timeout: time.Second * 10, + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, addrFamily, addr) + }, + }, } +} + +func httpGet(addr string, path string) (*http.Response, error) { + return makeHttpClient(addr).Get("http://." + path) +} + +func httpPost(addr string, path string, contentType string, body io.Reader) (*http.Response, error) { + return makeHttpClient(addr).Post("http://."+path, contentType, body) +} - resp, err := client.Post(fmt.Sprintf("http://%s/reload", addr), "application/json", nil) +func SendReloadRequest(addr string, stdout, stderr io.Writer) error { + resp, err := httpPost(addr, "/reload", "application/json", nil) if err != nil { return err } @@ -44,7 +73,7 @@ func SendReloadRequest(addr string, stdout, stderr io.Writer) error { } func SendConnectionsRequest(addr string, stdout, stderr io.Writer) error { - resp, err := http.Get(fmt.Sprintf("http://%s/status", addr)) + resp, err := httpGet(addr, "/status") if err != nil { return err } @@ -77,19 +106,43 @@ func SendConnectionsRequest(addr string, stdout, stderr io.Writer) error { return nil } - _, _ = fmt.Fprintln(stdout, "=== Active Connections ===") + table := tablewriter.NewTable( + stdout, + tablewriter.WithRendition(tw.Rendition{ + Borders: tw.BorderNone, + Settings: tw.Settings{ + Lines: tw.LinesNone, + Separators: tw.SeparatorsNone, + }, + }), + tablewriter.WithPadding(tw.Padding{ + Right: " ", + Overwrite: true, + }), + tablewriter.WithHeaderAutoFormat(tw.Off), + tablewriter.WithAlignment(tw.Alignment{ + tw.AlignRight, // Index + tw.AlignRight, // RemoteAddr + tw.AlignDefault, // Module + tw.AlignRight, // UpstreamAddr + tw.AlignDefault, // ConnectedAt + tw.AlignRight, // ReceivedBytes + tw.AlignRight, // SentBytes + }), + ) + table.Header("Index", "Remote", "Module", "Upstream", "Connected", "Received", "Sent") for _, conn := range result.Connections { - _, _ = fmt.Fprintf(stdout, "Index: %d, Addr: %s, Module: %s, Upstream: %s, Connected: %s, Recv: %d bytes, Send: %d bytes\n", - conn.Index, + _ = table.Append([]string{ + strconv.Itoa(conn.Index), conn.RemoteAddr, conn.Module, conn.UpstreamAddr, - conn.ConnectedAt.Format("2006-01-02 15:04:05"), - conn.ReceivedBytes, - conn.SentBytes) + conn.ConnectedAt.Format(time.DateTime), + strconv.FormatInt(conn.ReceivedBytes, 10), + strconv.FormatInt(conn.SentBytes, 10), + }) } - _, _ = fmt.Fprintln(stdout, "==========================") - return nil + return table.Render() } func printVersion(out io.Writer, pretty bool) error { @@ -116,31 +169,25 @@ func printVersion(out io.Writer, pretty bool) error { }) } -func newConnectionsCmd(s *server.Server) *cobra.Command { +func newConnectionsCmd() *cobra.Command { c := &cobra.Command{ Use: "connections", Short: "Show active connections", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - if err := s.ReadConfigFromFile(false); err != nil { - return fmt.Errorf("load config: %w", err) - } - return SendConnectionsRequest(s.HTTPListenAddr, cmd.OutOrStdout(), cmd.ErrOrStderr()) + return SendConnectionsRequest(daemonSocket, cmd.OutOrStdout(), cmd.ErrOrStderr()) }, } return c } -func newReloadCmd(s *server.Server) *cobra.Command { +func newReloadCmd() *cobra.Command { c := &cobra.Command{ Use: "reload", Short: "Inform server to reload config", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - if err := s.ReadConfigFromFile(false); err != nil { - return fmt.Errorf("load config: %w", err) - } - return SendReloadRequest(s.HTTPListenAddr, cmd.OutOrStdout(), cmd.ErrOrStderr()) + return SendReloadRequest(daemonSocket, cmd.OutOrStdout(), cmd.ErrOrStderr()) }, } return c @@ -164,13 +211,14 @@ func newUpstreamModulesCmd(s *server.Server) *cobra.Command { if err != nil { return fmt.Errorf("parse rsync url: %w", err) } + rsyncHost := parsed.Host if parsed.Host == "" { - return fmt.Errorf("invalid rsync url: missing host") - } - if parsed.Path != "" && parsed.Path != "/" { + // Unix socket + rsyncHost = parsed.Path + } else if parsed.Path != "" && parsed.Path != "/" { return fmt.Errorf("invalid rsync url: path is not allowed") } - modules, err := s.DiscoverModulesWithProxyProtocol(parsed.Host, useProxyProtocol) + modules, err := s.DiscoverModulesWithProxyProtocol(rsyncHost, useProxyProtocol) if err != nil { return err } @@ -239,12 +287,13 @@ func New() *cobra.Command { SilenceUsage: true, } pFlags := c.PersistentFlags() + pFlags.StringVarP(&daemonSocket, "host", "H", DefaultUnixSocketPath, "Daemon socket to connect to") pFlags.StringVarP(&s.ConfigPath, "config", "c", "/etc/rsync-proxy/config.toml", "Path to config file") pFlags.BoolVarP(&version, "version", "V", false, "Print version and exit") c.AddCommand( - newConnectionsCmd(s), - newReloadCmd(s), + newConnectionsCmd(), + newReloadCmd(), newUpstreamModulesCmd(s), newVersionCmd(), ) diff --git a/go.mod b/go.mod index cba2d31..e6992f2 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,27 @@ module github.com/ustclug/rsync-proxy go 1.26 require ( + github.com/olekukonko/tablewriter v1.1.4 github.com/pelletier/go-toml v1.9.5 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.8.1 ) require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/clipperhouse/displaywidth v0.10.0 // indirect + github.com/clipperhouse/uax29/v2 v2.6.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fatih/color v1.18.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect + github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect + github.com/olekukonko/errors v1.2.0 // indirect + github.com/olekukonko/ll v0.1.6 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.10 // indirect + golang.org/x/sys v0.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 267223c..f56071f 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,31 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/clipperhouse/displaywidth v0.10.0 h1:GhBG8WuerxjFQQYeuZAeVTuyxuX+UraiZGD4HJQ3Y8g= +github.com/clipperhouse/displaywidth v0.10.0/go.mod h1:XqJajYsaiEwkxOj4bowCTMcT1SgvHo9flfF3jQasdbs= +github.com/clipperhouse/uax29/v2 v2.6.0 h1:z0cDbUV+aPASdFb2/ndFnS9ts/WNXgTNNGFoKXuhpos= +github.com/clipperhouse/uax29/v2 v2.6.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= +github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 h1:zrbMGy9YXpIeTnGj4EljqMiZsIcE09mmF8XsD5AYOJc= +github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6/go.mod h1:rEKTHC9roVVicUIfZK7DYrdIoM0EOr8mK1Hj5s3JjH0= +github.com/olekukonko/errors v1.2.0 h1:10Zcn4GeV59t/EGqJc8fUjtFT/FuUh5bTMzZ1XwmCRo= +github.com/olekukonko/errors v1.2.0/go.mod h1:ppzxA5jBKcO1vIpCXQ9ZqgDh8iwODz6OXIGKU8r5m4Y= +github.com/olekukonko/ll v0.1.6 h1:lGVTHO+Qc4Qm+fce/2h2m5y9LvqaW+DCN7xW9hsU3uA= +github.com/olekukonko/ll v0.1.6/go.mod h1:NVUmjBb/aCtUpjKk75BhWrOlARz3dqsM+OtszpY4o88= +github.com/olekukonko/tablewriter v1.1.4 h1:ORUMI3dXbMnRlRggJX3+q7OzQFDdvgbN9nVWj1drm6I= +github.com/olekukonko/tablewriter v1.1.4/go.mod h1:+kedxuyTtgoZLwif3P1Em4hARJs+mVnzKxmsCL/C5RY= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -22,6 +44,9 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/logging/file.go b/pkg/logging/file.go index 4f4f9df..59dcace 100644 --- a/pkg/logging/file.go +++ b/pkg/logging/file.go @@ -1,6 +1,7 @@ package logging import ( + "fmt" "io" "log" "os" @@ -13,9 +14,6 @@ type FileLogger struct { f *os.File l *log.Logger mu sync.Mutex - - F func(string, ...any) - Ln func(...any) } func NewFileLogger(filename string) (l *FileLogger, err error) { @@ -24,9 +22,6 @@ func NewFileLogger(filename string) (l *FileLogger, err error) { filename: filename, f: nil, l: logger, - - F: logger.Printf, - Ln: logger.Println, } if filename != "" { @@ -37,6 +32,18 @@ func NewFileLogger(filename string) (l *FileLogger, err error) { return } +func (l *FileLogger) F(format string, v ...any) { + if err := l.l.Output(2, fmt.Sprintf(format, v...)); err != nil { + log.Printf("logging output failed: %v", err) + } +} + +func (l *FileLogger) Ln(v ...any) { + if err := l.l.Output(2, fmt.Sprint(v...)); err != nil { + log.Printf("logging output failed: %v", err) + } +} + func (l *FileLogger) SetFlags(flag int) { l.l.SetFlags(flag) } diff --git a/pkg/queue/queue.go b/pkg/queue/queue.go index 012f830..92ef967 100644 --- a/pkg/queue/queue.go +++ b/pkg/queue/queue.go @@ -89,7 +89,7 @@ func (q *Queue) makeHandle(ch chan Status) *Handle { // Move next queued handle to active queued func (q *Queue) popHead() { head := q.queued[0] - head.ch <- Status{Ok: true} + trySend(head.ch, Status{Ok: true}) close(head.ch) q.active = append(q.active, head) q.queued = q.queued[1:] @@ -147,9 +147,14 @@ func (h *internalHandle) release() { func (q *Queue) broadcastStatus() { surplus := len(q.active) - q.max for i := range q.queued { - select { - case q.queued[i].ch <- Status{Index: surplus + i, Max: surplus + len(q.queued)}: - default: - } + trySend(q.queued[i].ch, Status{Index: surplus + i, Max: surplus + len(q.queued)}) + } +} + +func trySend[T any](ch chan T, obj T) { + select { + case <-ch: + default: } + ch <- obj } diff --git a/pkg/server/server.go b/pkg/server/server.go index 1f05c23..77e517e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -26,7 +26,9 @@ import ( ) const ( - TCPBufferSize = 256 + ReadBufferSize = 256 + + defaultRsyncPortString = "873" ) var ( @@ -35,6 +37,13 @@ var ( // See https://github.com/RsyncProject/rsync/blob/a6312e60c95e5ebb5764eaf18eb07be23420ebc6/clientserver.c#L203 RsyncdServerVersion = []byte("@RSYNCD: 32.0 sha512 sha256 sha1 md5 md4\n") RsyncdExit = []byte("@RSYNCD: EXIT\n") + + bufPool = &sync.Pool{ + New: func() any { + buf := make([]byte, ReadBufferSize) + return &buf + }, + } // pool of (*[]byte) ) const lineFeed = '\n' @@ -106,7 +115,6 @@ type Server struct { reloadLock sync.RWMutex dialer net.Dialer - bufPool sync.Pool // name -> upstream targets modules map[string][]Target upstreams []upstreamConfig @@ -118,10 +126,9 @@ type Server struct { connIndex atomic.Uint32 connInfo sync.Map - TCPListener *net.TCPListener - // TLSListener is not a TCPListener + TCPListener net.Listener TLSListener net.Listener - HTTPListener *net.TCPListener + HTTPListener net.Listener } type countingReader struct { @@ -141,12 +148,6 @@ func New() *Server { accessLog, _ := logging.NewFileLogger("") errorLog, _ := logging.NewFileLogger("") s := &Server{ - bufPool: sync.Pool{ - New: func() any { - buf := make([]byte, TCPBufferSize) - return &buf - }, - }, dialer: net.Dialer{}, // customize keep alive interval? accessLog: accessLog, errorLog: errorLog, @@ -197,8 +198,7 @@ func (s *Server) loadConfig(c *Config, openLog bool) error { return fmt.Errorf("upstream=%s must set modules or discover_modules", upstreamName) } addr := v.Address - _, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { + if err := validateTCPOrUnixAddr(addr); err != nil { return fmt.Errorf("resolve address: %w, upstream=%s, address=%s", err, upstreamName, addr) } upstreams = append(upstreams, upstreamConfig{ @@ -362,26 +362,20 @@ func (s *Server) discoverConfiguredModules(ctx context.Context, upstreams []upst func (s *Server) discoverModulesFromUpstream(ctx context.Context, upstream upstreamConfig) ([]string, error) { addr := upstream.Target.Addr - _, _, err := net.SplitHostPort(addr) - if err != nil { - if addrErr, ok := err.(*net.AddrError); ok && addrErr.Err == "missing port in address" { - addr = net.JoinHostPort(addr, "873") - } else { - return nil, fmt.Errorf("invalid address: %w", err) - } - } - conn, err := s.dialer.DialContext(ctx, "tcp", addr) + addr = addDefaultTCPPort(addr, defaultRsyncPortString) + conn, err := dialContextTCPOrUnix(ctx, s.dialer, addr) if err != nil { return nil, fmt.Errorf("dial: %w", err) } defer conn.Close() if upstream.Target.UseProxyProtocol { - if err := writeProxyProtocolHeader(conn, conn.LocalAddr(), conn.RemoteAddr(), s.WriteTimeout); err != nil { + err := writeProxyProtocolHeader(conn, conn.LocalAddr(), conn.RemoteAddr(), s.WriteTimeout) + if err != nil { return nil, fmt.Errorf("send proxy protocol header: %w", err) } } - reader := bufio.NewReaderSize(conn, TCPBufferSize) + reader := bufio.NewReaderSize(conn, ReadBufferSize) if _, err := writeWithTimeout(conn, RsyncdServerVersion, s.WriteTimeout); err != nil { return nil, fmt.Errorf("send version: %w", err) } @@ -496,12 +490,12 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err s.connInfo.Store(index, &info) defer s.connInfo.Delete(index) - bufPtr := s.bufPool.Get().(*[]byte) - defer s.bufPool.Put(bufPtr) + bufPtr := bufPool.Get().(*[]byte) + defer bufPool.Put(bufPtr) buf := *bufPtr addr := downConn.RemoteAddr().String() - ip := downConn.RemoteAddr().(*net.TCPAddr).IP.String() + ip := netAddrToString(downConn.RemoteAddr()) writeTimeout := s.WriteTimeout readTimeout := s.ReadTimeout @@ -599,31 +593,31 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err } } - conn, err := s.dialer.DialContext(ctx, "tcp", upstreamAddr) + upConn, err := dialContextTCPOrUnix(ctx, s.dialer, upstreamAddr) if err != nil { return fmt.Errorf("dial to upstream: %s: %w", upstreamAddr, err) } - upConn := conn.(*net.TCPConn) defer upConn.Close() - upIp := upConn.RemoteAddr().(*net.TCPAddr).IP.String() + upAddr := netAddrToString(upConn.RemoteAddr()) if useProxyProtocol { - if err := writeProxyProtocolHeader(upConn, downConn.RemoteAddr(), upConn.RemoteAddr(), writeTimeout); err != nil { - return fmt.Errorf("send proxy protocol header to upstream %s: %w", upIp, err) + err := writeProxyProtocolHeader(upConn, downConn.RemoteAddr(), upConn.RemoteAddr(), s.WriteTimeout) + if err != nil { + return fmt.Errorf("send proxy protocol header to upstream %s: %w", upAddr, err) } } _, err = writeWithTimeout(upConn, rsyncdClientVersion, writeTimeout) if err != nil { - return fmt.Errorf("send version to upstream %s: %w", upIp, err) + return fmt.Errorf("send version to upstream %s: %w", upAddr, err) } n, err = readLine(upConn, buf, readTimeout) if err != nil { - return fmt.Errorf("read version from upstream %s: %w", upIp, err) + return fmt.Errorf("read version from upstream %s: %w", upAddr, err) } data = buf[:n] if !bytes.HasPrefix(data, RsyncdVersionPrefix) { - return fmt.Errorf("unknown version from upstream %s: %s", upIp, data) + return fmt.Errorf("unknown version from upstream %s: %s", upAddr, data) } // send back the motd @@ -637,7 +631,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err _, err = writeWithTimeout(upConn, []byte(moduleName+"\n"), writeTimeout) if err != nil { - return fmt.Errorf("send module to upstream %s: %w", upIp, err) + return fmt.Errorf("send module to upstream %s: %w", upAddr, err) } s.accessLog.F("client %s starts requesting module %s", ip, moduleName) @@ -648,7 +642,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err _ = downConn.SetDeadline(zeroTime) // Use countingReader to track bytes in real-time - // and are with the client, not upstream + // and are relative to the client, not upstream downReader := &countingReader{reader: downConn, counter: &info.ReceivedBytes} upReader := &countingReader{reader: upConn, counter: &info.SentBytes} @@ -671,11 +665,12 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err }() select { case <-receivedClosed: - _ = upConn.SetLinger(0) - _ = upConn.CloseRead() + if err := closeRead(upConn, true); err != nil { + s.errorLog.F("close upstream read: %v", err) + } case <-sentClosed: - if closeReader, ok := downConn.(interface{ CloseRead() error }); ok { - _ = closeReader.CloseRead() + if err := closeRead(downConn, false); err != nil { + s.errorLog.F("close downstream read: %v", err) } } @@ -764,7 +759,7 @@ func (s *Server) runHTTPServer() error { } func (s *Server) Listen() error { - l1, err := net.Listen("tcp", s.ListenAddr) + l1, err := listenTCPOrUnix(s.ListenAddr) if err != nil { return fmt.Errorf("create tcp listener: %w", err) } @@ -773,7 +768,7 @@ func (s *Server) Listen() error { var lTLS net.Listener if s.TLSListenAddr != "" { - lTLS, err = net.Listen("tcp", s.TLSListenAddr) + lTLS, err = listenTCPOrUnix(s.TLSListenAddr) if err != nil { _ = l1.Close() return fmt.Errorf("create tls listener: %w", err) @@ -783,7 +778,7 @@ func (s *Server) Listen() error { lTLS = tls.NewListener(lTLS, &tls.Config{GetCertificate: s.getTLSCertificate}) } - l2, err := net.Listen("tcp", s.HTTPListenAddr) + l2, err := listenTCPOrUnix(s.HTTPListenAddr) if err != nil { _ = l1.Close() if lTLS != nil { @@ -794,9 +789,9 @@ func (s *Server) Listen() error { s.HTTPListenAddr = l2.Addr().String() log.Printf("[INFO] HTTP server listening on %s", s.HTTPListenAddr) - s.TCPListener = l1.(*net.TCPListener) + s.TCPListener = l1 s.TLSListener = lTLS - s.HTTPListener = l2.(*net.TCPListener) + s.HTTPListener = l2 return nil } @@ -817,6 +812,13 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn) { defer s.activeConnCount.Add(-1) connIndex := s.connIndex.Add(1) + defer func() { + err := recover() + if err != nil { + s.errorLog.F("handleConn panicked: %s", err) + } + }() + err := s.relay(ctx, connIndex, conn) if err != nil { s.errorLog.F("handleConn: %s", err) @@ -837,7 +839,7 @@ func (s *Server) runRsyncServer(ctx context.Context, listener net.Listener, acce } func (s *Server) Run() error { - errC := make(chan error) + errC := make(chan error, 1) go func() { err := s.runHTTPServer() if err != nil { diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 8a1f5b9..94ed105 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -81,7 +81,7 @@ func TestMotdFromServer(t *testing.T) { proxyMotd := "Hello\n" srv.Motd = proxyMotd - l := strings.Repeat("a", TCPBufferSize) + l := strings.Repeat("a", ReadBufferSize) serverMotd := fmt.Sprintf("%s\n%s\n\n", l, l) fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { diff --git a/pkg/server/utils.go b/pkg/server/utils.go index efeb120..c5381ae 100644 --- a/pkg/server/utils.go +++ b/pkg/server/utils.go @@ -1,8 +1,11 @@ package server import ( + "context" "fmt" "net" + "os" + "strings" "time" ) @@ -14,23 +17,54 @@ func writeWithTimeout(conn net.Conn, buf []byte, timeout time.Duration) (n int, return } -func writeProxyProtocolHeader(conn net.Conn, sourceAddr, destAddr net.Addr, timeout time.Duration) error { - sourceTCP, ok := sourceAddr.(*net.TCPAddr) - if !ok { - return fmt.Errorf("invalid source address type %T", sourceAddr) +func netAddrToString(addr net.Addr) string { + switch addr := addr.(type) { + case *net.TCPAddr: + return addr.IP.String() + case *net.UnixAddr: + return addr.String() + default: + return addr.String() } - destTCP, ok := destAddr.(*net.TCPAddr) - if !ok { - return fmt.Errorf("invalid destination address type %T", destAddr) +} + +func writeProxyProtocolHeader(conn net.Conn, sourceAddr, destAddr net.Addr, writeTimeout time.Duration) error { + h, err := generateProxyProtocolHeader(sourceAddr, destAddr) + if err != nil { + return err + } + _, err = writeWithTimeout(conn, []byte(h), writeTimeout) + return err +} + +func generateProxyProtocolHeader(sourceAddr, destAddr net.Addr) (string, error) { + var ( + sourceIP, destIP net.IP + sourcePort, destPort int + ) + switch sourceTCP := sourceAddr.(type) { + case *net.TCPAddr: + sourceIP, sourcePort = sourceTCP.IP, sourceTCP.Port + case *net.UnixAddr: + sourceIP, sourcePort = net.IPv4(127, 0, 0, 1), 1 + default: + return "", fmt.Errorf("invalid source address type %T", sourceAddr) + } + + switch destTCP := destAddr.(type) { + case *net.TCPAddr: + destIP, destPort = destTCP.IP, destTCP.Port + case *net.UnixAddr: + destIP, destPort = net.IPv4(127, 0, 0, 1), 1 + default: + return "", fmt.Errorf("invalid destination address type %T", destAddr) } ipVersion := "TCP4" - if sourceTCP.IP.To4() == nil { + if sourceIP.To4() == nil { ipVersion = "TCP6" } - proxyHeader := fmt.Sprintf("PROXY %s %s %s %d %d\r\n", ipVersion, sourceTCP.IP.String(), destTCP.IP.String(), sourceTCP.Port, destTCP.Port) - _, err := writeWithTimeout(conn, []byte(proxyHeader), timeout) - return err + return fmt.Sprintf("PROXY %s %s %s %d %d\r\n", ipVersion, sourceIP.String(), destIP.String(), sourcePort, destPort), nil } // readLine will read as much as it can until the last read character is a newline character. @@ -54,3 +88,61 @@ func readLine(conn net.Conn, buf []byte, timeout time.Duration) (n int, err erro } } } + +func listenTCPOrUnix(addr string) (net.Listener, error) { + if strings.HasPrefix(addr, "/") { + os.Remove(addr) + l, err := net.Listen("unix", addr) + if err != nil { + return l, err + } + err = os.Chmod(addr, 0o660) + return l, err + } + return net.Listen("tcp", addr) +} + +func dialContextTCPOrUnix(ctx context.Context, dialer net.Dialer, addr string) (net.Conn, error) { + if strings.HasPrefix(addr, "/") { + return dialer.DialContext(ctx, "unix", addr) + } + return dialer.DialContext(ctx, "tcp", addr) +} + +func addDefaultTCPPort(addr string, defaultPort string) string { + if strings.HasPrefix(addr, "/") { + // don't touch Unix address + return addr + } else { + _, _, err := net.SplitHostPort(addr) + if err != nil { + if addrErr, ok := err.(*net.AddrError); ok && addrErr.Err == "missing port in address" { + return net.JoinHostPort(addr, defaultPort) + } + // invalid address, return as-is + } + } + return addr +} + +func validateTCPOrUnixAddr(addr string) error { + if strings.HasPrefix(addr, "/") { + _, err := net.ResolveUnixAddr("unix", addr) + return err + } + _, err := net.ResolveTCPAddr("tcp", addr) + return err +} + +func closeRead(conn net.Conn, setLinger bool) error { + if setLinger { + if tcpConn, ok := conn.(*net.TCPConn); ok { + _ = tcpConn.SetLinger(0) + } + } + + if closeReader, ok := conn.(interface{ CloseRead() error }); ok { + return closeReader.CloseRead() + } + return nil +} diff --git a/pkg/server/utils_test.go b/pkg/server/utils_test.go index 57704ea..26963c3 100644 --- a/pkg/server/utils_test.go +++ b/pkg/server/utils_test.go @@ -1,8 +1,10 @@ package server import ( + "context" "io" "net" + "path/filepath" "testing" "time" @@ -61,10 +63,56 @@ func TestReadLine(t *testing.T) { {'\n'}, }} - buf := make([]byte, TCPBufferSize) + buf := make([]byte, ReadBufferSize) n, err := readLine(c, buf, time.Minute) require.NoError(t, err) got := buf[:n] expected := []byte("@RSYNCD: 31.0\n") assert.Equal(t, expected, got, "unexpected data") } + +func TestListenAndDialUnixSocket(t *testing.T) { + addr := filepath.Join(t.TempDir(), "rsync-proxy.sock") + + listener, err := listenTCPOrUnix(addr) + require.NoError(t, err) + defer func() { + require.NoError(t, listener.Close()) + }() + + accepted := make(chan net.Conn, 1) + acceptErr := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + acceptErr <- err + return + } + accepted <- conn + }() + + conn, err := dialContextTCPOrUnix(context.Background(), net.Dialer{}, addr) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close()) + }() + + select { + case err := <-acceptErr: + require.NoError(t, err) + case acceptedConn := <-accepted: + defer func() { + require.NoError(t, acceptedConn.Close()) + }() + assert.IsType(t, &net.UnixConn{}, acceptedConn) + case <-time.After(time.Second): + t.Fatal("timed out waiting for unix socket accept") + } + + assert.IsType(t, &net.UnixConn{}, conn) + assert.FileExists(t, addr) + + info, err := net.ResolveUnixAddr("unix", addr) + require.NoError(t, err) + assert.Equal(t, info.String(), netAddrToString(info)) +}