Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.10.1
github.com/spf13/viper v1.21.0
golang.org/x/sys v0.40.0
google.golang.org/grpc v1.72.3
google.golang.org/protobuf v1.36.8
k8s.io/api v0.35.0
Expand Down Expand Up @@ -103,7 +104,6 @@ require (
golang.org/x/net v0.49.0 // indirect
golang.org/x/oauth2 v0.33.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/term v0.39.0 // indirect
golang.org/x/text v0.33.0 // indirect
golang.org/x/time v0.9.0 // indirect
Expand Down
28 changes: 22 additions & 6 deletions pkg/drift/remediation.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,27 @@ func runKubernetesUpgradeRemediation(
return result, err
}

// rebootCommandRunner is a function that executes the reboot command and returns combined output.
// It is a field type to allow injection in tests.
type rebootCommandRunner func(ctx context.Context) ([]byte, error)

func defaultRebootRunner(ctx context.Context) ([]byte, error) {
cmd := exec.CommandContext(ctx, "systemctl", "reboot") // #nosec G204 -- fixed command, no user input
return cmd.CombinedOutput()
}

func runRebootRemediation(
ctx context.Context,
logger *logrus.Logger,
) error {
return runRebootRemediationWithDeps(ctx, logger, systemd.New(), defaultRebootRunner)
}

func runRebootRemediationWithDeps(
ctx context.Context,
logger *logrus.Logger,
mgr systemd.Manager,
runner rebootCommandRunner,
) error {
// Key design points:
// - Only reboot if aks-flex-node-agent is running as a systemd service
Expand All @@ -292,8 +310,7 @@ func runRebootRemediation(

// Check if aks-flex-node-agent is managed by systemd.
// We use GetUnitStatus to check if the service is active and running under systemd.
mgr := systemd.New()
status, err := mgr.GetUnitStatus(ctx, agentServiceName)
unitStatus, err := mgr.GetUnitStatus(ctx, agentServiceName)
if err != nil {
if errors.Is(err, systemd.ErrUnitNotFound) {
logger.Warn("aks-flex-node-agent is not running as a systemd service; skipping reboot")
Expand All @@ -305,8 +322,8 @@ func runRebootRemediation(
}

// Only reboot if the service is active
if status.ActiveState != systemd.UnitActiveStateActive {
logger.Warnf("aks-flex-node-agent service is not active (state: %s); skipping reboot", status.ActiveState)
if unitStatus.ActiveState != systemd.UnitActiveStateActive {
logger.Warnf("aks-flex-node-agent service is not active (state: %s); skipping reboot", unitStatus.ActiveState)
return nil
}

Expand All @@ -323,8 +340,7 @@ func runRebootRemediation(
defer cancel()
}

cmd := exec.CommandContext(rebootCtx, "systemctl", "reboot")
output, err := cmd.CombinedOutput()
output, err := runner(rebootCtx)
if err != nil {
// If the context was canceled or timed out, surface that information explicitly.
if errors.Is(rebootCtx.Err(), context.DeadlineExceeded) {
Expand Down
148 changes: 148 additions & 0 deletions pkg/drift/remediation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ import (
"testing"
"time"

"github.com/coreos/go-systemd/v22/dbus"
"github.com/sirupsen/logrus"

"github.com/Azure/AKSFlexNode/pkg/bootstrapper"
"github.com/Azure/AKSFlexNode/pkg/config"
"github.com/Azure/AKSFlexNode/pkg/spec"
"github.com/Azure/AKSFlexNode/pkg/status"
"github.com/Azure/AKSFlexNode/pkg/systemd"
)

type countingDetector struct {
Expand Down Expand Up @@ -218,3 +220,149 @@ func TestShouldMarkKubeletUnhealthyAfterUpgradeFailure(t *testing.T) {
t.Fatalf("nil error marked unhealthy=true, want false")
}
}

// stubSystemdManager is a test double for systemd.Manager.
type stubSystemdManager struct {
getUnitStatusResult dbus.UnitStatus
getUnitStatusErr error
}

var _ systemd.Manager = (*stubSystemdManager)(nil)

func (s *stubSystemdManager) DaemonReload(_ context.Context) error { return nil }
func (s *stubSystemdManager) EnableUnit(_ context.Context, _ string) error { return nil }
func (s *stubSystemdManager) DisableUnit(_ context.Context, _ string) error { return nil }
func (s *stubSystemdManager) MaskUnit(_ context.Context, _ string) error { return nil }
func (s *stubSystemdManager) StartUnit(_ context.Context, _ string) error { return nil }
func (s *stubSystemdManager) StopUnit(_ context.Context, _ string) error { return nil }
func (s *stubSystemdManager) ReloadOrRestartUnit(_ context.Context, _ string) error { return nil }

func (s *stubSystemdManager) GetUnitStatus(_ context.Context, _ string) (dbus.UnitStatus, error) {
return s.getUnitStatusResult, s.getUnitStatusErr
}

func (s *stubSystemdManager) EnsureUnitFile(_ context.Context, _ string, _ []byte) (bool, error) {
return false, nil
}

func (s *stubSystemdManager) EnsureDropInFile(_ context.Context, _ string, _ string, _ []byte) (bool, error) {
return false, nil
}

// stubRebootRunner records whether it was called and returns a configurable result.
type stubRebootRunner struct {
called bool
output []byte
err error
}

func (s *stubRebootRunner) run(_ context.Context) ([]byte, error) {
s.called = true
return s.output, s.err
}

func TestRunRebootRemediationWithDeps(t *testing.T) {
t.Parallel()

statusActive := dbus.UnitStatus{ActiveState: systemd.UnitActiveStateActive}
statusInactive := dbus.UnitStatus{ActiveState: systemd.UnitActiveStateInactive}
dbusErr := errors.New("dbus connection refused")
rebootErr := errors.New("exit status 1")
timeoutErr := errors.New("signal: killed")

// pastDeadlineCtx returns a context whose deadline has already elapsed so that
// runRebootRemediationWithDeps receives a context with hasDeadline==true and
// ctx.Err()==DeadlineExceeded, exercising the timeout error path.
pastDeadlineCtx := func() (context.Context, context.CancelFunc) {
return context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
}

tests := []struct {
name string
mgr *stubSystemdManager
runner *stubRebootRunner
makeCtx func() (context.Context, context.CancelFunc)
wantErr error
wantErrWrapped bool // whether wantErr should be wrapped inside the returned error
wantRunnerCalled bool
}{
{
name: "unit-not-found/no-op",
mgr: &stubSystemdManager{getUnitStatusErr: systemd.ErrUnitNotFound},
runner: &stubRebootRunner{},
wantRunnerCalled: false,
},
{
name: "inactive/no-op",
mgr: &stubSystemdManager{getUnitStatusResult: statusInactive},
runner: &stubRebootRunner{},
wantRunnerCalled: false,
},
{
name: "active/executes-reboot",
mgr: &stubSystemdManager{getUnitStatusResult: statusActive},
runner: &stubRebootRunner{},
wantRunnerCalled: true,
},
{
name: "get-unit-status-error/surfaced",
mgr: &stubSystemdManager{getUnitStatusErr: dbusErr},
runner: &stubRebootRunner{},
wantErr: dbusErr,
wantErrWrapped: true,
wantRunnerCalled: false,
},
{
name: "reboot-command-fails/surfaced",
mgr: &stubSystemdManager{getUnitStatusResult: statusActive},
runner: &stubRebootRunner{err: rebootErr},
wantErr: rebootErr,
wantErrWrapped: true,
wantRunnerCalled: true,
},
{
name: "reboot-command-times-out/surfaced",
mgr: &stubSystemdManager{getUnitStatusResult: statusActive},
runner: &stubRebootRunner{err: timeoutErr},
makeCtx: pastDeadlineCtx,
wantErr: timeoutErr,
wantErrWrapped: true,
wantRunnerCalled: true,
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

var ctx context.Context
var cancel context.CancelFunc
if tc.makeCtx != nil {
ctx, cancel = tc.makeCtx()
} else {
ctx, cancel = context.WithCancel(context.Background())
}
defer cancel()

err := runRebootRemediationWithDeps(ctx, logrus.New(), tc.mgr, tc.runner.run)

if tc.wantErr == nil {
if err != nil {
t.Fatalf("err=%v, want nil", err)
}
} else {
if err == nil {
t.Fatalf("err=nil, want error wrapping %v", tc.wantErr)
}
if tc.wantErrWrapped && !errors.Is(err, tc.wantErr) {
t.Fatalf("err=%v, want to wrap %v", err, tc.wantErr)
}
}

if tc.wantRunnerCalled != tc.runner.called {
t.Fatalf("runner.called=%v, want %v", tc.runner.called, tc.wantRunnerCalled)
}
})
}
}