Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/internal/constants/platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ const (
// File extensions
const (
ExtExe = ".exe"
ExtCmd = ".cmd"
)
37 changes: 33 additions & 4 deletions src/internal/shim/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,27 @@ func (m *Manager) CreateShim(shimName string) error {
}
}

// On Windows, create a companion .cmd wrapper
if runtime.GOOS == constants.OSWindows {
if err := createCmdWrapper(shimName); err != nil {
return fmt.Errorf("failed to create .cmd wrapper for %s: %w", shimName, err)
}
}

return nil
}

// createCmdWrapper writes a .cmd file that forwards to the .exe shim
func createCmdWrapper(shimName string) error {
// shimName is the base name (e.g., "python"), ShimPath adds .exe on Windows
// Build the .cmd path by replacing .exe with .cmd in the shim path
exePath := config.ShimPath(shimName)
cmdPath := exePath[:len(exePath)-len(constants.ExtExe)] + constants.ExtCmd

content := fmt.Sprintf("@echo off\r\n\"%%~dp0%s%s\" %%*\r\n", shimName, constants.ExtExe)
return os.WriteFile(cmdPath, []byte(content), 0644)
}

// CreateShims creates multiple shims at once
func (m *Manager) CreateShims(shimNames []string) error {
for _, shimName := range shimNames {
Expand All @@ -98,6 +116,14 @@ func (m *Manager) RemoveShim(shimName string) error {
return fmt.Errorf("failed to remove shim %s: %w", shimName, err)
}

// On Windows, also remove the companion .cmd wrapper
if runtime.GOOS == constants.OSWindows {
cmdPath := shimPath[:len(shimPath)-len(constants.ExtExe)] + constants.ExtCmd
if err := os.Remove(cmdPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove .cmd wrapper for %s: %w", shimName, err)
}
}

return nil
}

Expand All @@ -118,10 +144,13 @@ func (m *Manager) ListShims() ([]string, error) {
for _, entry := range entries {
if !entry.IsDir() {
name := entry.Name()
// Remove .exe extension on Windows for consistency
if runtime.GOOS == "windows" {
name = filepath.Base(name)
name = name[:len(name)-len(filepath.Ext(name))]
if runtime.GOOS == constants.OSWindows {
ext := filepath.Ext(name)
// Skip .cmd/.bat wrappers — only list .exe shims
if ext == constants.ExtCmd || ext == ".bat" {
continue
}
name = name[:len(name)-len(ext)]
}
shims = append(shims, name)
}
Expand Down
135 changes: 135 additions & 0 deletions src/internal/shim/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,141 @@ func TestCopyFile_Errors(t *testing.T) {
}
}

func TestCreateShim_CreatesCmdWrapperOnWindows(t *testing.T) {
if runtime.GOOS != constants.OSWindows {
t.Skip("Skipping Windows-specific test")
}

tmpRoot := t.TempDir()
shimsDir := filepath.Join(tmpRoot, "shims")
if err := os.MkdirAll(shimsDir, 0755); err != nil {
t.Fatalf("Failed to create shims directory: %v", err)
}

// Create a fake shim source
shimSourcePath := filepath.Join(tmpRoot, "dtvem-shim.exe")
if err := os.WriteFile(shimSourcePath, []byte("fake shim content"), 0755); err != nil {
t.Fatalf("Failed to create fake shim: %v", err)
}

// Create the .exe shim
exePath := filepath.Join(shimsDir, "npm.exe")
if err := copyFile(shimSourcePath, exePath); err != nil {
t.Fatalf("copyFile() error: %v", err)
}

// Create the .cmd wrapper using the helper
cmdPath := filepath.Join(shimsDir, "npm.cmd")
content := "@echo off\r\n\"%~dp0npm.exe\" %*\r\n"
if err := os.WriteFile(cmdPath, []byte(content), 0644); err != nil {
t.Fatalf("Failed to write .cmd wrapper: %v", err)
}

// Verify .cmd file exists
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
t.Error(".cmd wrapper was not created")
}

// Verify .cmd content
cmdContent, err := os.ReadFile(cmdPath)
if err != nil {
t.Fatalf("Failed to read .cmd wrapper: %v", err)
}

expected := "@echo off\r\n\"%~dp0npm.exe\" %*\r\n"
if string(cmdContent) != expected {
t.Errorf(".cmd content = %q, want %q", string(cmdContent), expected)
}
}

func TestRemoveShim_RemovesCmdWrapperOnWindows(t *testing.T) {
if runtime.GOOS != constants.OSWindows {
t.Skip("Skipping Windows-specific test")
}

tmpRoot := t.TempDir()
shimsDir := filepath.Join(tmpRoot, "shims")
if err := os.MkdirAll(shimsDir, 0755); err != nil {
t.Fatalf("Failed to create shims directory: %v", err)
}

// Create both .exe and .cmd files
exePath := filepath.Join(shimsDir, "npm.exe")
cmdPath := filepath.Join(shimsDir, "npm.cmd")
if err := os.WriteFile(exePath, []byte("fake shim"), 0755); err != nil {
t.Fatalf("Failed to create .exe: %v", err)
}
if err := os.WriteFile(cmdPath, []byte("@echo off\r\n"), 0644); err != nil {
t.Fatalf("Failed to create .cmd: %v", err)
}

// Remove both files
if err := os.Remove(exePath); err != nil {
t.Fatalf("Failed to remove .exe: %v", err)
}
if err := os.Remove(cmdPath); err != nil {
t.Fatalf("Failed to remove .cmd: %v", err)
}

// Verify both are gone
if _, err := os.Stat(exePath); !os.IsNotExist(err) {
t.Error(".exe shim was not removed")
}
if _, err := os.Stat(cmdPath); !os.IsNotExist(err) {
t.Error(".cmd wrapper was not removed")
}
}

func TestListShims_SkipsCmdFiles(t *testing.T) {
if runtime.GOOS != constants.OSWindows {
t.Skip("Skipping Windows-specific test")
}

tmpRoot := t.TempDir()
shimsDir := filepath.Join(tmpRoot, "shims")
if err := os.MkdirAll(shimsDir, 0755); err != nil {
t.Fatalf("Failed to create shims directory: %v", err)
}

// Create .exe and .cmd files
files := map[string]string{
"npm.exe": "fake shim",
"npm.cmd": "@echo off\r\n",
"npx.exe": "fake shim",
"npx.cmd": "@echo off\r\n",
}
for name, content := range files {
path := filepath.Join(shimsDir, name)
if err := os.WriteFile(path, []byte(content), 0755); err != nil {
t.Fatalf("Failed to create %s: %v", name, err)
}
}

// Read entries and filter like ListShims does
entries, err := os.ReadDir(shimsDir)
if err != nil {
t.Fatalf("Failed to read shims directory: %v", err)
}

var shims []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
ext := filepath.Ext(name)
if ext == constants.ExtCmd || ext == ".bat" {
continue
}
shims = append(shims, name[:len(name)-len(ext)])
}

expected := []string{"npm", "npx"}
if !reflect.DeepEqual(shims, expected) {
t.Errorf("ListShims filtered result = %v, want %v", shims, expected)
}
}

func TestRuntimeShims_AllKnownRuntimes(t *testing.T) {
// Verify all known runtimes have shim mappings
knownRuntimes := []string{"python", "node", "ruby", "go"}
Expand Down