diff --git a/duckdbservice/service.go b/duckdbservice/service.go index e714a71..fb9f50a 100644 --- a/duckdbservice/service.go +++ b/duckdbservice/service.go @@ -25,6 +25,9 @@ import ( "google.golang.org/grpc/credentials" ) +var bootstrapBundledExtensions = server.BootstrapBundledExtensions +var exitProcess = os.Exit + // DuckDBService is a standalone Arrow Flight SQL service backed by DuckDB. type DuckDBService struct { cfg ServiceConfig @@ -198,6 +201,14 @@ func (p *SessionPool) reapLoop() { func Run(cfg ServiceConfig) { svc := NewDuckDBService(cfg) + if err := bootstrapBundledExtensions(cfg.ServerConfig.DataDir); err != nil { + slog.Error("Failed to bootstrap bundled DuckDB extensions.", + "source", "/app/extensions", + "extension_directory", cfg.ServerConfig.DataDir+"/extensions", + "error", err) + exitProcess(1) + } + // Pre-warm the DuckDB instance (load extensions, attach DuckLake) // in the background so we don't block the gRPC server from starting. // This ensures that waitForWorker doesn't time out during spawn. diff --git a/duckdbservice/service_test.go b/duckdbservice/service_test.go index b0ae3c1..361f608 100644 --- a/duckdbservice/service_test.go +++ b/duckdbservice/service_test.go @@ -3,11 +3,16 @@ package duckdbservice import ( "context" "database/sql" + "errors" "testing" _ "github.com/duckdb/duckdb-go/v2" ) +type exitPanic struct { + code int +} + func TestInitSearchPath(t *testing.T) { db, err := sql.Open("duckdb", "") if err != nil { @@ -57,3 +62,39 @@ func TestInitSearchPath(t *testing.T) { } }) } + +func TestRunExitsWhenBundledExtensionBootstrapFails(t *testing.T) { + prevBootstrap := bootstrapBundledExtensions + prevExit := exitProcess + defer func() { + bootstrapBundledExtensions = prevBootstrap + exitProcess = prevExit + }() + + bootstrapBundledExtensions = func(string) error { + return errors.New("boom") + } + + exitCode := -1 + exitProcess = func(code int) { + exitCode = code + panic(exitPanic{code: code}) + } + + defer func() { + r := recover() + p, ok := r.(exitPanic) + if !ok { + t.Fatalf("expected exit panic, got %v", r) + } + if p.code != 1 { + t.Fatalf("expected exit code 1, got %d", p.code) + } + if exitCode != 1 { + t.Fatalf("expected exitProcess to be called with 1, got %d", exitCode) + } + }() + + Run(ServiceConfig{}) + t.Fatal("expected Run to exit") +} diff --git a/server/bundled_extensions_test.go b/server/bundled_extensions_test.go index a2c18f3..071f7f8 100644 --- a/server/bundled_extensions_test.go +++ b/server/bundled_extensions_test.go @@ -1,10 +1,13 @@ package server import ( + "database/sql" "os" "path/filepath" "sync" "testing" + + _ "github.com/duckdb/duckdb-go/v2" ) func TestSeedBundledExtensionsCopiesMissingFiles(t *testing.T) { @@ -226,3 +229,59 @@ func TestBootstrapBundledExtensionsRunsOncePerExtensionDirectory(t *testing.T) { t.Fatalf("expected bootstrap to run once, got %q", string(got)) } } + +func TestSetExtensionDirectorySetsPathAfterBootstrap(t *testing.T) { + bundledRoot := t.TempDir() + dataDir := t.TempDir() + + srcDir := filepath.Join(bundledRoot, "v1.5.2", "linux_arm64") + if err := os.MkdirAll(srcDir, 0o755); err != nil { + t.Fatalf("mkdir src: %v", err) + } + srcExt := filepath.Join(srcDir, "postgres_scanner.duckdb_extension") + if err := os.WriteFile(srcExt, []byte("nightly"), 0o644); err != nil { + t.Fatalf("write src extension: %v", err) + } + + prevBundledRoot := bundledDuckDBExtensionsDir + bundledDuckDBExtensionsDir = bundledRoot + defer func() { bundledDuckDBExtensionsDir = prevBundledRoot }() + + bundledExtensionBootstrap = struct { + mu sync.Mutex + byPath map[string]error + }{} + + db, err := sql.Open("duckdb", ":memory:") + if err != nil { + t.Fatalf("open duckdb: %v", err) + } + defer func() { _ = db.Close() }() + + if err := bootstrapBundledExtensions(dataDir); err != nil { + t.Fatalf("bootstrapBundledExtensions: %v", err) + } + + if err := setExtensionDirectory(db, dataDir); err != nil { + t.Fatalf("setExtensionDirectory: %v", err) + } + + var gotExtDir string + if err := db.QueryRow("SELECT current_setting('extension_directory')").Scan(&gotExtDir); err != nil { + t.Fatalf("query extension_directory: %v", err) + } + + wantExtDir := filepath.Join(dataDir, "extensions") + if gotExtDir != wantExtDir { + t.Fatalf("extension_directory = %q, want %q", gotExtDir, wantExtDir) + } + + dstExt := filepath.Join(wantExtDir, "v1.5.2", "linux_arm64", "postgres_scanner.duckdb_extension") + got, err := os.ReadFile(dstExt) + if err != nil { + t.Fatalf("read dst extension: %v", err) + } + if string(got) != "nightly" { + t.Fatalf("expected seeded extension to match bundled contents, got %q", string(got)) + } +} diff --git a/server/checkpoint.go b/server/checkpoint.go index 90386f5..01a2c2d 100644 --- a/server/checkpoint.go +++ b/server/checkpoint.go @@ -4,7 +4,6 @@ import ( "database/sql" "fmt" "log/slog" - "path/filepath" "sync" "time" ) @@ -32,13 +31,12 @@ func NewDuckLakeCheckpointer(cfg Config) (*DuckLakeCheckpointer, error) { return nil, fmt.Errorf("checkpoint: open duckdb: %w", err) } - extDir := filepath.Join(cfg.DataDir, "extensions") - if _, err := db.Exec(fmt.Sprintf("SET extension_directory = '%s'", extDir)); err != nil { + if err := setExtensionDirectory(db, cfg.DataDir); err != nil { _ = db.Close() - return nil, fmt.Errorf("checkpoint: set extension_directory: %w", err) + return nil, fmt.Errorf("checkpoint: set extension directory: %w", err) } - if _, err := db.Exec("INSTALL ducklake; LOAD ducklake"); err != nil { + if err := LoadExtensions(db, []string{"ducklake"}); err != nil { _ = db.Close() return nil, fmt.Errorf("checkpoint: load ducklake: %w", err) } diff --git a/server/querylog.go b/server/querylog.go index 0d3633c..09ea5c2 100644 --- a/server/querylog.go +++ b/server/querylog.go @@ -7,7 +7,6 @@ import ( "hash/fnv" "log/slog" "net" - "path/filepath" "regexp" "strings" "sync" @@ -69,15 +68,12 @@ func NewQueryLogger(cfg Config) (*QueryLogger, error) { return nil, fmt.Errorf("querylog: open duckdb: %w", err) } - // Set extension directory under DataDir so DuckDB doesn't rely on $HOME/.duckdb - extDir := filepath.Join(cfg.DataDir, "extensions") - if _, err := db.Exec(fmt.Sprintf("SET extension_directory = '%s'", extDir)); err != nil { + if err := setExtensionDirectory(db, cfg.DataDir); err != nil { _ = db.Close() - return nil, fmt.Errorf("querylog: set extension_directory: %w", err) + return nil, fmt.Errorf("querylog: set extension directory: %w", err) } - // Load ducklake extension - if _, err := db.Exec("INSTALL ducklake; LOAD ducklake"); err != nil { + if err := LoadExtensions(db, []string{"ducklake"}); err != nil { _ = db.Close() return nil, fmt.Errorf("querylog: load ducklake: %w", err) } diff --git a/server/server.go b/server/server.go index d062869..e3b8cfa 100644 --- a/server/server.go +++ b/server/server.go @@ -67,6 +67,21 @@ func bootstrapBundledExtensions(dataDir string) error { return err } +// BootstrapBundledExtensions eagerly seeds bundled extension binaries into the +// configured extension_directory cache once per data directory. +func BootstrapBundledExtensions(dataDir string) error { + return bootstrapBundledExtensions(dataDir) +} + +func setExtensionDirectory(db *sql.DB, dataDir string) error { + extDir := filepath.Join(dataDir, "extensions") + if _, err := db.Exec(fmt.Sprintf("SET extension_directory = '%s'", extDir)); err != nil { + return fmt.Errorf("set extension_directory %s: %w", extDir, err) + } + + return nil +} + // passwordPattern matches password= or password: with quoted or unquoted values. var passwordPattern = regexp.MustCompile(`(?i)(password\s*[=:]\s*)("[^"]*"|[^\s"]+)`) @@ -500,7 +515,7 @@ func New(cfg Config) (*Server, error) { } if err := bootstrapBundledExtensions(cfg.DataDir); err != nil { - slog.Warn("Failed to bootstrap bundled DuckDB extensions.", "source", bundledDuckDBExtensionsDir, "extension_directory", filepath.Join(cfg.DataDir, "extensions"), "error", err) + return nil, fmt.Errorf("failed to bootstrap bundled DuckDB extensions: %w", err) } // Initialize query logger (non-fatal on error) @@ -896,13 +911,9 @@ func openBaseDB(cfg Config, username string) (*sql.DB, error) { slog.Debug("Set DuckDB temp_directory.", "temp_directory", tempDir) } - // Set extension directory under DataDir so DuckDB doesn't rely on $HOME/.duckdb - // for autoloading/installing extensions. - extDir := filepath.Join(cfg.DataDir, "extensions") - if _, err := db.Exec(fmt.Sprintf("SET extension_directory = '%s'", extDir)); err != nil { - slog.Warn("Failed to set DuckDB extension_directory.", "extension_directory", extDir, "error", err) - } else { - slog.Debug("Set DuckDB extension_directory.", "extension_directory", extDir) + if err := setExtensionDirectory(db, cfg.DataDir); err != nil { + _ = db.Close() + return nil, fmt.Errorf("failed to configure extension_directory: %w", err) } // Load configured extensions @@ -1200,11 +1211,15 @@ func LoadExtensions(db *sql.DB, extensions []string) error { for _, ext := range extensions { name, installCmd := parseExtensionName(ext) - // First install the extension (downloads if needed) - if _, err := db.Exec("INSTALL " + installCmd); err != nil { - slog.Warn("Failed to install extension.", "extension", installCmd, "error", err) - lastErr = err - continue + if shouldInstallExtension(name) { + // First install the extension (downloads if needed). Bundled extensions + // are preseeded into the extension cache and INSTALL can overwrite that + // bundled binary with DuckDB's repository copy. + if _, err := db.Exec("INSTALL " + installCmd); err != nil { + slog.Warn("Failed to install extension.", "extension", installCmd, "error", err) + lastErr = err + continue + } } // Then load it into the current session @@ -1220,6 +1235,15 @@ func LoadExtensions(db *sql.DB, extensions []string) error { return lastErr } +func shouldInstallExtension(name string) bool { + return !hasBundledExtensionBinary(name) +} + +func hasBundledExtensionBinary(name string) bool { + matches, err := filepath.Glob(filepath.Join(bundledDuckDBExtensionsDir, "*", "*", name+".duckdb_extension")) + return err == nil && len(matches) > 0 +} + func boolPtr(v bool) *bool { return &v } func duckLakeDisableMetadataThreadLocalCacheEnabled(dlCfg DuckLakeConfig) bool { diff --git a/server/server_test.go b/server/server_test.go index 357a849..0ec2df3 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -112,6 +112,44 @@ func TestParseExtensionName(t *testing.T) { } } +func TestNewFailsWhenBundledExtensionBootstrapFails(t *testing.T) { + bundledRoot := filepath.Join(t.TempDir(), "extensions-file") + if err := os.WriteFile(bundledRoot, []byte("not-a-directory"), 0o644); err != nil { + t.Fatalf("write bundled root file: %v", err) + } + + prevBundledRoot := bundledDuckDBExtensionsDir + bundledDuckDBExtensionsDir = bundledRoot + defer func() { bundledDuckDBExtensionsDir = prevBundledRoot }() + + bundledExtensionBootstrap = struct { + mu sync.Mutex + byPath map[string]error + }{} + + certDir := t.TempDir() + certFile := filepath.Join(certDir, "server.crt") + keyFile := filepath.Join(certDir, "server.key") + if err := generateSelfSignedCert(certFile, keyFile); err != nil { + t.Fatalf("generateSelfSignedCert: %v", err) + } + + _, err := New(Config{ + Host: "127.0.0.1", + Port: 5432, + DataDir: t.TempDir(), + Users: map[string]string{"postgres": "postgres"}, + TLSCertFile: certFile, + TLSKeyFile: keyFile, + }) + if err == nil { + t.Fatal("expected bootstrap failure") + } + if !strings.Contains(err.Error(), "bootstrap bundled DuckDB extensions") { + t.Fatalf("expected bootstrap error, got %v", err) + } +} + func TestNeedsCredentialRefresh(t *testing.T) { tests := []struct { name string diff --git a/server/shell.go b/server/shell.go index 9619b9b..f61fe98 100644 --- a/server/shell.go +++ b/server/shell.go @@ -20,9 +20,9 @@ import ( func RunShell(cfg Config) { sem := make(chan struct{}, 1) if err := bootstrapBundledExtensions(cfg.DataDir); err != nil { - slog.Warn("Failed to bootstrap bundled DuckDB extensions.", "source", bundledDuckDBExtensionsDir, "extension_directory", filepath.Join(cfg.DataDir, "extensions"), "error", err) + slog.Error("Failed to bootstrap bundled DuckDB extensions.", "source", bundledDuckDBExtensionsDir, "extension_directory", filepath.Join(cfg.DataDir, "extensions"), "error", err) + os.Exit(1) } - db, err := CreateDBConnection(cfg, sem, "shell", processStartTime, processVersion) if err != nil { slog.Error("Failed to create database connection.", "error", err) diff --git a/server/worker.go b/server/worker.go index 624d4d0..69c6a6d 100644 --- a/server/worker.go +++ b/server/worker.go @@ -295,7 +295,10 @@ func runChildWorker(tcpConn *net.TCPConn, cfg *ChildConfig) int { } if err := bootstrapBundledExtensions(serverCfg.DataDir); err != nil { - slog.Warn("Failed to bootstrap bundled DuckDB extensions.", "source", bundledDuckDBExtensionsDir, "extension_directory", filepath.Join(serverCfg.DataDir, "extensions"), "error", err) + slog.Error("Failed to bootstrap bundled DuckDB extensions.", "source", bundledDuckDBExtensionsDir, "extension_directory", filepath.Join(serverCfg.DataDir, "extensions"), "error", err) + _ = writeErrorResponse(writer, "FATAL", "58000", fmt.Sprintf("failed to prepare bundled extensions: %v", err)) + _ = writer.Flush() + return ExitError } // Create DuckDB connection