Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,39 @@ func (c *Client) OS() (*os.Release, error) {
return os, nil
}

// Protocol returns the protocol used to connect to the host.
// IsConnected returns true if the underlying connection is currently active.
// This delegates to the protocol connection's IsConnected, which may perform
// an active liveness probe (e.g. ssh -O check for OpenSSH multiplexed sessions,
// or a no-op command for WinRM/non-multiplexed SSH) and may block up to a timeout.
func (c *Client) IsConnected() bool {
if c.connection == nil {
return false
}
return c.connection.IsConnected()
}

// Protocol returns the protocol family used to connect to the host,
// such as "SSH", "WinRM", or "Local". Both the native SSH and OpenSSH
// implementations return "SSH". Custom or test implementations may
// return other values.
func (c *Client) Protocol() string {
if c.connection == nil {
return "uninitialized"
}
return c.connection.Protocol()
}

// ProtocolName returns the specific protocol implementation name, such as
// "SSH", "OpenSSH", "WinRM", or "Local". Use this for logging or diagnostics
// where the distinction between native SSH and OpenSSH matters. Custom or
// test implementations may return other values.
func (c *Client) ProtocolName() string {
if c.connection == nil {
return "uninitialized"
}
return c.connection.ProtocolName()
}

// Address returns the address of the host.
func (c *Client) Address() string {
if c.connection != nil {
Expand Down
34 changes: 34 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,40 @@ func TestClientPackageManagerErrorFallback(t *testing.T) {
require.ErrorIs(t, err, mockErr)
}

func TestClientReconnect(t *testing.T) {
conn := rigtest.NewMockConnection()
conn.AddCommandOutput(rigtest.Match("echo hello"), "hello")

client, err := rig.NewClient(rig.WithConnection(conn))
require.NoError(t, err)

require.NoError(t, client.Connect(context.Background()))
require.True(t, client.IsConnected())

out, err := client.ExecOutput("echo hello")
require.NoError(t, err)
require.Equal(t, "hello", out)

client.Disconnect()
require.False(t, client.IsConnected())

require.NoError(t, client.Connect(context.Background()))
require.True(t, client.IsConnected())

out, err = client.ExecOutput("echo hello")
require.NoError(t, err)
require.Equal(t, "hello", out)
}

func TestClientProtocolName(t *testing.T) {
conn := rigtest.NewMockConnection()
client, err := rig.NewClient(rig.WithConnection(conn))
require.NoError(t, err)

require.Equal(t, "mock", client.Protocol())
require.Equal(t, "mock", client.ProtocolName())
}

type testConfig struct {
Hosts []*testHost `yaml:"hosts"`
}
Expand Down
16 changes: 16 additions & 0 deletions protocol/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,23 @@ type InteractiveExecer interface {
// Connection is the minimum interface for protocol implementations.
type Connection interface {
fmt.Stringer
// Protocol returns the protocol family, such as "SSH", "WinRM", or "Local".
// Both the native SSH and OpenSSH implementations return "SSH".
// Custom or test implementations may return other values.
Protocol() string
// ProtocolName returns the specific implementation name, such as "SSH",
// "OpenSSH", "WinRM", or "Local". Use this for logging or diagnostics
// where the distinction between native SSH and OpenSSH matters. Custom
// or test implementations may return other values.
ProtocolName() string
// IsConnected returns true if the connection is currently active.
// Built-in implementations attempt an active liveness probe where
// possible (e.g. SSH keepalive, ssh -O check for OpenSSH multiplexing,
// or a no-op command for WinRM), but some implementations may skip the
// probe or always return true (e.g. Localhost). Callers should be aware
// this may block up to a timeout (typically 10s) and may cause
// side-effects on the remote.
IsConnected() bool
IPAddress() string
ProcessStarter
WindowsChecker
Expand Down
12 changes: 11 additions & 1 deletion protocol/localhost/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ func NewConnection() (*Connection, error) {
return &Connection{}, nil
}

// Protocol returns the protocol name, "Local".
// Protocol returns the protocol family, "Local".
func (c *Connection) Protocol() string {
return "Local"
}

// ProtocolName returns the implementation name, "Local".
func (c *Connection) ProtocolName() string {
return "Local"
}

// IPAddress returns the connection address.
func (c *Connection) IPAddress() string {
return "127.0.0.1"
Expand All @@ -46,6 +51,11 @@ func (c *Connection) IsWindows() bool {
return runtime.GOOS == "windows"
}

// IsConnected always returns true for localhost — there is no connection to lose.
func (c *Connection) IsConnected() bool {
return true
}

// StartProcess executes a command on the remote host and uses the passed in streams for stdin, stdout and stderr. It returns a Waiter with a .Wait() function that
// blocks until the command finishes and returns an error if the exit code is not zero.
func (c *Connection) StartProcess(ctx context.Context, cmd string, stdin io.Reader, stdout, stderr io.Writer) (protocol.Waiter, error) {
Expand Down
116 changes: 92 additions & 24 deletions protocol/openssh/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,21 @@ var (
errNotConnected = errors.New("not connected")
)

// isHostKeyError reports whether ssh stderr output indicates a host key
// verification failure. These are fatal and should not be retried.
func isHostKeyError(stderr string) bool {
return strings.Contains(stderr, "Host key verification failed") ||
strings.Contains(stderr, "REMOTE HOST IDENTIFICATION HAS CHANGED")
}

// Connection is a rig.Connection implementation that uses the system openssh client "ssh" to connect to remote hosts.
// The connection is by default multiplexec over a control master, so that subsequent connections don't need to re-authenticate.
type Connection struct {
log.LoggerInjectable `yaml:"-"`
Config `yaml:",inline"`

isConnected bool
controlPath string
controlMutex sync.Mutex

isWindows *bool
Expand All @@ -44,8 +52,13 @@ func NewConnection(cfg Config) (*Connection, error) {
return &Connection{Config: cfg}, nil
}

// Protocol returns the protocol name.
// Protocol returns the protocol family, "SSH".
func (c *Connection) Protocol() string {
return "SSH"
}

// ProtocolName returns the implementation name, "OpenSSH".
func (c *Connection) ProtocolName() string {
return "OpenSSH"
}

Expand Down Expand Up @@ -135,20 +148,33 @@ func (c *Connection) args() []string {

// Connect connects to the remote host. If multiplexing is enabled, this will start a control master. If multiplexing is disabled, this will just run a noop command to check connectivity.
func (c *Connection) Connect(ctx context.Context) error {
c.controlMutex.Lock()
if c.isConnected {
c.controlMutex.Unlock()
return nil
}

if c.DisableMultiplexing {
// just run a noop command to check connectivity
if _, err := c.StartProcess(ctx, "exit 0", nil, nil, nil); err != nil {
c.controlMutex.Unlock()
// Run a noop to check connectivity. Capture stderr to detect host key failures.
errBuf := bytes.NewBuffer(nil)
proc, err := c.StartProcess(ctx, "exit 0", nil, nil, errBuf)
if err == nil {
err = proc.Wait()
}
if err != nil {
errOut := errBuf.String()
if isHostKeyError(errOut) {
return fmt.Errorf("%w: host key verification failed: %w (%s)", protocol.ErrNonRetryable, err, errOut)
}
return fmt.Errorf("failed to connect: %w", err)
}
c.controlMutex.Lock()
c.isConnected = true
c.controlMutex.Unlock()
return nil
}

c.controlMutex.Lock()
defer c.controlMutex.Unlock()

opts := c.Options.Copy()
Expand All @@ -162,24 +188,24 @@ func (c *Connection) Connect(ctx context.Context) error {
args = append(args, c.args()...)

cmd := exec.CommandContext(ctx, "ssh", args...)
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("create stderr pipe: %w", err)
}
defer stderr.Close()
errBuf := bytes.NewBuffer(nil)
go func() {
_, _ = io.Copy(errBuf, stderr)
}()
cmd.Stderr = errBuf

log.Trace(ctx, "starting ssh control master", log.KeyHost, c, log.KeyCommand, strings.Join(args, " "))
if err := cmd.Run(); err != nil {
c.isConnected = false
return fmt.Errorf("failed to start ssh multiplexing control master: %w (%s)", err, errBuf.String())
errOut := errBuf.String()
if isHostKeyError(errOut) {
return fmt.Errorf("%w: host key verification failed: %w (%s)", protocol.ErrNonRetryable, err, errOut)
}
return fmt.Errorf("failed to start ssh multiplexing control master: %w (%s)", err, errOut)
}

c.isConnected = true
log.Trace(ctx, "started ssh multipliexing control master", log.KeyHost, c)
if cp, ok := c.Options["ControlPath"].(string); ok {
c.controlPath = cp
}
log.Trace(ctx, "started ssh multiplexing control master", log.KeyHost, c)

return nil
}
Expand All @@ -192,15 +218,13 @@ func (c *Connection) closeControl() error {
return nil
}

controlPath, ok := c.Options["ControlPath"].(string)
if !ok {
if c.controlPath == "" {
return ErrControlPathNotSet
}

args := make([]string, 0, 4+len(c.args())+1)
args = append(args, "-O", "exit", "-S", controlPath)
args := make([]string, 0, 4+len(c.args()))
args = append(args, "-O", "exit", "-S", c.controlPath)
args = append(args, c.args()...)
args = append(args, c.userhost())

log.Trace(context.Background(), "closing ssh multiplexing control master", log.KeyHost, c)
cmd := exec.Command("ssh", args...) //nolint:noctx // cleanup code path, no context available
Expand All @@ -214,7 +238,10 @@ func (c *Connection) closeControl() error {

// StartProcess executes a command on the remote host, streaming stdin, stdout and stderr.
func (c *Connection) StartProcess(ctx context.Context, cmdStr string, stdin io.Reader, stdout, stderr io.Writer) (protocol.Waiter, error) {
if !c.DisableMultiplexing && !c.isConnected {
c.controlMutex.Lock()
connected := c.isConnected
c.controlMutex.Unlock()
if !c.DisableMultiplexing && !connected {
return nil, errNotConnected
}

Expand Down Expand Up @@ -261,16 +288,57 @@ func (c *Connection) String() string {
return c.name
}

// IsConnected returns true if the connection is connected.
// IsConnected returns true if the connection is alive. For multiplexed
// connections this probes the control master via ssh -O check. For
// non-multiplexed connections it runs a no-op command over a fresh session.
func (c *Connection) IsConnected() bool {
return c.isConnected
c.controlMutex.Lock()
connected := c.isConnected
controlPath := c.controlPath
c.controlMutex.Unlock()

if !connected {
return false
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if !c.DisableMultiplexing {
var args []string
if controlPath != "" {
args = make([]string, 0, 4+len(c.args()))
args = append(args, "-O", "check", "-S", controlPath)
} else {
// ControlPath comes from ssh_config (-F); let ssh resolve it from options.
args = make([]string, 0, 2+len(c.Options.ToArgs())+len(c.args()))
args = append(args, c.Options.ToArgs()...)
args = append(args, "-O", "check")
}
args = append(args, c.args()...)
if exec.CommandContext(ctx, "ssh", args...).Run() != nil {
c.controlMutex.Lock()
c.isConnected = false
c.controlMutex.Unlock()
return false
}
return true
}
proc, err := c.StartProcess(ctx, "exit 0", nil, nil, nil)
if err != nil || proc.Wait() != nil {
c.controlMutex.Lock()
c.isConnected = false
c.controlMutex.Unlock()
return false
}
return true
}

// Disconnect disconnects from the remote host. If multiplexing is enabled, this will close the control master.
// If multiplexing is disabled, this will do nothing.
// If multiplexing is disabled, this marks the connection as disconnected.
func (c *Connection) Disconnect() {
if c.DisableMultiplexing {
// nothing to do
c.controlMutex.Lock()
c.isConnected = false
c.controlMutex.Unlock()
return
}

Expand Down
7 changes: 6 additions & 1 deletion protocol/ssh/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,16 @@ func (c *Connection) SetDefaults(ctx context.Context) {
})
}

// Protocol returns the protocol name, "SSH".
// Protocol returns the protocol family, "SSH".
func (c *Connection) Protocol() string {
return "SSH"
}

// ProtocolName returns the implementation name, "SSH".
func (c *Connection) ProtocolName() string {
return "SSH"
}

// IPAddress returns the connection address.
func (c *Connection) IPAddress() string {
return c.Address
Expand Down
Loading
Loading