Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions cmd/rds-iam-psql/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# rds-iam-psql

A CLI that launches an interactive `psql` session from either:
- a positional connection URL, or
- individual `-host/-port/-user/-db` flags.

It supports standard PostgreSQL URLs and `pgutils` custom IAM URLs (`postgres+rds-iam://...`).

## Why?

RDS IAM authentication lets you connect using AWS credentials instead of a static DB password. IAM auth tokens are short-lived and inconvenient to generate manually. This tool resolves a fresh DSN through `pgutils` and opens `psql` for you.

## Installation

```bash
go install github.com/corbaltcode/go-libraries/cmd/rds-iam-psql@latest
```

Or build from source:

```bash
cd ./cmd/rds-iam-psql
go build
```

## Prerequisites

- **psql** installed and available in your PATH
- For IAM URLs (`postgres+rds-iam://...`), **AWS credentials** configured (env vars, `~/.aws/credentials`, IAM role, etc.)
- For IAM URLs (`postgres+rds-iam://...`), **AWS_REGION** set
- For IAM URLs (`postgres+rds-iam://...`), **RDS IAM authentication enabled** on your database instance
- For IAM URLs (`postgres+rds-iam://...`), a DB user configured for IAM auth (for example: `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`)

## Usage

```bash
rds-iam-psql [connection-url] [options]
```

```bash
rds-iam-psql -host <endpoint> -user <db-user> -db <database-name> [options]
```

`connection-url` supports:
- `postgres+rds-iam://user@host:5432/dbname`
- `postgres://user:pass@host:5432/dbname?...`
- `postgresql://user:pass@host:5432/dbname?...`

If `connection-url` is provided, do not combine it with `-host/-port/-user/-db`.

### Flags

| Flag | Default | Description |
|------|---------|-------------|
| `-host` | | Endpoint hostname (required if `connection-url` is not provided) |
| `-port` | `5432` | PostgreSQL port |
| `-user` | | DB username (required if `connection-url` is not provided) |
| `-db` | | DB name (required if `connection-url` is not provided) |
| `-psql` | `psql` | Path to the `psql` binary |
| `-sslmode` | `require` | SSL mode (`require`, `verify-full`, etc.) |
| `-search-path` | | PostgreSQL `search_path` to set on connection (e.g. `myschema,public`) |

## Examples

Positional IAM URL (your requested form):

```bash
./rds-iam-psql 'postgres+rds-iam://server@acremins-test.cicxifnkufnd.us-east-1.rds.amazonaws.com:5432/postgres'
```

IAM URL with cross-account role assumption:

```bash
rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp?assume_role_arn=arn:aws:iam::123456789012:role/db-connect&assume_role_session_name=rds-iam-psql'
```

Flag-based IAM connection:

```bash
rds-iam-psql -host mydb.abc123.us-east-1.rds.amazonaws.com -user app_user -db myapp
```

Standard PostgreSQL URL (non-IAM):

```bash
rds-iam-psql 'postgresql://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable'
```

With search path:

```bash
rds-iam-psql \
-host mydb.abc123.us-east-1.rds.amazonaws.com \
-user app_user \
-db myapp \
-search-path "app_schema,public"
```

## How It Works

1. Parses input from either positional URL or `-host/-port/-user/-db`.
2. Builds a `pgutils.ConnectionStringProvider` from the URL.
3. For IAM URLs, validates AWS auth context (including `AWS_REGION`).
4. Resolves a DSN from the provider and launches `psql` with:
- `PGPASSWORD` set from the URL password/token
- `PGSSLMODE` set from `-sslmode`
- `PGOPTIONS` set when `-search-path` is provided

## Setting Up IAM Auth on RDS

1. Enable IAM authentication on your RDS instance
2. Create a database user and grant IAM privileges:
```sql
CREATE USER myuser WITH LOGIN;
GRANT rds_iam TO myuser;
```
3. Attach an IAM policy allowing `rds-db:connect` to your AWS user/role:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "rds-db:connect",
"Resource": "arn:aws:rds-db:<region>:<account-id>:dbuser:<dbi-resource-id>/<db-user>"
}
]
}
```
207 changes: 207 additions & 0 deletions cmd/rds-iam-psql/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package main

import (
"context"
"flag"
"fmt"
"log"
"net"
"net/url"
"os"
"os/exec"
"os/signal"
"strconv"
"strings"
"syscall"

"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/corbaltcode/go-libraries/pgutils"
)

func main() {
var (
host = flag.String("host", "", "RDS PostgreSQL endpoint hostname (no port, e.g. mydb.abc123.us-east-1.rds.amazonaws.com)")
port = flag.Int("port", 5432, "RDS PostgreSQL port (default 5432)")
user = flag.String("user", "", "Database user name")
dbName = flag.String("db", "", "Database name")
psqlPath = flag.String("psql", "psql", "Path to psql binary")
sslMode = flag.String("sslmode", "require", "PGSSLMODE for psql (e.g. require, verify-full)")
searchPath = flag.String("search-path", "", "Optional PostgreSQL search_path to set (e.g. 'myschema,public')")
)
flag.Parse()

args := flag.Args()
if len(args) > 1 {
log.Fatalf("expected at most one positional connection URL argument, got %d", len(args))
}

connectionURLArg := ""
if len(args) == 1 {
connectionURLArg = args[0]
}

rawURL, usesIAM, err := buildRawURL(connectionURLArg, *host, *port, *user, *dbName)
if err != nil {
log.Fatalf("%v\n\nUsage examples:\n %s -host mydb.abc123.us-east-1.rds.amazonaws.com -port 5432 -user myuser -db mydb -search-path \"login,public\"\n %s 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'\n", err, os.Args[0], os.Args[0])
}

ctx := context.Background()

connectionStringProvider, err := pgutils.NewConnectionStringProviderFromURLString(ctx, rawURL)
if err != nil {
log.Fatalf("failed to create connection string provider: %v", err)
}

if usesIAM {
if os.Getenv("AWS_REGION") == "" {
log.Fatalf("AWS_REGION must be set for IAM auth")
}

cfg, err := awsconfig.LoadDefaultConfig(ctx)
if err != nil {
log.Fatalf("failed to load AWS config: %v", err)
}
if err := printCallerIdentity(ctx, cfg); err != nil {
log.Fatalf("AWS credentials check failed: %v", err)
}
}

dsnWithToken, err := connectionStringProvider.ConnectionString(ctx)
if err != nil {
log.Fatalf("failed to get connection string from provider: %v", err)
}

parsedURL, err := url.Parse(dsnWithToken)
if err != nil {
log.Fatalf("failed to parse connection string from provider: %v", err)
}

password := ""
if parsedURL.User != nil {
var ok bool
password, ok = parsedURL.User.Password()
if ok {
parsedURL.User = url.User(parsedURL.User.Username())
}
}

// Pass DSN to psql without password in argv, and provide password via env.
cmd := exec.Command(*psqlPath, parsedURL.String())

cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

env := os.Environ()
if password != "" {
env = append(env, "PGPASSWORD="+password)
}
env = append(env, "PGSSLMODE="+*sslMode)

if sp := strings.TrimSpace(*searchPath); sp != "" {
add := "-c search_path=" + sp

found := false
for i, e := range env {
if strings.HasPrefix(e, "PGOPTIONS=") {
current := strings.TrimPrefix(e, "PGOPTIONS=")
if strings.TrimSpace(current) == "" {
env[i] = "PGOPTIONS=" + add
} else {
env[i] = "PGOPTIONS=" + current + " " + add
}
found = true
break
}
}
if !found {
env = append(env, "PGOPTIONS="+add)
}
}

cmd.Env = env

// Keep psql in the foreground process group. Swallow SIGINT in wrapper so
// psql handles Ctrl-C directly.
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
defer signal.Stop(sigCh)

if err := cmd.Start(); err != nil {
log.Fatalf("failed to start psql: %v", err)
}

waitCh := make(chan error, 1)
go func() { waitCh <- cmd.Wait() }()

for {
select {
case sig := <-sigCh:
switch sig {
case os.Interrupt:
continue
case syscall.SIGTERM:
if cmd.Process != nil {
_ = cmd.Process.Signal(syscall.SIGTERM)
}
}
case err := <-waitCh:
if err == nil {
return
}
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
log.Fatalf("psql failed: %v", err)
}
}
}

func buildRawURL(connectionURLArg, host string, port int, user, dbName string) (string, bool, error) {
if connectionURLArg != "" {
if host != "" || user != "" || dbName != "" || port != 5432 {
return "", false, fmt.Errorf("positional connection URL cannot be combined with -host, -port, -user, or -db")
}
parsedURL, err := url.Parse(connectionURLArg)
if err != nil {
return "", false, fmt.Errorf("failed to parse positional connection URL: %w", err)
}
switch parsedURL.Scheme {
case "postgres+rds-iam":
return connectionURLArg, true, nil
case "postgres", "postgresql":
return connectionURLArg, false, nil
default:
return "", false, fmt.Errorf("unsupported connection URL scheme %q (expected postgres, postgresql, or postgres+rds-iam)", parsedURL.Scheme)
}
}

if host == "" || user == "" || dbName == "" {
return "", false, fmt.Errorf("host, user, and db are required when no positional connection URL is provided")
}
if port <= 0 {
return "", false, fmt.Errorf("invalid port: %d", port)
}

iamURL := &url.URL{
Scheme: "postgres+rds-iam",
User: url.User(user),
Host: net.JoinHostPort(host, strconv.Itoa(port)),
Path: "/" + dbName,
}
return iamURL.String(), true, nil
}

func printCallerIdentity(ctx context.Context, cfg aws.Config) error {
stsClient := sts.NewFromConfig(cfg)

out, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return fmt.Errorf("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w", err)
}

fmt.Printf("Caller ARN: %s\n", aws.ToString(out.Arn))
return nil
}
Loading