diff --git a/cmd/portforward.go b/cmd/portforward.go index b20b3f3b..82111fc1 100644 --- a/cmd/portforward.go +++ b/cmd/portforward.go @@ -21,6 +21,7 @@ import ( "os/signal" "strconv" "strings" + "sync" "time" "github.com/mendersoftware/go-lib-micro/ws" @@ -94,7 +95,8 @@ type PortForwardCmd struct { sessionID string bindingHost string portMappings []portMapping - recvChans map[string]chan *ws.ProtoMsg + recvChanMu sync.RWMutex + recvChans map[string]chan<- *ws.ProtoMsg running bool stop chan struct{} err error @@ -184,7 +186,7 @@ func NewPortForwardCmd(cmd *cobra.Command, args []string) (*PortForwardCmd, erro deviceID: args[0], bindingHost: bindingHost, portMappings: portMappings, - recvChans: make(map[string]chan *ws.ProtoMsg), + recvChans: make(map[string]chan<- *ws.ProtoMsg), stop: make(chan struct{}), }, nil } @@ -198,6 +200,12 @@ func (c *PortForwardCmd) Run() error { } } +func (c *PortForwardCmd) registerRecvChan(connectionID string, recvChan chan<- *ws.ProtoMsg) { + c.recvChanMu.Lock() + defer c.recvChanMu.Unlock() + c.recvChans[connectionID] = recvChan +} + func (c *PortForwardCmd) run() error { ctx, cancelContext := context.WithCancel(context.Background()) defer cancelContext() @@ -239,14 +247,14 @@ func (c *PortForwardCmd) run() error { if err != nil { return err } - go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans) + go forwarder.Run(ctx, c.sessionID, msgChan, c.registerRecvChan) case protocolUDP: forwarder, err := NewUDPPortForwarder(c.bindingHost, portMapping.LocalPort, portMapping.RemoteHost, portMapping.RemotePort) if err != nil { return err } - go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans) + go forwarder.Run(ctx, c.sessionID, msgChan, c.registerRecvChan) default: return errors.New("unknown protocol: " + portMapping.Protocol) } @@ -411,9 +419,11 @@ func (c *PortForwardCmd) processIncomingMessages( m.Header.MsgType == wspf.MessageTypePortForwardStop) { connectionID, _ := m.Header.Properties[wspf.PropertyConnectionID].(string) if connectionID != "" { + c.recvChanMu.RLock() if recvChan, ok := c.recvChans[connectionID]; ok { recvChan <- m } + c.recvChanMu.RUnlock() } } } diff --git a/cmd/portforward_tcp.go b/cmd/portforward_tcp.go index 963d1eab..f8031dca 100644 --- a/cmd/portforward_tcp.go +++ b/cmd/portforward_tcp.go @@ -22,7 +22,6 @@ import ( "net" "os" "strconv" - "sync" "github.com/google/uuid" "github.com/mendersoftware/go-lib-micro/ws" @@ -37,7 +36,6 @@ type TCPPortForwarder struct { listen net.Listener remoteHost string remotePort uint16 - mutexAck map[string]*sync.Mutex } func NewTCPPortForwarder( @@ -55,7 +53,6 @@ func NewTCPPortForwarder( listen: listen, remoteHost: remoteHost, remotePort: remotePort, - mutexAck: map[string]*sync.Mutex{}, }, nil } @@ -63,7 +60,7 @@ func (p *TCPPortForwarder) Run( ctx context.Context, sessionID string, msgChan chan *ws.ProtoMsg, - recvChans map[string]chan *ws.ProtoMsg, + registerRecvChan func(string, chan<- *ws.ProtoMsg), ) { // listen for new connections defer p.listen.Close() @@ -74,6 +71,15 @@ func (p *TCPPortForwarder) Run( for { conn, err := p.listen.Accept() if err != nil { + close(acceptedConnections) + fmt.Fprintf(os.Stderr, + "error accepting new connection on socket %s: %s\n", + p.listen.Addr(), err.Error(), + ) + fmt.Fprintf(os.Stderr, + "closing listening socket %s\n", + p.listen.Addr(), + ) return } fmt.Printf( @@ -88,17 +94,94 @@ func (p *TCPPortForwarder) Run( // handle new connections for { select { - case conn := <-acceptedConnections: + case conn, open := <-acceptedConnections: + if !open { + return + } connectionUUID, _ := uuid.NewUUID() connectionID := connectionUUID.String() recvChan := make(chan *ws.ProtoMsg, portForwardTCPChannelSize) - recvChans[connectionID] = recvChan + registerRecvChan(connectionID, recvChan) go p.handleRequest(ctx, conn, sessionID, connectionID, recvChan, msgChan) case <-ctx.Done(): return } } } +func (p *TCPPortForwarder) handleInboundMessages( + ctx context.Context, + conn net.Conn, + sessionID string, + connectionID string, + ackChan <-chan struct{}, + msgChan chan<- *ws.ProtoMsg, + recvChan <-chan *ws.ProtoMsg) { + sendStopMessage := true + defer func() { + conn.Close() + if sendStopMessage { + m := &ws.ProtoMsg{ + Header: ws.ProtoHdr{ + Proto: ws.ProtoTypePortForward, + MsgType: wspf.MessageTypePortForwardStop, + SessionID: sessionID, + Properties: map[string]interface{}{ + wspf.PropertyConnectionID: connectionID, + }, + }, + } + msgChan <- m + } + }() + + for { + var ( + m *ws.ProtoMsg + open bool + ) + select { + case m, open = <-recvChan: + if !open { + return + } + if m.Header.Proto == ws.ProtoTypePortForward && + m.Header.MsgType == wspf.MessageTypePortForwardStop { + sendStopMessage = false + return + } else if m.Header.Proto == ws.ProtoTypePortForward && + m.Header.MsgType == wspf.MessageTypePortForward { + _, err := conn.Write(m.Body) + if err != nil { + if !errors.Is(err, net.ErrClosed) { + fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) + } + return + } else { + // send the ack + m = &ws.ProtoMsg{ + Header: ws.ProtoHdr{ + Proto: ws.ProtoTypePortForward, + MsgType: wspf.MessageTypePortForwardAck, + SessionID: sessionID, + Properties: map[string]interface{}{ + wspf.PropertyConnectionID: connectionID, + }, + }, + } + msgChan <- m + } + } else if m.Header.Proto == ws.ProtoTypePortForward && + m.Header.MsgType == wspf.MessageTypePortForwardAck { + _, open = <-ackChan + if !open { + return + } + } + case <-ctx.Done(): + return + } + } +} func (p *TCPPortForwarder) handleRequest( ctx context.Context, @@ -108,12 +191,9 @@ func (p *TCPPortForwarder) handleRequest( recvChan chan *ws.ProtoMsg, msgChan chan *ws.ProtoMsg, ) { - defer conn.Close() - p.mutexAck[connectionID] = &sync.Mutex{} - defer func() { - delete(p.mutexAck, connectionID) - }() + ackChan := make(chan struct{}) + defer func() { close(ackChan) }() errChan := make(chan error) dataChan := make(chan []byte) @@ -142,81 +222,18 @@ func (p *TCPPortForwarder) handleRequest( } msgChan <- m - sendStopMessage := true - defer func() { - conn.Close() - if sendStopMessage { - m := &ws.ProtoMsg{ - Header: ws.ProtoHdr{ - Proto: ws.ProtoTypePortForward, - MsgType: wspf.MessageTypePortForwardStop, - SessionID: sessionID, - Properties: map[string]interface{}{ - wspf.PropertyConnectionID: connectionID, - }, - }, - } - msgChan <- m - } - }() - // go routine to handle the network connection go p.handleRequestConnection(dataChan, errChan, conn) // go routine to handle received messages - go func(connectionID string) { - for { - select { - case m := <-recvChan: - if m.Header.Proto == ws.ProtoTypePortForward && - m.Header.MsgType == wspf.MessageTypePortForwardStop { - sendStopMessage = false - return - } else if m.Header.Proto == ws.ProtoTypePortForward && - m.Header.MsgType == wspf.MessageTypePortForward { - _, err := conn.Write(m.Body) - if err != nil { - if errors.Unwrap(err) != net.ErrClosed { - fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) - } - } else { - // send the ack - m := &ws.ProtoMsg{ - Header: ws.ProtoHdr{ - Proto: ws.ProtoTypePortForward, - MsgType: wspf.MessageTypePortForwardAck, - SessionID: sessionID, - Properties: map[string]interface{}{ - wspf.PropertyConnectionID: connectionID, - }, - }, - } - msgChan <- m - } - } else if m.Header.Proto == ws.ProtoTypePortForward && - m.Header.MsgType == wspf.MessageTypePortForwardAck { - if m, ok := p.mutexAck[connectionID]; ok { - m.Unlock() - } - } - case <-ctx.Done(): - return - } - } - }(connectionID) + go p.handleInboundMessages(ctx, conn, sessionID, connectionID, ackChan, msgChan, recvChan) // go routine to handle sent messages - for { + for err == nil { select { - case err := <-errChan: - if err != io.EOF { - fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) - } - return - case data := <-dataChan: - // lock the ack mutex, we don't allow more than one in-flight message - p.mutexAck[connectionID].Lock() + case err = <-errChan: + case data := <-dataChan: m := &ws.ProtoMsg{ Header: ws.ProtoHdr{ Proto: ws.ProtoTypePortForward, @@ -229,10 +246,23 @@ func (p *TCPPortForwarder) handleRequest( Body: data, } msgChan <- m + // wait for the ack to be received before processing more data + select { + case ackChan <- struct{}{}: + case <-ctx.Done(): + err = ctx.Err() + + case err = <-errChan: + + } case <-ctx.Done(): - return + err = ctx.Err() } } + + if err != io.EOF { + fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) + } } func (p *TCPPortForwarder) handleRequestConnection( diff --git a/cmd/portforward_udp.go b/cmd/portforward_udp.go index a9e43ca6..b045b396 100644 --- a/cmd/portforward_udp.go +++ b/cmd/portforward_udp.go @@ -73,7 +73,7 @@ func (p *UDPPortForwarder) Run( ctx context.Context, sessionID string, msgChan chan *ws.ProtoMsg, - recvChans map[string]chan *ws.ProtoMsg, + registerRecvChan func(string, chan<- *ws.ProtoMsg), ) { // listen for new connections defer p.conn.Close() @@ -81,7 +81,7 @@ func (p *UDPPortForwarder) Run( connectionUUID, _ := uuid.NewUUID() connectionID := connectionUUID.String() recvChan := make(chan *ws.ProtoMsg, portForwardUDPChannelSize) - recvChans[connectionID] = recvChan + registerRecvChan(connectionID, recvChan) protocol := portforward.PortForwardProtocol(wspf.PortForwardProtocolUDP) portforwardNew := &wspf.PortForwardNew{