Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions duckdbservice/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import (
"google.golang.org/grpc/credentials"
)

var bootstrapBundledExtensions = server.BootstrapBundledExtensions
var exitProcess = os.Exit

// DuckDBService is a standalone Arrow Flight SQL service backed by DuckDB.
type DuckDBService struct {
cfg ServiceConfig
Expand Down Expand Up @@ -198,6 +201,14 @@ func (p *SessionPool) reapLoop() {
func Run(cfg ServiceConfig) {
svc := NewDuckDBService(cfg)

if err := bootstrapBundledExtensions(cfg.ServerConfig.DataDir); err != nil {
slog.Error("Failed to bootstrap bundled DuckDB extensions.",
"source", "/app/extensions",
"extension_directory", cfg.ServerConfig.DataDir+"/extensions",
"error", err)
exitProcess(1)
}

// Pre-warm the DuckDB instance (load extensions, attach DuckLake)
// in the background so we don't block the gRPC server from starting.
// This ensures that waitForWorker doesn't time out during spawn.
Expand Down
41 changes: 41 additions & 0 deletions duckdbservice/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ package duckdbservice
import (
"context"
"database/sql"
"errors"
"testing"

_ "github.com/duckdb/duckdb-go/v2"
)

type exitPanic struct {
code int
}

func TestInitSearchPath(t *testing.T) {
db, err := sql.Open("duckdb", "")
if err != nil {
Expand Down Expand Up @@ -57,3 +62,39 @@ func TestInitSearchPath(t *testing.T) {
}
})
}

func TestRunExitsWhenBundledExtensionBootstrapFails(t *testing.T) {
prevBootstrap := bootstrapBundledExtensions
prevExit := exitProcess
defer func() {
bootstrapBundledExtensions = prevBootstrap
exitProcess = prevExit
}()

bootstrapBundledExtensions = func(string) error {
return errors.New("boom")
}

exitCode := -1
exitProcess = func(code int) {
exitCode = code
panic(exitPanic{code: code})
}

defer func() {
r := recover()
p, ok := r.(exitPanic)
if !ok {
t.Fatalf("expected exit panic, got %v", r)
}
if p.code != 1 {
t.Fatalf("expected exit code 1, got %d", p.code)
}
if exitCode != 1 {
t.Fatalf("expected exitProcess to be called with 1, got %d", exitCode)
}
}()

Run(ServiceConfig{})
t.Fatal("expected Run to exit")
}
59 changes: 59 additions & 0 deletions server/bundled_extensions_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package server

import (
"database/sql"
"os"
"path/filepath"
"sync"
"testing"

_ "github.com/duckdb/duckdb-go/v2"
)

func TestSeedBundledExtensionsCopiesMissingFiles(t *testing.T) {
Expand Down Expand Up @@ -226,3 +229,59 @@ func TestBootstrapBundledExtensionsRunsOncePerExtensionDirectory(t *testing.T) {
t.Fatalf("expected bootstrap to run once, got %q", string(got))
}
}

func TestSetExtensionDirectorySetsPathAfterBootstrap(t *testing.T) {
bundledRoot := t.TempDir()
dataDir := t.TempDir()

srcDir := filepath.Join(bundledRoot, "v1.5.2", "linux_arm64")
if err := os.MkdirAll(srcDir, 0o755); err != nil {
t.Fatalf("mkdir src: %v", err)
}
srcExt := filepath.Join(srcDir, "postgres_scanner.duckdb_extension")
if err := os.WriteFile(srcExt, []byte("nightly"), 0o644); err != nil {
t.Fatalf("write src extension: %v", err)
}

prevBundledRoot := bundledDuckDBExtensionsDir
bundledDuckDBExtensionsDir = bundledRoot
defer func() { bundledDuckDBExtensionsDir = prevBundledRoot }()

bundledExtensionBootstrap = struct {
mu sync.Mutex
byPath map[string]error
}{}

db, err := sql.Open("duckdb", ":memory:")
if err != nil {
t.Fatalf("open duckdb: %v", err)
}
defer func() { _ = db.Close() }()

if err := bootstrapBundledExtensions(dataDir); err != nil {
t.Fatalf("bootstrapBundledExtensions: %v", err)
}

if err := setExtensionDirectory(db, dataDir); err != nil {
t.Fatalf("setExtensionDirectory: %v", err)
}

var gotExtDir string
if err := db.QueryRow("SELECT current_setting('extension_directory')").Scan(&gotExtDir); err != nil {
t.Fatalf("query extension_directory: %v", err)
}

wantExtDir := filepath.Join(dataDir, "extensions")
if gotExtDir != wantExtDir {
t.Fatalf("extension_directory = %q, want %q", gotExtDir, wantExtDir)
}

dstExt := filepath.Join(wantExtDir, "v1.5.2", "linux_arm64", "postgres_scanner.duckdb_extension")
got, err := os.ReadFile(dstExt)
if err != nil {
t.Fatalf("read dst extension: %v", err)
}
if string(got) != "nightly" {
t.Fatalf("expected seeded extension to match bundled contents, got %q", string(got))
}
}
8 changes: 3 additions & 5 deletions server/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"database/sql"
"fmt"
"log/slog"
"path/filepath"
"sync"
"time"
)
Expand Down Expand Up @@ -32,13 +31,12 @@ func NewDuckLakeCheckpointer(cfg Config) (*DuckLakeCheckpointer, error) {
return nil, fmt.Errorf("checkpoint: open duckdb: %w", err)
}

extDir := filepath.Join(cfg.DataDir, "extensions")
if _, err := db.Exec(fmt.Sprintf("SET extension_directory = '%s'", extDir)); err != nil {
if err := setExtensionDirectory(db, cfg.DataDir); err != nil {
_ = db.Close()
return nil, fmt.Errorf("checkpoint: set extension_directory: %w", err)
return nil, fmt.Errorf("checkpoint: set extension directory: %w", err)
}

if _, err := db.Exec("INSTALL ducklake; LOAD ducklake"); err != nil {
if err := LoadExtensions(db, []string{"ducklake"}); err != nil {
_ = db.Close()
return nil, fmt.Errorf("checkpoint: load ducklake: %w", err)
}
Expand Down
10 changes: 3 additions & 7 deletions server/querylog.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"hash/fnv"
"log/slog"
"net"
"path/filepath"
"regexp"
"strings"
"sync"
Expand Down Expand Up @@ -69,15 +68,12 @@ func NewQueryLogger(cfg Config) (*QueryLogger, error) {
return nil, fmt.Errorf("querylog: open duckdb: %w", err)
}

// Set extension directory under DataDir so DuckDB doesn't rely on $HOME/.duckdb
extDir := filepath.Join(cfg.DataDir, "extensions")
if _, err := db.Exec(fmt.Sprintf("SET extension_directory = '%s'", extDir)); err != nil {
if err := setExtensionDirectory(db, cfg.DataDir); err != nil {
_ = db.Close()
return nil, fmt.Errorf("querylog: set extension_directory: %w", err)
return nil, fmt.Errorf("querylog: set extension directory: %w", err)
}

// Load ducklake extension
if _, err := db.Exec("INSTALL ducklake; LOAD ducklake"); err != nil {
if err := LoadExtensions(db, []string{"ducklake"}); err != nil {
_ = db.Close()
return nil, fmt.Errorf("querylog: load ducklake: %w", err)
}
Expand Down
50 changes: 37 additions & 13 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ func bootstrapBundledExtensions(dataDir string) error {
return err
}

// BootstrapBundledExtensions eagerly seeds bundled extension binaries into the
// configured extension_directory cache once per data directory.
func BootstrapBundledExtensions(dataDir string) error {
return bootstrapBundledExtensions(dataDir)
}

func setExtensionDirectory(db *sql.DB, dataDir string) error {
extDir := filepath.Join(dataDir, "extensions")
if _, err := db.Exec(fmt.Sprintf("SET extension_directory = '%s'", extDir)); err != nil {
return fmt.Errorf("set extension_directory %s: %w", extDir, err)
}

return nil
}

// passwordPattern matches password=<value> or password: <value> with quoted or unquoted values.
var passwordPattern = regexp.MustCompile(`(?i)(password\s*[=:]\s*)("[^"]*"|[^\s"]+)`)

Expand Down Expand Up @@ -500,7 +515,7 @@ func New(cfg Config) (*Server, error) {
}

if err := bootstrapBundledExtensions(cfg.DataDir); err != nil {
slog.Warn("Failed to bootstrap bundled DuckDB extensions.", "source", bundledDuckDBExtensionsDir, "extension_directory", filepath.Join(cfg.DataDir, "extensions"), "error", err)
return nil, fmt.Errorf("failed to bootstrap bundled DuckDB extensions: %w", err)
}

// Initialize query logger (non-fatal on error)
Expand Down Expand Up @@ -896,13 +911,9 @@ func openBaseDB(cfg Config, username string) (*sql.DB, error) {
slog.Debug("Set DuckDB temp_directory.", "temp_directory", tempDir)
}

// Set extension directory under DataDir so DuckDB doesn't rely on $HOME/.duckdb
// for autoloading/installing extensions.
extDir := filepath.Join(cfg.DataDir, "extensions")
if _, err := db.Exec(fmt.Sprintf("SET extension_directory = '%s'", extDir)); err != nil {
slog.Warn("Failed to set DuckDB extension_directory.", "extension_directory", extDir, "error", err)
} else {
slog.Debug("Set DuckDB extension_directory.", "extension_directory", extDir)
if err := setExtensionDirectory(db, cfg.DataDir); err != nil {
_ = db.Close()
return nil, fmt.Errorf("failed to configure extension_directory: %w", err)
}

// Load configured extensions
Expand Down Expand Up @@ -1200,11 +1211,15 @@ func LoadExtensions(db *sql.DB, extensions []string) error {
for _, ext := range extensions {
name, installCmd := parseExtensionName(ext)

// First install the extension (downloads if needed)
if _, err := db.Exec("INSTALL " + installCmd); err != nil {
slog.Warn("Failed to install extension.", "extension", installCmd, "error", err)
lastErr = err
continue
if shouldInstallExtension(name) {
// First install the extension (downloads if needed). Bundled extensions
// are preseeded into the extension cache and INSTALL can overwrite that
// bundled binary with DuckDB's repository copy.
if _, err := db.Exec("INSTALL " + installCmd); err != nil {
slog.Warn("Failed to install extension.", "extension", installCmd, "error", err)
lastErr = err
continue
}
}

// Then load it into the current session
Expand All @@ -1220,6 +1235,15 @@ func LoadExtensions(db *sql.DB, extensions []string) error {
return lastErr
}

func shouldInstallExtension(name string) bool {
return !hasBundledExtensionBinary(name)
}

func hasBundledExtensionBinary(name string) bool {
matches, err := filepath.Glob(filepath.Join(bundledDuckDBExtensionsDir, "*", "*", name+".duckdb_extension"))
return err == nil && len(matches) > 0
}

func boolPtr(v bool) *bool { return &v }

func duckLakeDisableMetadataThreadLocalCacheEnabled(dlCfg DuckLakeConfig) bool {
Expand Down
38 changes: 38 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,44 @@ func TestParseExtensionName(t *testing.T) {
}
}

func TestNewFailsWhenBundledExtensionBootstrapFails(t *testing.T) {
bundledRoot := filepath.Join(t.TempDir(), "extensions-file")
if err := os.WriteFile(bundledRoot, []byte("not-a-directory"), 0o644); err != nil {
t.Fatalf("write bundled root file: %v", err)
}

prevBundledRoot := bundledDuckDBExtensionsDir
bundledDuckDBExtensionsDir = bundledRoot
defer func() { bundledDuckDBExtensionsDir = prevBundledRoot }()

bundledExtensionBootstrap = struct {
mu sync.Mutex
byPath map[string]error
}{}

certDir := t.TempDir()
certFile := filepath.Join(certDir, "server.crt")
keyFile := filepath.Join(certDir, "server.key")
if err := generateSelfSignedCert(certFile, keyFile); err != nil {
t.Fatalf("generateSelfSignedCert: %v", err)
}

_, err := New(Config{
Host: "127.0.0.1",
Port: 5432,
DataDir: t.TempDir(),
Users: map[string]string{"postgres": "postgres"},
TLSCertFile: certFile,
TLSKeyFile: keyFile,
})
if err == nil {
t.Fatal("expected bootstrap failure")
}
if !strings.Contains(err.Error(), "bootstrap bundled DuckDB extensions") {
t.Fatalf("expected bootstrap error, got %v", err)
}
}

func TestNeedsCredentialRefresh(t *testing.T) {
tests := []struct {
name string
Expand Down
4 changes: 2 additions & 2 deletions server/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ import (
func RunShell(cfg Config) {
sem := make(chan struct{}, 1)
if err := bootstrapBundledExtensions(cfg.DataDir); err != nil {
slog.Warn("Failed to bootstrap bundled DuckDB extensions.", "source", bundledDuckDBExtensionsDir, "extension_directory", filepath.Join(cfg.DataDir, "extensions"), "error", err)
slog.Error("Failed to bootstrap bundled DuckDB extensions.", "source", bundledDuckDBExtensionsDir, "extension_directory", filepath.Join(cfg.DataDir, "extensions"), "error", err)
os.Exit(1)
}

db, err := CreateDBConnection(cfg, sem, "shell", processStartTime, processVersion)
if err != nil {
slog.Error("Failed to create database connection.", "error", err)
Expand Down
5 changes: 4 additions & 1 deletion server/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,10 @@ func runChildWorker(tcpConn *net.TCPConn, cfg *ChildConfig) int {
}

if err := bootstrapBundledExtensions(serverCfg.DataDir); err != nil {
slog.Warn("Failed to bootstrap bundled DuckDB extensions.", "source", bundledDuckDBExtensionsDir, "extension_directory", filepath.Join(serverCfg.DataDir, "extensions"), "error", err)
slog.Error("Failed to bootstrap bundled DuckDB extensions.", "source", bundledDuckDBExtensionsDir, "extension_directory", filepath.Join(serverCfg.DataDir, "extensions"), "error", err)
_ = writeErrorResponse(writer, "FATAL", "58000", fmt.Sprintf("failed to prepare bundled extensions: %v", err))
_ = writer.Flush()
return ExitError
}

// Create DuckDB connection
Expand Down
Loading