Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions cmd/xsql/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/zx06/xsql/internal/errors"
"github.com/zx06/xsql/internal/output"
"github.com/zx06/xsql/internal/proxy"
"github.com/zx06/xsql/internal/ssh"
)

// ProxyFlags holds the flags for the proxy command
Expand Down Expand Up @@ -131,24 +132,36 @@ func runProxy(cmd *cobra.Command, flags *ProxyFlags, w *output.Writer) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

sshClient, xe := app.ResolveSSH(ctx, p, allowPlaintext, flags.SSHSkipHostKey)
// Use ReconnectDialer for automatic SSH reconnection
onStatus := func(event ssh.StatusEvent) {
switch event.Type {
case ssh.StatusDisconnected:
log.Printf("[proxy] SSH connection lost: %v", event.Error)
case ssh.StatusReconnecting:
log.Printf("[proxy] reconnecting to SSH server...")
case ssh.StatusReconnected:
log.Printf("[proxy] SSH reconnected successfully")
case ssh.StatusReconnectFailed:
log.Printf("[proxy] SSH reconnection failed: %v", event.Error)
}
}

reconnDialer, xe := app.ResolveReconnectableSSH(ctx, p, allowPlaintext, flags.SSHSkipHostKey, onStatus)
if xe != nil {
return xe
}
if sshClient != nil {
defer func() {
if err := sshClient.Close(); err != nil {
log.Printf("[proxy] failed to close ssh client: %v", err)
}
}()
}
defer func() {
if err := reconnDialer.Close(); err != nil {
log.Printf("[proxy] failed to close ssh dialer: %v", err)
}
}()

proxyOpts := proxy.Options{
LocalHost: flags.LocalHost,
LocalPort: localPort,
RemoteHost: p.Host,
RemotePort: p.Port,
Dialer: sshClient,
Dialer: reconnDialer,
}

px, result, xe := proxy.Start(ctx, proxyOpts)
Expand Down
28 changes: 28 additions & 0 deletions docs/ssh-proxy.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,33 @@
- 由 driver 回调触发时,通过 `sshClient.Dial("tcp", target)` 建立到 DB 的连接。
- 支持连接池:dial 可并发、安全复用 SSH client。

## SSH Keepalive 与自动重连

### 问题
长生命周期的 SSH 连接(如 `xsql proxy`)可能因网络中断而变为"死连接"。此时通过 proxy 的新连接请求会超时失败,且用户无法感知 proxy 已不可用。

### 解决方案
xsql 内置了 SSH 连接健康检测和自动重连机制:

1. **SSH Keepalive**:周期性发送 `keepalive@openssh.com` 请求探测连接存活状态
- 默认间隔:30 秒
- 默认最大失败次数:3(连续 3 次失败判定为死连接)

2. **自动重连(ReconnectDialer)**:
- 当 keepalive 检测到连接死亡时,自动触发重连
- 当 dial 操作失败时,自动尝试重连并重试
- 带指数退避的重试策略
- 重连过程中输出状态日志:
```
[proxy] SSH connection lost: <error>
[proxy] reconnecting to SSH server...
[proxy] SSH reconnected successfully
```

### 适用范围
- **`xsql proxy`**:使用 ReconnectDialer,支持自动重连(长生命周期连接)
- **`xsql query` / `xsql schema dump`**:每次命令创建新连接,不需要重连

## 配置方式

### SSH Proxy 复用(推荐)
Expand Down Expand Up @@ -68,3 +95,4 @@ profiles:
- 监听 `127.0.0.1:0`(或指定端口)分配端口
- 将 DB 连接指向本地端口
- 输出支持 JSON/YAML 或终端表格(详见 `docs/cli-spec.md`)
- 内置 SSH 自动重连:网络中断后自动恢复代理服务
57 changes: 47 additions & 10 deletions internal/app/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,29 +138,66 @@ func ResolveSSH(ctx context.Context, profile config.Profile, allowPlaintext, ski
return nil, nil
}

sshOpts, xe := resolveSSHOptions(profile, allowPlaintext, skipHostKeyCheck)
if xe != nil {
return nil, xe
}

sc, xe := ssh.Connect(ctx, sshOpts)
if xe != nil {
return nil, xe
}

return sc, nil
}

// ResolveReconnectableSSH creates an SSH ReconnectDialer for long-lived connections
// (e.g. the proxy command). It wraps the SSH connection with automatic keepalive
// monitoring and reconnection on failure.
func ResolveReconnectableSSH(ctx context.Context, profile config.Profile, allowPlaintext, skipHostKeyCheck bool, onStatus func(ssh.StatusEvent)) (*ssh.ReconnectDialer, *errors.XError) {
if profile.SSHConfig == nil {
return nil, errors.New(errors.CodeCfgInvalid, "profile must have ssh_proxy configured", nil)
}

sshOpts, xe := resolveSSHOptions(profile, allowPlaintext, skipHostKeyCheck)
if xe != nil {
return nil, xe
}

var ropts []ssh.ReconnectOption
if onStatus != nil {
ropts = append(ropts, ssh.WithStatusCallback(onStatus))
}

rd, err := ssh.NewReconnectDialer(ctx, sshOpts, ropts...)
if err != nil {
if xe, ok := err.(*errors.XError); ok {
return nil, xe
}
return nil, errors.Wrap(errors.CodeSSHDialFailed, "failed to establish reconnectable ssh connection", map[string]any{"host": profile.SSHConfig.Host}, err)
}

return rd, nil
}

// resolveSSHOptions builds SSH options from a profile, resolving secrets.
func resolveSSHOptions(profile config.Profile, allowPlaintext, skipHostKeyCheck bool) (ssh.Options, *errors.XError) {
passphrase := profile.SSHConfig.Passphrase
if passphrase != "" {
pp, xe := secret.Resolve(passphrase, secret.Options{AllowPlaintext: allowPlaintext})
if xe != nil {
return nil, xe
return ssh.Options{}, xe
}
passphrase = pp
}

sshOpts := ssh.Options{
return ssh.Options{
Host: profile.SSHConfig.Host,
Port: profile.SSHConfig.Port,
User: profile.SSHConfig.User,
IdentityFile: profile.SSHConfig.IdentityFile,
Passphrase: passphrase,
KnownHostsFile: profile.SSHConfig.KnownHostsFile,
SkipKnownHostsCheck: skipHostKeyCheck || profile.SSHConfig.SkipHostKey,
}

sc, xe := ssh.Connect(ctx, sshOpts)
if xe != nil {
return nil, xe
}

return sc, nil
}, nil
}
120 changes: 118 additions & 2 deletions internal/app/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
"fmt"
"sync/atomic"
"testing"
"time"

"github.com/zx06/xsql/internal/config"
"github.com/zx06/xsql/internal/db"
"github.com/zx06/xsql/internal/errors"
"github.com/zx06/xsql/internal/ssh"
)

func TestResolveConnection_UnsupportedDriver(t *testing.T) {
Expand Down Expand Up @@ -118,20 +120,24 @@
func TestResolveSSH_PassphraseAllowed(t *testing.T) {
profile := config.Profile{
SSHConfig: &config.SSHProxy{
Host: "example.com",
Host: "127.0.0.1",

Check failure on line 123 in internal/app/conn_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Define a constant instead of duplicating this literal "127.0.0.1" 6 times.

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0i-3mmGkojc2DtqzEV&open=AZ0i-3mmGkojc2DtqzEV&pullRequest=37
Port: 22,
User: "user",
Passphrase: "phrase-value",
},
}

client, err := ResolveSSH(nil, profile, true, false)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

client, err := ResolveSSH(ctx, profile, true, false)

if err == nil {
if client != nil {
client.Close()
}
}
// Error is acceptable (no SSH server running), just verifying passphrase resolves
}

func TestConnectionOptions_Fields(t *testing.T) {
Expand Down Expand Up @@ -320,3 +326,113 @@
t.Fatalf("expected ssh auth/dial failure, got %s", xe.Code)
}
}

func TestResolveReconnectableSSH_NoSSHConfig(t *testing.T) {

Check warning on line 330 in internal/app/conn_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Rename function "TestResolveReconnectableSSH_NoSSHConfig" to match the regular expression ^(_|[a-zA-Z0-9]+)$

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0i-3mmGkojc2DtqzEW&open=AZ0i-3mmGkojc2DtqzEW&pullRequest=37
profile := config.Profile{}

rd, xe := ResolveReconnectableSSH(context.Background(), profile, false, false, nil)
if rd != nil {
t.Fatal("expected nil dialer")

Check failure on line 335 in internal/app/conn_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Define a constant instead of duplicating this literal "expected nil dialer" 3 times.

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0i-3mmGkojc2DtqzEU&open=AZ0i-3mmGkojc2DtqzEU&pullRequest=37
}
if xe == nil {
t.Fatal("expected error when SSHConfig is nil")
}
if xe.Code != errors.CodeCfgInvalid {
t.Errorf("expected CodeCfgInvalid, got %s", xe.Code)
}
}

func TestResolveReconnectableSSH_PassphraseNotAllowed(t *testing.T) {

Check warning on line 345 in internal/app/conn_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Rename function "TestResolveReconnectableSSH_PassphraseNotAllowed" to match the regular expression ^(_|[a-zA-Z0-9]+)$

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0i-3mmGkojc2DtqzEX&open=AZ0i-3mmGkojc2DtqzEX&pullRequest=37
profile := config.Profile{
SSHConfig: &config.SSHProxy{
Host: "127.0.0.1",
Port: 22,
User: "user",
Passphrase: "plain-phrase",
},
}

rd, xe := ResolveReconnectableSSH(context.Background(), profile, false, false, nil)
if rd != nil {
t.Fatal("expected nil dialer")
}
if xe == nil {
t.Fatal("expected error for plaintext passphrase")
}
if xe.Code != errors.CodeCfgInvalid {
t.Errorf("expected CodeCfgInvalid, got %s", xe.Code)
}
}

func TestResolveReconnectableSSH_ConnectFails(t *testing.T) {

Check warning on line 367 in internal/app/conn_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Rename function "TestResolveReconnectableSSH_ConnectFails" to match the regular expression ^(_|[a-zA-Z0-9]+)$

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0i-3mmGkojc2DtqzEY&open=AZ0i-3mmGkojc2DtqzEY&pullRequest=37
profile := config.Profile{
SSHConfig: &config.SSHProxy{
Host: "127.0.0.1",
Port: 1, // unlikely to have SSH on port 1
User: "user",
},
}

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

var statusCalled bool
rd, xe := ResolveReconnectableSSH(ctx, profile, true, true, func(e ssh.StatusEvent) {
statusCalled = true
})
if rd != nil {
rd.Close()
t.Fatal("expected nil dialer")
}
if xe == nil {
t.Fatal("expected connection error")
}
// The error may be wrapped as XError or as a generic error
_ = statusCalled
}

func TestResolveReconnectableSSH_NilCallback(t *testing.T) {

Check warning on line 394 in internal/app/conn_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Rename function "TestResolveReconnectableSSH_NilCallback" to match the regular expression ^(_|[a-zA-Z0-9]+)$

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0i-3mmGkojc2DtqzEZ&open=AZ0i-3mmGkojc2DtqzEZ&pullRequest=37
profile := config.Profile{
SSHConfig: &config.SSHProxy{
Host: "127.0.0.1",
Port: 1,
User: "user",
},
}

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

// nil callback should not cause panic
rd, xe := ResolveReconnectableSSH(ctx, profile, true, true, nil)
if rd != nil {
rd.Close()
}
// Just verify no panic
_ = xe
}

func TestResolveSSHOptions_Success(t *testing.T) {

Check warning on line 415 in internal/app/conn_test.go

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Rename function "TestResolveSSHOptions_Success" to match the regular expression ^(_|[a-zA-Z0-9]+)$

See more on https://sonarcloud.io/project/issues?id=zx06_xsql&issues=AZ0i-3mmGkojc2DtqzEa&open=AZ0i-3mmGkojc2DtqzEa&pullRequest=37
profile := config.Profile{
SSHConfig: &config.SSHProxy{
Host: "bastion.example.com",
Port: 2222,
User: "admin",
IdentityFile: "~/.ssh/id_ed25519",
},
}

opts, xe := resolveSSHOptions(profile, true, true)
if xe != nil {
t.Fatalf("unexpected error: %v", xe)
}
if opts.Host != "bastion.example.com" {
t.Errorf("expected host bastion.example.com, got %s", opts.Host)
}
if opts.Port != 2222 {
t.Errorf("expected port 2222, got %d", opts.Port)
}
if !opts.SkipKnownHostsCheck {
t.Error("expected SkipKnownHostsCheck=true")
}
}
25 changes: 24 additions & 1 deletion internal/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"strings"
"sync/atomic"

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
Expand All @@ -18,6 +19,7 @@ import (
// Client wraps ssh.Client and provides DialContext for use by database drivers.
type Client struct {
client *ssh.Client
alive atomic.Bool
}

// Connect establishes an SSH connection.
Expand Down Expand Up @@ -59,22 +61,43 @@ func Connect(ctx context.Context, opts Options) (*Client, *errors.XError) {
}
return nil, errors.Wrap(errors.CodeSSHDialFailed, "failed to connect to ssh server", map[string]any{"host": opts.Host}, err)
}
return &Client{client: client}, nil
c := &Client{client: client}
c.alive.Store(true)
return c, nil
}

// DialContext establishes a connection to the target through the SSH tunnel.
func (c *Client) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if c.client == nil {
return nil, fmt.Errorf("ssh client is not connected")
}
return c.client.Dial(network, addr)
}

// Close closes the SSH connection.
func (c *Client) Close() error {
c.alive.Store(false)
if c.client != nil {
return c.client.Close()
}
return nil
}

// SendKeepalive sends a single SSH keepalive request and returns any error.
// A nil error means the connection is alive.
func (c *Client) SendKeepalive() error {
if c.client == nil {
return fmt.Errorf("ssh client is nil")
}
_, _, err := c.client.SendRequest("keepalive@openssh.com", true, nil)
return err
}

// Alive reports whether the SSH connection is believed to be alive.
func (c *Client) Alive() bool {
return c.alive.Load()
}

func buildAuthMethods(opts Options) ([]ssh.AuthMethod, *errors.XError) {
var methods []ssh.AuthMethod

Expand Down
Loading
Loading