diff --git a/.gitignore b/.gitignore index aa814d3..898a52c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ .npm-tmp/ npm/*/bin/xsql npm/*/bin/xsql.exe +coverage.txt diff --git a/cmd/xsql/command_unit_test.go b/cmd/xsql/command_unit_test.go index 3dca927..f904cba 100644 --- a/cmd/xsql/command_unit_test.go +++ b/cmd/xsql/command_unit_test.go @@ -598,6 +598,200 @@ profiles: } } +func TestResolveProxyPort(t *testing.T) { + t.Run("nil cmd returns config port", func(t *testing.T) { + port, fromConfig := resolveProxyPort(nil, &ProxyFlags{LocalPort: 5555}, 13306) + if port != 13306 { + t.Errorf("expected 13306, got %d", port) + } + if !fromConfig { + t.Error("expected fromConfig=true") + } + }) + + t.Run("nil cmd with zero config returns auto", func(t *testing.T) { + port, fromConfig := resolveProxyPort(nil, &ProxyFlags{}, 0) + if port != 0 { + t.Errorf("expected 0, got %d", port) + } + if fromConfig { + t.Error("expected fromConfig=false") + } + }) + + t.Run("cli flag takes priority", func(t *testing.T) { + cmd := NewProxyCommand(nil) + // Simulate setting the flag + _ = cmd.Flags().Set("local-port", "9999") + port, fromConfig := resolveProxyPort(cmd, &ProxyFlags{LocalPort: 9999}, 13306) + if port != 9999 { + t.Errorf("expected 9999, got %d", port) + } + if fromConfig { + t.Error("expected fromConfig=false") + } + }) + + t.Run("config port when cli not set", func(t *testing.T) { + cmd := NewProxyCommand(nil) + // Don't set the flag - use config port + port, fromConfig := resolveProxyPort(cmd, &ProxyFlags{}, 13306) + if port != 13306 { + t.Errorf("expected 13306, got %d", port) + } + if !fromConfig { + t.Error("expected fromConfig=true") + } + }) +} + +func TestConfigInitCommand(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + + GlobalConfig.FormatStr = "json" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + cmd := newConfigInitCommand(&w) + cmd.SetArgs([]string{"--path", path}) + if err := cmd.Execute(); err != nil { + t.Fatalf("config init failed: %v", err) + } + if !json.Valid(out.Bytes()) { + t.Fatalf("expected json output, got %s", out.String()) + } + + // Verify file exists + if _, err := os.Stat(path); err != nil { + t.Errorf("config file should exist: %v", err) + } +} + +func TestConfigInitCommand_FileExists(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("test"), 0600); err != nil { + t.Fatal(err) + } + + GlobalConfig.FormatStr = "json" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + cmd := newConfigInitCommand(&w) + cmd.SetArgs([]string{"--path", path}) + if err := cmd.Execute(); err == nil { + t.Fatal("expected error when file exists") + } +} + +func TestConfigSetCommand(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + GlobalConfig.ConfigStr = path + GlobalConfig.FormatStr = "json" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + cmd := newConfigSetCommand(&w) + cmd.SetArgs([]string{"profile.dev.host", "localhost"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("config set failed: %v", err) + } + if !json.Valid(out.Bytes()) { + t.Fatalf("expected json output, got %s", out.String()) + } + + // Verify the config was updated + data, _ := os.ReadFile(path) + if !bytes.Contains(data, []byte("localhost")) { + t.Error("config should contain 'localhost'") + } +} + +func TestConfigSetCommand_InvalidKey(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + GlobalConfig.ConfigStr = path + GlobalConfig.FormatStr = "json" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + cmd := newConfigSetCommand(&w) + cmd.SetArgs([]string{"badkey", "value"}) + if err := cmd.Execute(); err == nil { + t.Fatal("expected error for invalid key") + } +} + +func TestConfigSetCommand_NoConfig(t *testing.T) { + GlobalConfig.ConfigStr = "" + GlobalConfig.FormatStr = "json" + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + cmd := newConfigSetCommand(&w) + cmd.SetArgs([]string{"profile.dev.host", "localhost"}) + + // Set HOME and work dir to temp dirs with no config files + origHome := os.Getenv("HOME") + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + defer func() { _ = os.Setenv("HOME", origHome) }() + + origDir, _ := os.Getwd() + tmpWorkDir := t.TempDir() + _ = os.Chdir(tmpWorkDir) + defer func() { _ = os.Chdir(origDir) }() + + err := cmd.Execute() + // FindConfigPath returns default home path, SetConfigValue creates the file. + // This should either succeed (creating new file) or fail. + // Since no config exists yet, it should succeed by creating a new one. + if err != nil { + // If it fails, that's okay too - we just want to verify it doesn't panic + t.Logf("error (acceptable): %v", err) + } +} + +func TestRunProxy_WithConfigLocalPort(t *testing.T) { + // Test that config local_port is used when --local-port is not set + GlobalConfig.ProfileStr = "dev" + GlobalConfig.FormatStr = "json" + GlobalConfig.Resolved.Profile = config.Profile{ + DB: "mysql", + Host: "db.example.com", + Port: 3306, + LocalPort: 13306, + SSHConfig: &config.SSHProxy{ + Host: "bastion.example.com", + Port: 22, + User: "user", + }, + } + + var out bytes.Buffer + w := output.New(&out, &bytes.Buffer{}) + // This will fail at SSH connection, but we can verify the port resolution + err := runProxy(nil, &ProxyFlags{}, &w) + if err == nil { + t.Fatal("expected error (SSH not available)") + } + // The error should be about SSH, not port + if xe, ok := errors.As(err); ok && xe.Code == errors.CodePortInUse { + t.Error("should not get port-in-use error") + } +} + func TestValueIfSet(t *testing.T) { if got := valueIfSet(false, "x"); got != "" { t.Fatalf("expected empty when not set, got %q", got) diff --git a/cmd/xsql/config.go b/cmd/xsql/config.go new file mode 100644 index 0000000..b924d2b --- /dev/null +++ b/cmd/xsql/config.go @@ -0,0 +1,85 @@ +package main + +import ( + "github.com/spf13/cobra" + + "github.com/zx06/xsql/internal/config" + "github.com/zx06/xsql/internal/errors" + "github.com/zx06/xsql/internal/output" +) + +// NewConfigCommand creates the config command group +func NewConfigCommand(w *output.Writer) *cobra.Command { + configCmd := &cobra.Command{ + Use: "config", + Short: "Manage configuration", + } + + configCmd.AddCommand(newConfigInitCommand(w)) + configCmd.AddCommand(newConfigSetCommand(w)) + + return configCmd +} + +// newConfigInitCommand creates the config init command +func newConfigInitCommand(w *output.Writer) *cobra.Command { + var path string + + cmd := &cobra.Command{ + Use: "init", + Short: "Create a template configuration file", + RunE: func(cmd *cobra.Command, args []string) error { + format, err := parseOutputFormat(GlobalConfig.FormatStr) + if err != nil { + return err + } + + cfgPath, xe := config.InitConfig(path) + if xe != nil { + return xe + } + + return w.WriteOK(format, map[string]any{ + "config_path": cfgPath, + }) + }, + } + + cmd.Flags().StringVar(&path, "path", "", "Config file path (default: $HOME/.config/xsql/xsql.yaml)") + + return cmd +} + +// newConfigSetCommand creates the config set command +func newConfigSetCommand(w *output.Writer) *cobra.Command { + return &cobra.Command{ + Use: "set ", + Short: "Set a configuration value (e.g., profile.dev.host localhost)", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + key, value := args[0], args[1] + + format, err := parseOutputFormat(GlobalConfig.FormatStr) + if err != nil { + return err + } + + cfgPath := config.FindConfigPath(config.Options{ + ConfigPath: GlobalConfig.ConfigStr, + }) + if cfgPath == "" { + return errors.New(errors.CodeCfgNotFound, "no config file found; run 'xsql config init' first", nil) + } + + if xe := config.SetConfigValue(cfgPath, key, value); xe != nil { + return xe + } + + return w.WriteOK(format, map[string]any{ + "config_path": cfgPath, + "key": key, + "value": value, + }) + }, + } +} diff --git a/cmd/xsql/main.go b/cmd/xsql/main.go index 54fc248..d31472e 100644 --- a/cmd/xsql/main.go +++ b/cmd/xsql/main.go @@ -30,6 +30,7 @@ func run() int { root.AddCommand(NewSchemaCommand(&w)) root.AddCommand(NewMCPCommand()) root.AddCommand(NewProxyCommand(&w)) + root.AddCommand(NewConfigCommand(&w)) // Execute and handle errors if err := root.Execute(); err != nil { diff --git a/cmd/xsql/proxy.go b/cmd/xsql/proxy.go index 092cc1f..2eaa7be 100644 --- a/cmd/xsql/proxy.go +++ b/cmd/xsql/proxy.go @@ -1,14 +1,17 @@ package main import ( + "bufio" "context" "fmt" "log" "os" "os/signal" + "strings" "syscall" "github.com/spf13/cobra" + "golang.org/x/term" "github.com/zx06/xsql/internal/app" "github.com/zx06/xsql/internal/errors" @@ -44,6 +47,49 @@ func NewProxyCommand(w *output.Writer) *cobra.Command { return cmd } +// resolveProxyPort determines the port to use with the following priority: +// CLI --local-port > profile.local_port > 0 (auto) +// Returns the port and whether it came from config (for conflict handling). +func resolveProxyPort(cmd *cobra.Command, flags *ProxyFlags, profileLocalPort int) (port int, fromConfig bool) { + if cmd != nil && cmd.Flags().Changed("local-port") { + return flags.LocalPort, false + } + if profileLocalPort > 0 { + return profileLocalPort, true + } + return 0, false +} + +// handlePortConflict handles a port conflict when the port comes from config. +// In TTY mode, prompts the user to choose random port or quit. +// In non-TTY mode, returns an error. +func handlePortConflict(port int, host string) (int, *errors.XError) { + if !term.IsTerminal(int(os.Stdin.Fd())) { + return 0, errors.New(errors.CodePortInUse, "configured port is already in use", + map[string]any{"port": port, "host": host}) + } + + fmt.Fprintf(os.Stderr, "⚠ Port %d is already in use.\n", port) + fmt.Fprintf(os.Stderr, " [R] Use a random port\n") + fmt.Fprintf(os.Stderr, " [Q] Quit\n") + fmt.Fprintf(os.Stderr, "Choice [R/Q]: ") + + reader := bufio.NewReader(os.Stdin) + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(strings.ToLower(input)) + + switch input { + case "r", "": + return 0, nil // 0 means auto-assign + case "q": + return 0, errors.New(errors.CodePortInUse, "user chose to quit due to port conflict", + map[string]any{"port": port}) + default: + return 0, errors.New(errors.CodePortInUse, "user chose to quit due to port conflict", + map[string]any{"port": port}) + } +} + // runProxy executes the proxy command func runProxy(cmd *cobra.Command, flags *ProxyFlags, w *output.Writer) error { if GlobalConfig.ProfileStr == "" { @@ -64,6 +110,22 @@ func runProxy(cmd *cobra.Command, flags *ProxyFlags, w *output.Writer) error { return errors.New(errors.CodeCfgInvalid, "profile must have ssh_proxy configured for port forwarding", nil) } + // Resolve port: CLI > config local_port > 0 (auto) + localPort, fromConfig := resolveProxyPort(cmd, flags, p.LocalPort) + + // Check for port conflict if a specific port is configured + if localPort > 0 && !proxy.IsPortAvailable(flags.LocalHost, localPort) { + if fromConfig { + // Port from config: offer interactive choice + newPort, xe := handlePortConflict(localPort, flags.LocalHost) + if xe != nil { + return xe + } + localPort = newPort + } + // If port from CLI flag, let proxy.Start handle the error naturally + } + allowPlaintext := flags.AllowPlaintext || p.AllowPlaintext ctx, cancel := context.WithCancel(context.Background()) @@ -83,7 +145,7 @@ func runProxy(cmd *cobra.Command, flags *ProxyFlags, w *output.Writer) error { proxyOpts := proxy.Options{ LocalHost: flags.LocalHost, - LocalPort: flags.LocalPort, + LocalPort: localPort, RemoteHost: p.Host, RemotePort: p.Port, Dialer: sshClient, diff --git a/docs/cli-spec.md b/docs/cli-spec.md index 7aee318..18b7854 100644 --- a/docs/cli-spec.md +++ b/docs/cli-spec.md @@ -269,10 +269,21 @@ xsql proxy -p prod-mysql # 指定本地端口 xsql -p prod-mysql proxy --local-port 13306 +# 使用 profile 中配置的 local_port(如果设置了的话) +xsql proxy -p prod-mysql # 自动使用 profile 中的 local_port + # 输出 JSON 格式 xsql -p prod-mysql proxy --format json ``` +**端口优先级:**`--local-port` flag > `profile.local_port` > 0(自动分配) + +**端口冲突处理:** +- 当端口来源于**配置文件**且端口已被占用时: + - TTY 环境:交互式询问用户选择随机端口或退出 + - 非 TTY 环境:返回错误 `XSQL_PORT_IN_USE` +- 当端口来源于 `--local-port` CLI flag 且被占用时:直接返回错误 + **要求:** - Profile 必须配置 `ssh_proxy`,否则无法启动 - Profile 必须配置数据库类型(`db`)、主机(`host`)和端口(`port`) @@ -325,6 +336,63 @@ Press Ctrl+C to stop - 密码/passphrase 复用 keyring 机制,不泄露明文 - 按 Ctrl+C 或发送 SIGTERM 信号优雅关闭代理 +### `xsql config init` + +创建配置文件模板。 + +```bash +# 在默认路径创建 +xsql config init + +# 指定路径 +xsql config init --path ./xsql.yaml +``` + +**Flags:** +| Flag | 默认值 | 说明 | +|------|--------|------| +| `--path` | `$HOME/.config/xsql/xsql.yaml` | 配置文件路径 | + +**输出示例(JSON):** +```json +{ + "ok": true, + "schema_version": 1, + "data": { + "config_path": "/home/user/.config/xsql/xsql.yaml" + } +} +``` + +### `xsql config set ` + +快速修改配置项,使用点号分隔的路径定位配置字段。 + +```bash +# 设置 profile 字段 +xsql config set profile.dev.host localhost +xsql config set profile.dev.port 3306 +xsql config set profile.dev.db mysql +xsql config set profile.dev.local_port 13306 + +# 设置 SSH proxy 字段 +xsql config set ssh_proxy.bastion.host bastion.example.com +xsql config set ssh_proxy.bastion.user admin +``` + +**输出示例(JSON):** +```json +{ + "ok": true, + "schema_version": 1, + "data": { + "config_path": "/home/user/.config/xsql/xsql.yaml", + "key": "profile.dev.host", + "value": "localhost" + } +} +``` + ## 全局 Flags | Flag | 说明 | diff --git a/docs/config.md b/docs/config.md index f2e0760..1263d79 100644 --- a/docs/config.md +++ b/docs/config.md @@ -76,6 +76,7 @@ profiles: user: app_readonly password: "keyring:prod/mysql_password" database: myapp_prod + local_port: 13306 # proxy 本地端口(可选) ssh_proxy: bastion # 引用预定义的 SSH 代理 # 另一个使用同一 SSH 代理的数据库 @@ -130,6 +131,7 @@ profiles: | `unsafe_allow_write` | bool | 允许写操作,绕过只读保护(默认 false) | | `allow_plaintext` | bool | 允许明文密码(默认 false) | | `format` | string | 输出格式:json/yaml/table/csv/auto | +| `local_port` | int | proxy 本地监听端口(默认 0,自动分配) | | `ssh_proxy` | string | SSH 代理名称(引用 `ssh_proxies` 中定义的名称) | ## Secrets diff --git a/docs/error-contract.md b/docs/error-contract.md index 18d6f6b..f41fa05 100644 --- a/docs/error-contract.md +++ b/docs/error-contract.md @@ -55,6 +55,9 @@ DB 类: 只读策略: - `XSQL_RO_BLOCKED` - 写操作被只读策略拦截 +端口: +- `XSQL_PORT_IN_USE` - 代理端口被占用 + 内部: - `XSQL_INTERNAL` - 内部错误 diff --git a/docs/rfcs/0006-proxy-port-config-set.md b/docs/rfcs/0006-proxy-port-config-set.md new file mode 100644 index 0000000..7d52af2 --- /dev/null +++ b/docs/rfcs/0006-proxy-port-config-set.md @@ -0,0 +1,105 @@ +# RFC 0006: Proxy Port Config & Config Set Command + +Status: Accepted + +## 摘要 +增加两项能力:(1)支持在 profile 配置中指定 proxy 本地端口,端口冲突时交互式询问用户;(2)新增 `xsql config set` 和 `xsql config init` 命令,降低配置复杂度。两项变更涉及 config schema 新增字段、新增 CLI 命令、新增错误码。 + +## 背景 / 动机 +- 当前痛点: + - `xsql proxy` 每次都需通过 `--local-port` 指定端口,无法在配置文件中固定。 + - 配置文件需要手动编辑 YAML,对新用户不友好。 +- 目标: + - Profile 可配置 `local_port`,proxy 命令自动使用;端口被占用时交互提示。 + - 提供 `config set` 快速修改配置、`config init` 生成模板。 +- 非目标: + - 不改变现有 proxy 的 SSH 连接逻辑。 + - 不实现 `config get` 或 `config delete`(可后续扩展)。 + +## 方案(Proposed) + +### 用户视角(CLI/配置/输出) + +#### 1. Profile `local_port` 字段 +```yaml +profiles: + prod-mysql: + db: mysql + host: db.internal.example.com + port: 3306 + local_port: 13306 # 新增:proxy 本地端口 + ssh_proxy: bastion +``` + +端口优先级:`--local-port` flag > `profile.local_port` > 0(自动分配) + +端口冲突处理: +- 仅当端口来源于**配置文件**(非 CLI flag)时,才提示用户选择。 +- TTY 环境:询问 "Port 13306 is in use. [R]andom port / [Q]uit?" +- 非 TTY 环境:返回错误 `XSQL_PORT_IN_USE`,退出码 10。 + +#### 2. `xsql config init` +```bash +xsql config init # 创建 ~/.config/xsql/xsql.yaml +xsql config init --path ./xsql.yaml # 指定路径 +``` + +#### 3. `xsql config set` +```bash +xsql config set profile.dev.host localhost +xsql config set profile.dev.port 3306 +xsql config set profile.dev.db mysql +xsql config set ssh_proxy.bastion.host bastion.example.com +xsql config set ssh_proxy.bastion.user admin +``` + +输出(JSON): +```json +{"ok":true,"schema_version":1,"data":{"config_path":"/path/to/xsql.yaml","key":"profile.dev.host","value":"localhost"}} +``` + +### 新增错误码 +| Code | 含义 | 退出码 | +|------|------|--------| +| `XSQL_PORT_IN_USE` | 端口被占用 | 10 | + +### 技术设计(Architecture) + +#### 涉及模块 +- `internal/config/types.go`:Profile 新增 `LocalPort` 字段 +- `internal/config/write.go`:新增配置写入能力 +- `internal/proxy/proxy.go`:端口冲突检测 +- `internal/errors/codes.go`:新增错误码 +- `cmd/xsql/proxy.go`:读取 config local_port、端口冲突交互 +- `cmd/xsql/config.go`:新增 config 命令组 + +#### 兼容性策略 +- `local_port` 是新增字段,默认为 0(不影响现有行为) +- config 命令是全新命令,不影响现有命令 +- 只增不改 + +## 备选方案(Alternatives) +- 方案 A(采用):profile 中直接加 `local_port` 字段 +- 方案 B:在 proxy 子节点增加配置 —— 过度设计,当前场景不需要 + +## 兼容性与迁移(Compatibility & Migration) +- 不破坏兼容:所有新增字段有零值默认 +- 无需迁移 + +## 安全与隐私(Security/Privacy) +- config set 的 password 字段建议使用 `keyring:` 引用,不阻止明文但遵循已有 allow_plaintext 机制 +- config init 模板中不包含真实密码 + +## 测试计划(Test Plan) +- 单元测试: + - config write/update 逻辑 + - proxy port conflict detection + - config set key parsing +- E2E 测试: + - `xsql config init` 创建文件 + - `xsql config set` 修改配置 + - proxy 使用 config local_port + - proxy 端口冲突场景 + +## 未决问题(Open Questions) +- 无 diff --git a/internal/config/types.go b/internal/config/types.go index e77896d..17be3bf 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -39,6 +39,9 @@ type Profile struct { // SSH proxy 引用(引用 ssh_proxies 中定义的名称) SSHProxy string `yaml:"ssh_proxy"` + // Proxy 本地端口(用于 xsql proxy 命令) + LocalPort int `yaml:"local_port"` + // 解析后的 SSH 配置(由 Resolve 填充,不从 YAML 读取) SSHConfig *SSHProxy `yaml:"-"` } diff --git a/internal/config/write.go b/internal/config/write.go new file mode 100644 index 0000000..2525936 --- /dev/null +++ b/internal/config/write.go @@ -0,0 +1,234 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + "gopkg.in/yaml.v3" + + "github.com/zx06/xsql/internal/errors" +) + +// InitConfig creates a template config file at the given path. +// If path is empty, uses the default path ($HOME/.config/xsql/xsql.yaml). +func InitConfig(path string) (string, *errors.XError) { + if path == "" { + home, err := os.UserHomeDir() + if err != nil { + return "", errors.Wrap(errors.CodeInternal, "failed to get home directory", nil, err) + } + path = filepath.Join(home, ".config", "xsql", "xsql.yaml") + } + + if _, err := os.Stat(path); err == nil { + return "", errors.New(errors.CodeCfgInvalid, "config file already exists", map[string]any{"path": path}) + } + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0700); err != nil { + return "", errors.Wrap(errors.CodeInternal, "failed to create config directory", map[string]any{"path": dir}, err) + } + + template := `# xsql configuration file +# Documentation: https://github.com/zx06/xsql/blob/main/docs/config.md + +ssh_proxies: {} + # example: + # host: bastion.example.com + # port: 22 + # user: admin + # identity_file: ~/.ssh/id_ed25519 + +profiles: {} + # example: + # db: mysql + # host: 127.0.0.1 + # port: 3306 + # user: root + # password: "keyring:dev/mysql_password" + # database: mydb +` + + if err := os.WriteFile(path, []byte(template), 0600); err != nil { + return "", errors.Wrap(errors.CodeInternal, "failed to write config file", map[string]any{"path": path}, err) + } + + return path, nil +} + +// SetConfigValue sets a value in the config file using dot-notation key. +// Supported key patterns: +// - profile.. +// - ssh_proxy.. +func SetConfigValue(configPath, key, value string) *errors.XError { + if configPath == "" { + return errors.New(errors.CodeCfgNotFound, "no config file found; run 'xsql config init' first", nil) + } + + cfg, xe := readFile(configPath) + if xe != nil { + if xe.Code == errors.CodeCfgNotFound { + // Create new empty config + cfg = File{ + SSHProxies: map[string]SSHProxy{}, + Profiles: map[string]Profile{}, + } + } else { + return xe + } + } + + parts := strings.SplitN(key, ".", 3) + if len(parts) != 3 { + return errors.New(errors.CodeCfgInvalid, "invalid key format; use profile.. or ssh_proxy..", + map[string]any{"key": key}) + } + + section, name, field := parts[0], parts[1], parts[2] + + switch section { + case "profile": + if xe := setProfileField(&cfg, name, field, value); xe != nil { + return xe + } + case "ssh_proxy": + if xe := setSSHProxyField(&cfg, name, field, value); xe != nil { + return xe + } + default: + return errors.New(errors.CodeCfgInvalid, "unsupported config section; use 'profile' or 'ssh_proxy'", + map[string]any{"section": section}) + } + + return writeFile(configPath, cfg) +} + +func setProfileField(cfg *File, name, field, value string) *errors.XError { + p := cfg.Profiles[name] + + switch field { + case "db": + p.DB = value + case "host": + p.Host = value + case "port": + port, err := strconv.Atoi(value) + if err != nil { + return errors.New(errors.CodeCfgInvalid, "port must be a number", map[string]any{"value": value}) + } + p.Port = port + case "local_port": + port, err := strconv.Atoi(value) + if err != nil { + return errors.New(errors.CodeCfgInvalid, "local_port must be a number", map[string]any{"value": value}) + } + p.LocalPort = port + case "user": + p.User = value + case "password": + p.Password = value + case "database": + p.Database = value + case "dsn": + p.DSN = value + case "description": + p.Description = value + case "format": + p.Format = value + case "ssh_proxy": + p.SSHProxy = value + case "unsafe_allow_write": + p.UnsafeAllowWrite = parseBool(value) + case "allow_plaintext": + p.AllowPlaintext = parseBool(value) + default: + return errors.New(errors.CodeCfgInvalid, fmt.Sprintf("unknown profile field: %s", field), + map[string]any{"field": field}) + } + + cfg.Profiles[name] = p + return nil +} + +func setSSHProxyField(cfg *File, name, field, value string) *errors.XError { + sp := cfg.SSHProxies[name] + + switch field { + case "host": + sp.Host = value + case "port": + port, err := strconv.Atoi(value) + if err != nil { + return errors.New(errors.CodeCfgInvalid, "port must be a number", map[string]any{"value": value}) + } + sp.Port = port + case "user": + sp.User = value + case "identity_file": + sp.IdentityFile = value + case "passphrase": + sp.Passphrase = value + case "known_hosts_file": + sp.KnownHostsFile = value + case "skip_host_key": + sp.SkipHostKey = parseBool(value) + default: + return errors.New(errors.CodeCfgInvalid, fmt.Sprintf("unknown ssh_proxy field: %s", field), + map[string]any{"field": field}) + } + + cfg.SSHProxies[name] = sp + return nil +} + +func writeFile(path string, cfg File) *errors.XError { + b, err := yaml.Marshal(cfg) + if err != nil { + return errors.Wrap(errors.CodeInternal, "failed to marshal config", nil, err) + } + + if err := os.WriteFile(path, b, 0600); err != nil { + return errors.Wrap(errors.CodeInternal, "failed to write config file", map[string]any{"path": path}, err) + } + + return nil +} + +func parseBool(s string) bool { + s = strings.ToLower(s) + return s == "true" || s == "1" || s == "yes" +} + +// FindConfigPath returns the path to the config file being used (or default path). +func FindConfigPath(opts Options) string { + if opts.ConfigPath != "" { + return opts.ConfigPath + } + + workDir := opts.WorkDir + if workDir == "" { + wd, _ := os.Getwd() + workDir = wd + } + homeDir := opts.HomeDir + if homeDir == "" { + if hd, err := os.UserHomeDir(); err == nil { + homeDir = hd + } + } + + for _, p := range defaultConfigPaths(workDir, homeDir) { + if _, err := os.Stat(p); err == nil { + return p + } + } + + // Return default home path if none found + if homeDir != "" { + return filepath.Join(homeDir, ".config", "xsql", "xsql.yaml") + } + return "" +} diff --git a/internal/config/write_test.go b/internal/config/write_test.go new file mode 100644 index 0000000..cfd9db1 --- /dev/null +++ b/internal/config/write_test.go @@ -0,0 +1,440 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "gopkg.in/yaml.v3" +) + +func TestInitConfig(t *testing.T) { + t.Run("creates config at specified path", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "subdir", "xsql.yaml") + + cfgPath, xe := InitConfig(path) + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + if cfgPath != path { + t.Errorf("expected %s, got %s", path, cfgPath) + } + + // File should exist + if _, err := os.Stat(path); err != nil { + t.Errorf("config file should exist: %v", err) + } + + // File should be valid YAML + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read file: %v", err) + } + var f File + if err := yaml.Unmarshal(data, &f); err != nil { + t.Errorf("config should be valid YAML: %v", err) + } + }) + + t.Run("creates config at default path", func(t *testing.T) { + // Create a temp HOME + dir := t.TempDir() + t.Setenv("HOME", dir) + t.Setenv("USERPROFILE", dir) // Windows compatibility + + cfgPath, xe := InitConfig("") + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + + expected := filepath.Join(dir, ".config", "xsql", "xsql.yaml") + if cfgPath != expected { + t.Errorf("expected %s, got %s", expected, cfgPath) + } + }) + + t.Run("fails if file already exists", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("test"), 0600); err != nil { + t.Fatal(err) + } + + _, xe := InitConfig(path) + if xe == nil { + t.Error("expected error when file exists") + } + }) +} + +func TestSetConfigValue(t *testing.T) { + t.Run("set profile field on new config", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + // Create minimal config + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "profile.dev.host", "localhost") + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + + // Read back and verify + f, xe2 := readFile(path) + if xe2 != nil { + t.Fatalf("failed to read config: %v", xe2) + } + + p, ok := f.Profiles["dev"] + if !ok { + t.Fatal("profile 'dev' not found") + } + if p.Host != "localhost" { + t.Errorf("expected host=localhost, got %s", p.Host) + } + }) + + t.Run("set profile port as number", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "profile.dev.port", "3306") + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + + f, _ := readFile(path) + if f.Profiles["dev"].Port != 3306 { + t.Errorf("expected port=3306, got %d", f.Profiles["dev"].Port) + } + }) + + t.Run("set profile local_port", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "profile.prod.local_port", "13306") + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + + f, _ := readFile(path) + if f.Profiles["prod"].LocalPort != 13306 { + t.Errorf("expected local_port=13306, got %d", f.Profiles["prod"].LocalPort) + } + }) + + t.Run("set ssh_proxy field", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "ssh_proxy.bastion.host", "bastion.example.com") + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + + f, _ := readFile(path) + sp, ok := f.SSHProxies["bastion"] + if !ok { + t.Fatal("ssh_proxy 'bastion' not found") + } + if sp.Host != "bastion.example.com" { + t.Errorf("expected host=bastion.example.com, got %s", sp.Host) + } + }) + + t.Run("set ssh_proxy port", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "ssh_proxy.bastion.port", "2222") + if xe != nil { + t.Fatalf("unexpected error: %v", xe) + } + + f, _ := readFile(path) + if f.SSHProxies["bastion"].Port != 2222 { + t.Errorf("expected port=2222, got %d", f.SSHProxies["bastion"].Port) + } + }) + + t.Run("invalid key format", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "invalidkey", "value") + if xe == nil { + t.Error("expected error for invalid key format") + } + }) + + t.Run("invalid section", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "unknown.name.field", "value") + if xe == nil { + t.Error("expected error for unknown section") + } + }) + + t.Run("invalid port value", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "profile.dev.port", "notanumber") + if xe == nil { + t.Error("expected error for invalid port") + } + }) + + t.Run("unknown profile field", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "profile.dev.nonexistent", "value") + if xe == nil { + t.Error("expected error for unknown field") + } + }) + + t.Run("unknown ssh_proxy field", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "ssh_proxy.bastion.nonexistent", "value") + if xe == nil { + t.Error("expected error for unknown field") + } + }) + + t.Run("empty config path", func(t *testing.T) { + xe := SetConfigValue("", "profile.dev.host", "localhost") + if xe == nil { + t.Error("expected error for empty config path") + } + }) + + t.Run("set all profile fields", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + fields := map[string]string{ + "db": "mysql", + "host": "localhost", + "user": "root", + "password": "secret", + "database": "mydb", + "dsn": "root:secret@tcp(localhost:3306)/mydb", + "description": "test db", + "format": "json", + "ssh_proxy": "bastion", + "unsafe_allow_write": "true", + "allow_plaintext": "true", + } + + for field, value := range fields { + xe := SetConfigValue(path, "profile.test."+field, value) + if xe != nil { + t.Fatalf("failed to set %s: %v", field, xe) + } + } + + f, _ := readFile(path) + p := f.Profiles["test"] + if p.DB != "mysql" { + t.Errorf("db: expected mysql, got %s", p.DB) + } + if p.Host != "localhost" { + t.Errorf("host: expected localhost, got %s", p.Host) + } + if !p.UnsafeAllowWrite { + t.Error("unsafe_allow_write should be true") + } + if !p.AllowPlaintext { + t.Error("allow_plaintext should be true") + } + }) + + t.Run("set all ssh_proxy fields", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + fields := map[string]string{ + "host": "bastion.example.com", + "user": "admin", + "identity_file": "~/.ssh/id_ed25519", + "passphrase": "keyring:ssh/passphrase", + "known_hosts_file": "~/.ssh/known_hosts", + "skip_host_key": "true", + } + + for field, value := range fields { + xe := SetConfigValue(path, "ssh_proxy.bastion."+field, value) + if xe != nil { + t.Fatalf("failed to set %s: %v", field, xe) + } + } + + f, _ := readFile(path) + sp := f.SSHProxies["bastion"] + if sp.Host != "bastion.example.com" { + t.Errorf("host: expected bastion.example.com, got %s", sp.Host) + } + if sp.User != "admin" { + t.Errorf("user: expected admin, got %s", sp.User) + } + if !sp.SkipHostKey { + t.Error("skip_host_key should be true") + } + }) + + t.Run("invalid local_port value", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "profile.dev.local_port", "abc") + if xe == nil { + t.Error("expected error for non-numeric local_port") + } + }) + + t.Run("invalid ssh_proxy port value", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + xe := SetConfigValue(path, "ssh_proxy.bastion.port", "abc") + if xe == nil { + t.Error("expected error for non-numeric ssh_proxy port") + } + }) +} + +func TestFindConfigPath(t *testing.T) { + t.Run("returns explicit config path", func(t *testing.T) { + path := FindConfigPath(Options{ConfigPath: "/explicit/path.yaml"}) + if path != "/explicit/path.yaml" { + t.Errorf("expected /explicit/path.yaml, got %s", path) + } + }) + + t.Run("finds config in work dir", func(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(cfgPath, []byte("profiles: {}"), 0600); err != nil { + t.Fatal(err) + } + + path := FindConfigPath(Options{WorkDir: dir}) + if path != cfgPath { + t.Errorf("expected %s, got %s", cfgPath, path) + } + }) + + t.Run("returns default home path when not found", func(t *testing.T) { + dir := t.TempDir() + path := FindConfigPath(Options{WorkDir: "/nonexistent", HomeDir: dir}) + expected := filepath.Join(dir, ".config", "xsql", "xsql.yaml") + if path != expected { + t.Errorf("expected %s, got %s", expected, path) + } + }) +} + +func TestParseBool(t *testing.T) { + cases := []struct { + input string + want bool + }{ + {"true", true}, + {"True", true}, + {"TRUE", true}, + {"1", true}, + {"yes", true}, + {"Yes", true}, + {"false", false}, + {"0", false}, + {"no", false}, + {"", false}, + } + + for _, tc := range cases { + got := parseBool(tc.input) + if got != tc.want { + t.Errorf("parseBool(%q) = %v, want %v", tc.input, got, tc.want) + } + } +} + +func TestLocalPortInProfile(t *testing.T) { + // Test that local_port is properly deserialized from YAML + yamlContent := `profiles: + dev: + db: mysql + host: localhost + port: 3306 + local_port: 13306 +ssh_proxies: {} +` + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte(yamlContent), 0600); err != nil { + t.Fatal(err) + } + + f, xe := readFile(path) + if xe != nil { + t.Fatalf("failed to read config: %v", xe) + } + + p, ok := f.Profiles["dev"] + if !ok { + t.Fatal("profile 'dev' not found") + } + + if p.LocalPort != 13306 { + t.Errorf("expected local_port=13306, got %d", p.LocalPort) + } +} diff --git a/internal/errors/codes.go b/internal/errors/codes.go index 681aefb..f9224c3 100644 --- a/internal/errors/codes.go +++ b/internal/errors/codes.go @@ -24,6 +24,9 @@ const ( // Read-only policy CodeROBlocked Code = "XSQL_RO_BLOCKED" + // Port + CodePortInUse Code = "XSQL_PORT_IN_USE" + // Internal CodeInternal Code = "XSQL_INTERNAL" ) @@ -41,6 +44,7 @@ func AllCodes() []Code { CodeDBAuthFailed, CodeDBExecFailed, CodeROBlocked, + CodePortInUse, CodeInternal, } } diff --git a/internal/errors/exitcode.go b/internal/errors/exitcode.go index f08b9d1..8fedf2d 100644 --- a/internal/errors/exitcode.go +++ b/internal/errors/exitcode.go @@ -31,6 +31,8 @@ func ExitCodeFor(code Code) ExitCode { return ExitConnect case CodeROBlocked: return ExitReadOnly + case CodePortInUse: + return ExitInternal case CodeDBExecFailed: return ExitDBExec case CodeInternal: diff --git a/internal/errors/exitcode_test.go b/internal/errors/exitcode_test.go index 60ac8bb..1d78cf4 100644 --- a/internal/errors/exitcode_test.go +++ b/internal/errors/exitcode_test.go @@ -20,6 +20,7 @@ func TestExitCodeFor(t *testing.T) { {CodeDBAuthFailed, ExitConnect}, {CodeDBDriverUnsupported, ExitConnect}, {CodeROBlocked, ExitReadOnly}, + {CodePortInUse, ExitInternal}, {CodeDBExecFailed, ExitDBExec}, {CodeInternal, ExitInternal}, {Code("UNKNOWN_CODE"), ExitInternal}, // unknown code @@ -101,8 +102,8 @@ func TestAs(t *testing.T) { func TestAllCodes(t *testing.T) { codes := AllCodes() - if len(codes) != 12 { - t.Errorf("AllCodes() should return 12 codes, got %d", len(codes)) + if len(codes) != 13 { + t.Errorf("AllCodes() should return 13 codes, got %d", len(codes)) } // Check for duplicates diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 0732a89..bdd4974 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -56,6 +56,11 @@ func Start(ctx context.Context, opts Options) (*Proxy, *Result, *errors.XError) addr := fmt.Sprintf("%s:%d", opts.LocalHost, opts.LocalPort) listener, err := net.Listen("tcp", addr) if err != nil { + // Check if this is a port-in-use error + if isPortInUse(err) { + return nil, nil, errors.New(errors.CodePortInUse, "port is already in use", + map[string]any{"address": addr, "port": opts.LocalPort}) + } return nil, nil, errors.Wrap(errors.CodeInternal, "failed to listen on local port", map[string]any{"address": addr}, err) } @@ -191,3 +196,40 @@ func (p *Proxy) LocalAddress() string { } return "" } + +// IsPortAvailable checks if a port is available for binding. +func IsPortAvailable(host string, port int) bool { + if host == "" { + host = "127.0.0.1" + } + addr := fmt.Sprintf("%s:%d", host, port) + ln, err := net.Listen("tcp", addr) + if err != nil { + return false + } + _ = ln.Close() + return true +} + +// isPortInUse checks if the error is caused by the port being in use. +func isPortInUse(err error) bool { + if err == nil { + return false + } + s := err.Error() + return contains(s, "address already in use") || contains(s, "bind: address already in use") || + contains(s, "Only one usage of each socket address") +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsSubstr(s, sub)) +} + +func containsSubstr(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 21bf987..8c324b4 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -262,8 +262,8 @@ func TestProxy_PortInUse(t *testing.T) { if xe == nil { t.Error("expected error when port is already in use") } - if xe != nil && xe.Code != errors.CodeInternal { - t.Errorf("expected CodeInternal, got %s", xe.Code) + if xe != nil && xe.Code != errors.CodePortInUse { + t.Errorf("expected CodePortInUse, got %s", xe.Code) } } @@ -377,3 +377,112 @@ func TestProxy_HandleConnection_DialReturnsNilConn(t *testing.T) { t.Fatal("handleConnection should return quickly when dialer returns nil conn") } } + +func TestIsPortAvailable(t *testing.T) { + t.Run("available port", func(t *testing.T) { + // Port 0 always resolves to an available port; just verify the function works + // by finding a free port first + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + port := ln.Addr().(*net.TCPAddr).Port + _ = ln.Close() + + // Port should now be available + if !IsPortAvailable("127.0.0.1", port) { + t.Error("port should be available after closing listener") + } + }) + + t.Run("unavailable port", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer func() { _ = ln.Close() }() + + port := ln.Addr().(*net.TCPAddr).Port + if IsPortAvailable("127.0.0.1", port) { + t.Error("port should not be available while listener is active") + } + }) + + t.Run("default host", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer func() { _ = ln.Close() }() + + port := ln.Addr().(*net.TCPAddr).Port + if IsPortAvailable("", port) { + t.Error("port should not be available (empty host defaults to 127.0.0.1)") + } + }) +} + +func TestIsPortInUse(t *testing.T) { + tests := []struct { + errMsg string + want bool + }{ + {"listen tcp 127.0.0.1:8080: bind: address already in use", true}, + {"address already in use", true}, + {"Only one usage of each socket address", true}, + {"connection refused", false}, + {"", false}, + } + + for _, tt := range tests { + var err error + if tt.errMsg != "" { + err = &net.OpError{Op: "listen", Err: &net.AddrError{Err: tt.errMsg, Addr: "127.0.0.1:8080"}} + } + got := isPortInUse(err) + if got != tt.want { + t.Errorf("isPortInUse(%q) = %v, want %v", tt.errMsg, got, tt.want) + } + } + + // nil error + if isPortInUse(nil) { + t.Error("isPortInUse(nil) should return false") + } +} + +func TestProxy_PortInUse_ReturnsCorrectErrorCode(t *testing.T) { + // Bind to a port + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer func() { _ = ln.Close() }() + + port := ln.Addr().(*net.TCPAddr).Port + + dialer := newMockSSHClient(t, "127.0.0.1:18500") + defer func() { _ = dialer.Close() }() + + ctx := context.Background() + _, _, xe := Start(ctx, Options{ + LocalHost: "127.0.0.1", + LocalPort: port, + RemoteHost: "127.0.0.1", + RemotePort: 18500, + Dialer: dialer, + }) + + if xe == nil { + t.Fatal("expected error") + } + if xe.Code != errors.CodePortInUse { + t.Errorf("expected CodePortInUse, got %s", xe.Code) + } + if xe.Details == nil { + t.Fatal("expected details") + } + if xe.Details["port"] != port { + t.Errorf("expected port=%d in details, got %v", port, xe.Details["port"]) + } +} diff --git a/tests/e2e/config_test.go b/tests/e2e/config_test.go new file mode 100644 index 0000000..846c0b9 --- /dev/null +++ b/tests/e2e/config_test.go @@ -0,0 +1,586 @@ +//go:build e2e + +package e2e + +import ( + "encoding/json" + "net" + "os" + "path/filepath" + "strconv" + "strings" + "testing" +) + +// ============================================================================ +// xsql config init Tests +// ============================================================================ + +func TestConfigInit_CreatesFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + + stdout, _, exitCode := runXSQL(t, "config", "init", "--path", path, "--format", "json") + + if exitCode != 0 { + t.Fatalf("expected exit code 0, got %d; output: %s", exitCode, stdout) + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v\noutput: %s", err, stdout) + } + + if !resp.OK { + t.Error("expected ok=true") + } + + if resp.Data == nil { + t.Fatal("expected data in response") + } + + data, ok := resp.Data.(map[string]any) + if !ok { + t.Fatal("expected data to be a map") + } + + if data["config_path"] != path { + t.Errorf("expected config_path=%s, got %v", path, data["config_path"]) + } + + // Verify file exists + if _, err := os.Stat(path); err != nil { + t.Errorf("config file should exist: %v", err) + } +} + +func TestConfigInit_FileExists(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("test"), 0600); err != nil { + t.Fatal(err) + } + + stdout, _, exitCode := runXSQL(t, "config", "init", "--path", path, "--format", "json") + + if exitCode == 0 { + t.Error("expected non-zero exit code when file exists") + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if resp.OK { + t.Error("expected ok=false") + } +} + +func TestConfigInit_TableFormat(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + + stdout, _, exitCode := runXSQL(t, "config", "init", "--path", path, "--format", "table") + + if exitCode != 0 { + t.Fatalf("expected exit code 0, got %d", exitCode) + } + + // Should not be JSON + if strings.HasPrefix(strings.TrimSpace(stdout), "{") { + t.Error("table format should not output JSON") + } +} + +// ============================================================================ +// xsql config set Tests +// ============================================================================ + +func TestConfigSet_ProfileHost(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + stdout, _, exitCode := runXSQL(t, "config", "set", "profile.dev.host", "localhost", + "--config", path, "--format", "json") + + if exitCode != 0 { + t.Fatalf("expected exit code 0, got %d; output: %s", exitCode, stdout) + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if !resp.OK { + t.Error("expected ok=true") + } + + data, ok := resp.Data.(map[string]any) + if !ok { + t.Fatal("expected data to be a map") + } + if data["key"] != "profile.dev.host" { + t.Errorf("expected key=profile.dev.host, got %v", data["key"]) + } + if data["value"] != "localhost" { + t.Errorf("expected value=localhost, got %v", data["value"]) + } + + // Verify config was updated + content, _ := os.ReadFile(path) + if !strings.Contains(string(content), "localhost") { + t.Error("config should contain 'localhost'") + } +} + +func TestConfigSet_ProfilePort(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + stdout, _, exitCode := runXSQL(t, "config", "set", "profile.dev.port", "3306", + "--config", path, "--format", "json") + + if exitCode != 0 { + t.Fatalf("expected exit code 0, got %d; output: %s", exitCode, stdout) + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if !resp.OK { + t.Error("expected ok=true") + } +} + +func TestConfigSet_ProfileLocalPort(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + stdout, _, exitCode := runXSQL(t, "config", "set", "profile.prod.local_port", "13306", + "--config", path, "--format", "json") + + if exitCode != 0 { + t.Fatalf("expected exit code 0, got %d; output: %s", exitCode, stdout) + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if !resp.OK { + t.Error("expected ok=true") + } +} + +func TestConfigSet_SSHProxy(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + stdout, _, exitCode := runXSQL(t, "config", "set", "ssh_proxy.bastion.host", "bastion.example.com", + "--config", path, "--format", "json") + + if exitCode != 0 { + t.Fatalf("expected exit code 0, got %d; output: %s", exitCode, stdout) + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if !resp.OK { + t.Error("expected ok=true") + } +} + +func TestConfigSet_InvalidKey(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + stdout, _, exitCode := runXSQL(t, "config", "set", "badkey", "value", + "--config", path, "--format", "json") + + if exitCode == 0 { + t.Error("expected non-zero exit code for invalid key") + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if resp.OK { + t.Error("expected ok=false") + } + if resp.Error == nil { + t.Fatal("expected error") + } + if resp.Error.Code != "XSQL_CFG_INVALID" { + t.Errorf("expected XSQL_CFG_INVALID, got %s", resp.Error.Code) + } +} + +func TestConfigSet_InvalidPortValue(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + stdout, _, exitCode := runXSQL(t, "config", "set", "profile.dev.port", "abc", + "--config", path, "--format", "json") + + if exitCode == 0 { + t.Error("expected non-zero exit code for invalid port") + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if resp.OK { + t.Error("expected ok=false") + } +} + +func TestConfigSet_UnknownField(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + stdout, _, exitCode := runXSQL(t, "config", "set", "profile.dev.nonexistent", "value", + "--config", path, "--format", "json") + + if exitCode == 0 { + t.Error("expected non-zero exit code for unknown field") + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + if resp.OK { + t.Error("expected ok=false") + } +} + +func TestConfigSet_MissingArgs(t *testing.T) { + _, _, exitCode := runXSQL(t, "config", "set", "--format", "json") + + if exitCode == 0 { + t.Error("expected non-zero exit code for missing args") + } +} + +func TestConfigSet_TableFormat(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + stdout, _, exitCode := runXSQL(t, "config", "set", "profile.dev.host", "localhost", + "--config", path, "--format", "table") + + if exitCode != 0 { + t.Fatalf("expected exit code 0, got %d", exitCode) + } + + // Should not be JSON + if strings.HasPrefix(strings.TrimSpace(stdout), "{") { + t.Error("table format should not output JSON on stdout") + } +} + +// ============================================================================ +// Config Help Tests +// ============================================================================ + +func TestConfig_Help(t *testing.T) { + stdout, _, exitCode := runXSQL(t, "config", "--help") + + if exitCode != 0 { + t.Fatalf("expected exit code 0, got %d", exitCode) + } + + if !strings.Contains(stdout, "config") { + t.Error("help output should contain 'config'") + } + if !strings.Contains(stdout, "init") { + t.Error("help output should contain 'init'") + } + if !strings.Contains(stdout, "set") { + t.Error("help output should contain 'set'") + } +} + +func TestConfigInit_Help(t *testing.T) { + stdout, _, exitCode := runXSQL(t, "config", "init", "--help") + + if exitCode != 0 { + t.Fatalf("expected exit code 0, got %d", exitCode) + } + + if !strings.Contains(stdout, "init") { + t.Error("help should contain 'init'") + } +} + +func TestConfigSet_Help(t *testing.T) { + stdout, _, exitCode := runXSQL(t, "config", "set", "--help") + + if exitCode != 0 { + t.Fatalf("expected exit code 0, got %d", exitCode) + } + + if !strings.Contains(stdout, "set") { + t.Error("help should contain 'set'") + } +} + +// ============================================================================ +// Proxy with config local_port Tests +// ============================================================================ + +func TestProxy_UsesConfigLocalPort(t *testing.T) { + // Test that proxy reads local_port from profile config + config := createTempConfig(t, `ssh_proxies: + test_ssh: + host: bastion.example.com + user: test +profiles: + test: + db: mysql + host: remote-db.example.com + port: 3306 + local_port: 13306 + ssh_proxy: test_ssh +`) + + // This will fail to connect to SSH, but should read the config + stdout, _, exitCode := runXSQL(t, "-p", "test", "proxy", + "--config", config, "--format", "json") + + // Should fail due to SSH connection issues + if exitCode == 0 { + // If somehow it succeeded, verify local port + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err == nil && resp.OK && resp.Data != nil { + data, ok := resp.Data.(map[string]any) + if ok { + if localAddr, hasAddr := data["local_address"].(string); hasAddr { + if !strings.Contains(localAddr, "13306") { + t.Errorf("expected local_address to contain port 13306, got %s", localAddr) + } + } + } + } + } + + // The output should be valid JSON regardless + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v\noutput: %s", err, stdout) + } +} + +func TestProxy_CLIFlagOverridesConfigLocalPort(t *testing.T) { + // Test that --local-port flag overrides config local_port + config := createTempConfig(t, `ssh_proxies: + test_ssh: + host: bastion.example.com + user: test +profiles: + test: + db: mysql + host: remote-db.example.com + port: 3306 + local_port: 13306 + ssh_proxy: test_ssh +`) + + // Use --local-port to override config + stdout, _, exitCode := runXSQL(t, "-p", "test", "proxy", + "--config", config, "--local-port", "23306", "--format", "json") + + // Will fail at SSH, but should have accepted the CLI port + if exitCode == 0 { + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err == nil && resp.OK && resp.Data != nil { + data, ok := resp.Data.(map[string]any) + if ok { + if localAddr, hasAddr := data["local_address"].(string); hasAddr { + if !strings.Contains(localAddr, "23306") { + t.Errorf("expected local_address to contain port 23306, got %s", localAddr) + } + } + } + } + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v\noutput: %s", err, stdout) + } +} + +func TestProxy_ConfigPortInUse_NonTTY(t *testing.T) { + // When config specifies a port that's in use and we're non-TTY, should get error + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to find available port: %v", err) + } + defer listener.Close() + + port := listener.Addr().(*net.TCPAddr).Port + + config := createTempConfig(t, `ssh_proxies: + test_ssh: + host: bastion.example.com + user: test + identity_file: ~/.ssh/id_ed25519 +profiles: + test: + db: mysql + host: remote-db.example.com + port: 3306 + local_port: `+strconv.Itoa(port)+` + ssh_proxy: test_ssh +`) + + stdout, _, exitCode := runXSQL(t, "-p", "test", "proxy", + "--config", config, "--format", "json") + + // Should fail because port is in use (non-TTY mode) + if exitCode == 0 { + t.Error("expected non-zero exit code when config port is in use") + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v\noutput: %s", err, stdout) + } + if resp.OK { + t.Error("expected ok=false") + } + if resp.Error != nil && resp.Error.Code != "XSQL_PORT_IN_USE" { + // May also be SSH error if port resolution happens differently + t.Logf("error code: %s (may be port or ssh error)", resp.Error.Code) + } +} + +func TestProxy_CLIFlagPortInUse(t *testing.T) { + // When CLI flag specifies a port that's in use, should fail directly + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to find available port: %v", err) + } + defer listener.Close() + + port := listener.Addr().(*net.TCPAddr).Port + + config := createTempConfig(t, `ssh_proxies: + test_ssh: + host: bastion.example.com + user: test +profiles: + test: + db: mysql + host: remote-db.example.com + port: 3306 + ssh_proxy: test_ssh +`) + + stdout, _, exitCode := runXSQL(t, "-p", "test", "proxy", + "--config", config, "--local-port", strconv.Itoa(port), "--format", "json") + + // Should fail + if exitCode == 0 { + t.Error("expected non-zero exit code when CLI port is in use") + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v\noutput: %s", err, stdout) + } + if resp.OK { + t.Error("expected ok=false") + } +} + +// ============================================================================ +// Config set + profile show integration +// ============================================================================ + +func TestConfigSet_ThenProfileShow(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "xsql.yaml") + if err := os.WriteFile(path, []byte("profiles: {}\nssh_proxies: {}\n"), 0600); err != nil { + t.Fatal(err) + } + + // Set profile fields + fields := []struct { + key, value string + }{ + {"profile.dev.db", "mysql"}, + {"profile.dev.host", "localhost"}, + {"profile.dev.port", "3306"}, + {"profile.dev.user", "root"}, + {"profile.dev.database", "testdb"}, + {"profile.dev.local_port", "13306"}, + } + + for _, f := range fields { + _, _, exitCode := runXSQL(t, "config", "set", f.key, f.value, + "--config", path, "--format", "json") + if exitCode != 0 { + t.Fatalf("config set %s=%s failed", f.key, f.value) + } + } + + // Show the profile + stdout, _, exitCode := runXSQL(t, "-p", "dev", "profile", "show", "dev", + "--config", path, "--format", "json") + + if exitCode != 0 { + t.Fatalf("profile show failed; output: %s", stdout) + } + + var resp Response + if err := json.Unmarshal([]byte(stdout), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + if !resp.OK { + t.Error("expected ok=true") + } + + data, ok := resp.Data.(map[string]any) + if !ok { + t.Fatal("expected data to be a map") + } + + if data["db"] != "mysql" { + t.Errorf("expected db=mysql, got %v", data["db"]) + } + if data["host"] != "localhost" { + t.Errorf("expected host=localhost, got %v", data["host"]) + } +}