Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch:

permissions:
contents: read
Expand Down
14 changes: 14 additions & 0 deletions cmd/xsql/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"net/http"
"os"
"os/signal"
"syscall"
"time"

"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -86,6 +89,17 @@ func runMCPServer(opts *mcpServerOptions) error {
Addr: resolved.httpAddr,
Handler: handler,
}

sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

go func() {
<-sigChan
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_ = httpServer.Shutdown(ctx)
}()

return httpServer.ListenAndServe()
default:
return errors.New(errors.CodeCfgInvalid, "unsupported mcp transport", map[string]any{"transport": resolved.transport})
Expand Down
20 changes: 2 additions & 18 deletions cmd/xsql/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,9 @@ func newProfileListCommand(w *output.Writer) *cobra.Command {
return xe
}

type profileInfo struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
DB string `json:"db"`
Mode string `json:"mode"` // "read-only" or "read-write"
}

profiles := make([]profileInfo, 0, len(cfg.Profiles))
profiles := make([]config.ProfileInfo, 0, len(cfg.Profiles))
for name, p := range cfg.Profiles {
mode := "read-only"
if p.UnsafeAllowWrite {
mode = "read-write"
}
profiles = append(profiles, profileInfo{
Name: name,
Description: p.Description,
DB: p.DB,
Mode: mode,
})
profiles = append(profiles, config.ProfileToInfo(name, p))
}

result := map[string]any{
Expand Down
14 changes: 13 additions & 1 deletion cmd/xsql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ import (
"github.com/zx06/xsql/internal/output"
)

const DefaultQueryTimeout = 30 * time.Second

// QueryFlags holds the flags for the query command
type QueryFlags struct {
UnsafeAllowWrite bool
AllowPlaintext bool
SSHSkipHostKey bool
QueryTimeout int
QueryTimeoutSet bool
}

// NewQueryCommand creates the query command
Expand All @@ -28,13 +32,15 @@ func NewQueryCommand(w *output.Writer) *cobra.Command {
Short: "Execute a SQL query (read-only by default)",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
flags.QueryTimeoutSet = cmd.Flags().Changed("query-timeout")
return runQuery(cmd, args, flags, w)
},
}

cmd.Flags().BoolVar(&flags.UnsafeAllowWrite, "unsafe-allow-write", false, "Allow write operations (bypasses read-only protection)")
cmd.Flags().BoolVar(&flags.AllowPlaintext, "allow-plaintext", false, "Allow plaintext secrets in config")
cmd.Flags().BoolVar(&flags.SSHSkipHostKey, "ssh-skip-known-hosts-check", false, "Skip SSH known_hosts check (dangerous)")
cmd.Flags().IntVar(&flags.QueryTimeout, "query-timeout", 0, "Query timeout in seconds (default: 30)")

return cmd
}
Expand All @@ -52,7 +58,13 @@ func runQuery(cmd *cobra.Command, args []string, flags *QueryFlags, w *output.Wr
return errors.New(errors.CodeCfgInvalid, "db type is required (mysql|pg)", nil)
}

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
timeout := DefaultQueryTimeout
if flags.QueryTimeoutSet && flags.QueryTimeout > 0 {
timeout = time.Duration(flags.QueryTimeout) * time.Second
} else if p.QueryTimeout > 0 {
timeout = time.Duration(p.QueryTimeout) * time.Second
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

conn, xe := app.ResolveConnection(ctx, app.ConnectionOptions{
Expand Down
22 changes: 17 additions & 5 deletions cmd/xsql/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ import (
"github.com/zx06/xsql/internal/output"
)

const DefaultSchemaTimeout = 60 * time.Second

// SchemaFlags holds the flags for the schema command
type SchemaFlags struct {
TablePattern string
IncludeSystem bool
AllowPlaintext bool
SSHSkipHostKey bool
TablePattern string
IncludeSystem bool
AllowPlaintext bool
SSHSkipHostKey bool
SchemaTimeout int
SchemaTimeoutSet bool
}

// NewSchemaCommand creates the schema command
Expand All @@ -41,6 +45,7 @@ func NewSchemaDumpCommand(w *output.Writer, flags *SchemaFlags) *cobra.Command {
Use: "dump",
Short: "Dump database schema (tables, columns, indexes, foreign keys)",
RunE: func(cmd *cobra.Command, args []string) error {
flags.SchemaTimeoutSet = cmd.Flags().Changed("schema-timeout")
return runSchemaDump(cmd, args, flags, w)
},
}
Expand All @@ -49,6 +54,7 @@ func NewSchemaDumpCommand(w *output.Writer, flags *SchemaFlags) *cobra.Command {
cmd.Flags().BoolVar(&flags.IncludeSystem, "include-system", false, "Include system tables")
cmd.Flags().BoolVar(&flags.AllowPlaintext, "allow-plaintext", false, "Allow plaintext secrets in config")
cmd.Flags().BoolVar(&flags.SSHSkipHostKey, "ssh-skip-known-hosts-check", false, "Skip SSH known_hosts check (dangerous)")
cmd.Flags().IntVar(&flags.SchemaTimeout, "schema-timeout", 0, "Schema dump timeout in seconds (default: 60)")

return cmd
}
Expand All @@ -65,7 +71,13 @@ func runSchemaDump(cmd *cobra.Command, args []string, flags *SchemaFlags, w *out
return errors.New(errors.CodeCfgInvalid, "db type is required (mysql|pg)", nil)
}

ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
timeout := DefaultSchemaTimeout
if flags.SchemaTimeoutSet && flags.SchemaTimeout > 0 {
timeout = time.Duration(flags.SchemaTimeout) * time.Second
} else if p.SchemaTimeout > 0 {
timeout = time.Duration(p.SchemaTimeout) * time.Second
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

conn, xe := app.ResolveConnection(ctx, app.ConnectionOptions{
Expand Down
9 changes: 9 additions & 0 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ profiles:
| `allow_plaintext` | bool | 允许明文密码(默认 false) |
| `format` | string | 输出格式:json/yaml/table/csv/auto |
| `ssh_proxy` | string | SSH 代理名称(引用 `ssh_proxies` 中定义的名称) |
| `query_timeout` | int | 查询超时秒数(默认 30 秒) |
| `schema_timeout` | int | Schema 导出超时秒数(默认 60 秒) |

## CLI Timeout Flags

| Flag | 说明 |
|------|------|
| `--query-timeout <seconds>` | 查询超时(覆盖 profile 配置) |
| `--schema-timeout <seconds>` | Schema 导出超时(覆盖 profile 配置) |

## Secrets

Expand Down
56 changes: 55 additions & 1 deletion docs/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,59 @@ func TestMain(m *testing.M) {

---

## SSH 测试

### 单元测试
SSH 客户端单元测试位于 `internal/ssh/client_test.go`,测试:
- 路径扩展(tilde expansion)
- 认证方法构建(私钥、默认密钥查找)
- known_hosts 校验
- SSH dial 失败错误码返回
- Passphrase 保护的密钥(正确/错误 passphrase)

```bash
# 运行 SSH 单元测试
go test -v ./internal/ssh/...
```

### E2E SSH 测试

#### CLI Flags 测试
测试 SSH CLI flags 与配置文件的合并行为:
- `--ssh-skip-known-hosts-check`
- `--ssh-identity-file`
- `--ssh-user`
- `--ssh-host`

```bash
# 运行 SSH CLI flag 测试
go test -tags=e2e -v -run "SSH" ./tests/e2e/...
```

#### 真实 SSH 测试
需要真实 SSH 服务器的测试(跳过如果环境未配置):
- `ssh_proxy_success_test.go`

设置环境变量:
```bash
export SSH_TEST_HOST=your-ssh-server
export SSH_TEST_PORT=22
export SSH_TEST_USER=your-user
export SSH_TEST_KEY_PATH=/path/to/private/key
export SSH_KNOWN_HOSTS_FILE=/path/to/known_hosts # 可选

# 可选:MySQL/PG over SSH
export XSQL_TEST_MYSQL_DSN="..."
export XSQL_TEST_PG_DSN="..."
```

### 注意事项
- SSH 代理测试需要可访问的 SSH 服务器
- 使用 `SkipKnownHostsCheck: true` 进行测试,避免 known_hosts 问题
- SSH 密钥应有适当的权限(600 或 400)

---

## E2E 测试

### 定位
Expand All @@ -213,7 +266,8 @@ tests/e2e/
├── profile_test.go # profile 命令测试
├── proxy_test.go # proxy 命令测试
├── readonly_test.go # 只读策略测试
└── ssh_proxy_success_test.go # SSH 代理成功测试
├── ssh_cli_flags_test.go # SSH CLI flags 测试
└── ssh_proxy_success_test.go # SSH 代理成功测试(需要真实 SSH)
```

### 运行测试
Expand Down
24 changes: 24 additions & 0 deletions internal/config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ type Profile struct {
AllowPlaintext bool `yaml:"allow_plaintext"` // 允许明文密码
UnsafeAllowWrite bool `yaml:"unsafe_allow_write"` // 允许写操作(绕过只读保护)

// 超时配置(秒)
QueryTimeout int `yaml:"query_timeout"` // 查询超时,默认 30 秒
SchemaTimeout int `yaml:"schema_timeout"` // Schema 导出超时,默认 60 秒

// SSH proxy 引用(引用 ssh_proxies 中定义的名称)
SSHProxy string `yaml:"ssh_proxy"`

Expand Down Expand Up @@ -83,3 +87,23 @@ type Options struct {
// WorkDir 用于默认路径(为空则使用进程当前工作目录)。
WorkDir string
}

type ProfileInfo struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
DB string `json:"db"`
Mode string `json:"mode"` // "read-only" or "read-write"
}

func ProfileToInfo(name string, p Profile) ProfileInfo {
mode := "read-only"
if p.UnsafeAllowWrite {
mode = "read-write"
}
return ProfileInfo{
Name: name,
Description: p.Description,
DB: p.DB,
Mode: mode,
}
}
20 changes: 18 additions & 2 deletions internal/mcp/streamable_http.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package mcp

import (
"context"
"crypto/subtle"
"net/http"
"strings"
"time"

"github.com/modelcontextprotocol/go-sdk/mcp"

Expand All @@ -20,10 +22,15 @@ const (
bearerPrefix = "Bearer "
unauthorized = "unauthorized"
headerMissing = "authorization header is required"

DefaultHTTPTimeout = 60 * time.Second
)

// NewStreamableHTTPHandler creates a streamable HTTP handler with required auth.
func NewStreamableHTTPHandler(server *mcp.Server, authToken string) (http.Handler, error) {
return NewStreamableHTTPHandlerWithTimeout(server, authToken, DefaultHTTPTimeout)
}

func NewStreamableHTTPHandlerWithTimeout(server *mcp.Server, authToken string, timeout time.Duration) (http.Handler, error) {
if server == nil {
return nil, errors.New(errors.CodeInternal, "mcp server is nil", nil)
}
Expand All @@ -33,7 +40,16 @@ func NewStreamableHTTPHandler(server *mcp.Server, authToken string) (http.Handle
handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server {
return server
}, &mcp.StreamableHTTPOptions{JSONResponse: true})
return requireAuth(handler, authToken), nil
authHandler := requireAuth(handler, authToken)
return withTimeout(authHandler, timeout), nil
}

func withTimeout(next http.Handler, timeout time.Duration) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(req.Context(), timeout)
defer cancel()
next.ServeHTTP(w, req.WithContext(ctx))
})
}

func requireAuth(next http.Handler, token string) http.Handler {
Expand Down
26 changes: 6 additions & 20 deletions internal/mcp/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,25 +309,9 @@ func (h *ToolHandler) Query(ctx context.Context, req *mcp.CallToolRequest, input

// ProfileList lists all profiles
func (h *ToolHandler) ProfileList(ctx context.Context, req *mcp.CallToolRequest, input struct{}) (*mcp.CallToolResult, any, error) {
type profileInfo struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
DB string `json:"db"`
Mode string `json:"mode"`
}

profiles := make([]profileInfo, 0, len(h.config.Profiles))
profiles := make([]config.ProfileInfo, 0, len(h.config.Profiles))
for name, p := range h.config.Profiles {
mode := "read-only"
if p.UnsafeAllowWrite {
mode = "read-write"
}
profiles = append(profiles, profileInfo{
Name: name,
Description: p.Description,
DB: p.DB,
Mode: mode,
})
profiles = append(profiles, config.ProfileToInfo(name, p))
}

output := map[string]any{
Expand All @@ -347,7 +331,6 @@ func (h *ToolHandler) ProfileList(ctx context.Context, req *mcp.CallToolRequest,
}, nil, nil
}

// Return result directly in content per RFC
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: string(jsonData)},
Expand Down Expand Up @@ -463,7 +446,10 @@ func (h *ToolHandler) formatError(err error) string {
"details": xe.Details,
},
}
jsonData, _ := json.MarshalIndent(output, "", " ")
jsonData, jsonErr := json.MarshalIndent(output, "", " ")
if jsonErr != nil {
return `{"ok":false,"error":{"code":"XSQL_INTERNAL","message":"failed to format error"}}`
}
return string(jsonData)
}

Expand Down
9 changes: 6 additions & 3 deletions internal/output/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,12 @@ func writeTable(out io.Writer, env Envelope) error {
_, _ = fmt.Fprintf(tw, "%s\t%v\n", k, m[k])
}
} else {
// 只有这里不得已才使用 JSON 格式化
b, _ := json.MarshalIndent(env.Data, "", " ")
_, _ = fmt.Fprintf(tw, "%s\n", b)
b, err := json.MarshalIndent(env.Data, "", " ")
if err == nil {
_, _ = fmt.Fprintf(tw, "%s\n", b)
} else {
_, _ = fmt.Fprintf(tw, "%v\n", env.Data)
}
}
}
return tw.Flush()
Expand Down
Loading
Loading