Skip to content
Open
1 change: 1 addition & 0 deletions pkg/drift/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ package drift
func DefaultDetectors() []Detector {
return []Detector{
NewKubernetesVersionDetector(),
NewRebootDetector(),
}
}
1 change: 1 addition & 0 deletions pkg/drift/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type RemediationAction string
const (
RemediationActionUnspecified RemediationAction = ""
RemediationActionKubernetesUpgrade RemediationAction = "kubernetes-upgrade"
RemediationActionReboot RemediationAction = "reboot"
)

// Remediation describes what the agent should do to address a drift.
Expand Down
51 changes: 51 additions & 0 deletions pkg/drift/node_reboot.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package drift

import (
"context"

"github.com/Azure/AKSFlexNode/pkg/config"
"github.com/Azure/AKSFlexNode/pkg/spec"
"github.com/Azure/AKSFlexNode/pkg/status"
)

const NodeRebootFindingID = "node-reboot"

type RebootDetector struct{}

func NewRebootDetector() *RebootDetector {
return &RebootDetector{}
}

func (d *RebootDetector) Name() string {
return "RebootDetector"
}

func (d *RebootDetector) Detect(
ctx context.Context,
_ *config.Config,
_ *spec.ManagedClusterSpec,
statusSnap *status.NodeStatus,
) ([]Finding, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}

if statusSnap == nil {
return nil, nil
}

if !statusSnap.NeedReboot {
return nil, nil
}

return []Finding{
{
ID: NodeRebootFindingID,
Title: "Node reboot required",
Details: "Node status indicates a reboot is needed",
Remediation: Remediation{
Action: RemediationActionReboot,
},
},
}, nil
}
82 changes: 82 additions & 0 deletions pkg/drift/node_reboot_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package drift

import (
"context"
"testing"

"github.com/Azure/AKSFlexNode/pkg/status"
)

func TestRebootDetector_Name(t *testing.T) {
t.Parallel()
d := NewRebootDetector()
if name := d.Name(); name != "RebootDetector" {
t.Errorf("expected name %q, got %q", "RebootDetector", name)
}
}

func TestRebootDetector_NilStatus_NoFindings(t *testing.T) {
t.Parallel()
d := NewRebootDetector()
findings, err := d.Detect(context.Background(), nil, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(findings) != 0 {
t.Errorf("expected no findings, got %d", len(findings))
}
}

func TestRebootDetector_NeedRebootFalse_NoFindings(t *testing.T) {
t.Parallel()
d := NewRebootDetector()
statusSnap := &status.NodeStatus{
NeedReboot: false,
}
findings, err := d.Detect(context.Background(), nil, nil, statusSnap)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(findings) != 0 {
t.Errorf("expected no findings, got %d", len(findings))
}
}

func TestRebootDetector_NeedRebootTrue_ReturnsFinding(t *testing.T) {
t.Parallel()
d := NewRebootDetector()
statusSnap := &status.NodeStatus{
NeedReboot: true,
}
findings, err := d.Detect(context.Background(), nil, nil, statusSnap)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(findings) != 1 {
t.Fatalf("expected 1 finding, got %d", len(findings))
}

f := findings[0]
if f.ID != NodeRebootFindingID {
t.Errorf("expected ID %q, got %q", NodeRebootFindingID, f.ID)
}
if f.Title != "Node reboot required" {
t.Errorf("unexpected title: %q", f.Title)
}
if f.Remediation.Action != RemediationActionReboot {
t.Errorf("expected action %q, got %q", RemediationActionReboot, f.Remediation.Action)
}
}

func TestRebootDetector_CanceledContext_ReturnsError(t *testing.T) {
t.Parallel()
d := NewRebootDetector()
ctx, cancel := context.WithCancel(context.Background())
cancel()

statusSnap := &status.NodeStatus{NeedReboot: true}
_, err := d.Detect(ctx, nil, nil, statusSnap)
if err == nil {
t.Fatal("expected error from canceled context")
}
}
72 changes: 72 additions & 0 deletions pkg/drift/remediation.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"os"
"os/exec"
"sync/atomic"
"time"

Expand All @@ -16,6 +17,7 @@ import (
"github.com/Azure/AKSFlexNode/pkg/kube"
"github.com/Azure/AKSFlexNode/pkg/spec"
"github.com/Azure/AKSFlexNode/pkg/status"
"github.com/Azure/AKSFlexNode/pkg/systemd"
)

const driftKubernetesUpgradeOperation = "drift-kubernetes-upgrade"
Expand All @@ -28,6 +30,8 @@ const (
upgradeStepUncordon = "uncordon"
)

const agentServiceName = "aks-flex-node-agent.service"

// maxManagedClusterSpecAge is a safety guard to avoid acting on very stale spec snapshots.
// In normal operation we run drift immediately after a successful spec collection, so this
// should rarely block remediation.
Expand Down Expand Up @@ -154,6 +158,13 @@ func detectAndRemediate(
logger.Info("Kubernetes upgrade remediation completed successfully")
return detectErr

case RemediationActionReboot:
if err := runRebootRemediation(ctx, logger); err != nil {
return fmt.Errorf("reboot remediation failed: %w", err)
}
logger.Info("Reboot remediation completed without error")
return detectErr

default:
return fmt.Errorf("unsupported drift remediation action: %q", plan.Action)
}
Expand Down Expand Up @@ -267,6 +278,67 @@ func runKubernetesUpgradeRemediation(
return result, err
}

func runRebootRemediation(
ctx context.Context,
logger *logrus.Logger,
) error {
// Key design points:
// - Only reboot if aks-flex-node-agent is running as a systemd service
// - If not running under systemd, skip reboot (agent may be running in development/test mode)
// - Use systemctl reboot for a clean shutdown
if logger == nil {
logger = logrus.New()
}

// Check if aks-flex-node-agent is managed by systemd.
// We use GetUnitStatus to check if the service is active and running under systemd.
mgr := systemd.New()
status, err := mgr.GetUnitStatus(ctx, agentServiceName)
if err != nil {
if errors.Is(err, systemd.ErrUnitNotFound) {
logger.Warn("aks-flex-node-agent is not running as a systemd service; skipping reboot")
// Not running under systemd is an expected scenario (e.g., dev/test); treat as a no-op, not an error.
return nil
}
logger.WithError(err).Warn("Failed to check systemd service status; aborting reboot remediation")
return fmt.Errorf("failed to check systemd service status: %w", err)
}

// Only reboot if the service is active
if status.ActiveState != systemd.UnitActiveStateActive {
logger.Warnf("aks-flex-node-agent service is not active (state: %s); skipping reboot", status.ActiveState)
return nil
}

logger.Info("Initiating system reboot via systemctl")

// Use systemctl reboot for a clean shutdown.
// This will gracefully stop services and sync filesystems before rebooting.
// To avoid silently ignoring immediate failures (e.g., DBus unavailable), run the
// command and check its exit status, using a short timeout if no deadline is set.
rebootCtx := ctx
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
rebootCtx, cancel = context.WithTimeout(ctx, 30*time.Second)
defer cancel()
}

cmd := exec.CommandContext(rebootCtx, "systemctl", "reboot")
output, err := cmd.CombinedOutput()
if err != nil {
// If the context was canceled or timed out, surface that information explicitly.
if errors.Is(rebootCtx.Err(), context.DeadlineExceeded) {
logger.WithError(err).WithField("output", string(output)).
Error("systemctl reboot timed out")
return fmt.Errorf("systemctl reboot timed out: %w", err)
}
logger.WithError(err).WithField("output", string(output)).
Error("systemctl reboot failed")
return fmt.Errorf("systemctl reboot failed: %w", err)
}
return nil
}

// handleExecutionResult mirrors main's handleExecutionResult but lives in drift so remediation
// can share the same logging and error semantics.
func handleExecutionResult(result *bootstrapper.ExecutionResult, operation string, logger *logrus.Logger) error {
Expand Down
73 changes: 73 additions & 0 deletions pkg/status/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package status
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
Expand All @@ -13,6 +14,7 @@ import (
"github.com/Azure/AKSFlexNode/pkg/kube"
"github.com/Azure/AKSFlexNode/pkg/utils"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
Expand Down Expand Up @@ -60,6 +62,9 @@ func (c *Collector) CollectStatus(ctx context.Context) (*NodeStatus, error) {
}
status.ArcStatus = arcStatus

// Check if reboot is needed node condition
status.NeedReboot = c.checkRebootNeeded(ctx)

return status, nil
}

Expand Down Expand Up @@ -297,3 +302,71 @@ func GetStatusFilePath() string {
// Fallback to temp directory for testing/development
return filepath.Join("/tmp/aks-flex-node", "status.json")
}

func getBootTime() (time.Time, error) {
var sysinfo unix.Sysinfo_t
if err := unix.Sysinfo(&sysinfo); err != nil {
return time.Time{}, fmt.Errorf("failed to get system info: %w", err)
}

// Calculate boot time: current time - uptime
// Sysinfo.Uptime is in seconds since boot
uptime := time.Duration(sysinfo.Uptime) * time.Second
bootTime := time.Now().Add(-uptime)
return bootTime, nil
}

func getNodeName() (string, error) {
host, err := os.Hostname()
if err != nil {
return "", fmt.Errorf("failed to get hostname: %w", err)
}

nodeName := strings.ToLower(strings.TrimSpace(host))
if nodeName == "" {
return "", fmt.Errorf("node name is empty")
}

return nodeName, nil
}

func (c *Collector) checkRebootNeeded(ctx context.Context) bool {
hostBootTime, err := getBootTime()
if err != nil {
c.logger.Warnf("Failed to get boot time: %v", err)
return false
}
nodeName, err := getNodeName()
if err != nil {
c.logger.Errorf("failed to get node name: %s", err.Error())
return false
}

clientset, err := kube.KubeletClientset()
if err != nil {
c.logger.Errorf("failed to get kubelet clientset: %s", err.Error())
return false
}

// Get the node with a timeout and respecting the passed-in context
ctxWithTimeout, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

node, err := clientset.CoreV1().Nodes().Get(ctxWithTimeout, nodeName, metav1.GetOptions{})
if err != nil {
c.logger.Errorf("failed to get node: %s", err.Error())
return false
}

for _, condition := range node.Status.Conditions {
switch condition.Type {
case "KernelDeadlock":
if condition.Status == corev1.ConditionTrue && condition.LastTransitionTime.After(hostBootTime) {
c.logger.Infof("Node has a kernel deadlock since %s, rebooting...",
condition.LastTransitionTime.Format("2006-01-02 15:04:05"))
return true
}
}
}
return false
}
2 changes: 2 additions & 0 deletions pkg/status/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type NodeStatus struct {
LastUpdatedBy LastUpdatedBy `json:"lastUpdatedBy,omitempty"`
LastUpdatedReason LastUpdatedReason `json:"lastUpdatedReason,omitempty"`
AgentVersion string `json:"agentVersion"`

NeedReboot bool `json:"needReboot,omitempty"`
}

// ArcStatus contains Azure Arc machine registration and connection status
Expand Down
Loading