diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ac04e73..b890771 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,6 +5,7 @@ on: branches: [main] pull_request: branches: [main] + workflow_dispatch: permissions: contents: read diff --git a/cmd/xsql/mcp.go b/cmd/xsql/mcp.go index 3cfd32e..d47bad1 100644 --- a/cmd/xsql/mcp.go +++ b/cmd/xsql/mcp.go @@ -4,6 +4,9 @@ import ( "context" "net/http" "os" + "os/signal" + "syscall" + "time" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/spf13/cobra" @@ -86,6 +89,17 @@ func runMCPServer(opts *mcpServerOptions) error { Addr: resolved.httpAddr, Handler: handler, } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + <-sigChan + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _ = httpServer.Shutdown(ctx) + }() + return httpServer.ListenAndServe() default: return errors.New(errors.CodeCfgInvalid, "unsupported mcp transport", map[string]any{"transport": resolved.transport}) diff --git a/cmd/xsql/profile.go b/cmd/xsql/profile.go index b5d9879..6a8035a 100644 --- a/cmd/xsql/profile.go +++ b/cmd/xsql/profile.go @@ -39,25 +39,9 @@ func newProfileListCommand(w *output.Writer) *cobra.Command { return xe } - type profileInfo struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - DB string `json:"db"` - Mode string `json:"mode"` // "read-only" or "read-write" - } - - profiles := make([]profileInfo, 0, len(cfg.Profiles)) + profiles := make([]config.ProfileInfo, 0, len(cfg.Profiles)) for name, p := range cfg.Profiles { - mode := "read-only" - if p.UnsafeAllowWrite { - mode = "read-write" - } - profiles = append(profiles, profileInfo{ - Name: name, - Description: p.Description, - DB: p.DB, - Mode: mode, - }) + profiles = append(profiles, config.ProfileToInfo(name, p)) } result := map[string]any{ diff --git a/cmd/xsql/query.go b/cmd/xsql/query.go index d429d0f..61d54f7 100644 --- a/cmd/xsql/query.go +++ b/cmd/xsql/query.go @@ -12,11 +12,15 @@ import ( "github.com/zx06/xsql/internal/output" ) +const DefaultQueryTimeout = 30 * time.Second + // QueryFlags holds the flags for the query command type QueryFlags struct { UnsafeAllowWrite bool AllowPlaintext bool SSHSkipHostKey bool + QueryTimeout int + QueryTimeoutSet bool } // NewQueryCommand creates the query command @@ -28,6 +32,7 @@ func NewQueryCommand(w *output.Writer) *cobra.Command { Short: "Execute a SQL query (read-only by default)", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + flags.QueryTimeoutSet = cmd.Flags().Changed("query-timeout") return runQuery(cmd, args, flags, w) }, } @@ -35,6 +40,7 @@ func NewQueryCommand(w *output.Writer) *cobra.Command { cmd.Flags().BoolVar(&flags.UnsafeAllowWrite, "unsafe-allow-write", false, "Allow write operations (bypasses read-only protection)") cmd.Flags().BoolVar(&flags.AllowPlaintext, "allow-plaintext", false, "Allow plaintext secrets in config") cmd.Flags().BoolVar(&flags.SSHSkipHostKey, "ssh-skip-known-hosts-check", false, "Skip SSH known_hosts check (dangerous)") + cmd.Flags().IntVar(&flags.QueryTimeout, "query-timeout", 0, "Query timeout in seconds (default: 30)") return cmd } @@ -52,7 +58,13 @@ func runQuery(cmd *cobra.Command, args []string, flags *QueryFlags, w *output.Wr return errors.New(errors.CodeCfgInvalid, "db type is required (mysql|pg)", nil) } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + timeout := DefaultQueryTimeout + if flags.QueryTimeoutSet && flags.QueryTimeout > 0 { + timeout = time.Duration(flags.QueryTimeout) * time.Second + } else if p.QueryTimeout > 0 { + timeout = time.Duration(p.QueryTimeout) * time.Second + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() conn, xe := app.ResolveConnection(ctx, app.ConnectionOptions{ diff --git a/cmd/xsql/schema.go b/cmd/xsql/schema.go index 4109725..ed4e974 100644 --- a/cmd/xsql/schema.go +++ b/cmd/xsql/schema.go @@ -12,12 +12,16 @@ import ( "github.com/zx06/xsql/internal/output" ) +const DefaultSchemaTimeout = 60 * time.Second + // SchemaFlags holds the flags for the schema command type SchemaFlags struct { - TablePattern string - IncludeSystem bool - AllowPlaintext bool - SSHSkipHostKey bool + TablePattern string + IncludeSystem bool + AllowPlaintext bool + SSHSkipHostKey bool + SchemaTimeout int + SchemaTimeoutSet bool } // NewSchemaCommand creates the schema command @@ -41,6 +45,7 @@ func NewSchemaDumpCommand(w *output.Writer, flags *SchemaFlags) *cobra.Command { Use: "dump", Short: "Dump database schema (tables, columns, indexes, foreign keys)", RunE: func(cmd *cobra.Command, args []string) error { + flags.SchemaTimeoutSet = cmd.Flags().Changed("schema-timeout") return runSchemaDump(cmd, args, flags, w) }, } @@ -49,6 +54,7 @@ func NewSchemaDumpCommand(w *output.Writer, flags *SchemaFlags) *cobra.Command { cmd.Flags().BoolVar(&flags.IncludeSystem, "include-system", false, "Include system tables") cmd.Flags().BoolVar(&flags.AllowPlaintext, "allow-plaintext", false, "Allow plaintext secrets in config") cmd.Flags().BoolVar(&flags.SSHSkipHostKey, "ssh-skip-known-hosts-check", false, "Skip SSH known_hosts check (dangerous)") + cmd.Flags().IntVar(&flags.SchemaTimeout, "schema-timeout", 0, "Schema dump timeout in seconds (default: 60)") return cmd } @@ -65,7 +71,13 @@ func runSchemaDump(cmd *cobra.Command, args []string, flags *SchemaFlags, w *out return errors.New(errors.CodeCfgInvalid, "db type is required (mysql|pg)", nil) } - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + timeout := DefaultSchemaTimeout + if flags.SchemaTimeoutSet && flags.SchemaTimeout > 0 { + timeout = time.Duration(flags.SchemaTimeout) * time.Second + } else if p.SchemaTimeout > 0 { + timeout = time.Duration(p.SchemaTimeout) * time.Second + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() conn, xe := app.ResolveConnection(ctx, app.ConnectionOptions{ diff --git a/docs/config.md b/docs/config.md index f2e0760..b1101a0 100644 --- a/docs/config.md +++ b/docs/config.md @@ -131,6 +131,15 @@ profiles: | `allow_plaintext` | bool | 允许明文密码(默认 false) | | `format` | string | 输出格式:json/yaml/table/csv/auto | | `ssh_proxy` | string | SSH 代理名称(引用 `ssh_proxies` 中定义的名称) | +| `query_timeout` | int | 查询超时秒数(默认 30 秒) | +| `schema_timeout` | int | Schema 导出超时秒数(默认 60 秒) | + +## CLI Timeout Flags + +| Flag | 说明 | +|------|------| +| `--query-timeout ` | 查询超时(覆盖 profile 配置) | +| `--schema-timeout ` | Schema 导出超时(覆盖 profile 配置) | ## Secrets diff --git a/docs/testing.md b/docs/testing.md index 1d3bc45..5dac9b9 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -197,6 +197,59 @@ func TestMain(m *testing.M) { --- +## SSH 测试 + +### 单元测试 +SSH 客户端单元测试位于 `internal/ssh/client_test.go`,测试: +- 路径扩展(tilde expansion) +- 认证方法构建(私钥、默认密钥查找) +- known_hosts 校验 +- SSH dial 失败错误码返回 +- Passphrase 保护的密钥(正确/错误 passphrase) + +```bash +# 运行 SSH 单元测试 +go test -v ./internal/ssh/... +``` + +### E2E SSH 测试 + +#### CLI Flags 测试 +测试 SSH CLI flags 与配置文件的合并行为: +- `--ssh-skip-known-hosts-check` +- `--ssh-identity-file` +- `--ssh-user` +- `--ssh-host` + +```bash +# 运行 SSH CLI flag 测试 +go test -tags=e2e -v -run "SSH" ./tests/e2e/... +``` + +#### 真实 SSH 测试 +需要真实 SSH 服务器的测试(跳过如果环境未配置): +- `ssh_proxy_success_test.go` + +设置环境变量: +```bash +export SSH_TEST_HOST=your-ssh-server +export SSH_TEST_PORT=22 +export SSH_TEST_USER=your-user +export SSH_TEST_KEY_PATH=/path/to/private/key +export SSH_KNOWN_HOSTS_FILE=/path/to/known_hosts # 可选 + +# 可选:MySQL/PG over SSH +export XSQL_TEST_MYSQL_DSN="..." +export XSQL_TEST_PG_DSN="..." +``` + +### 注意事项 +- SSH 代理测试需要可访问的 SSH 服务器 +- 使用 `SkipKnownHostsCheck: true` 进行测试,避免 known_hosts 问题 +- SSH 密钥应有适当的权限(600 或 400) + +--- + ## E2E 测试 ### 定位 @@ -213,7 +266,8 @@ tests/e2e/ ├── profile_test.go # profile 命令测试 ├── proxy_test.go # proxy 命令测试 ├── readonly_test.go # 只读策略测试 - └── ssh_proxy_success_test.go # SSH 代理成功测试 + ├── ssh_cli_flags_test.go # SSH CLI flags 测试 + └── ssh_proxy_success_test.go # SSH 代理成功测试(需要真实 SSH) ``` ### 运行测试 diff --git a/internal/config/types.go b/internal/config/types.go index e77896d..9aa09ea 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -36,6 +36,10 @@ type Profile struct { AllowPlaintext bool `yaml:"allow_plaintext"` // 允许明文密码 UnsafeAllowWrite bool `yaml:"unsafe_allow_write"` // 允许写操作(绕过只读保护) + // 超时配置(秒) + QueryTimeout int `yaml:"query_timeout"` // 查询超时,默认 30 秒 + SchemaTimeout int `yaml:"schema_timeout"` // Schema 导出超时,默认 60 秒 + // SSH proxy 引用(引用 ssh_proxies 中定义的名称) SSHProxy string `yaml:"ssh_proxy"` @@ -83,3 +87,23 @@ type Options struct { // WorkDir 用于默认路径(为空则使用进程当前工作目录)。 WorkDir string } + +type ProfileInfo struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + DB string `json:"db"` + Mode string `json:"mode"` // "read-only" or "read-write" +} + +func ProfileToInfo(name string, p Profile) ProfileInfo { + mode := "read-only" + if p.UnsafeAllowWrite { + mode = "read-write" + } + return ProfileInfo{ + Name: name, + Description: p.Description, + DB: p.DB, + Mode: mode, + } +} diff --git a/internal/mcp/streamable_http.go b/internal/mcp/streamable_http.go index 171e001..864ecc3 100644 --- a/internal/mcp/streamable_http.go +++ b/internal/mcp/streamable_http.go @@ -1,9 +1,11 @@ package mcp import ( + "context" "crypto/subtle" "net/http" "strings" + "time" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -20,10 +22,15 @@ const ( bearerPrefix = "Bearer " unauthorized = "unauthorized" headerMissing = "authorization header is required" + + DefaultHTTPTimeout = 60 * time.Second ) -// NewStreamableHTTPHandler creates a streamable HTTP handler with required auth. func NewStreamableHTTPHandler(server *mcp.Server, authToken string) (http.Handler, error) { + return NewStreamableHTTPHandlerWithTimeout(server, authToken, DefaultHTTPTimeout) +} + +func NewStreamableHTTPHandlerWithTimeout(server *mcp.Server, authToken string, timeout time.Duration) (http.Handler, error) { if server == nil { return nil, errors.New(errors.CodeInternal, "mcp server is nil", nil) } @@ -33,7 +40,16 @@ func NewStreamableHTTPHandler(server *mcp.Server, authToken string) (http.Handle handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { return server }, &mcp.StreamableHTTPOptions{JSONResponse: true}) - return requireAuth(handler, authToken), nil + authHandler := requireAuth(handler, authToken) + return withTimeout(authHandler, timeout), nil +} + +func withTimeout(next http.Handler, timeout time.Duration) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx, cancel := context.WithTimeout(req.Context(), timeout) + defer cancel() + next.ServeHTTP(w, req.WithContext(ctx)) + }) } func requireAuth(next http.Handler, token string) http.Handler { diff --git a/internal/mcp/tools.go b/internal/mcp/tools.go index 589d6a7..8c11c5c 100644 --- a/internal/mcp/tools.go +++ b/internal/mcp/tools.go @@ -309,25 +309,9 @@ func (h *ToolHandler) Query(ctx context.Context, req *mcp.CallToolRequest, input // ProfileList lists all profiles func (h *ToolHandler) ProfileList(ctx context.Context, req *mcp.CallToolRequest, input struct{}) (*mcp.CallToolResult, any, error) { - type profileInfo struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - DB string `json:"db"` - Mode string `json:"mode"` - } - - profiles := make([]profileInfo, 0, len(h.config.Profiles)) + profiles := make([]config.ProfileInfo, 0, len(h.config.Profiles)) for name, p := range h.config.Profiles { - mode := "read-only" - if p.UnsafeAllowWrite { - mode = "read-write" - } - profiles = append(profiles, profileInfo{ - Name: name, - Description: p.Description, - DB: p.DB, - Mode: mode, - }) + profiles = append(profiles, config.ProfileToInfo(name, p)) } output := map[string]any{ @@ -347,7 +331,6 @@ func (h *ToolHandler) ProfileList(ctx context.Context, req *mcp.CallToolRequest, }, nil, nil } - // Return result directly in content per RFC return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: string(jsonData)}, @@ -463,7 +446,10 @@ func (h *ToolHandler) formatError(err error) string { "details": xe.Details, }, } - jsonData, _ := json.MarshalIndent(output, "", " ") + jsonData, jsonErr := json.MarshalIndent(output, "", " ") + if jsonErr != nil { + return `{"ok":false,"error":{"code":"XSQL_INTERNAL","message":"failed to format error"}}` + } return string(jsonData) } diff --git a/internal/output/writer.go b/internal/output/writer.go index fbbb07f..0093f4c 100644 --- a/internal/output/writer.go +++ b/internal/output/writer.go @@ -152,9 +152,12 @@ func writeTable(out io.Writer, env Envelope) error { _, _ = fmt.Fprintf(tw, "%s\t%v\n", k, m[k]) } } else { - // 只有这里不得已才使用 JSON 格式化 - b, _ := json.MarshalIndent(env.Data, "", " ") - _, _ = fmt.Fprintf(tw, "%s\n", b) + b, err := json.MarshalIndent(env.Data, "", " ") + if err == nil { + _, _ = fmt.Fprintf(tw, "%s\n", b) + } else { + _, _ = fmt.Fprintf(tw, "%v\n", env.Data) + } } } return tw.Flush() diff --git a/internal/ssh/client_test.go b/internal/ssh/client_test.go index 982528c..3e615f3 100644 --- a/internal/ssh/client_test.go +++ b/internal/ssh/client_test.go @@ -336,14 +336,25 @@ func TestClientClose_NoClient(t *testing.T) { func writeTestKey(t *testing.T, dir, name string) string { t.Helper() + return writeTestKeyWithPassphrase(t, dir, name, "") +} + +func writeTestKeyWithPassphrase(t *testing.T, dir, name, passphrase string) string { + t.Helper() key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("failed to generate key: %v", err) } - keyBytes := x509.MarshalPKCS1PrivateKey(key) - pemBytes := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes}) + var pemBytes []byte + if passphrase != "" { + block, _ := ssh.MarshalPrivateKeyWithPassphrase(key, passphrase, []byte(passphrase)) + pemBytes = pem.EncodeToMemory(block) + } else { + keyBytes := x509.MarshalPKCS1PrivateKey(key) + pemBytes = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyBytes}) + } path := filepath.Join(dir, name) if err := os.WriteFile(path, pemBytes, 0600); err != nil { t.Fatalf("failed to write key: %v", err) @@ -351,6 +362,43 @@ func writeTestKey(t *testing.T, dir, name string) string { return path } +func TestBuildAuthMethods_WithPassphrase(t *testing.T) { + keyPath := writeTestKeyWithPassphrase(t, t.TempDir(), "id_rsa_passphrase", "testpassphrase") + + opts := Options{ + IdentityFile: keyPath, + Passphrase: "testpassphrase", + } + + methods, xe := buildAuthMethods(opts) + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + if len(methods) == 0 { + t.Fatal("expected auth methods") + } +} + +func TestBuildAuthMethods_WithWrongPassphrase(t *testing.T) { + keyPath := writeTestKeyWithPassphrase(t, t.TempDir(), "id_rsa_wrong_pass", "correctpassphrase") + + opts := Options{ + IdentityFile: keyPath, + Passphrase: "wrongpassphrase", + } + + methods, xe := buildAuthMethods(opts) + if xe == nil { + t.Fatal("expected error for wrong passphrase") + } + if xe.Code != errors.CodeCfgInvalid && xe.Code != errors.CodeSSHAuthFailed { + t.Errorf("expected CodeCfgInvalid or CodeSSHAuthFailed, got %s", xe.Code) + } + if len(methods) != 0 { + t.Error("expected no auth methods for wrong passphrase") + } +} + // Helper function to check if a path contains another path component (cross-platform) func containsPath(path, component string) bool { p := filepath.ToSlash(filepath.Clean(path)) diff --git a/tests/e2e/ssh_cli_flags_test.go b/tests/e2e/ssh_cli_flags_test.go new file mode 100644 index 0000000..a2d29cb --- /dev/null +++ b/tests/e2e/ssh_cli_flags_test.go @@ -0,0 +1,147 @@ +//go:build e2e + +package e2e + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestSSH_SkipKnownHostsCheckFlag(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "xsql.yaml") + configContent := ` +ssh_proxies: + test_ssh: + host: localhost + port: 22 + user: testuser + identity_file: /nonexistent/key + +profiles: + test: + db: mysql + host: localhost + ssh_proxy: test_ssh +` + if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil { + t.Fatal(err) + } + + stdout, _, exitCode := runXSQL(t, "query", "SELECT 1", + "--config", configPath, + "--profile", "test", + "--format", "json", + "--ssh-skip-known-hosts-check") + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v\noutput: %s", err, stdout) + } + + if exitCode == 0 { + t.Log("Query succeeded with SSH proxy") + } else { + if resp.Error != nil && resp.Error.Code == "XSQL_SSH_HOSTKEY_MISMATCH" { + t.Error("should not get XSQL_SSH_HOSTKEY_MISMATCH when --ssh-skip-known-hosts-check is set") + } + } +} + +func TestSSH_IdentityFileFlagOverride(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "xsql.yaml") + configContent := ` +ssh_proxies: + test_ssh: + host: localhost + port: 22 + user: testuser + identity_file: /nonexistent/config_key + +profiles: + test: + db: mysql + host: localhost + ssh_proxy: test_ssh +` + if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil { + t.Fatal(err) + } + + _, _, exitCode := runXSQL(t, "query", "SELECT 1", + "--config", configPath, + "--profile", "test", + "--format", "json", + "--ssh-identity-file", "/nonexistent/cli_key") + + if exitCode == 0 { + t.Log("Query succeeded with SSH proxy") + } else { + t.Log("Query failed (expected without real SSH server)") + } +} + +func TestSSH_UserFlagOverride(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "xsql.yaml") + configContent := ` +ssh_proxies: + test_ssh: + host: localhost + port: 22 + user: config_user + identity_file: /nonexistent/key + +profiles: + test: + db: mysql + host: localhost + ssh_proxy: test_ssh +` + if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil { + t.Fatal(err) + } + + _, _, exitCode := runXSQL(t, "query", "SELECT 1", + "--config", configPath, + "--profile", "test", + "--format", "json", + "--ssh-user", "cli_user") + + if exitCode == 0 { + t.Log("Query succeeded with SSH proxy") + } else { + t.Log("Query failed (expected without real SSH server)") + } +} + +func TestSSH_HostFlag(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "xsql.yaml") + configContent := ` +profiles: + test: + db: mysql + host: localhost +` + if err := os.WriteFile(configPath, []byte(configContent), 0600); err != nil { + t.Fatal(err) + } + + _, _, exitCode := runXSQL(t, "query", "SELECT 1", + "--config", configPath, + "--profile", "test", + "--format", "json", + "--ssh-host", "example.com", + "--ssh-user", "test", + "--ssh-skip-known-hosts-check") + + if exitCode == 0 { + t.Log("Query succeeded with SSH proxy") + } else { + t.Log("Query failed (expected without real SSH server)") + } +}