From 6d9f2c57ab7edb9acd1017cd83cfd6ad60b7e99b Mon Sep 17 00:00:00 2001 From: abtreece Date: Fri, 8 May 2026 10:42:01 -0500 Subject: [PATCH] refactor: split confd run orchestration --- cmd/confd/cli.go | 306 +++++++++++++++++++++++++----------------- cmd/confd/cli_test.go | 199 +++++++++++++++++++++++++++ 2 files changed, 379 insertions(+), 126 deletions(-) diff --git a/cmd/confd/cli.go b/cmd/confd/cli.go index 462ab4773..8d63d55be 100644 --- a/cmd/confd/cli.go +++ b/cmd/confd/cli.go @@ -376,105 +376,175 @@ func (i *IMDSCmd) Run(cli *CLI) error { // run is the shared execution function for all backends func run(cli *CLI, backendCfg backends.Config) error { - // Load TOML config file if it exists (for defaults) - if err := loadConfigFile(cli, &backendCfg); err != nil { + resolvedBackendCfg, err := buildBackendConfig(cli, backendCfg) + if err != nil { return err } - applyBackendDefaults(&backendCfg) - // Process environment variables + configureLogging(cli) + + // Check-config mode: validate configuration and exit (no backend needed) + if cli.CheckConfig { + return template.ValidateConfig(cli.ConfDir, cli.Resource) + } + + // Validate mode: validate templates and exit (no backend needed) + if cli.Validate { + return template.ValidateTemplates(cli.ConfDir, cli.Resource, cli.MockData) + } + + if err := applySRVDiscovery(cli, &resolvedBackendCfg); err != nil { + return err + } + + log.Info("Starting confd") + log.Info("Backend set to %s", resolvedBackendCfg.Backend) + + template.InitTemplateCache(cli.TemplateCache, cli.StatCacheTTL) + + storeClient, err := backends.New(resolvedBackendCfg) + if err != nil { + return err + } + + storeClient, metricsServer := startObservability(cli.MetricsAddr, resolvedBackendCfg.Backend, storeClient) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tmplCfg, err := buildTemplateConfig(cli, storeClient, ctx) + if err != nil { + return err + } + + // Preflight mode: run connectivity checks and exit + if cli.Preflight { + return template.Preflight(tmplCfg) + } + + // One-time mode + if cli.Onetime { + if err := template.Process(tmplCfg); err != nil { + return err + } + return nil + } + + // Continuous mode with processor + channels := newProcessorChannels() + + // Create reload manager for SIGHUP handling + reloadMgr := service.NewReloadManager() + reloadChan := reloadMgr.Subscribe() + + processor := buildProcessor(cli, tmplCfg, channels, reloadChan) + go processor.Process() + + // Create shutdown manager for graceful shutdown coordination + shutdownMgr := service.NewShutdownManager(cli.ShutdownTimeout, metricsServer, storeClient) + + // Create systemd notifier and start watchdog if enabled + systemdNotifier := service.NewSystemdNotifier(cli.SystemdNotify, cli.WatchdogInterval) + systemdNotifier.StartWatchdog(ctx) + + // Notify systemd that we're ready + if err := systemdNotifier.NotifyReady(); err != nil { + log.Warning("Failed to notify systemd ready: %v", err) + } + + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + defer signal.Stop(signalChan) + + return superviseRuntime(cancel, channels, signalChan, reloadMgr, shutdownMgr, systemdNotifier) +} + +func buildBackendConfig(cli *CLI, backendCfg backends.Config) (backends.Config, error) { + if err := loadConfigFile(cli, &backendCfg); err != nil { + return backendCfg, err + } + applyBackendDefaults(&backendCfg) processEnv(&backendCfg) + applyConnectionSettings(cli, &backendCfg) + return backendCfg, nil +} - // Apply CLI timeout/retry config to backend config +func applyConnectionSettings(cli *CLI, backendCfg *backends.Config) { backendCfg.DialTimeout = cli.DialTimeout backendCfg.ReadTimeout = cli.ReadTimeout backendCfg.WriteTimeout = cli.WriteTimeout backendCfg.RetryMaxAttempts = cli.RetryMaxAttempts backendCfg.RetryBaseDelay = cli.RetryBaseDelay backendCfg.RetryMaxDelay = cli.RetryMaxDelay +} - // Set up logging +func configureLogging(cli *CLI) { if cli.LogLevel != "" { log.SetLevel(cli.LogLevel) } if cli.LogFormat != "" { log.SetFormat(cli.LogFormat) } +} - // Check-config mode: validate configuration and exit (no backend needed) - if cli.CheckConfig { - return template.ValidateConfig(cli.ConfDir, cli.Resource) +func applySRVDiscovery(cli *CLI, backendCfg *backends.Config) error { + if cli.SRVDomain != "" && cli.SRVRecord == "" { + cli.SRVRecord = fmt.Sprintf("_%s._tcp.%s.", backendCfg.Backend, cli.SRVDomain) } - - // Validate mode: validate templates and exit (no backend needed) - if cli.Validate { - return template.ValidateTemplates(cli.ConfDir, cli.Resource, cli.MockData) + if backendCfg.Backend == "env" || cli.SRVRecord == "" { + return nil } - // Handle SRV record discovery - if cli.SRVDomain != "" && cli.SRVRecord == "" { - cli.SRVRecord = fmt.Sprintf("_%s._tcp.%s.", backendCfg.Backend, cli.SRVDomain) + log.Info("SRV record set to %s", cli.SRVRecord) + srvNodes, err := getBackendNodesFromSRV(cli.SRVRecord) + if err != nil { + return fmt.Errorf("cannot get nodes from SRV records: %w", err) } - if backendCfg.Backend != "env" && cli.SRVRecord != "" { - log.Info("SRV record set to %s", cli.SRVRecord) - srvNodes, err := getBackendNodesFromSRV(cli.SRVRecord) - if err != nil { - return fmt.Errorf("cannot get nodes from SRV records: %w", err) + if backendCfg.Backend == "etcd" { + for i, v := range srvNodes { + srvNodes[i] = backendCfg.Scheme + "://" + v } - if backendCfg.Backend == "etcd" { - for i, v := range srvNodes { - srvNodes[i] = backendCfg.Scheme + "://" + v - } - } - backendCfg.BackendNodes = srvNodes } + backendCfg.BackendNodes = srvNodes + return nil +} - log.Info("Starting confd") - log.Info("Backend set to %s", backendCfg.Backend) - - // Initialize template cache - template.InitTemplateCache(cli.TemplateCache, cli.StatCacheTTL) - - // Create store client - storeClient, err := backends.New(backendCfg) - if err != nil { - return err +func startObservability(addr, backendName string, storeClient backends.StoreClient) (backends.StoreClient, *http.Server) { + if addr == "" { + return storeClient, nil } - // Start metrics server if configured - var metricsServer *http.Server - if cli.MetricsAddr != "" { - metrics.Initialize() - storeClient = metrics.WrapStoreClient(storeClient, backendCfg.Backend) - mux := http.NewServeMux() - mux.Handle("/metrics", promhttp.HandlerFor(metrics.Registry, promhttp.HandlerOpts{})) - mux.HandleFunc("/health", metrics.HealthHandler(storeClient)) - mux.HandleFunc("/ready", metrics.ReadyHandler(storeClient)) - mux.HandleFunc("/ready/detailed", metrics.ReadyDetailedHandler(storeClient)) - metricsServer = &http.Server{ - Addr: cli.MetricsAddr, - Handler: mux, - ReadHeaderTimeout: 10 * time.Second, + metrics.Initialize() + wrappedClient := metrics.WrapStoreClient(storeClient, backendName) + metricsServer := buildMetricsServer(addr, wrappedClient) + go func() { + log.Info("Starting metrics server on %s", addr) + if err := metricsServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Error("Metrics server error: %v", err) } - go func() { - log.Info("Starting metrics server on %s", cli.MetricsAddr) - if err := metricsServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Error("Metrics server error: %v", err) - } - }() - } + }() + return wrappedClient, metricsServer +} - // Create root context for cancellation - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +func buildMetricsServer(addr string, storeClient backends.StoreClient) *http.Server { + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.HandlerFor(metrics.Registry, promhttp.HandlerOpts{})) + mux.HandleFunc("/health", metrics.HealthHandler(storeClient)) + mux.HandleFunc("/ready", metrics.ReadyHandler(storeClient)) + mux.HandleFunc("/ready/detailed", metrics.ReadyDetailedHandler(storeClient)) + return &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + } +} - // Parse failure mode +func buildTemplateConfig(cli *CLI, storeClient backends.StoreClient, ctx context.Context) (template.Config, error) { failureMode, err := template.ParseFailureMode(cli.FailureMode) if err != nil { - return err + return template.Config{}, err } - // Build template config tmplCfg := template.Config{ ConfDir: cli.ConfDir, ConfigDir: filepath.Join(cli.ConfDir, "conf.d"), @@ -496,88 +566,61 @@ func run(cli *CLI, backendCfg backends.Config) error { FailureMode: failureMode, } - // Parse watch mode duration flags if cli.DebounceStr != "" { d, err := time.ParseDuration(cli.DebounceStr) if err != nil { - return fmt.Errorf("invalid debounce duration %q: %w", cli.DebounceStr, err) + return template.Config{}, fmt.Errorf("invalid debounce duration %q: %w", cli.DebounceStr, err) } tmplCfg.Debounce = d } if cli.BatchIntervalStr != "" { d, err := time.ParseDuration(cli.BatchIntervalStr) if err != nil { - return fmt.Errorf("invalid batch-interval duration %q: %w", cli.BatchIntervalStr, err) + return template.Config{}, fmt.Errorf("invalid batch-interval duration %q: %w", cli.BatchIntervalStr, err) } tmplCfg.BatchInterval = d } - // Preflight mode: run connectivity checks and exit - if cli.Preflight { - return template.Preflight(tmplCfg) - } - - // One-time mode - if cli.Onetime { - if err := template.Process(tmplCfg); err != nil { - return err - } - return nil - } + return tmplCfg, nil +} - // Continuous mode with processor - stopChan := make(chan bool) - doneChan := make(chan bool) - errChan := make(chan error, 10) +type processorChannels struct { + stop chan bool + done chan bool + err chan error +} - // Create reload manager for SIGHUP handling - reloadMgr := service.NewReloadManager() - reloadChan := reloadMgr.Subscribe() +func newProcessorChannels() processorChannels { + return processorChannels{ + stop: make(chan bool), + done: make(chan bool), + err: make(chan error, 10), + } +} - var processor template.Processor +func buildProcessor(cli *CLI, tmplCfg template.Config, channels processorChannels, reloadChan <-chan struct{}) template.Processor { if cli.Watch { if tmplCfg.BatchInterval > 0 { - // Use batch processor when --batch-interval is specified log.Info("Batch processing enabled with interval %v", tmplCfg.BatchInterval) - processor = template.BatchWatchProcessor(tmplCfg, stopChan, doneChan, errChan, reloadChan) - } else { - processor = template.WatchProcessor(tmplCfg, stopChan, doneChan, errChan, reloadChan) + return template.BatchWatchProcessor(tmplCfg, channels.stop, channels.done, channels.err, reloadChan) } - } else { - processor = template.IntervalProcessor(tmplCfg, stopChan, doneChan, errChan, cli.Interval, reloadChan) - } - - go processor.Process() - - // Create shutdown manager for graceful shutdown coordination - shutdownMgr := service.NewShutdownManager(cli.ShutdownTimeout, metricsServer, storeClient) - - // Create systemd notifier and start watchdog if enabled - systemdNotifier := service.NewSystemdNotifier(cli.SystemdNotify, cli.WatchdogInterval) - systemdNotifier.StartWatchdog(ctx) - - // Notify systemd that we're ready - if err := systemdNotifier.NotifyReady(); err != nil { - log.Warning("Failed to notify systemd ready: %v", err) - } - - shutdown := func() error { - if err := shutdownMgr.Shutdown(context.Background()); err != nil { - log.Error("Shutdown error: %v", err) - return err - } - if err := template.CloseAllCachedClients(); err != nil { - log.Warning("Error closing per-resource backend clients: %v", err) - } - return nil + return template.WatchProcessor(tmplCfg, channels.stop, channels.done, channels.err, reloadChan) } + return template.IntervalProcessor(tmplCfg, channels.stop, channels.done, channels.err, cli.Interval, reloadChan) +} - signalChan := make(chan os.Signal, 1) - signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) +func superviseRuntime( + cancel context.CancelFunc, + channels processorChannels, + signalChan <-chan os.Signal, + reloadMgr *service.ReloadManager, + shutdownMgr *service.ShutdownManager, + systemdNotifier *service.SystemdNotifier, +) error { var stopOnce sync.Once for { select { - case err := <-errChan: + case err := <-channels.err: log.Error("%s", err.Error()) case s := <-signalChan: switch s { @@ -599,16 +642,27 @@ func run(cli *CLI, backendCfg backends.Config) error { log.Warning("Failed to notify systemd stopping: %v", err) } cancel() // Cancel context to signal all goroutines - stopOnce.Do(func() { close(stopChan) }) - <-doneChan - return shutdown() + stopOnce.Do(func() { close(channels.stop) }) + <-channels.done + return shutdownRuntime(shutdownMgr) } - case <-doneChan: - return shutdown() + case <-channels.done: + return shutdownRuntime(shutdownMgr) } } } +func shutdownRuntime(shutdownMgr *service.ShutdownManager) error { + if err := shutdownMgr.Shutdown(context.Background()); err != nil { + log.Error("Shutdown error: %v", err) + return err + } + if err := template.CloseAllCachedClients(); err != nil { + log.Warning("Error closing per-resource backend clients: %v", err) + } + return nil +} + func applyBackendDefaults(cfg *backends.Config) { if len(cfg.BackendNodes) > 0 { return diff --git a/cmd/confd/cli_test.go b/cmd/confd/cli_test.go index 989fcdf4e..8dbdcb031 100644 --- a/cmd/confd/cli_test.go +++ b/cmd/confd/cli_test.go @@ -1,7 +1,10 @@ package main import ( + "context" "fmt" + "net/http" + "net/http/httptest" "os" "path/filepath" "strings" @@ -10,6 +13,8 @@ import ( "github.com/abtreece/confd/pkg/backends" "github.com/abtreece/confd/pkg/log" + "github.com/abtreece/confd/pkg/metrics" + "github.com/abtreece/confd/pkg/service" "github.com/alecthomas/kong" ) @@ -674,6 +679,200 @@ func TestProcessEnvDoesNotOverride(t *testing.T) { } } +type cliMockStoreClient struct { + closed bool +} + +func (m *cliMockStoreClient) GetValues(_ context.Context, _ []string) (map[string]string, error) { + return map[string]string{}, nil +} + +func (m *cliMockStoreClient) WatchPrefix(_ context.Context, _ string, _ []string, _ uint64, _ chan bool) (uint64, error) { + return 0, nil +} + +func (m *cliMockStoreClient) HealthCheck(_ context.Context) error { + return nil +} + +func (m *cliMockStoreClient) Close() error { + m.closed = true + return nil +} + +func TestBuildBackendConfigAppliesDefaultsAndConnectionSettings(t *testing.T) { + cli := &CLI{ + ConfigFile: filepath.Join(t.TempDir(), "missing.toml"), + DialTimeout: 2 * time.Second, + ReadTimeout: 3 * time.Second, + WriteTimeout: 4 * time.Second, + RetryMaxAttempts: 9, + RetryBaseDelay: 250 * time.Millisecond, + RetryMaxDelay: 7 * time.Second, + } + + cfg, err := buildBackendConfig(cli, backends.Config{Backend: "consul"}) + if err != nil { + t.Fatalf("buildBackendConfig() unexpected error: %v", err) + } + + if len(cfg.BackendNodes) != 1 || cfg.BackendNodes[0] != "127.0.0.1:8500" { + t.Fatalf("BackendNodes = %v, want default consul node", cfg.BackendNodes) + } + if cfg.DialTimeout != cli.DialTimeout || cfg.ReadTimeout != cli.ReadTimeout || cfg.WriteTimeout != cli.WriteTimeout { + t.Fatalf("timeouts were not copied from CLI: %#v", cfg) + } + if cfg.RetryMaxAttempts != cli.RetryMaxAttempts || cfg.RetryBaseDelay != cli.RetryBaseDelay || cfg.RetryMaxDelay != cli.RetryMaxDelay { + t.Fatalf("retry settings were not copied from CLI: %#v", cfg) + } +} + +func TestRunCheckConfigSkipsSRVDiscovery(t *testing.T) { + tmpDir := t.TempDir() + if err := os.Mkdir(filepath.Join(tmpDir, "conf.d"), 0755); err != nil { + t.Fatalf("failed to create conf.d: %v", err) + } + + cli := &CLI{ + ConfDir: tmpDir, + ConfigFile: filepath.Join(tmpDir, "missing.toml"), + CheckConfig: true, + SRVRecord: "_invalid._tcp.invalid.", + } + + err := run(cli, backends.Config{Backend: "consul"}) + if err != nil { + t.Fatalf("run() with --check-config returned error: %v", err) + } +} + +func TestBuildTemplateConfig(t *testing.T) { + client := &cliMockStoreClient{} + ctx := context.Background() + cli := &CLI{ + ConfDir: "/tmp/confd", + Noop: true, + Prefix: "/app", + SyncOnly: true, + KeepStageFile: true, + Diff: true, + DiffContext: 5, + Color: true, + BackendTimeout: 10 * time.Second, + CheckCmdTimeout: 11 * time.Second, + ReloadCmdTimeout: 12 * time.Second, + WatchErrorBackoff: 13 * time.Second, + PreflightTimeout: 14 * time.Second, + FailureMode: "fail-fast", + DebounceStr: "150ms", + BatchIntervalStr: "2s", + } + + cfg, err := buildTemplateConfig(cli, client, ctx) + if err != nil { + t.Fatalf("buildTemplateConfig() unexpected error: %v", err) + } + + if cfg.ConfigDir != "/tmp/confd/conf.d" || cfg.TemplateDir != "/tmp/confd/templates" { + t.Fatalf("unexpected template paths: config=%q template=%q", cfg.ConfigDir, cfg.TemplateDir) + } + if cfg.StoreClient != client || cfg.Ctx != ctx { + t.Fatal("store client or context was not propagated") + } + if cfg.Debounce != 150*time.Millisecond || cfg.BatchInterval != 2*time.Second { + t.Fatalf("watch durations not parsed: debounce=%v batch=%v", cfg.Debounce, cfg.BatchInterval) + } +} + +func TestBuildTemplateConfigInvalidDurations(t *testing.T) { + tests := []struct { + name string + cli CLI + want string + }{ + { + name: "debounce", + cli: CLI{FailureMode: "best-effort", DebounceStr: "bad"}, + want: "invalid debounce duration", + }, + { + name: "batch interval", + cli: CLI{FailureMode: "best-effort", BatchIntervalStr: "bad"}, + want: "invalid batch-interval duration", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := buildTemplateConfig(&tc.cli, &cliMockStoreClient{}, context.Background()) + if err == nil || !strings.Contains(err.Error(), tc.want) { + t.Fatalf("buildTemplateConfig() error = %v, want %q", err, tc.want) + } + }) + } +} + +func TestStartObservabilityDisabled(t *testing.T) { + client := &cliMockStoreClient{} + wrapped, server := startObservability("", "env", client) + if wrapped != client { + t.Fatal("disabled observability should return original client") + } + if server != nil { + t.Fatal("disabled observability should not create a server") + } +} + +func TestBuildMetricsServerHandlers(t *testing.T) { + metrics.Initialize() + client := &cliMockStoreClient{} + server := buildMetricsServer(":0", client) + + for _, path := range []string{"/health", "/ready", "/ready/detailed", "/metrics"} { + req := httptest.NewRequest(http.MethodGet, path, nil) + rec := httptest.NewRecorder() + server.Handler.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("%s returned status %d, want %d", path, rec.Code, http.StatusOK) + } + } +} + +func TestSuperviseRuntimeShutdownOnSignal(t *testing.T) { + client := &cliMockStoreClient{} + channels := processorChannels{ + stop: make(chan bool), + done: make(chan bool), + err: make(chan error, 1), + } + signalChan := make(chan os.Signal, 1) + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + <-channels.stop + close(channels.done) + }() + + signalChan <- os.Interrupt + err := superviseRuntime( + cancel, + channels, + signalChan, + service.NewReloadManager(), + service.NewShutdownManager(time.Second, nil, client), + service.NewSystemdNotifier(false, 0), + ) + if err != nil { + t.Fatalf("superviseRuntime() unexpected error: %v", err) + } + if ctx.Err() == nil { + t.Fatal("context was not cancelled") + } + if !client.closed { + t.Fatal("store client was not closed") + } +} + // clearCONFDEnvVars clears all CONFD_* and backend-specific env vars that could // interfere with Kong parsing, saving original values for restoration via t.Cleanup. func clearCONFDEnvVars(t *testing.T) {