Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func New(ctx context.Context, logger log.Logger, config DatabaseConfig) (*sql.DB
if err != nil {
return nil, fmt.Errorf("connecting to postgres: %w", err)
}
return ApplyConnectionsConfig(db, &config.Postgres.Connections, logger), nil
return db, nil
}

return nil, ErrMissingConfig
Expand Down Expand Up @@ -112,3 +112,4 @@ func ApplyConnectionsConfig(db *sql.DB, connections *ConnectionsConfig, logger l

return db
}

1 change: 1 addition & 0 deletions database/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ func TestDataTooLong(t *testing.T) {
}
}


func TestConnectionsConfigOrder(t *testing.T) {
bs, err := os.ReadFile("database.go")
require.NoError(t, err)
Expand Down
104 changes: 74 additions & 30 deletions database/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import (
"fmt"
"net"
"strings"
"time"

"cloud.google.com/go/alloydbconn"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/moov-io/base/log"
)
Expand All @@ -23,35 +24,81 @@ const (
)

func postgresConnection(ctx context.Context, logger log.Logger, config PostgresConfig, databaseName string) (*sql.DB, error) {
var connStr string
if config.Alloy != nil {
c, err := getAlloyDBConnectorConnStr(ctx, config, databaseName)
if err != nil {
return nil, logger.LogErrorf("creating alloydb connection: %w", err).Err()
}
connStr = c
} else {
c, err := getPostgresConnStr(config, databaseName)
if err != nil {
return nil, logger.LogErrorf("creating postgres connection: %w", err).Err()
}
connStr = c
poolConfig, err := buildPgxPoolConfig(ctx, config, databaseName)
if err != nil {
return nil, logger.LogErrorf("building pgx pool config: %w", err).Err()
}

db, err := sql.Open("pgx", connStr)
applyPgxPoolConnectionsConfig(logger, poolConfig, config.Connections)

// Ping connections that have been idle for more than 200ms before handing
// them to the caller. This catches dead connections left by an AlloyDB
// switchover before a query is attempted, without adding overhead on
// hot connections used moments ago.
// HealthCheckPeriod (the background reaper) does NOT ping — it only evicts
// connections that have exceeded their age thresholds. ShouldPing is the
// mechanism that actually tests liveness at acquire time.
poolConfig.ShouldPing = func(_ context.Context, p pgxpool.ShouldPingParams) bool {
return p.IdleDuration > 200*time.Millisecond
}

pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, logger.LogErrorf("opening database: %w", err).Err()
return nil, logger.LogErrorf("creating pgx pool: %w", err).Err()
}

err = db.Ping()
err = pool.Ping(ctx)
if err != nil {
_ = db.Close()
pool.Close()
return nil, logger.LogErrorf("connecting to database: %w", err).Err()
}

// OpenDBFromPool wraps pgxpool in a *sql.DB for compatibility with the rest
// of the codebase. It automatically sets MaxIdleConns to 0 on the sql.DB —
// this must not be overridden, as pgxpool manages its own connection pool
// and a non-zero value would prevent connections from being released back.
db := stdlib.OpenDBFromPool(pool)
Comment thread
stevemsmith marked this conversation as resolved.

return db, nil
}

// applyPgxPoolConnectionsConfig translates ConnectionsConfig onto a pgxpool.Config.
// MaxIdle has no pgxpool equivalent — pgxpool caps total connections via MaxConns
// rather than idle count, and keeps a floor via MinConns. When set, MaxIdle is
// logged and ignored so operators aren't misled into thinking it took effect.
func applyPgxPoolConnectionsConfig(logger log.Logger, poolConfig *pgxpool.Config, connections ConnectionsConfig) {
if connections.MaxOpen > 0 {
logger.Logf("setting pgx pool MaxConns to %d", connections.MaxOpen)
poolConfig.MaxConns = int32(connections.MaxOpen)
}

if connections.MaxIdle > 0 {
logger.Logf("ignoring ConnectionsConfig.MaxIdle=%d: pgxpool has no MaxIdle equivalent", connections.MaxIdle)
}

if connections.MaxIdleTime > 0 {
logger.Logf("setting pgx pool MaxConnIdleTime to %v", connections.MaxIdleTime)
poolConfig.MaxConnIdleTime = connections.MaxIdleTime
}

if connections.MaxLifetime > 0 {
logger.Logf("setting pgx pool MaxConnLifetime to %v", connections.MaxLifetime)
poolConfig.MaxConnLifetime = connections.MaxLifetime
}
}

func buildPgxPoolConfig(ctx context.Context, config PostgresConfig, databaseName string) (*pgxpool.Config, error) {
if config.Alloy != nil {
return buildAlloyDBPoolConfig(ctx, config, databaseName)
}

connStr, err := getPostgresConnStr(config, databaseName)
if err != nil {
return nil, err
}
return pgxpool.ParseConfig(connStr)
}

func getPostgresConnStr(config PostgresConfig, databaseName string) (string, error) {
url := fmt.Sprintf("postgres://%s:%s@%s/%s", config.User, config.Password, config.Address, databaseName)

Expand Down Expand Up @@ -81,9 +128,9 @@ func getPostgresConnStr(config PostgresConfig, databaseName string) (string, err
return connStr, nil
}

func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, databaseName string) (string, error) {
func buildAlloyDBPoolConfig(ctx context.Context, config PostgresConfig, databaseName string) (*pgxpool.Config, error) {
if config.Alloy == nil {
return "", fmt.Errorf("missing alloy config")
return nil, fmt.Errorf("missing alloy config")
}

var dialer *alloydbconn.Dialer
Expand All @@ -92,7 +139,7 @@ func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, data
if config.Alloy.UseIAM {
d, err := alloydbconn.NewDialer(ctx, alloydbconn.WithIAMAuthN())
if err != nil {
return "", fmt.Errorf("creating alloydb dialer: %v", err)
return nil, fmt.Errorf("creating alloydb dialer: %v", err)
}
dialer = d
dsn = fmt.Sprintf(
Expand All @@ -104,7 +151,7 @@ func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, data
} else {
d, err := alloydbconn.NewDialer(ctx)
if err != nil {
return "", fmt.Errorf("creating alloydb dialer: %v", err)
return nil, fmt.Errorf("creating alloydb dialer: %v", err)
}
dialer = d
dsn = fmt.Sprintf(
Expand All @@ -114,25 +161,21 @@ func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, data
)
}

// TODO
//cleanup := func() error { return d.Close() }

connConfig, err := pgx.ParseConfig(dsn)
poolConfig, err := pgxpool.ParseConfig(dsn)
if err != nil {
return "", fmt.Errorf("failed to parse pgx config: %v", err)
return nil, fmt.Errorf("failed to parse pgx pool config: %v", err)
}

var connOptions []alloydbconn.DialOption
if config.Alloy.UsePSC {
connOptions = append(connOptions, alloydbconn.WithPSC())
}

connConfig.DialFunc = func(ctx context.Context, _ string, _ string) (net.Conn, error) {
poolConfig.ConnConfig.DialFunc = func(ctx context.Context, _ string, _ string) (net.Conn, error) {
return dialer.Dial(ctx, config.Alloy.InstanceURI, connOptions...)
}

connStr := stdlib.RegisterConnConfig(connConfig)
return connStr, nil
return poolConfig, nil
}

// PostgresUniqueViolation returns true when the provided error matches the Postgres code
Expand Down Expand Up @@ -164,3 +207,4 @@ func PostgresDeadlockFound(err error) bool {

return strings.Contains(err.Error(), postgresErrDeadlockFound)
}