Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions internal/app/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,9 @@ func ResolveConnection(ctx context.Context, opts ConnectionOptions) (*Connection

var sshClient *ssh.Client
if opts.Profile.SSHConfig != nil {
passphrase := opts.Profile.SSHConfig.Passphrase
if passphrase != "" {
pp, xe := secret.Resolve(passphrase, secret.Options{AllowPlaintext: allowPlaintext})
if xe != nil {
return nil, xe
}
passphrase = pp
}

sshOpts := ssh.Options{
Host: opts.Profile.SSHConfig.Host,
Port: opts.Profile.SSHConfig.Port,
User: opts.Profile.SSHConfig.User,
IdentityFile: opts.Profile.SSHConfig.IdentityFile,
Passphrase: passphrase,
KnownHostsFile: opts.Profile.SSHConfig.KnownHostsFile,
SkipKnownHostsCheck: opts.SkipHostKeyCheck || opts.Profile.SSHConfig.SkipHostKey,
sshOpts, xe := resolveSSHOptions(opts.Profile, allowPlaintext, opts.SkipHostKeyCheck)
if xe != nil {
return nil, xe
}
sc, xe := ssh.Connect(ctx, sshOpts)
if xe != nil {
Expand Down
22 changes: 21 additions & 1 deletion internal/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"path/filepath"
"strings"
"sync/atomic"
"time"

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
Expand Down Expand Up @@ -54,13 +55,32 @@ func Connect(ctx context.Context, opts Options) (*Client, *errors.XError) {
}

addr := fmt.Sprintf("%s:%d", opts.Host, opts.Port)
client, err := ssh.Dial("tcp", addr, config)

// Use net.Dialer with context so that context cancellation/timeout
// can interrupt the TCP connection phase (ssh.Dial does not accept context).
d := net.Dialer{}
netConn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, errors.Wrap(errors.CodeSSHDialFailed, "failed to connect to ssh server", map[string]any{"host": opts.Host}, err)
}

// Perform SSH handshake over the established TCP connection.
// Set a deadline derived from context to prevent hanging during handshake.
if deadline, ok := ctx.Deadline(); ok {
_ = netConn.SetDeadline(deadline)
}
sshConn, chans, reqs, err := ssh.NewClientConn(netConn, addr, config)
if err != nil {
_ = netConn.Close()
if strings.Contains(err.Error(), "unable to authenticate") {
return nil, errors.Wrap(errors.CodeSSHAuthFailed, "ssh authentication failed", map[string]any{"host": opts.Host}, err)
}
return nil, errors.Wrap(errors.CodeSSHDialFailed, "failed to connect to ssh server", map[string]any{"host": opts.Host}, err)
}
// Clear the deadline after successful handshake so it doesn't affect later I/O.
_ = netConn.SetDeadline(time.Time{})

client := ssh.NewClient(sshConn, chans, reqs)
c := &Client{client: client}
c.alive.Store(true)
return c, nil
Expand Down
184 changes: 184 additions & 0 deletions internal/ssh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"net"
"os"
"path/filepath"
Expand Down Expand Up @@ -298,7 +300,7 @@
func TestConnect_DialFailureReturnsCode(t *testing.T) {
keyPath := writeTestKey(t, t.TempDir(), "id_rsa")

ln, err := net.Listen("tcp", "127.0.0.1:0")

Check failure on line 303 in internal/ssh/client_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Define a constant instead of duplicating this literal "127.0.0.1:0" 3 times.

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0jncaRZXezoAn1aUfq&open=AZ0jncaRZXezoAn1aUfq&pullRequest=38
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
Expand Down Expand Up @@ -461,3 +463,185 @@

return false
}

// ============================================================================
// Tests using in-process SSH server
// ============================================================================

func TestConnect_RealSSHServer(t *testing.T) {

Check warning on line 471 in internal/ssh/client_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Rename function "TestConnect_RealSSHServer" to match the regular expression ^(_|[a-zA-Z0-9]+)$

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0jncaRZXezoAn1aUfr&open=AZ0jncaRZXezoAn1aUfr&pullRequest=38
srv := newTestSSHServer(t)
opts := connectToTestServer(srv)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

client, xe := Connect(ctx, opts)
if xe != nil {
t.Fatalf("connect to test SSH server failed: %v", xe)
}
defer client.Close()

if !client.Alive() {
t.Error("client should be alive after connect")
}
}

func TestConnect_RealSSHServer_Keepalive(t *testing.T) {

Check warning on line 489 in internal/ssh/client_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Rename function "TestConnect_RealSSHServer_Keepalive" to match the regular expression ^(_|[a-zA-Z0-9]+)$

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0jncaRZXezoAn1aUfs&open=AZ0jncaRZXezoAn1aUfs&pullRequest=38
srv := newTestSSHServer(t)
opts := connectToTestServer(srv)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

client, xe := Connect(ctx, opts)
if xe != nil {
t.Fatalf("connect failed: %v", xe)

Check failure on line 498 in internal/ssh/client_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Define a constant instead of duplicating this literal "connect failed: %v" 3 times.

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0jncaRZXezoAn1aUfp&open=AZ0jncaRZXezoAn1aUfp&pullRequest=38
}
defer client.Close()

// SendKeepalive should succeed on a real server
if err := client.SendKeepalive(); err != nil {
t.Errorf("keepalive should succeed: %v", err)
}
}

func TestConnect_RealSSHServer_KeepaliveRejected(t *testing.T) {

Check warning on line 508 in internal/ssh/client_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Rename function "TestConnect_RealSSHServer_KeepaliveRejected" to match the regular expression ^(_|[a-zA-Z0-9]+)$

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0jncaRZXezoAn1aUft&open=AZ0jncaRZXezoAn1aUft&pullRequest=38
srv := newTestSSHServer(t)
srv.mu.Lock()
srv.onKeepalive = func() bool { return false }
srv.mu.Unlock()

opts := connectToTestServer(srv)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

client, xe := Connect(ctx, opts)
if xe != nil {
t.Fatalf("connect failed: %v", xe)
}
defer client.Close()

// SendKeepalive returns nil for the request itself (the server replied),
// but the reply payload indicates rejection. The current implementation
// only checks if the request call itself fails, not the reply value.
// This test verifies no panic occurs.
_ = client.SendKeepalive()
}

func TestConnect_ContextCancelled(t *testing.T) {

Check warning on line 532 in internal/ssh/client_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Rename function "TestConnect_ContextCancelled" to match the regular expression ^(_|[a-zA-Z0-9]+)$

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0jncaRZXezoAn1aUfu&open=AZ0jncaRZXezoAn1aUfu&pullRequest=38
// Start a TCP listener that never accepts SSH handshake
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()

host, port := parseHostPort(ln.Addr().String())

ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately

start := time.Now()
_, xe := Connect(ctx, Options{
Host: host,
Port: port,
SkipKnownHostsCheck: true,
})
elapsed := time.Since(start)

if xe == nil {
t.Fatal("expected error when context is cancelled")
}
if elapsed > 2*time.Second {
t.Errorf("expected fast return on cancelled context, took %v", elapsed)
}
}

func TestConnect_ContextTimeout_DuringHandshake(t *testing.T) {

Check warning on line 561 in internal/ssh/client_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Rename function "TestConnect_ContextTimeout_DuringHandshake" to match the regular expression ^(_|[a-zA-Z0-9]+)$

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0jncaRZXezoAn1aUfv&open=AZ0jncaRZXezoAn1aUfv&pullRequest=38
// TCP listener that accepts but doesn't do SSH handshake (black hole)
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()

go func() {
conn, err := ln.Accept()
if err != nil {
return
}
// Hold connection open without doing SSH handshake
defer conn.Close()
buf := make([]byte, 1024)
for {
if _, err := conn.Read(buf); err != nil {
return
}
}
}()

host, port := parseHostPort(ln.Addr().String())

ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()

start := time.Now()
_, xe := Connect(ctx, Options{
Host: host,
Port: port,
SkipKnownHostsCheck: true,
})
elapsed := time.Since(start)

if xe == nil {
t.Fatal("expected error on timeout")
}
if elapsed > 2*time.Second {
t.Errorf("expected timeout within ~200ms, took %v", elapsed)
}
}

func TestClient_DialContext_RealSSHTunnel(t *testing.T) {

Check warning on line 605 in internal/ssh/client_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Rename function "TestClient_DialContext_RealSSHTunnel" to match the regular expression ^(_|[a-zA-Z0-9]+)$

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0jncaRZXezoAn1aUfw&open=AZ0jncaRZXezoAn1aUfw&pullRequest=38
// Start an echo server
echoLn := startEchoServer(t)
echoHost, echoPort := parseHostPort(echoLn.Addr().String())

// Start SSH server that forwards direct-tcpip to echo server
srv := newTestSSHServer(t)
srv.mu.Lock()
srv.onDirectTCPIP = func(destHost string, destPort uint32) (net.Conn, error) {
return net.Dial("tcp", echoLn.Addr().String())
}
srv.mu.Unlock()

opts := connectToTestServer(srv)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

client, xe := Connect(ctx, opts)
if xe != nil {
t.Fatalf("connect failed: %v", xe)
}
defer client.Close()

// Dial through SSH tunnel to echo server
conn, err := client.DialContext(ctx, "tcp", net.JoinHostPort(echoHost, fmt.Sprintf("%d", echoPort)))
if err != nil {
t.Fatalf("dial through tunnel failed: %v", err)
}
defer conn.Close()

// Verify data roundtrip
msg := []byte("hello-ssh-tunnel")
if _, err := conn.Write(msg); err != nil {
t.Fatalf("write failed: %v", err)
}
buf := make([]byte, len(msg))
if _, err := io.ReadFull(conn, buf); err != nil {
t.Fatalf("read failed: %v", err)
}
if string(buf) != string(msg) {
t.Errorf("echo mismatch: got %q, want %q", buf, msg)
}
}
Loading
Loading