-
Notifications
You must be signed in to change notification settings - Fork 0
Replace PostgresqlConnector with ConnectionStringProvider #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7e90184
fe2a24e
ddb9337
3cd2b7f
b8b7cfe
5d0d0a3
9002590
9807241
a0d89ad
38ca632
2be12a0
59e5efd
8fe5ff6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,8 +5,9 @@ import ( | |
| "errors" | ||
| "fmt" | ||
| "log" | ||
| "net" | ||
| "net/url" | ||
| "time" | ||
| "strings" | ||
|
|
||
| "database/sql" | ||
| "database/sql/driver" | ||
|
|
@@ -20,109 +21,161 @@ import ( | |
| "github.com/lib/pq" | ||
| ) | ||
|
|
||
| type baseConnectionStringProvider interface { | ||
| getBaseConnectionString(ctx context.Context) (string, error) | ||
| } | ||
| const defaultPostgresPort = "5432" | ||
|
|
||
| var pqDriver = &pq.Driver{} | ||
|
|
||
| type PostgresqlConnector struct { | ||
| baseConnectionStringProvider | ||
| searchPath string | ||
| // ConnectionStringProvider returns a Postgres connection string for use by clients | ||
| // that need a DSN (e.g., pq.Listener) or to build a connector. | ||
| type ConnectionStringProvider interface { | ||
| ConnectionString(ctx context.Context) (string, error) | ||
| } | ||
|
|
||
| func (conn *PostgresqlConnector) WithSearchPath(searchPath string) *PostgresqlConnector { | ||
| return &PostgresqlConnector{ | ||
| baseConnectionStringProvider: conn.baseConnectionStringProvider, | ||
| searchPath: searchPath, | ||
| } | ||
| type connectionStringProviderFunc func(context.Context) (string, error) | ||
|
|
||
| func (f connectionStringProviderFunc) ConnectionString(ctx context.Context) (string, error) { | ||
| return f(ctx) | ||
| } | ||
|
|
||
| func (conn *PostgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { | ||
| dsn, err := conn.GetConnectionString(ctx) | ||
| // NewConnectionStringProviderFromURLString parses rawURL and constructs a provider. | ||
| // | ||
| // Standard Postgres example: | ||
| // | ||
| // postgres://<user>:<pass>@<host>:<port>/<db-name>?sslmode=require | ||
| // | ||
| // IAM example 1: | ||
| // | ||
| // postgres+rds-iam://<user>@<rds-endpoint>:<port>/<db-name> | ||
| // | ||
| // IAM example 2 (cross-account): | ||
| // | ||
| // postgres+rds-iam://<user>@<rds-endpoint>:<port>/<db-name>?assume_role_arn=<...>&assume_role_session_name=<...> | ||
| // | ||
| // For postgres+rds-iam, the provider generates a fresh IAM auth token on each ConnectionString(ctx) call. | ||
| func NewConnectionStringProviderFromURLString(ctx context.Context, rawURL string) (ConnectionStringProvider, error) { | ||
| u, err := url.Parse(rawURL) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("get connection string: %w", err) | ||
| return nil, fmt.Errorf("parsing URL: %w", err) | ||
| } | ||
| pqConnector, err := pq.NewConnector(dsn) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("create pq connector: %w", err) | ||
|
|
||
| switch u.Scheme { | ||
| case "postgres", "postgresql": | ||
| return &staticConnectionStringProvider{connectionString: u.String()}, nil | ||
| case "postgres+rds-iam": | ||
| return newIAMConnectionStringProviderFromURL(ctx, u) | ||
| default: | ||
| return nil, fmt.Errorf("unsupported URL scheme: %q (expected postgres, postgresql, or postgres+rds-iam)", u.Scheme) | ||
| } | ||
| } | ||
|
|
||
| return pqConnector.Connect(ctx) | ||
| // ToConnector wraps a ConnectionStringProvider as a driver.Connector. | ||
| // Each Connect(ctx) call asks the provider for a fresh DSN. | ||
| func ToConnector(provider ConnectionStringProvider) driver.Connector { | ||
| return &postgresqlConnector{connectionStringProvider: provider} | ||
| } | ||
|
|
||
| func (conn *PostgresqlConnector) GetConnectionString(ctx context.Context) (string, error) { | ||
| dsn, err := conn.getBaseConnectionString(ctx) | ||
| if err != nil { | ||
| return "", fmt.Errorf("get base connection string: %w", err) | ||
| // WithSchemaSearchPath returns a ConnectionStringProvider that appends search_path | ||
| // to the DSN produced by the underlying provider. | ||
| func WithSchemaSearchPath(provider ConnectionStringProvider, searchPath string) ConnectionStringProvider { | ||
| return connectionStringProviderFunc(func(ctx context.Context) (string, error) { | ||
| dsn, err := provider.ConnectionString(ctx) | ||
| if err != nil { | ||
| return "", fmt.Errorf("ConnectionString failed: %w", err) | ||
| } | ||
|
|
||
| dsnWithPath, err := addSearchPathToURL(dsn, searchPath) | ||
| if err != nil { | ||
| return "", fmt.Errorf("applying schema search path failed: %w", err) | ||
| } | ||
|
|
||
| return dsnWithPath, nil | ||
| }) | ||
| } | ||
|
|
||
| // ConnectDB opens a connection using the connector and verifies it with a ping | ||
| func ConnectDB(conn driver.Connector) (*sqlx.DB, error) { | ||
| sqlDB := sql.OpenDB(conn) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is less code: You don't need the intermediate variable sqlDB. |
||
| db := sqlx.NewDb(sqlDB, "postgres") | ||
| if err := db.Ping(); err != nil { | ||
| db.Close() | ||
| return nil, err | ||
| } | ||
| if conn.searchPath == "" { | ||
| return dsn, nil | ||
| return db, nil | ||
| } | ||
|
|
||
| // MustConnectDB is like ConnectDB but panics on error | ||
| func MustConnectDB(conn driver.Connector) *sqlx.DB { | ||
| db, err := ConnectDB(conn) | ||
| if err != nil { | ||
| panic(err) | ||
| } | ||
| return db | ||
| } | ||
|
|
||
| // Add search path | ||
| u, err := url.Parse(dsn) | ||
| // addSearchPathToURL returns a copy of u with search_path set in the query string. | ||
| // It returns an error if search_path is already present. | ||
| func addSearchPathToURL(rawURL string, searchPath string) (string, error) { | ||
| u, err := url.Parse(rawURL) | ||
| if err != nil { | ||
| return "", fmt.Errorf("parse DSN URL: %w", err) | ||
| return "", fmt.Errorf("url string failed to parse while adding search path: %w", err) | ||
| } | ||
|
|
||
| if searchPath == "" { | ||
| return u.String(), nil | ||
| } | ||
|
|
||
| q := u.Query() | ||
| if v := q.Get("search_path"); v != "" { | ||
| return "", fmt.Errorf("search_path already set to %q", v) | ||
| } | ||
| q.Set("search_path", conn.searchPath) // url.Values will percent-encode commas as needed | ||
| q.Set("search_path", searchPath) | ||
| u.RawQuery = q.Encode() | ||
| return u.String(), nil | ||
| } | ||
|
|
||
| func (c *PostgresqlConnector) Driver() driver.Driver { | ||
| return &pq.Driver{} | ||
| type postgresqlConnector struct { | ||
| connectionStringProvider ConnectionStringProvider | ||
| } | ||
|
|
||
| type staticConnectionStringProvider struct { | ||
| connectionString string | ||
| } | ||
| func (c *postgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { | ||
| dsn, err := c.connectionStringProvider.ConnectionString(ctx) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("getting connection string from provider: %w", err) | ||
| } | ||
| pqConnector, err := pq.NewConnector(dsn) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("creating pq connector: %w", err) | ||
| } | ||
|
|
||
| func (p *staticConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { | ||
| return p.connectionString, nil | ||
| return pqConnector.Connect(ctx) | ||
| } | ||
|
|
||
| func NewPostgresqlConnectorFromConnectionString(connectionString string) *PostgresqlConnector { | ||
| return &PostgresqlConnector{ | ||
| baseConnectionStringProvider: &staticConnectionStringProvider{connectionString}, | ||
| } | ||
| func (c *postgresqlConnector) Driver() driver.Driver { | ||
| return pqDriver | ||
| } | ||
|
|
||
| type IAMAuthConfig struct { | ||
| RDSEndpoint string | ||
| User string | ||
| Database string | ||
|
|
||
| // Optional: cross-account role assumption. | ||
| // Set this to a role ARN in the RDS account (Account A) that has rds-db:connect. | ||
| AssumeRoleARN string | ||
|
|
||
| // Optional: if your trust policy requires an external ID. | ||
| AssumeRoleExternalID string | ||
|
|
||
| // Optional: override the default session name. | ||
| AssumeRoleSessionName string | ||
|
|
||
| // Optional: override STS assume role duration. | ||
| // If zero, SDK default is used. | ||
| AssumeRoleDuration time.Duration | ||
| type staticConnectionStringProvider struct { | ||
| connectionString string | ||
| } | ||
|
|
||
| type iamAuthConnectionStringProvider struct { | ||
| IAMAuthConfig | ||
| func (p *staticConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { | ||
| return p.connectionString, nil | ||
| } | ||
|
|
||
| region string | ||
| creds aws.CredentialsProvider | ||
| type rdsIAMConnectionStringProvider struct { | ||
| RDSEndpoint string | ||
| Region string | ||
| User string | ||
| Database string | ||
| CredentialsProvider aws.CredentialsProvider | ||
| } | ||
|
|
||
| func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { | ||
| authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.region, p.User, p.creds) | ||
| func (p *rdsIAMConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { | ||
| authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.Region, p.User, p.CredentialsProvider) | ||
| if err != nil { | ||
| return "", fmt.Errorf("building auth token: %w", err) | ||
| } | ||
| log.Printf("Signing RDS IAM token for \n Endpoint: %s \n User: %s \n Database: %s", p.RDSEndpoint, p.User, p.Database) | ||
| log.Printf("Signing RDS IAM token for Endpoint: %s User: %s Database: %s", p.RDSEndpoint, p.User, p.Database) | ||
|
|
||
| dsnURL := &url.URL{ | ||
| Scheme: "postgresql", | ||
|
|
@@ -134,9 +187,43 @@ func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Co | |
| return dsnURL.String(), nil | ||
| } | ||
|
|
||
| func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) (*PostgresqlConnector, error) { | ||
| if cfg.RDSEndpoint == "" || cfg.User == "" || cfg.Database == "" { | ||
| return nil, errors.New("RDS endpoint, user, and database are required") | ||
| func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL) (ConnectionStringProvider, error) { | ||
| user := "" | ||
| if u.User != nil { | ||
| user = u.User.Username() | ||
| if _, hasPw := u.User.Password(); hasPw { | ||
| return nil, errors.New("postgres+rds-iam URL must not include a password") | ||
| } | ||
| } | ||
| if user == "" { | ||
| return nil, errors.New("postgres+rds-iam URL missing username") | ||
| } | ||
|
|
||
| host := u.Hostname() | ||
| if host == "" { | ||
| return nil, errors.New("postgres+rds-iam URL missing host") | ||
| } | ||
|
|
||
| port := u.Port() | ||
| if port == "" { | ||
| port = defaultPostgresPort | ||
| } | ||
|
|
||
| // Match libpq/psql defaulting: if dbname isn't specified, dbname defaults to username. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this mean:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the words after the ':' describe in more details what that means. I mean we historically have relied on this all over the place. We have lots of connection strings that don't have /postgres at the end to specify the database, instead we rely on lib/psql behavior where if the username is 'postgres' then the database will be set to 'postgres' if you don't set it. |
||
| dbName := strings.TrimPrefix(u.Path, "/") | ||
| if dbName == "" { | ||
| dbName = user | ||
| } | ||
|
|
||
| q := u.Query() | ||
| supportedParams := map[string]struct{}{ | ||
| "assume_role_arn": {}, | ||
| "assume_role_session_name": {}, | ||
| } | ||
| for k := range q { | ||
| if _, ok := supportedParams[k]; !ok { | ||
| return nil, fmt.Errorf("postgres+rds-iam URL has unsupported query parameter: %s", k) | ||
| } | ||
| } | ||
|
|
||
| awsCfg, err := awsconfig.LoadDefaultConfig(ctx) | ||
|
|
@@ -149,66 +236,25 @@ func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) | |
| } | ||
|
|
||
| creds := awsCfg.Credentials | ||
|
|
||
| // Cross-account support: | ||
| // If AssumeRoleARN is set, assume a role in the RDS account (Account A) | ||
| // using the ECS task role creds from Account B as the source credentials. | ||
| if cfg.AssumeRoleARN != "" { | ||
| log.Printf("RDS IAM Assuming Role: %s for \n Endpoint: %s \n User: %s \n Database: %s", cfg.AssumeRoleARN, cfg.RDSEndpoint, cfg.User, cfg.Database) | ||
| assumeRoleARN := q.Get("assume_role_arn") | ||
| if assumeRoleARN != "" { | ||
| stsClient := sts.NewFromConfig(awsCfg) | ||
|
|
||
| sessionName := cfg.AssumeRoleSessionName | ||
| sessionName := q.Get("assume_role_session_name") | ||
| if sessionName == "" { | ||
| sessionName = "pgutils-rds-iam" | ||
| } | ||
|
|
||
| assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, cfg.AssumeRoleARN, func(assumeRoleOpts *stscreds.AssumeRoleOptions) { | ||
| assumeRoleOpts.RoleSessionName = sessionName | ||
|
|
||
| if cfg.AssumeRoleExternalID != "" { | ||
| assumeRoleOpts.ExternalID = aws.String(cfg.AssumeRoleExternalID) | ||
| } | ||
|
|
||
| if cfg.AssumeRoleDuration != 0 { | ||
| assumeRoleOpts.Duration = cfg.AssumeRoleDuration | ||
| } | ||
| log.Printf("RDS IAM Assuming Role: %s with session name: %s for Host: %s User: %s Database: %s", assumeRoleARN, sessionName, host, user, dbName) | ||
| assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, assumeRoleARN, func(opts *stscreds.AssumeRoleOptions) { | ||
| opts.RoleSessionName = sessionName | ||
| }) | ||
|
|
||
| // Cache to avoid calling STS too frequently. | ||
| creds = aws.NewCredentialsCache(assumeProvider) | ||
| } | ||
|
|
||
| return &PostgresqlConnector{ | ||
| baseConnectionStringProvider: &iamAuthConnectionStringProvider{ | ||
| IAMAuthConfig: *cfg, | ||
| region: awsCfg.Region, | ||
| creds: creds, | ||
| }, | ||
| return &rdsIAMConnectionStringProvider{ | ||
| Region: awsCfg.Region, | ||
| RDSEndpoint: net.JoinHostPort(host, port), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here you have Endpoint including port. |
||
| User: user, | ||
| Database: dbName, | ||
| CredentialsProvider: creds, | ||
| }, nil | ||
| } | ||
|
|
||
| // Provides missing sqlx.OpenDB | ||
| func OpenDB(conn *PostgresqlConnector) *sqlx.DB { | ||
| sqlDB := sql.OpenDB(conn) | ||
| return sqlx.NewDb(sqlDB, "postgres") | ||
| } | ||
|
|
||
| // ConnectDB opens a connection using the connector and verifies it with a ping | ||
| func ConnectDB(conn *PostgresqlConnector) (*sqlx.DB, error) { | ||
| db := OpenDB(conn) | ||
| if err := db.Ping(); err != nil { | ||
| db.Close() | ||
| return nil, err | ||
| } | ||
| return db, nil | ||
| } | ||
|
|
||
| // MustConnectDB is like ConnectDB but panics on error | ||
| func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB { | ||
| db, err := ConnectDB(conn) | ||
| if err != nil { | ||
| panic(err) | ||
| } | ||
| return db | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.