Skip to content
288 changes: 167 additions & 121 deletions pgutils/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import (
"errors"
"fmt"
"log"
"net"
"net/url"
"time"
"strings"

"database/sql"
"database/sql/driver"
Expand All @@ -20,109 +21,161 @@ import (
"github.com/lib/pq"
)

type baseConnectionStringProvider interface {
getBaseConnectionString(ctx context.Context) (string, error)
}
const defaultPostgresPort = "5432"

var pqDriver = &pq.Driver{}

type PostgresqlConnector struct {
baseConnectionStringProvider
searchPath string
// ConnectionStringProvider returns a Postgres connection string for use by clients
// that need a DSN (e.g., pq.Listener) or to build a connector.
type ConnectionStringProvider interface {
ConnectionString(ctx context.Context) (string, error)
}

func (conn *PostgresqlConnector) WithSearchPath(searchPath string) *PostgresqlConnector {
return &PostgresqlConnector{
baseConnectionStringProvider: conn.baseConnectionStringProvider,
searchPath: searchPath,
}
type connectionStringProviderFunc func(context.Context) (string, error)

func (f connectionStringProviderFunc) ConnectionString(ctx context.Context) (string, error) {
return f(ctx)
}

func (conn *PostgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
dsn, err := conn.GetConnectionString(ctx)
// NewConnectionStringProviderFromURLString parses rawURL and constructs a provider.
//
// Standard Postgres example:
//
// postgres://<user>:<pass>@<host>:<port>/<db-name>?sslmode=require
//
// IAM example 1:
//
// postgres+rds-iam://<user>@<rds-endpoint>:<port>/<db-name>
//
// IAM example 2 (cross-account):
//
// postgres+rds-iam://<user>@<rds-endpoint>:<port>/<db-name>?assume_role_arn=<...>&assume_role_session_name=<...>
//
// For postgres+rds-iam, the provider generates a fresh IAM auth token on each ConnectionString(ctx) call.
func NewConnectionStringProviderFromURLString(ctx context.Context, rawURL string) (ConnectionStringProvider, error) {
u, err := url.Parse(rawURL)
if err != nil {
return nil, fmt.Errorf("get connection string: %w", err)
return nil, fmt.Errorf("parsing URL: %w", err)
}
pqConnector, err := pq.NewConnector(dsn)
if err != nil {
return nil, fmt.Errorf("create pq connector: %w", err)

switch u.Scheme {
case "postgres", "postgresql":
return &staticConnectionStringProvider{connectionString: u.String()}, nil
case "postgres+rds-iam":
return newIAMConnectionStringProviderFromURL(ctx, u)
default:
return nil, fmt.Errorf("unsupported URL scheme: %q (expected postgres, postgresql, or postgres+rds-iam)", u.Scheme)
}
}

return pqConnector.Connect(ctx)
// ToConnector wraps a ConnectionStringProvider as a driver.Connector.
// Each Connect(ctx) call asks the provider for a fresh DSN.
func ToConnector(provider ConnectionStringProvider) driver.Connector {
return &postgresqlConnector{connectionStringProvider: provider}
}

func (conn *PostgresqlConnector) GetConnectionString(ctx context.Context) (string, error) {
dsn, err := conn.getBaseConnectionString(ctx)
if err != nil {
return "", fmt.Errorf("get base connection string: %w", err)
// WithSchemaSearchPath returns a ConnectionStringProvider that appends search_path
// to the DSN produced by the underlying provider.
func WithSchemaSearchPath(provider ConnectionStringProvider, searchPath string) ConnectionStringProvider {
return connectionStringProviderFunc(func(ctx context.Context) (string, error) {
dsn, err := provider.ConnectionString(ctx)
if err != nil {
return "", fmt.Errorf("ConnectionString failed: %w", err)
}

dsnWithPath, err := addSearchPathToURL(dsn, searchPath)
if err != nil {
return "", fmt.Errorf("applying schema search path failed: %w", err)
}

return dsnWithPath, nil
})
}

// ConnectDB opens a connection using the connector and verifies it with a ping
func ConnectDB(conn driver.Connector) (*sqlx.DB, error) {
sqlDB := sql.OpenDB(conn)
Copy link
Contributor

@leslie-corbalt leslie-corbalt Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is less code:
func ConnectDB(conn driver.Connector) (*sqlx.DB, error) {
db := sqlx.NewDb(sql.OpenDB(conn), "postgres")
if err := db.Ping(); err != nil {
db.Close()
return nil, err
}
return db, nil
}

You don't need the intermediate variable sqlDB.

db := sqlx.NewDb(sqlDB, "postgres")
if err := db.Ping(); err != nil {
db.Close()
return nil, err
}
if conn.searchPath == "" {
return dsn, nil
return db, nil
}

// MustConnectDB is like ConnectDB but panics on error
func MustConnectDB(conn driver.Connector) *sqlx.DB {
db, err := ConnectDB(conn)
if err != nil {
panic(err)
}
return db
}

// Add search path
u, err := url.Parse(dsn)
// addSearchPathToURL returns a copy of u with search_path set in the query string.
// It returns an error if search_path is already present.
func addSearchPathToURL(rawURL string, searchPath string) (string, error) {
u, err := url.Parse(rawURL)
if err != nil {
return "", fmt.Errorf("parse DSN URL: %w", err)
return "", fmt.Errorf("url string failed to parse while adding search path: %w", err)
}

if searchPath == "" {
return u.String(), nil
}

q := u.Query()
if v := q.Get("search_path"); v != "" {
return "", fmt.Errorf("search_path already set to %q", v)
}
q.Set("search_path", conn.searchPath) // url.Values will percent-encode commas as needed
q.Set("search_path", searchPath)
u.RawQuery = q.Encode()
return u.String(), nil
}

func (c *PostgresqlConnector) Driver() driver.Driver {
return &pq.Driver{}
type postgresqlConnector struct {
connectionStringProvider ConnectionStringProvider
}

type staticConnectionStringProvider struct {
connectionString string
}
func (c *postgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
dsn, err := c.connectionStringProvider.ConnectionString(ctx)
if err != nil {
return nil, fmt.Errorf("getting connection string from provider: %w", err)
}
pqConnector, err := pq.NewConnector(dsn)
if err != nil {
return nil, fmt.Errorf("creating pq connector: %w", err)
}

func (p *staticConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) {
return p.connectionString, nil
return pqConnector.Connect(ctx)
}

func NewPostgresqlConnectorFromConnectionString(connectionString string) *PostgresqlConnector {
return &PostgresqlConnector{
baseConnectionStringProvider: &staticConnectionStringProvider{connectionString},
}
func (c *postgresqlConnector) Driver() driver.Driver {
return pqDriver
}

type IAMAuthConfig struct {
RDSEndpoint string
User string
Database string

// Optional: cross-account role assumption.
// Set this to a role ARN in the RDS account (Account A) that has rds-db:connect.
AssumeRoleARN string

// Optional: if your trust policy requires an external ID.
AssumeRoleExternalID string

// Optional: override the default session name.
AssumeRoleSessionName string

// Optional: override STS assume role duration.
// If zero, SDK default is used.
AssumeRoleDuration time.Duration
type staticConnectionStringProvider struct {
connectionString string
}

type iamAuthConnectionStringProvider struct {
IAMAuthConfig
func (p *staticConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) {
return p.connectionString, nil
}

region string
creds aws.CredentialsProvider
type rdsIAMConnectionStringProvider struct {
RDSEndpoint string
Region string
User string
Database string
CredentialsProvider aws.CredentialsProvider
}

func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) {
authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.region, p.User, p.creds)
func (p *rdsIAMConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) {
authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.Region, p.User, p.CredentialsProvider)
if err != nil {
return "", fmt.Errorf("building auth token: %w", err)
}
log.Printf("Signing RDS IAM token for \n Endpoint: %s \n User: %s \n Database: %s", p.RDSEndpoint, p.User, p.Database)
log.Printf("Signing RDS IAM token for Endpoint: %s User: %s Database: %s", p.RDSEndpoint, p.User, p.Database)

dsnURL := &url.URL{
Scheme: "postgresql",
Expand All @@ -134,9 +187,43 @@ func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Co
return dsnURL.String(), nil
}

func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) (*PostgresqlConnector, error) {
if cfg.RDSEndpoint == "" || cfg.User == "" || cfg.Database == "" {
return nil, errors.New("RDS endpoint, user, and database are required")
func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL) (ConnectionStringProvider, error) {
user := ""
if u.User != nil {
user = u.User.Username()
if _, hasPw := u.User.Password(); hasPw {
return nil, errors.New("postgres+rds-iam URL must not include a password")
}
}
if user == "" {
return nil, errors.New("postgres+rds-iam URL missing username")
}

host := u.Hostname()
if host == "" {
return nil, errors.New("postgres+rds-iam URL missing host")
}

port := u.Port()
if port == "" {
port = defaultPostgresPort
}

// Match libpq/psql defaulting: if dbname isn't specified, dbname defaults to username.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean: Match libpq/psql defaulting. Sounds cryptic to me.
Also, why set a default? Why not force the user to set this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the words after the ':' describe in more details what that means. if dbname isn't specified, dbname defaults to username. Is that confusing?

I mean we historically have relied on this all over the place. We have lots of connection strings that don't have /postgres at the end to specify the database, instead we rely on lib/psql behavior where if the username is 'postgres' then the database will be set to 'postgres' if you don't set it.

dbName := strings.TrimPrefix(u.Path, "/")
if dbName == "" {
dbName = user
}

q := u.Query()
supportedParams := map[string]struct{}{
"assume_role_arn": {},
"assume_role_session_name": {},
}
for k := range q {
if _, ok := supportedParams[k]; !ok {
return nil, fmt.Errorf("postgres+rds-iam URL has unsupported query parameter: %s", k)
}
}

awsCfg, err := awsconfig.LoadDefaultConfig(ctx)
Expand All @@ -149,66 +236,25 @@ func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig)
}

creds := awsCfg.Credentials

// Cross-account support:
// If AssumeRoleARN is set, assume a role in the RDS account (Account A)
// using the ECS task role creds from Account B as the source credentials.
if cfg.AssumeRoleARN != "" {
log.Printf("RDS IAM Assuming Role: %s for \n Endpoint: %s \n User: %s \n Database: %s", cfg.AssumeRoleARN, cfg.RDSEndpoint, cfg.User, cfg.Database)
assumeRoleARN := q.Get("assume_role_arn")
if assumeRoleARN != "" {
stsClient := sts.NewFromConfig(awsCfg)

sessionName := cfg.AssumeRoleSessionName
sessionName := q.Get("assume_role_session_name")
if sessionName == "" {
sessionName = "pgutils-rds-iam"
}

assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, cfg.AssumeRoleARN, func(assumeRoleOpts *stscreds.AssumeRoleOptions) {
assumeRoleOpts.RoleSessionName = sessionName

if cfg.AssumeRoleExternalID != "" {
assumeRoleOpts.ExternalID = aws.String(cfg.AssumeRoleExternalID)
}

if cfg.AssumeRoleDuration != 0 {
assumeRoleOpts.Duration = cfg.AssumeRoleDuration
}
log.Printf("RDS IAM Assuming Role: %s with session name: %s for Host: %s User: %s Database: %s", assumeRoleARN, sessionName, host, user, dbName)
assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, assumeRoleARN, func(opts *stscreds.AssumeRoleOptions) {
opts.RoleSessionName = sessionName
})

// Cache to avoid calling STS too frequently.
creds = aws.NewCredentialsCache(assumeProvider)
}

return &PostgresqlConnector{
baseConnectionStringProvider: &iamAuthConnectionStringProvider{
IAMAuthConfig: *cfg,
region: awsCfg.Region,
creds: creds,
},
return &rdsIAMConnectionStringProvider{
Region: awsCfg.Region,
RDSEndpoint: net.JoinHostPort(host, port),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you have Endpoint including port.
You mentioned that RDSEndpoint does not include port.

User: user,
Database: dbName,
CredentialsProvider: creds,
}, nil
}

// Provides missing sqlx.OpenDB
func OpenDB(conn *PostgresqlConnector) *sqlx.DB {
sqlDB := sql.OpenDB(conn)
return sqlx.NewDb(sqlDB, "postgres")
}

// ConnectDB opens a connection using the connector and verifies it with a ping
func ConnectDB(conn *PostgresqlConnector) (*sqlx.DB, error) {
db := OpenDB(conn)
if err := db.Ping(); err != nil {
db.Close()
return nil, err
}
return db, nil
}

// MustConnectDB is like ConnectDB but panics on error
func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB {
db, err := ConnectDB(conn)
if err != nil {
panic(err)
}
return db
}

7 changes: 3 additions & 4 deletions pgutils/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,17 @@ func listenerEventToString(t pq.ListenerEventType) string {
// The callback is invoked from the listener goroutine; it MUST NOT block
// for long periods. If you need to do heavy work, offload it to another
// goroutine.
func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string, callback func(*pq.Notification), onClose func()) error {
func Listen(ctx context.Context, provider ConnectionStringProvider, pgChannelName string, callback func(*pq.Notification), onClose func()) error {
if callback == nil {
return fmt.Errorf("listener callback cannot be nil")
}

reconnectEventCh := make(chan struct{}, 1) // We just need a single reconnect event to trigger, so buffer size of 1

makeListener := func() (*pq.Listener, error) {
url, err := conn.GetConnectionString(ctx)
url, err := provider.ConnectionString(ctx)
if err != nil {
return nil, fmt.Errorf("get url: %w", err)
return nil, fmt.Errorf("error getting connection string from provider: %w", err)
}

cb := func(t pq.ListenerEventType, e error) {
Expand Down Expand Up @@ -174,4 +174,3 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string

return nil
}

Loading