From 60ddb73e37bc91342765b539c9c75d347ad6f29e Mon Sep 17 00:00:00 2001 From: Michel Osswald Date: Thu, 11 Jun 2026 09:34:36 +0200 Subject: [PATCH] fix(clawpatch): refuse non-socket local runtime paths --- internal/localruntime/service.go | 21 ++++++++++++++-- internal/localruntime/service_test.go | 36 +++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/internal/localruntime/service.go b/internal/localruntime/service.go index d1797cc..6f66ce2 100644 --- a/internal/localruntime/service.go +++ b/internal/localruntime/service.go @@ -64,8 +64,8 @@ func NewService(opts Options) (*Service, error) { func (s *Service) SocketPath() string { return s.socketPath } func (s *Service) Start(ctx context.Context) error { - if err := os.Remove(s.socketPath); err != nil && !errors.Is(err, os.ErrNotExist) { - return fmt.Errorf("remove stale local runtime socket: %w", err) + if err := removeStaleSocket(s.socketPath); err != nil { + return err } ln, err := net.Listen("unix", s.socketPath) if err != nil { @@ -79,6 +79,23 @@ func (s *Service) Start(ctx context.Context) error { return nil } +func removeStaleSocket(path string) error { + info, err := os.Lstat(path) + if errors.Is(err, os.ErrNotExist) { + return nil + } + if err != nil { + return fmt.Errorf("inspect local runtime socket path: %w", err) + } + if info.Mode()&os.ModeSocket == 0 { + return fmt.Errorf("local runtime socket path exists and is not a socket: %s", path) + } + if err := os.Remove(path); err != nil { + return fmt.Errorf("remove stale local runtime socket: %w", err) + } + return nil +} + func (s *Service) Stop() { _ = s.Shutdown(context.Background()) } diff --git a/internal/localruntime/service_test.go b/internal/localruntime/service_test.go index 4416c9d..dcedeee 100644 --- a/internal/localruntime/service_test.go +++ b/internal/localruntime/service_test.go @@ -248,6 +248,42 @@ func TestServiceAllowsNonblockingHookPayloadDecodeFailure(t *testing.T) { } } +func TestServiceStartRefusesExistingRegularFile(t *testing.T) { + t.Parallel() + + core, err := runtimecore.New(&stubRuntime{}) + if err != nil { + t.Fatalf("runtimecore.New() error = %v", err) + } + socketPath := filepath.Join(t.TempDir(), "kontext.sock") + contents := []byte("keep this file") + if err := os.WriteFile(socketPath, contents, 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + service, err := NewService(Options{ + SocketPath: socketPath, + Core: core, + }) + if err != nil { + t.Fatalf("NewService() error = %v", err) + } + + err = service.Start(context.Background()) + if err == nil { + t.Fatal("Start() error = nil, want non-socket path error") + } + if !strings.Contains(err.Error(), "not a socket") { + t.Fatalf("Start() error = %v, want non-socket path error", err) + } + got, readErr := os.ReadFile(socketPath) + if readErr != nil { + t.Fatalf("ReadFile() error = %v", readErr) + } + if string(got) != string(contents) { + t.Fatalf("file contents = %q, want %q", got, contents) + } +} + func newTestService(t *testing.T, runtime *stubRuntime, asyncIngest bool) *Service { t.Helper()