Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions internal/cli/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,16 @@ func TestDefaultGlobalPreloadKeepsExistingProcessEnv(t *testing.T) {
}
}

func TestDefaultGlobalPreloadReturnsContextError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()

err := defaultGlobalPreload(ctx)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context canceled, got %v", err)
}
}

func TestWriteURLDispatchSuccessOutput(t *testing.T) {
var buffer bytes.Buffer
err := writeURLDispatchSuccessOutput(&buffer, urlscheme.DispatchResult{
Expand Down
5 changes: 5 additions & 0 deletions internal/config/env_platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ package config
func PersistUserEnvVar(key string, value string) error {
return nil
}

// DeleteUserEnvVar 删除用户级环境变量;非 Windows 平台当前无需额外处理。
func DeleteUserEnvVar(key string) error {
return nil
}
17 changes: 17 additions & 0 deletions internal/config/env_platform_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//go:build !windows

package config

import "testing"

func TestPersistUserEnvVarNoopOnNonWindows(t *testing.T) {
if err := PersistUserEnvVar("NEOCODE_TEST_KEY", "value"); err != nil {
t.Fatalf("PersistUserEnvVar() error = %v", err)
}
}

func TestDeleteUserEnvVarNoopOnNonWindows(t *testing.T) {
if err := DeleteUserEnvVar("NEOCODE_TEST_KEY"); err != nil {
t.Fatalf("DeleteUserEnvVar() error = %v", err)
}
}
28 changes: 28 additions & 0 deletions internal/config/env_platform_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,31 @@ func PersistUserEnvVar(key string, value string) error {
}
return nil
}

// DeleteUserEnvVar 删除 Windows 用户级环境变量,不存在时视为成功。
func DeleteUserEnvVar(key string) error {
normalizedKey := strings.TrimSpace(key)
if normalizedKey == "" {
return errors.New("config: env key is empty")
}
if strings.ContainsAny(normalizedKey, " \t\r\n=") {
return fmt.Errorf("config: env key %q is invalid", normalizedKey)
}

envKey, err := registry.OpenKey(registry.CURRENT_USER, windowsUserEnvironmentKey, registry.SET_VALUE)
if err != nil {
if errors.Is(err, registry.ErrNotExist) {
return nil
}
return fmt.Errorf("config: open windows user env: %w", err)
}
defer envKey.Close()

if err := envKey.DeleteValue(normalizedKey); err != nil {
if errors.Is(err, registry.ErrNotExist) {
return nil
}
return fmt.Errorf("config: delete windows user env %q: %w", normalizedKey, err)
}
return nil
}
40 changes: 40 additions & 0 deletions internal/config/envfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,46 @@ func LoadPersistedEnv(baseDir string) error {
return nil
}

// RemovePersistedEnvVar 从持久化 .env 文件中删除指定键;文件不存在时视为成功。
func RemovePersistedEnvVar(baseDir string, key string) error {
normalizedKey := strings.TrimSpace(key)
if normalizedKey == "" {
return errors.New("config: env key is empty")
}
if strings.ContainsAny(normalizedKey, " \t\r\n=") {
return fmt.Errorf("config: env key %q is invalid", normalizedKey)
}

envPath := EnvFilePath(baseDir)
data, err := os.ReadFile(envPath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("config: read env file: %w", err)
}

lines := strings.Split(strings.ReplaceAll(string(data), "\r\n", "\n"), "\n")
filtered := make([]string, 0, len(lines))
for _, line := range lines {
currentKey, _, ok := parseEnvAssignment(line)
if ok && currentKey == normalizedKey {
continue
}
filtered = append(filtered, line)
}

content := strings.Join(filtered, "\n")
content = strings.TrimRight(content, "\n")
if content != "" {
content += "\n"
}
if err := os.WriteFile(envPath, []byte(content), 0o600); err != nil {
return fmt.Errorf("config: write env file: %w", err)
}
return nil
}

func parseEnvAssignment(line string) (string, string, bool) {
trimmed := strings.TrimSpace(line)
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
Expand Down
90 changes: 90 additions & 0 deletions internal/config/envfile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,96 @@ func TestPersistEnvVarRejectsInvalidInput(t *testing.T) {
}
}

func TestRemovePersistedEnvVarRemovesEntryOnly(t *testing.T) {
baseDir := t.TempDir()
path := EnvFilePath(baseDir)
content := "KEEP=1\nREMOVE=2\nKEEP_AGAIN=3\n"
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}
if err := os.WriteFile(path, []byte(content), 0o600); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}

if err := RemovePersistedEnvVar(baseDir, "REMOVE"); err != nil {
t.Fatalf("RemovePersistedEnvVar() error = %v", err)
}

data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
got := string(data)
if strings.Contains(got, "REMOVE=2") {
t.Fatalf("expected key to be removed, got %q", got)
}
if !strings.Contains(got, "KEEP=1") || !strings.Contains(got, "KEEP_AGAIN=3") {
t.Fatalf("expected other lines preserved, got %q", got)
}
}

func TestRemovePersistedEnvVarHandlesMissingFileAndInvalidKey(t *testing.T) {
baseDir := t.TempDir()

if err := RemovePersistedEnvVar(baseDir, "MISSING"); err != nil {
t.Fatalf("expected missing file to be ignored, got %v", err)
}
if err := RemovePersistedEnvVar(baseDir, " "); err == nil {
t.Fatal("expected empty key error")
}
if err := RemovePersistedEnvVar(baseDir, "BAD KEY"); err == nil {
t.Fatal("expected invalid key error")
}
}

func TestParseEnvAssignmentAndValueVariants(t *testing.T) {
tests := []struct {
line string
wantKey string
wantVal string
wantOkay bool
}{
{line: "", wantOkay: false},
{line: "# comment", wantOkay: false},
{line: "NO_EQUALS", wantOkay: false},
{line: "export KEY=value", wantKey: "KEY", wantVal: "value", wantOkay: true},
{line: "KEY='single quoted value'", wantKey: "KEY", wantVal: "single quoted value", wantOkay: true},
{line: `KEY="line\tvalue"`, wantKey: "KEY", wantVal: "line\tvalue", wantOkay: true},
{line: `KEY="unterminated`, wantKey: "KEY", wantVal: `"unterminated`, wantOkay: true},
{line: "SPACED = plain ", wantKey: "SPACED", wantVal: "plain", wantOkay: true},
}
for _, tt := range tests {
key, val, ok := parseEnvAssignment(tt.line)
if ok != tt.wantOkay {
t.Fatalf("parseEnvAssignment(%q) ok = %v, want %v", tt.line, ok, tt.wantOkay)
}
if !tt.wantOkay {
continue
}
if key != tt.wantKey || val != tt.wantVal {
t.Fatalf("parseEnvAssignment(%q) = (%q,%q), want (%q,%q)", tt.line, key, val, tt.wantKey, tt.wantVal)
}
}
}

func TestEncodeEnvValue(t *testing.T) {
tests := []struct {
value string
want string
}{
{value: "", want: `""`},
{value: "plain", want: "plain"},
{value: "has space", want: `"has space"`},
{value: `has"quote`, want: `"has\"quote"`},
{value: "has#hash", want: `"has#hash"`},
}
for _, tt := range tests {
if got := encodeEnvValue(tt.value); got != tt.want {
t.Fatalf("encodeEnvValue(%q) = %q, want %q", tt.value, got, tt.want)
}
}
}

func captureEnv(t *testing.T, key string) func() {
t.Helper()
value, exists := os.LookupEnv(key)
Expand Down
66 changes: 66 additions & 0 deletions internal/config/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,72 @@ func TestSaveCustomProviderPersistsDriverSpecificSettings(t *testing.T) {
}
}

func TestSaveCustomProviderRejectsUnsafeProviderName(t *testing.T) {
t.Parallel()

baseDir := t.TempDir()
invalidNames := []string{
"",
" ",
"../escape",
"..",
"team/gateway",
`team\gateway`,
"/tmp/abs",
"中文",
}
for _, name := range invalidNames {
err := SaveCustomProvider(
baseDir,
name,
provider.DriverOpenAICompat,
"https://llm.example.com/v1",
"CUSTOM_API_KEY",
provider.OpenAICompatibleAPIStyleChatCompletions,
"",
"",
)
if err == nil {
t.Fatalf("expected SaveCustomProvider to reject %q", name)
}
}
}

func TestDeleteCustomProviderRejectsUnsafeProviderName(t *testing.T) {
t.Parallel()

baseDir := t.TempDir()
invalidNames := []string{
"",
"../escape",
"team/gateway",
`team\gateway`,
}
for _, name := range invalidNames {
if err := DeleteCustomProvider(baseDir, name); err == nil {
t.Fatalf("expected DeleteCustomProvider to reject %q", name)
}
}
}

func TestDeleteCustomProviderRemovesProviderDir(t *testing.T) {
t.Parallel()

baseDir := t.TempDir()
providerName := "team-gateway"
providerDir := filepath.Join(baseDir, providersDirName, providerName)
if err := os.MkdirAll(providerDir, 0o755); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}

if err := DeleteCustomProvider(baseDir, providerName); err != nil {
t.Fatalf("DeleteCustomProvider() error = %v", err)
}
if _, err := os.Stat(providerDir); !os.IsNotExist(err) {
t.Fatalf("expected provider dir to be removed, stat err = %v", err)
}
}

func TestLoaderLoadsUnknownCustomProviderDriverUsingTopLevelBaseURL(t *testing.T) {
t.Parallel()

Expand Down
32 changes: 32 additions & 0 deletions internal/config/provider_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"bytes"
"errors"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -218,6 +219,10 @@ func SaveCustomProvider(
deploymentMode string,
apiVersion string,
) error {
if err := validateCustomProviderName(name); err != nil {
return err
}

providersDir := filepath.Join(baseDir, providersDirName, name)
if err := os.MkdirAll(providersDir, 0o755); err != nil {
return fmt.Errorf("config: create provider dir: %w", err)
Expand Down Expand Up @@ -265,6 +270,33 @@ func SaveCustomProvider(

// DeleteCustomProvider 删除自定义 provider。
func DeleteCustomProvider(baseDir string, name string) error {
if err := validateCustomProviderName(name); err != nil {
return err
}
providersDir := filepath.Join(baseDir, providersDirName, name)
return os.RemoveAll(providersDir)
}

// validateCustomProviderName 校验 provider 名称,拒绝路径穿越和分隔符语义。
func validateCustomProviderName(name string) error {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return errors.New("config: provider name is empty")
}
if trimmed == "." || trimmed == ".." {
return fmt.Errorf("config: provider name %q is invalid", name)
}
if strings.ContainsAny(trimmed, `/\`) {
return fmt.Errorf("config: provider name %q is invalid", name)
}
if filepath.IsAbs(trimmed) {
return fmt.Errorf("config: provider name %q is invalid", name)
}
for _, r := range trimmed {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' {
continue
}
return fmt.Errorf("config: provider name %q contains unsupported character %q", name, string(r))
}
return nil
}
Loading
Loading