diff --git a/client.go b/client.go index 67bc2110..1527e902 100644 --- a/client.go +++ b/client.go @@ -379,7 +379,21 @@ func (c *Client) OS() (*os.Release, error) { return os, nil } -// Protocol returns the protocol used to connect to the host. +// IsConnected returns true if the underlying connection is currently active. +// This delegates to the protocol connection's IsConnected, which may perform +// an active liveness probe (e.g. ssh -O check for OpenSSH multiplexed sessions, +// or a no-op command for WinRM/non-multiplexed SSH) and may block up to a timeout. +func (c *Client) IsConnected() bool { + if c.connection == nil { + return false + } + return c.connection.IsConnected() +} + +// Protocol returns the protocol family used to connect to the host, +// such as "SSH", "WinRM", or "Local". Both the native SSH and OpenSSH +// implementations return "SSH". Custom or test implementations may +// return other values. func (c *Client) Protocol() string { if c.connection == nil { return "uninitialized" @@ -387,6 +401,17 @@ func (c *Client) Protocol() string { return c.connection.Protocol() } +// ProtocolName returns the specific protocol implementation name, such as +// "SSH", "OpenSSH", "WinRM", or "Local". Use this for logging or diagnostics +// where the distinction between native SSH and OpenSSH matters. Custom or +// test implementations may return other values. +func (c *Client) ProtocolName() string { + if c.connection == nil { + return "uninitialized" + } + return c.connection.ProtocolName() +} + // Address returns the address of the host. func (c *Client) Address() string { if c.connection != nil { diff --git a/client_test.go b/client_test.go index 3700db63..0a47f521 100644 --- a/client_test.go +++ b/client_test.go @@ -103,6 +103,40 @@ func TestClientPackageManagerErrorFallback(t *testing.T) { require.ErrorIs(t, err, mockErr) } +func TestClientReconnect(t *testing.T) { + conn := rigtest.NewMockConnection() + conn.AddCommandOutput(rigtest.Match("echo hello"), "hello") + + client, err := rig.NewClient(rig.WithConnection(conn)) + require.NoError(t, err) + + require.NoError(t, client.Connect(context.Background())) + require.True(t, client.IsConnected()) + + out, err := client.ExecOutput("echo hello") + require.NoError(t, err) + require.Equal(t, "hello", out) + + client.Disconnect() + require.False(t, client.IsConnected()) + + require.NoError(t, client.Connect(context.Background())) + require.True(t, client.IsConnected()) + + out, err = client.ExecOutput("echo hello") + require.NoError(t, err) + require.Equal(t, "hello", out) +} + +func TestClientProtocolName(t *testing.T) { + conn := rigtest.NewMockConnection() + client, err := rig.NewClient(rig.WithConnection(conn)) + require.NoError(t, err) + + require.Equal(t, "mock", client.Protocol()) + require.Equal(t, "mock", client.ProtocolName()) +} + type testConfig struct { Hosts []*testHost `yaml:"hosts"` } diff --git a/protocol/connection.go b/protocol/connection.go index 11aa53f6..45bcd8c8 100644 --- a/protocol/connection.go +++ b/protocol/connection.go @@ -50,7 +50,23 @@ type InteractiveExecer interface { // Connection is the minimum interface for protocol implementations. type Connection interface { fmt.Stringer + // Protocol returns the protocol family, such as "SSH", "WinRM", or "Local". + // Both the native SSH and OpenSSH implementations return "SSH". + // Custom or test implementations may return other values. Protocol() string + // ProtocolName returns the specific implementation name, such as "SSH", + // "OpenSSH", "WinRM", or "Local". Use this for logging or diagnostics + // where the distinction between native SSH and OpenSSH matters. Custom + // or test implementations may return other values. + ProtocolName() string + // IsConnected returns true if the connection is currently active. + // Built-in implementations attempt an active liveness probe where + // possible (e.g. SSH keepalive, ssh -O check for OpenSSH multiplexing, + // or a no-op command for WinRM), but some implementations may skip the + // probe or always return true (e.g. Localhost). Callers should be aware + // this may block up to a timeout (typically 10s) and may cause + // side-effects on the remote. + IsConnected() bool IPAddress() string ProcessStarter WindowsChecker diff --git a/protocol/localhost/connection.go b/protocol/localhost/connection.go index bc5c595a..b6039495 100644 --- a/protocol/localhost/connection.go +++ b/protocol/localhost/connection.go @@ -26,11 +26,16 @@ func NewConnection() (*Connection, error) { return &Connection{}, nil } -// Protocol returns the protocol name, "Local". +// Protocol returns the protocol family, "Local". func (c *Connection) Protocol() string { return "Local" } +// ProtocolName returns the implementation name, "Local". +func (c *Connection) ProtocolName() string { + return "Local" +} + // IPAddress returns the connection address. func (c *Connection) IPAddress() string { return "127.0.0.1" @@ -46,6 +51,11 @@ func (c *Connection) IsWindows() bool { return runtime.GOOS == "windows" } +// IsConnected always returns true for localhost — there is no connection to lose. +func (c *Connection) IsConnected() bool { + return true +} + // StartProcess executes a command on the remote host and uses the passed in streams for stdin, stdout and stderr. It returns a Waiter with a .Wait() function that // blocks until the command finishes and returns an error if the exit code is not zero. func (c *Connection) StartProcess(ctx context.Context, cmd string, stdin io.Reader, stdout, stderr io.Writer) (protocol.Waiter, error) { diff --git a/protocol/openssh/connection.go b/protocol/openssh/connection.go index 23ea787e..613d39a8 100644 --- a/protocol/openssh/connection.go +++ b/protocol/openssh/connection.go @@ -24,6 +24,13 @@ var ( errNotConnected = errors.New("not connected") ) +// isHostKeyError reports whether ssh stderr output indicates a host key +// verification failure. These are fatal and should not be retried. +func isHostKeyError(stderr string) bool { + return strings.Contains(stderr, "Host key verification failed") || + strings.Contains(stderr, "REMOTE HOST IDENTIFICATION HAS CHANGED") +} + // Connection is a rig.Connection implementation that uses the system openssh client "ssh" to connect to remote hosts. // The connection is by default multiplexec over a control master, so that subsequent connections don't need to re-authenticate. type Connection struct { @@ -31,6 +38,7 @@ type Connection struct { Config `yaml:",inline"` isConnected bool + controlPath string controlMutex sync.Mutex isWindows *bool @@ -44,8 +52,13 @@ func NewConnection(cfg Config) (*Connection, error) { return &Connection{Config: cfg}, nil } -// Protocol returns the protocol name. +// Protocol returns the protocol family, "SSH". func (c *Connection) Protocol() string { + return "SSH" +} + +// ProtocolName returns the implementation name, "OpenSSH". +func (c *Connection) ProtocolName() string { return "OpenSSH" } @@ -135,20 +148,33 @@ func (c *Connection) args() []string { // Connect connects to the remote host. If multiplexing is enabled, this will start a control master. If multiplexing is disabled, this will just run a noop command to check connectivity. func (c *Connection) Connect(ctx context.Context) error { + c.controlMutex.Lock() if c.isConnected { + c.controlMutex.Unlock() return nil } if c.DisableMultiplexing { - // just run a noop command to check connectivity - if _, err := c.StartProcess(ctx, "exit 0", nil, nil, nil); err != nil { + c.controlMutex.Unlock() + // Run a noop to check connectivity. Capture stderr to detect host key failures. + errBuf := bytes.NewBuffer(nil) + proc, err := c.StartProcess(ctx, "exit 0", nil, nil, errBuf) + if err == nil { + err = proc.Wait() + } + if err != nil { + errOut := errBuf.String() + if isHostKeyError(errOut) { + return fmt.Errorf("%w: host key verification failed: %w (%s)", protocol.ErrNonRetryable, err, errOut) + } return fmt.Errorf("failed to connect: %w", err) } + c.controlMutex.Lock() c.isConnected = true + c.controlMutex.Unlock() return nil } - c.controlMutex.Lock() defer c.controlMutex.Unlock() opts := c.Options.Copy() @@ -162,24 +188,24 @@ func (c *Connection) Connect(ctx context.Context) error { args = append(args, c.args()...) cmd := exec.CommandContext(ctx, "ssh", args...) - stderr, err := cmd.StderrPipe() - if err != nil { - return fmt.Errorf("create stderr pipe: %w", err) - } - defer stderr.Close() errBuf := bytes.NewBuffer(nil) - go func() { - _, _ = io.Copy(errBuf, stderr) - }() + cmd.Stderr = errBuf log.Trace(ctx, "starting ssh control master", log.KeyHost, c, log.KeyCommand, strings.Join(args, " ")) if err := cmd.Run(); err != nil { c.isConnected = false - return fmt.Errorf("failed to start ssh multiplexing control master: %w (%s)", err, errBuf.String()) + errOut := errBuf.String() + if isHostKeyError(errOut) { + return fmt.Errorf("%w: host key verification failed: %w (%s)", protocol.ErrNonRetryable, err, errOut) + } + return fmt.Errorf("failed to start ssh multiplexing control master: %w (%s)", err, errOut) } c.isConnected = true - log.Trace(ctx, "started ssh multipliexing control master", log.KeyHost, c) + if cp, ok := c.Options["ControlPath"].(string); ok { + c.controlPath = cp + } + log.Trace(ctx, "started ssh multiplexing control master", log.KeyHost, c) return nil } @@ -192,15 +218,13 @@ func (c *Connection) closeControl() error { return nil } - controlPath, ok := c.Options["ControlPath"].(string) - if !ok { + if c.controlPath == "" { return ErrControlPathNotSet } - args := make([]string, 0, 4+len(c.args())+1) - args = append(args, "-O", "exit", "-S", controlPath) + args := make([]string, 0, 4+len(c.args())) + args = append(args, "-O", "exit", "-S", c.controlPath) args = append(args, c.args()...) - args = append(args, c.userhost()) log.Trace(context.Background(), "closing ssh multiplexing control master", log.KeyHost, c) cmd := exec.Command("ssh", args...) //nolint:noctx // cleanup code path, no context available @@ -214,7 +238,10 @@ func (c *Connection) closeControl() error { // StartProcess executes a command on the remote host, streaming stdin, stdout and stderr. func (c *Connection) StartProcess(ctx context.Context, cmdStr string, stdin io.Reader, stdout, stderr io.Writer) (protocol.Waiter, error) { - if !c.DisableMultiplexing && !c.isConnected { + c.controlMutex.Lock() + connected := c.isConnected + c.controlMutex.Unlock() + if !c.DisableMultiplexing && !connected { return nil, errNotConnected } @@ -261,16 +288,57 @@ func (c *Connection) String() string { return c.name } -// IsConnected returns true if the connection is connected. +// IsConnected returns true if the connection is alive. For multiplexed +// connections this probes the control master via ssh -O check. For +// non-multiplexed connections it runs a no-op command over a fresh session. func (c *Connection) IsConnected() bool { - return c.isConnected + c.controlMutex.Lock() + connected := c.isConnected + controlPath := c.controlPath + c.controlMutex.Unlock() + + if !connected { + return false + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if !c.DisableMultiplexing { + var args []string + if controlPath != "" { + args = make([]string, 0, 4+len(c.args())) + args = append(args, "-O", "check", "-S", controlPath) + } else { + // ControlPath comes from ssh_config (-F); let ssh resolve it from options. + args = make([]string, 0, 2+len(c.Options.ToArgs())+len(c.args())) + args = append(args, c.Options.ToArgs()...) + args = append(args, "-O", "check") + } + args = append(args, c.args()...) + if exec.CommandContext(ctx, "ssh", args...).Run() != nil { + c.controlMutex.Lock() + c.isConnected = false + c.controlMutex.Unlock() + return false + } + return true + } + proc, err := c.StartProcess(ctx, "exit 0", nil, nil, nil) + if err != nil || proc.Wait() != nil { + c.controlMutex.Lock() + c.isConnected = false + c.controlMutex.Unlock() + return false + } + return true } // Disconnect disconnects from the remote host. If multiplexing is enabled, this will close the control master. -// If multiplexing is disabled, this will do nothing. +// If multiplexing is disabled, this marks the connection as disconnected. func (c *Connection) Disconnect() { if c.DisableMultiplexing { - // nothing to do + c.controlMutex.Lock() + c.isConnected = false + c.controlMutex.Unlock() return } diff --git a/protocol/ssh/connection.go b/protocol/ssh/connection.go index 87cd255f..26708420 100644 --- a/protocol/ssh/connection.go +++ b/protocol/ssh/connection.go @@ -151,11 +151,16 @@ func (c *Connection) SetDefaults(ctx context.Context) { }) } -// Protocol returns the protocol name, "SSH". +// Protocol returns the protocol family, "SSH". func (c *Connection) Protocol() string { return "SSH" } +// ProtocolName returns the implementation name, "SSH". +func (c *Connection) ProtocolName() string { + return "SSH" +} + // IPAddress returns the connection address. func (c *Connection) IPAddress() string { return c.Address diff --git a/protocol/winrm/connection.go b/protocol/winrm/connection.go index 57794770..c7f0cc58 100644 --- a/protocol/winrm/connection.go +++ b/protocol/winrm/connection.go @@ -37,6 +37,7 @@ type Connection struct { key []byte cert []byte + mu sync.Mutex client *winrm.Client } @@ -54,11 +55,16 @@ func NewConnection(cfg Config, opts ...Option) (*Connection, error) { return c, nil } -// Protocol returns the protocol name, "WinRM". +// Protocol returns the protocol family, "WinRM". func (c *Connection) Protocol() string { return "WinRM" } +// ProtocolName returns the implementation name, "WinRM". +func (c *Connection) ProtocolName() string { + return "WinRM" +} + // IPAddress returns the connection address. func (c *Connection) IPAddress() string { return c.Address @@ -78,6 +84,24 @@ func (c *Connection) IsWindows() bool { return true } +// IsConnected returns true if the WinRM connection is alive by running a no-op +// command. WinRM is stateless HTTP so the only real liveness test is a probe. +func (c *Connection) IsConnected() bool { + c.mu.Lock() + connected := c.client != nil + c.mu.Unlock() + if !connected { + return false + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + proc, err := c.StartProcess(ctx, "cmd.exe /c exit 0", nil, nil, nil) + if err != nil { + return false + } + return proc.Wait() == nil +} + func (c *Connection) loadCertificates() error { c.caCert = nil if c.CACertPath != "" { @@ -178,14 +202,18 @@ func (c *Connection) Connect(ctx context.Context) error { return fmt.Errorf("create winrm client: %w", err) } + c.mu.Lock() c.client = client + c.mu.Unlock() return nil } // Disconnect closes the WinRM connection. func (c *Connection) Disconnect() { + c.mu.Lock() c.client = nil + c.mu.Unlock() } type command struct { @@ -233,14 +261,17 @@ func (c *command) Close() error { // StartProcess executes a command on the remote host and uses the passed in streams for stdin, stdout and stderr. It returns a Waiter with a .Wait() function that // blocks until the command finishes and returns an error if the exit code is not zero. func (c *Connection) StartProcess(ctx context.Context, cmd string, stdin io.Reader, stdout, stderr io.Writer) (protocol.Waiter, error) { - if c.client == nil { + c.mu.Lock() + client := c.client + c.mu.Unlock() + if client == nil { return nil, errNotConnected } if len(cmd) > 8191 { return nil, fmt.Errorf("%w: %w: command too long (%d/%d)", protocol.ErrNonRetryable, errInvalidCommand, len(cmd), 8191) } - shell, err := c.client.CreateShell() + shell, err := client.CreateShell() if err != nil { return nil, fmt.Errorf("create shell: %w", err) } @@ -292,10 +323,16 @@ func (c *Connection) StartProcess(ctx context.Context, cmd string, stdin io.Read // ExecInteractive executes a command on the host and passes stdin/stdout/stderr as-is to the session. func (c *Connection) ExecInteractive(cmd string, stdin io.Reader, stdout, stderr io.Writer) error { + c.mu.Lock() + client := c.client + c.mu.Unlock() + if client == nil { + return errNotConnected + } if cmd == "" { cmd = "cmd.exe" } - _, err := c.client.RunWithContextWithInput(context.Background(), cmd, stdout, stderr, stdin) + _, err := client.RunWithContextWithInput(context.Background(), cmd, stdout, stderr, stdin) if err != nil { return fmt.Errorf("execute command in interactive mode: %w", err) } diff --git a/rigtest/mockrunner.go b/rigtest/mockrunner.go index afe1d22a..6fedcd98 100644 --- a/rigtest/mockrunner.go +++ b/rigtest/mockrunner.go @@ -109,8 +109,9 @@ type MockConnection struct { log.LoggerInjectable commands []string *MockStarter - Windows bool - mu sync.Mutex + Windows bool + mu sync.Mutex + connected bool } // NewMockConnection creates a new mock connection. @@ -120,6 +121,21 @@ func NewMockConnection() *MockConnection { } } +// Connect marks the mock connection as connected. +func (m *MockConnection) Connect(_ context.Context) error { + m.mu.Lock() + defer m.mu.Unlock() + m.connected = true + return nil +} + +// Disconnect marks the mock connection as disconnected. +func (m *MockConnection) Disconnect() { + m.mu.Lock() + defer m.mu.Unlock() + m.connected = false +} + // IsWindows returns true if the runner's connection is set to be a Windows client. func (m *MockRunner) IsWindows() bool { return m.Windows @@ -141,8 +157,20 @@ func (m *MockConnection) IsWindows() bool { return m.Windows } // String returns the string representation of the client. func (m *MockConnection) String() string { return "mockclient" } +const mockProtocol = "mock" + // Protocol returns the protocol of the client. -func (m *MockConnection) Protocol() string { return "mock" } +func (m *MockConnection) Protocol() string { return mockProtocol } + +// ProtocolName returns the protocol name of the client. +func (m *MockConnection) ProtocolName() string { return mockProtocol } + +// IsConnected returns true if the client is connected. +func (m *MockConnection) IsConnected() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.connected +} // IPAddress returns the IP address of the client. func (m *MockConnection) IPAddress() string { return "mock" }