diff --git a/pkg/drift/defaults.go b/pkg/drift/defaults.go index b4f82be..9a61d7d 100644 --- a/pkg/drift/defaults.go +++ b/pkg/drift/defaults.go @@ -7,5 +7,6 @@ package drift func DefaultDetectors() []Detector { return []Detector{ NewKubernetesVersionDetector(), + NewRebootDetector(), } } diff --git a/pkg/drift/detector.go b/pkg/drift/detector.go index 2b7608e..400e50d 100644 --- a/pkg/drift/detector.go +++ b/pkg/drift/detector.go @@ -25,6 +25,7 @@ type RemediationAction string const ( RemediationActionUnspecified RemediationAction = "" RemediationActionKubernetesUpgrade RemediationAction = "kubernetes-upgrade" + RemediationActionReboot RemediationAction = "reboot" ) // Remediation describes what the agent should do to address a drift. diff --git a/pkg/drift/node_reboot.go b/pkg/drift/node_reboot.go new file mode 100644 index 0000000..9d6874c --- /dev/null +++ b/pkg/drift/node_reboot.go @@ -0,0 +1,51 @@ +package drift + +import ( + "context" + + "github.com/Azure/AKSFlexNode/pkg/config" + "github.com/Azure/AKSFlexNode/pkg/spec" + "github.com/Azure/AKSFlexNode/pkg/status" +) + +const NodeRebootFindingID = "node-reboot" + +type RebootDetector struct{} + +func NewRebootDetector() *RebootDetector { + return &RebootDetector{} +} + +func (d *RebootDetector) Name() string { + return "RebootDetector" +} + +func (d *RebootDetector) Detect( + ctx context.Context, + _ *config.Config, + _ *spec.ManagedClusterSpec, + statusSnap *status.NodeStatus, +) ([]Finding, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + if statusSnap == nil { + return nil, nil + } + + if !statusSnap.NeedReboot { + return nil, nil + } + + return []Finding{ + { + ID: NodeRebootFindingID, + Title: "Node reboot required", + Details: "Node status indicates a reboot is needed", + Remediation: Remediation{ + Action: RemediationActionReboot, + }, + }, + }, nil +} diff --git a/pkg/drift/node_reboot_test.go b/pkg/drift/node_reboot_test.go new file mode 100644 index 0000000..2082c13 --- /dev/null +++ b/pkg/drift/node_reboot_test.go @@ -0,0 +1,82 @@ +package drift + +import ( + "context" + "testing" + + "github.com/Azure/AKSFlexNode/pkg/status" +) + +func TestRebootDetector_Name(t *testing.T) { + t.Parallel() + d := NewRebootDetector() + if name := d.Name(); name != "RebootDetector" { + t.Errorf("expected name %q, got %q", "RebootDetector", name) + } +} + +func TestRebootDetector_NilStatus_NoFindings(t *testing.T) { + t.Parallel() + d := NewRebootDetector() + findings, err := d.Detect(context.Background(), nil, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(findings) != 0 { + t.Errorf("expected no findings, got %d", len(findings)) + } +} + +func TestRebootDetector_NeedRebootFalse_NoFindings(t *testing.T) { + t.Parallel() + d := NewRebootDetector() + statusSnap := &status.NodeStatus{ + NeedReboot: false, + } + findings, err := d.Detect(context.Background(), nil, nil, statusSnap) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(findings) != 0 { + t.Errorf("expected no findings, got %d", len(findings)) + } +} + +func TestRebootDetector_NeedRebootTrue_ReturnsFinding(t *testing.T) { + t.Parallel() + d := NewRebootDetector() + statusSnap := &status.NodeStatus{ + NeedReboot: true, + } + findings, err := d.Detect(context.Background(), nil, nil, statusSnap) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(findings) != 1 { + t.Fatalf("expected 1 finding, got %d", len(findings)) + } + + f := findings[0] + if f.ID != NodeRebootFindingID { + t.Errorf("expected ID %q, got %q", NodeRebootFindingID, f.ID) + } + if f.Title != "Node reboot required" { + t.Errorf("unexpected title: %q", f.Title) + } + if f.Remediation.Action != RemediationActionReboot { + t.Errorf("expected action %q, got %q", RemediationActionReboot, f.Remediation.Action) + } +} + +func TestRebootDetector_CanceledContext_ReturnsError(t *testing.T) { + t.Parallel() + d := NewRebootDetector() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + statusSnap := &status.NodeStatus{NeedReboot: true} + _, err := d.Detect(ctx, nil, nil, statusSnap) + if err == nil { + t.Fatal("expected error from canceled context") + } +} diff --git a/pkg/drift/remediation.go b/pkg/drift/remediation.go index bcf1bd9..90be5a6 100644 --- a/pkg/drift/remediation.go +++ b/pkg/drift/remediation.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "os" + "os/exec" "sync/atomic" "time" @@ -16,6 +17,7 @@ import ( "github.com/Azure/AKSFlexNode/pkg/kube" "github.com/Azure/AKSFlexNode/pkg/spec" "github.com/Azure/AKSFlexNode/pkg/status" + "github.com/Azure/AKSFlexNode/pkg/systemd" ) const driftKubernetesUpgradeOperation = "drift-kubernetes-upgrade" @@ -28,6 +30,8 @@ const ( upgradeStepUncordon = "uncordon" ) +const agentServiceName = "aks-flex-node-agent.service" + // maxManagedClusterSpecAge is a safety guard to avoid acting on very stale spec snapshots. // In normal operation we run drift immediately after a successful spec collection, so this // should rarely block remediation. @@ -154,6 +158,13 @@ func detectAndRemediate( logger.Info("Kubernetes upgrade remediation completed successfully") return detectErr + case RemediationActionReboot: + if err := runRebootRemediation(ctx, logger); err != nil { + return fmt.Errorf("reboot remediation failed: %w", err) + } + logger.Info("Reboot remediation completed without error") + return detectErr + default: return fmt.Errorf("unsupported drift remediation action: %q", plan.Action) } @@ -267,6 +278,67 @@ func runKubernetesUpgradeRemediation( return result, err } +func runRebootRemediation( + ctx context.Context, + logger *logrus.Logger, +) error { + // Key design points: + // - Only reboot if aks-flex-node-agent is running as a systemd service + // - If not running under systemd, skip reboot (agent may be running in development/test mode) + // - Use systemctl reboot for a clean shutdown + if logger == nil { + logger = logrus.New() + } + + // Check if aks-flex-node-agent is managed by systemd. + // We use GetUnitStatus to check if the service is active and running under systemd. + mgr := systemd.New() + status, err := mgr.GetUnitStatus(ctx, agentServiceName) + if err != nil { + if errors.Is(err, systemd.ErrUnitNotFound) { + logger.Warn("aks-flex-node-agent is not running as a systemd service; skipping reboot") + // Not running under systemd is an expected scenario (e.g., dev/test); treat as a no-op, not an error. + return nil + } + logger.WithError(err).Warn("Failed to check systemd service status; aborting reboot remediation") + return fmt.Errorf("failed to check systemd service status: %w", err) + } + + // Only reboot if the service is active + if status.ActiveState != systemd.UnitActiveStateActive { + logger.Warnf("aks-flex-node-agent service is not active (state: %s); skipping reboot", status.ActiveState) + return nil + } + + logger.Info("Initiating system reboot via systemctl") + + // Use systemctl reboot for a clean shutdown. + // This will gracefully stop services and sync filesystems before rebooting. + // To avoid silently ignoring immediate failures (e.g., DBus unavailable), run the + // command and check its exit status, using a short timeout if no deadline is set. + rebootCtx := ctx + if _, hasDeadline := ctx.Deadline(); !hasDeadline { + var cancel context.CancelFunc + rebootCtx, cancel = context.WithTimeout(ctx, 30*time.Second) + defer cancel() + } + + cmd := exec.CommandContext(rebootCtx, "systemctl", "reboot") + output, err := cmd.CombinedOutput() + if err != nil { + // If the context was canceled or timed out, surface that information explicitly. + if errors.Is(rebootCtx.Err(), context.DeadlineExceeded) { + logger.WithError(err).WithField("output", string(output)). + Error("systemctl reboot timed out") + return fmt.Errorf("systemctl reboot timed out: %w", err) + } + logger.WithError(err).WithField("output", string(output)). + Error("systemctl reboot failed") + return fmt.Errorf("systemctl reboot failed: %w", err) + } + return nil +} + // handleExecutionResult mirrors main's handleExecutionResult but lives in drift so remediation // can share the same logging and error semantics. func handleExecutionResult(result *bootstrapper.ExecutionResult, operation string, logger *logrus.Logger) error { diff --git a/pkg/status/collector.go b/pkg/status/collector.go index dc4d72b..4cd7b19 100644 --- a/pkg/status/collector.go +++ b/pkg/status/collector.go @@ -3,6 +3,7 @@ package status import ( "context" "encoding/json" + "fmt" "os" "os/exec" "path/filepath" @@ -13,6 +14,7 @@ import ( "github.com/Azure/AKSFlexNode/pkg/kube" "github.com/Azure/AKSFlexNode/pkg/utils" "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -60,6 +62,9 @@ func (c *Collector) CollectStatus(ctx context.Context) (*NodeStatus, error) { } status.ArcStatus = arcStatus + // Check if reboot is needed node condition + status.NeedReboot = c.checkRebootNeeded(ctx) + return status, nil } @@ -297,3 +302,71 @@ func GetStatusFilePath() string { // Fallback to temp directory for testing/development return filepath.Join("/tmp/aks-flex-node", "status.json") } + +func getBootTime() (time.Time, error) { + var sysinfo unix.Sysinfo_t + if err := unix.Sysinfo(&sysinfo); err != nil { + return time.Time{}, fmt.Errorf("failed to get system info: %w", err) + } + + // Calculate boot time: current time - uptime + // Sysinfo.Uptime is in seconds since boot + uptime := time.Duration(sysinfo.Uptime) * time.Second + bootTime := time.Now().Add(-uptime) + return bootTime, nil +} + +func getNodeName() (string, error) { + host, err := os.Hostname() + if err != nil { + return "", fmt.Errorf("failed to get hostname: %w", err) + } + + nodeName := strings.ToLower(strings.TrimSpace(host)) + if nodeName == "" { + return "", fmt.Errorf("node name is empty") + } + + return nodeName, nil +} + +func (c *Collector) checkRebootNeeded(ctx context.Context) bool { + hostBootTime, err := getBootTime() + if err != nil { + c.logger.Warnf("Failed to get boot time: %v", err) + return false + } + nodeName, err := getNodeName() + if err != nil { + c.logger.Errorf("failed to get node name: %s", err.Error()) + return false + } + + clientset, err := kube.KubeletClientset() + if err != nil { + c.logger.Errorf("failed to get kubelet clientset: %s", err.Error()) + return false + } + + // Get the node with a timeout and respecting the passed-in context + ctxWithTimeout, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + node, err := clientset.CoreV1().Nodes().Get(ctxWithTimeout, nodeName, metav1.GetOptions{}) + if err != nil { + c.logger.Errorf("failed to get node: %s", err.Error()) + return false + } + + for _, condition := range node.Status.Conditions { + switch condition.Type { + case "KernelDeadlock": + if condition.Status == corev1.ConditionTrue && condition.LastTransitionTime.After(hostBootTime) { + c.logger.Infof("Node has a kernel deadlock since %s, rebooting...", + condition.LastTransitionTime.Format("2006-01-02 15:04:05")) + return true + } + } + } + return false +} diff --git a/pkg/status/types.go b/pkg/status/types.go index ea32d81..c2c1aab 100644 --- a/pkg/status/types.go +++ b/pkg/status/types.go @@ -41,6 +41,8 @@ type NodeStatus struct { LastUpdatedBy LastUpdatedBy `json:"lastUpdatedBy,omitempty"` LastUpdatedReason LastUpdatedReason `json:"lastUpdatedReason,omitempty"` AgentVersion string `json:"agentVersion"` + + NeedReboot bool `json:"needReboot,omitempty"` } // ArcStatus contains Azure Arc machine registration and connection status