From ca798c38d1bbf8f637f97056a57371de6c313c72 Mon Sep 17 00:00:00 2001 From: xgopilot Date: Thu, 16 Apr 2026 06:13:48 +0000 Subject: [PATCH] fix(provider): mask api key render and harden provider add rollback - mask provider add API key in TUI view - validate custom provider names to block path traversal - rollback persisted/process/user env and provider files on add failure - add tests to raise requested file coverage above 80% Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: creatang <165447160+creatang@users.noreply.github.com> --- internal/cli/root_test.go | 10 ++ internal/config/env_platform.go | 5 + internal/config/env_platform_test.go | 17 +++ internal/config/env_platform_windows.go | 28 +++++ internal/config/envfile.go | 40 +++++++ internal/config/envfile_test.go | 90 +++++++++++++++ internal/config/loader_test.go | 66 +++++++++++ internal/config/provider_loader.go | 32 ++++++ internal/tui/core/app/update.go | 98 +++++++++++++--- internal/tui/core/app/update_test.go | 123 ++++++++++++++++++++ internal/tui/core/app/view.go | 10 +- internal/tui/core/app/view_test.go | 143 ++++++++++++++++++++++++ 12 files changed, 646 insertions(+), 16 deletions(-) create mode 100644 internal/config/env_platform_test.go diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index 182050fc..d0716057 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -1006,6 +1006,16 @@ func TestDefaultGlobalPreloadKeepsExistingProcessEnv(t *testing.T) { } } +func TestDefaultGlobalPreloadReturnsContextError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := defaultGlobalPreload(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } +} + func TestWriteURLDispatchSuccessOutput(t *testing.T) { var buffer bytes.Buffer err := writeURLDispatchSuccessOutput(&buffer, urlscheme.DispatchResult{ diff --git a/internal/config/env_platform.go b/internal/config/env_platform.go index 07afef6e..aa8ee3b2 100644 --- a/internal/config/env_platform.go +++ b/internal/config/env_platform.go @@ -7,3 +7,8 @@ package config func PersistUserEnvVar(key string, value string) error { return nil } + +// DeleteUserEnvVar 删除用户级环境变量;非 Windows 平台当前无需额外处理。 +func DeleteUserEnvVar(key string) error { + return nil +} diff --git a/internal/config/env_platform_test.go b/internal/config/env_platform_test.go new file mode 100644 index 00000000..d141c14c --- /dev/null +++ b/internal/config/env_platform_test.go @@ -0,0 +1,17 @@ +//go:build !windows + +package config + +import "testing" + +func TestPersistUserEnvVarNoopOnNonWindows(t *testing.T) { + if err := PersistUserEnvVar("NEOCODE_TEST_KEY", "value"); err != nil { + t.Fatalf("PersistUserEnvVar() error = %v", err) + } +} + +func TestDeleteUserEnvVarNoopOnNonWindows(t *testing.T) { + if err := DeleteUserEnvVar("NEOCODE_TEST_KEY"); err != nil { + t.Fatalf("DeleteUserEnvVar() error = %v", err) + } +} diff --git a/internal/config/env_platform_windows.go b/internal/config/env_platform_windows.go index f98af530..21edc605 100644 --- a/internal/config/env_platform_windows.go +++ b/internal/config/env_platform_windows.go @@ -36,3 +36,31 @@ func PersistUserEnvVar(key string, value string) error { } return nil } + +// DeleteUserEnvVar 删除 Windows 用户级环境变量,不存在时视为成功。 +func DeleteUserEnvVar(key string) error { + normalizedKey := strings.TrimSpace(key) + if normalizedKey == "" { + return errors.New("config: env key is empty") + } + if strings.ContainsAny(normalizedKey, " \t\r\n=") { + return fmt.Errorf("config: env key %q is invalid", normalizedKey) + } + + envKey, err := registry.OpenKey(registry.CURRENT_USER, windowsUserEnvironmentKey, registry.SET_VALUE) + if err != nil { + if errors.Is(err, registry.ErrNotExist) { + return nil + } + return fmt.Errorf("config: open windows user env: %w", err) + } + defer envKey.Close() + + if err := envKey.DeleteValue(normalizedKey); err != nil { + if errors.Is(err, registry.ErrNotExist) { + return nil + } + return fmt.Errorf("config: delete windows user env %q: %w", normalizedKey, err) + } + return nil +} diff --git a/internal/config/envfile.go b/internal/config/envfile.go index 93fd241e..b92f9cc7 100644 --- a/internal/config/envfile.go +++ b/internal/config/envfile.go @@ -101,6 +101,46 @@ func LoadPersistedEnv(baseDir string) error { return nil } +// RemovePersistedEnvVar 从持久化 .env 文件中删除指定键;文件不存在时视为成功。 +func RemovePersistedEnvVar(baseDir string, key string) error { + normalizedKey := strings.TrimSpace(key) + if normalizedKey == "" { + return errors.New("config: env key is empty") + } + if strings.ContainsAny(normalizedKey, " \t\r\n=") { + return fmt.Errorf("config: env key %q is invalid", normalizedKey) + } + + envPath := EnvFilePath(baseDir) + data, err := os.ReadFile(envPath) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("config: read env file: %w", err) + } + + lines := strings.Split(strings.ReplaceAll(string(data), "\r\n", "\n"), "\n") + filtered := make([]string, 0, len(lines)) + for _, line := range lines { + currentKey, _, ok := parseEnvAssignment(line) + if ok && currentKey == normalizedKey { + continue + } + filtered = append(filtered, line) + } + + content := strings.Join(filtered, "\n") + content = strings.TrimRight(content, "\n") + if content != "" { + content += "\n" + } + if err := os.WriteFile(envPath, []byte(content), 0o600); err != nil { + return fmt.Errorf("config: write env file: %w", err) + } + return nil +} + func parseEnvAssignment(line string) (string, string, bool) { trimmed := strings.TrimSpace(line) if trimmed == "" || strings.HasPrefix(trimmed, "#") { diff --git a/internal/config/envfile_test.go b/internal/config/envfile_test.go index b49efe37..4c6110ea 100644 --- a/internal/config/envfile_test.go +++ b/internal/config/envfile_test.go @@ -107,6 +107,96 @@ func TestPersistEnvVarRejectsInvalidInput(t *testing.T) { } } +func TestRemovePersistedEnvVarRemovesEntryOnly(t *testing.T) { + baseDir := t.TempDir() + path := EnvFilePath(baseDir) + content := "KEEP=1\nREMOVE=2\nKEEP_AGAIN=3\n" + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + if err := RemovePersistedEnvVar(baseDir, "REMOVE"); err != nil { + t.Fatalf("RemovePersistedEnvVar() error = %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + got := string(data) + if strings.Contains(got, "REMOVE=2") { + t.Fatalf("expected key to be removed, got %q", got) + } + if !strings.Contains(got, "KEEP=1") || !strings.Contains(got, "KEEP_AGAIN=3") { + t.Fatalf("expected other lines preserved, got %q", got) + } +} + +func TestRemovePersistedEnvVarHandlesMissingFileAndInvalidKey(t *testing.T) { + baseDir := t.TempDir() + + if err := RemovePersistedEnvVar(baseDir, "MISSING"); err != nil { + t.Fatalf("expected missing file to be ignored, got %v", err) + } + if err := RemovePersistedEnvVar(baseDir, " "); err == nil { + t.Fatal("expected empty key error") + } + if err := RemovePersistedEnvVar(baseDir, "BAD KEY"); err == nil { + t.Fatal("expected invalid key error") + } +} + +func TestParseEnvAssignmentAndValueVariants(t *testing.T) { + tests := []struct { + line string + wantKey string + wantVal string + wantOkay bool + }{ + {line: "", wantOkay: false}, + {line: "# comment", wantOkay: false}, + {line: "NO_EQUALS", wantOkay: false}, + {line: "export KEY=value", wantKey: "KEY", wantVal: "value", wantOkay: true}, + {line: "KEY='single quoted value'", wantKey: "KEY", wantVal: "single quoted value", wantOkay: true}, + {line: `KEY="line\tvalue"`, wantKey: "KEY", wantVal: "line\tvalue", wantOkay: true}, + {line: `KEY="unterminated`, wantKey: "KEY", wantVal: `"unterminated`, wantOkay: true}, + {line: "SPACED = plain ", wantKey: "SPACED", wantVal: "plain", wantOkay: true}, + } + for _, tt := range tests { + key, val, ok := parseEnvAssignment(tt.line) + if ok != tt.wantOkay { + t.Fatalf("parseEnvAssignment(%q) ok = %v, want %v", tt.line, ok, tt.wantOkay) + } + if !tt.wantOkay { + continue + } + if key != tt.wantKey || val != tt.wantVal { + t.Fatalf("parseEnvAssignment(%q) = (%q,%q), want (%q,%q)", tt.line, key, val, tt.wantKey, tt.wantVal) + } + } +} + +func TestEncodeEnvValue(t *testing.T) { + tests := []struct { + value string + want string + }{ + {value: "", want: `""`}, + {value: "plain", want: "plain"}, + {value: "has space", want: `"has space"`}, + {value: `has"quote`, want: `"has\"quote"`}, + {value: "has#hash", want: `"has#hash"`}, + } + for _, tt := range tests { + if got := encodeEnvValue(tt.value); got != tt.want { + t.Fatalf("encodeEnvValue(%q) = %q, want %q", tt.value, got, tt.want) + } + } +} + func captureEnv(t *testing.T, key string) func() { t.Helper() value, exists := os.LookupEnv(key) diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 7526d763..e5c0cca3 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -976,6 +976,72 @@ func TestSaveCustomProviderPersistsDriverSpecificSettings(t *testing.T) { } } +func TestSaveCustomProviderRejectsUnsafeProviderName(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + invalidNames := []string{ + "", + " ", + "../escape", + "..", + "team/gateway", + `team\gateway`, + "/tmp/abs", + "中文", + } + for _, name := range invalidNames { + err := SaveCustomProvider( + baseDir, + name, + provider.DriverOpenAICompat, + "https://llm.example.com/v1", + "CUSTOM_API_KEY", + provider.OpenAICompatibleAPIStyleChatCompletions, + "", + "", + ) + if err == nil { + t.Fatalf("expected SaveCustomProvider to reject %q", name) + } + } +} + +func TestDeleteCustomProviderRejectsUnsafeProviderName(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + invalidNames := []string{ + "", + "../escape", + "team/gateway", + `team\gateway`, + } + for _, name := range invalidNames { + if err := DeleteCustomProvider(baseDir, name); err == nil { + t.Fatalf("expected DeleteCustomProvider to reject %q", name) + } + } +} + +func TestDeleteCustomProviderRemovesProviderDir(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + providerName := "team-gateway" + providerDir := filepath.Join(baseDir, providersDirName, providerName) + if err := os.MkdirAll(providerDir, 0o755); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + + if err := DeleteCustomProvider(baseDir, providerName); err != nil { + t.Fatalf("DeleteCustomProvider() error = %v", err) + } + if _, err := os.Stat(providerDir); !os.IsNotExist(err) { + t.Fatalf("expected provider dir to be removed, stat err = %v", err) + } +} + func TestLoaderLoadsUnknownCustomProviderDriverUsingTopLevelBaseURL(t *testing.T) { t.Parallel() diff --git a/internal/config/provider_loader.go b/internal/config/provider_loader.go index da15f0e9..dd51cb7d 100644 --- a/internal/config/provider_loader.go +++ b/internal/config/provider_loader.go @@ -2,6 +2,7 @@ package config import ( "bytes" + "errors" "fmt" "os" "path/filepath" @@ -218,6 +219,10 @@ func SaveCustomProvider( deploymentMode string, apiVersion string, ) error { + if err := validateCustomProviderName(name); err != nil { + return err + } + providersDir := filepath.Join(baseDir, providersDirName, name) if err := os.MkdirAll(providersDir, 0o755); err != nil { return fmt.Errorf("config: create provider dir: %w", err) @@ -265,6 +270,33 @@ func SaveCustomProvider( // DeleteCustomProvider 删除自定义 provider。 func DeleteCustomProvider(baseDir string, name string) error { + if err := validateCustomProviderName(name); err != nil { + return err + } providersDir := filepath.Join(baseDir, providersDirName, name) return os.RemoveAll(providersDir) } + +// validateCustomProviderName 校验 provider 名称,拒绝路径穿越和分隔符语义。 +func validateCustomProviderName(name string) error { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return errors.New("config: provider name is empty") + } + if trimmed == "." || trimmed == ".." { + return fmt.Errorf("config: provider name %q is invalid", name) + } + if strings.ContainsAny(trimmed, `/\`) { + return fmt.Errorf("config: provider name %q is invalid", name) + } + if filepath.IsAbs(trimmed) { + return fmt.Errorf("config: provider name %q is invalid", name) + } + for _, r := range trimmed { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' { + continue + } + return fmt.Errorf("config: provider name %q contains unsupported character %q", name, string(r)) + } + return nil +} diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index db720cea..aac312ca 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -45,6 +45,7 @@ const providerAddSelectTimeout = 10 * time.Second var panelOrder = []panel{panelTranscript, panelActivity, panelInput} var persistProviderUserEnvVar = config.PersistUserEnvVar +var deleteProviderUserEnvVar = config.DeleteUserEnvVar func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd @@ -2367,6 +2368,28 @@ func (a *App) runProviderAddFlow(request providerAddRequest) tea.Cmd { return func() tea.Msg { apiKeyEnv := providerAddAPIKeyEnv(request.Name) + previousEnvValue, hadPreviousEnv := os.LookupEnv(apiKeyEnv) + + rollback := func(processEnvApplied bool, userEnvPersisted bool, envPersisted bool, providerSaved bool, originalErr error) error { + return rollbackProviderAddSideEffects( + baseDir, + request.Name, + apiKeyEnv, + processEnvApplied, + userEnvPersisted, + envPersisted, + hadPreviousEnv, + previousEnvValue, + providerSaved, + originalErr, + ) + } + + providerSaved := false + envPersisted := false + userEnvPersisted := false + processEnvApplied := false + if err := config.SaveCustomProvider( baseDir, request.Name, @@ -2382,19 +2405,17 @@ func (a *App) runProviderAddFlow(request providerAddRequest) tea.Cmd { Error: sanitizeProviderAddError(fmt.Errorf("save provider config: %w", err), request.APIKey, baseDir), } } + providerSaved = true if err := config.PersistEnvVar(baseDir, apiKeyEnv, request.APIKey); err != nil { - if rollbackErr := config.DeleteCustomProvider(baseDir, request.Name); rollbackErr != nil { - err = fmt.Errorf("%w (rollback failed: %v)", err, rollbackErr) - } + err = rollback(processEnvApplied, userEnvPersisted, envPersisted, providerSaved, err) return providerAddResultMsg{ Name: request.Name, Error: sanitizeProviderAddError(fmt.Errorf("persist api key: %w", err), request.APIKey, baseDir), } } + envPersisted = true if err := persistProviderUserEnvVar(apiKeyEnv, request.APIKey); err != nil { - if rollbackErr := config.DeleteCustomProvider(baseDir, request.Name); rollbackErr != nil { - err = fmt.Errorf("%w (rollback failed: %v)", err, rollbackErr) - } + err = rollback(processEnvApplied, userEnvPersisted, envPersisted, providerSaved, err) return providerAddResultMsg{ Name: request.Name, Error: sanitizeProviderAddError( @@ -2404,19 +2425,17 @@ func (a *App) runProviderAddFlow(request providerAddRequest) tea.Cmd { ), } } + userEnvPersisted = true if err := os.Setenv(apiKeyEnv, request.APIKey); err != nil { - if rollbackErr := config.DeleteCustomProvider(baseDir, request.Name); rollbackErr != nil { - err = fmt.Errorf("%w (rollback failed: %v)", err, rollbackErr) - } + err = rollback(processEnvApplied, userEnvPersisted, envPersisted, providerSaved, err) return providerAddResultMsg{ Name: request.Name, Error: sanitizeProviderAddError(fmt.Errorf("apply api key env: %w", err), request.APIKey, baseDir), } } + processEnvApplied = true if _, err := configManager.Reload(context.Background()); err != nil { - if rollbackErr := config.DeleteCustomProvider(baseDir, request.Name); rollbackErr != nil { - err = fmt.Errorf("%w (rollback failed: %v)", err, rollbackErr) - } + err = rollback(processEnvApplied, userEnvPersisted, envPersisted, providerSaved, err) return providerAddResultMsg{ Name: request.Name, Error: sanitizeProviderAddError(fmt.Errorf("reload config snapshot: %w", err), request.APIKey, baseDir), @@ -2428,9 +2447,7 @@ func (a *App) runProviderAddFlow(request providerAddRequest) tea.Cmd { selection, err := providerSvc.SelectProvider(ctx, request.Name) if err != nil { - if rollbackErr := config.DeleteCustomProvider(baseDir, request.Name); rollbackErr != nil { - err = fmt.Errorf("%w (rollback failed: %v)", err, rollbackErr) - } + err = rollback(processEnvApplied, userEnvPersisted, envPersisted, providerSaved, err) if errors.Is(err, context.DeadlineExceeded) { err = fmt.Errorf( "model discovery timed out after %s; check base URL, API key, and network connectivity", @@ -2450,6 +2467,57 @@ func (a *App) runProviderAddFlow(request providerAddRequest) tea.Cmd { } } +// rollbackProviderAddSideEffects 回滚 provider add 过程中已落地的副作用,避免失败后残留配置与密钥。 +func rollbackProviderAddSideEffects( + baseDir string, + providerName string, + apiKeyEnv string, + processEnvApplied bool, + userEnvPersisted bool, + envPersisted bool, + hadPreviousEnv bool, + previousEnvValue string, + providerSaved bool, + originalErr error, +) error { + rollbackErrs := make([]error, 0, 4) + + if processEnvApplied { + if hadPreviousEnv { + if err := os.Setenv(apiKeyEnv, previousEnvValue); err != nil { + rollbackErrs = append(rollbackErrs, fmt.Errorf("restore process env: %w", err)) + } + } else { + if err := os.Unsetenv(apiKeyEnv); err != nil { + rollbackErrs = append(rollbackErrs, fmt.Errorf("unset process env: %w", err)) + } + } + } + + if userEnvPersisted { + if err := deleteProviderUserEnvVar(apiKeyEnv); err != nil { + rollbackErrs = append(rollbackErrs, fmt.Errorf("delete user env: %w", err)) + } + } + + if envPersisted { + if err := config.RemovePersistedEnvVar(baseDir, apiKeyEnv); err != nil { + rollbackErrs = append(rollbackErrs, fmt.Errorf("remove persisted env: %w", err)) + } + } + + if providerSaved { + if err := config.DeleteCustomProvider(baseDir, providerName); err != nil { + rollbackErrs = append(rollbackErrs, fmt.Errorf("delete provider config: %w", err)) + } + } + + if len(rollbackErrs) == 0 { + return originalErr + } + return fmt.Errorf("%w (rollback failed: %v)", originalErr, errors.Join(rollbackErrs...)) +} + func (a *App) handleProviderAddResultMsg(msg providerAddResultMsg) { if a.providerAddForm == nil { return diff --git a/internal/tui/core/app/update_test.go b/internal/tui/core/app/update_test.go index e3e4faf1..134afe1b 100644 --- a/internal/tui/core/app/update_test.go +++ b/internal/tui/core/app/update_test.go @@ -332,6 +332,129 @@ func TestSubmitProviderAddFormRedactsSensitiveError(t *testing.T) { } } +func TestSubmitProviderAddFormRollsBackPersistedStateOnSelectFailure(t *testing.T) { + restorePersistUserEnv := persistProviderUserEnvVar + restoreDeleteUserEnv := deleteProviderUserEnvVar + persistProviderUserEnvVar = func(key string, value string) error { return nil } + deleteProviderUserEnvVar = func(key string) error { return nil } + t.Cleanup(func() { persistProviderUserEnvVar = restorePersistUserEnv }) + t.Cleanup(func() { deleteProviderUserEnvVar = restoreDeleteUserEnv }) + + providerName := "rollback-gateway" + envName := providerAddAPIKeyEnv(providerName) + restoreEnv := captureEnv(t, envName) + defer restoreEnv() + if err := os.Setenv(envName, "previous-value"); err != nil { + t.Fatalf("Setenv() error = %v", err) + } + + service := stubProviderService{ + selectErr: errors.New("select failed"), + } + app, _ := newTestAppWithProviderService(t, service) + app.startProviderAddForm() + app.providerAddForm.Name = providerName + app.providerAddForm.Driver = provider.DriverOpenAICompat + app.providerAddForm.BaseURL = "https://rollback.example.com/v1" + app.providerAddForm.APIKey = "sk-failed-rollback" + app.providerAddForm.APIStyle = provider.OpenAICompatibleAPIStyleChatCompletions + + cmd := app.submitProviderAddForm() + if cmd == nil { + t.Fatalf("expected async command") + } + msg := cmd() + result, ok := msg.(providerAddResultMsg) + if !ok { + t.Fatalf("expected providerAddResultMsg, got %T", msg) + } + if strings.TrimSpace(result.Error) == "" { + t.Fatalf("expected failure result") + } + + if got := os.Getenv(envName); got != "previous-value" { + t.Fatalf("expected process env restored, got %q", got) + } + + envData, readErr := os.ReadFile(config.EnvFilePath(app.configManager.BaseDir())) + if readErr != nil { + t.Fatalf("read env file: %v", readErr) + } + if strings.Contains(string(envData), envName+"=") { + t.Fatalf("expected persisted env key rollback, got %q", string(envData)) + } + + providerPath := filepath.Join(app.configManager.BaseDir(), "providers", providerName) + if _, err := os.Stat(providerPath); !os.IsNotExist(err) { + t.Fatalf("expected provider dir rollback, stat err = %v", err) + } +} + +func TestSubmitProviderAddFormRollsBackOnUserEnvPersistFailure(t *testing.T) { + restorePersistUserEnv := persistProviderUserEnvVar + restoreDeleteUserEnv := deleteProviderUserEnvVar + persistProviderUserEnvVar = func(key string, value string) error { return errors.New("user env failed") } + deleteProviderUserEnvVar = func(key string) error { return nil } + t.Cleanup(func() { persistProviderUserEnvVar = restorePersistUserEnv }) + t.Cleanup(func() { deleteProviderUserEnvVar = restoreDeleteUserEnv }) + + providerName := "user-env-fail" + envName := providerAddAPIKeyEnv(providerName) + restoreEnv := captureEnv(t, envName) + defer restoreEnv() + _ = os.Unsetenv(envName) + + app, _ := newTestApp(t) + app.startProviderAddForm() + app.providerAddForm.Name = providerName + app.providerAddForm.Driver = provider.DriverOpenAICompat + app.providerAddForm.BaseURL = "https://rollback.example.com/v1" + app.providerAddForm.APIKey = "sk-failed" + app.providerAddForm.APIStyle = provider.OpenAICompatibleAPIStyleChatCompletions + + cmd := app.submitProviderAddForm() + if cmd == nil { + t.Fatalf("expected async command") + } + msg := cmd() + result, ok := msg.(providerAddResultMsg) + if !ok { + t.Fatalf("expected providerAddResultMsg, got %T", msg) + } + if strings.TrimSpace(result.Error) == "" { + t.Fatalf("expected failure result") + } + + if got := os.Getenv(envName); got != "" { + t.Fatalf("expected process env to stay unset, got %q", got) + } + + envData, readErr := os.ReadFile(config.EnvFilePath(app.configManager.BaseDir())) + if readErr != nil { + t.Fatalf("read env file: %v", readErr) + } + if strings.Contains(string(envData), envName+"=") { + t.Fatalf("expected persisted env key rollback, got %q", string(envData)) + } + + providerPath := filepath.Join(app.configManager.BaseDir(), "providers", providerName) + if _, err := os.Stat(providerPath); !os.IsNotExist(err) { + t.Fatalf("expected provider dir rollback, stat err = %v", err) + } +} + +func captureEnv(t *testing.T, key string) func() { + t.Helper() + value, exists := os.LookupEnv(key) + return func() { + if exists { + _ = os.Setenv(key, value) + return + } + _ = os.Unsetenv(key) + } +} + func TestAppUpdateBasic(t *testing.T) { app, _ := newTestApp(t) diff --git a/internal/tui/core/app/view.go b/internal/tui/core/app/view.go index ce702f69..2eb3b516 100644 --- a/internal/tui/core/app/view.go +++ b/internal/tui/core/app/view.go @@ -235,7 +235,7 @@ func (a App) renderProviderAddForm() string { case providerAddFieldAPIVersion: fields = append(fields, renderField{label: "API Version", value: a.providerAddForm.APIVersion}) case providerAddFieldAPIKey: - fields = append(fields, renderField{label: "API Key", value: a.providerAddForm.APIKey, required: true}) + fields = append(fields, renderField{label: "API Key", value: maskedSecret(a.providerAddForm.APIKey), required: true}) } } @@ -268,6 +268,14 @@ func (a App) renderProviderAddForm() string { return sb.String() } +// maskedSecret 将敏感输入渲染为固定掩码,避免在终端界面泄露明文。 +func maskedSecret(value string) string { + if strings.TrimSpace(value) == "" { + return "" + } + return "******" +} + func (a App) renderPrompt(width int) string { if a.pendingPermission != nil { box := a.styles.inputBoxFocused diff --git a/internal/tui/core/app/view_test.go b/internal/tui/core/app/view_test.go index 0c807d5a..6f291439 100644 --- a/internal/tui/core/app/view_test.go +++ b/internal/tui/core/app/view_test.go @@ -234,3 +234,146 @@ func TestRenderBody(t *testing.T) { t.Fatalf("expected renderBody output") } } + +func TestMaskedSecret(t *testing.T) { + if got := maskedSecret(""); got != "" { + t.Fatalf("maskedSecret(empty) = %q, want empty", got) + } + if got := maskedSecret(" "); got != "" { + t.Fatalf("maskedSecret(space) = %q, want empty", got) + } + if got := maskedSecret("sk-12345"); got != "******" { + t.Fatalf("maskedSecret(secret) = %q, want ******", got) + } +} + +func TestRenderProviderAddFormMasksAPIKeyAndShowsHints(t *testing.T) { + app, _ := newTestApp(t) + app.startProviderAddForm() + app.providerAddForm.Driver = "openaicompat" + app.providerAddForm.Name = "team-gateway" + app.providerAddForm.APIKey = "sk-secret-98765" + app.providerAddForm.BaseURL = "" + app.providerAddForm.APIStyle = "" + app.providerAddForm.Error = "input invalid" + app.providerAddForm.ErrorIsHard = true + + form := app.renderProviderAddForm() + if strings.Contains(form, "sk-secret-98765") { + t.Fatalf("expected api key to be masked, got %q", form) + } + if !strings.Contains(form, "API Key: ******") { + t.Fatalf("expected masked api key, got %q", form) + } + if !strings.Contains(form, "留空会自动填充默认地址") { + t.Fatalf("expected base url hint, got %q", form) + } + if !strings.Contains(form, "默认 chat_completions") { + t.Fatalf("expected api style hint, got %q", form) + } + if !strings.Contains(form, "[Error] input invalid") { + t.Fatalf("expected hard error label, got %q", form) + } +} + +func TestRenderProviderAddFormPromptLabel(t *testing.T) { + app, _ := newTestApp(t) + app.startProviderAddForm() + app.providerAddForm.Driver = "anthropic" + app.providerAddForm.Error = "continue input" + app.providerAddForm.ErrorIsHard = false + + form := app.renderProviderAddForm() + if !strings.Contains(form, "[Prompt] continue input") { + t.Fatalf("expected prompt label, got %q", form) + } +} + +func TestViewSmallWindowHint(t *testing.T) { + app, _ := newTestApp(t) + app.width = 40 + app.height = 10 + + view := app.View() + if !strings.Contains(view, "Window too small.") { + t.Fatalf("expected small-window hint, got %q", view) + } +} + +func TestViewNormalIncludesHeaderAndBody(t *testing.T) { + app, _ := newTestApp(t) + app.width = 100 + app.height = 30 + app.state.CurrentModel = "test-model" + app.state.StatusText = "running" + app.state.IsAgentRunning = true + app.runProgressKnown = true + app.runProgressValue = 0.42 + app.runProgressLabel = "loading" + app.state.InputText = "hi" + app.input.SetValue("hi") + + view := app.View() + if strings.TrimSpace(view) == "" { + t.Fatalf("expected non-empty view") + } + if !strings.Contains(view, "NeoCode") { + t.Fatalf("expected header text, got %q", view) + } + if !strings.Contains(view, "42% loading") { + t.Fatalf("expected progress header, got %q", view) + } +} + +func TestRenderPanelAndActivityPreview(t *testing.T) { + app, _ := newTestApp(t) + panel := app.renderPanel("Title", "Sub", "Body", 60, 8, true) + if !strings.Contains(panel, "Title") || !strings.Contains(panel, "Body") { + t.Fatalf("expected panel content, got %q", panel) + } + + if got := app.renderActivityPreview(60); got != "" { + t.Fatalf("expected empty activity preview, got %q", got) + } + app.activities = []tuistate.ActivityEntry{{Kind: "tool", Title: "Run", Detail: "Detail"}} + withActivity := app.renderActivityPreview(60) + if !strings.Contains(withActivity, activityTitle) { + t.Fatalf("expected activity panel title, got %q", withActivity) + } +} + +func TestRenderMessageContentWithCopyBranches(t *testing.T) { + app, _ := newTestApp(t) + + app.markdownRenderer = nil + rendered, bindings := app.renderMessageContentWithCopy("hello", 40, app.styles.messageBody, 1) + if len(bindings) != 0 || strings.TrimSpace(rendered) == "" { + t.Fatalf("expected fallback content without bindings, got rendered=%q bindings=%v", rendered, bindings) + } + + app, _ = newTestApp(t) + content := "hello\n```go\nfmt.Println(\"x\")\n```\nworld" + rendered, bindings = app.renderMessageContentWithCopy(content, 60, app.styles.messageBody, 3) + if strings.TrimSpace(rendered) == "" { + t.Fatalf("expected rendered markdown content") + } + if len(bindings) != 1 { + t.Fatalf("expected one copy binding, got %d", len(bindings)) + } + if bindings[0].ID != 3 || !strings.Contains(bindings[0].Code, "fmt.Println") { + t.Fatalf("unexpected binding: %+v", bindings[0]) + } +} + +func TestNormalizeAndTrimHelpers(t *testing.T) { + trimmed := trimRenderedTrailingWhitespace("line1 \nline2\t") + if strings.HasSuffix(trimmed, "\t") || strings.HasSuffix(trimmed, " ") { + t.Fatalf("expected trailing whitespace trimmed, got %q", trimmed) + } + + normalized := normalizeBlockRightEdge("a\nbb", 6) + lines := strings.Split(normalized, "\n") + if len(lines) != 2 { + t.Fatalf("expected two lines, got %q", normalized) + } +}