diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 54a7a4f61..177e91d19 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -50,7 +50,8 @@ jobs: output-file-path: bench.txt # Access token to push results to gh-pages branch github-token: ${{ secrets.GITHUB_TOKEN }} - auto-push: true + # Fork PRs cannot push to upstream gh-pages with GITHUB_TOKEN. + auto-push: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository }} # Fail if performance drops by more than 50% alert-threshold: '200%' comment-on-alert: true diff --git a/internal/autoscaler/protos/externalgrpc.pb.go b/internal/autoscaler/protos/externalgrpc.pb.go index 6bc0ceb36..efb7c52bf 100644 --- a/internal/autoscaler/protos/externalgrpc.pb.go +++ b/internal/autoscaler/protos/externalgrpc.pb.go @@ -19,6 +19,8 @@ // protoc v3.21.12 // source: cloudprovider/externalgrpc/protos/externalgrpc.proto +//lint:file-ignore SA1019 Generated protobuf compatibility shims may reference deprecated APIs. + package protos import ( diff --git a/internal/repositories/libvirt/adapter.go b/internal/repositories/libvirt/adapter.go index 52bd73e8e..4be947501 100644 --- a/internal/repositories/libvirt/adapter.go +++ b/internal/repositories/libvirt/adapter.go @@ -22,10 +22,38 @@ import ( "time" "github.com/digitalocean/go-libvirt" + libvirtsocket "github.com/digitalocean/go-libvirt/socket" "github.com/google/uuid" "github.com/poyrazk/thecloud/internal/core/ports" ) +type libvirtUnixDialer struct { + uri string + timeout time.Duration +} + +var _ libvirtsocket.Dialer = (*libvirtUnixDialer)(nil) + +func (d *libvirtUnixDialer) Dial() (net.Conn, error) { + return net.DialTimeout("unix", d.uri, d.timeout) +} + +// SocketDialer is the interface for creating network connections. +// It allows dependency injection for testing. +type SocketDialer interface { + Dial() (net.Conn, error) +} + +// DialerOption configures NewLibvirtAdapter. +type DialerOption func(*LibvirtAdapter) + +// WithSocketDialer injects a custom socket dialer for testing. +func WithSocketDialer(d SocketDialer) DialerOption { + return func(a *LibvirtAdapter) { + a.socketDialer = d + } +} + const ( defaultPoolName = "default" userDataFileName = "user-data" @@ -66,6 +94,9 @@ type LibvirtAdapter struct { execCommandContext func(ctx context.Context, name string, arg ...string) *exec.Cmd lookPath func(file string) (string, error) osOpen func(name string) (*os.File, error) + + // Socket dialer for testability + socketDialer SocketDialer } func (a *LibvirtAdapter) recordPortMapping(name string, hPortStr string, cPort string) error { @@ -85,7 +116,7 @@ func (a *LibvirtAdapter) recordPortMapping(name string, hPortStr string, cPort s } // NewLibvirtAdapter creates a LibvirtAdapter connected to the provided URI. -func NewLibvirtAdapter(logger *slog.Logger, uri string) (*LibvirtAdapter, error) { +func NewLibvirtAdapter(logger *slog.Logger, uri string, opts ...DialerOption) (*LibvirtAdapter, error) { if uri == "" { uri = os.Getenv("LIBVIRT_URI") } @@ -93,27 +124,8 @@ func NewLibvirtAdapter(logger *slog.Logger, uri string) (*LibvirtAdapter, error) uri = "/var/run/libvirt/libvirt-sock" } - // Connect to libvirt socket - c, err := net.DialTimeout("unix", uri, 2*time.Second) - if err != nil { - // Fallback to session mode if system socket fails - if !strings.Contains(uri, "session") { - sessionUri := filepath.Join(os.Getenv("HOME"), ".cache/libvirt/libvirt-sock") - if c2, err2 := net.DialTimeout("unix", sessionUri, 2*time.Second); err2 == nil { - c = c2 - uri = sessionUri - } else { - return nil, fmt.Errorf("failed to dial libvirt (system and session): %w", err) - } - } else { - return nil, fmt.Errorf("failed to dial libvirt: %w", err) - } - } - - //nolint:staticcheck - l := libvirt.New(c) + // Create adapter with default dialer first (needed for option application) adapter := &LibvirtAdapter{ - client: &RealLibvirtClient{conn: l}, logger: logger, uri: uri, portMappings: make(map[string]map[string]int), @@ -128,6 +140,39 @@ func NewLibvirtAdapter(logger *slog.Logger, uri string) (*LibvirtAdapter, error) osOpen: os.Open, } + // Apply options (e.g., inject mock dialer for testing) + for _, opt := range opts { + opt(adapter) + } + + // Determine dialer to use + var dialer SocketDialer + if adapter.socketDialer != nil { + dialer = adapter.socketDialer + } else { + dialer = &libvirtUnixDialer{uri: uri, timeout: 2 * time.Second} + } + + // Verify connectivity + if _, err := dialer.Dial(); err != nil { + // Fallback to session mode if system socket fails + if !strings.Contains(uri, "session") { + sessionUri := filepath.Join(os.Getenv("HOME"), ".cache/libvirt/libvirt-sock") + sessionDialer := &libvirtUnixDialer{uri: sessionUri, timeout: 2 * time.Second} + if _, err2 := sessionDialer.Dial(); err2 == nil { + dialer = sessionDialer + adapter.uri = sessionUri + } else { + return nil, fmt.Errorf("failed to dial libvirt (system and session): %w", err) + } + } else { + return nil, fmt.Errorf("failed to dial libvirt: %w", err) + } + } + + l := libvirt.NewWithDialer(dialer) + adapter.client = &RealLibvirtClient{conn: l} + connectCtx, connectCancel := context.WithTimeout(context.Background(), 10*time.Second) defer connectCancel() diff --git a/internal/repositories/libvirt/adapter_unit_test.go b/internal/repositories/libvirt/adapter_unit_test.go index 6b8f87f9b..52801d0b2 100644 --- a/internal/repositories/libvirt/adapter_unit_test.go +++ b/internal/repositories/libvirt/adapter_unit_test.go @@ -829,3 +829,78 @@ func TestLibvirtAdapter_IsNotFound(t *testing.T) { assert.False(t, a.isNotFound(err)) }) } + +func TestLibvirtUnixDialer(t *testing.T) { + t.Parallel() + + t.Run("ImplementsSocketDialerInterface", func(t *testing.T) { + t.Parallel() + var dialer SocketDialer = &libvirtUnixDialer{uri: "/var/run/libvirt/libvirt-sock", timeout: 2 * time.Second} + assert.NotNil(t, dialer) + }) + + t.Run("DialTimeoutReturnsConn", func(t *testing.T) { + t.Parallel() + // This test requires a real Unix socket to exist. + // If /var/run/libvirt/libvirt-sock doesn't exist, this will fail. + // In CI without libvirt, we expect this to fail gracefully. + dialer := &libvirtUnixDialer{uri: "/var/run/libvirt/libvirt-sock", timeout: 100 * time.Millisecond} + conn, err := dialer.Dial() + if err != nil { + // Expected if no libvirt socket exists (e.g., in CI without libvirt) + return + } + require.NotNil(t, conn) + conn.Close() + }) + + t.Run("DialInvalidPathReturnsError", func(t *testing.T) { + t.Parallel() + dialer := &libvirtUnixDialer{uri: "/nonexistent/path/to/socket", timeout: 100 * time.Millisecond} + conn, err := dialer.Dial() + require.Error(t, err) + assert.Nil(t, conn) + }) + + t.Run("DialWithVeryShortTimeout", func(t *testing.T) { + t.Parallel() + dialer := &libvirtUnixDialer{uri: "/var/run/libvirt/libvirt-sock", timeout: 1 * time.Millisecond} + conn, err := dialer.Dial() + // Should either succeed quickly or timeout/fail + if err != nil { + // Expected - socket unreachable or timeout + return + } + require.NotNil(t, conn) + conn.Close() + }) +} + +func TestWithSocketDialerOption(t *testing.T) { + t.Parallel() + + t.Run("SetsSocketDialerField", func(t *testing.T) { + t.Parallel() + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + // Create a mock dialer + mockDialer := &mockSocketDialer{err: fmt.Errorf("test error")} + + // Create adapter with the option + a := &LibvirtAdapter{logger: logger} + WithSocketDialer(mockDialer)(a) + + // Verify the dialer was set + assert.Equal(t, mockDialer, a.socketDialer) + }) +} + +// mockSocketDialer is a test double for SocketDialer. +type mockSocketDialer struct { + conn net.Conn + err error +} + +func (m *mockSocketDialer) Dial() (net.Conn, error) { + return m.conn, m.err +}