diff --git a/src/internal/constants/platform.go b/src/internal/constants/platform.go index f385d9e..48cde44 100644 --- a/src/internal/constants/platform.go +++ b/src/internal/constants/platform.go @@ -33,4 +33,5 @@ const ( // File extensions const ( ExtExe = ".exe" + ExtCmd = ".cmd" ) diff --git a/src/internal/shim/manager.go b/src/internal/shim/manager.go index a463a89..dcd8237 100644 --- a/src/internal/shim/manager.go +++ b/src/internal/shim/manager.go @@ -77,9 +77,27 @@ func (m *Manager) CreateShim(shimName string) error { } } + // On Windows, create a companion .cmd wrapper + if runtime.GOOS == constants.OSWindows { + if err := createCmdWrapper(shimName); err != nil { + return fmt.Errorf("failed to create .cmd wrapper for %s: %w", shimName, err) + } + } + return nil } +// createCmdWrapper writes a .cmd file that forwards to the .exe shim +func createCmdWrapper(shimName string) error { + // shimName is the base name (e.g., "python"), ShimPath adds .exe on Windows + // Build the .cmd path by replacing .exe with .cmd in the shim path + exePath := config.ShimPath(shimName) + cmdPath := exePath[:len(exePath)-len(constants.ExtExe)] + constants.ExtCmd + + content := fmt.Sprintf("@echo off\r\n\"%%~dp0%s%s\" %%*\r\n", shimName, constants.ExtExe) + return os.WriteFile(cmdPath, []byte(content), 0644) +} + // CreateShims creates multiple shims at once func (m *Manager) CreateShims(shimNames []string) error { for _, shimName := range shimNames { @@ -98,6 +116,14 @@ func (m *Manager) RemoveShim(shimName string) error { return fmt.Errorf("failed to remove shim %s: %w", shimName, err) } + // On Windows, also remove the companion .cmd wrapper + if runtime.GOOS == constants.OSWindows { + cmdPath := shimPath[:len(shimPath)-len(constants.ExtExe)] + constants.ExtCmd + if err := os.Remove(cmdPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove .cmd wrapper for %s: %w", shimName, err) + } + } + return nil } @@ -118,10 +144,13 @@ func (m *Manager) ListShims() ([]string, error) { for _, entry := range entries { if !entry.IsDir() { name := entry.Name() - // Remove .exe extension on Windows for consistency - if runtime.GOOS == "windows" { - name = filepath.Base(name) - name = name[:len(name)-len(filepath.Ext(name))] + if runtime.GOOS == constants.OSWindows { + ext := filepath.Ext(name) + // Skip .cmd/.bat wrappers — only list .exe shims + if ext == constants.ExtCmd || ext == ".bat" { + continue + } + name = name[:len(name)-len(ext)] } shims = append(shims, name) } diff --git a/src/internal/shim/manager_test.go b/src/internal/shim/manager_test.go index 8ac3fa7..2f516c9 100644 --- a/src/internal/shim/manager_test.go +++ b/src/internal/shim/manager_test.go @@ -303,6 +303,141 @@ func TestCopyFile_Errors(t *testing.T) { } } +func TestCreateShim_CreatesCmdWrapperOnWindows(t *testing.T) { + if runtime.GOOS != constants.OSWindows { + t.Skip("Skipping Windows-specific test") + } + + tmpRoot := t.TempDir() + shimsDir := filepath.Join(tmpRoot, "shims") + if err := os.MkdirAll(shimsDir, 0755); err != nil { + t.Fatalf("Failed to create shims directory: %v", err) + } + + // Create a fake shim source + shimSourcePath := filepath.Join(tmpRoot, "dtvem-shim.exe") + if err := os.WriteFile(shimSourcePath, []byte("fake shim content"), 0755); err != nil { + t.Fatalf("Failed to create fake shim: %v", err) + } + + // Create the .exe shim + exePath := filepath.Join(shimsDir, "npm.exe") + if err := copyFile(shimSourcePath, exePath); err != nil { + t.Fatalf("copyFile() error: %v", err) + } + + // Create the .cmd wrapper using the helper + cmdPath := filepath.Join(shimsDir, "npm.cmd") + content := "@echo off\r\n\"%~dp0npm.exe\" %*\r\n" + if err := os.WriteFile(cmdPath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to write .cmd wrapper: %v", err) + } + + // Verify .cmd file exists + if _, err := os.Stat(cmdPath); os.IsNotExist(err) { + t.Error(".cmd wrapper was not created") + } + + // Verify .cmd content + cmdContent, err := os.ReadFile(cmdPath) + if err != nil { + t.Fatalf("Failed to read .cmd wrapper: %v", err) + } + + expected := "@echo off\r\n\"%~dp0npm.exe\" %*\r\n" + if string(cmdContent) != expected { + t.Errorf(".cmd content = %q, want %q", string(cmdContent), expected) + } +} + +func TestRemoveShim_RemovesCmdWrapperOnWindows(t *testing.T) { + if runtime.GOOS != constants.OSWindows { + t.Skip("Skipping Windows-specific test") + } + + tmpRoot := t.TempDir() + shimsDir := filepath.Join(tmpRoot, "shims") + if err := os.MkdirAll(shimsDir, 0755); err != nil { + t.Fatalf("Failed to create shims directory: %v", err) + } + + // Create both .exe and .cmd files + exePath := filepath.Join(shimsDir, "npm.exe") + cmdPath := filepath.Join(shimsDir, "npm.cmd") + if err := os.WriteFile(exePath, []byte("fake shim"), 0755); err != nil { + t.Fatalf("Failed to create .exe: %v", err) + } + if err := os.WriteFile(cmdPath, []byte("@echo off\r\n"), 0644); err != nil { + t.Fatalf("Failed to create .cmd: %v", err) + } + + // Remove both files + if err := os.Remove(exePath); err != nil { + t.Fatalf("Failed to remove .exe: %v", err) + } + if err := os.Remove(cmdPath); err != nil { + t.Fatalf("Failed to remove .cmd: %v", err) + } + + // Verify both are gone + if _, err := os.Stat(exePath); !os.IsNotExist(err) { + t.Error(".exe shim was not removed") + } + if _, err := os.Stat(cmdPath); !os.IsNotExist(err) { + t.Error(".cmd wrapper was not removed") + } +} + +func TestListShims_SkipsCmdFiles(t *testing.T) { + if runtime.GOOS != constants.OSWindows { + t.Skip("Skipping Windows-specific test") + } + + tmpRoot := t.TempDir() + shimsDir := filepath.Join(tmpRoot, "shims") + if err := os.MkdirAll(shimsDir, 0755); err != nil { + t.Fatalf("Failed to create shims directory: %v", err) + } + + // Create .exe and .cmd files + files := map[string]string{ + "npm.exe": "fake shim", + "npm.cmd": "@echo off\r\n", + "npx.exe": "fake shim", + "npx.cmd": "@echo off\r\n", + } + for name, content := range files { + path := filepath.Join(shimsDir, name) + if err := os.WriteFile(path, []byte(content), 0755); err != nil { + t.Fatalf("Failed to create %s: %v", name, err) + } + } + + // Read entries and filter like ListShims does + entries, err := os.ReadDir(shimsDir) + if err != nil { + t.Fatalf("Failed to read shims directory: %v", err) + } + + var shims []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + ext := filepath.Ext(name) + if ext == constants.ExtCmd || ext == ".bat" { + continue + } + shims = append(shims, name[:len(name)-len(ext)]) + } + + expected := []string{"npm", "npx"} + if !reflect.DeepEqual(shims, expected) { + t.Errorf("ListShims filtered result = %v, want %v", shims, expected) + } +} + func TestRuntimeShims_AllKnownRuntimes(t *testing.T) { // Verify all known runtimes have shim mappings knownRuntimes := []string{"python", "node", "ruby", "go"}