From b8960609eb6e87a0b4aeeb3352e88bd7c1e92353 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Mar 2026 22:56:54 +0000 Subject: [PATCH 1/2] Initial plan From 65df37944c754b6d9e4c0be578abfe15aeabdb80 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Mar 2026 23:07:19 +0000 Subject: [PATCH 2/2] Add tests for runRebootRemediation via dependency injection Co-authored-by: runzhen <32292691+runzhen@users.noreply.github.com> Agent-Logs-Url: https://github.com/Azure/AKSFlexNode/sessions/c9b7f2cd-a2e3-41b0-a177-ab46f6eaa9f3 --- go.mod | 2 +- pkg/drift/remediation.go | 28 +++++-- pkg/drift/remediation_test.go | 148 ++++++++++++++++++++++++++++++++++ 3 files changed, 171 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index a880a5f..45a6332 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.10.1 github.com/spf13/viper v1.21.0 + golang.org/x/sys v0.40.0 google.golang.org/grpc v1.72.3 google.golang.org/protobuf v1.36.8 k8s.io/api v0.35.0 @@ -103,7 +104,6 @@ require ( golang.org/x/net v0.49.0 // indirect golang.org/x/oauth2 v0.33.0 // indirect golang.org/x/sync v0.19.0 // indirect - golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.39.0 // indirect golang.org/x/text v0.33.0 // indirect golang.org/x/time v0.9.0 // indirect diff --git a/pkg/drift/remediation.go b/pkg/drift/remediation.go index 90be5a6..a2b30ca 100644 --- a/pkg/drift/remediation.go +++ b/pkg/drift/remediation.go @@ -278,9 +278,27 @@ func runKubernetesUpgradeRemediation( return result, err } +// rebootCommandRunner is a function that executes the reboot command and returns combined output. +// It is a field type to allow injection in tests. +type rebootCommandRunner func(ctx context.Context) ([]byte, error) + +func defaultRebootRunner(ctx context.Context) ([]byte, error) { + cmd := exec.CommandContext(ctx, "systemctl", "reboot") // #nosec G204 -- fixed command, no user input + return cmd.CombinedOutput() +} + func runRebootRemediation( ctx context.Context, logger *logrus.Logger, +) error { + return runRebootRemediationWithDeps(ctx, logger, systemd.New(), defaultRebootRunner) +} + +func runRebootRemediationWithDeps( + ctx context.Context, + logger *logrus.Logger, + mgr systemd.Manager, + runner rebootCommandRunner, ) error { // Key design points: // - Only reboot if aks-flex-node-agent is running as a systemd service @@ -292,8 +310,7 @@ func runRebootRemediation( // Check if aks-flex-node-agent is managed by systemd. // We use GetUnitStatus to check if the service is active and running under systemd. - mgr := systemd.New() - status, err := mgr.GetUnitStatus(ctx, agentServiceName) + unitStatus, err := mgr.GetUnitStatus(ctx, agentServiceName) if err != nil { if errors.Is(err, systemd.ErrUnitNotFound) { logger.Warn("aks-flex-node-agent is not running as a systemd service; skipping reboot") @@ -305,8 +322,8 @@ func runRebootRemediation( } // Only reboot if the service is active - if status.ActiveState != systemd.UnitActiveStateActive { - logger.Warnf("aks-flex-node-agent service is not active (state: %s); skipping reboot", status.ActiveState) + if unitStatus.ActiveState != systemd.UnitActiveStateActive { + logger.Warnf("aks-flex-node-agent service is not active (state: %s); skipping reboot", unitStatus.ActiveState) return nil } @@ -323,8 +340,7 @@ func runRebootRemediation( defer cancel() } - cmd := exec.CommandContext(rebootCtx, "systemctl", "reboot") - output, err := cmd.CombinedOutput() + output, err := runner(rebootCtx) if err != nil { // If the context was canceled or timed out, surface that information explicitly. if errors.Is(rebootCtx.Err(), context.DeadlineExceeded) { diff --git a/pkg/drift/remediation_test.go b/pkg/drift/remediation_test.go index 9037a94..3b8bc55 100644 --- a/pkg/drift/remediation_test.go +++ b/pkg/drift/remediation_test.go @@ -8,12 +8,14 @@ import ( "testing" "time" + "github.com/coreos/go-systemd/v22/dbus" "github.com/sirupsen/logrus" "github.com/Azure/AKSFlexNode/pkg/bootstrapper" "github.com/Azure/AKSFlexNode/pkg/config" "github.com/Azure/AKSFlexNode/pkg/spec" "github.com/Azure/AKSFlexNode/pkg/status" + "github.com/Azure/AKSFlexNode/pkg/systemd" ) type countingDetector struct { @@ -218,3 +220,149 @@ func TestShouldMarkKubeletUnhealthyAfterUpgradeFailure(t *testing.T) { t.Fatalf("nil error marked unhealthy=true, want false") } } + +// stubSystemdManager is a test double for systemd.Manager. +type stubSystemdManager struct { + getUnitStatusResult dbus.UnitStatus + getUnitStatusErr error +} + +var _ systemd.Manager = (*stubSystemdManager)(nil) + +func (s *stubSystemdManager) DaemonReload(_ context.Context) error { return nil } +func (s *stubSystemdManager) EnableUnit(_ context.Context, _ string) error { return nil } +func (s *stubSystemdManager) DisableUnit(_ context.Context, _ string) error { return nil } +func (s *stubSystemdManager) MaskUnit(_ context.Context, _ string) error { return nil } +func (s *stubSystemdManager) StartUnit(_ context.Context, _ string) error { return nil } +func (s *stubSystemdManager) StopUnit(_ context.Context, _ string) error { return nil } +func (s *stubSystemdManager) ReloadOrRestartUnit(_ context.Context, _ string) error { return nil } + +func (s *stubSystemdManager) GetUnitStatus(_ context.Context, _ string) (dbus.UnitStatus, error) { + return s.getUnitStatusResult, s.getUnitStatusErr +} + +func (s *stubSystemdManager) EnsureUnitFile(_ context.Context, _ string, _ []byte) (bool, error) { + return false, nil +} + +func (s *stubSystemdManager) EnsureDropInFile(_ context.Context, _ string, _ string, _ []byte) (bool, error) { + return false, nil +} + +// stubRebootRunner records whether it was called and returns a configurable result. +type stubRebootRunner struct { + called bool + output []byte + err error +} + +func (s *stubRebootRunner) run(_ context.Context) ([]byte, error) { + s.called = true + return s.output, s.err +} + +func TestRunRebootRemediationWithDeps(t *testing.T) { + t.Parallel() + + statusActive := dbus.UnitStatus{ActiveState: systemd.UnitActiveStateActive} + statusInactive := dbus.UnitStatus{ActiveState: systemd.UnitActiveStateInactive} + dbusErr := errors.New("dbus connection refused") + rebootErr := errors.New("exit status 1") + timeoutErr := errors.New("signal: killed") + + // pastDeadlineCtx returns a context whose deadline has already elapsed so that + // runRebootRemediationWithDeps receives a context with hasDeadline==true and + // ctx.Err()==DeadlineExceeded, exercising the timeout error path. + pastDeadlineCtx := func() (context.Context, context.CancelFunc) { + return context.WithDeadline(context.Background(), time.Now().Add(-time.Second)) + } + + tests := []struct { + name string + mgr *stubSystemdManager + runner *stubRebootRunner + makeCtx func() (context.Context, context.CancelFunc) + wantErr error + wantErrWrapped bool // whether wantErr should be wrapped inside the returned error + wantRunnerCalled bool + }{ + { + name: "unit-not-found/no-op", + mgr: &stubSystemdManager{getUnitStatusErr: systemd.ErrUnitNotFound}, + runner: &stubRebootRunner{}, + wantRunnerCalled: false, + }, + { + name: "inactive/no-op", + mgr: &stubSystemdManager{getUnitStatusResult: statusInactive}, + runner: &stubRebootRunner{}, + wantRunnerCalled: false, + }, + { + name: "active/executes-reboot", + mgr: &stubSystemdManager{getUnitStatusResult: statusActive}, + runner: &stubRebootRunner{}, + wantRunnerCalled: true, + }, + { + name: "get-unit-status-error/surfaced", + mgr: &stubSystemdManager{getUnitStatusErr: dbusErr}, + runner: &stubRebootRunner{}, + wantErr: dbusErr, + wantErrWrapped: true, + wantRunnerCalled: false, + }, + { + name: "reboot-command-fails/surfaced", + mgr: &stubSystemdManager{getUnitStatusResult: statusActive}, + runner: &stubRebootRunner{err: rebootErr}, + wantErr: rebootErr, + wantErrWrapped: true, + wantRunnerCalled: true, + }, + { + name: "reboot-command-times-out/surfaced", + mgr: &stubSystemdManager{getUnitStatusResult: statusActive}, + runner: &stubRebootRunner{err: timeoutErr}, + makeCtx: pastDeadlineCtx, + wantErr: timeoutErr, + wantErrWrapped: true, + wantRunnerCalled: true, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var ctx context.Context + var cancel context.CancelFunc + if tc.makeCtx != nil { + ctx, cancel = tc.makeCtx() + } else { + ctx, cancel = context.WithCancel(context.Background()) + } + defer cancel() + + err := runRebootRemediationWithDeps(ctx, logrus.New(), tc.mgr, tc.runner.run) + + if tc.wantErr == nil { + if err != nil { + t.Fatalf("err=%v, want nil", err) + } + } else { + if err == nil { + t.Fatalf("err=nil, want error wrapping %v", tc.wantErr) + } + if tc.wantErrWrapped && !errors.Is(err, tc.wantErr) { + t.Fatalf("err=%v, want to wrap %v", err, tc.wantErr) + } + } + + if tc.wantRunnerCalled != tc.runner.called { + t.Fatalf("runner.called=%v, want %v", tc.runner.called, tc.wantRunnerCalled) + } + }) + } +}