-
Notifications
You must be signed in to change notification settings - Fork 46
fix: concurrent map access for mutexes for in-flight messages #319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
47f7ef4
d6e7edb
abf50b7
12c8ca6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -21,6 +21,7 @@ import ( | |||||||||||||||||||||||
| "os/signal" | ||||||||||||||||||||||||
| "strconv" | ||||||||||||||||||||||||
| "strings" | ||||||||||||||||||||||||
| "sync" | ||||||||||||||||||||||||
| "time" | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| "github.com/mendersoftware/go-lib-micro/ws" | ||||||||||||||||||||||||
|
|
@@ -94,7 +95,8 @@ type PortForwardCmd struct { | |||||||||||||||||||||||
| sessionID string | ||||||||||||||||||||||||
| bindingHost string | ||||||||||||||||||||||||
| portMappings []portMapping | ||||||||||||||||||||||||
| recvChans map[string]chan *ws.ProtoMsg | ||||||||||||||||||||||||
| recvChanMu sync.RWMutex | ||||||||||||||||||||||||
| recvChans map[string]chan<- *ws.ProtoMsg | ||||||||||||||||||||||||
| running bool | ||||||||||||||||||||||||
| stop chan struct{} | ||||||||||||||||||||||||
| err error | ||||||||||||||||||||||||
|
|
@@ -184,7 +186,7 @@ func NewPortForwardCmd(cmd *cobra.Command, args []string) (*PortForwardCmd, erro | |||||||||||||||||||||||
| deviceID: args[0], | ||||||||||||||||||||||||
| bindingHost: bindingHost, | ||||||||||||||||||||||||
| portMappings: portMappings, | ||||||||||||||||||||||||
| recvChans: make(map[string]chan *ws.ProtoMsg), | ||||||||||||||||||||||||
| recvChans: make(map[string]chan<- *ws.ProtoMsg), | ||||||||||||||||||||||||
| stop: make(chan struct{}), | ||||||||||||||||||||||||
| }, nil | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
@@ -198,6 +200,12 @@ func (c *PortForwardCmd) Run() error { | |||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| func (c *PortForwardCmd) registerRecvChan(connectionID string, recvChan chan<- *ws.ProtoMsg) { | ||||||||||||||||||||||||
| c.recvChanMu.Lock() | ||||||||||||||||||||||||
| defer c.recvChanMu.Unlock() | ||||||||||||||||||||||||
| c.recvChans[connectionID] = recvChan | ||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if (and even if it is a plausible comment here) the hash already contains an entry for the |
||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| func (c *PortForwardCmd) run() error { | ||||||||||||||||||||||||
| ctx, cancelContext := context.WithCancel(context.Background()) | ||||||||||||||||||||||||
| defer cancelContext() | ||||||||||||||||||||||||
|
|
@@ -239,14 +247,14 @@ func (c *PortForwardCmd) run() error { | |||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||
| return err | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans) | ||||||||||||||||||||||||
| go forwarder.Run(ctx, c.sessionID, msgChan, c.registerRecvChan) | ||||||||||||||||||||||||
| case protocolUDP: | ||||||||||||||||||||||||
| forwarder, err := NewUDPPortForwarder(c.bindingHost, portMapping.LocalPort, | ||||||||||||||||||||||||
| portMapping.RemoteHost, portMapping.RemotePort) | ||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||
| return err | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans) | ||||||||||||||||||||||||
| go forwarder.Run(ctx, c.sessionID, msgChan, c.registerRecvChan) | ||||||||||||||||||||||||
| default: | ||||||||||||||||||||||||
| return errors.New("unknown protocol: " + portMapping.Protocol) | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
@@ -411,9 +419,11 @@ func (c *PortForwardCmd) processIncomingMessages( | |||||||||||||||||||||||
| m.Header.MsgType == wspf.MessageTypePortForwardStop) { | ||||||||||||||||||||||||
| connectionID, _ := m.Header.Properties[wspf.PropertyConnectionID].(string) | ||||||||||||||||||||||||
| if connectionID != "" { | ||||||||||||||||||||||||
| c.recvChanMu.RLock() | ||||||||||||||||||||||||
| if recvChan, ok := c.recvChans[connectionID]; ok { | ||||||||||||||||||||||||
| recvChan <- m | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| c.recvChanMu.RUnlock() | ||||||||||||||||||||||||
|
Comment on lines
+422
to
+426
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isnt it better like that:
Suggested change
|
||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I totally trust you as this is quite hard to review with large portions of code moved. I did open side by side both parts of the code and compared, but you know. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,6 @@ import ( | |
| "net" | ||
| "os" | ||
| "strconv" | ||
| "sync" | ||
|
|
||
| "github.com/google/uuid" | ||
| "github.com/mendersoftware/go-lib-micro/ws" | ||
|
|
@@ -37,7 +36,6 @@ type TCPPortForwarder struct { | |
| listen net.Listener | ||
| remoteHost string | ||
| remotePort uint16 | ||
| mutexAck map[string]*sync.Mutex | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is also one in https://github.com/mendersoftware/mender-cli/pull/303/files#diff-e1a621ea821c78f5bf64ce3fe74d6dbfc3211e7cba95be68139c627cc989813aR98 |
||
| } | ||
|
|
||
| func NewTCPPortForwarder( | ||
|
|
@@ -55,15 +53,14 @@ func NewTCPPortForwarder( | |
| listen: listen, | ||
| remoteHost: remoteHost, | ||
| remotePort: remotePort, | ||
| mutexAck: map[string]*sync.Mutex{}, | ||
| }, nil | ||
| } | ||
|
|
||
| func (p *TCPPortForwarder) Run( | ||
| ctx context.Context, | ||
| sessionID string, | ||
| msgChan chan *ws.ProtoMsg, | ||
| recvChans map[string]chan *ws.ProtoMsg, | ||
| registerRecvChan func(string, chan<- *ws.ProtoMsg), | ||
| ) { | ||
| // listen for new connections | ||
| defer p.listen.Close() | ||
|
|
@@ -74,6 +71,15 @@ func (p *TCPPortForwarder) Run( | |
| for { | ||
| conn, err := p.listen.Accept() | ||
| if err != nil { | ||
| close(acceptedConnections) | ||
| fmt.Fprintf(os.Stderr, | ||
| "error accepting new connection on socket %s: %s\n", | ||
| p.listen.Addr(), err.Error(), | ||
| ) | ||
| fmt.Fprintf(os.Stderr, | ||
| "closing listening socket %s\n", | ||
| p.listen.Addr(), | ||
| ) | ||
|
Comment on lines
+75
to
+82
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are new errors that have not been reported before in that form, could we please mention it explicitly so there is a plain change log for that?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should be consistent about using |
||
| return | ||
| } | ||
| fmt.Printf( | ||
|
|
@@ -88,17 +94,94 @@ func (p *TCPPortForwarder) Run( | |
| // handle new connections | ||
| for { | ||
| select { | ||
| case conn := <-acceptedConnections: | ||
| case conn, open := <-acceptedConnections: | ||
| if !open { | ||
| return | ||
alfrunes marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| connectionUUID, _ := uuid.NewUUID() | ||
| connectionID := connectionUUID.String() | ||
| recvChan := make(chan *ws.ProtoMsg, portForwardTCPChannelSize) | ||
| recvChans[connectionID] = recvChan | ||
| registerRecvChan(connectionID, recvChan) | ||
| go p.handleRequest(ctx, conn, sessionID, connectionID, recvChan, msgChan) | ||
| case <-ctx.Done(): | ||
| return | ||
| } | ||
| } | ||
| } | ||
| func (p *TCPPortForwarder) handleInboundMessages( | ||
| ctx context.Context, | ||
| conn net.Conn, | ||
| sessionID string, | ||
| connectionID string, | ||
| ackChan <-chan struct{}, | ||
| msgChan chan<- *ws.ProtoMsg, | ||
| recvChan <-chan *ws.ProtoMsg) { | ||
| sendStopMessage := true | ||
| defer func() { | ||
| conn.Close() | ||
| if sendStopMessage { | ||
| m := &ws.ProtoMsg{ | ||
| Header: ws.ProtoHdr{ | ||
| Proto: ws.ProtoTypePortForward, | ||
| MsgType: wspf.MessageTypePortForwardStop, | ||
| SessionID: sessionID, | ||
| Properties: map[string]interface{}{ | ||
| wspf.PropertyConnectionID: connectionID, | ||
| }, | ||
| }, | ||
| } | ||
| msgChan <- m | ||
| } | ||
| }() | ||
|
|
||
| for { | ||
| var ( | ||
| m *ws.ProtoMsg | ||
| open bool | ||
| ) | ||
| select { | ||
| case m, open = <-recvChan: | ||
| if !open { | ||
| return | ||
| } | ||
| if m.Header.Proto == ws.ProtoTypePortForward && | ||
| m.Header.MsgType == wspf.MessageTypePortForwardStop { | ||
| sendStopMessage = false | ||
| return | ||
| } else if m.Header.Proto == ws.ProtoTypePortForward && | ||
| m.Header.MsgType == wspf.MessageTypePortForward { | ||
| _, err := conn.Write(m.Body) | ||
| if err != nil { | ||
| if !errors.Is(err, net.ErrClosed) { | ||
| fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) | ||
| } | ||
| return | ||
| } else { | ||
| // send the ack | ||
| m = &ws.ProtoMsg{ | ||
| Header: ws.ProtoHdr{ | ||
| Proto: ws.ProtoTypePortForward, | ||
| MsgType: wspf.MessageTypePortForwardAck, | ||
| SessionID: sessionID, | ||
| Properties: map[string]interface{}{ | ||
| wspf.PropertyConnectionID: connectionID, | ||
| }, | ||
| }, | ||
| } | ||
| msgChan <- m | ||
| } | ||
| } else if m.Header.Proto == ws.ProtoTypePortForward && | ||
| m.Header.MsgType == wspf.MessageTypePortForwardAck { | ||
| _, open = <-ackChan | ||
| if !open { | ||
| return | ||
| } | ||
| } | ||
| case <-ctx.Done(): | ||
| return | ||
| } | ||
| } | ||
| } | ||
|
|
||
| func (p *TCPPortForwarder) handleRequest( | ||
| ctx context.Context, | ||
|
|
@@ -108,12 +191,9 @@ func (p *TCPPortForwarder) handleRequest( | |
| recvChan chan *ws.ProtoMsg, | ||
| msgChan chan *ws.ProtoMsg, | ||
| ) { | ||
| defer conn.Close() | ||
|
|
||
| p.mutexAck[connectionID] = &sync.Mutex{} | ||
| defer func() { | ||
| delete(p.mutexAck, connectionID) | ||
| }() | ||
| ackChan := make(chan struct{}) | ||
| defer func() { close(ackChan) }() | ||
|
|
||
| errChan := make(chan error) | ||
| dataChan := make(chan []byte) | ||
|
|
@@ -142,81 +222,18 @@ func (p *TCPPortForwarder) handleRequest( | |
| } | ||
| msgChan <- m | ||
|
|
||
| sendStopMessage := true | ||
| defer func() { | ||
| conn.Close() | ||
| if sendStopMessage { | ||
| m := &ws.ProtoMsg{ | ||
| Header: ws.ProtoHdr{ | ||
| Proto: ws.ProtoTypePortForward, | ||
| MsgType: wspf.MessageTypePortForwardStop, | ||
| SessionID: sessionID, | ||
| Properties: map[string]interface{}{ | ||
| wspf.PropertyConnectionID: connectionID, | ||
| }, | ||
| }, | ||
| } | ||
| msgChan <- m | ||
| } | ||
| }() | ||
|
|
||
| // go routine to handle the network connection | ||
| go p.handleRequestConnection(dataChan, errChan, conn) | ||
|
|
||
| // go routine to handle received messages | ||
| go func(connectionID string) { | ||
| for { | ||
| select { | ||
| case m := <-recvChan: | ||
| if m.Header.Proto == ws.ProtoTypePortForward && | ||
| m.Header.MsgType == wspf.MessageTypePortForwardStop { | ||
| sendStopMessage = false | ||
| return | ||
| } else if m.Header.Proto == ws.ProtoTypePortForward && | ||
| m.Header.MsgType == wspf.MessageTypePortForward { | ||
| _, err := conn.Write(m.Body) | ||
| if err != nil { | ||
| if errors.Unwrap(err) != net.ErrClosed { | ||
| fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) | ||
| } | ||
| } else { | ||
| // send the ack | ||
| m := &ws.ProtoMsg{ | ||
| Header: ws.ProtoHdr{ | ||
| Proto: ws.ProtoTypePortForward, | ||
| MsgType: wspf.MessageTypePortForwardAck, | ||
| SessionID: sessionID, | ||
| Properties: map[string]interface{}{ | ||
| wspf.PropertyConnectionID: connectionID, | ||
| }, | ||
| }, | ||
| } | ||
| msgChan <- m | ||
| } | ||
| } else if m.Header.Proto == ws.ProtoTypePortForward && | ||
| m.Header.MsgType == wspf.MessageTypePortForwardAck { | ||
| if m, ok := p.mutexAck[connectionID]; ok { | ||
| m.Unlock() | ||
| } | ||
| } | ||
| case <-ctx.Done(): | ||
| return | ||
| } | ||
| } | ||
| }(connectionID) | ||
| go p.handleInboundMessages(ctx, conn, sessionID, connectionID, ackChan, msgChan, recvChan) | ||
|
|
||
| // go routine to handle sent messages | ||
| for { | ||
| for err == nil { | ||
| select { | ||
| case err := <-errChan: | ||
| if err != io.EOF { | ||
| fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) | ||
| } | ||
| return | ||
| case data := <-dataChan: | ||
| // lock the ack mutex, we don't allow more than one in-flight message | ||
| p.mutexAck[connectionID].Lock() | ||
| case err = <-errChan: | ||
|
|
||
| case data := <-dataChan: | ||
| m := &ws.ProtoMsg{ | ||
| Header: ws.ProtoHdr{ | ||
| Proto: ws.ProtoTypePortForward, | ||
|
|
@@ -229,10 +246,23 @@ func (p *TCPPortForwarder) handleRequest( | |
| Body: data, | ||
| } | ||
| msgChan <- m | ||
| // wait for the ack to be received before processing more data | ||
| select { | ||
| case ackChan <- struct{}{}: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do I read this correctly, that now we ack every message and wait for ack with each message? I am having some doubts if that is the correct way of doing things. I know it is preexisting, but doubts I have.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it has always been a peer-to-peer stop and wait and was previously achieved with the mutex dance 👯 |
||
| case <-ctx.Done(): | ||
| err = ctx.Err() | ||
|
|
||
| case err = <-errChan: | ||
|
|
||
| } | ||
| case <-ctx.Done(): | ||
| return | ||
| err = ctx.Err() | ||
| } | ||
| } | ||
|
|
||
| if err != io.EOF { | ||
| fmt.Fprintf(os.Stderr, "error: %v\n", err.Error()) | ||
| } | ||
| } | ||
|
|
||
| func (p *TCPPortForwarder) handleRequestConnection( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.