From 47f7ef4cd686b9e06fbf4fca2ab591d9ee1ac855 Mon Sep 17 00:00:00 2001 From: Alf-Rune Siqveland Date: Mon, 8 Dec 2025 15:49:24 +0100 Subject: [PATCH 1/4] fix: concurrent map access for mutexes for in-flight messages This commit replaces the map with chan to better handle in flight messages while listening to concurrent events. This also solves the concurrency issue where the application crashes due to concurrent map writes. Signed-off-by: Alf-Rune Siqveland --- cmd/portforward_tcp.go | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/cmd/portforward_tcp.go b/cmd/portforward_tcp.go index 963d1eab..f5114984 100644 --- a/cmd/portforward_tcp.go +++ b/cmd/portforward_tcp.go @@ -22,7 +22,6 @@ import ( "net" "os" "strconv" - "sync" "github.com/google/uuid" "github.com/mendersoftware/go-lib-micro/ws" @@ -37,7 +36,6 @@ type TCPPortForwarder struct { listen net.Listener remoteHost string remotePort uint16 - mutexAck map[string]*sync.Mutex } func NewTCPPortForwarder( @@ -55,7 +53,6 @@ func NewTCPPortForwarder( listen: listen, remoteHost: remoteHost, remotePort: remotePort, - mutexAck: map[string]*sync.Mutex{}, }, nil } @@ -110,10 +107,8 @@ func (p *TCPPortForwarder) handleRequest( ) { defer conn.Close() - p.mutexAck[connectionID] = &sync.Mutex{} - defer func() { - delete(p.mutexAck, connectionID) - }() + ackChan := make(chan struct{}) + defer func() { close(ackChan) }() errChan := make(chan error) dataChan := make(chan []byte) @@ -195,9 +190,7 @@ func (p *TCPPortForwarder) handleRequest( } } else if m.Header.Proto == ws.ProtoTypePortForward && m.Header.MsgType == wspf.MessageTypePortForwardAck { - if m, ok := p.mutexAck[connectionID]; ok { - m.Unlock() - } + <-ackChan } case <-ctx.Done(): return @@ -206,17 +199,11 @@ func (p *TCPPortForwarder) handleRequest( }(connectionID) // go routine to handle sent messages - for { + for err == nil { select { - case err := <-errChan: - if err != io.EOF { - fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) - } - return - case data := <-dataChan: - // lock the ack mutex, we don't allow more than one in-flight message - p.mutexAck[connectionID].Lock() + case err = <-errChan: + case data := <-dataChan: m := &ws.ProtoMsg{ Header: ws.ProtoHdr{ Proto: ws.ProtoTypePortForward, @@ -229,10 +216,23 @@ func (p *TCPPortForwarder) handleRequest( Body: data, } msgChan <- m + // wait for the ack to be received before processing more data + select { + case ackChan <- struct{}{}: + case <-ctx.Done(): + err = ctx.Err() + + case err = <-errChan: + + } case <-ctx.Done(): - return + err = ctx.Err() } } + + if err != io.EOF { + fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) + } } func (p *TCPPortForwarder) handleRequestConnection( From d6e7edb6acb1faf11f311a1d7caf47563de354d0 Mon Sep 17 00:00:00 2001 From: Alf-Rune Siqveland Date: Mon, 8 Dec 2025 16:35:25 +0100 Subject: [PATCH 2/4] chore: lower cyclomatic complexity for (*TCPPortForwarder).handleRequest Signed-off-by: Alf-Rune Siqveland --- cmd/portforward_tcp.go | 121 ++++++++++++++++++++++------------------- 1 file changed, 64 insertions(+), 57 deletions(-) diff --git a/cmd/portforward_tcp.go b/cmd/portforward_tcp.go index f5114984..afa8f51d 100644 --- a/cmd/portforward_tcp.go +++ b/cmd/portforward_tcp.go @@ -96,6 +96,69 @@ func (p *TCPPortForwarder) Run( } } } +func (p *TCPPortForwarder) handleInboundMessages( + ctx context.Context, + conn net.Conn, + sessionID string, + connectionID string, + ackChan <-chan struct{}, + msgChan chan<- *ws.ProtoMsg, + recvChan <-chan *ws.ProtoMsg) { + sendStopMessage := true + defer func() { + conn.Close() + if sendStopMessage { + m := &ws.ProtoMsg{ + Header: ws.ProtoHdr{ + Proto: ws.ProtoTypePortForward, + MsgType: wspf.MessageTypePortForwardStop, + SessionID: sessionID, + Properties: map[string]interface{}{ + wspf.PropertyConnectionID: connectionID, + }, + }, + } + msgChan <- m + } + }() + + for { + select { + case m := <-recvChan: + if m.Header.Proto == ws.ProtoTypePortForward && + m.Header.MsgType == wspf.MessageTypePortForwardStop { + sendStopMessage = false + return + } else if m.Header.Proto == ws.ProtoTypePortForward && + m.Header.MsgType == wspf.MessageTypePortForward { + _, err := conn.Write(m.Body) + if err != nil { + if errors.Unwrap(err) != net.ErrClosed { + fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) + } + } else { + // send the ack + m := &ws.ProtoMsg{ + Header: ws.ProtoHdr{ + Proto: ws.ProtoTypePortForward, + MsgType: wspf.MessageTypePortForwardAck, + SessionID: sessionID, + Properties: map[string]interface{}{ + wspf.PropertyConnectionID: connectionID, + }, + }, + } + msgChan <- m + } + } else if m.Header.Proto == ws.ProtoTypePortForward && + m.Header.MsgType == wspf.MessageTypePortForwardAck { + <-ackChan + } + case <-ctx.Done(): + return + } + } +} func (p *TCPPortForwarder) handleRequest( ctx context.Context, @@ -105,7 +168,6 @@ func (p *TCPPortForwarder) handleRequest( recvChan chan *ws.ProtoMsg, msgChan chan *ws.ProtoMsg, ) { - defer conn.Close() ackChan := make(chan struct{}) defer func() { close(ackChan) }() @@ -137,66 +199,11 @@ func (p *TCPPortForwarder) handleRequest( } msgChan <- m - sendStopMessage := true - defer func() { - conn.Close() - if sendStopMessage { - m := &ws.ProtoMsg{ - Header: ws.ProtoHdr{ - Proto: ws.ProtoTypePortForward, - MsgType: wspf.MessageTypePortForwardStop, - SessionID: sessionID, - Properties: map[string]interface{}{ - wspf.PropertyConnectionID: connectionID, - }, - }, - } - msgChan <- m - } - }() - // go routine to handle the network connection go p.handleRequestConnection(dataChan, errChan, conn) // go routine to handle received messages - go func(connectionID string) { - for { - select { - case m := <-recvChan: - if m.Header.Proto == ws.ProtoTypePortForward && - m.Header.MsgType == wspf.MessageTypePortForwardStop { - sendStopMessage = false - return - } else if m.Header.Proto == ws.ProtoTypePortForward && - m.Header.MsgType == wspf.MessageTypePortForward { - _, err := conn.Write(m.Body) - if err != nil { - if errors.Unwrap(err) != net.ErrClosed { - fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) - } - } else { - // send the ack - m := &ws.ProtoMsg{ - Header: ws.ProtoHdr{ - Proto: ws.ProtoTypePortForward, - MsgType: wspf.MessageTypePortForwardAck, - SessionID: sessionID, - Properties: map[string]interface{}{ - wspf.PropertyConnectionID: connectionID, - }, - }, - } - msgChan <- m - } - } else if m.Header.Proto == ws.ProtoTypePortForward && - m.Header.MsgType == wspf.MessageTypePortForwardAck { - <-ackChan - } - case <-ctx.Done(): - return - } - } - }(connectionID) + go p.handleInboundMessages(ctx, conn, sessionID, connectionID, ackChan, msgChan, recvChan) // go routine to handle sent messages for err == nil { From abf50b77b5202f3432636204404515cf24a53bfa Mon Sep 17 00:00:00 2001 From: Alf-Rune Siqveland Date: Tue, 9 Dec 2025 12:42:07 +0100 Subject: [PATCH 3/4] chore: close listening socket when failing to accept connection Signed-off-by: Alf-Rune Siqveland --- cmd/portforward_tcp.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/cmd/portforward_tcp.go b/cmd/portforward_tcp.go index afa8f51d..9dc5b170 100644 --- a/cmd/portforward_tcp.go +++ b/cmd/portforward_tcp.go @@ -71,6 +71,15 @@ func (p *TCPPortForwarder) Run( for { conn, err := p.listen.Accept() if err != nil { + close(acceptedConnections) + fmt.Fprintf(os.Stderr, + "error accepting new connection on socket %s: %s\n", + p.listen.Addr(), err.Error(), + ) + fmt.Fprintf(os.Stderr, + "closing listening socket %s\n", + p.listen.Addr(), + ) return } fmt.Printf( @@ -85,7 +94,10 @@ func (p *TCPPortForwarder) Run( // handle new connections for { select { - case conn := <-acceptedConnections: + case conn, open := <-acceptedConnections: + if !open { + return + } connectionUUID, _ := uuid.NewUUID() connectionID := connectionUUID.String() recvChan := make(chan *ws.ProtoMsg, portForwardTCPChannelSize) From 12c8ca663cef6e879cc39ebeef6874b9b846ff9d Mon Sep 17 00:00:00 2001 From: Alf-Rune Siqveland Date: Fri, 12 Dec 2025 16:54:59 +0100 Subject: [PATCH 4/4] fix: protect port forward receive chan map from concurrent read/write Signed-off-by: Alf-Rune Siqveland --- cmd/portforward.go | 18 ++++++++++++++---- cmd/portforward_tcp.go | 23 +++++++++++++++++------ cmd/portforward_udp.go | 4 ++-- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/cmd/portforward.go b/cmd/portforward.go index b20b3f3b..82111fc1 100644 --- a/cmd/portforward.go +++ b/cmd/portforward.go @@ -21,6 +21,7 @@ import ( "os/signal" "strconv" "strings" + "sync" "time" "github.com/mendersoftware/go-lib-micro/ws" @@ -94,7 +95,8 @@ type PortForwardCmd struct { sessionID string bindingHost string portMappings []portMapping - recvChans map[string]chan *ws.ProtoMsg + recvChanMu sync.RWMutex + recvChans map[string]chan<- *ws.ProtoMsg running bool stop chan struct{} err error @@ -184,7 +186,7 @@ func NewPortForwardCmd(cmd *cobra.Command, args []string) (*PortForwardCmd, erro deviceID: args[0], bindingHost: bindingHost, portMappings: portMappings, - recvChans: make(map[string]chan *ws.ProtoMsg), + recvChans: make(map[string]chan<- *ws.ProtoMsg), stop: make(chan struct{}), }, nil } @@ -198,6 +200,12 @@ func (c *PortForwardCmd) Run() error { } } +func (c *PortForwardCmd) registerRecvChan(connectionID string, recvChan chan<- *ws.ProtoMsg) { + c.recvChanMu.Lock() + defer c.recvChanMu.Unlock() + c.recvChans[connectionID] = recvChan +} + func (c *PortForwardCmd) run() error { ctx, cancelContext := context.WithCancel(context.Background()) defer cancelContext() @@ -239,14 +247,14 @@ func (c *PortForwardCmd) run() error { if err != nil { return err } - go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans) + go forwarder.Run(ctx, c.sessionID, msgChan, c.registerRecvChan) case protocolUDP: forwarder, err := NewUDPPortForwarder(c.bindingHost, portMapping.LocalPort, portMapping.RemoteHost, portMapping.RemotePort) if err != nil { return err } - go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans) + go forwarder.Run(ctx, c.sessionID, msgChan, c.registerRecvChan) default: return errors.New("unknown protocol: " + portMapping.Protocol) } @@ -411,9 +419,11 @@ func (c *PortForwardCmd) processIncomingMessages( m.Header.MsgType == wspf.MessageTypePortForwardStop) { connectionID, _ := m.Header.Properties[wspf.PropertyConnectionID].(string) if connectionID != "" { + c.recvChanMu.RLock() if recvChan, ok := c.recvChans[connectionID]; ok { recvChan <- m } + c.recvChanMu.RUnlock() } } } diff --git a/cmd/portforward_tcp.go b/cmd/portforward_tcp.go index 9dc5b170..f8031dca 100644 --- a/cmd/portforward_tcp.go +++ b/cmd/portforward_tcp.go @@ -60,7 +60,7 @@ func (p *TCPPortForwarder) Run( ctx context.Context, sessionID string, msgChan chan *ws.ProtoMsg, - recvChans map[string]chan *ws.ProtoMsg, + registerRecvChan func(string, chan<- *ws.ProtoMsg), ) { // listen for new connections defer p.listen.Close() @@ -101,7 +101,7 @@ func (p *TCPPortForwarder) Run( connectionUUID, _ := uuid.NewUUID() connectionID := connectionUUID.String() recvChan := make(chan *ws.ProtoMsg, portForwardTCPChannelSize) - recvChans[connectionID] = recvChan + registerRecvChan(connectionID, recvChan) go p.handleRequest(ctx, conn, sessionID, connectionID, recvChan, msgChan) case <-ctx.Done(): return @@ -135,8 +135,15 @@ func (p *TCPPortForwarder) handleInboundMessages( }() for { + var ( + m *ws.ProtoMsg + open bool + ) select { - case m := <-recvChan: + case m, open = <-recvChan: + if !open { + return + } if m.Header.Proto == ws.ProtoTypePortForward && m.Header.MsgType == wspf.MessageTypePortForwardStop { sendStopMessage = false @@ -145,12 +152,13 @@ func (p *TCPPortForwarder) handleInboundMessages( m.Header.MsgType == wspf.MessageTypePortForward { _, err := conn.Write(m.Body) if err != nil { - if errors.Unwrap(err) != net.ErrClosed { + if !errors.Is(err, net.ErrClosed) { fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) } + return } else { // send the ack - m := &ws.ProtoMsg{ + m = &ws.ProtoMsg{ Header: ws.ProtoHdr{ Proto: ws.ProtoTypePortForward, MsgType: wspf.MessageTypePortForwardAck, @@ -164,7 +172,10 @@ func (p *TCPPortForwarder) handleInboundMessages( } } else if m.Header.Proto == ws.ProtoTypePortForward && m.Header.MsgType == wspf.MessageTypePortForwardAck { - <-ackChan + _, open = <-ackChan + if !open { + return + } } case <-ctx.Done(): return diff --git a/cmd/portforward_udp.go b/cmd/portforward_udp.go index a9e43ca6..b045b396 100644 --- a/cmd/portforward_udp.go +++ b/cmd/portforward_udp.go @@ -73,7 +73,7 @@ func (p *UDPPortForwarder) Run( ctx context.Context, sessionID string, msgChan chan *ws.ProtoMsg, - recvChans map[string]chan *ws.ProtoMsg, + registerRecvChan func(string, chan<- *ws.ProtoMsg), ) { // listen for new connections defer p.conn.Close() @@ -81,7 +81,7 @@ func (p *UDPPortForwarder) Run( connectionUUID, _ := uuid.NewUUID() connectionID := connectionUUID.String() recvChan := make(chan *ws.ProtoMsg, portForwardUDPChannelSize) - recvChans[connectionID] = recvChan + registerRecvChan(connectionID, recvChan) protocol := portforward.PortForwardProtocol(wspf.PortForwardProtocolUDP) portforwardNew := &wspf.PortForwardNew{