From f8dcc535da28e322aa1aa5cf91132fa661b8bada Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Fri, 12 Dec 2025 05:27:36 -0800 Subject: [PATCH 1/7] init --- cmd/rds-iam-psql/README.md | 114 ++++++++++++++++++++++++++++++ cmd/rds-iam-psql/main.go | 139 +++++++++++++++++++++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 cmd/rds-iam-psql/README.md create mode 100644 cmd/rds-iam-psql/main.go diff --git a/cmd/rds-iam-psql/README.md b/cmd/rds-iam-psql/README.md new file mode 100644 index 0000000..b6dae57 --- /dev/null +++ b/cmd/rds-iam-psql/README.md @@ -0,0 +1,114 @@ +# rds-iam-psql + +A simple CLI tool that bridges AWS RDS IAM authentication into an interactive `psql` session. It generates a short-lived IAM auth token and launches `psql` with the token as the password, so you never have to manage database passwords. + +## Why? + +RDS IAM authentication lets you connect to PostgreSQL using your AWS credentials instead of a static database password. However, the auth tokens are temporary (15 minutes) and cumbersome to generate manually. This tool handles token generation automatically and drops you into a familiar `psql` shell. + +## Installation + +```bash +go install github.com/corbaltcode/go-libraries/cmd/rds-iam-psql@latest +``` + +Or build from source: + +```bash +cd ./cmd/rds-iam-psql +go build +``` + +## Prerequisites + +- **psql** installed and available in your PATH +- **AWS credentials** configured (via environment variables, `~/.aws/credentials`, IAM role, etc.) +- **RDS IAM authentication enabled** on your database instance +- A database user configured for IAM authentication (created with `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`) + +## Usage + +```bash +rds-iam-psql -host -user -db [options] +``` + +### Required Flags + +| Flag | Description | +|------|-------------| +| `-host` | RDS endpoint hostname (without port), e.g. `mydb.abc123.us-east-1.rds.amazonaws.com` | +| `-user` | Database username configured for IAM auth | +| `-db` | Database name to connect to | + +### Optional Flags + +| Flag | Default | Description | +|------|---------|-------------| +| `-port` | `5432` | PostgreSQL port | +| `-region` | auto | AWS region. If omitted, inferred from AWS config or the hostname | +| `-profile` | | AWS shared config profile to use (e.g. `dev`, `prod`) | +| `-psql` | `psql` | Path to the `psql` binary | +| `-sslmode` | `require` | SSL mode (`require`, `verify-full`, etc.) | +| `-search-path` | | PostgreSQL `search_path` to set on connection (e.g. `myschema,public`) | + +## Examples + +Basic connection: + +```bash +rds-iam-psql -host mydb.abc123.us-east-1.rds.amazonaws.com -user app_user -db myapp +``` + +With a specific AWS profile and schema: + +```bash +rds-iam-psql \ + -host mydb.abc123.us-east-1.rds.amazonaws.com \ + -user app_user \ + -db myapp \ + -profile production \ + -search-path "app_schema,public" +``` + +Using a non-standard port and explicit region: + +```bash +rds-iam-psql \ + -host mydb.abc123.us-east-1.rds.amazonaws.com \ + -port 5433 \ + -user admin \ + -db postgres \ + -region us-east-1 +``` + +## How It Works + +1. Loads your AWS credentials from the standard credential chain +2. Generates a temporary RDS IAM auth token using `auth.BuildAuthToken` +3. Launches `psql` with: + - `PGPASSWORD` set to the auth token + - `PGSSLMODE` set according to `-sslmode` + - `PGOPTIONS` set if `-search-path` is provided +4. Attaches stdin/stdout/stderr for interactive use + +## Setting Up IAM Auth on RDS + +1. Enable IAM authentication on your RDS instance +2. Create a database user and grant IAM privileges: + ```sql + CREATE USER myuser WITH LOGIN; + GRANT rds_iam TO myuser; + ``` +3. Attach an IAM policy allowing `rds-db:connect` to your AWS user/role: + ```json + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": "rds-db:connect", + "Resource": "arn:aws:rds-db:::dbuser:/" + } + ] + } + ``` diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go new file mode 100644 index 0000000..d076047 --- /dev/null +++ b/cmd/rds-iam-psql/main.go @@ -0,0 +1,139 @@ +// rds-iam-psql.go +package main + +import ( + "context" + "flag" + "fmt" + "log" + "os" + "os/exec" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/rds/auth" +) + +func main() { + var ( + host = flag.String("host", "", "RDS PostgreSQL endpoint hostname (no port, e.g. mydb.abc123.us-east-1.rds.amazonaws.com)") + port = flag.Int("port", 5432, "RDS PostgreSQL port (default 5432)") + user = flag.String("user", "", "Database user name") + dbName = flag.String("db", "", "Database name") + region = flag.String("region", "", "AWS region for the RDS instance (e.g. us-east-1). If empty, uses AWS config or tries to infer from host.") + profile = flag.String("profile", "", "Optional AWS shared config profile (e.g. dev)") + psqlPath = flag.String("psql", "psql", "Path to psql binary") + sslMode = flag.String("sslmode", "require", "PGSSLMODE for psql (e.g. require, verify-full)") + searchPath = flag.String("search-path", "", "Optional PostgreSQL search_path to set (e.g. 'myschema,public')") + ) + flag.Parse() + + if *host == "" || *user == "" || *dbName == "" { + log.Fatalf("host, user, and db are required\n\nUsage example:\n %s -host mydb.abc123.us-east-1.rds.amazonaws.com -port 5432 -user myuser -db mydb -search-path \"login,public\" -region us-east-1\n", os.Args[0]) + } + + ctx := context.Background() + + // Load AWS config (standard RDS/IAM auth expects your AWS creds, *not* the DB password). + var cfg aws.Config + var err error + if *profile != "" { + cfg, err = awsconfig.LoadDefaultConfig(ctx, awsconfig.WithSharedConfigProfile(*profile)) + } else { + cfg, err = awsconfig.LoadDefaultConfig(ctx) + } + if err != nil { + log.Fatalf("failed to load AWS config: %v", err) + } + + awsRegion := *region + if awsRegion == "" { + awsRegion = cfg.Region + } + if awsRegion == "" { + // Last resort: try to infer from the hostname if it looks like a standard RDS endpoint. + if inferred := inferRegionFromHost(*host); inferred != "" { + awsRegion = inferred + } + } + + if awsRegion == "" { + log.Fatalf("AWS region is not set; pass -region or set AWS_REGION / configure your AWS profile") + } + + endpointWithPort := fmt.Sprintf("%s:%d", *host, *port) + + // Generate the IAM auth token. + authToken, err := auth.BuildAuthToken(ctx, endpointWithPort, awsRegion, *user, cfg.Credentials) + if err != nil { + log.Fatalf("failed to build RDS IAM auth token: %v", err) + } + + // Prepare psql command. We pass the token through PGPASSWORD and SSL mode via PGSSLMODE. + cmd := exec.Command( + *psqlPath, + "--host", *host, + "--port", fmt.Sprintf("%d", *port), + "--username", *user, + "--dbname", *dbName, + ) + + // Attach stdio so it behaves like an interactive shell. + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Inherit existing env and add PG vars. + env := os.Environ() + env = append(env, + "PGPASSWORD="+authToken, + "PGSSLMODE="+*sslMode, + ) + + // If a search path is provided, wire it through PGOPTIONS. + if sp := strings.TrimSpace(*searchPath); sp != "" { + // Build our addition: one -c flag. + add := "-c search_path=" + sp + + // Check if PGOPTIONS already exists; if so, append. + found := false + for i, e := range env { + if strings.HasPrefix(e, "PGOPTIONS=") { + current := strings.TrimPrefix(e, "PGOPTIONS=") + if strings.TrimSpace(current) == "" { + env[i] = "PGOPTIONS=" + add + } else { + env[i] = "PGOPTIONS=" + current + " " + add + } + found = true + break + } + } + if !found { + env = append(env, "PGOPTIONS="+add) + } + } + + cmd.Env = env + + if err := cmd.Run(); err != nil { + // psql will print its own error messages; just propagate the exit code. + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + log.Fatalf("failed to run psql: %v", err) + } +} + +// inferRegionFromHost tries to pull the AWS region out of a typical RDS hostname like +// "mydb.abc123.us-east-1.rds.amazonaws.com". If it can't, it returns "". +func inferRegionFromHost(host string) string { + parts := strings.Split(host, ".") + for i := 0; i < len(parts); i++ { + if parts[i] == "rds" && i > 0 { + return parts[i-1] + } + } + return "" +} From 200599b9d491e3bc1642c4bc2b2869af087cc182 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 18 Dec 2025 20:25:02 -0800 Subject: [PATCH 2/7] Fix ctrl-c issue, and remove parsing region from db hostname --- cmd/rds-iam-psql/main.go | 64 +++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index d076047..659518c 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -1,4 +1,3 @@ -// rds-iam-psql.go package main import ( @@ -8,7 +7,9 @@ import ( "log" "os" "os/exec" + "os/signal" "strings" + "syscall" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" @@ -51,12 +52,6 @@ func main() { if awsRegion == "" { awsRegion = cfg.Region } - if awsRegion == "" { - // Last resort: try to infer from the hostname if it looks like a standard RDS endpoint. - if inferred := inferRegionFromHost(*host); inferred != "" { - awsRegion = inferred - } - } if awsRegion == "" { log.Fatalf("AWS region is not set; pass -region or set AWS_REGION / configure your AWS profile") @@ -93,10 +88,8 @@ func main() { // If a search path is provided, wire it through PGOPTIONS. if sp := strings.TrimSpace(*searchPath); sp != "" { - // Build our addition: one -c flag. add := "-c search_path=" + sp - // Check if PGOPTIONS already exists; if so, append. found := false for i, e := range env { if strings.HasPrefix(e, "PGOPTIONS=") { @@ -117,23 +110,46 @@ func main() { cmd.Env = env - if err := cmd.Run(); err != nil { - // psql will print its own error messages; just propagate the exit code. - if exitErr, ok := err.(*exec.ExitError); ok { - os.Exit(exitErr.ExitCode()) - } - log.Fatalf("failed to run psql: %v", err) + // --- Ctrl-C handling --- + // The key idea: keep psql in the same foreground process group so it can read + // from the terminal. We intercept SIGINT only to prevent THIS wrapper from + // exiting; psql will still receive SIGINT normally and cancel the current + // query / line as expected. + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + + if err := cmd.Start(); err != nil { + log.Fatalf("failed to start psql: %v", err) } -} -// inferRegionFromHost tries to pull the AWS region out of a typical RDS hostname like -// "mydb.abc123.us-east-1.rds.amazonaws.com". If it can't, it returns "". -func inferRegionFromHost(host string) string { - parts := strings.Split(host, ".") - for i := 0; i < len(parts); i++ { - if parts[i] == "rds" && i > 0 { - return parts[i-1] + waitCh := make(chan error, 1) + go func() { waitCh <- cmd.Wait() }() + + for { + select { + case sig := <-sigCh: + switch sig { + case os.Interrupt: + // Swallow SIGINT so this wrapper doesn't exit. + // psql still gets SIGINT (same terminal foreground process group). + continue + case syscall.SIGTERM: + // If we're being terminated, pass it through to psql and exit accordingly. + if cmd.Process != nil { + _ = cmd.Process.Signal(syscall.SIGTERM) + } + } + case err := <-waitCh: + // psql exited; now we exit with the same code. + if err == nil { + return + } + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + log.Fatalf("psql failed: %v", err) } } - return "" } + From a59d002953481f76fe932baff31e3e0a396a7293 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 18 Dec 2025 20:41:17 -0800 Subject: [PATCH 3/7] Sts check --- cmd/rds-iam-psql/main.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index 659518c..43a6870 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -14,6 +14,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/aws/aws-sdk-go-v2/service/sts" ) func main() { @@ -48,6 +49,11 @@ func main() { log.Fatalf("failed to load AWS config: %v", err) } + // Fail fast + print identity (account/arn/role-ish). + if err := printCallerIdentity(ctx, cfg); err != nil { + log.Fatalf("AWS credentials check failed: %v", err) + } + awsRegion := *region if awsRegion == "" { awsRegion = cfg.Region @@ -153,3 +159,19 @@ func main() { } } +func printCallerIdentity(ctx context.Context, cfg aws.Config) error { + stsClient := sts.NewFromConfig(cfg) + + out, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return fmt.Errorf("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w", err) + } + + account := aws.ToString(out.Account) + arn := aws.ToString(out.Arn) + + fmt.Printf("AWS Account: %s\n", account) + fmt.Printf("Caller ARN: %s\n", arn) + + return nil +} From ece0181e437f3a12a80fa91967c914b11f5b0254 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 18 Dec 2025 20:41:39 -0800 Subject: [PATCH 4/7] White space --- cmd/rds-iam-psql/main.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index 43a6870..f6213b0 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -175,3 +175,4 @@ func printCallerIdentity(ctx context.Context, cfg aws.Config) error { return nil } + From ac386aa4f076709e4c7c354311e39cb71fdf7728 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 18 Dec 2025 20:45:21 -0800 Subject: [PATCH 5/7] tighter sts print --- cmd/rds-iam-psql/main.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index f6213b0..d6e2c6f 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -167,12 +167,7 @@ func printCallerIdentity(ctx context.Context, cfg aws.Config) error { return fmt.Errorf("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w", err) } - account := aws.ToString(out.Account) - arn := aws.ToString(out.Arn) - - fmt.Printf("AWS Account: %s\n", account) - fmt.Printf("Caller ARN: %s\n", arn) - + fmt.Printf("Caller ARN: %s\n", aws.ToString(out.Arn)) return nil } From d46322af740cc377805bf3abf1d9db4b37c84fd6 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 18 Dec 2025 20:45:41 -0800 Subject: [PATCH 6/7] go fmt --- cmd/rds-iam-psql/main.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index d6e2c6f..8406c0d 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -170,4 +170,3 @@ func printCallerIdentity(ctx context.Context, cfg aws.Config) error { fmt.Printf("Caller ARN: %s\n", aws.ToString(out.Arn)) return nil } - From b2884c959839b60a9f7cfaf6a3a562d9d8539982 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Wed, 18 Feb 2026 19:15:23 -0800 Subject: [PATCH 7/7] use connetor --- cmd/rds-iam-psql/README.md | 85 ++++++----- cmd/rds-iam-psql/main.go | 139 +++++++++++------- pgutils/connector.go | 288 +++++++++++++++++++++---------------- pgutils/listener.go | 7 +- 4 files changed, 307 insertions(+), 212 deletions(-) diff --git a/cmd/rds-iam-psql/README.md b/cmd/rds-iam-psql/README.md index b6dae57..a316106 100644 --- a/cmd/rds-iam-psql/README.md +++ b/cmd/rds-iam-psql/README.md @@ -1,10 +1,14 @@ # rds-iam-psql -A simple CLI tool that bridges AWS RDS IAM authentication into an interactive `psql` session. It generates a short-lived IAM auth token and launches `psql` with the token as the password, so you never have to manage database passwords. +A CLI that launches an interactive `psql` session from either: +- a positional connection URL, or +- individual `-host/-port/-user/-db` flags. + +It supports standard PostgreSQL URLs and `pgutils` custom IAM URLs (`postgres+rds-iam://...`). ## Why? -RDS IAM authentication lets you connect to PostgreSQL using your AWS credentials instead of a static database password. However, the auth tokens are temporary (15 minutes) and cumbersome to generate manually. This tool handles token generation automatically and drops you into a familiar `psql` shell. +RDS IAM authentication lets you connect using AWS credentials instead of a static DB password. IAM auth tokens are short-lived and inconvenient to generate manually. This tool resolves a fresh DSN through `pgutils` and opens `psql` for you. ## Installation @@ -22,74 +26,85 @@ go build ## Prerequisites - **psql** installed and available in your PATH -- **AWS credentials** configured (via environment variables, `~/.aws/credentials`, IAM role, etc.) -- **RDS IAM authentication enabled** on your database instance -- A database user configured for IAM authentication (created with `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`) +- For IAM URLs (`postgres+rds-iam://...`), **AWS credentials** configured (env vars, `~/.aws/credentials`, IAM role, etc.) +- For IAM URLs (`postgres+rds-iam://...`), **AWS_REGION** set +- For IAM URLs (`postgres+rds-iam://...`), **RDS IAM authentication enabled** on your database instance +- For IAM URLs (`postgres+rds-iam://...`), a DB user configured for IAM auth (for example: `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`) ## Usage ```bash -rds-iam-psql -host -user -db [options] +rds-iam-psql [connection-url] [options] ``` -### Required Flags +```bash +rds-iam-psql -host -user -db [options] +``` -| Flag | Description | -|------|-------------| -| `-host` | RDS endpoint hostname (without port), e.g. `mydb.abc123.us-east-1.rds.amazonaws.com` | -| `-user` | Database username configured for IAM auth | -| `-db` | Database name to connect to | +`connection-url` supports: +- `postgres+rds-iam://user@host:5432/dbname` +- `postgres://user:pass@host:5432/dbname?...` +- `postgresql://user:pass@host:5432/dbname?...` -### Optional Flags +If `connection-url` is provided, do not combine it with `-host/-port/-user/-db`. + +### Flags | Flag | Default | Description | |------|---------|-------------| +| `-host` | | Endpoint hostname (required if `connection-url` is not provided) | | `-port` | `5432` | PostgreSQL port | -| `-region` | auto | AWS region. If omitted, inferred from AWS config or the hostname | -| `-profile` | | AWS shared config profile to use (e.g. `dev`, `prod`) | +| `-user` | | DB username (required if `connection-url` is not provided) | +| `-db` | | DB name (required if `connection-url` is not provided) | | `-psql` | `psql` | Path to the `psql` binary | | `-sslmode` | `require` | SSL mode (`require`, `verify-full`, etc.) | | `-search-path` | | PostgreSQL `search_path` to set on connection (e.g. `myschema,public`) | ## Examples -Basic connection: +Positional IAM URL (your requested form): + +```bash +./rds-iam-psql 'postgres+rds-iam://server@acremins-test.cicxifnkufnd.us-east-1.rds.amazonaws.com:5432/postgres' +``` + +IAM URL with cross-account role assumption: + +```bash +rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp?assume_role_arn=arn:aws:iam::123456789012:role/db-connect&assume_role_session_name=rds-iam-psql' +``` + +Flag-based IAM connection: ```bash rds-iam-psql -host mydb.abc123.us-east-1.rds.amazonaws.com -user app_user -db myapp ``` -With a specific AWS profile and schema: +Standard PostgreSQL URL (non-IAM): ```bash -rds-iam-psql \ - -host mydb.abc123.us-east-1.rds.amazonaws.com \ - -user app_user \ - -db myapp \ - -profile production \ - -search-path "app_schema,public" +rds-iam-psql 'postgresql://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable' ``` -Using a non-standard port and explicit region: +With search path: ```bash rds-iam-psql \ -host mydb.abc123.us-east-1.rds.amazonaws.com \ - -port 5433 \ - -user admin \ - -db postgres \ - -region us-east-1 + -user app_user \ + -db myapp \ + -search-path "app_schema,public" ``` ## How It Works -1. Loads your AWS credentials from the standard credential chain -2. Generates a temporary RDS IAM auth token using `auth.BuildAuthToken` -3. Launches `psql` with: - - `PGPASSWORD` set to the auth token - - `PGSSLMODE` set according to `-sslmode` - - `PGOPTIONS` set if `-search-path` is provided -4. Attaches stdin/stdout/stderr for interactive use +1. Parses input from either positional URL or `-host/-port/-user/-db`. +2. Builds a `pgutils.ConnectionStringProvider` from the URL. +3. For IAM URLs, validates AWS auth context (including `AWS_REGION`). +4. Resolves a DSN from the provider and launches `psql` with: +- `PGPASSWORD` set from the URL password/token +- `PGSSLMODE` set from `-sslmode` +- `PGOPTIONS` set when `-search-path` is provided ## Setting Up IAM Auth on RDS diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index 8406c0d..e554ea8 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -5,16 +5,19 @@ import ( "flag" "fmt" "log" + "net" + "net/url" "os" "os/exec" "os/signal" + "strconv" "strings" "syscall" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/feature/rds/auth" "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/corbaltcode/go-libraries/pgutils" ) func main() { @@ -23,76 +26,80 @@ func main() { port = flag.Int("port", 5432, "RDS PostgreSQL port (default 5432)") user = flag.String("user", "", "Database user name") dbName = flag.String("db", "", "Database name") - region = flag.String("region", "", "AWS region for the RDS instance (e.g. us-east-1). If empty, uses AWS config or tries to infer from host.") - profile = flag.String("profile", "", "Optional AWS shared config profile (e.g. dev)") psqlPath = flag.String("psql", "psql", "Path to psql binary") sslMode = flag.String("sslmode", "require", "PGSSLMODE for psql (e.g. require, verify-full)") searchPath = flag.String("search-path", "", "Optional PostgreSQL search_path to set (e.g. 'myschema,public')") ) flag.Parse() - if *host == "" || *user == "" || *dbName == "" { - log.Fatalf("host, user, and db are required\n\nUsage example:\n %s -host mydb.abc123.us-east-1.rds.amazonaws.com -port 5432 -user myuser -db mydb -search-path \"login,public\" -region us-east-1\n", os.Args[0]) + args := flag.Args() + if len(args) > 1 { + log.Fatalf("expected at most one positional connection URL argument, got %d", len(args)) } - ctx := context.Background() - - // Load AWS config (standard RDS/IAM auth expects your AWS creds, *not* the DB password). - var cfg aws.Config - var err error - if *profile != "" { - cfg, err = awsconfig.LoadDefaultConfig(ctx, awsconfig.WithSharedConfigProfile(*profile)) - } else { - cfg, err = awsconfig.LoadDefaultConfig(ctx) + connectionURLArg := "" + if len(args) == 1 { + connectionURLArg = args[0] } + + rawURL, usesIAM, err := buildRawURL(connectionURLArg, *host, *port, *user, *dbName) if err != nil { - log.Fatalf("failed to load AWS config: %v", err) + log.Fatalf("%v\n\nUsage examples:\n %s -host mydb.abc123.us-east-1.rds.amazonaws.com -port 5432 -user myuser -db mydb -search-path \"login,public\"\n %s 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'\n", err, os.Args[0], os.Args[0]) } - // Fail fast + print identity (account/arn/role-ish). - if err := printCallerIdentity(ctx, cfg); err != nil { - log.Fatalf("AWS credentials check failed: %v", err) - } + ctx := context.Background() - awsRegion := *region - if awsRegion == "" { - awsRegion = cfg.Region + connectionStringProvider, err := pgutils.NewConnectionStringProviderFromURLString(ctx, rawURL) + if err != nil { + log.Fatalf("failed to create connection string provider: %v", err) } - if awsRegion == "" { - log.Fatalf("AWS region is not set; pass -region or set AWS_REGION / configure your AWS profile") + if usesIAM { + if os.Getenv("AWS_REGION") == "" { + log.Fatalf("AWS_REGION must be set for IAM auth") + } + + cfg, err := awsconfig.LoadDefaultConfig(ctx) + if err != nil { + log.Fatalf("failed to load AWS config: %v", err) + } + if err := printCallerIdentity(ctx, cfg); err != nil { + log.Fatalf("AWS credentials check failed: %v", err) + } } - endpointWithPort := fmt.Sprintf("%s:%d", *host, *port) + dsnWithToken, err := connectionStringProvider.ConnectionString(ctx) + if err != nil { + log.Fatalf("failed to get connection string from provider: %v", err) + } - // Generate the IAM auth token. - authToken, err := auth.BuildAuthToken(ctx, endpointWithPort, awsRegion, *user, cfg.Credentials) + parsedURL, err := url.Parse(dsnWithToken) if err != nil { - log.Fatalf("failed to build RDS IAM auth token: %v", err) + log.Fatalf("failed to parse connection string from provider: %v", err) } - // Prepare psql command. We pass the token through PGPASSWORD and SSL mode via PGSSLMODE. - cmd := exec.Command( - *psqlPath, - "--host", *host, - "--port", fmt.Sprintf("%d", *port), - "--username", *user, - "--dbname", *dbName, - ) + password := "" + if parsedURL.User != nil { + var ok bool + password, ok = parsedURL.User.Password() + if ok { + parsedURL.User = url.User(parsedURL.User.Username()) + } + } + + // Pass DSN to psql without password in argv, and provide password via env. + cmd := exec.Command(*psqlPath, parsedURL.String()) - // Attach stdio so it behaves like an interactive shell. cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - // Inherit existing env and add PG vars. env := os.Environ() - env = append(env, - "PGPASSWORD="+authToken, - "PGSSLMODE="+*sslMode, - ) + if password != "" { + env = append(env, "PGPASSWORD="+password) + } + env = append(env, "PGSSLMODE="+*sslMode) - // If a search path is provided, wire it through PGOPTIONS. if sp := strings.TrimSpace(*searchPath); sp != "" { add := "-c search_path=" + sp @@ -116,11 +123,8 @@ func main() { cmd.Env = env - // --- Ctrl-C handling --- - // The key idea: keep psql in the same foreground process group so it can read - // from the terminal. We intercept SIGINT only to prevent THIS wrapper from - // exiting; psql will still receive SIGINT normally and cancel the current - // query / line as expected. + // Keep psql in the foreground process group. Swallow SIGINT in wrapper so + // psql handles Ctrl-C directly. sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) defer signal.Stop(sigCh) @@ -137,17 +141,13 @@ func main() { case sig := <-sigCh: switch sig { case os.Interrupt: - // Swallow SIGINT so this wrapper doesn't exit. - // psql still gets SIGINT (same terminal foreground process group). continue case syscall.SIGTERM: - // If we're being terminated, pass it through to psql and exit accordingly. if cmd.Process != nil { _ = cmd.Process.Signal(syscall.SIGTERM) } } case err := <-waitCh: - // psql exited; now we exit with the same code. if err == nil { return } @@ -159,6 +159,41 @@ func main() { } } +func buildRawURL(connectionURLArg, host string, port int, user, dbName string) (string, bool, error) { + if connectionURLArg != "" { + if host != "" || user != "" || dbName != "" || port != 5432 { + return "", false, fmt.Errorf("positional connection URL cannot be combined with -host, -port, -user, or -db") + } + parsedURL, err := url.Parse(connectionURLArg) + if err != nil { + return "", false, fmt.Errorf("failed to parse positional connection URL: %w", err) + } + switch parsedURL.Scheme { + case "postgres+rds-iam": + return connectionURLArg, true, nil + case "postgres", "postgresql": + return connectionURLArg, false, nil + default: + return "", false, fmt.Errorf("unsupported connection URL scheme %q (expected postgres, postgresql, or postgres+rds-iam)", parsedURL.Scheme) + } + } + + if host == "" || user == "" || dbName == "" { + return "", false, fmt.Errorf("host, user, and db are required when no positional connection URL is provided") + } + if port <= 0 { + return "", false, fmt.Errorf("invalid port: %d", port) + } + + iamURL := &url.URL{ + Scheme: "postgres+rds-iam", + User: url.User(user), + Host: net.JoinHostPort(host, strconv.Itoa(port)), + Path: "/" + dbName, + } + return iamURL.String(), true, nil +} + func printCallerIdentity(ctx context.Context, cfg aws.Config) error { stsClient := sts.NewFromConfig(cfg) diff --git a/pgutils/connector.go b/pgutils/connector.go index 21dce91..1a0b773 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -5,8 +5,9 @@ import ( "errors" "fmt" "log" + "net" "net/url" - "time" + "strings" "database/sql" "database/sql/driver" @@ -20,109 +21,161 @@ import ( "github.com/lib/pq" ) -type baseConnectionStringProvider interface { - getBaseConnectionString(ctx context.Context) (string, error) -} +const defaultPostgresPort = "5432" + +var pqDriver = &pq.Driver{} -type PostgresqlConnector struct { - baseConnectionStringProvider - searchPath string +// ConnectionStringProvider returns a Postgres connection string for use by clients +// that need a DSN (e.g., pq.Listener) or to build a connector. +type ConnectionStringProvider interface { + ConnectionString(ctx context.Context) (string, error) } -func (conn *PostgresqlConnector) WithSearchPath(searchPath string) *PostgresqlConnector { - return &PostgresqlConnector{ - baseConnectionStringProvider: conn.baseConnectionStringProvider, - searchPath: searchPath, - } +type connectionStringProviderFunc func(context.Context) (string, error) + +func (f connectionStringProviderFunc) ConnectionString(ctx context.Context) (string, error) { + return f(ctx) } -func (conn *PostgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { - dsn, err := conn.GetConnectionString(ctx) +// NewConnectionStringProviderFromURLString parses rawURL and constructs a provider. +// +// Standard Postgres example: +// +// postgres://user:pass@host:5432/dbname?sslmode=require +// +// IAM example 1: +// +// postgres+rds-iam://user@host:5432/dbname +// +// IAM example 2 (cross-account): +// +// postgres+rds-iam://user@host:5432/dbname?assume_role_arn=...&assume_role_session_name=... +// +// For postgres+rds-iam, the provider generates a fresh IAM auth token on each ConnectionString(ctx) call. +func NewConnectionStringProviderFromURLString(ctx context.Context, rawURL string) (ConnectionStringProvider, error) { + u, err := url.Parse(rawURL) if err != nil { - return nil, fmt.Errorf("get connection string: %w", err) + return nil, fmt.Errorf("parsing URL: %w", err) } - pqConnector, err := pq.NewConnector(dsn) - if err != nil { - return nil, fmt.Errorf("create pq connector: %w", err) + + switch u.Scheme { + case "postgres", "postgresql": + return &staticConnectionStringProvider{connectionString: u.String()}, nil + case "postgres+rds-iam": + return newIAMConnectionStringProviderFromURL(ctx, u) + default: + return nil, fmt.Errorf("unsupported URL scheme: %q (expected postgres, postgresql, or postgres+rds-iam)", u.Scheme) } +} - return pqConnector.Connect(ctx) +// ToConnector wraps a ConnectionStringProvider as a driver.Connector. +// Each Connect(ctx) call asks the provider for a fresh DSN. +func ToConnector(provider ConnectionStringProvider) driver.Connector { + return &postgresqlConnector{connectionStringProvider: provider} } -func (conn *PostgresqlConnector) GetConnectionString(ctx context.Context) (string, error) { - dsn, err := conn.getBaseConnectionString(ctx) - if err != nil { - return "", fmt.Errorf("get base connection string: %w", err) +// WithSchemaSearchPath returns a ConnectionStringProvider that appends search_path +// to the DSN produced by the underlying provider. +func WithSchemaSearchPath(provider ConnectionStringProvider, searchPath string) ConnectionStringProvider { + return connectionStringProviderFunc(func(ctx context.Context) (string, error) { + dsn, err := provider.ConnectionString(ctx) + if err != nil { + return "", fmt.Errorf("ConnectionString failed: %w", err) + } + + dsnWithPath, err := addSearchPathToURL(dsn, searchPath) + if err != nil { + return "", fmt.Errorf("applying schema search path failed: %w", err) + } + + return dsnWithPath, nil + }) +} + +// ConnectDB opens a connection using the connector and verifies it with a ping +func ConnectDB(conn driver.Connector) (*sqlx.DB, error) { + sqlDB := sql.OpenDB(conn) + db := sqlx.NewDb(sqlDB, "postgres") + if err := db.Ping(); err != nil { + db.Close() + return nil, err } - if conn.searchPath == "" { - return dsn, nil + return db, nil +} + +// MustConnectDB is like ConnectDB but panics on error +func MustConnectDB(conn driver.Connector) *sqlx.DB { + db, err := ConnectDB(conn) + if err != nil { + panic(err) } + return db +} - // Add search path - u, err := url.Parse(dsn) +// addSearchPathToURL returns a copy of u with search_path set in the query string. +// It returns an error if search_path is already present. +func addSearchPathToURL(rawURL string, searchPath string) (string, error) { + u, err := url.Parse(rawURL) if err != nil { - return "", fmt.Errorf("parse DSN URL: %w", err) + return "", fmt.Errorf("url string failed to parse while adding search path: %w", err) + } + + if searchPath == "" { + return u.String(), nil } + q := u.Query() if v := q.Get("search_path"); v != "" { return "", fmt.Errorf("search_path already set to %q", v) } - q.Set("search_path", conn.searchPath) // url.Values will percent-encode commas as needed + q.Set("search_path", searchPath) u.RawQuery = q.Encode() return u.String(), nil } -func (c *PostgresqlConnector) Driver() driver.Driver { - return &pq.Driver{} +type postgresqlConnector struct { + connectionStringProvider ConnectionStringProvider } -type staticConnectionStringProvider struct { - connectionString string -} +func (c *postgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { + dsn, err := c.connectionStringProvider.ConnectionString(ctx) + if err != nil { + return nil, fmt.Errorf("getting connection string from provider: %w", err) + } + pqConnector, err := pq.NewConnector(dsn) + if err != nil { + return nil, fmt.Errorf("creating pq connector: %w", err) + } -func (p *staticConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { - return p.connectionString, nil + return pqConnector.Connect(ctx) } -func NewPostgresqlConnectorFromConnectionString(connectionString string) *PostgresqlConnector { - return &PostgresqlConnector{ - baseConnectionStringProvider: &staticConnectionStringProvider{connectionString}, - } +func (c *postgresqlConnector) Driver() driver.Driver { + return pqDriver } -type IAMAuthConfig struct { - RDSEndpoint string - User string - Database string - - // Optional: cross-account role assumption. - // Set this to a role ARN in the RDS account (Account A) that has rds-db:connect. - AssumeRoleARN string - - // Optional: if your trust policy requires an external ID. - AssumeRoleExternalID string - - // Optional: override the default session name. - AssumeRoleSessionName string - - // Optional: override STS assume role duration. - // If zero, SDK default is used. - AssumeRoleDuration time.Duration +type staticConnectionStringProvider struct { + connectionString string } -type iamAuthConnectionStringProvider struct { - IAMAuthConfig +func (p *staticConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { + return p.connectionString, nil +} - region string - creds aws.CredentialsProvider +type rdsIAMConnectionStringProvider struct { + RDSEndpoint string + Region string + User string + Database string + CredentialsProvider aws.CredentialsProvider } -func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { - authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.region, p.User, p.creds) +func (p *rdsIAMConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { + authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.Region, p.User, p.CredentialsProvider) if err != nil { return "", fmt.Errorf("building auth token: %w", err) } - log.Printf("Signing RDS IAM token for \n Endpoint: %s \n User: %s \n Database: %s", p.RDSEndpoint, p.User, p.Database) + log.Printf("Signing RDS IAM token for Endpoint: %s User: %s Database: %s", p.RDSEndpoint, p.User, p.Database) dsnURL := &url.URL{ Scheme: "postgresql", @@ -134,9 +187,43 @@ func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Co return dsnURL.String(), nil } -func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) (*PostgresqlConnector, error) { - if cfg.RDSEndpoint == "" || cfg.User == "" || cfg.Database == "" { - return nil, errors.New("RDS endpoint, user, and database are required") +func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL) (ConnectionStringProvider, error) { + user := "" + if u.User != nil { + user = u.User.Username() + if _, hasPw := u.User.Password(); hasPw { + return nil, errors.New("postgres+rds-iam URL must not include a password") + } + } + if user == "" { + return nil, errors.New("postgres+rds-iam URL missing username") + } + + host := u.Hostname() + if host == "" { + return nil, errors.New("postgres+rds-iam URL missing host") + } + + port := u.Port() + if port == "" { + port = defaultPostgresPort + } + + // Match libpq/psql defaulting: if dbname isn't specified, dbname defaults to username. + dbName := strings.TrimPrefix(u.Path, "/") + if dbName == "" { + dbName = user + } + + q := u.Query() + supportedParams := map[string]struct{}{ + "assume_role_arn": {}, + "assume_role_session_name": {}, + } + for k := range q { + if _, ok := supportedParams[k]; !ok { + return nil, fmt.Errorf("postgres+rds-iam URL has unsupported query parameter: %s", k) + } } awsCfg, err := awsconfig.LoadDefaultConfig(ctx) @@ -149,66 +236,25 @@ func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) } creds := awsCfg.Credentials - - // Cross-account support: - // If AssumeRoleARN is set, assume a role in the RDS account (Account A) - // using the ECS task role creds from Account B as the source credentials. - if cfg.AssumeRoleARN != "" { - log.Printf("RDS IAM Assuming Role: %s for \n Endpoint: %s \n User: %s \n Database: %s", cfg.AssumeRoleARN, cfg.RDSEndpoint, cfg.User, cfg.Database) + assumeRoleARN := q.Get("assume_role_arn") + if assumeRoleARN != "" { stsClient := sts.NewFromConfig(awsCfg) - - sessionName := cfg.AssumeRoleSessionName + sessionName := q.Get("assume_role_session_name") if sessionName == "" { sessionName = "pgutils-rds-iam" } - - assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, cfg.AssumeRoleARN, func(assumeRoleOpts *stscreds.AssumeRoleOptions) { - assumeRoleOpts.RoleSessionName = sessionName - - if cfg.AssumeRoleExternalID != "" { - assumeRoleOpts.ExternalID = aws.String(cfg.AssumeRoleExternalID) - } - - if cfg.AssumeRoleDuration != 0 { - assumeRoleOpts.Duration = cfg.AssumeRoleDuration - } + log.Printf("RDS IAM Assuming Role: %s with session name: %s for Host: %s User: %s Database: %s", assumeRoleARN, sessionName, host, user, dbName) + assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, assumeRoleARN, func(opts *stscreds.AssumeRoleOptions) { + opts.RoleSessionName = sessionName }) - - // Cache to avoid calling STS too frequently. creds = aws.NewCredentialsCache(assumeProvider) } - return &PostgresqlConnector{ - baseConnectionStringProvider: &iamAuthConnectionStringProvider{ - IAMAuthConfig: *cfg, - region: awsCfg.Region, - creds: creds, - }, + return &rdsIAMConnectionStringProvider{ + Region: awsCfg.Region, + RDSEndpoint: net.JoinHostPort(host, port), + User: user, + Database: dbName, + CredentialsProvider: creds, }, nil } - -// Provides missing sqlx.OpenDB -func OpenDB(conn *PostgresqlConnector) *sqlx.DB { - sqlDB := sql.OpenDB(conn) - return sqlx.NewDb(sqlDB, "postgres") -} - -// ConnectDB opens a connection using the connector and verifies it with a ping -func ConnectDB(conn *PostgresqlConnector) (*sqlx.DB, error) { - db := OpenDB(conn) - if err := db.Ping(); err != nil { - db.Close() - return nil, err - } - return db, nil -} - -// MustConnectDB is like ConnectDB but panics on error -func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB { - db, err := ConnectDB(conn) - if err != nil { - panic(err) - } - return db -} - diff --git a/pgutils/listener.go b/pgutils/listener.go index 958462c..d1a7d06 100644 --- a/pgutils/listener.go +++ b/pgutils/listener.go @@ -69,7 +69,7 @@ func listenerEventToString(t pq.ListenerEventType) string { // The callback is invoked from the listener goroutine; it MUST NOT block // for long periods. If you need to do heavy work, offload it to another // goroutine. -func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string, callback func(*pq.Notification), onClose func()) error { +func Listen(ctx context.Context, provider ConnectionStringProvider, pgChannelName string, callback func(*pq.Notification), onClose func()) error { if callback == nil { return fmt.Errorf("listener callback cannot be nil") } @@ -77,9 +77,9 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string reconnectEventCh := make(chan struct{}, 1) // We just need a single reconnect event to trigger, so buffer size of 1 makeListener := func() (*pq.Listener, error) { - url, err := conn.GetConnectionString(ctx) + url, err := provider.ConnectionString(ctx) if err != nil { - return nil, fmt.Errorf("get url: %w", err) + return nil, fmt.Errorf("error getting connection string from provider: %w", err) } cb := func(t pq.ListenerEventType, e error) { @@ -174,4 +174,3 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string return nil } -