diff --git a/database/database.go b/database/database.go index 67543936..dec12d18 100644 --- a/database/database.go +++ b/database/database.go @@ -40,7 +40,7 @@ func New(ctx context.Context, logger log.Logger, config DatabaseConfig) (*sql.DB if err != nil { return nil, fmt.Errorf("connecting to postgres: %w", err) } - return ApplyConnectionsConfig(db, &config.Postgres.Connections, logger), nil + return db, nil } return nil, ErrMissingConfig @@ -112,3 +112,4 @@ func ApplyConnectionsConfig(db *sql.DB, connections *ConnectionsConfig, logger l return db } + diff --git a/database/database_test.go b/database/database_test.go index a648aa02..bc17f75c 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -121,6 +121,7 @@ func TestDataTooLong(t *testing.T) { } } + func TestConnectionsConfigOrder(t *testing.T) { bs, err := os.ReadFile("database.go") require.NoError(t, err) diff --git a/database/postgres.go b/database/postgres.go index 7c98cdd2..dd326cd5 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -7,10 +7,11 @@ import ( "fmt" "net" "strings" + "time" "cloud.google.com/go/alloydbconn" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/stdlib" "github.com/moov-io/base/log" ) @@ -23,35 +24,81 @@ const ( ) func postgresConnection(ctx context.Context, logger log.Logger, config PostgresConfig, databaseName string) (*sql.DB, error) { - var connStr string - if config.Alloy != nil { - c, err := getAlloyDBConnectorConnStr(ctx, config, databaseName) - if err != nil { - return nil, logger.LogErrorf("creating alloydb connection: %w", err).Err() - } - connStr = c - } else { - c, err := getPostgresConnStr(config, databaseName) - if err != nil { - return nil, logger.LogErrorf("creating postgres connection: %w", err).Err() - } - connStr = c + poolConfig, err := buildPgxPoolConfig(ctx, config, databaseName) + if err != nil { + return nil, logger.LogErrorf("building pgx pool config: %w", err).Err() } - db, err := sql.Open("pgx", connStr) + applyPgxPoolConnectionsConfig(logger, poolConfig, config.Connections) + + // Ping connections that have been idle for more than 200ms before handing + // them to the caller. This catches dead connections left by an AlloyDB + // switchover before a query is attempted, without adding overhead on + // hot connections used moments ago. + // HealthCheckPeriod (the background reaper) does NOT ping — it only evicts + // connections that have exceeded their age thresholds. ShouldPing is the + // mechanism that actually tests liveness at acquire time. + poolConfig.ShouldPing = func(_ context.Context, p pgxpool.ShouldPingParams) bool { + return p.IdleDuration > 200*time.Millisecond + } + + pool, err := pgxpool.NewWithConfig(ctx, poolConfig) if err != nil { - return nil, logger.LogErrorf("opening database: %w", err).Err() + return nil, logger.LogErrorf("creating pgx pool: %w", err).Err() } - err = db.Ping() + err = pool.Ping(ctx) if err != nil { - _ = db.Close() + pool.Close() return nil, logger.LogErrorf("connecting to database: %w", err).Err() } + // OpenDBFromPool wraps pgxpool in a *sql.DB for compatibility with the rest + // of the codebase. It automatically sets MaxIdleConns to 0 on the sql.DB — + // this must not be overridden, as pgxpool manages its own connection pool + // and a non-zero value would prevent connections from being released back. + db := stdlib.OpenDBFromPool(pool) + return db, nil } +// applyPgxPoolConnectionsConfig translates ConnectionsConfig onto a pgxpool.Config. +// MaxIdle has no pgxpool equivalent — pgxpool caps total connections via MaxConns +// rather than idle count, and keeps a floor via MinConns. When set, MaxIdle is +// logged and ignored so operators aren't misled into thinking it took effect. +func applyPgxPoolConnectionsConfig(logger log.Logger, poolConfig *pgxpool.Config, connections ConnectionsConfig) { + if connections.MaxOpen > 0 { + logger.Logf("setting pgx pool MaxConns to %d", connections.MaxOpen) + poolConfig.MaxConns = int32(connections.MaxOpen) + } + + if connections.MaxIdle > 0 { + logger.Logf("ignoring ConnectionsConfig.MaxIdle=%d: pgxpool has no MaxIdle equivalent", connections.MaxIdle) + } + + if connections.MaxIdleTime > 0 { + logger.Logf("setting pgx pool MaxConnIdleTime to %v", connections.MaxIdleTime) + poolConfig.MaxConnIdleTime = connections.MaxIdleTime + } + + if connections.MaxLifetime > 0 { + logger.Logf("setting pgx pool MaxConnLifetime to %v", connections.MaxLifetime) + poolConfig.MaxConnLifetime = connections.MaxLifetime + } +} + +func buildPgxPoolConfig(ctx context.Context, config PostgresConfig, databaseName string) (*pgxpool.Config, error) { + if config.Alloy != nil { + return buildAlloyDBPoolConfig(ctx, config, databaseName) + } + + connStr, err := getPostgresConnStr(config, databaseName) + if err != nil { + return nil, err + } + return pgxpool.ParseConfig(connStr) +} + func getPostgresConnStr(config PostgresConfig, databaseName string) (string, error) { url := fmt.Sprintf("postgres://%s:%s@%s/%s", config.User, config.Password, config.Address, databaseName) @@ -81,9 +128,9 @@ func getPostgresConnStr(config PostgresConfig, databaseName string) (string, err return connStr, nil } -func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, databaseName string) (string, error) { +func buildAlloyDBPoolConfig(ctx context.Context, config PostgresConfig, databaseName string) (*pgxpool.Config, error) { if config.Alloy == nil { - return "", fmt.Errorf("missing alloy config") + return nil, fmt.Errorf("missing alloy config") } var dialer *alloydbconn.Dialer @@ -92,7 +139,7 @@ func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, data if config.Alloy.UseIAM { d, err := alloydbconn.NewDialer(ctx, alloydbconn.WithIAMAuthN()) if err != nil { - return "", fmt.Errorf("creating alloydb dialer: %v", err) + return nil, fmt.Errorf("creating alloydb dialer: %v", err) } dialer = d dsn = fmt.Sprintf( @@ -104,7 +151,7 @@ func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, data } else { d, err := alloydbconn.NewDialer(ctx) if err != nil { - return "", fmt.Errorf("creating alloydb dialer: %v", err) + return nil, fmt.Errorf("creating alloydb dialer: %v", err) } dialer = d dsn = fmt.Sprintf( @@ -114,12 +161,9 @@ func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, data ) } - // TODO - //cleanup := func() error { return d.Close() } - - connConfig, err := pgx.ParseConfig(dsn) + poolConfig, err := pgxpool.ParseConfig(dsn) if err != nil { - return "", fmt.Errorf("failed to parse pgx config: %v", err) + return nil, fmt.Errorf("failed to parse pgx pool config: %v", err) } var connOptions []alloydbconn.DialOption @@ -127,12 +171,11 @@ func getAlloyDBConnectorConnStr(ctx context.Context, config PostgresConfig, data connOptions = append(connOptions, alloydbconn.WithPSC()) } - connConfig.DialFunc = func(ctx context.Context, _ string, _ string) (net.Conn, error) { + poolConfig.ConnConfig.DialFunc = func(ctx context.Context, _ string, _ string) (net.Conn, error) { return dialer.Dial(ctx, config.Alloy.InstanceURI, connOptions...) } - connStr := stdlib.RegisterConnConfig(connConfig) - return connStr, nil + return poolConfig, nil } // PostgresUniqueViolation returns true when the provided error matches the Postgres code @@ -164,3 +207,4 @@ func PostgresDeadlockFound(err error) bool { return strings.Contains(err.Error(), postgresErrDeadlockFound) } +