Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions cmd/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"os/signal"
"strconv"
"strings"
"sync"
"time"

"github.com/mendersoftware/go-lib-micro/ws"
Expand Down Expand Up @@ -94,7 +95,8 @@ type PortForwardCmd struct {
sessionID string
bindingHost string
portMappings []portMapping
recvChans map[string]chan *ws.ProtoMsg
recvChanMu sync.RWMutex
recvChans map[string]chan<- *ws.ProtoMsg
running bool
stop chan struct{}
err error
Expand Down Expand Up @@ -184,7 +186,7 @@ func NewPortForwardCmd(cmd *cobra.Command, args []string) (*PortForwardCmd, erro
deviceID: args[0],
bindingHost: bindingHost,
portMappings: portMappings,
recvChans: make(map[string]chan *ws.ProtoMsg),
recvChans: make(map[string]chan<- *ws.ProtoMsg),
stop: make(chan struct{}),
}, nil
}
Expand All @@ -198,6 +200,12 @@ func (c *PortForwardCmd) Run() error {
}
}

func (c *PortForwardCmd) registerRecvChan(connectionID string, recvChan chan<- *ws.ProtoMsg) {
c.recvChanMu.Lock()
defer c.recvChanMu.Unlock()
c.recvChans[connectionID] = recvChan
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if (and even if it is a plausible comment here) the hash already contains an entry for the connectionID?

}

func (c *PortForwardCmd) run() error {
ctx, cancelContext := context.WithCancel(context.Background())
defer cancelContext()
Expand Down Expand Up @@ -239,14 +247,14 @@ func (c *PortForwardCmd) run() error {
if err != nil {
return err
}
go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans)
go forwarder.Run(ctx, c.sessionID, msgChan, c.registerRecvChan)
case protocolUDP:
forwarder, err := NewUDPPortForwarder(c.bindingHost, portMapping.LocalPort,
portMapping.RemoteHost, portMapping.RemotePort)
if err != nil {
return err
}
go forwarder.Run(ctx, c.sessionID, msgChan, c.recvChans)
go forwarder.Run(ctx, c.sessionID, msgChan, c.registerRecvChan)
default:
return errors.New("unknown protocol: " + portMapping.Protocol)
}
Expand Down Expand Up @@ -411,9 +419,11 @@ func (c *PortForwardCmd) processIncomingMessages(
m.Header.MsgType == wspf.MessageTypePortForwardStop) {
connectionID, _ := m.Header.Properties[wspf.PropertyConnectionID].(string)
if connectionID != "" {
c.recvChanMu.RLock()
if recvChan, ok := c.recvChans[connectionID]; ok {
recvChan <- m
}
c.recvChanMu.RUnlock()
Comment on lines +422 to +426
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt it better like that:

Suggested change
c.recvChanMu.RLock()
if recvChan, ok := c.recvChans[connectionID]; ok {
recvChan <- m
}
c.recvChanMu.RUnlock()
c.recvChanMu.RLock()
recvChan, ok := c.recvChans[connectionID]; ok {
c.recvChanMu.RUnlock()
if ok {
recvChan <- m
}

}
}
}
Expand Down
188 changes: 109 additions & 79 deletions cmd/portforward_tcp.go
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally trust you as this is quite hard to review with large portions of code moved. I did open side by side both parts of the code and compared, but you know.

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"net"
"os"
"strconv"
"sync"

"github.com/google/uuid"
"github.com/mendersoftware/go-lib-micro/ws"
Expand All @@ -37,7 +36,6 @@ type TCPPortForwarder struct {
listen net.Listener
remoteHost string
remotePort uint16
mutexAck map[string]*sync.Mutex
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

func NewTCPPortForwarder(
Expand All @@ -55,15 +53,14 @@ func NewTCPPortForwarder(
listen: listen,
remoteHost: remoteHost,
remotePort: remotePort,
mutexAck: map[string]*sync.Mutex{},
}, nil
}

func (p *TCPPortForwarder) Run(
ctx context.Context,
sessionID string,
msgChan chan *ws.ProtoMsg,
recvChans map[string]chan *ws.ProtoMsg,
registerRecvChan func(string, chan<- *ws.ProtoMsg),
) {
// listen for new connections
defer p.listen.Close()
Expand All @@ -74,6 +71,15 @@ func (p *TCPPortForwarder) Run(
for {
conn, err := p.listen.Accept()
if err != nil {
close(acceptedConnections)
fmt.Fprintf(os.Stderr,
"error accepting new connection on socket %s: %s\n",
p.listen.Addr(), err.Error(),
)
fmt.Fprintf(os.Stderr,
"closing listening socket %s\n",
p.listen.Addr(),
)
Comment on lines +75 to +82
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are new errors that have not been reported before in that form, could we please mention it explicitly so there is a plain change log for that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should be consistent about using log (of slog) instead of printing to standard error?

return
}
fmt.Printf(
Expand All @@ -88,17 +94,94 @@ func (p *TCPPortForwarder) Run(
// handle new connections
for {
select {
case conn := <-acceptedConnections:
case conn, open := <-acceptedConnections:
if !open {
return
}
connectionUUID, _ := uuid.NewUUID()
connectionID := connectionUUID.String()
recvChan := make(chan *ws.ProtoMsg, portForwardTCPChannelSize)
recvChans[connectionID] = recvChan
registerRecvChan(connectionID, recvChan)
go p.handleRequest(ctx, conn, sessionID, connectionID, recvChan, msgChan)
case <-ctx.Done():
return
}
}
}
func (p *TCPPortForwarder) handleInboundMessages(
ctx context.Context,
conn net.Conn,
sessionID string,
connectionID string,
ackChan <-chan struct{},
msgChan chan<- *ws.ProtoMsg,
recvChan <-chan *ws.ProtoMsg) {
sendStopMessage := true
defer func() {
conn.Close()
if sendStopMessage {
m := &ws.ProtoMsg{
Header: ws.ProtoHdr{
Proto: ws.ProtoTypePortForward,
MsgType: wspf.MessageTypePortForwardStop,
SessionID: sessionID,
Properties: map[string]interface{}{
wspf.PropertyConnectionID: connectionID,
},
},
}
msgChan <- m
}
}()

for {
var (
m *ws.ProtoMsg
open bool
)
select {
case m, open = <-recvChan:
if !open {
return
}
if m.Header.Proto == ws.ProtoTypePortForward &&
m.Header.MsgType == wspf.MessageTypePortForwardStop {
sendStopMessage = false
return
} else if m.Header.Proto == ws.ProtoTypePortForward &&
m.Header.MsgType == wspf.MessageTypePortForward {
_, err := conn.Write(m.Body)
if err != nil {
if !errors.Is(err, net.ErrClosed) {
fmt.Fprintf(os.Stderr, "error: %v\n", err.Error())
}
return
} else {
// send the ack
m = &ws.ProtoMsg{
Header: ws.ProtoHdr{
Proto: ws.ProtoTypePortForward,
MsgType: wspf.MessageTypePortForwardAck,
SessionID: sessionID,
Properties: map[string]interface{}{
wspf.PropertyConnectionID: connectionID,
},
},
}
msgChan <- m
}
} else if m.Header.Proto == ws.ProtoTypePortForward &&
m.Header.MsgType == wspf.MessageTypePortForwardAck {
_, open = <-ackChan
if !open {
return
}
}
case <-ctx.Done():
return
}
}
}

func (p *TCPPortForwarder) handleRequest(
ctx context.Context,
Expand All @@ -108,12 +191,9 @@ func (p *TCPPortForwarder) handleRequest(
recvChan chan *ws.ProtoMsg,
msgChan chan *ws.ProtoMsg,
) {
defer conn.Close()

p.mutexAck[connectionID] = &sync.Mutex{}
defer func() {
delete(p.mutexAck, connectionID)
}()
ackChan := make(chan struct{})
defer func() { close(ackChan) }()

errChan := make(chan error)
dataChan := make(chan []byte)
Expand Down Expand Up @@ -142,81 +222,18 @@ func (p *TCPPortForwarder) handleRequest(
}
msgChan <- m

sendStopMessage := true
defer func() {
conn.Close()
if sendStopMessage {
m := &ws.ProtoMsg{
Header: ws.ProtoHdr{
Proto: ws.ProtoTypePortForward,
MsgType: wspf.MessageTypePortForwardStop,
SessionID: sessionID,
Properties: map[string]interface{}{
wspf.PropertyConnectionID: connectionID,
},
},
}
msgChan <- m
}
}()

// go routine to handle the network connection
go p.handleRequestConnection(dataChan, errChan, conn)

// go routine to handle received messages
go func(connectionID string) {
for {
select {
case m := <-recvChan:
if m.Header.Proto == ws.ProtoTypePortForward &&
m.Header.MsgType == wspf.MessageTypePortForwardStop {
sendStopMessage = false
return
} else if m.Header.Proto == ws.ProtoTypePortForward &&
m.Header.MsgType == wspf.MessageTypePortForward {
_, err := conn.Write(m.Body)
if err != nil {
if errors.Unwrap(err) != net.ErrClosed {
fmt.Fprintf(os.Stderr, "error: %v\n", err.Error())
}
} else {
// send the ack
m := &ws.ProtoMsg{
Header: ws.ProtoHdr{
Proto: ws.ProtoTypePortForward,
MsgType: wspf.MessageTypePortForwardAck,
SessionID: sessionID,
Properties: map[string]interface{}{
wspf.PropertyConnectionID: connectionID,
},
},
}
msgChan <- m
}
} else if m.Header.Proto == ws.ProtoTypePortForward &&
m.Header.MsgType == wspf.MessageTypePortForwardAck {
if m, ok := p.mutexAck[connectionID]; ok {
m.Unlock()
}
}
case <-ctx.Done():
return
}
}
}(connectionID)
go p.handleInboundMessages(ctx, conn, sessionID, connectionID, ackChan, msgChan, recvChan)

// go routine to handle sent messages
for {
for err == nil {
select {
case err := <-errChan:
if err != io.EOF {
fmt.Fprintf(os.Stderr, "error: %v\n", err.Error())
}
return
case data := <-dataChan:
// lock the ack mutex, we don't allow more than one in-flight message
p.mutexAck[connectionID].Lock()
case err = <-errChan:

case data := <-dataChan:
m := &ws.ProtoMsg{
Header: ws.ProtoHdr{
Proto: ws.ProtoTypePortForward,
Expand All @@ -229,10 +246,23 @@ func (p *TCPPortForwarder) handleRequest(
Body: data,
}
msgChan <- m
// wait for the ack to be received before processing more data
select {
case ackChan <- struct{}{}:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do I read this correctly, that now we ack every message and wait for ack with each message? I am having some doubts if that is the correct way of doing things. I know it is preexisting, but doubts I have.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it has always been a peer-to-peer stop and wait and was previously achieved with the mutex dance 👯
I noticed that in the previous implementation that despite using a map of mutexes, there's only a single entry in the map, so I replaced it with a chan primitive.
There is an epic of improving the protocol to move the transmission control to the only faulty link (NATS), but that will come with a separate change of the protocol. We cannot just change this on the client side or we risk connections hanging due to lost packets.

case <-ctx.Done():
err = ctx.Err()

case err = <-errChan:

}
case <-ctx.Done():
return
err = ctx.Err()
}
}

if err != io.EOF {
fmt.Fprintf(os.Stderr, "error: %v\n", err.Error())
}
}

func (p *TCPPortForwarder) handleRequestConnection(
Expand Down
4 changes: 2 additions & 2 deletions cmd/portforward_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ func (p *UDPPortForwarder) Run(
ctx context.Context,
sessionID string,
msgChan chan *ws.ProtoMsg,
recvChans map[string]chan *ws.ProtoMsg,
registerRecvChan func(string, chan<- *ws.ProtoMsg),
) {
// listen for new connections
defer p.conn.Close()

connectionUUID, _ := uuid.NewUUID()
connectionID := connectionUUID.String()
recvChan := make(chan *ws.ProtoMsg, portForwardUDPChannelSize)
recvChans[connectionID] = recvChan
registerRecvChan(connectionID, recvChan)

protocol := portforward.PortForwardProtocol(wspf.PortForwardProtocolUDP)
portforwardNew := &wspf.PortForwardNew{
Expand Down