From c5c36416f7cf6f427262bcefb7a2d7a4d65c1ee8 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Wed, 25 Mar 2026 13:38:38 +0200 Subject: [PATCH 1/9] Minor improvements to help with k0sctl migration Signed-off-by: Kimmo Lehto --- client.go | 24 +++++++++++++- client_test.go | 34 +++++++++++++++++++ protocol/connection.go | 10 ++++++ protocol/localhost/connection.go | 12 ++++++- protocol/openssh/connection.go | 57 ++++++++++++++++++++++++++++---- protocol/ssh/connection.go | 7 +++- protocol/winrm/connection.go | 22 +++++++++++- rigtest/mockrunner.go | 30 ++++++++++++++--- 8 files changed, 182 insertions(+), 14 deletions(-) diff --git a/client.go b/client.go index 67bc2110..301aa481 100644 --- a/client.go +++ b/client.go @@ -379,7 +379,19 @@ func (c *Client) OS() (*os.Release, error) { return os, nil } -// Protocol returns the protocol used to connect to the host. +// IsConnected returns true if the underlying connection is currently active. +// For SSH this sends a real keepalive probe; for other protocols it reflects +// whether Connect has been called and Disconnect has not. +func (c *Client) IsConnected() bool { + if c.connection == nil { + return false + } + return c.connection.IsConnected() +} + +// Protocol returns the protocol family used to connect to the host: +// "SSH", "WinRM", or "Local". Both the native SSH and OpenSSH +// implementations return "SSH". func (c *Client) Protocol() string { if c.connection == nil { return "uninitialized" @@ -387,6 +399,16 @@ func (c *Client) Protocol() string { return c.connection.Protocol() } +// ProtocolName returns the specific protocol implementation name: +// "SSH", "OpenSSH", "WinRM", or "Local". Use this for logging or +// diagnostics where the distinction between native SSH and OpenSSH matters. +func (c *Client) ProtocolName() string { + if c.connection == nil { + return "uninitialized" + } + return c.connection.ProtocolName() +} + // Address returns the address of the host. func (c *Client) Address() string { if c.connection != nil { diff --git a/client_test.go b/client_test.go index 3700db63..0a47f521 100644 --- a/client_test.go +++ b/client_test.go @@ -103,6 +103,40 @@ func TestClientPackageManagerErrorFallback(t *testing.T) { require.ErrorIs(t, err, mockErr) } +func TestClientReconnect(t *testing.T) { + conn := rigtest.NewMockConnection() + conn.AddCommandOutput(rigtest.Match("echo hello"), "hello") + + client, err := rig.NewClient(rig.WithConnection(conn)) + require.NoError(t, err) + + require.NoError(t, client.Connect(context.Background())) + require.True(t, client.IsConnected()) + + out, err := client.ExecOutput("echo hello") + require.NoError(t, err) + require.Equal(t, "hello", out) + + client.Disconnect() + require.False(t, client.IsConnected()) + + require.NoError(t, client.Connect(context.Background())) + require.True(t, client.IsConnected()) + + out, err = client.ExecOutput("echo hello") + require.NoError(t, err) + require.Equal(t, "hello", out) +} + +func TestClientProtocolName(t *testing.T) { + conn := rigtest.NewMockConnection() + client, err := rig.NewClient(rig.WithConnection(conn)) + require.NoError(t, err) + + require.Equal(t, "mock", client.Protocol()) + require.Equal(t, "mock", client.ProtocolName()) +} + type testConfig struct { Hosts []*testHost `yaml:"hosts"` } diff --git a/protocol/connection.go b/protocol/connection.go index 11aa53f6..be4e8e57 100644 --- a/protocol/connection.go +++ b/protocol/connection.go @@ -50,7 +50,17 @@ type InteractiveExecer interface { // Connection is the minimum interface for protocol implementations. type Connection interface { fmt.Stringer + // Protocol returns the protocol family: "SSH", "WinRM", or "Local". + // Both the native SSH and OpenSSH implementations return "SSH". Protocol() string + // ProtocolName returns the specific implementation name: "SSH", "OpenSSH", + // "WinRM", or "Local". Use this for logging or diagnostics where the + // distinction between native SSH and OpenSSH matters. + ProtocolName() string + // IsConnected returns true if the connection is currently active. + // For SSH, this sends a real keepalive probe. For other protocols, + // it reflects whether the connection has been established and not closed. + IsConnected() bool IPAddress() string ProcessStarter WindowsChecker diff --git a/protocol/localhost/connection.go b/protocol/localhost/connection.go index bc5c595a..b6039495 100644 --- a/protocol/localhost/connection.go +++ b/protocol/localhost/connection.go @@ -26,11 +26,16 @@ func NewConnection() (*Connection, error) { return &Connection{}, nil } -// Protocol returns the protocol name, "Local". +// Protocol returns the protocol family, "Local". func (c *Connection) Protocol() string { return "Local" } +// ProtocolName returns the implementation name, "Local". +func (c *Connection) ProtocolName() string { + return "Local" +} + // IPAddress returns the connection address. func (c *Connection) IPAddress() string { return "127.0.0.1" @@ -46,6 +51,11 @@ func (c *Connection) IsWindows() bool { return runtime.GOOS == "windows" } +// IsConnected always returns true for localhost — there is no connection to lose. +func (c *Connection) IsConnected() bool { + return true +} + // StartProcess executes a command on the remote host and uses the passed in streams for stdin, stdout and stderr. It returns a Waiter with a .Wait() function that // blocks until the command finishes and returns an error if the exit code is not zero. func (c *Connection) StartProcess(ctx context.Context, cmd string, stdin io.Reader, stdout, stderr io.Writer) (protocol.Waiter, error) { diff --git a/protocol/openssh/connection.go b/protocol/openssh/connection.go index 23ea787e..cffea645 100644 --- a/protocol/openssh/connection.go +++ b/protocol/openssh/connection.go @@ -24,6 +24,13 @@ var ( errNotConnected = errors.New("not connected") ) +// isHostKeyError reports whether ssh stderr output indicates a host key +// verification failure. These are fatal and should not be retried. +func isHostKeyError(stderr string) bool { + return strings.Contains(stderr, "Host key verification failed") || + strings.Contains(stderr, "REMOTE HOST IDENTIFICATION HAS CHANGED") +} + // Connection is a rig.Connection implementation that uses the system openssh client "ssh" to connect to remote hosts. // The connection is by default multiplexec over a control master, so that subsequent connections don't need to re-authenticate. type Connection struct { @@ -44,8 +51,13 @@ func NewConnection(cfg Config) (*Connection, error) { return &Connection{Config: cfg}, nil } -// Protocol returns the protocol name. +// Protocol returns the protocol family, "SSH". func (c *Connection) Protocol() string { + return "SSH" +} + +// ProtocolName returns the implementation name, "OpenSSH". +func (c *Connection) ProtocolName() string { return "OpenSSH" } @@ -140,8 +152,17 @@ func (c *Connection) Connect(ctx context.Context) error { } if c.DisableMultiplexing { - // just run a noop command to check connectivity - if _, err := c.StartProcess(ctx, "exit 0", nil, nil, nil); err != nil { + // Run a noop to check connectivity. Capture stderr to detect host key failures. + errBuf := bytes.NewBuffer(nil) + proc, err := c.StartProcess(ctx, "exit 0", nil, nil, errBuf) + if err == nil { + err = proc.Wait() + } + if err != nil { + errOut := errBuf.String() + if isHostKeyError(errOut) { + return fmt.Errorf("%w: host key verification failed: %w (%s)", protocol.ErrNonRetryable, err, errOut) + } return fmt.Errorf("failed to connect: %w", err) } c.isConnected = true @@ -175,7 +196,11 @@ func (c *Connection) Connect(ctx context.Context) error { log.Trace(ctx, "starting ssh control master", log.KeyHost, c, log.KeyCommand, strings.Join(args, " ")) if err := cmd.Run(); err != nil { c.isConnected = false - return fmt.Errorf("failed to start ssh multiplexing control master: %w (%s)", err, errBuf.String()) + errOut := errBuf.String() + if isHostKeyError(errOut) { + return fmt.Errorf("%w: host key verification failed: %w (%s)", protocol.ErrNonRetryable, err, errOut) + } + return fmt.Errorf("failed to start ssh multiplexing control master: %w (%s)", err, errOut) } c.isConnected = true @@ -261,9 +286,29 @@ func (c *Connection) String() string { return c.name } -// IsConnected returns true if the connection is connected. +// IsConnected returns true if the connection is alive. For multiplexed +// connections this probes the control master via ssh -O check. For +// non-multiplexed connections it runs a no-op command over a fresh session. func (c *Connection) IsConnected() bool { - return c.isConnected + if !c.isConnected { + return false + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if !c.DisableMultiplexing { + controlPath, ok := c.Options["ControlPath"].(string) + if !ok { + return false + } + args := []string{"-O", "check", "-S", controlPath} + args = append(args, c.args()...) + return exec.CommandContext(ctx, "ssh", args...).Run() == nil + } + proc, err := c.StartProcess(ctx, "exit 0", nil, nil, nil) + if err != nil { + return false + } + return proc.Wait() == nil } // Disconnect disconnects from the remote host. If multiplexing is enabled, this will close the control master. diff --git a/protocol/ssh/connection.go b/protocol/ssh/connection.go index 87cd255f..26708420 100644 --- a/protocol/ssh/connection.go +++ b/protocol/ssh/connection.go @@ -151,11 +151,16 @@ func (c *Connection) SetDefaults(ctx context.Context) { }) } -// Protocol returns the protocol name, "SSH". +// Protocol returns the protocol family, "SSH". func (c *Connection) Protocol() string { return "SSH" } +// ProtocolName returns the implementation name, "SSH". +func (c *Connection) ProtocolName() string { + return "SSH" +} + // IPAddress returns the connection address. func (c *Connection) IPAddress() string { return c.Address diff --git a/protocol/winrm/connection.go b/protocol/winrm/connection.go index 57794770..a61d5152 100644 --- a/protocol/winrm/connection.go +++ b/protocol/winrm/connection.go @@ -54,11 +54,16 @@ func NewConnection(cfg Config, opts ...Option) (*Connection, error) { return c, nil } -// Protocol returns the protocol name, "WinRM". +// Protocol returns the protocol family, "WinRM". func (c *Connection) Protocol() string { return "WinRM" } +// ProtocolName returns the implementation name, "WinRM". +func (c *Connection) ProtocolName() string { + return "WinRM" +} + // IPAddress returns the connection address. func (c *Connection) IPAddress() string { return c.Address @@ -78,6 +83,21 @@ func (c *Connection) IsWindows() bool { return true } +// IsConnected returns true if the WinRM connection is alive by running a no-op +// command. WinRM is stateless HTTP so the only real liveness test is a probe. +func (c *Connection) IsConnected() bool { + if c.client == nil { + return false + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + proc, err := c.StartProcess(ctx, "exit 0", nil, nil, nil) + if err != nil { + return false + } + return proc.Wait() == nil +} + func (c *Connection) loadCertificates() error { c.caCert = nil if c.CACertPath != "" { diff --git a/rigtest/mockrunner.go b/rigtest/mockrunner.go index afe1d22a..6e4e1dd0 100644 --- a/rigtest/mockrunner.go +++ b/rigtest/mockrunner.go @@ -107,10 +107,11 @@ func NewMockRunner() *MockRunner { // MockConnection is a mock client. It can be used to simulate a client in tests. type MockConnection struct { log.LoggerInjectable - commands []string + commands []string *MockStarter - Windows bool - mu sync.Mutex + Windows bool + mu sync.Mutex + connected bool } // NewMockConnection creates a new mock connection. @@ -120,6 +121,21 @@ func NewMockConnection() *MockConnection { } } +// Connect marks the mock connection as connected. +func (m *MockConnection) Connect(_ context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + m.connected = true + return nil +} + +// Disconnect marks the mock connection as disconnected. +func (m *MockConnection) Disconnect() { + m.mu.Lock() + defer m.mu.Unlock() + m.connected = false +} + // IsWindows returns true if the runner's connection is set to be a Windows client. func (m *MockRunner) IsWindows() bool { return m.Windows @@ -142,7 +158,13 @@ func (m *MockConnection) IsWindows() bool { return m.Windows } func (m *MockConnection) String() string { return "mockclient" } // Protocol returns the protocol of the client. -func (m *MockConnection) Protocol() string { return "mock" } +func (m *MockConnection) Protocol() string { return "mock" } +func (m *MockConnection) ProtocolName() string { return "mock" } +func (m *MockConnection) IsConnected() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.connected +} // IPAddress returns the IP address of the client. func (m *MockConnection) IPAddress() string { return "mock" } From e932852d3ee223342ffb70dd8c88e037b91bc562 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Wed, 25 Mar 2026 14:22:21 +0200 Subject: [PATCH 2/9] lint + corrections Signed-off-by: Kimmo Lehto --- protocol/connection.go | 6 ++++-- protocol/openssh/connection.go | 12 +++--------- rigtest/mockrunner.go | 18 ++++++++++++------ 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/protocol/connection.go b/protocol/connection.go index be4e8e57..b65ea70c 100644 --- a/protocol/connection.go +++ b/protocol/connection.go @@ -58,8 +58,10 @@ type Connection interface { // distinction between native SSH and OpenSSH matters. ProtocolName() string // IsConnected returns true if the connection is currently active. - // For SSH, this sends a real keepalive probe. For other protocols, - // it reflects whether the connection has been established and not closed. + // All implementations perform an active liveness probe (e.g. SSH keepalive, + // ssh -O check for OpenSSH multiplexing, or a no-op command for WinRM). + // Localhost always returns true. Callers should be aware this may block + // up to a timeout (typically 10s) and may cause side-effects on the remote. IsConnected() bool IPAddress() string ProcessStarter diff --git a/protocol/openssh/connection.go b/protocol/openssh/connection.go index cffea645..aafca623 100644 --- a/protocol/openssh/connection.go +++ b/protocol/openssh/connection.go @@ -183,15 +183,8 @@ func (c *Connection) Connect(ctx context.Context) error { args = append(args, c.args()...) cmd := exec.CommandContext(ctx, "ssh", args...) - stderr, err := cmd.StderrPipe() - if err != nil { - return fmt.Errorf("create stderr pipe: %w", err) - } - defer stderr.Close() errBuf := bytes.NewBuffer(nil) - go func() { - _, _ = io.Copy(errBuf, stderr) - }() + cmd.Stderr = errBuf log.Trace(ctx, "starting ssh control master", log.KeyHost, c, log.KeyCommand, strings.Join(args, " ")) if err := cmd.Run(); err != nil { @@ -300,7 +293,8 @@ func (c *Connection) IsConnected() bool { if !ok { return false } - args := []string{"-O", "check", "-S", controlPath} + args := make([]string, 0, 4+len(c.args())) + args = append(args, "-O", "check", "-S", controlPath) args = append(args, c.args()...) return exec.CommandContext(ctx, "ssh", args...).Run() == nil } diff --git a/rigtest/mockrunner.go b/rigtest/mockrunner.go index 6e4e1dd0..6fedcd98 100644 --- a/rigtest/mockrunner.go +++ b/rigtest/mockrunner.go @@ -107,11 +107,11 @@ func NewMockRunner() *MockRunner { // MockConnection is a mock client. It can be used to simulate a client in tests. type MockConnection struct { log.LoggerInjectable - commands []string + commands []string *MockStarter - Windows bool - mu sync.Mutex - connected bool + Windows bool + mu sync.Mutex + connected bool } // NewMockConnection creates a new mock connection. @@ -157,9 +157,15 @@ func (m *MockConnection) IsWindows() bool { return m.Windows } // String returns the string representation of the client. func (m *MockConnection) String() string { return "mockclient" } +const mockProtocol = "mock" + // Protocol returns the protocol of the client. -func (m *MockConnection) Protocol() string { return "mock" } -func (m *MockConnection) ProtocolName() string { return "mock" } +func (m *MockConnection) Protocol() string { return mockProtocol } + +// ProtocolName returns the protocol name of the client. +func (m *MockConnection) ProtocolName() string { return mockProtocol } + +// IsConnected returns true if the client is connected. func (m *MockConnection) IsConnected() bool { m.mu.Lock() defer m.mu.Unlock() From 70a6a65b3c9b58585e2aa7ceba3c467510befdfa Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Wed, 25 Mar 2026 14:40:01 +0200 Subject: [PATCH 3/9] review fixes Signed-off-by: Kimmo Lehto --- client.go | 5 +++-- protocol/openssh/connection.go | 24 ++++++++++++++---------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 301aa481..7131c73c 100644 --- a/client.go +++ b/client.go @@ -380,8 +380,9 @@ func (c *Client) OS() (*os.Release, error) { } // IsConnected returns true if the underlying connection is currently active. -// For SSH this sends a real keepalive probe; for other protocols it reflects -// whether Connect has been called and Disconnect has not. +// This delegates to the protocol connection's IsConnected, which may perform +// an active liveness probe (e.g. ssh -O check for OpenSSH multiplexed sessions, +// or a no-op command for WinRM/non-multiplexed SSH) and may block up to a timeout. func (c *Client) IsConnected() bool { if c.connection == nil { return false diff --git a/protocol/openssh/connection.go b/protocol/openssh/connection.go index aafca623..8fcc153b 100644 --- a/protocol/openssh/connection.go +++ b/protocol/openssh/connection.go @@ -288,21 +288,25 @@ func (c *Connection) IsConnected() bool { } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + var alive bool if !c.DisableMultiplexing { controlPath, ok := c.Options["ControlPath"].(string) - if !ok { - return false + if ok { + args := make([]string, 0, 4+len(c.args())) + args = append(args, "-O", "check", "-S", controlPath) + args = append(args, c.args()...) + alive = exec.CommandContext(ctx, "ssh", args...).Run() == nil + } + } else { + proc, err := c.StartProcess(ctx, "exit 0", nil, nil, nil) + if err == nil { + alive = proc.Wait() == nil } - args := make([]string, 0, 4+len(c.args())) - args = append(args, "-O", "check", "-S", controlPath) - args = append(args, c.args()...) - return exec.CommandContext(ctx, "ssh", args...).Run() == nil } - proc, err := c.StartProcess(ctx, "exit 0", nil, nil, nil) - if err != nil { - return false + if !alive { + c.isConnected = false } - return proc.Wait() == nil + return alive } // Disconnect disconnects from the remote host. If multiplexing is enabled, this will close the control master. From dc48408c7da1706920a9a3fe21cc41914c4ca0fc Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Wed, 25 Mar 2026 14:48:52 +0200 Subject: [PATCH 4/9] review fixes Signed-off-by: Kimmo Lehto --- protocol/openssh/connection.go | 37 +++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/protocol/openssh/connection.go b/protocol/openssh/connection.go index 8fcc153b..c99c5ad0 100644 --- a/protocol/openssh/connection.go +++ b/protocol/openssh/connection.go @@ -38,6 +38,7 @@ type Connection struct { Config `yaml:",inline"` isConnected bool + controlPath string controlMutex sync.Mutex isWindows *bool @@ -197,6 +198,9 @@ func (c *Connection) Connect(ctx context.Context) error { } c.isConnected = true + if cp, ok := c.Options["ControlPath"].(string); ok { + c.controlPath = cp + } log.Trace(ctx, "started ssh multipliexing control master", log.KeyHost, c) return nil @@ -210,13 +214,12 @@ func (c *Connection) closeControl() error { return nil } - controlPath, ok := c.Options["ControlPath"].(string) - if !ok { + if c.controlPath == "" { return ErrControlPathNotSet } args := make([]string, 0, 4+len(c.args())+1) - args = append(args, "-O", "exit", "-S", controlPath) + args = append(args, "-O", "exit", "-S", c.controlPath) args = append(args, c.args()...) args = append(args, c.userhost()) @@ -288,25 +291,27 @@ func (c *Connection) IsConnected() bool { } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - var alive bool if !c.DisableMultiplexing { - controlPath, ok := c.Options["ControlPath"].(string) - if ok { - args := make([]string, 0, 4+len(c.args())) - args = append(args, "-O", "check", "-S", controlPath) - args = append(args, c.args()...) - alive = exec.CommandContext(ctx, "ssh", args...).Run() == nil + if c.controlPath == "" { + // Control path not known (e.g. set via ssh_config -F); skip probe + // to avoid incorrectly marking a live connection as disconnected. + return true } - } else { - proc, err := c.StartProcess(ctx, "exit 0", nil, nil, nil) - if err == nil { - alive = proc.Wait() == nil + args := make([]string, 0, 4+len(c.args())) + args = append(args, "-O", "check", "-S", c.controlPath) + args = append(args, c.args()...) + if exec.CommandContext(ctx, "ssh", args...).Run() != nil { + c.isConnected = false + return false } + return true } - if !alive { + proc, err := c.StartProcess(ctx, "exit 0", nil, nil, nil) + if err != nil || proc.Wait() != nil { c.isConnected = false + return false } - return alive + return true } // Disconnect disconnects from the remote host. If multiplexing is enabled, this will close the control master. From d54dc0e3deab18230c033a94f8f075ec8a2e599d Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Wed, 25 Mar 2026 15:02:26 +0200 Subject: [PATCH 5/9] openssh fixes Signed-off-by: Kimmo Lehto --- protocol/openssh/connection.go | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/protocol/openssh/connection.go b/protocol/openssh/connection.go index c99c5ad0..95d6633a 100644 --- a/protocol/openssh/connection.go +++ b/protocol/openssh/connection.go @@ -148,11 +148,14 @@ func (c *Connection) args() []string { // Connect connects to the remote host. If multiplexing is enabled, this will start a control master. If multiplexing is disabled, this will just run a noop command to check connectivity. func (c *Connection) Connect(ctx context.Context) error { + c.controlMutex.Lock() if c.isConnected { + c.controlMutex.Unlock() return nil } if c.DisableMultiplexing { + c.controlMutex.Unlock() // Run a noop to check connectivity. Capture stderr to detect host key failures. errBuf := bytes.NewBuffer(nil) proc, err := c.StartProcess(ctx, "exit 0", nil, nil, errBuf) @@ -162,15 +165,16 @@ func (c *Connection) Connect(ctx context.Context) error { if err != nil { errOut := errBuf.String() if isHostKeyError(errOut) { - return fmt.Errorf("%w: host key verification failed: %w (%s)", protocol.ErrNonRetryable, err, errOut) + return fmt.Errorf("%w: host key verification failed: %v (%s)", protocol.ErrNonRetryable, err, errOut) } return fmt.Errorf("failed to connect: %w", err) } + c.controlMutex.Lock() c.isConnected = true + c.controlMutex.Unlock() return nil } - c.controlMutex.Lock() defer c.controlMutex.Unlock() opts := c.Options.Copy() @@ -192,7 +196,7 @@ func (c *Connection) Connect(ctx context.Context) error { c.isConnected = false errOut := errBuf.String() if isHostKeyError(errOut) { - return fmt.Errorf("%w: host key verification failed: %w (%s)", protocol.ErrNonRetryable, err, errOut) + return fmt.Errorf("%w: host key verification failed: %v (%s)", protocol.ErrNonRetryable, err, errOut) } return fmt.Errorf("failed to start ssh multiplexing control master: %w (%s)", err, errOut) } @@ -218,10 +222,9 @@ func (c *Connection) closeControl() error { return ErrControlPathNotSet } - args := make([]string, 0, 4+len(c.args())+1) + args := make([]string, 0, 4+len(c.args())) args = append(args, "-O", "exit", "-S", c.controlPath) args = append(args, c.args()...) - args = append(args, c.userhost()) log.Trace(context.Background(), "closing ssh multiplexing control master", log.KeyHost, c) cmd := exec.Command("ssh", args...) //nolint:noctx // cleanup code path, no context available @@ -235,7 +238,10 @@ func (c *Connection) closeControl() error { // StartProcess executes a command on the remote host, streaming stdin, stdout and stderr. func (c *Connection) StartProcess(ctx context.Context, cmdStr string, stdin io.Reader, stdout, stderr io.Writer) (protocol.Waiter, error) { - if !c.DisableMultiplexing && !c.isConnected { + c.controlMutex.Lock() + connected := c.isConnected + c.controlMutex.Unlock() + if !c.DisableMultiplexing && !connected { return nil, errNotConnected } @@ -286,29 +292,38 @@ func (c *Connection) String() string { // connections this probes the control master via ssh -O check. For // non-multiplexed connections it runs a no-op command over a fresh session. func (c *Connection) IsConnected() bool { - if !c.isConnected { + c.controlMutex.Lock() + connected := c.isConnected + cp := c.controlPath + c.controlMutex.Unlock() + + if !connected { return false } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if !c.DisableMultiplexing { - if c.controlPath == "" { + if cp == "" { // Control path not known (e.g. set via ssh_config -F); skip probe // to avoid incorrectly marking a live connection as disconnected. return true } args := make([]string, 0, 4+len(c.args())) - args = append(args, "-O", "check", "-S", c.controlPath) + args = append(args, "-O", "check", "-S", cp) args = append(args, c.args()...) if exec.CommandContext(ctx, "ssh", args...).Run() != nil { + c.controlMutex.Lock() c.isConnected = false + c.controlMutex.Unlock() return false } return true } proc, err := c.StartProcess(ctx, "exit 0", nil, nil, nil) if err != nil || proc.Wait() != nil { + c.controlMutex.Lock() c.isConnected = false + c.controlMutex.Unlock() return false } return true From 3a0f3083180c4b3664976f057f99bd2f063b373d Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Wed, 25 Mar 2026 15:20:04 +0200 Subject: [PATCH 6/9] lint and other findings Signed-off-by: Kimmo Lehto --- protocol/openssh/connection.go | 18 ++++++++++-------- protocol/winrm/connection.go | 25 +++++++++++++++++++++---- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/protocol/openssh/connection.go b/protocol/openssh/connection.go index 95d6633a..491cf287 100644 --- a/protocol/openssh/connection.go +++ b/protocol/openssh/connection.go @@ -165,7 +165,7 @@ func (c *Connection) Connect(ctx context.Context) error { if err != nil { errOut := errBuf.String() if isHostKeyError(errOut) { - return fmt.Errorf("%w: host key verification failed: %v (%s)", protocol.ErrNonRetryable, err, errOut) + return fmt.Errorf("%w: host key verification failed: %w (%s)", protocol.ErrNonRetryable, err, errOut) } return fmt.Errorf("failed to connect: %w", err) } @@ -196,7 +196,7 @@ func (c *Connection) Connect(ctx context.Context) error { c.isConnected = false errOut := errBuf.String() if isHostKeyError(errOut) { - return fmt.Errorf("%w: host key verification failed: %v (%s)", protocol.ErrNonRetryable, err, errOut) + return fmt.Errorf("%w: host key verification failed: %w (%s)", protocol.ErrNonRetryable, err, errOut) } return fmt.Errorf("failed to start ssh multiplexing control master: %w (%s)", err, errOut) } @@ -205,7 +205,7 @@ func (c *Connection) Connect(ctx context.Context) error { if cp, ok := c.Options["ControlPath"].(string); ok { c.controlPath = cp } - log.Trace(ctx, "started ssh multipliexing control master", log.KeyHost, c) + log.Trace(ctx, "started ssh multiplexing control master", log.KeyHost, c) return nil } @@ -294,7 +294,7 @@ func (c *Connection) String() string { func (c *Connection) IsConnected() bool { c.controlMutex.Lock() connected := c.isConnected - cp := c.controlPath + controlPath := c.controlPath c.controlMutex.Unlock() if !connected { @@ -303,13 +303,13 @@ func (c *Connection) IsConnected() bool { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if !c.DisableMultiplexing { - if cp == "" { + if controlPath == "" { // Control path not known (e.g. set via ssh_config -F); skip probe // to avoid incorrectly marking a live connection as disconnected. return true } args := make([]string, 0, 4+len(c.args())) - args = append(args, "-O", "check", "-S", cp) + args = append(args, "-O", "check", "-S", controlPath) args = append(args, c.args()...) if exec.CommandContext(ctx, "ssh", args...).Run() != nil { c.controlMutex.Lock() @@ -330,10 +330,12 @@ func (c *Connection) IsConnected() bool { } // Disconnect disconnects from the remote host. If multiplexing is enabled, this will close the control master. -// If multiplexing is disabled, this will do nothing. +// If multiplexing is disabled, this marks the connection as disconnected. func (c *Connection) Disconnect() { if c.DisableMultiplexing { - // nothing to do + c.controlMutex.Lock() + c.isConnected = false + c.controlMutex.Unlock() return } diff --git a/protocol/winrm/connection.go b/protocol/winrm/connection.go index a61d5152..3f9a27b7 100644 --- a/protocol/winrm/connection.go +++ b/protocol/winrm/connection.go @@ -37,6 +37,7 @@ type Connection struct { key []byte cert []byte + mu sync.Mutex client *winrm.Client } @@ -86,7 +87,10 @@ func (c *Connection) IsWindows() bool { // IsConnected returns true if the WinRM connection is alive by running a no-op // command. WinRM is stateless HTTP so the only real liveness test is a probe. func (c *Connection) IsConnected() bool { - if c.client == nil { + c.mu.Lock() + connected := c.client != nil + c.mu.Unlock() + if !connected { return false } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -198,14 +202,18 @@ func (c *Connection) Connect(ctx context.Context) error { return fmt.Errorf("create winrm client: %w", err) } + c.mu.Lock() c.client = client + c.mu.Unlock() return nil } // Disconnect closes the WinRM connection. func (c *Connection) Disconnect() { + c.mu.Lock() c.client = nil + c.mu.Unlock() } type command struct { @@ -253,14 +261,17 @@ func (c *command) Close() error { // StartProcess executes a command on the remote host and uses the passed in streams for stdin, stdout and stderr. It returns a Waiter with a .Wait() function that // blocks until the command finishes and returns an error if the exit code is not zero. func (c *Connection) StartProcess(ctx context.Context, cmd string, stdin io.Reader, stdout, stderr io.Writer) (protocol.Waiter, error) { - if c.client == nil { + c.mu.Lock() + client := c.client + c.mu.Unlock() + if client == nil { return nil, errNotConnected } if len(cmd) > 8191 { return nil, fmt.Errorf("%w: %w: command too long (%d/%d)", protocol.ErrNonRetryable, errInvalidCommand, len(cmd), 8191) } - shell, err := c.client.CreateShell() + shell, err := client.CreateShell() if err != nil { return nil, fmt.Errorf("create shell: %w", err) } @@ -312,10 +323,16 @@ func (c *Connection) StartProcess(ctx context.Context, cmd string, stdin io.Read // ExecInteractive executes a command on the host and passes stdin/stdout/stderr as-is to the session. func (c *Connection) ExecInteractive(cmd string, stdin io.Reader, stdout, stderr io.Writer) error { + c.mu.Lock() + client := c.client + c.mu.Unlock() + if client == nil { + return errNotConnected + } if cmd == "" { cmd = "cmd.exe" } - _, err := c.client.RunWithContextWithInput(context.Background(), cmd, stdout, stderr, stdin) + _, err := client.RunWithContextWithInput(context.Background(), cmd, stdout, stderr, stdin) if err != nil { return fmt.Errorf("execute command in interactive mode: %w", err) } From 544a67ec5d8112fad5474cfe1e96f2e6c37c45bd Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Wed, 25 Mar 2026 15:52:14 +0200 Subject: [PATCH 7/9] final(?) tweaks Signed-off-by: Kimmo Lehto --- client.go | 7 ++++--- protocol/connection.go | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index 7131c73c..5466dfba 100644 --- a/client.go +++ b/client.go @@ -390,9 +390,10 @@ func (c *Client) IsConnected() bool { return c.connection.IsConnected() } -// Protocol returns the protocol family used to connect to the host: -// "SSH", "WinRM", or "Local". Both the native SSH and OpenSSH -// implementations return "SSH". +// Protocol returns the protocol family used to connect to the host, +// such as "SSH", "WinRM", or "Local". Both the native SSH and OpenSSH +// implementations return "SSH". Custom or test implementations may +// return other values. func (c *Client) Protocol() string { if c.connection == nil { return "uninitialized" diff --git a/protocol/connection.go b/protocol/connection.go index b65ea70c..60656aa7 100644 --- a/protocol/connection.go +++ b/protocol/connection.go @@ -50,8 +50,9 @@ type InteractiveExecer interface { // Connection is the minimum interface for protocol implementations. type Connection interface { fmt.Stringer - // Protocol returns the protocol family: "SSH", "WinRM", or "Local". + // Protocol returns the protocol family, such as "SSH", "WinRM", or "Local". // Both the native SSH and OpenSSH implementations return "SSH". + // Custom or test implementations may return other values. Protocol() string // ProtocolName returns the specific implementation name: "SSH", "OpenSSH", // "WinRM", or "Local". Use this for logging or diagnostics where the From 2d7917ebe2981585e7b1aa0dedd907a95893d6a3 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Wed, 25 Mar 2026 16:17:47 +0200 Subject: [PATCH 8/9] wish these could be found in one go.. Signed-off-by: Kimmo Lehto --- client.go | 7 ++++--- protocol/connection.go | 17 ++++++++++------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index 5466dfba..1527e902 100644 --- a/client.go +++ b/client.go @@ -401,9 +401,10 @@ func (c *Client) Protocol() string { return c.connection.Protocol() } -// ProtocolName returns the specific protocol implementation name: -// "SSH", "OpenSSH", "WinRM", or "Local". Use this for logging or -// diagnostics where the distinction between native SSH and OpenSSH matters. +// ProtocolName returns the specific protocol implementation name, such as +// "SSH", "OpenSSH", "WinRM", or "Local". Use this for logging or diagnostics +// where the distinction between native SSH and OpenSSH matters. Custom or +// test implementations may return other values. func (c *Client) ProtocolName() string { if c.connection == nil { return "uninitialized" diff --git a/protocol/connection.go b/protocol/connection.go index 60656aa7..45bcd8c8 100644 --- a/protocol/connection.go +++ b/protocol/connection.go @@ -54,15 +54,18 @@ type Connection interface { // Both the native SSH and OpenSSH implementations return "SSH". // Custom or test implementations may return other values. Protocol() string - // ProtocolName returns the specific implementation name: "SSH", "OpenSSH", - // "WinRM", or "Local". Use this for logging or diagnostics where the - // distinction between native SSH and OpenSSH matters. + // ProtocolName returns the specific implementation name, such as "SSH", + // "OpenSSH", "WinRM", or "Local". Use this for logging or diagnostics + // where the distinction between native SSH and OpenSSH matters. Custom + // or test implementations may return other values. ProtocolName() string // IsConnected returns true if the connection is currently active. - // All implementations perform an active liveness probe (e.g. SSH keepalive, - // ssh -O check for OpenSSH multiplexing, or a no-op command for WinRM). - // Localhost always returns true. Callers should be aware this may block - // up to a timeout (typically 10s) and may cause side-effects on the remote. + // Built-in implementations attempt an active liveness probe where + // possible (e.g. SSH keepalive, ssh -O check for OpenSSH multiplexing, + // or a no-op command for WinRM), but some implementations may skip the + // probe or always return true (e.g. Localhost). Callers should be aware + // this may block up to a timeout (typically 10s) and may cause + // side-effects on the remote. IsConnected() bool IPAddress() string ProcessStarter From 08a6177b9dcd7a8cc06aa3a43f6121b0a0ef0e40 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Wed, 25 Mar 2026 16:32:14 +0200 Subject: [PATCH 9/9] more findings Signed-off-by: Kimmo Lehto --- protocol/openssh/connection.go | 15 +++++++++------ protocol/winrm/connection.go | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/protocol/openssh/connection.go b/protocol/openssh/connection.go index 491cf287..613d39a8 100644 --- a/protocol/openssh/connection.go +++ b/protocol/openssh/connection.go @@ -303,13 +303,16 @@ func (c *Connection) IsConnected() bool { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if !c.DisableMultiplexing { - if controlPath == "" { - // Control path not known (e.g. set via ssh_config -F); skip probe - // to avoid incorrectly marking a live connection as disconnected. - return true + var args []string + if controlPath != "" { + args = make([]string, 0, 4+len(c.args())) + args = append(args, "-O", "check", "-S", controlPath) + } else { + // ControlPath comes from ssh_config (-F); let ssh resolve it from options. + args = make([]string, 0, 2+len(c.Options.ToArgs())+len(c.args())) + args = append(args, c.Options.ToArgs()...) + args = append(args, "-O", "check") } - args := make([]string, 0, 4+len(c.args())) - args = append(args, "-O", "check", "-S", controlPath) args = append(args, c.args()...) if exec.CommandContext(ctx, "ssh", args...).Run() != nil { c.controlMutex.Lock() diff --git a/protocol/winrm/connection.go b/protocol/winrm/connection.go index 3f9a27b7..c7f0cc58 100644 --- a/protocol/winrm/connection.go +++ b/protocol/winrm/connection.go @@ -95,7 +95,7 @@ func (c *Connection) IsConnected() bool { } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - proc, err := c.StartProcess(ctx, "exit 0", nil, nil, nil) + proc, err := c.StartProcess(ctx, "cmd.exe /c exit 0", nil, nil, nil) if err != nil { return false }