diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 38cd98df..b056cd03 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,7 +23,9 @@ jobs: go-version-file: go.mod - name: Stub frontend embed dir - run: mkdir -p internal/web/dist && echo ok > internal/web/dist/stub.html + run: | + mkdir -p internal/web/dist + echo '
' > internal/web/dist/index.html - uses: golangci/golangci-lint-action@1e7e51e771db61008b38414a730f564565cf7c20 # v9.2.0 with: @@ -55,7 +57,9 @@ jobs: path-type: inherit - name: Stub frontend embed dir - run: mkdir -p internal/web/dist && echo ok > internal/web/dist/stub.html + run: | + mkdir -p internal/web/dist + echo '
' > internal/web/dist/index.html - name: Run Go tests run: go test -tags fts5 ./... -v -count=1 @@ -80,7 +84,9 @@ jobs: go-version-file: go.mod - name: Stub frontend embed dir - run: mkdir -p internal/web/dist && echo ok > internal/web/dist/stub.html + run: | + mkdir -p internal/web/dist + echo '
' > internal/web/dist/index.html - name: Test with coverage run: go test -tags fts5 -race -coverprofile=coverage.out ./... @@ -98,6 +104,42 @@ jobs: if: steps.codecov.outcome == 'failure' run: echo "::warning::Codecov upload failed" + integration: + runs-on: ubuntu-latest + services: + postgres: + image: postgres:16-alpine + env: + POSTGRES_USER: agentsview_test + POSTGRES_PASSWORD: agentsview_test_password + POSTGRES_DB: agentsview_test + ports: + - 5433:5432 + options: >- + --health-cmd "pg_isready -U agentsview_test -d agentsview_test" + --health-interval 2s + --health-timeout 5s + --health-retries 10 + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 + with: + go-version-file: go.mod + + - name: Stub frontend embed dir + run: | + mkdir -p internal/web/dist + echo '
' > internal/web/dist/index.html + + - name: Run PostgreSQL integration tests + run: make test-postgres-ci + env: + CGO_ENABLED: "1" + TEST_PG_URL: postgres://agentsview_test:agentsview_test_password@localhost:5433/agentsview_test?sslmode=disable + e2e: runs-on: ubuntu-latest steps: diff --git a/CLAUDE.md b/CLAUDE.md index e640e588..f854687c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,15 +8,20 @@ agentsview is a local web viewer for AI agent sessions (Claude Code, Codex, Copi ``` CLI (agentsview) → Config → DB (SQLite/FTS5) - ↓ + ↓ ↓ File Watcher → Sync Engine → Parser (Claude, Codex, Copilot, Gemini, OpenCode, Amp) - ↓ + ↓ ↓ HTTP Server → REST API + SSE + Embedded SPA + ↓ + PG Push Sync → PostgreSQL (optional) + ↑ + HTTP Server (pg serve) ← PostgreSQL ``` - **Server**: HTTP server with auto-port discovery (default 8080) -- **Storage**: SQLite with WAL mode, FTS5 for full-text search +- **Storage**: SQLite with WAL mode, FTS5 for full-text search; optional PostgreSQL for multi-machine shared access - **Sync**: File watcher + periodic sync (15min) for session directories +- **PG Sync**: On-demand push sync from SQLite to PostgreSQL via `pg push` - **Frontend**: Svelte 5 SPA embedded in the Go binary at build time - **Config**: Env vars (`AGENT_VIEWER_DATA_DIR`, `CLAUDE_PROJECTS_DIR`, `CODEX_SESSIONS_DIR`, `COPILOT_DIR`, `GEMINI_DIR`, `OPENCODE_DIR`, `AMP_DIR`) and CLI flags @@ -24,8 +29,9 @@ CLI (agentsview) → Config → DB (SQLite/FTS5) - `cmd/agentsview/` - Go server entrypoint - `cmd/testfixture/` - Test data generator for E2E tests -- `internal/config/` - Config loading, flag registration, legacy migration +- `internal/config/` - Config loading (TOML, JSON migration), flag registration - `internal/db/` - SQLite operations (sessions, messages, search, analytics) +- `internal/postgres/` - PostgreSQL support: push sync, read-only store, schema, connection helpers - `internal/parser/` - Session file parsers (Claude, Codex, Copilot, Gemini, OpenCode, Amp, content extraction) - `internal/server/` - HTTP handlers, SSE, middleware, search, export - `internal/sync/` - Sync engine, file watcher, discovery, hashing @@ -39,6 +45,7 @@ CLI (agentsview) → Config → DB (SQLite/FTS5) | Path | Purpose | |------|---------| | `cmd/agentsview/main.go` | CLI entry point, server startup, file watcher | +| `cmd/agentsview/pg.go` | pg command group (push, status, serve) | | `internal/server/server.go` | HTTP router and handler setup | | `internal/server/sessions.go` | Session list/detail API handlers | | `internal/server/search.go` | Full-text search API | @@ -51,6 +58,15 @@ CLI (agentsview) → Config → DB (SQLite/FTS5) | `internal/parser/codex.go` | Codex session parser | | `internal/parser/copilot.go` | Copilot CLI session parser | | `internal/parser/amp.go` | Amp session parser | +| `internal/postgres/connect.go` | Connection setup, SSL checks, DSN helpers | +| `internal/postgres/schema.go` | PG DDL, schema management | +| `internal/postgres/push.go` | Push logic, fingerprinting | +| `internal/postgres/sync.go` | Push sync lifecycle | +| `internal/postgres/store.go` | PostgreSQL read-only store | +| `internal/postgres/sessions.go` | PG session queries (read side) | +| `internal/postgres/messages.go` | PG message queries, ILIKE search | +| `internal/postgres/analytics.go` | PG analytics queries | +| `internal/postgres/time.go` | Timestamp conversion helpers | | `internal/config/config.go` | Config loading, flag registration | ## Development @@ -78,6 +94,29 @@ make lint # golangci-lint make vet # go vet ``` +### PostgreSQL Integration Tests + +PG integration tests require a real PostgreSQL instance and the `pgtest` +build tag. The easiest way to run them is with docker-compose: + +```bash +make test-postgres # Starts PG container, runs tests, leaves container running +make postgres-down # Stop the test container when done +``` + +Or manually with an existing PostgreSQL instance: + +```bash +TEST_PG_URL="postgres://user:pass@host:5432/dbname?sslmode=disable" \ + CGO_ENABLED=1 go test -tags "fts5,pgtest" ./internal/postgres/... -v +``` + +Tests create and drop the `agentsview` schema, so use a dedicated +database or one where schema changes are acceptable. + +The CI pipeline runs these tests automatically via a GitHub Actions +service container (see `.github/workflows/ci.yml`, `integration` job). + ### Test Guidelines - Table-driven tests for Go code diff --git a/Makefile b/Makefile index 7073e91a..dab00c76 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ LDFLAGS := -X main.version=$(VERSION) \ LDFLAGS_RELEASE := $(LDFLAGS) -s -w DESKTOP_DIST_DIR := dist/desktop -.PHONY: build build-release install frontend frontend-dev dev desktop-dev desktop-build desktop-macos-app desktop-macos-dmg desktop-windows-installer desktop-linux-appimage desktop-app test test-short e2e vet lint tidy clean release release-darwin-arm64 release-darwin-amd64 release-linux-amd64 install-hooks ensure-embed-dir help +.PHONY: build build-release install frontend frontend-dev dev desktop-dev desktop-build desktop-macos-app desktop-macos-dmg desktop-windows-installer desktop-linux-appimage desktop-app test test-short test-postgres test-postgres-ci postgres-up postgres-down e2e vet lint tidy clean release release-darwin-arm64 release-darwin-amd64 release-linux-amd64 install-hooks ensure-embed-dir help # Ensure go:embed has at least one file (no-op if frontend is built) ensure-embed-dir: @@ -141,6 +141,25 @@ test: ensure-embed-dir test-short: ensure-embed-dir go test -tags fts5 ./... -short -count=1 +# Start test PostgreSQL container +postgres-up: + docker compose -f docker-compose.test.yml up -d --wait + +# Stop test PostgreSQL container +postgres-down: + docker compose -f docker-compose.test.yml down + +# Run PostgreSQL integration tests (starts postgres automatically) +test-postgres: ensure-embed-dir postgres-up + @echo "Waiting for postgres to be ready..." + @sleep 2 + TEST_PG_URL="postgres://agentsview_test:agentsview_test_password@localhost:5433/agentsview_test?sslmode=disable" \ + CGO_ENABLED=1 go test -tags "fts5,pgtest" -v ./internal/postgres/... -count=1 + +# PostgreSQL integration tests for CI (postgres already running as service) +test-postgres-ci: ensure-embed-dir + CGO_ENABLED=1 go test -tags "fts5,pgtest" -v ./internal/postgres/... -count=1 + # Run Playwright E2E tests e2e: cd frontend && npx playwright test @@ -224,6 +243,9 @@ help: @echo "" @echo " test - Run all tests" @echo " test-short - Run fast tests only" + @echo " test-postgres - Run PostgreSQL integration tests" + @echo " postgres-up - Start test PostgreSQL container" + @echo " postgres-down - Stop test PostgreSQL container" @echo " e2e - Run Playwright E2E tests" @echo " vet - Run go vet" @echo " lint - Run golangci-lint" diff --git a/README.md b/README.md index 8726e37a..f4b868c7 100644 --- a/README.md +++ b/README.md @@ -123,23 +123,18 @@ agentsview -host 127.0.0.1 -port 8080 \ -allowed-subnet 192.168.1.0/24 ``` -You can persist the same settings in `~/.agentsview/config.json`: - -```json -{ - "public_url": "https://viewer.example.test", - "proxy": { - "mode": "caddy", - "bind_host": "0.0.0.0", - "public_port": 8443, - "tls_cert": "/home/user/.certs/viewer.crt", - "tls_key": "/home/user/.certs/viewer.key", - "allowed_subnets": [ - "10.0/16", - "192.168.1.0/24" - ] - } -} +You can persist the same settings in `~/.agentsview/config.toml`: + +```toml +public_url = "https://viewer.example.test" + +[proxy] +mode = "caddy" +bind_host = "0.0.0.0" +public_port = 8443 +tls_cert = "/home/user/.certs/viewer.crt" +tls_key = "/home/user/.certs/viewer.key" +allowed_subnets = ["10.0/16", "192.168.1.0/24"] ``` `public_origins` remains available as an advanced override when you @@ -169,6 +164,66 @@ need to allow additional browser origins beyond the main `public_url`. | `r` | Sync sessions | | `?` | Show all shortcuts | +## PostgreSQL Sync + +agentsview can push session data from the local SQLite database to a +remote PostgreSQL instance, enabling shared team dashboards and +centralized search across multiple machines. + +### Push Sync (SQLite to PG) + +Configure `pg` in `~/.agentsview/config.toml`: + +```toml +[pg] +url = "postgres://user:pass@host:5432/dbname?sslmode=require" +machine_name = "my-laptop" +``` + +Use `sslmode=require` (or `verify-full` for CA-verified connections) +for non-local PostgreSQL instances. Only use `sslmode=disable` for +trusted local/loopback connections. + +The `machine_name` identifies which machine pushed each session +(must not be `"local"`, which is reserved). + +CLI commands: + +```bash +agentsview pg push # push now +agentsview pg push --full # force full re-push (bypasses heuristic) +agentsview pg status # show sync status +``` + +Push is on-demand — run `pg push` whenever you want to sync to +PostgreSQL. There is no automatic background push. + +### PG Read-Only Mode + +Serve the web UI directly from PostgreSQL with no local SQLite. +Configure `[pg].url` in config (as shown above), then: + +```bash +agentsview pg serve # default: 127.0.0.1:8080 +agentsview pg serve -port 9090 # custom port +``` + +This mode is useful for shared team viewers where multiple machines +push to a central PG database and one or more read-only instances +serve the UI. Uploads, file watching, and local sync are disabled. +By default, `pg serve` binds to `127.0.0.1`. When a non-loopback +`-host` is specified, remote access is enabled automatically and an +auth token is generated and printed to stdout. + +### Known Limitations + +- **Deleted sessions**: Sessions permanently pruned from SQLite + (via `agentsview prune`) are not propagated as deletions to PG. + Sessions soft-deleted with `deleted_at` are synced correctly. +- **Change detection**: Push uses aggregate length statistics + rather than content hashes. Use `-full` to force a complete + re-push if content was rewritten in-place. + ## Documentation Full documentation is available at @@ -218,6 +273,7 @@ PATH/API keys overrides). cmd/agentsview/ CLI entrypoint internal/config/ Configuration loading internal/db/ SQLite operations (sessions, search, analytics) +internal/postgres/ PostgreSQL support (push sync, read-only store, schema) internal/parser/ Session parsers (all supported agents) internal/server/ HTTP handlers, SSE, middleware internal/sync/ Sync engine, file watcher, discovery diff --git a/cmd/agentsview/main.go b/cmd/agentsview/main.go index d13a1ecb..af0126f8 100644 --- a/cmd/agentsview/main.go +++ b/cmd/agentsview/main.go @@ -49,6 +49,9 @@ func main() { case "sync": runSync(os.Args[2:]) return + case "pg": + runPG(os.Args[2:]) + return case "token-use": runTokenUse(os.Args[2:]) return @@ -76,6 +79,9 @@ Usage: agentsview [flags] Start the server (default command) agentsview serve [flags] Start the server (explicit) agentsview sync [flags] Sync session data without serving + agentsview pg push [flags] Push local data to PostgreSQL + agentsview pg status Show PG sync status + agentsview pg serve [flags] Serve from PostgreSQL (read-only) agentsview token-use Show token usage for a session (JSON) agentsview prune [flags] Delete sessions matching filters agentsview update [flags] Check for and install updates @@ -94,10 +100,18 @@ Server flags: -tls-cert string TLS certificate path for managed Caddy HTTPS mode -tls-key string TLS key path for managed Caddy HTTPS mode -allowed-subnet str Client CIDR allowed to connect to the managed proxy + -no-browser Don't open browser on startup Sync flags: -full Force a full resync regardless of data version +PG push flags: + -full Bypass per-message skip heuristic + +PG serve flags: + -host string Host to bind to (default "127.0.0.1") + -port int Port to listen on (default 8080) + Prune flags: -project string Sessions whose project contains this substring -max-messages int Sessions with at most N messages (default -1) @@ -121,21 +135,20 @@ Environment variables: IFLOW_DIR iFlow projects directory AMP_DIR Amp threads directory AGENT_VIEWER_DATA_DIR Data directory (database, config) + AGENTSVIEW_PG_URL PostgreSQL connection URL for sync + AGENTSVIEW_PG_MACHINE Machine name for PG sync + AGENTSVIEW_PG_SCHEMA PG schema name (default "agentsview") Watcher excludes: - Add "watch_exclude_patterns" to ~/.agentsview/config.json to skip + Add "watch_exclude_patterns" to ~/.agentsview/config.toml to skip directory names/patterns while recursively watching roots. Example: - { - "watch_exclude_patterns": [".git", "node_modules", ".next", "dist"] - } + watch_exclude_patterns = [".git", "node_modules", ".next", "dist"] Multiple directories: - Add arrays to ~/.agentsview/config.json to scan multiple locations: - { - "claude_project_dirs": ["/path/one", "/path/two"], - "codex_sessions_dirs": ["/codex/a", "/codex/b"] - } + Add arrays to ~/.agentsview/config.toml to scan multiple locations: + claude_project_dirs = ["/path/one", "/path/two"] + codex_sessions_dirs = ["/codex/a", "/codex/b"] When set, these override the default directory. Environment variables override config file arrays. @@ -170,6 +183,7 @@ func runServe(args []string) { start := time.Now() cfg := mustLoadConfig(args) setupLogFile(cfg.DataDir) + if err := validateServeConfig(cfg); err != nil { fatal("invalid serve config: %v", err) } @@ -650,3 +664,4 @@ func startUnwatchedPoll(engine *sync.Engine) { engine.SyncAll(context.Background(), nil) } } + diff --git a/cmd/agentsview/pg.go b/cmd/agentsview/pg.go new file mode 100644 index 00000000..32f2873f --- /dev/null +++ b/cmd/agentsview/pg.go @@ -0,0 +1,290 @@ +package main + +import ( + "context" + "encoding/base64" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/wesm/agentsview/internal/config" + "github.com/wesm/agentsview/internal/db" + "github.com/wesm/agentsview/internal/postgres" + "github.com/wesm/agentsview/internal/server" +) + +func runPG(args []string) { + if len(args) == 0 { + fmt.Fprintln(os.Stderr, + "usage: agentsview pg ") + os.Exit(1) + } + switch args[0] { + case "push": + runPGPush(args[1:]) + case "status": + runPGStatus(args[1:]) + case "serve": + runPGServe(args[1:]) + default: + fmt.Fprintf(os.Stderr, + "unknown pg command: %s\n", args[0]) + os.Exit(1) + } +} + +func runPGPush(args []string) { + fs := flag.NewFlagSet("pg push", flag.ExitOnError) + full := fs.Bool("full", false, + "Force full local resync and PG push") + if err := fs.Parse(args); err != nil { + log.Fatalf("parsing flags: %v", err) + } + + appCfg, err := config.LoadMinimal() + if err != nil { + log.Fatalf("loading config: %v", err) + } + if err := os.MkdirAll(appCfg.DataDir, 0o755); err != nil { + log.Fatalf("creating data dir: %v", err) + } + setupLogFile(appCfg.DataDir) + + database, err := db.Open(appCfg.DBPath) + if err != nil { + fatal("opening database: %v", err) + } + defer database.Close() + + if appCfg.CursorSecret != "" { + secret, decErr := base64.StdEncoding.DecodeString( + appCfg.CursorSecret, + ) + if decErr != nil { + fatal("invalid cursor secret: %v", decErr) + } + database.SetCursorSecret(secret) + } + + // Run local sync first so newly discovered sessions + // are available for push. If a full resync was performed + // (e.g. due to data version change), force a full PG push + // since watermarks become stale after a local rebuild. + didResync := runLocalSync(appCfg, database, *full) + forceFull := *full || didResync + + pgCfg, err := appCfg.ResolvePG() + if err != nil { + fatal("pg push: %v", err) + } + if pgCfg.URL == "" { + fatal("pg push: url not configured") + } + + ps, err := postgres.New( + pgCfg.URL, pgCfg.Schema, database, + pgCfg.MachineName, pgCfg.AllowInsecure, + ) + if err != nil { + fatal("pg push: %v", err) + } + defer ps.Close() + + ctx, stop := signal.NotifyContext( + context.Background(), os.Interrupt, + ) + defer stop() + + if err := ps.EnsureSchema(ctx); err != nil { + fatal("pg push schema: %v", err) + } + result, err := ps.Push(ctx, forceFull) + if err != nil { + fatal("pg push: %v", err) + } + fmt.Printf( + "Pushed %d sessions, %d messages in %s\n", + result.SessionsPushed, + result.MessagesPushed, + result.Duration.Round(time.Millisecond), + ) + if result.Errors > 0 { + fatal("pg push: %d session(s) failed", + result.Errors) + } +} + +func runPGStatus(args []string) { + fs := flag.NewFlagSet("pg status", flag.ExitOnError) + if err := fs.Parse(args); err != nil { + log.Fatalf("parsing flags: %v", err) + } + + appCfg, err := config.LoadMinimal() + if err != nil { + log.Fatalf("loading config: %v", err) + } + if err := os.MkdirAll(appCfg.DataDir, 0o755); err != nil { + log.Fatalf("creating data dir: %v", err) + } + setupLogFile(appCfg.DataDir) + + database, err := db.Open(appCfg.DBPath) + if err != nil { + fatal("opening database: %v", err) + } + defer database.Close() + + pgCfg, err := appCfg.ResolvePG() + if err != nil { + fatal("pg status: %v", err) + } + if pgCfg.URL == "" { + fatal("pg status: url not configured") + } + + ps, err := postgres.New( + pgCfg.URL, pgCfg.Schema, database, + pgCfg.MachineName, pgCfg.AllowInsecure, + ) + if err != nil { + fatal("pg status: %v", err) + } + defer ps.Close() + + ctx, stop := signal.NotifyContext( + context.Background(), os.Interrupt, + ) + defer stop() + + status, err := ps.Status(ctx) + if err != nil { + fatal("pg status: %v", err) + } + fmt.Printf("Machine: %s\n", status.Machine) + fmt.Printf("Last push: %s\n", + valueOrNever(status.LastPushAt)) + fmt.Printf("PG sessions: %d\n", status.PGSessions) + fmt.Printf("PG messages: %d\n", status.PGMessages) +} + +func runPGServe(args []string) { + fs := flag.NewFlagSet("pg serve", flag.ExitOnError) + host := fs.String("host", "127.0.0.1", + "Host to bind to") + port := fs.Int("port", 8080, + "Port to listen on") + basePath := fs.String("base-path", "", + "URL prefix for reverse-proxy subpath (e.g. /agentsview)") + if err := fs.Parse(args); err != nil { + log.Fatalf("parsing flags: %v", err) + } + + appCfg, err := config.LoadMinimal() + if err != nil { + log.Fatalf("loading config: %v", err) + } + setupLogFile(appCfg.DataDir) + + pgCfg, err := appCfg.ResolvePG() + if err != nil { + fatal("pg serve: %v", err) + } + if pgCfg.URL == "" { + fatal("pg serve: url not configured") + } + + store, err := postgres.NewStore( + pgCfg.URL, pgCfg.Schema, pgCfg.AllowInsecure, + ) + if err != nil { + fatal("pg serve: %v", err) + } + defer store.Close() + + ctx, stop := signal.NotifyContext( + context.Background(), + os.Interrupt, syscall.SIGTERM, + ) + defer stop() + + if err := postgres.CheckSchemaCompat( + ctx, store.DB(), + ); err != nil { + fatal("pg serve: schema incompatible: %v", err) + } + + appCfg.Host = *host + // Enable remote access with auth when binding to a + // non-loopback address; keep it off for localhost. + if !isLoopbackHost(*host) { + appCfg.RemoteAccess = true + if err := appCfg.EnsureAuthToken(); err != nil { + fatal("pg serve: generating auth token: %v", err) + } + fmt.Printf("Auth token: %s\n", appCfg.AuthToken) + } else { + appCfg.RemoteAccess = false + } + appCfg.Port = server.FindAvailablePort(*host, *port) + if appCfg.Port != *port { + fmt.Printf("Port %d in use, using %d\n", + *port, appCfg.Port) + } + + opts := []server.Option{ + server.WithVersion(server.VersionInfo{ + Version: version, + Commit: commit, + BuildDate: buildDate, + ReadOnly: true, + }), + server.WithBaseContext(ctx), + } + if *basePath != "" { + opts = append(opts, server.WithBasePath(*basePath)) + } + srv := server.New(appCfg, store, nil, opts...) + + serveErrCh := make(chan error, 1) + go func() { + serveErrCh <- srv.ListenAndServe() + }() + if err := waitForLocalPort( + ctx, appCfg.Host, appCfg.Port, + 5*time.Second, serveErrCh, + ); err != nil { + shutdownCtx, cancel := context.WithTimeout( + context.Background(), 5*time.Second, + ) + defer cancel() + _ = srv.Shutdown(shutdownCtx) + fatal("pg serve: server failed to start: %v", err) + } + + fmt.Printf( + "agentsview %s (pg read-only) at http://%s:%d\n", + version, appCfg.Host, appCfg.Port, + ) + + select { + case err := <-serveErrCh: + if err != nil && err != http.ErrServerClosed { + fatal("pg serve: server error: %v", err) + } + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout( + context.Background(), 5*time.Second, + ) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil && + err != http.ErrServerClosed { + fatal("pg serve: shutdown error: %v", err) + } + } +} diff --git a/cmd/agentsview/sync.go b/cmd/agentsview/sync.go index 6bf8a197..043ec3fa 100644 --- a/cmd/agentsview/sync.go +++ b/cmd/agentsview/sync.go @@ -41,7 +41,9 @@ func parseSyncFlags(args []string) (SyncConfig, error) { ) } - return SyncConfig{Full: *full}, nil + return SyncConfig{ + Full: *full, + }, nil } func runSync(args []string) { @@ -65,8 +67,7 @@ func runSync(args []string) { setupLogFile(appCfg.DataDir) - var database *db.DB - database, err = db.Open(appCfg.DBPath) + database, err := db.Open(appCfg.DBPath) if err != nil { fatal("opening database: %v", err) } @@ -80,6 +81,16 @@ func runSync(args []string) { database.SetCursorSecret(secret) } + runLocalSync(appCfg, database, cfg.Full) +} + +// runLocalSync runs a local sync (incremental or full resync). +// It returns true if a full resync was performed, which callers +// can use to force a full PG push (watermarks become stale after +// a local resync). +func runLocalSync( + appCfg config.Config, database *db.DB, full bool, +) bool { for _, def := range parser.Registry { if !appCfg.IsUserConfigured(def.Type) { continue @@ -97,8 +108,9 @@ func runSync(args []string) { Machine: "local", }) + didResync := full || database.NeedsResync() ctx := context.Background() - if cfg.Full || database.NeedsResync() { + if didResync { runInitialResync(ctx, engine) } else { runInitialSync(ctx, engine) @@ -112,4 +124,12 @@ func runSync(args []string) { stats.SessionCount, stats.MessageCount, ) } + return didResync +} + +func valueOrNever(s string) string { + if s == "" { + return "never" + } + return s } diff --git a/docker-compose.test.yml b/docker-compose.test.yml new file mode 100644 index 00000000..8ab3f0d1 --- /dev/null +++ b/docker-compose.test.yml @@ -0,0 +1,22 @@ +# Docker Compose file for integration testing with PostgreSQL +# Usage: +# make postgres-up +# make test-postgres +# make postgres-down + +services: + postgres: + image: postgres:16-alpine + environment: + POSTGRES_USER: agentsview_test + POSTGRES_PASSWORD: agentsview_test_password + POSTGRES_DB: agentsview_test + ports: + - "5433:5432" # Non-standard port to avoid conflict with local postgres + healthcheck: + test: ["CMD-SHELL", "pg_isready -U agentsview_test -d agentsview_test"] + interval: 2s + timeout: 5s + retries: 10 + tmpfs: + - /var/lib/postgresql/data # Use tmpfs for faster tests, no persistence needed diff --git a/frontend/index.html b/frontend/index.html index 5083228a..21c0ad48 100644 --- a/frontend/index.html +++ b/frontend/index.html @@ -3,7 +3,7 @@ - + @@ -11,6 +11,6 @@
- + diff --git a/frontend/src/lib/api/client.ts b/frontend/src/lib/api/client.ts index 46ee2956..3bdd630e 100644 --- a/frontend/src/lib/api/client.ts +++ b/frontend/src/lib/api/client.ts @@ -40,7 +40,18 @@ const AUTH_TOKEN_KEY = "agentsview-auth-token"; function getBase(): string { const server = getServerUrl(); - return server ? `${server}/api/v1` : "/api/v1"; + if (server) return `${server}/api/v1`; + // Use the tag injected by --base-path so the app + // works behind a reverse-proxy subpath (e.g. /agentsview/api/v1). + // Only derive from baseURI when a real tag exists; + // otherwise fall back to "/api/v1" so SPA fallback pages on + // non-root URLs don't produce wrong API paths. + const baseEl = document.querySelector("base[href]"); + if (baseEl) { + const base = new URL(document.baseURI).pathname.replace(/\/$/, ""); + return `${base}/api/v1`; + } + return "/api/v1"; } export function getServerUrl(): string { diff --git a/frontend/src/lib/components/sidebar/SessionList.svelte b/frontend/src/lib/components/sidebar/SessionList.svelte index b58a6f35..1038c5e9 100644 --- a/frontend/src/lib/components/sidebar/SessionList.svelte +++ b/frontend/src/lib/components/sidebar/SessionList.svelte @@ -45,11 +45,21 @@ ); }); - // Ensure agents are loaded when dropdown opens. + let machineSearch = $state(""); + let sortedMachines = $derived.by(() => { + const machines = [...sessions.machines].sort(); + if (!machineSearch) return machines; + const q = machineSearch.toLowerCase(); + return machines.filter((m) => m.toLowerCase().includes(q)); + }); + + // Ensure agents and machines are loaded when dropdown opens. $effect(() => { if (showFilterDropdown) { sessions.loadAgents(); + sessions.loadMachines(); agentSearch = ""; + machineSearch = ""; } }); @@ -518,6 +528,43 @@ {/each} + {#if sessions.machines.length > 0} +
+ + {#if sessions.machines.length > 5} + + {/if} +
+ {#each sortedMachines as machine (machine)} + {@const selected = + sessions.filters.machine === machine} + + {:else} + + {machineSearch ? "No match" : "No machines"} + + {/each} +
+
+ {/if}
diff --git a/frontend/src/lib/stores/sessions.svelte.ts b/frontend/src/lib/stores/sessions.svelte.ts index 70f6ee3a..d4cfbdf4 100644 --- a/frontend/src/lib/stores/sessions.svelte.ts +++ b/frontend/src/lib/stores/sessions.svelte.ts @@ -20,6 +20,7 @@ export interface SessionGroup { interface Filters { project: string; + machine: string; agent: string; date: string; dateFrom: string; @@ -35,6 +36,7 @@ interface Filters { function defaultFilters(): Filters { return { project: "", + machine: "", agent: "", date: "", dateFrom: "", @@ -52,6 +54,7 @@ class SessionsStore { sessions: Session[] = $state([]); projects: ProjectInfo[] = $state([]); agents: AgentInfo[] = $state([]); + machines: string[] = $state([]); activeSessionId: string | null = $state(null); childSessions: Map = $state(new Map()); nextCursor: string | null = $state(null); @@ -71,6 +74,9 @@ class SessionsStore { private agentsVersion: number = 0; private refreshVersion: number = 0; private childSessionsVersion: number = 0; + private machinesLoaded: boolean = false; + private machinesPromise: Promise | null = null; + private machinesVersion: number = 0; get activeSession(): Session | undefined { return this.sessions.find((s) => s.id === this.activeSessionId); @@ -90,6 +96,7 @@ class SessionsStore { return { project: f.project || undefined, exclude_project: exclude, + machine: f.machine || undefined, agent: f.agent || undefined, date: f.date || undefined, date_from: f.dateFrom || undefined, @@ -143,6 +150,7 @@ class SessionsStore { this.filters = { project, + machine: params["machine"] ?? "", agent: params["agent"] ?? "", date: params["date"] ?? "", dateFrom: params["date_from"] ?? "", @@ -321,6 +329,31 @@ class SessionsStore { return this.agentsPromise; } + async loadMachines() { + if (this.machinesLoaded) return; + if (this.machinesPromise) return this.machinesPromise; + const ver = this.machinesVersion; + this.machinesPromise = (async () => { + try { + const params = this.filters.includeOneShot + ? { include_one_shot: true as const } + : {}; + const res = await api.getMachines(params); + if (ver === this.machinesVersion) { + this.machines = res.machines; + this.machinesLoaded = true; + } + } catch { + // Non-fatal; machines list stays stale. + } finally { + if (ver === this.machinesVersion) { + this.machinesPromise = null; + } + } + })(); + return this.machinesPromise; + } + private setActiveSession(id: string | null) { if (id === this.activeSessionId) return; this.activeSessionId = id; @@ -431,6 +464,12 @@ class SessionsStore { this.load(); } + setMachineFilter(machine: string) { + this.filters.machine = this.filters.machine === machine ? "" : machine; + this.activeSessionId = null; + this.load(); + } + setAgentFilter(agent: string) { if (this.filters.agent === agent) { this.filters.agent = ""; @@ -497,6 +536,7 @@ class SessionsStore { get hasActiveFilters(): boolean { const f = this.filters; return !!( + f.machine || f.agent || f.recentlyActive || f.hideUnknownProject || @@ -555,8 +595,12 @@ class SessionsStore { this.agentsVersion++; this.agentsLoaded = false; this.agentsPromise = null; + this.machinesVersion++; + this.machinesLoaded = false; + this.machinesPromise = null; this.loadProjects(); this.loadAgents(); + this.loadMachines(); sync.loadStats( this.filters.includeOneShot ? { include_one_shot: true } diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index f4b3dc3f..6f46e6a0 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -13,6 +13,7 @@ function gitCommit(): string { } export default defineConfig({ + base: "./", plugins: [svelte()], define: { "import.meta.env.VITE_BUILD_COMMIT": JSON.stringify( diff --git a/go.mod b/go.mod index 2576a1c4..251d8099 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,11 @@ module github.com/wesm/agentsview go 1.25.5 require ( + github.com/BurntSushi/toml v1.6.0 github.com/fsnotify/fsnotify v1.9.0 github.com/google/go-cmp v0.7.0 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 + github.com/jackc/pgx/v5 v5.7.4 github.com/mattn/go-sqlite3 v1.14.34 github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 @@ -14,9 +16,17 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.13.0 // indirect + golang.org/x/crypto v0.31.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/text v0.21.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 7daeca4b..24df9d24 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= +github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= @@ -6,10 +10,27 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg= +github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= @@ -18,11 +39,19 @@ github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/config/config.go b/internal/config/config.go index eb4bd35d..e6ca2335 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "bytes" "crypto/rand" "encoding/base64" "encoding/json" @@ -12,11 +13,14 @@ import ( "net/url" "os" "path/filepath" + "regexp" "slices" "strconv" "strings" + "sync" "time" + "github.com/BurntSushi/toml" "github.com/wesm/agentsview/internal/parser" ) @@ -24,64 +28,74 @@ import ( type TerminalConfig struct { // Mode: "auto" (detect terminal), "custom" (use CustomBin), // or "clipboard" (never launch, always copy). - Mode string `json:"mode"` + Mode string `json:"mode" toml:"mode"` // CustomBin is the terminal binary path (used when Mode == "custom"). - CustomBin string `json:"custom_bin,omitempty"` + CustomBin string `json:"custom_bin,omitempty" toml:"custom_bin"` // CustomArgs is a template for terminal args. Use {cmd} as // placeholder for the resume command (e.g. "-- bash -c {cmd}"). - CustomArgs string `json:"custom_args,omitempty"` + CustomArgs string `json:"custom_args,omitempty" toml:"custom_args"` } // ProxyConfig controls an optional managed reverse proxy. type ProxyConfig struct { // Mode enables a managed proxy implementation. // Currently supported: "caddy". - Mode string `json:"mode,omitempty"` + Mode string `json:"mode,omitempty" toml:"mode"` // Bin overrides the proxy executable path. - Bin string `json:"bin,omitempty"` + Bin string `json:"bin,omitempty" toml:"bin"` // BindHost is the local interface/IP the proxy binds to. - BindHost string `json:"bind_host,omitempty"` + BindHost string `json:"bind_host,omitempty" toml:"bind_host"` // PublicPort is the external port exposed by the proxy. - PublicPort int `json:"public_port,omitempty"` + PublicPort int `json:"public_port,omitempty" toml:"public_port"` // TLSCert and TLSKey are used by managed HTTPS mode. - TLSCert string `json:"tls_cert,omitempty"` - TLSKey string `json:"tls_key,omitempty"` + TLSCert string `json:"tls_cert,omitempty" toml:"tls_cert"` + TLSKey string `json:"tls_key,omitempty" toml:"tls_key"` // AllowedSubnets restrict inbound clients to these CIDRs. - AllowedSubnets []string `json:"allowed_subnets,omitempty"` + AllowedSubnets []string `json:"allowed_subnets,omitempty" toml:"allowed_subnets"` +} + +// PGConfig holds PostgreSQL connection settings. +type PGConfig struct { + URL string `toml:"url" json:"url"` + Schema string `toml:"schema" json:"schema"` + MachineName string `toml:"machine_name" json:"machine_name"` + AllowInsecure bool `toml:"allow_insecure" json:"allow_insecure"` } // Config holds all application configuration. type Config struct { - Host string `json:"host"` - Port int `json:"port"` - DataDir string `json:"data_dir"` - DBPath string `json:"-"` - PublicURL string `json:"public_url,omitempty"` - PublicOrigins []string `json:"public_origins,omitempty"` - Proxy ProxyConfig `json:"proxy,omitempty"` - WatchExcludePatterns []string `json:"watch_exclude_patterns,omitempty"` - CursorSecret string `json:"cursor_secret"` - GithubToken string `json:"github_token,omitempty"` - Terminal TerminalConfig `json:"terminal,omitempty"` - AuthToken string `json:"auth_token,omitempty"` - RemoteAccess bool `json:"remote_access"` - WriteTimeout time.Duration `json:"-"` + Host string `json:"host" toml:"host"` + Port int `json:"port" toml:"port"` + DataDir string `json:"data_dir" toml:"data_dir"` + DBPath string `json:"-" toml:"-"` + PublicURL string `json:"public_url,omitempty" toml:"public_url"` + PublicOrigins []string `json:"public_origins,omitempty" toml:"public_origins"` + Proxy ProxyConfig `json:"proxy,omitempty" toml:"proxy"` + WatchExcludePatterns []string `json:"watch_exclude_patterns,omitempty" toml:"watch_exclude_patterns"` + CursorSecret string `json:"cursor_secret" toml:"cursor_secret"` + GithubToken string `json:"github_token,omitempty" toml:"github_token"` + Terminal TerminalConfig `json:"terminal,omitempty" toml:"terminal"` + AuthToken string `json:"auth_token,omitempty" toml:"auth_token"` + RemoteAccess bool `json:"remote_access" toml:"remote_access"` + NoBrowser bool `json:"no_browser" toml:"no_browser"` + PG PGConfig `json:"pg,omitempty" toml:"pg"` + WriteTimeout time.Duration `json:"-" toml:"-"` // AgentDirs maps each AgentType to its configured // directories. Single-dir agents store a one-element // slice; unconfigured agents use nil. - AgentDirs map[parser.AgentType][]string `json:"-"` + AgentDirs map[parser.AgentType][]string `json:"-" toml:"-"` // agentDirSource tracks how each agent's dirs were // set so loadFile doesn't override env-set values. agentDirSource map[parser.AgentType]dirSource - ResultContentBlockedCategories []string `json:"result_content_blocked_categories,omitempty"` + ResultContentBlockedCategories []string `json:"result_content_blocked_categories,omitempty" toml:"result_content_blocked_categories"` // HostExplicit is true when the user passed --host on the CLI. // Used to prevent auto-bind to 0.0.0.0 when the user // explicitly requested a specific host. - HostExplicit bool `json:"-"` + HostExplicit bool `json:"-" toml:"-"` } type dirSource int @@ -181,31 +195,73 @@ func LoadMinimal() (Config, error) { } func (c *Config) configPath() string { + return filepath.Join(c.DataDir, "config.toml") +} + +func (c *Config) jsonConfigPath() string { return filepath.Join(c.DataDir, "config.json") } -func (c *Config) loadFile() error { - data, err := os.ReadFile(c.configPath()) +// migrateJSONToTOML converts config.json to config.toml if +// config.json exists and config.toml does not. The original +// JSON file is renamed to config.json.bak. +func (c *Config) migrateJSONToTOML() error { + jsonPath := c.jsonConfigPath() + tomlPath := c.configPath() + + if _, err := os.Stat(tomlPath); err == nil { + return nil // TOML already exists + } + data, err := os.ReadFile(jsonPath) if os.IsNotExist(err) { - return nil + return nil // no JSON to migrate } if err != nil { + return fmt.Errorf("reading config.json for migration: %w", err) + } + + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return fmt.Errorf("parsing config.json for migration: %w", err) + } + + var buf bytes.Buffer + if err := toml.NewEncoder(&buf).Encode(m); err != nil { + return fmt.Errorf("encoding config.toml: %w", err) + } + if err := os.WriteFile(tomlPath, buf.Bytes(), 0o600); err != nil { + return fmt.Errorf("writing config.toml: %w", err) + } + if err := os.Rename(jsonPath, jsonPath+".bak"); err != nil { + return fmt.Errorf("renaming config.json to .bak: %w", err) + } + return nil +} + +func (c *Config) loadFile() error { + if err := c.migrateJSONToTOML(); err != nil { return err } + path := c.configPath() + if _, err := os.Stat(path); os.IsNotExist(err) { + return nil + } + var file struct { - GithubToken string `json:"github_token"` - CursorSecret string `json:"cursor_secret"` - PublicURL string `json:"public_url"` - PublicOrigins []string `json:"public_origins"` - Proxy ProxyConfig `json:"proxy"` - WatchExcludePatterns []string `json:"watch_exclude_patterns"` - ResultContentBlockedCategories []string `json:"result_content_blocked_categories"` - Terminal TerminalConfig `json:"terminal"` - AuthToken string `json:"auth_token"` - RemoteAccess bool `json:"remote_access"` - } - if err := json.Unmarshal(data, &file); err != nil { + GithubToken string `toml:"github_token"` + CursorSecret string `toml:"cursor_secret"` + PublicURL string `toml:"public_url"` + PublicOrigins []string `toml:"public_origins"` + Proxy ProxyConfig `toml:"proxy"` + WatchExcludePatterns []string `toml:"watch_exclude_patterns"` + ResultContentBlockedCategories []string `toml:"result_content_blocked_categories"` + Terminal TerminalConfig `toml:"terminal"` + AuthToken string `toml:"auth_token"` + RemoteAccess bool `toml:"remote_access"` + PG PGConfig `toml:"pg"` + } + if _, err := toml.DecodeFile(path, &file); err != nil { return fmt.Errorf("parsing config: %w", err) } if file.GithubToken != "" { @@ -239,11 +295,25 @@ func (c *Config) loadFile() error { c.AuthToken = file.AuthToken } c.RemoteAccess = file.RemoteAccess + // Merge pg field-by-field so env vars override only + // the fields they set, preserving config-file settings. + if file.PG.URL != "" && c.PG.URL == "" { + c.PG.URL = file.PG.URL + } + if file.PG.Schema != "" && c.PG.Schema == "" { + c.PG.Schema = file.PG.Schema + } + if file.PG.MachineName != "" && c.PG.MachineName == "" { + c.PG.MachineName = file.PG.MachineName + } + if file.PG.AllowInsecure { + c.PG.AllowInsecure = true + } // Parse config-file dir arrays for agents that have a // ConfigKey. Only apply when not already set by env var. - var raw map[string]json.RawMessage - if err := json.Unmarshal(data, &raw); err != nil { + var raw map[string]any + if _, err := toml.DecodeFile(path, &raw); err != nil { return fmt.Errorf("parsing config raw: %w", err) } for _, def := range parser.Registry { @@ -257,14 +327,27 @@ func (c *Config) loadFile() error { if c.agentDirSource[def.Type] == dirEnv { continue } - var dirs []string - if err := json.Unmarshal(rawVal, &dirs); err != nil { + rawSlice, ok := rawVal.([]any) + if !ok { log.Printf( - "config: %s: expected string array: %v", - def.ConfigKey, err, + "config: %s: expected string array: got %T", + def.ConfigKey, rawVal, ) continue } + dirs := make([]string, 0, len(rawSlice)) + for _, v := range rawSlice { + s, ok := v.(string) + if !ok { + log.Printf( + "config: %s: expected string array: element is %T", + def.ConfigKey, v, + ) + dirs = nil + break + } + dirs = append(dirs, s) + } if len(dirs) > 0 { c.AgentDirs[def.Type] = dirs c.agentDirSource[def.Type] = dirFile @@ -289,24 +372,39 @@ func (c *Config) ensureCursorSecret() error { return fmt.Errorf("creating data dir: %w", err) } + existing, err := c.readConfigMap() + if err != nil { + return err + } + + existing["cursor_secret"] = secret + return c.writeConfigMap(existing) +} + +// readConfigMap reads the TOML config file into a map. Returns +// an empty map if the file does not exist. +func (c *Config) readConfigMap() (map[string]any, error) { existing := make(map[string]any) data, err := os.ReadFile(c.configPath()) if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("reading config: %w", err) + return nil, fmt.Errorf("reading config: %w", err) } if err == nil { - if err := json.Unmarshal(data, &existing); err != nil { - return fmt.Errorf("existing config invalid: %w", err) + if _, err := toml.Decode(string(data), &existing); err != nil { + return nil, fmt.Errorf("existing config invalid: %w", err) } } + return existing, nil +} - existing["cursor_secret"] = secret - out, err := json.MarshalIndent(existing, "", " ") - if err != nil { +// writeConfigMap encodes a map as TOML and writes it to the +// config file. +func (c *Config) writeConfigMap(m map[string]any) error { + var buf bytes.Buffer + if err := toml.NewEncoder(&buf).Encode(m); err != nil { return fmt.Errorf("marshaling config: %w", err) } - - if err := os.WriteFile(c.configPath(), out, 0o600); err != nil { + if err := os.WriteFile(c.configPath(), buf.Bytes(), 0o600); err != nil { return fmt.Errorf("writing config: %w", err) } return nil @@ -322,6 +420,15 @@ func (c *Config) loadEnv() { if v := os.Getenv("AGENT_VIEWER_DATA_DIR"); v != "" { c.DataDir = v } + if v := os.Getenv("AGENTSVIEW_PG_URL"); v != "" { + c.PG.URL = v + } + if v := os.Getenv("AGENTSVIEW_PG_SCHEMA"); v != "" { + c.PG.Schema = v + } + if v := os.Getenv("AGENTSVIEW_PG_MACHINE"); v != "" { + c.PG.MachineName = v + } } type stringListFlag []string @@ -421,6 +528,8 @@ func applyFlags(cfg *Config, fs *flag.FlagSet) { cfg.Proxy.TLSKey = f.Value.String() case "allowed-subnet": cfg.Proxy.AllowedSubnets = splitFlagList(f.Value.String()) + case "no-browser": + cfg.NoBrowser = f.Value.String() == "true" } }) } @@ -712,78 +821,138 @@ func ResolveDataDir() (string, error) { return cfg.DataDir, nil } +// ResolvePG returns a copy of PG config with defaults applied +// and environment variables expanded in URL. +func (c *Config) ResolvePG() (PGConfig, error) { + pg := c.PG + if pg.URL != "" { + expanded, err := expandBracedEnv(pg.URL) + if err != nil { + return pg, fmt.Errorf("expanding url: %w", err) + } + pg.URL = expanded + } + if pg.Schema == "" { + pg.Schema = "agentsview" + } + if pg.MachineName == "" { + h, err := os.Hostname() + if err != nil { + return pg, fmt.Errorf("os.Hostname failed (%w); set machine_name explicitly in config", err) + } + pg.MachineName = h + } + return pg, nil +} + +var ( + bracedEnvPattern = regexp.MustCompile(`\$\{([A-Za-z_][A-Za-z0-9_]*)\}`) + bareEnvPattern = regexp.MustCompile(`^\$([A-Za-z_][A-Za-z0-9_]*)$`) + partialBareEnvPattern = regexp.MustCompile(`\$([A-Za-z_][A-Za-z0-9_]*)`) +) + +// bareEnvWarned tracks which bare $VAR names have already been warned +// about, so each distinct variable triggers a warning at most once. +var bareEnvWarned sync.Map + +// ResetBareEnvWarned clears the warning dedup state. Exported for tests. +func ResetBareEnvWarned() { + bareEnvWarned.Range(func(k, _ any) bool { bareEnvWarned.Delete(k); return true }) +} + +// expandBracedEnv expands ${VAR} references in s. As a convenience, +// if the entire string is a single bare $VAR (e.g. "$PGURL"), it is +// expanded as a whole-string shortcut. Bare $VAR references embedded +// in a larger string (e.g. "postgres://$USER@host") are NOT expanded; +// use ${VAR} syntax instead. +func expandBracedEnv(s string) (string, error) { + if parts := bareEnvPattern.FindStringSubmatch(s); parts != nil { + val, ok := os.LookupEnv(parts[1]) + if !ok { + return "", fmt.Errorf("environment variable %s is not set", parts[1]) + } + return val, nil + } + + // Warn about bare $VAR references that won't be expanded. + if remaining := bracedEnvPattern.ReplaceAllString(s, ""); partialBareEnvPattern.MatchString(remaining) { + for _, m := range partialBareEnvPattern.FindAllStringSubmatch(remaining, -1) { + if _, set := os.LookupEnv(m[1]); set { + if _, warned := bareEnvWarned.LoadOrStore(m[1], true); !warned { + log.Printf("warning: pg.url contains bare $%s which will NOT be expanded; use ${%s} syntax instead", m[1], m[1]) + } + } + } + } + + var missingVars []string + result := bracedEnvPattern.ReplaceAllStringFunc(s, func(match string) string { + name := bracedEnvPattern.FindStringSubmatch(match)[1] + val, ok := os.LookupEnv(name) + if !ok { + missingVars = append(missingVars, name) + return "" + } + return val + }) + if len(missingVars) > 0 { + return "", fmt.Errorf("environment variable(s) not set: %s", + strings.Join(missingVars, ", ")) + } + return result, nil +} + // SaveTerminalConfig persists terminal settings to the config file. func (c *Config) SaveTerminalConfig(tc TerminalConfig) error { if err := os.MkdirAll(c.DataDir, 0o700); err != nil { return fmt.Errorf("creating data dir: %w", err) } - existing := make(map[string]any) - data, err := os.ReadFile(c.configPath()) - if err != nil && !os.IsNotExist(err) { + existing, err := c.readConfigMap() + if err != nil { return fmt.Errorf("reading config file: %w", err) } - if err == nil { - if err := json.Unmarshal(data, &existing); err != nil { - return fmt.Errorf( - "existing config is invalid, cannot update: %w", - err, - ) - } - } existing["terminal"] = tc - out, err := json.MarshalIndent(existing, "", " ") - if err != nil { - return fmt.Errorf("marshaling config: %w", err) - } - - if err := os.WriteFile(c.configPath(), out, 0o600); err != nil { - return fmt.Errorf("writing config: %w", err) + if err := c.writeConfigMap(existing); err != nil { + return err } c.Terminal = tc return nil } // SaveSettings persists a partial settings update to the config file. -// The patch map contains JSON keys mapped to their new values. Only +// The patch map contains config keys mapped to their new values. Only // the keys present in patch are written; other config keys are preserved. func (c *Config) SaveSettings(patch map[string]any) error { if err := os.MkdirAll(c.DataDir, 0o700); err != nil { return fmt.Errorf("creating data dir: %w", err) } - existing := make(map[string]any) - data, err := os.ReadFile(c.configPath()) - if err != nil && !os.IsNotExist(err) { + existing, err := c.readConfigMap() + if err != nil { return fmt.Errorf("reading config file: %w", err) } - if err == nil { - if err := json.Unmarshal(data, &existing); err != nil { - return fmt.Errorf( - "existing config is invalid, cannot update: %w", - err, - ) - } - } maps.Copy(existing, patch) - out, err := json.MarshalIndent(existing, "", " ") - if err != nil { - return fmt.Errorf("marshaling config: %w", err) - } - - if err := os.WriteFile(c.configPath(), out, 0o600); err != nil { - return fmt.Errorf("writing config: %w", err) + if err := c.writeConfigMap(existing); err != nil { + return err } // Update in-memory config for known keys. if v, ok := patch["terminal"]; ok { - if b, err := json.Marshal(v); err == nil { - var tc TerminalConfig - if err := json.Unmarshal(b, &tc); err == nil { - c.Terminal = tc + if tc, ok := v.(TerminalConfig); ok { + c.Terminal = tc + } else if m, ok := v.(map[string]any); ok { + if s, ok := m["mode"].(string); ok { + c.Terminal.Mode = s + } + if s, ok := m["custom_bin"].(string); ok { + c.Terminal.CustomBin = s + } + if s, ok := m["custom_args"].(string); ok { + c.Terminal.CustomArgs = s } } } @@ -823,27 +992,13 @@ func (c *Config) EnsureAuthToken() error { return fmt.Errorf("creating data dir: %w", err) } - existing := make(map[string]any) - data, err := os.ReadFile(c.configPath()) - if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("reading config: %w", err) - } - if err == nil { - if err := json.Unmarshal(data, &existing); err != nil { - return fmt.Errorf("existing config invalid: %w", err) - } - } - - existing["auth_token"] = token - out, err := json.MarshalIndent(existing, "", " ") + existing, err := c.readConfigMap() if err != nil { - return fmt.Errorf("marshaling config: %w", err) + return err } - if err := os.WriteFile(c.configPath(), out, 0o600); err != nil { - return fmt.Errorf("writing config: %w", err) - } - return nil + existing["auth_token"] = token + return c.writeConfigMap(existing) } // SaveGithubToken persists the GitHub token to the config file. @@ -852,27 +1007,13 @@ func (c *Config) SaveGithubToken(token string) error { return fmt.Errorf("creating data dir: %w", err) } - existing := make(map[string]any) - data, err := os.ReadFile(c.configPath()) - if err != nil && !os.IsNotExist(err) { + existing, err := c.readConfigMap() + if err != nil { return fmt.Errorf("reading config file: %w", err) } - if err == nil { - if err := json.Unmarshal(data, &existing); err != nil { - return fmt.Errorf( - "existing config is invalid, cannot update: %w", - err, - ) - } - } existing["github_token"] = token - out, err := json.MarshalIndent(existing, "", " ") - if err != nil { - return fmt.Errorf("marshaling config: %w", err) - } - - if err := os.WriteFile(c.configPath(), out, 0o600); err != nil { + if err := c.writeConfigMap(existing); err != nil { return fmt.Errorf("writing config: %w", err) } c.GithubToken = token diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 3a599066..2abaad35 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -2,7 +2,6 @@ package config import ( "bytes" - "encoding/json" "flag" "log" "os" @@ -11,10 +10,11 @@ import ( "strings" "testing" + "github.com/BurntSushi/toml" "github.com/wesm/agentsview/internal/parser" ) -const configFileName = "config.json" +const configFileName = "config.toml" func skipIfNotUnix(t *testing.T) { t.Helper() @@ -32,11 +32,11 @@ func skipIfNotUnix(t *testing.T) { func writeConfig(t *testing.T, dir string, data any) { t.Helper() - b, err := json.Marshal(data) - if err != nil { + var buf bytes.Buffer + if err := toml.NewEncoder(&buf).Encode(data); err != nil { t.Fatalf("marshal config: %v", err) } - if err := os.WriteFile(filepath.Join(dir, configFileName), b, 0o600); err != nil { + if err := os.WriteFile(filepath.Join(dir, configFileName), buf.Bytes(), 0o600); err != nil { t.Fatalf("write config: %v", err) } } @@ -369,10 +369,10 @@ func TestSaveGithubToken_RejectsCorruptConfig(t *testing.T) { tmp := setupTestEnv(t) cfg := Config{DataDir: tmp} - // Write invalid JSON to config file + // Write invalid TOML to config file path := filepath.Join(tmp, configFileName) if err := os.WriteFile( - path, []byte("not json"), 0o600, + path, []byte("[invalid toml = ="), 0o600, ); err != nil { t.Fatal(err) } @@ -392,7 +392,7 @@ func TestSaveGithubToken_ReturnsErrorOnReadFailure(t *testing.T) { // Create a config file that is not readable path := filepath.Join(tmp, configFileName) if err := os.WriteFile( - path, []byte(`{"k":"v"}`), 0o000, + path, []byte("k = \"v\"\n"), 0o000, ); err != nil { t.Fatal(err) } @@ -422,7 +422,7 @@ func TestSaveGithubToken_PreservesExistingKeys(t *testing.T) { t.Fatal(err) } var result map[string]any - if err := json.Unmarshal(got, &result); err != nil { + if _, err := toml.Decode(string(got), &result); err != nil { t.Fatal(err) } if result["custom_key"] != "value" { @@ -728,3 +728,201 @@ func TestLoadFile_ResultContentBlockedCategories(t *testing.T) { }) } } + +func TestLoadFile_PGConfig(t *testing.T) { + tests := []struct { + name string + config map[string]any + envURL string + want PGConfig + }{ + { + "NoConfig", + map[string]any{}, + "", + PGConfig{}, + }, + { + "FromConfigFile", + map[string]any{ + "pg": map[string]any{ + "url": "postgres://localhost/test", + "machine_name": "laptop", + }, + }, + "", + PGConfig{ + URL: "postgres://localhost/test", + MachineName: "laptop", + }, + }, + { + "EnvOverridesConfig", + map[string]any{ + "pg": map[string]any{ + "url": "postgres://from-config", + }, + }, + "postgres://from-env", + PGConfig{ + URL: "postgres://from-env", + }, + }, + { + "EnvURLMergesFileFields", + map[string]any{ + "pg": map[string]any{ + "url": "postgres://from-config", + "machine_name": "laptop", + }, + }, + "postgres://from-env", + PGConfig{ + URL: "postgres://from-env", + MachineName: "laptop", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := setupTestEnv(t) + writeConfig(t, dir, tt.config) + if tt.envURL != "" { + t.Setenv("AGENTSVIEW_PG_URL", tt.envURL) + } + + cfg, err := LoadMinimal() + if err != nil { + t.Fatal(err) + } + + if cfg.PG.URL != tt.want.URL { + t.Errorf( + "URL = %q, want %q", + cfg.PG.URL, + tt.want.URL, + ) + } + if cfg.PG.MachineName != tt.want.MachineName { + t.Errorf( + "MachineName = %q, want %q", + cfg.PG.MachineName, + tt.want.MachineName, + ) + } + }) + } +} + +func TestResolvePG_Defaults(t *testing.T) { + cfg := Config{ + PG: PGConfig{ + URL: "postgres://localhost/test", + }, + } + resolved, err := cfg.ResolvePG() + if err != nil { + t.Fatalf("ResolvePG: %v", err) + } + + if resolved.Schema != "agentsview" { + t.Errorf("Schema = %q, want agentsview", resolved.Schema) + } + if resolved.MachineName == "" { + t.Error("MachineName should default to hostname") + } +} + +func TestResolvePG_ExpandsEnvVars(t *testing.T) { + t.Setenv("PGPASS", "env-secret") + t.Setenv("PGURL", "postgres://localhost/test") + + cfg := Config{ + PG: PGConfig{ + URL: "${PGURL}?password=${PGPASS}", + }, + } + + resolved, err := cfg.ResolvePG() + if err != nil { + t.Fatalf("ResolvePG: %v", err) + } + + want := "postgres://localhost/test?password=env-secret" + if resolved.URL != want { + t.Fatalf("URL = %q, want %q", resolved.URL, want) + } +} + +func TestResolvePG_ExpandsBareEnvOnlyForWholeValue(t *testing.T) { + t.Setenv("PGURL", "postgres://localhost/test") + + cfg := Config{ + PG: PGConfig{ + URL: "$PGURL", + }, + } + + resolved, err := cfg.ResolvePG() + if err != nil { + t.Fatalf("ResolvePG: %v", err) + } + + want := "postgres://localhost/test" + if resolved.URL != want { + t.Fatalf("URL = %q, want %q", resolved.URL, want) + } +} + +func TestResolvePG_PreservesLiteralDollarSequencesInURL(t *testing.T) { + t.Setenv("PGPASS", "env-secret") + + cfg := Config{ + PG: PGConfig{ + URL: "postgres://user:pa$word@localhost/db?application_name=$client&password=${PGPASS}", + }, + } + + resolved, err := cfg.ResolvePG() + if err != nil { + t.Fatalf("ResolvePG: %v", err) + } + + want := "postgres://user:pa$word@localhost/db?application_name=$client&password=env-secret" + if resolved.URL != want { + t.Fatalf("URL = %q, want %q", resolved.URL, want) + } +} + +func TestResolvePG_ErrorsOnMissingEnvVar(t *testing.T) { + cfg := Config{ + PG: PGConfig{ + URL: "${NONEXISTENT_PG_VAR}", + }, + } + + _, err := cfg.ResolvePG() + if err == nil { + t.Fatal("expected error for unset env var") + } + if !strings.Contains(err.Error(), "NONEXISTENT_PG_VAR") { + t.Errorf("error = %v, want mention of NONEXISTENT_PG_VAR", err) + } +} + +func TestResolvePG_ErrorsOnMissingBareEnvVar(t *testing.T) { + cfg := Config{ + PG: PGConfig{ + URL: "$NONEXISTENT_PG_BARE_VAR", + }, + } + + _, err := cfg.ResolvePG() + if err == nil { + t.Fatal("expected error for unset bare env var") + } + if !strings.Contains(err.Error(), "NONEXISTENT_PG_BARE_VAR") { + t.Errorf("error = %v, want mention of NONEXISTENT_PG_BARE_VAR", err) + } +} diff --git a/internal/config/persistence_test.go b/internal/config/persistence_test.go index 9fa08c07..5b0fed10 100644 --- a/internal/config/persistence_test.go +++ b/internal/config/persistence_test.go @@ -1,20 +1,19 @@ package config import ( - "encoding/json" "os" "path/filepath" "testing" + + "github.com/BurntSushi/toml" ) func readConfigFile(t *testing.T, dir string) Config { t.Helper() - data, err := os.ReadFile(filepath.Join(dir, configFileName)) - if err != nil { - t.Fatalf("reading config file: %v", err) - } var fileCfg Config - if err := json.Unmarshal(data, &fileCfg); err != nil { + if _, err := toml.DecodeFile( + filepath.Join(dir, configFileName), &fileCfg, + ); err != nil { t.Fatalf("parsing config file: %v", err) } return fileCfg @@ -63,7 +62,7 @@ func TestCursorSecret_GeneratedAndPersisted(t *testing.T) { func TestCursorSecret_RegeneratedIfMissing(t *testing.T) { dir := setupTestEnv(t) - initialContent := `{"cursor_secret": ""}` + initialContent := "cursor_secret = \"\"\n" if err := os.WriteFile(filepath.Join(dir, configFileName), []byte(initialContent), 0o600); err != nil { t.Fatalf("write config: %v", err) } @@ -86,7 +85,7 @@ func TestCursorSecret_RegeneratedIfMissing(t *testing.T) { func TestCursorSecret_LoadErrorOnInvalidConfig(t *testing.T) { dir := setupTestEnv(t) - if err := os.WriteFile(filepath.Join(dir, configFileName), []byte("{invalid-json"), 0o600); err != nil { + if err := os.WriteFile(filepath.Join(dir, configFileName), []byte("[invalid toml = ="), 0o600); err != nil { t.Fatalf("write config: %v", err) } @@ -99,7 +98,7 @@ func TestCursorSecret_LoadErrorOnInvalidConfig(t *testing.T) { func TestCursorSecret_PreservesOtherFields(t *testing.T) { dir := setupTestEnv(t) - if err := os.WriteFile(filepath.Join(dir, configFileName), []byte(`{"github_token": "my-token"}`), 0o600); err != nil { + if err := os.WriteFile(filepath.Join(dir, configFileName), []byte("github_token = \"my-token\"\n"), 0o600); err != nil { t.Fatalf("write config: %v", err) } diff --git a/internal/db/db.go b/internal/db/db.go index f9e5af06..b37dec4b 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -79,6 +79,9 @@ func (db *DB) Path() string { return db.path } +// ReadOnly returns false for the local SQLite store. +func (db *DB) ReadOnly() bool { return false } + // SetCursorSecret updates the secret key used for cursor signing. func (db *DB) SetCursorSecret(secret []byte) { db.cursorMu.Lock() @@ -300,6 +303,10 @@ func (db *DB) migrateColumns() error { "sessions", "peak_context_tokens", "ALTER TABLE sessions ADD COLUMN peak_context_tokens INTEGER NOT NULL DEFAULT 0", }, + { + "sessions", "local_modified_at", + "ALTER TABLE sessions ADD COLUMN local_modified_at TEXT", + }, } for _, m := range migrations { @@ -633,3 +640,28 @@ func (db *DB) Update(fn func(tx *sql.Tx) error) error { func (db *DB) Reader() *sql.DB { return db.getReader() } + +// GetSyncState reads a value from the pg_sync_state table. +func (db *DB) GetSyncState(key string) (string, error) { + var value string + err := db.getReader().QueryRow( + "SELECT value FROM pg_sync_state WHERE key = ?", key, + ).Scan(&value) + if err == sql.ErrNoRows { + return "", nil + } + return value, err +} + +// SetSyncState writes a value to the pg_sync_state table. +func (db *DB) SetSyncState(key, value string) error { + db.mu.Lock() + defer db.mu.Unlock() + _, err := db.getWriter().Exec( + `INSERT INTO pg_sync_state (key, value) + VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value`, + key, value, + ) + return err +} diff --git a/internal/db/db_test.go b/internal/db/db_test.go index 813af327..1a7549fd 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -4509,3 +4509,164 @@ func TestUpdateSessionIncremental(t *testing.T) { t.Errorf("FileHash cleared: %v", got.FileHash) } } + +func TestSyncState_GetSetRoundtrip(t *testing.T) { + d := testDB(t) + + // Initially empty. + val, err := d.GetSyncState("last_push_at") + requireNoError(t, err, "get initial") + if val != "" { + t.Fatalf("initial value = %q, want empty", val) + } + + // Set and read back. + if err := d.SetSyncState("last_push_at", "2026-03-11T12:00:00.000Z"); err != nil { + t.Fatalf("set: %v", err) + } + val, err = d.GetSyncState("last_push_at") + requireNoError(t, err, "get after set") + if val != "2026-03-11T12:00:00.000Z" { + t.Fatalf("value = %q, want 2026-03-11T12:00:00.000Z", val) + } + + // Update. + if err := d.SetSyncState("last_push_at", "2026-03-11T13:00:00.000Z"); err != nil { + t.Fatalf("update: %v", err) + } + val, err = d.GetSyncState("last_push_at") + requireNoError(t, err, "get after update") + if val != "2026-03-11T13:00:00.000Z" { + t.Fatalf("value = %q, want 2026-03-11T13:00:00.000Z", val) + } +} + +func TestListSessionsModifiedBetween(t *testing.T) { + d := testDB(t) + ctx := context.Background() + + // Insert sessions with different timestamps. + sessions := []Session{ + {ID: "s1", Project: "p", Machine: "local", Agent: "claude", CreatedAt: "2026-03-10T12:00:00.000Z"}, + {ID: "s2", Project: "p", Machine: "local", Agent: "claude", CreatedAt: "2026-03-11T12:00:00.000Z"}, + {ID: "s3", Project: "p", Machine: "local", Agent: "claude", CreatedAt: "2026-03-12T12:00:00.000Z"}, + } + for _, s := range sessions { + if err := d.UpsertSession(s); err != nil { + t.Fatalf("upsert %s: %v", s.ID, err) + } + } + + // Backdate created_at for deterministic test results. + for _, s := range sessions { + _, err := d.getWriter().Exec( + "UPDATE sessions SET created_at = ? WHERE id = ?", + s.CreatedAt, s.ID, + ) + if err != nil { + t.Fatalf("backdate %s: %v", s.ID, err) + } + } + + // Query all. + all, err := d.ListSessionsModifiedBetween(ctx, "", "") + if err != nil { + t.Fatalf("list all: %v", err) + } + if len(all) != 3 { + t.Fatalf("list all = %d, want 3", len(all)) + } + + // Query with since. + since, err := d.ListSessionsModifiedBetween(ctx, "2026-03-11T00:00:00Z", "") + if err != nil { + t.Fatalf("list since: %v", err) + } + if len(since) != 2 { + t.Fatalf("list since = %d, want 2", len(since)) + } + + // Query with until. + until, err := d.ListSessionsModifiedBetween(ctx, "", "2026-03-11T12:00:00.000Z") + if err != nil { + t.Fatalf("list until: %v", err) + } + if len(until) != 2 { + t.Fatalf("list until = %d, want 2", len(until)) + } + + // Query with both. + between, err := d.ListSessionsModifiedBetween(ctx, "2026-03-10T12:00:00.000Z", "2026-03-11T12:00:00.000Z") + if err != nil { + t.Fatalf("list between: %v", err) + } + if len(between) != 1 { + t.Fatalf("list between = %d, want 1 (s2 only)", len(between)) + } + if between[0].ID != "s2" { + t.Errorf("between[0].ID = %q, want s2", between[0].ID) + } +} + +func TestMessageContentFingerprint(t *testing.T) { + d := testDB(t) + sess := Session{ID: "fp-sess", Project: "p", Machine: "local", Agent: "claude"} + if err := d.UpsertSession(sess); err != nil { + t.Fatalf("upsert: %v", err) + } + if err := d.InsertMessages([]Message{ + {SessionID: "fp-sess", Ordinal: 0, Role: "user", Content: "hello", ContentLength: 5}, + {SessionID: "fp-sess", Ordinal: 1, Role: "assistant", Content: "hi there!", ContentLength: 9}, + }); err != nil { + t.Fatalf("insert: %v", err) + } + + sum, max, min, err := d.MessageContentFingerprint("fp-sess") + if err != nil { + t.Fatalf("fingerprint: %v", err) + } + if sum != 14 { + t.Errorf("sum = %d, want 14", sum) + } + if max != 9 { + t.Errorf("max = %d, want 9", max) + } + if min != 5 { + t.Errorf("min = %d, want 5", min) + } +} + +func TestToolCallCountAndFingerprint(t *testing.T) { + d := testDB(t) + sess := Session{ID: "tc-sess", Project: "p", Machine: "local", Agent: "claude"} + if err := d.UpsertSession(sess); err != nil { + t.Fatalf("upsert: %v", err) + } + if err := d.InsertMessages([]Message{ + { + SessionID: "tc-sess", Ordinal: 0, Role: "assistant", Content: "tool", + ToolCalls: []ToolCall{ + {ToolName: "Read", Category: "Read", ResultContentLength: 100}, + {ToolName: "Write", Category: "Write", ResultContentLength: 50}, + }, + }, + }); err != nil { + t.Fatalf("insert: %v", err) + } + + count, err := d.ToolCallCount("tc-sess") + if err != nil { + t.Fatalf("count: %v", err) + } + if count != 2 { + t.Errorf("count = %d, want 2", count) + } + + sum, err := d.ToolCallContentFingerprint("tc-sess") + if err != nil { + t.Fatalf("fingerprint: %v", err) + } + if sum != 150 { + t.Errorf("sum = %d, want 150", sum) + } +} diff --git a/internal/db/messages.go b/internal/db/messages.go index cf91760c..aa53dae5 100644 --- a/internal/db/messages.go +++ b/internal/db/messages.go @@ -536,6 +536,39 @@ func (db *DB) MessageCount(sessionID string) (int, error) { return count, err } +// MessageContentFingerprint returns a lightweight fingerprint of all +// messages for a session, computed as the sum, max, and min of +// content_length values. +func (db *DB) MessageContentFingerprint(sessionID string) (sum, max, min int64, err error) { + err = db.getReader().QueryRow( + "SELECT COALESCE(SUM(content_length), 0), COALESCE(MAX(content_length), 0), COALESCE(MIN(content_length), 0) FROM messages WHERE session_id = ?", + sessionID, + ).Scan(&sum, &max, &min) + return sum, max, min, err +} + +// ToolCallCount returns the number of tool_calls rows for a session. +func (db *DB) ToolCallCount(sessionID string) (int, error) { + var n int + err := db.getReader().QueryRow( + "SELECT COUNT(*) FROM tool_calls WHERE session_id = ?", + sessionID, + ).Scan(&n) + return n, err +} + +// ToolCallContentFingerprint returns the sum of result_content_length +// values for a session's tool calls, used as a lightweight content +// change detector. +func (db *DB) ToolCallContentFingerprint(sessionID string) (int64, error) { + var sum int64 + err := db.getReader().QueryRow( + "SELECT COALESCE(SUM(result_content_length), 0) FROM tool_calls WHERE session_id = ?", + sessionID, + ).Scan(&sum) + return sum, err +} + // GetMessageByOrdinal returns a single message by session ID and ordinal. func (db *DB) GetMessageByOrdinal( sessionID string, ordinal int, diff --git a/internal/db/schema.sql b/internal/db/schema.sql index c0eb703c..0b434593 100644 --- a/internal/db/schema.sql +++ b/internal/db/schema.sql @@ -14,6 +14,7 @@ CREATE TABLE IF NOT EXISTS sessions ( file_size INTEGER, file_mtime INTEGER, file_hash TEXT, + local_modified_at TEXT, parent_session_id TEXT, relationship_type TEXT NOT NULL DEFAULT '', total_output_tokens INTEGER NOT NULL DEFAULT 0, @@ -181,3 +182,9 @@ CREATE TABLE IF NOT EXISTS skipped_files ( file_path TEXT PRIMARY KEY, file_mtime INTEGER NOT NULL ); + +-- PG sync state: stores watermarks for push sync +CREATE TABLE IF NOT EXISTS pg_sync_state ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); diff --git a/internal/db/sessions.go b/internal/db/sessions.go index c1fe3fa0..d70b14be 100644 --- a/internal/db/sessions.go +++ b/internal/db/sessions.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "strings" + "time" ) // ErrInvalidCursor is returned when a cursor cannot be decoded or verified. @@ -45,7 +46,7 @@ const sessionFullCols = `id, project, machine, agent, parent_session_id, relationship_type, total_output_tokens, peak_context_tokens, deleted_at, file_path, file_size, file_mtime, - file_hash, created_at` + file_hash, local_modified_at, created_at` const ( // DefaultSessionLimit is the default number of sessions returned. @@ -95,6 +96,7 @@ type Session struct { FileSize *int64 `json:"file_size,omitempty"` FileMtime *int64 `json:"file_mtime,omitempty"` FileHash *string `json:"file_hash,omitempty"` + LocalModifiedAt *string `json:"local_modified_at,omitempty"` CreatedAt string `json:"created_at"` } @@ -459,7 +461,7 @@ func (db *DB) GetSessionFull( &s.ParentSessionID, &s.RelationshipType, &s.TotalOutputTokens, &s.PeakContextTokens, &s.DeletedAt, &s.FilePath, &s.FileSize, - &s.FileMtime, &s.FileHash, &s.CreatedAt, + &s.FileMtime, &s.FileHash, &s.LocalModifiedAt, &s.CreatedAt, ) if err == sql.ErrNoRows { return nil, nil @@ -942,7 +944,7 @@ func (db *DB) GetMachines( } defer rows.Close() - var machines []string + machines := []string{} for rows.Next() { var m string if err := rows.Scan(&m); err != nil { @@ -1068,7 +1070,9 @@ func (db *DB) SoftDeleteSession(id string) error { db.mu.Lock() defer db.mu.Unlock() _, err := db.getWriter().Exec( - `UPDATE sessions SET deleted_at = strftime('%Y-%m-%dT%H:%M:%fZ','now') + `UPDATE sessions + SET deleted_at = strftime('%Y-%m-%dT%H:%M:%fZ','now'), + local_modified_at = strftime('%Y-%m-%dT%H:%M:%fZ','now') WHERE id = ? AND deleted_at IS NULL`, id, ) return err @@ -1081,7 +1085,10 @@ func (db *DB) RestoreSession(id string) (int64, error) { db.mu.Lock() defer db.mu.Unlock() res, err := db.getWriter().Exec( - "UPDATE sessions SET deleted_at = NULL WHERE id = ? AND deleted_at IS NOT NULL", + `UPDATE sessions + SET deleted_at = NULL, + local_modified_at = strftime('%Y-%m-%dT%H:%M:%fZ','now') + WHERE id = ? AND deleted_at IS NOT NULL`, id, ) if err != nil { @@ -1097,7 +1104,10 @@ func (db *DB) RenameSession(id string, displayName *string) error { db.mu.Lock() defer db.mu.Unlock() _, err := db.getWriter().Exec( - "UPDATE sessions SET display_name = ? WHERE id = ? AND deleted_at IS NULL", + `UPDATE sessions + SET display_name = ?, + local_modified_at = strftime('%Y-%m-%dT%H:%M:%fZ','now') + WHERE id = ? AND deleted_at IS NULL`, displayName, id, ) return err @@ -1210,3 +1220,111 @@ func (db *DB) DeleteSessions(ids []string) (int, error) { } return total, nil } + +// ListSessionsModifiedBetween returns all sessions created or +// modified after since and at or before until. +// +// Uses file_mtime (nanoseconds since epoch from the source file) +// as the primary modification signal so that active sessions with +// new messages are detected even when ended_at has not changed. +// Falls back to session timestamps for rows without file_mtime. +// +// Precision note: file_mtime is compared as nanosecond integers, +// while text timestamps are normalized to millisecond precision +// (strftime '%f' -> 3 decimal places). Sub-millisecond differences +// in text timestamp fields are therefore truncated. +func (db *DB) ListSessionsModifiedBetween( + ctx context.Context, since, until string, +) ([]Session, error) { + query := "SELECT " + sessionFullCols + " FROM sessions" + var ( + args []any + where []string + ) + if since != "" { + sinceTime, err := time.Parse(time.RFC3339Nano, since) + if err != nil { + return nil, fmt.Errorf( + "parsing since timestamp %q: %w", since, err, + ) + } + sinceText := sinceTime.UTC().Format("2006-01-02T15:04:05.000Z") + sinceNano := sinceTime.UnixNano() + where = append(where, `(file_mtime > ? + OR `+sqliteSyncTimestampExpr(colLocalModifiedAt)+` > ? + OR `+sqliteSyncTimestampExpr(colBestTimestamp)+` > ? + OR `+sqliteSyncTimestampExpr(colCreatedAt)+` > ?)`) + args = append(args, sinceNano, sinceText, sinceText, sinceText) + } + if until != "" { + untilTime, err := time.Parse(time.RFC3339Nano, until) + if err != nil { + return nil, fmt.Errorf( + "parsing until timestamp %q: %w", until, err, + ) + } + untilText := untilTime.UTC().Format("2006-01-02T15:04:05.000Z") + untilNano := untilTime.UnixNano() + // COALESCE(file_mtime, -1) maps NULL to -1, which is always + // <= untilNano. This is intentional: rows without file_mtime + // should pass the upper-bound check and fall through to the + // timestamp comparisons below. The since clause omits COALESCE + // so that NULL file_mtime does not satisfy > sinceNano. + where = append(where, `(COALESCE(file_mtime, -1) <= ? + AND COALESCE(`+sqliteSyncTimestampExpr(colLocalModifiedAt)+`, '') <= ? + AND `+sqliteSyncTimestampExpr(colBestTimestamp)+` <= ? + AND `+sqliteSyncTimestampExpr(colCreatedAt)+` <= ?)`) + args = append(args, untilNano, untilText, untilText, untilText) + } + if len(where) > 0 { + query += " WHERE " + strings.Join(where, " AND ") + } + query += ` ORDER BY created_at` + + rows, err := db.getReader().QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf( + "listing sessions modified since %s: %w", + since, err, + ) + } + defer rows.Close() + + var sessions []Session + for rows.Next() { + var s Session + err := rows.Scan( + &s.ID, &s.Project, &s.Machine, &s.Agent, + &s.FirstMessage, &s.DisplayName, &s.StartedAt, &s.EndedAt, + &s.MessageCount, &s.UserMessageCount, + &s.ParentSessionID, &s.RelationshipType, + &s.TotalOutputTokens, &s.PeakContextTokens, + &s.DeletedAt, &s.FilePath, &s.FileSize, + &s.FileMtime, &s.FileHash, &s.LocalModifiedAt, &s.CreatedAt, + ) + if err != nil { + return nil, fmt.Errorf("scanning session: %w", err) + } + sessions = append(sessions, s) + } + return sessions, rows.Err() +} + +// trustedSQLiteExpr is a string type for SQL expressions known to be safe +// (literals, column references). Using a distinct type prevents accidental +// injection of user input, mirroring the trustedSQL pattern in pgsync/time.go. +type trustedSQLiteExpr string + +const ( + colLocalModifiedAt trustedSQLiteExpr = "NULLIF(local_modified_at, '')" + colBestTimestamp trustedSQLiteExpr = `COALESCE( + NULLIF(ended_at, ''), + NULLIF(started_at, ''), + created_at + )` + colCreatedAt trustedSQLiteExpr = "created_at" +) + +func sqliteSyncTimestampExpr(expr trustedSQLiteExpr) string { + return "strftime('%Y-%m-%dT%H:%M:%fZ', " + string(expr) + ")" +} diff --git a/internal/db/store.go b/internal/db/store.go new file mode 100644 index 00000000..98c5c45a --- /dev/null +++ b/internal/db/store.go @@ -0,0 +1,93 @@ +package db + +import "context" + +// ErrReadOnly is returned by write methods on read-only store +// implementations (e.g. the PostgreSQL reader). +var ErrReadOnly = errReadOnly{} + +type errReadOnly struct{} + +func (errReadOnly) Error() string { return "not available in remote mode" } + +// Store is the interface the HTTP server uses for all data access. +// The concrete *DB (SQLite) satisfies it implicitly. The pgdb +// package provides a read-only PostgreSQL implementation. +type Store interface { + // Cursor pagination. + SetCursorSecret(secret []byte) + EncodeCursor(endedAt, id string, total ...int) string + DecodeCursor(s string) (SessionCursor, error) + + // Sessions. + ListSessions(ctx context.Context, f SessionFilter) (SessionPage, error) + GetSession(ctx context.Context, id string) (*Session, error) + GetSessionFull(ctx context.Context, id string) (*Session, error) + GetChildSessions(ctx context.Context, parentID string) ([]Session, error) + + // Messages. + GetMessages(ctx context.Context, sessionID string, from, limit int, asc bool) ([]Message, error) + GetAllMessages(ctx context.Context, sessionID string) ([]Message, error) + GetMinimap(ctx context.Context, sessionID string) ([]MinimapEntry, error) + GetMinimapFrom(ctx context.Context, sessionID string, from int) ([]MinimapEntry, error) + + // Search. + HasFTS() bool + Search(ctx context.Context, f SearchFilter) (SearchPage, error) + SearchSession(ctx context.Context, sessionID, query string) ([]int, error) + + // SSE change detection. + GetSessionVersion(id string) (count int, fileMtime int64, ok bool) + + // Metadata. + GetStats(ctx context.Context, excludeOneShot bool) (Stats, error) + GetProjects(ctx context.Context, excludeOneShot bool) ([]ProjectInfo, error) + GetAgents(ctx context.Context, excludeOneShot bool) ([]AgentInfo, error) + GetMachines(ctx context.Context, excludeOneShot bool) ([]string, error) + + // Analytics. + GetAnalyticsSummary(ctx context.Context, f AnalyticsFilter) (AnalyticsSummary, error) + GetAnalyticsActivity(ctx context.Context, f AnalyticsFilter, granularity string) (ActivityResponse, error) + GetAnalyticsHeatmap(ctx context.Context, f AnalyticsFilter, metric string) (HeatmapResponse, error) + GetAnalyticsProjects(ctx context.Context, f AnalyticsFilter) (ProjectsAnalyticsResponse, error) + GetAnalyticsHourOfWeek(ctx context.Context, f AnalyticsFilter) (HourOfWeekResponse, error) + GetAnalyticsSessionShape(ctx context.Context, f AnalyticsFilter) (SessionShapeResponse, error) + GetAnalyticsTools(ctx context.Context, f AnalyticsFilter) (ToolsAnalyticsResponse, error) + GetAnalyticsVelocity(ctx context.Context, f AnalyticsFilter) (VelocityResponse, error) + GetAnalyticsTopSessions(ctx context.Context, f AnalyticsFilter, metric string) (TopSessionsResponse, error) + + // Stars (local-only; PG returns ErrReadOnly). + StarSession(sessionID string) (bool, error) + UnstarSession(sessionID string) error + ListStarredSessionIDs(ctx context.Context) ([]string, error) + BulkStarSessions(sessionIDs []string) error + + // Pins (local-only; PG returns ErrReadOnly). + PinMessage(sessionID string, messageID int64, note *string) (int64, error) + UnpinMessage(sessionID string, messageID int64) error + ListPinnedMessages(ctx context.Context, sessionID string) ([]PinnedMessage, error) + + // Insights (local-only; PG returns ErrReadOnly). + ListInsights(ctx context.Context, f InsightFilter) ([]Insight, error) + GetInsight(ctx context.Context, id int64) (*Insight, error) + InsertInsight(s Insight) (int64, error) + DeleteInsight(id int64) error + + // Session management (local-only; PG returns ErrReadOnly). + RenameSession(id string, displayName *string) error + SoftDeleteSession(id string) error + RestoreSession(id string) (int64, error) + DeleteSessionIfTrashed(id string) (int64, error) + ListTrashedSessions(ctx context.Context) ([]Session, error) + EmptyTrash() (int, error) + + // Upload (local-only; PG returns ErrReadOnly). + UpsertSession(s Session) error + ReplaceSessionMessages(sessionID string, msgs []Message) error + + // ReadOnly returns true for remote/PG-backed stores. + ReadOnly() bool +} + +// Compile-time check: *DB satisfies Store. +var _ Store = (*DB)(nil) diff --git a/internal/insight/prompt.go b/internal/insight/prompt.go index ee9d0d2d..68eef965 100644 --- a/internal/insight/prompt.go +++ b/internal/insight/prompt.go @@ -23,7 +23,7 @@ type GenerateRequest struct { // a prompt for the AI agent. func BuildPrompt( ctx context.Context, - database *db.DB, + database db.Store, req GenerateRequest, ) (string, error) { filter := db.SessionFilter{ diff --git a/internal/parser/claude.go b/internal/parser/claude.go index c119a6ba..e7dda862 100644 --- a/internal/parser/claude.go +++ b/internal/parser/claude.go @@ -830,7 +830,13 @@ func truncate(s string, maxLen int) string { if len(s) <= maxLen { return s } - return s[:maxLen] + "..." + // Truncate at a valid rune boundary to avoid producing + // invalid UTF-8. + r := []rune(s) + if len(r) <= maxLen { + return s + } + return string(r[:maxLen]) + "..." } // isClaudeSystemMessage returns true if the content matches diff --git a/internal/parser/claude_parser_test.go b/internal/parser/claude_parser_test.go index e492e530..c41a9c33 100644 --- a/internal/parser/claude_parser_test.go +++ b/internal/parser/claude_parser_test.go @@ -4,6 +4,7 @@ import ( "os" "path/filepath" "testing" + "unicode/utf8" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -554,3 +555,70 @@ func loadFixture(t *testing.T, name string) string { require.NoError(t, err) return string(data) } + +func TestTruncateRespectsRuneBoundaries(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + want string + }{ + { + name: "ASCII within limit", + input: "hello", + maxLen: 10, + want: "hello", + }, + { + name: "ASCII truncated", + input: "hello world", + maxLen: 5, + want: "hello...", + }, + { + name: "multibyte within limit", + input: "café", + maxLen: 10, + want: "café", + }, + { + name: "multibyte at boundary", + // 4 runes: c, a, f, é — truncate at 3 runes + input: "café", + maxLen: 3, + want: "caf...", + }, + { + name: "CJK characters", + // 3 runes, each 3 bytes + input: "你好世界", + maxLen: 2, + want: "你好...", + }, + { + name: "ellipsis character preserved", + // U+2026 is 3 bytes but 1 rune + input: "abc\u2026def", + maxLen: 4, + want: "abc\u2026...", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := truncate(tc.input, tc.maxLen) + if got != tc.want { + t.Errorf( + "truncate(%q, %d) = %q, want %q", + tc.input, tc.maxLen, got, tc.want, + ) + } + // Verify result is valid UTF-8. + if !utf8.ValidString(got) { + t.Errorf( + "truncate produced invalid UTF-8: %q", + got, + ) + } + }) + } +} diff --git a/internal/parser/openclaw_test.go b/internal/parser/openclaw_test.go index b1167e57..b3b18824 100644 --- a/internal/parser/openclaw_test.go +++ b/internal/parser/openclaw_test.go @@ -46,6 +46,7 @@ func TestParseOpenClawSession_Basic(t *testing.T) { } if sess == nil { t.Fatal("expected session, got nil") + return } if sess.ID != "openclaw:main:abc-123" { diff --git a/internal/parser/types.go b/internal/parser/types.go index a24f8cc6..dc8d6336 100644 --- a/internal/parser/types.go +++ b/internal/parser/types.go @@ -30,7 +30,7 @@ type AgentDef struct { Type AgentType DisplayName string // "Claude Code", "Codex", etc. EnvVar string // env var for dir override - ConfigKey string // JSON key in config.json ("" = none) + ConfigKey string // TOML key in config.toml ("" = none) DefaultDirs []string // paths relative to $HOME IDPrefix string // session ID prefix ("" for Claude) WatchSubdirs []string // subdirs to watch (nil = watch root) diff --git a/internal/postgres/analytics.go b/internal/postgres/analytics.go new file mode 100644 index 00000000..249b7bf6 --- /dev/null +++ b/internal/postgres/analytics.go @@ -0,0 +1,2044 @@ +package postgres + +import ( + "context" + "fmt" + "math" + "sort" + "strings" + "time" + + "github.com/wesm/agentsview/internal/db" +) + +// maxPGVars is the maximum bind variables per IN clause. +const maxPGVars = 500 + +// pgQueryChunked executes a callback for each chunk of IDs, +// splitting at maxPGVars to avoid excessive bind variables. +func pgQueryChunked( + ids []string, + fn func(chunk []string) error, +) error { + for i := 0; i < len(ids); i += maxPGVars { + end := min(i+maxPGVars, len(ids)) + if err := fn(ids[i:end]); err != nil { + return err + } + } + return nil +} + +// pgInPlaceholders returns a "(placeholders)" string for PG +// numbered parameters. +func pgInPlaceholders( + ids []string, pb *paramBuilder, +) string { + phs := make([]string, len(ids)) + for i, id := range ids { + phs[i] = pb.add(id) + } + return "(" + strings.Join(phs, ",") + ")" +} + +// analyticsUTCRange returns UTC time bounds padded by +/-14h +// to cover all possible timezone offsets. +func analyticsUTCRange( + f db.AnalyticsFilter, +) (string, string) { + from := f.From + "T00:00:00Z" + to := f.To + "T23:59:59Z" + tFrom, err := time.Parse(time.RFC3339, from) + if err != nil { + return from, to + } + tTo, err := time.Parse(time.RFC3339, to) + if err != nil { + return from, to + } + return tFrom.Add(-14 * time.Hour).Format(time.RFC3339), + tTo.Add(14 * time.Hour).Format(time.RFC3339) +} + +// buildAnalyticsWhere builds a WHERE clause with PG +// placeholders. dateCol is the date expression. +func buildAnalyticsWhere( + f db.AnalyticsFilter, + dateCol string, + pb *paramBuilder, +) string { + preds := []string{ + "message_count > 0", + "relationship_type NOT IN ('subagent', 'fork')", + "deleted_at IS NULL", + } + utcFrom, utcTo := analyticsUTCRange(f) + preds = append(preds, + dateCol+" >= "+pb.add(utcFrom)+"::timestamptz") + preds = append(preds, + dateCol+" <= "+pb.add(utcTo)+"::timestamptz") + if f.Machine != "" { + preds = append(preds, + "machine = "+pb.add(f.Machine)) + } + if f.Project != "" { + preds = append(preds, + "project = "+pb.add(f.Project)) + } + if f.Agent != "" { + agents := strings.Split(f.Agent, ",") + if len(agents) == 1 { + preds = append(preds, + "agent = "+pb.add(agents[0])) + } else { + phs := make([]string, len(agents)) + for i, a := range agents { + phs[i] = pb.add(a) + } + preds = append(preds, + "agent IN ("+ + strings.Join(phs, ",")+ + ")") + } + } + if f.MinUserMessages > 0 { + preds = append(preds, + "user_message_count >= "+ + pb.add(f.MinUserMessages)) + } + if f.ExcludeOneShot { + preds = append(preds, "user_message_count > 1") + } + if f.ActiveSince != "" { + preds = append(preds, + "COALESCE(ended_at, started_at, created_at)"+ + " >= "+pb.add(f.ActiveSince)+ + "::timestamptz") + } + return strings.Join(preds, " AND ") +} + +// localTime parses a UTC timestamp string and converts it to +// the given location. +func localTime( + ts string, loc *time.Location, +) (time.Time, bool) { + t, err := time.Parse(time.RFC3339Nano, ts) + if err != nil { + t, err = time.Parse("2006-01-02T15:04:05Z", ts) + if err != nil { + return time.Time{}, false + } + } + return t.In(loc), true +} + +// localDate converts a UTC timestamp string to a local date +// string (YYYY-MM-DD). +func localDate(ts string, loc *time.Location) string { + t, ok := localTime(ts, loc) + if !ok { + if len(ts) >= 10 { + return ts[:10] + } + return "" + } + return t.Format("2006-01-02") +} + +// inDateRange checks if a local date falls within [from, to]. +func inDateRange(date, from, to string) bool { + return date >= from && date <= to +} + +// medianInt returns the median of a sorted int slice. +func medianInt(sorted []int, n int) int { + if n == 0 { + return 0 + } + if n%2 == 0 { + return (sorted[n/2-1] + sorted[n/2]) / 2 + } + return sorted[n/2] +} + +// percentileFloat returns the value at the given percentile +// from a pre-sorted float64 slice. +func percentileFloat( + sorted []float64, pct float64, +) float64 { + n := len(sorted) + if n == 0 { + return 0 + } + idx := int(float64(n) * pct) + if idx >= n { + idx = n - 1 + } + return sorted[idx] +} + +// analyticsLocation loads the timezone from the filter. +func analyticsLocation( + f db.AnalyticsFilter, +) *time.Location { + if f.Timezone == "" { + return time.UTC + } + loc, err := time.LoadLocation(f.Timezone) + if err != nil { + return time.UTC + } + return loc +} + +// matchesTimeFilter checks whether a local time matches the +// active hour and/or day-of-week filter. +func matchesTimeFilter( + f db.AnalyticsFilter, t time.Time, +) bool { + if f.DayOfWeek != nil { + dow := (int(t.Weekday()) + 6) % 7 // ISO Mon=0 + if dow != *f.DayOfWeek { + return false + } + } + if f.Hour != nil { + if t.Hour() != *f.Hour { + return false + } + } + return true +} + +// pgDateCol is the date column expression for analytics. +const pgDateCol = "COALESCE(started_at, created_at)" + +// pgDateColS is the date column with "s." table prefix. +const pgDateColS = "COALESCE(s.started_at, s.created_at)" + +// filteredSessionIDs returns session IDs that have at least +// one message matching the hour/dow filter. +func (s *Store) filteredSessionIDs( + ctx context.Context, f db.AnalyticsFilter, +) (map[string]bool, error) { + loc := analyticsLocation(f) + pb := ¶mBuilder{} + where := buildAnalyticsWhere(f, pgDateColS, pb) + + query := `SELECT s.id, + TO_CHAR(m.timestamp AT TIME ZONE 'UTC', + 'YYYY-MM-DD"T"HH24:MI:SS"Z"') + FROM sessions s + JOIN messages m ON m.session_id = s.id + WHERE ` + where + ` AND m.timestamp IS NOT NULL` + + rows, err := s.pg.QueryContext( + ctx, query, pb.args..., + ) + if err != nil { + return nil, fmt.Errorf( + "querying filtered session IDs: %w", err, + ) + } + defer rows.Close() + + ids := make(map[string]bool) + for rows.Next() { + var sid, msgTS string + if err := rows.Scan(&sid, &msgTS); err != nil { + return nil, fmt.Errorf( + "scanning filtered session ID: %w", err, + ) + } + if ids[sid] { + continue + } + t, ok := localTime(msgTS, loc) + if !ok { + continue + } + if matchesTimeFilter(f, t) { + ids[sid] = true + } + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf( + "iterating filtered session IDs: %w", err, + ) + } + return ids, nil +} + +// bucketDate truncates a date to the start of its bucket. +func bucketDate(date string, granularity string) string { + t, err := time.Parse("2006-01-02", date) + if err != nil { + return date + } + switch granularity { + case "week": + weekday := int(t.Weekday()) + if weekday == 0 { + weekday = 7 + } + t = t.AddDate(0, 0, -(weekday - 1)) + return t.Format("2006-01-02") + case "month": + return t.Format("2006-01") + "-01" + default: + return date + } +} + +// scanDateCol scans a TIMESTAMPTZ column and returns it as +// an ISO-8601 string for client-side date processing. +func scanDateCol(t *time.Time) string { + if t == nil { + return "" + } + return FormatISO8601(*t) +} + +// --- Summary --- + +// GetAnalyticsSummary returns aggregate statistics. +func (s *Store) GetAnalyticsSummary( + ctx context.Context, f db.AnalyticsFilter, +) (db.AnalyticsSummary, error) { + loc := analyticsLocation(f) + pb := ¶mBuilder{} + where := buildAnalyticsWhere(f, pgDateCol, pb) + + var timeIDs map[string]bool + if f.HasTimeFilter() { + var err error + timeIDs, err = s.filteredSessionIDs(ctx, f) + if err != nil { + return db.AnalyticsSummary{}, err + } + } + + query := `SELECT id, ` + pgDateCol + + `, message_count, agent, project + FROM sessions WHERE ` + where + + ` ORDER BY message_count ASC` + + rows, err := s.pg.QueryContext( + ctx, query, pb.args..., + ) + if err != nil { + return db.AnalyticsSummary{}, + fmt.Errorf( + "querying analytics summary: %w", err, + ) + } + defer rows.Close() + + type sessionRow struct { + date string + messages int + agent string + project string + } + + var all []sessionRow + for rows.Next() { + var id string + var ts *time.Time + var mc int + var agent, project string + if err := rows.Scan( + &id, &ts, &mc, &agent, &project, + ); err != nil { + return db.AnalyticsSummary{}, + fmt.Errorf( + "scanning summary row: %w", err, + ) + } + date := localDate(scanDateCol(ts), loc) + if !inDateRange(date, f.From, f.To) { + continue + } + if timeIDs != nil && !timeIDs[id] { + continue + } + all = append(all, sessionRow{ + date: date, messages: mc, + agent: agent, project: project, + }) + } + if err := rows.Err(); err != nil { + return db.AnalyticsSummary{}, + fmt.Errorf( + "iterating summary rows: %w", err, + ) + } + + var summary db.AnalyticsSummary + summary.Agents = make(map[string]*db.AgentSummary) + + if len(all) == 0 { + return summary, nil + } + + days := make(map[string]bool) + projects := make(map[string]int) + msgCounts := make([]int, 0, len(all)) + + for _, r := range all { + summary.TotalSessions++ + summary.TotalMessages += r.messages + days[r.date] = true + projects[r.project] += r.messages + msgCounts = append(msgCounts, r.messages) + + if summary.Agents[r.agent] == nil { + summary.Agents[r.agent] = &db.AgentSummary{} + } + summary.Agents[r.agent].Sessions++ + summary.Agents[r.agent].Messages += r.messages + } + + summary.ActiveProjects = len(projects) + summary.ActiveDays = len(days) + summary.AvgMessages = math.Round( + float64(summary.TotalMessages)/ + float64(summary.TotalSessions)*10, + ) / 10 + + sort.Ints(msgCounts) + n := len(msgCounts) + if n%2 == 0 { + summary.MedianMessages = + (msgCounts[n/2-1] + msgCounts[n/2]) / 2 + } else { + summary.MedianMessages = msgCounts[n/2] + } + p90Idx := int(float64(n) * 0.9) + if p90Idx >= n { + p90Idx = n - 1 + } + summary.P90Messages = msgCounts[p90Idx] + + maxMsgs := 0 + for name, count := range projects { + if count > maxMsgs || + (count == maxMsgs && name < summary.MostActive) { + maxMsgs = count + summary.MostActive = name + } + } + + if summary.TotalMessages > 0 { + counts := make([]int, 0, len(projects)) + for _, c := range projects { + counts = append(counts, c) + } + sort.Sort(sort.Reverse(sort.IntSlice(counts))) + top := min(3, len(counts)) + topSum := 0 + for _, c := range counts[:top] { + topSum += c + } + summary.Concentration = math.Round( + float64(topSum)/ + float64(summary.TotalMessages)*1000, + ) / 1000 + } + + return summary, nil +} + +// --- Activity --- + +// GetAnalyticsActivity returns session/message counts grouped +// by time bucket. +func (s *Store) GetAnalyticsActivity( + ctx context.Context, f db.AnalyticsFilter, + granularity string, +) (db.ActivityResponse, error) { + if granularity == "" { + granularity = "day" + } + loc := analyticsLocation(f) + pb := ¶mBuilder{} + where := buildAnalyticsWhere(f, pgDateColS, pb) + + var timeIDs map[string]bool + if f.HasTimeFilter() { + var err error + timeIDs, err = s.filteredSessionIDs(ctx, f) + if err != nil { + return db.ActivityResponse{}, err + } + } + + query := `SELECT ` + pgDateColS + `, s.agent, s.id, + m.role, m.has_thinking, COUNT(*) + FROM sessions s + LEFT JOIN messages m ON m.session_id = s.id + WHERE ` + where + ` + GROUP BY s.id, ` + pgDateColS + + `, s.agent, m.role, m.has_thinking` + + rows, err := s.pg.QueryContext( + ctx, query, pb.args..., + ) + if err != nil { + return db.ActivityResponse{}, + fmt.Errorf( + "querying analytics activity: %w", err, + ) + } + defer rows.Close() + + buckets := make(map[string]*db.ActivityEntry) + sessionSeen := make(map[string]string) + var sessionIDs []string + + for rows.Next() { + var tsVal *time.Time + var agent, sid string + var role *string + var hasThinking *bool + var count int + if err := rows.Scan( + &tsVal, &agent, &sid, &role, + &hasThinking, &count, + ); err != nil { + return db.ActivityResponse{}, + fmt.Errorf( + "scanning activity row: %w", err, + ) + } + + date := localDate(scanDateCol(tsVal), loc) + if !inDateRange(date, f.From, f.To) { + continue + } + if timeIDs != nil && !timeIDs[sid] { + continue + } + bucket := bucketDate(date, granularity) + + entry, ok := buckets[bucket] + if !ok { + entry = &db.ActivityEntry{ + Date: bucket, + ByAgent: make(map[string]int), + } + buckets[bucket] = entry + } + + if _, seen := sessionSeen[sid]; !seen { + sessionSeen[sid] = bucket + sessionIDs = append(sessionIDs, sid) + entry.Sessions++ + } + + if role != nil { + entry.Messages += count + entry.ByAgent[agent] += count + switch *role { + case "user": + entry.UserMessages += count + case "assistant": + entry.AssistantMessages += count + } + if hasThinking != nil && *hasThinking { + entry.ThinkingMessages += count + } + } + } + if err := rows.Err(); err != nil { + return db.ActivityResponse{}, + fmt.Errorf( + "iterating activity rows: %w", err, + ) + } + + if len(sessionIDs) > 0 { + err = pgQueryChunked(sessionIDs, + func(chunk []string) error { + return s.mergeActivityToolCalls( + ctx, chunk, sessionSeen, buckets, + ) + }) + if err != nil { + return db.ActivityResponse{}, err + } + } + + series := make([]db.ActivityEntry, 0, len(buckets)) + for _, e := range buckets { + series = append(series, *e) + } + sort.Slice(series, func(i, j int) bool { + return series[i].Date < series[j].Date + }) + + return db.ActivityResponse{ + Granularity: granularity, + Series: series, + }, nil +} + +// mergeActivityToolCalls queries tool_calls for a chunk of +// session IDs and adds counts to the matching activity +// buckets. +func (s *Store) mergeActivityToolCalls( + ctx context.Context, + chunk []string, + sessionBucket map[string]string, + buckets map[string]*db.ActivityEntry, +) error { + pb := ¶mBuilder{} + ph := pgInPlaceholders(chunk, pb) + q := `SELECT session_id, COUNT(*) + FROM tool_calls + WHERE session_id IN ` + ph + ` + GROUP BY session_id` + rows, err := s.pg.QueryContext(ctx, q, pb.args...) + if err != nil { + return fmt.Errorf( + "querying activity tool_calls: %w", err, + ) + } + defer rows.Close() + + for rows.Next() { + var sid string + var count int + if err := rows.Scan(&sid, &count); err != nil { + return fmt.Errorf( + "scanning activity tool_call: %w", err, + ) + } + bucket := sessionBucket[sid] + if entry, ok := buckets[bucket]; ok { + entry.ToolCalls += count + } + } + return rows.Err() +} + +// --- Heatmap --- + +// MaxHeatmapDays is the maximum number of day entries. +const MaxHeatmapDays = 366 + +// clampFrom returns from clamped so [from, to] spans at +// most MaxHeatmapDays. +func clampFrom(from, to string) string { + start, err := time.Parse("2006-01-02", from) + if err != nil { + return from + } + end, err := time.Parse("2006-01-02", to) + if err != nil { + return from + } + earliest := end.AddDate(0, 0, -(MaxHeatmapDays - 1)) + if start.Before(earliest) { + return earliest.Format("2006-01-02") + } + return from +} + +// computeQuartileLevels computes thresholds from sorted +// values. +func computeQuartileLevels( + sorted []int, +) db.HeatmapLevels { + if len(sorted) == 0 { + return db.HeatmapLevels{ + L1: 1, L2: 2, L3: 3, L4: 4, + } + } + n := len(sorted) + return db.HeatmapLevels{ + L1: sorted[0], + L2: sorted[n/4], + L3: sorted[n/2], + L4: sorted[n*3/4], + } +} + +// assignLevel determines the heatmap level (0-4) for a value. +func assignLevel(value int, levels db.HeatmapLevels) int { + if value <= 0 { + return 0 + } + if value <= levels.L2 { + return 1 + } + if value <= levels.L3 { + return 2 + } + if value <= levels.L4 { + return 3 + } + return 4 +} + +// buildDateEntries creates a HeatmapEntry for each day in +// [from, to]. +func buildDateEntries( + from, to string, + values map[string]int, + levels db.HeatmapLevels, +) []db.HeatmapEntry { + start, err := time.Parse("2006-01-02", from) + if err != nil { + return nil + } + end, err := time.Parse("2006-01-02", to) + if err != nil { + return nil + } + + entries := []db.HeatmapEntry{} + for d := start; !d.After(end); d = d.AddDate(0, 0, 1) { + date := d.Format("2006-01-02") + v := values[date] + entries = append(entries, db.HeatmapEntry{ + Date: date, + Value: v, + Level: assignLevel(v, levels), + }) + } + return entries +} + +// GetAnalyticsHeatmap returns daily counts with intensity +// levels. +func (s *Store) GetAnalyticsHeatmap( + ctx context.Context, f db.AnalyticsFilter, + metric string, +) (db.HeatmapResponse, error) { + if metric == "" { + metric = "messages" + } + + loc := analyticsLocation(f) + pb := ¶mBuilder{} + where := buildAnalyticsWhere(f, pgDateCol, pb) + + var timeIDs map[string]bool + if f.HasTimeFilter() { + var err error + timeIDs, err = s.filteredSessionIDs(ctx, f) + if err != nil { + return db.HeatmapResponse{}, err + } + } + + query := `SELECT id, ` + pgDateCol + + `, message_count + FROM sessions WHERE ` + where + + rows, err := s.pg.QueryContext( + ctx, query, pb.args..., + ) + if err != nil { + return db.HeatmapResponse{}, + fmt.Errorf( + "querying analytics heatmap: %w", err, + ) + } + defer rows.Close() + + dayCounts := make(map[string]int) + daySessions := make(map[string]int) + + for rows.Next() { + var id string + var ts *time.Time + var mc int + if err := rows.Scan(&id, &ts, &mc); err != nil { + return db.HeatmapResponse{}, + fmt.Errorf( + "scanning heatmap row: %w", err, + ) + } + date := localDate(scanDateCol(ts), loc) + if !inDateRange(date, f.From, f.To) { + continue + } + if timeIDs != nil && !timeIDs[id] { + continue + } + dayCounts[date] += mc + daySessions[date]++ + } + if err := rows.Err(); err != nil { + return db.HeatmapResponse{}, + fmt.Errorf( + "iterating heatmap rows: %w", err, + ) + } + + source := dayCounts + if metric == "sessions" { + source = daySessions + } + + entriesFrom := clampFrom(f.From, f.To) + + var values []int + for date, v := range source { + if v > 0 && date >= entriesFrom && date <= f.To { + values = append(values, v) + } + } + sort.Ints(values) + + levels := computeQuartileLevels(values) + + entries := buildDateEntries( + entriesFrom, f.To, source, levels, + ) + + return db.HeatmapResponse{ + Metric: metric, + Entries: entries, + Levels: levels, + EntriesFrom: entriesFrom, + }, nil +} + +// --- Projects --- + +// GetAnalyticsProjects returns per-project analytics. +func (s *Store) GetAnalyticsProjects( + ctx context.Context, f db.AnalyticsFilter, +) (db.ProjectsAnalyticsResponse, error) { + loc := analyticsLocation(f) + pb := ¶mBuilder{} + where := buildAnalyticsWhere(f, pgDateCol, pb) + + var timeIDs map[string]bool + if f.HasTimeFilter() { + var err error + timeIDs, err = s.filteredSessionIDs(ctx, f) + if err != nil { + return db.ProjectsAnalyticsResponse{}, err + } + } + + query := `SELECT id, project, ` + pgDateCol + `, + message_count, agent + FROM sessions WHERE ` + where + + ` ORDER BY project, ` + pgDateCol + + rows, err := s.pg.QueryContext( + ctx, query, pb.args..., + ) + if err != nil { + return db.ProjectsAnalyticsResponse{}, + fmt.Errorf( + "querying analytics projects: %w", err, + ) + } + defer rows.Close() + + type projectData struct { + name string + sessions int + messages int + first string + last string + counts []int + agents map[string]int + days map[string]int + } + + projectMap := make(map[string]*projectData) + var projectOrder []string + + for rows.Next() { + var id, project, agent string + var ts *time.Time + var mc int + if err := rows.Scan( + &id, &project, &ts, &mc, &agent, + ); err != nil { + return db.ProjectsAnalyticsResponse{}, + fmt.Errorf( + "scanning project row: %w", err, + ) + } + date := localDate(scanDateCol(ts), loc) + if !inDateRange(date, f.From, f.To) { + continue + } + if timeIDs != nil && !timeIDs[id] { + continue + } + + pd, ok := projectMap[project] + if !ok { + pd = &projectData{ + name: project, + agents: make(map[string]int), + days: make(map[string]int), + } + projectMap[project] = pd + projectOrder = append( + projectOrder, project, + ) + } + + pd.sessions++ + pd.messages += mc + pd.counts = append(pd.counts, mc) + pd.agents[agent]++ + pd.days[date] += mc + + if pd.first == "" || date < pd.first { + pd.first = date + } + if date > pd.last { + pd.last = date + } + } + if err := rows.Err(); err != nil { + return db.ProjectsAnalyticsResponse{}, + fmt.Errorf( + "iterating project rows: %w", err, + ) + } + + projects := make( + []db.ProjectAnalytics, 0, len(projectMap), + ) + for _, name := range projectOrder { + pd := projectMap[name] + sort.Ints(pd.counts) + n := len(pd.counts) + + avg := 0.0 + if n > 0 { + avg = math.Round( + float64(pd.messages)/float64(n)*10, + ) / 10 + } + + trend := 0.0 + if len(pd.days) > 0 { + trend = math.Round( + float64(pd.messages)/ + float64(len(pd.days))*10, + ) / 10 + } + + projects = append(projects, db.ProjectAnalytics{ + Name: pd.name, + Sessions: pd.sessions, + Messages: pd.messages, + FirstSession: pd.first, + LastSession: pd.last, + AvgMessages: avg, + MedianMessages: medianInt(pd.counts, n), + Agents: pd.agents, + DailyTrend: trend, + }) + } + + sort.Slice(projects, func(i, j int) bool { + return projects[i].Messages > projects[j].Messages + }) + + return db.ProjectsAnalyticsResponse{ + Projects: projects, + }, nil +} + +// --- Hour-of-Week --- + +// GetAnalyticsHourOfWeek returns message counts bucketed by +// day-of-week and hour-of-day. +func (s *Store) GetAnalyticsHourOfWeek( + ctx context.Context, f db.AnalyticsFilter, +) (db.HourOfWeekResponse, error) { + loc := analyticsLocation(f) + pb := ¶mBuilder{} + where := buildAnalyticsWhere(f, pgDateColS, pb) + + query := `SELECT ` + pgDateColS + `, + TO_CHAR(m.timestamp AT TIME ZONE 'UTC', + 'YYYY-MM-DD"T"HH24:MI:SS"Z"') + FROM sessions s + JOIN messages m ON m.session_id = s.id + WHERE ` + where + ` AND m.timestamp IS NOT NULL` + + rows, err := s.pg.QueryContext( + ctx, query, pb.args..., + ) + if err != nil { + return db.HourOfWeekResponse{}, + fmt.Errorf( + "querying hour-of-week: %w", err, + ) + } + defer rows.Close() + + var grid [7][24]int + + for rows.Next() { + var sessTS *time.Time + var msgTS string + if err := rows.Scan(&sessTS, &msgTS); err != nil { + return db.HourOfWeekResponse{}, + fmt.Errorf( + "scanning hour-of-week row: %w", + err, + ) + } + sessDate := localDate(scanDateCol(sessTS), loc) + if !inDateRange(sessDate, f.From, f.To) { + continue + } + t, ok := localTime(msgTS, loc) + if !ok { + continue + } + dow := (int(t.Weekday()) + 6) % 7 + grid[dow][t.Hour()]++ + } + if err := rows.Err(); err != nil { + return db.HourOfWeekResponse{}, + fmt.Errorf( + "iterating hour-of-week rows: %w", err, + ) + } + + cells := make([]db.HourOfWeekCell, 0, 168) + for d := range 7 { + for h := range 24 { + cells = append(cells, db.HourOfWeekCell{ + DayOfWeek: d, + Hour: h, + Messages: grid[d][h], + }) + } + } + + return db.HourOfWeekResponse{Cells: cells}, nil +} + +// --- Session Shape --- + +// lengthBucket returns the bucket label for a message count. +func lengthBucket(mc int) string { + switch { + case mc <= 5: + return "1-5" + case mc <= 15: + return "6-15" + case mc <= 30: + return "16-30" + case mc <= 60: + return "31-60" + case mc <= 120: + return "61-120" + default: + return "121+" + } +} + +// durationBucket returns the bucket label for a duration in +// minutes. +func durationBucket(mins float64) string { + switch { + case mins < 5: + return "<5m" + case mins < 15: + return "5-15m" + case mins < 30: + return "15-30m" + case mins < 60: + return "30-60m" + case mins < 120: + return "1-2h" + default: + return "2h+" + } +} + +// autonomyBucket returns the bucket label for an autonomy +// ratio. +func autonomyBucket(ratio float64) string { + switch { + case ratio < 0.5: + return "<0.5" + case ratio < 1: + return "0.5-1" + case ratio < 2: + return "1-2" + case ratio < 5: + return "2-5" + case ratio < 10: + return "5-10" + default: + return "10+" + } +} + +var ( + lengthOrder = map[string]int{ + "1-5": 0, "6-15": 1, "16-30": 2, + "31-60": 3, "61-120": 4, "121+": 5, + } + durationOrder = map[string]int{ + "<5m": 0, "5-15m": 1, "15-30m": 2, + "30-60m": 3, "1-2h": 4, "2h+": 5, + } + autonomyOrder = map[string]int{ + "<0.5": 0, "0.5-1": 1, "1-2": 2, + "2-5": 3, "5-10": 4, "10+": 5, + } +) + +// sortBuckets sorts distribution buckets by defined order. +func sortBuckets( + buckets []db.DistributionBucket, + order map[string]int, +) { + sort.Slice(buckets, func(i, j int) bool { + return order[buckets[i].Label] < + order[buckets[j].Label] + }) +} + +// mapToBuckets converts a label->count map to sorted buckets. +func mapToBuckets( + m map[string]int, order map[string]int, +) []db.DistributionBucket { + buckets := make( + []db.DistributionBucket, 0, len(m), + ) + for label, count := range m { + buckets = append(buckets, db.DistributionBucket{ + Label: label, Count: count, + }) + } + sortBuckets(buckets, order) + return buckets +} + +// GetAnalyticsSessionShape returns distribution histograms +// for session length, duration, and autonomy ratio. +func (s *Store) GetAnalyticsSessionShape( + ctx context.Context, f db.AnalyticsFilter, +) (db.SessionShapeResponse, error) { + loc := analyticsLocation(f) + pb := ¶mBuilder{} + where := buildAnalyticsWhere(f, pgDateCol, pb) + + var timeIDs map[string]bool + if f.HasTimeFilter() { + var err error + timeIDs, err = s.filteredSessionIDs(ctx, f) + if err != nil { + return db.SessionShapeResponse{}, err + } + } + + query := `SELECT ` + pgDateCol + `, + EXTRACT(EPOCH FROM ended_at - started_at) + AS duration_sec, + message_count, id FROM sessions WHERE ` + where + + rows, err := s.pg.QueryContext( + ctx, query, pb.args..., + ) + if err != nil { + return db.SessionShapeResponse{}, + fmt.Errorf( + "querying session shape: %w", err, + ) + } + defer rows.Close() + + lengthCounts := make(map[string]int) + durationCounts := make(map[string]int) + var sessionIDs []string + totalCount := 0 + + for rows.Next() { + var tsVal *time.Time + var durationSec *float64 + var mc int + var id string + if err := rows.Scan( + &tsVal, &durationSec, &mc, &id, + ); err != nil { + return db.SessionShapeResponse{}, + fmt.Errorf( + "scanning session shape row: %w", + err, + ) + } + date := localDate(scanDateCol(tsVal), loc) + if !inDateRange(date, f.From, f.To) { + continue + } + if timeIDs != nil && !timeIDs[id] { + continue + } + + totalCount++ + lengthCounts[lengthBucket(mc)]++ + sessionIDs = append(sessionIDs, id) + + if durationSec != nil && *durationSec >= 0 { + mins := *durationSec / 60.0 + durationCounts[durationBucket(mins)]++ + } + } + if err := rows.Err(); err != nil { + return db.SessionShapeResponse{}, + fmt.Errorf( + "iterating session shape rows: %w", + err, + ) + } + + autonomyCounts := make(map[string]int) + if len(sessionIDs) > 0 { + err := pgQueryChunked(sessionIDs, + func(chunk []string) error { + return s.queryAutonomyChunk( + ctx, chunk, autonomyCounts, + ) + }) + if err != nil { + return db.SessionShapeResponse{}, err + } + } + + return db.SessionShapeResponse{ + Count: totalCount, + LengthDistribution: mapToBuckets( + lengthCounts, lengthOrder, + ), + DurationDistribution: mapToBuckets( + durationCounts, durationOrder, + ), + AutonomyDistribution: mapToBuckets( + autonomyCounts, autonomyOrder, + ), + }, nil +} + +// queryAutonomyChunk queries autonomy stats for a chunk of +// session IDs. +func (s *Store) queryAutonomyChunk( + ctx context.Context, + chunk []string, + counts map[string]int, +) error { + pb := ¶mBuilder{} + ph := pgInPlaceholders(chunk, pb) + q := `SELECT session_id, + SUM(CASE WHEN role='user' THEN 1 ELSE 0 END), + SUM(CASE WHEN role='assistant' + AND has_tool_use=true THEN 1 ELSE 0 END) + FROM messages + WHERE session_id IN ` + ph + ` + GROUP BY session_id` + + rows, err := s.pg.QueryContext(ctx, q, pb.args...) + if err != nil { + return fmt.Errorf("querying autonomy: %w", err) + } + defer rows.Close() + + for rows.Next() { + var sid string + var userCount, toolCount int + if err := rows.Scan( + &sid, &userCount, &toolCount, + ); err != nil { + return fmt.Errorf( + "scanning autonomy row: %w", err, + ) + } + if userCount > 0 { + ratio := float64(toolCount) / + float64(userCount) + counts[autonomyBucket(ratio)]++ + } + } + return rows.Err() +} + +// --- Tools --- + +// GetAnalyticsTools returns tool usage analytics. +func (s *Store) GetAnalyticsTools( + ctx context.Context, f db.AnalyticsFilter, +) (db.ToolsAnalyticsResponse, error) { + loc := analyticsLocation(f) + pb := ¶mBuilder{} + where := buildAnalyticsWhere(f, pgDateCol, pb) + + var timeIDs map[string]bool + if f.HasTimeFilter() { + var err error + timeIDs, err = s.filteredSessionIDs(ctx, f) + if err != nil { + return db.ToolsAnalyticsResponse{}, err + } + } + + sessQ := `SELECT id, ` + pgDateCol + `, agent + FROM sessions WHERE ` + where + + sessRows, err := s.pg.QueryContext( + ctx, sessQ, pb.args..., + ) + if err != nil { + return db.ToolsAnalyticsResponse{}, + fmt.Errorf( + "querying tool sessions: %w", err, + ) + } + defer sessRows.Close() + + type sessInfo struct { + date string + agent string + } + sessionMap := make(map[string]sessInfo) + var sessionIDs []string + + for sessRows.Next() { + var id, agent string + var ts *time.Time + if err := sessRows.Scan( + &id, &ts, &agent, + ); err != nil { + return db.ToolsAnalyticsResponse{}, + fmt.Errorf( + "scanning tool session: %w", err, + ) + } + date := localDate(scanDateCol(ts), loc) + if !inDateRange(date, f.From, f.To) { + continue + } + if timeIDs != nil && !timeIDs[id] { + continue + } + sessionMap[id] = sessInfo{ + date: date, agent: agent, + } + sessionIDs = append(sessionIDs, id) + } + if err := sessRows.Err(); err != nil { + return db.ToolsAnalyticsResponse{}, + fmt.Errorf( + "iterating tool sessions: %w", err, + ) + } + + resp := db.ToolsAnalyticsResponse{ + ByCategory: []db.ToolCategoryCount{}, + ByAgent: []db.ToolAgentBreakdown{}, + Trend: []db.ToolTrendEntry{}, + } + + if len(sessionIDs) == 0 { + return resp, nil + } + + type toolRow struct { + sessionID string + category string + } + var toolRows []toolRow + + err = pgQueryChunked(sessionIDs, + func(chunk []string) error { + chunkPB := ¶mBuilder{} + ph := pgInPlaceholders(chunk, chunkPB) + q := `SELECT session_id, category + FROM tool_calls + WHERE session_id IN ` + ph + rows, qErr := s.pg.QueryContext( + ctx, q, chunkPB.args..., + ) + if qErr != nil { + return fmt.Errorf( + "querying tool_calls: %w", qErr, + ) + } + defer rows.Close() + for rows.Next() { + var sid, cat string + if err := rows.Scan(&sid, &cat); err != nil { + return fmt.Errorf( + "scanning tool_call: %w", err, + ) + } + toolRows = append(toolRows, toolRow{ + sessionID: sid, category: cat, + }) + } + return rows.Err() + }) + if err != nil { + return db.ToolsAnalyticsResponse{}, err + } + + if len(toolRows) == 0 { + return resp, nil + } + + catCounts := make(map[string]int) + agentCats := make(map[string]map[string]int) + trendBuckets := make(map[string]map[string]int) + + for _, tr := range toolRows { + info := sessionMap[tr.sessionID] + catCounts[tr.category]++ + + if agentCats[info.agent] == nil { + agentCats[info.agent] = make(map[string]int) + } + agentCats[info.agent][tr.category]++ + + week := bucketDate(info.date, "week") + if trendBuckets[week] == nil { + trendBuckets[week] = make(map[string]int) + } + trendBuckets[week][tr.category]++ + } + + resp.TotalCalls = len(toolRows) + + resp.ByCategory = make( + []db.ToolCategoryCount, 0, len(catCounts), + ) + for cat, count := range catCounts { + pct := math.Round( + float64(count)/ + float64(resp.TotalCalls)*1000, + ) / 10 + resp.ByCategory = append(resp.ByCategory, + db.ToolCategoryCount{ + Category: cat, Count: count, Pct: pct, + }) + } + sort.Slice(resp.ByCategory, func(i, j int) bool { + if resp.ByCategory[i].Count != + resp.ByCategory[j].Count { + return resp.ByCategory[i].Count > + resp.ByCategory[j].Count + } + return resp.ByCategory[i].Category < + resp.ByCategory[j].Category + }) + + agentKeys := make([]string, 0, len(agentCats)) + for k := range agentCats { + agentKeys = append(agentKeys, k) + } + sort.Strings(agentKeys) + resp.ByAgent = make( + []db.ToolAgentBreakdown, 0, len(agentKeys), + ) + for _, agent := range agentKeys { + cats := agentCats[agent] + total := 0 + for _, c := range cats { + total += c + } + catList := make( + []db.ToolCategoryCount, 0, len(cats), + ) + for cat, count := range cats { + pct := math.Round( + float64(count)/float64(total)*1000, + ) / 10 + catList = append(catList, db.ToolCategoryCount{ + Category: cat, Count: count, Pct: pct, + }) + } + sort.Slice(catList, func(i, j int) bool { + if catList[i].Count != catList[j].Count { + return catList[i].Count > catList[j].Count + } + return catList[i].Category < + catList[j].Category + }) + resp.ByAgent = append(resp.ByAgent, + db.ToolAgentBreakdown{ + Agent: agent, + Total: total, + Categories: catList, + }) + } + + resp.Trend = make( + []db.ToolTrendEntry, 0, len(trendBuckets), + ) + for week, cats := range trendBuckets { + resp.Trend = append(resp.Trend, db.ToolTrendEntry{ + Date: week, ByCat: cats, + }) + } + sort.Slice(resp.Trend, func(i, j int) bool { + return resp.Trend[i].Date < resp.Trend[j].Date + }) + + return resp, nil +} + +// --- Velocity --- + +// velocityMsg holds per-message data needed for velocity. +type velocityMsg struct { + role string + ts time.Time + valid bool + contentLength int +} + +// queryVelocityMsgs fetches messages for a chunk of session +// IDs and appends them to sessionMsgs. +func (s *Store) queryVelocityMsgs( + ctx context.Context, + chunk []string, + loc *time.Location, + sessionMsgs map[string][]velocityMsg, +) error { + pb := ¶mBuilder{} + ph := pgInPlaceholders(chunk, pb) + q := `SELECT session_id, ordinal, role, + TO_CHAR(timestamp AT TIME ZONE 'UTC', + 'YYYY-MM-DD"T"HH24:MI:SS.US"Z"'), + content_length + FROM messages + WHERE session_id IN ` + ph + ` + ORDER BY session_id, ordinal` + + rows, err := s.pg.QueryContext(ctx, q, pb.args...) + if err != nil { + return fmt.Errorf( + "querying velocity messages: %w", err, + ) + } + defer rows.Close() + + for rows.Next() { + var sid string + var ordinal int + var role string + var ts *string + var cl int + if err := rows.Scan( + &sid, &ordinal, &role, &ts, &cl, + ); err != nil { + return fmt.Errorf( + "scanning velocity msg: %w", err, + ) + } + tsStr := "" + if ts != nil { + tsStr = *ts + } + t, ok := localTime(tsStr, loc) + sessionMsgs[sid] = append(sessionMsgs[sid], + velocityMsg{ + role: role, ts: t, valid: ok, + contentLength: cl, + }) + } + return rows.Err() +} + +// complexityBucket returns the complexity label. +func complexityBucket(mc int) string { + switch { + case mc <= 15: + return "1-15" + case mc <= 60: + return "16-60" + default: + return "61+" + } +} + +// velocityAccumulator collects raw values for a velocity +// group. +type velocityAccumulator struct { + turnCycles []float64 + firstResponses []float64 + totalMsgs int + totalChars int + totalToolCalls int + activeMinutes float64 + sessions int +} + +func (a *velocityAccumulator) computeOverview() db.VelocityOverview { + sort.Float64s(a.turnCycles) + sort.Float64s(a.firstResponses) + + var v db.VelocityOverview + v.TurnCycleSec = db.Percentiles{ + P50: math.Round( + percentileFloat(a.turnCycles, 0.5)*10) / 10, + P90: math.Round( + percentileFloat(a.turnCycles, 0.9)*10) / 10, + } + v.FirstResponseSec = db.Percentiles{ + P50: math.Round( + percentileFloat( + a.firstResponses, 0.5)*10) / 10, + P90: math.Round( + percentileFloat( + a.firstResponses, 0.9)*10) / 10, + } + if a.activeMinutes > 0 { + v.MsgsPerActiveMin = math.Round( + float64(a.totalMsgs)/ + a.activeMinutes*10) / 10 + v.CharsPerActiveMin = math.Round( + float64(a.totalChars)/ + a.activeMinutes*10) / 10 + v.ToolCallsPerActiveMin = math.Round( + float64(a.totalToolCalls)/ + a.activeMinutes*10) / 10 + } + return v +} + +// GetAnalyticsVelocity computes turn cycle, first response, +// and throughput metrics. +func (s *Store) GetAnalyticsVelocity( + ctx context.Context, f db.AnalyticsFilter, +) (db.VelocityResponse, error) { + loc := analyticsLocation(f) + pb := ¶mBuilder{} + where := buildAnalyticsWhere(f, pgDateCol, pb) + + var timeIDs map[string]bool + if f.HasTimeFilter() { + var err error + timeIDs, err = s.filteredSessionIDs(ctx, f) + if err != nil { + return db.VelocityResponse{}, err + } + } + + sessQuery := `SELECT id, ` + pgDateCol + `, agent, + message_count FROM sessions WHERE ` + where + + sessRows, err := s.pg.QueryContext( + ctx, sessQuery, pb.args..., + ) + if err != nil { + return db.VelocityResponse{}, + fmt.Errorf( + "querying velocity sessions: %w", err, + ) + } + defer sessRows.Close() + + type sessInfo struct { + agent string + mc int + } + sessionMap := make(map[string]sessInfo) + var sessionIDs []string + + for sessRows.Next() { + var id, agent string + var ts *time.Time + var mc int + if err := sessRows.Scan( + &id, &ts, &agent, &mc, + ); err != nil { + return db.VelocityResponse{}, + fmt.Errorf( + "scanning velocity session: %w", + err, + ) + } + date := localDate(scanDateCol(ts), loc) + if !inDateRange(date, f.From, f.To) { + continue + } + if timeIDs != nil && !timeIDs[id] { + continue + } + sessionMap[id] = sessInfo{agent: agent, mc: mc} + sessionIDs = append(sessionIDs, id) + } + if err := sessRows.Err(); err != nil { + return db.VelocityResponse{}, + fmt.Errorf( + "iterating velocity sessions: %w", err, + ) + } + + if len(sessionIDs) == 0 { + return db.VelocityResponse{ + ByAgent: []db.VelocityBreakdown{}, + ByComplexity: []db.VelocityBreakdown{}, + }, nil + } + + sessionMsgs := make(map[string][]velocityMsg) + err = pgQueryChunked(sessionIDs, + func(chunk []string) error { + return s.queryVelocityMsgs( + ctx, chunk, loc, sessionMsgs, + ) + }) + if err != nil { + return db.VelocityResponse{}, err + } + + toolCountMap := make(map[string]int) + err = pgQueryChunked(sessionIDs, + func(chunk []string) error { + chunkPB := ¶mBuilder{} + ph := pgInPlaceholders(chunk, chunkPB) + q := `SELECT session_id, COUNT(*) + FROM tool_calls + WHERE session_id IN ` + ph + ` + GROUP BY session_id` + rows, qErr := s.pg.QueryContext( + ctx, q, chunkPB.args..., + ) + if qErr != nil { + return fmt.Errorf( + "querying velocity tool_calls: %w", + qErr, + ) + } + defer rows.Close() + for rows.Next() { + var sid string + var count int + if err := rows.Scan( + &sid, &count, + ); err != nil { + return fmt.Errorf( + "scanning velocity tool_call: %w", + err, + ) + } + toolCountMap[sid] = count + } + return rows.Err() + }) + if err != nil { + return db.VelocityResponse{}, err + } + + overall := &velocityAccumulator{} + byAgent := make(map[string]*velocityAccumulator) + byComplexity := make(map[string]*velocityAccumulator) + + const maxCycleSec = 1800.0 + const maxGapSec = 300.0 + + for _, sid := range sessionIDs { + info := sessionMap[sid] + msgs := sessionMsgs[sid] + if len(msgs) < 2 { + continue + } + + agentKey := info.agent + compKey := complexityBucket(info.mc) + + if byAgent[agentKey] == nil { + byAgent[agentKey] = &velocityAccumulator{} + } + if byComplexity[compKey] == nil { + byComplexity[compKey] = &velocityAccumulator{} + } + + accums := []*velocityAccumulator{ + overall, + byAgent[agentKey], + byComplexity[compKey], + } + + for _, a := range accums { + a.sessions++ + } + + for i := 1; i < len(msgs); i++ { + prev := msgs[i-1] + cur := msgs[i] + if !prev.valid || !cur.valid { + continue + } + if prev.role == "user" && + cur.role == "assistant" { + delta := cur.ts.Sub(prev.ts).Seconds() + if delta > 0 && delta <= maxCycleSec { + for _, a := range accums { + a.turnCycles = append( + a.turnCycles, delta, + ) + } + } + } + } + + var firstUser, firstAsst *velocityMsg + firstUserIdx := -1 + for i := range msgs { + if msgs[i].role == "user" && msgs[i].valid { + firstUser = &msgs[i] + firstUserIdx = i + break + } + } + if firstUserIdx >= 0 { + for i := firstUserIdx + 1; i < len(msgs); i++ { + if msgs[i].role == "assistant" && + msgs[i].valid { + firstAsst = &msgs[i] + break + } + } + } + if firstUser != nil && firstAsst != nil { + delta := firstAsst.ts.Sub( + firstUser.ts, + ).Seconds() + if delta < 0 { + delta = 0 + } + for _, a := range accums { + a.firstResponses = append( + a.firstResponses, delta, + ) + } + } + + activeSec := 0.0 + asstChars := 0 + for i, m := range msgs { + if m.role == "assistant" { + asstChars += m.contentLength + } + if i > 0 && msgs[i-1].valid && m.valid { + gap := m.ts.Sub( + msgs[i-1].ts, + ).Seconds() + if gap > 0 { + if gap > maxGapSec { + gap = maxGapSec + } + activeSec += gap + } + } + } + activeMins := activeSec / 60.0 + if activeMins > 0 { + tc := toolCountMap[sid] + for _, a := range accums { + a.totalMsgs += len(msgs) + a.totalChars += asstChars + a.totalToolCalls += tc + a.activeMinutes += activeMins + } + } + } + + resp := db.VelocityResponse{ + Overall: overall.computeOverview(), + } + + agentKeys := make([]string, 0, len(byAgent)) + for k := range byAgent { + agentKeys = append(agentKeys, k) + } + sort.Strings(agentKeys) + resp.ByAgent = make( + []db.VelocityBreakdown, 0, len(agentKeys), + ) + for _, k := range agentKeys { + a := byAgent[k] + resp.ByAgent = append(resp.ByAgent, + db.VelocityBreakdown{ + Label: k, + Sessions: a.sessions, + Overview: a.computeOverview(), + }) + } + + compOrder := map[string]int{ + "1-15": 0, "16-60": 1, "61+": 2, + } + compKeys := make([]string, 0, len(byComplexity)) + for k := range byComplexity { + compKeys = append(compKeys, k) + } + sort.Slice(compKeys, func(i, j int) bool { + return compOrder[compKeys[i]] < + compOrder[compKeys[j]] + }) + resp.ByComplexity = make( + []db.VelocityBreakdown, 0, len(compKeys), + ) + for _, k := range compKeys { + a := byComplexity[k] + resp.ByComplexity = append(resp.ByComplexity, + db.VelocityBreakdown{ + Label: k, + Sessions: a.sessions, + Overview: a.computeOverview(), + }) + } + + return resp, nil +} + +// --- Top Sessions --- + +// GetAnalyticsTopSessions returns the top 10 sessions by the +// given metric. +func (s *Store) GetAnalyticsTopSessions( + ctx context.Context, f db.AnalyticsFilter, + metric string, +) (db.TopSessionsResponse, error) { + if metric == "" { + metric = "messages" + } + loc := analyticsLocation(f) + pb := ¶mBuilder{} + where := buildAnalyticsWhere(f, pgDateCol, pb) + + var timeIDs map[string]bool + if f.HasTimeFilter() { + var err error + timeIDs, err = s.filteredSessionIDs(ctx, f) + if err != nil { + return db.TopSessionsResponse{}, err + } + } + + needsGoSort := metric == "duration" + if metric != "duration" && metric != "messages" { + metric = "messages" + } + orderExpr := "message_count DESC, id ASC" + if metric == "duration" { + where += " AND started_at IS NOT NULL" + + " AND ended_at IS NOT NULL" + } + + limitClause := " LIMIT 1000" + if f.HasTimeFilter() || needsGoSort { + limitClause = "" + } + query := `SELECT id, ` + pgDateCol + `, project, + first_message, message_count, + EXTRACT(EPOCH FROM ended_at - started_at) + AS duration_sec + FROM sessions WHERE ` + where + + ` ORDER BY ` + orderExpr + limitClause + + rows, err := s.pg.QueryContext( + ctx, query, pb.args..., + ) + if err != nil { + return db.TopSessionsResponse{}, + fmt.Errorf( + "querying top sessions: %w", err, + ) + } + defer rows.Close() + + sessions := []db.TopSession{} + for rows.Next() { + var id, project string + var ts *time.Time + var firstMsg *string + var mc int + var durationSec *float64 + if err := rows.Scan( + &id, &ts, &project, &firstMsg, + &mc, &durationSec, + ); err != nil { + return db.TopSessionsResponse{}, + fmt.Errorf( + "scanning top session: %w", err, + ) + } + date := localDate(scanDateCol(ts), loc) + if !inDateRange(date, f.From, f.To) { + continue + } + if timeIDs != nil && !timeIDs[id] { + continue + } + durMin := 0.0 + if durationSec != nil { + durMin = *durationSec / 60.0 + } else if needsGoSort { + continue + } + sessions = append(sessions, db.TopSession{ + ID: id, + Project: project, + FirstMessage: firstMsg, + MessageCount: mc, + DurationMin: durMin, + }) + } + if err := rows.Err(); err != nil { + return db.TopSessionsResponse{}, + fmt.Errorf( + "iterating top sessions: %w", err, + ) + } + + sessions = rankTopSessions(sessions, needsGoSort) + + return db.TopSessionsResponse{ + Metric: metric, + Sessions: sessions, + }, nil +} + +// rankTopSessions sorts sessions by duration (if +// needsGoSort), truncates to top 10, and rounds DurationMin. +func rankTopSessions( + sessions []db.TopSession, needsGoSort bool, +) []db.TopSession { + if sessions == nil { + return []db.TopSession{} + } + if needsGoSort && len(sessions) > 1 { + sort.SliceStable(sessions, func(i, j int) bool { + if sessions[i].DurationMin != + sessions[j].DurationMin { + return sessions[i].DurationMin > + sessions[j].DurationMin + } + return sessions[i].ID < sessions[j].ID + }) + } + if len(sessions) > 10 { + sessions = sessions[:10] + } + for i := range sessions { + sessions[i].DurationMin = math.Round( + sessions[i].DurationMin*10) / 10 + } + return sessions +} diff --git a/internal/postgres/analytics_pgtest_test.go b/internal/postgres/analytics_pgtest_test.go new file mode 100644 index 00000000..4bf82ea0 --- /dev/null +++ b/internal/postgres/analytics_pgtest_test.go @@ -0,0 +1,114 @@ +package postgres + +import ( + "testing" + + "github.com/wesm/agentsview/internal/db" +) + +func TestRankTopSessions_DurationSort(t *testing.T) { + sessions := []db.TopSession{ + {ID: "a", DurationMin: 10.0}, + {ID: "b", DurationMin: 30.0}, + {ID: "c", DurationMin: 20.0}, + } + got := rankTopSessions(sessions, true) + if got[0].ID != "b" || got[1].ID != "c" || + got[2].ID != "a" { + t.Errorf("expected b,c,a order, got %s,%s,%s", + got[0].ID, got[1].ID, got[2].ID) + } +} + +func TestRankTopSessions_DurationTieBreaker(t *testing.T) { + sessions := []db.TopSession{ + {ID: "z", DurationMin: 5.0}, + {ID: "a", DurationMin: 5.0}, + {ID: "m", DurationMin: 5.0}, + } + got := rankTopSessions(sessions, true) + if got[0].ID != "a" || got[1].ID != "m" || + got[2].ID != "z" { + t.Errorf( + "expected a,m,z tie-break order, got %s,%s,%s", + got[0].ID, got[1].ID, got[2].ID) + } +} + +func TestRankTopSessions_NearTiePrecision(t *testing.T) { + sessions := []db.TopSession{ + {ID: "a", DurationMin: 10.04}, + {ID: "b", DurationMin: 10.06}, + } + got := rankTopSessions(sessions, true) + if got[0].ID != "b" { + t.Errorf("expected b first (10.06 > 10.04), got %s", + got[0].ID) + } + if got[0].DurationMin != 10.1 || + got[1].DurationMin != 10.0 { + t.Errorf("expected rounded 10.1, 10.0; got %.1f, %.1f", + got[0].DurationMin, got[1].DurationMin) + } +} + +func TestRankTopSessions_TruncatesTo10(t *testing.T) { + sessions := make([]db.TopSession, 15) + for i := range sessions { + sessions[i] = db.TopSession{ + ID: string(rune('a' + i)), + DurationMin: float64(i), + } + } + got := rankTopSessions(sessions, true) + if len(got) != 10 { + t.Errorf("expected 10 sessions, got %d", len(got)) + } + if got[0].DurationMin != 14.0 { + t.Errorf( + "expected first session duration 14.0, got %.1f", + got[0].DurationMin) + } +} + +func TestRankTopSessions_NoSortForMessages(t *testing.T) { + sessions := []db.TopSession{ + {ID: "c", MessageCount: 10}, + {ID: "a", MessageCount: 30}, + {ID: "b", MessageCount: 20}, + } + got := rankTopSessions(sessions, false) + if got[0].ID != "c" || got[1].ID != "a" || + got[2].ID != "b" { + t.Errorf( + "expected preserved order c,a,b, got %s,%s,%s", + got[0].ID, got[1].ID, got[2].ID) + } +} + +func TestRankTopSessions_NilInput(t *testing.T) { + got := rankTopSessions(nil, true) + if got == nil { + t.Error("expected non-nil empty slice, got nil") + } + if len(got) != 0 { + t.Errorf("expected empty slice, got %d elements", + len(got)) + } +} + +func TestRankTopSessions_RoundsForDisplay(t *testing.T) { + sessions := []db.TopSession{ + {ID: "a", DurationMin: 12.349}, + {ID: "b", DurationMin: 12.351}, + } + got := rankTopSessions(sessions, true) + if got[0].DurationMin != 12.4 { + t.Errorf("expected 12.4, got %v", + got[0].DurationMin) + } + if got[1].DurationMin != 12.3 { + t.Errorf("expected 12.3, got %v", + got[1].DurationMin) + } +} diff --git a/internal/postgres/connect.go b/internal/postgres/connect.go new file mode 100644 index 00000000..7bb01016 --- /dev/null +++ b/internal/postgres/connect.go @@ -0,0 +1,231 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "log" + "net" + "net/url" + "regexp" + "strings" + "time" + + "github.com/jackc/pgx/v5/pgconn" + _ "github.com/jackc/pgx/v5/stdlib" +) + +// RedactDSN returns the host portion of the DSN for diagnostics, +// stripping credentials, query parameters, and path components +// that may contain secrets. +func RedactDSN(dsn string) string { + u, err := url.Parse(dsn) + if err != nil { + return "" + } + return u.Hostname() +} + +// CheckSSL returns an error when the PG connection string targets +// a non-loopback host without TLS encryption. It uses the pgx +// driver's own DSN parser to resolve the effective host and TLS +// configuration, avoiding bypasses from exotic DSN formats. +// +// A connection is rejected when any path in the TLS negotiation +// chain (primary + fallbacks) permits plaintext for a non-loopback +// host. This rejects sslmode=disable, allow, and prefer. +func CheckSSL(dsn string) error { + cfg, err := pgconn.ParseConfig(dsn) + if err != nil { + return fmt.Errorf("parsing pg connection string: %w", err) + } + if isLoopback(cfg.Host) { + return nil + } + if hasPlaintextPath(cfg) { + return fmt.Errorf( + "pg connection to %s permits plaintext; "+ + "set sslmode=require (or verify-full) "+ + "for non-local hosts, "+ + "or set allow_insecure = true under [pg] "+ + "in config to override", + cfg.Host, + ) + } + return nil +} + +// WarnInsecureSSL logs a warning when the PG connection string +// targets a non-loopback host without TLS encryption. Uses the +// pgx driver's DSN parser for accurate host/TLS resolution. +func WarnInsecureSSL(dsn string) { + cfg, err := pgconn.ParseConfig(dsn) + if err != nil { + return + } + if isLoopback(cfg.Host) { + return + } + if hasPlaintextPath(cfg) { + log.Printf( + "warning: pg connection to %s permits "+ + "plaintext; consider sslmode=require or "+ + "verify-full for non-local hosts", + cfg.Host, + ) + } +} + +// hasPlaintextPath returns true if any path in the pgconn +// connection chain (primary config + fallbacks) has TLS disabled. +// This catches sslmode=disable (no TLS), sslmode=allow (plaintext +// first, TLS fallback), and sslmode=prefer (TLS first, plaintext +// fallback). +func hasPlaintextPath(cfg *pgconn.Config) bool { + if cfg.TLSConfig == nil { + return true + } + for _, fb := range cfg.Fallbacks { + if fb.TLSConfig == nil { + return true + } + } + return false +} + +// isLoopback returns true if host is a loopback address, +// localhost, a unix socket path, or empty (defaults to local +// connection). +func isLoopback(host string) bool { + if host == "" || host == "localhost" { + return true + } + ip := net.ParseIP(host) + if ip != nil && ip.IsLoopback() { + return true + } + // Unix socket paths start with / + if len(host) > 0 && host[0] == '/' { + return true + } + return false +} + +// validIdentifier matches simple SQL identifiers (letters, +// digits, underscores). Used to reject schema names that could +// enable SQL injection. +var validIdentifier = regexp.MustCompile( + `^[a-zA-Z_][a-zA-Z0-9_]*$`, +) + +// quoteIdentifier double-quotes a SQL identifier, escaping any +// embedded double quotes. Rejects empty or non-identifier strings +// to prevent injection. +func quoteIdentifier(name string) (string, error) { + if name == "" { + return "", fmt.Errorf( + "schema name must not be empty", + ) + } + if !validIdentifier.MatchString(name) { + return "", fmt.Errorf( + "invalid schema name: %q", name, + ) + } + return `"` + name + `"`, nil +} + +// Open opens a PG connection pool, validates SSL, and sets +// search_path to the given schema on every connection. +// +// The schema name is validated and quoted to prevent injection. +// When allowInsecure is true, non-loopback connections without +// TLS produce a warning instead of failing. +func Open( + dsn, schema string, allowInsecure bool, +) (*sql.DB, error) { + if dsn == "" { + return nil, fmt.Errorf("postgres URL is required") + } + quoted, err := quoteIdentifier(schema) + if err != nil { + return nil, fmt.Errorf("invalid pg schema: %w", err) + } + + if allowInsecure { + WarnInsecureSSL(dsn) + } else if err := CheckSSL(dsn); err != nil { + return nil, err + } + + // Append search_path and timezone as runtime parameters in + // the DSN so every connection in the pool inherits them. + // pgx's stdlib driver passes options through to ConnConfig. + connStr, err := appendConnParams(dsn, map[string]string{ + "search_path": quoted, + "TimeZone": "UTC", + }) + if err != nil { + return nil, fmt.Errorf( + "setting connection params: %w", err, + ) + } + + db, err := sql.Open("pgx", connStr) + if err != nil { + return nil, fmt.Errorf( + "opening pg (host=%s): %w", + RedactDSN(dsn), err, + ) + } + db.SetMaxOpenConns(5) + db.SetMaxIdleConns(5) + db.SetConnMaxLifetime(30 * time.Minute) + db.SetConnMaxIdleTime(5 * time.Minute) + + ctx, cancel := context.WithTimeout( + context.Background(), 10*time.Second, + ) + defer cancel() + if err := db.PingContext(ctx); err != nil { + db.Close() + return nil, fmt.Errorf( + "pg ping (host=%s): %w", + RedactDSN(dsn), err, + ) + } + return db, nil +} + +// appendConnParams injects key=value connection parameters into +// the DSN. For URI-style DSNs it adds query parameters; for +// key=value DSNs it appends key=value pairs. +func appendConnParams( + dsn string, params map[string]string, +) (string, error) { + // URI format: postgres://... + if strings.HasPrefix(dsn, "postgres://") || + strings.HasPrefix(dsn, "postgresql://") { + u, err := url.Parse(dsn) + if err != nil { + return "", fmt.Errorf( + "parsing pg URI: %w", err, + ) + } + q := u.Query() + for k, v := range params { + q.Set(k, v) + } + u.RawQuery = q.Encode() + return u.String(), nil + } + // Key=value format: append parameters. + result := dsn + for k, v := range params { + if result != "" { + result += " " + } + result += k + "=" + v + } + return result, nil +} diff --git a/internal/postgres/connect_test.go b/internal/postgres/connect_test.go new file mode 100644 index 00000000..fec3071a --- /dev/null +++ b/internal/postgres/connect_test.go @@ -0,0 +1,191 @@ +package postgres + +import "testing" + +func TestCheckSSL(t *testing.T) { + tests := []struct { + name string + dsn string + wantErr bool + }{ + { + "loopback localhost", + "postgres://user:pass@localhost:5432/db", + false, + }, + { + "loopback 127.0.0.1", + "postgres://user:pass@127.0.0.1:5432/db", + false, + }, + { + "loopback ::1", + "postgres://user:pass@[::1]:5432/db", + false, + }, + { + "empty host defaults local", + "", + false, + }, + { + "remote with require", + "postgres://u:p@remote:5432/db?sslmode=require", + false, + }, + { + "remote with verify-full", + "postgres://u:p@remote:5432/db?sslmode=verify-full", + false, + }, + { + "remote no sslmode", + "postgres://u:p@remote:5432/db", + true, + }, + { + "remote sslmode=disable", + "postgres://u:p@remote:5432/db?sslmode=disable", + true, + }, + { + "remote sslmode=prefer", + "postgres://u:p@remote:5432/db?sslmode=prefer", + true, + }, + { + "remote sslmode=allow", + "postgres://u:p@remote:5432/db?sslmode=allow", + true, + }, + { + "kv remote require", + "host=remote sslmode=require", + false, + }, + { + "kv remote disable", + "host=remote sslmode=disable", + true, + }, + { + "kv unix socket", + "host=/var/run/postgresql sslmode=disable", + false, + }, + { + "uri query host disable", + "postgres:///db?host=remote&sslmode=disable", + true, + }, + { + "uri query host require", + "postgres:///db?host=remote&sslmode=require", + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CheckSSL(tt.dsn) + if (err != nil) != tt.wantErr { + t.Errorf( + "CheckSSL() error = %v, wantErr %v", + err, tt.wantErr, + ) + } + }) + } +} + +func TestRedactDSN(t *testing.T) { + tests := []struct { + name string + dsn string + want string + }{ + { + "strips credentials", + "postgres://user:secret@myhost:5432/db", + "myhost", + }, + { + "empty dsn", + "", + "", + }, + { + "invalid dsn", + "://bad", + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := RedactDSN(tt.dsn) + if got != tt.want { + t.Errorf( + "RedactDSN() = %q, want %q", + got, tt.want, + ) + } + }) + } +} + +func TestIsLoopback(t *testing.T) { + tests := []struct { + host string + want bool + }{ + {"", true}, + {"localhost", true}, + {"127.0.0.1", true}, + {"::1", true}, + {"/var/run/postgresql", true}, + {"remote.host.com", false}, + {"10.0.0.1", false}, + } + for _, tt := range tests { + t.Run(tt.host, func(t *testing.T) { + if got := isLoopback(tt.host); got != tt.want { + t.Errorf( + "isLoopback(%q) = %v, want %v", + tt.host, got, tt.want, + ) + } + }) + } +} + +func TestQuoteIdentifier(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + {"simple", "agentsview", `"agentsview"`, false}, + {"underscore", "my_schema", `"my_schema"`, false}, + {"empty", "", "", true}, + {"has spaces", "bad schema", "", true}, + {"has semicolon", "bad;drop", "", true}, + {"starts with digit", "1bad", "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := quoteIdentifier(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf( + "quoteIdentifier() err = %v, wantErr %v", + err, tt.wantErr, + ) + } + if got != tt.want { + t.Errorf( + "quoteIdentifier() = %q, want %q", + got, tt.want, + ) + } + }) + } +} diff --git a/internal/postgres/helpers_pgtest_test.go b/internal/postgres/helpers_pgtest_test.go new file mode 100644 index 00000000..4fe5c46a --- /dev/null +++ b/internal/postgres/helpers_pgtest_test.go @@ -0,0 +1,17 @@ +//go:build pgtest + +package postgres + +import ( + "os" + "testing" +) + +func testPGURL(t *testing.T) string { + t.Helper() + url := os.Getenv("TEST_PG_URL") + if url == "" { + t.Skip("TEST_PG_URL not set; skipping PG tests") + } + return url +} diff --git a/internal/postgres/integration_test.go b/internal/postgres/integration_test.go new file mode 100644 index 00000000..780b1676 --- /dev/null +++ b/internal/postgres/integration_test.go @@ -0,0 +1,115 @@ +//go:build pgtest + +package postgres + +import ( + "context" + "testing" + "time" + + "github.com/wesm/agentsview/internal/db" +) + +func TestPGConnectivity(t *testing.T) { + pgURL := testPGURL(t) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "connectivity-test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx, cancel := context.WithTimeout( + context.Background(), 10*time.Second, + ) + defer cancel() + + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + status, err := ps.Status(ctx) + if err != nil { + t.Fatalf("get status: %v", err) + } + + t.Logf("PG Sync Status: %+v", status) +} + +func TestPGPushCycle(t *testing.T) { + pgURL := testPGURL(t) + + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, "machine-a", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + started := time.Now().UTC().Format(time.RFC3339) + firstMsg := "hello from pg" + sess := db.Session{ + ID: "pg-sess-001", + Project: "pg-project", + Machine: "local", + Agent: "test-agent", + FirstMessage: &firstMsg, + StartedAt: &started, + MessageCount: 1, + } + if err := local.UpsertSession(sess); err != nil { + t.Fatalf("upsert session: %v", err) + } + if err := local.InsertMessages([]db.Message{{ + SessionID: "pg-sess-001", + Ordinal: 0, + Role: "user", + Content: firstMsg, + }}); err != nil { + t.Fatalf("insert message: %v", err) + } + + pushResult, err := ps.Push(ctx, false) + if err != nil { + t.Fatalf("push: %v", err) + } + if pushResult.SessionsPushed != 1 || + pushResult.MessagesPushed != 1 { + t.Fatalf( + "pushed %d sessions, %d messages; want 1/1", + pushResult.SessionsPushed, + pushResult.MessagesPushed, + ) + } + + status, err := ps.Status(ctx) + if err != nil { + t.Fatalf("status: %v", err) + } + if status.PGSessions != 1 { + t.Errorf( + "pg sessions = %d, want 1", + status.PGSessions, + ) + } + if status.PGMessages != 1 { + t.Errorf( + "pg messages = %d, want 1", + status.PGMessages, + ) + } +} diff --git a/internal/postgres/messages.go b/internal/postgres/messages.go new file mode 100644 index 00000000..ac844a49 --- /dev/null +++ b/internal/postgres/messages.go @@ -0,0 +1,403 @@ +package postgres + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/wesm/agentsview/internal/db" +) + +const attachToolCallBatchSize = 500 + +// GetMessages returns paginated messages for a session. +func (s *Store) GetMessages( + ctx context.Context, + sessionID string, from, limit int, asc bool, +) ([]db.Message, error) { + if limit <= 0 || limit > db.MaxMessageLimit { + limit = db.DefaultMessageLimit + } + + dir := "ASC" + op := ">=" + if !asc { + dir = "DESC" + op = "<=" + } + + query := fmt.Sprintf(` + SELECT session_id, ordinal, role, content, + timestamp, has_thinking, has_tool_use, + content_length + FROM messages + WHERE session_id = $1 AND ordinal %s $2 + ORDER BY ordinal %s + LIMIT $3`, op, dir) + + rows, err := s.pg.QueryContext( + ctx, query, sessionID, from, limit, + ) + if err != nil { + return nil, fmt.Errorf( + "querying messages: %w", err, + ) + } + defer rows.Close() + + msgs, err := scanPGMessages(rows) + if err != nil { + return nil, err + } + if err := s.attachToolCalls(ctx, msgs); err != nil { + return nil, err + } + return msgs, nil +} + +// GetAllMessages returns all messages for a session ordered +// by ordinal. +func (s *Store) GetAllMessages( + ctx context.Context, sessionID string, +) ([]db.Message, error) { + rows, err := s.pg.QueryContext(ctx, ` + SELECT session_id, ordinal, role, content, + timestamp, has_thinking, has_tool_use, + content_length + FROM messages + WHERE session_id = $1 + ORDER BY ordinal ASC`, sessionID) + if err != nil { + return nil, fmt.Errorf( + "querying all messages: %w", err, + ) + } + defer rows.Close() + + msgs, err := scanPGMessages(rows) + if err != nil { + return nil, err + } + if err := s.attachToolCalls(ctx, msgs); err != nil { + return nil, err + } + return msgs, nil +} + +// GetMinimap returns lightweight metadata for all messages +// in a session. +func (s *Store) GetMinimap( + ctx context.Context, sessionID string, +) ([]db.MinimapEntry, error) { + return s.GetMinimapFrom(ctx, sessionID, 0) +} + +// GetMinimapFrom returns lightweight metadata for messages +// starting at ordinal >= from. +func (s *Store) GetMinimapFrom( + ctx context.Context, sessionID string, from int, +) ([]db.MinimapEntry, error) { + rows, err := s.pg.QueryContext(ctx, ` + SELECT ordinal, role, content_length, + has_thinking, has_tool_use + FROM messages + WHERE session_id = $1 AND ordinal >= $2 + ORDER BY ordinal ASC`, sessionID, from) + if err != nil { + return nil, fmt.Errorf( + "querying minimap: %w", err, + ) + } + defer rows.Close() + + entries := []db.MinimapEntry{} + for rows.Next() { + var e db.MinimapEntry + if err := rows.Scan( + &e.Ordinal, &e.Role, &e.ContentLength, + &e.HasThinking, &e.HasToolUse, + ); err != nil { + return nil, fmt.Errorf( + "scanning minimap entry: %w", err, + ) + } + entries = append(entries, e) + } + return entries, rows.Err() +} + +// SearchSession performs ILIKE substring search within a single +// session's messages, returning matching ordinals. +func (s *Store) SearchSession( + ctx context.Context, sessionID, query string, +) ([]int, error) { + if query == "" { + return nil, nil + } + like := "%" + escapeLike(query) + "%" + rows, err := s.pg.QueryContext(ctx, ` + SELECT DISTINCT m.ordinal + FROM messages m + LEFT JOIN tool_calls tc + ON tc.session_id = m.session_id + AND tc.message_ordinal = m.ordinal + WHERE m.session_id = $1 + AND (m.content ILIKE $2 + OR tc.result_content ILIKE $2) + ORDER BY m.ordinal ASC`, + sessionID, like, + ) + if err != nil { + return nil, fmt.Errorf( + "searching session: %w", err, + ) + } + defer rows.Close() + + var ordinals []int + for rows.Next() { + var ord int + if err := rows.Scan(&ord); err != nil { + return nil, fmt.Errorf( + "scanning ordinal: %w", err, + ) + } + ordinals = append(ordinals, ord) + } + return ordinals, rows.Err() +} + +// HasFTS returns true because ILIKE search is available. +func (s *Store) HasFTS() bool { return true } + +// escapeLike escapes SQL LIKE metacharacters so the bind +// parameter is treated as a literal substring. +func escapeLike(v string) string { + r := strings.NewReplacer( + `\`, `\\`, `%`, `\%`, `_`, `\_`, + ) + return r.Replace(v) +} + +// stripFTSQuotes removes surrounding double quotes that +// prepareFTSQuery adds for SQLite FTS phrase matching. +func stripFTSQuotes(v string) string { + if len(v) >= 2 && v[0] == '"' && v[len(v)-1] == '"' { + return v[1 : len(v)-1] + } + return v +} + +// Search performs ILIKE-based search across messages. +func (s *Store) Search( + ctx context.Context, f db.SearchFilter, +) (db.SearchPage, error) { + if f.Limit <= 0 || f.Limit > db.MaxSearchLimit { + f.Limit = db.DefaultSearchLimit + } + + searchTerm := stripFTSQuotes(f.Query) + if searchTerm == "" { + return db.SearchPage{}, nil + } + + whereClauses := []string{ + `m.content ILIKE '%' || $1 || '%' ESCAPE E'\\'`, + "s.deleted_at IS NULL", + } + args := []any{escapeLike(searchTerm), searchTerm} + argIdx := 3 + + if f.Project != "" { + whereClauses = append( + whereClauses, + fmt.Sprintf("s.project = $%d", argIdx), + ) + args = append(args, f.Project) + argIdx++ + } + + query := fmt.Sprintf(` + SELECT m.session_id, s.project, m.ordinal, + m.role, + COALESCE( + TO_CHAR(m.timestamp AT TIME ZONE 'UTC', + 'YYYY-MM-DD"T"HH24:MI:SS"Z"'), + '' + ), + CASE WHEN POSITION( + LOWER($2) IN LOWER(m.content)) > 100 + THEN '...' || SUBSTRING(m.content + FROM GREATEST(1, POSITION( + LOWER($2) IN LOWER(m.content) + ) - 50) FOR 200) || '...' + ELSE SUBSTRING(m.content FROM 1 FOR 200) + || CASE WHEN LENGTH(m.content) > 200 + THEN '...' ELSE '' END + END AS snippet, + 1.0 AS rank + FROM messages m + JOIN sessions s ON m.session_id = s.id + WHERE %s + ORDER BY m.timestamp DESC NULLS LAST, m.session_id, m.ordinal + LIMIT $%d OFFSET $%d`, + strings.Join(whereClauses, " AND "), + argIdx, argIdx+1, + ) + args = append(args, f.Limit+1, f.Cursor) + + rows, err := s.pg.QueryContext(ctx, query, args...) + if err != nil { + return db.SearchPage{}, + fmt.Errorf("searching: %w", err) + } + defer rows.Close() + + results := []db.SearchResult{} + for rows.Next() { + var r db.SearchResult + if err := rows.Scan( + &r.SessionID, &r.Project, &r.Ordinal, + &r.Role, &r.Timestamp, &r.Snippet, &r.Rank, + ); err != nil { + return db.SearchPage{}, + fmt.Errorf( + "scanning search result: %w", err, + ) + } + results = append(results, r) + } + if err := rows.Err(); err != nil { + return db.SearchPage{}, err + } + + page := db.SearchPage{Results: results} + if len(results) > f.Limit { + page.Results = results[:f.Limit] + page.NextCursor = f.Cursor + f.Limit + } + return page, nil +} + +// attachToolCalls loads tool_calls for the given messages and +// attaches them to each message's ToolCalls field. +func (s *Store) attachToolCalls( + ctx context.Context, msgs []db.Message, +) error { + if len(msgs) == 0 { + return nil + } + + ordToIdx := make(map[int]int, len(msgs)) + sessionID := msgs[0].SessionID + ordinals := make([]int, 0, len(msgs)) + for i, m := range msgs { + ordToIdx[m.Ordinal] = i + ordinals = append(ordinals, m.Ordinal) + } + + for i := 0; i < len(ordinals); i += attachToolCallBatchSize { + end := min(i+attachToolCallBatchSize, len(ordinals)) + if err := s.attachToolCallsBatch( + ctx, msgs, ordToIdx, sessionID, + ordinals[i:end], + ); err != nil { + return err + } + } + return nil +} + +func (s *Store) attachToolCallsBatch( + ctx context.Context, + msgs []db.Message, + ordToIdx map[int]int, + sessionID string, + batch []int, +) error { + if len(batch) == 0 { + return nil + } + + args := []any{sessionID} + phs := make([]string, len(batch)) + for i, ord := range batch { + args = append(args, ord) + phs[i] = fmt.Sprintf("$%d", i+2) + } + + query := fmt.Sprintf(` + SELECT message_ordinal, session_id, tool_name, + category, + COALESCE(tool_use_id, ''), + COALESCE(input_json, ''), + COALESCE(skill_name, ''), + COALESCE(result_content_length, 0), + COALESCE(result_content, ''), + COALESCE(subagent_session_id, '') + FROM tool_calls + WHERE session_id = $1 + AND message_ordinal IN (%s) + ORDER BY message_ordinal, call_index`, + strings.Join(phs, ",")) + + rows, err := s.pg.QueryContext(ctx, query, args...) + if err != nil { + return fmt.Errorf( + "querying tool_calls: %w", err, + ) + } + defer rows.Close() + + for rows.Next() { + var tc db.ToolCall + var msgOrdinal int + if err := rows.Scan( + &msgOrdinal, &tc.SessionID, + &tc.ToolName, &tc.Category, + &tc.ToolUseID, &tc.InputJSON, &tc.SkillName, + &tc.ResultContentLength, &tc.ResultContent, + &tc.SubagentSessionID, + ); err != nil { + return fmt.Errorf( + "scanning tool_call: %w", err, + ) + } + if idx, ok := ordToIdx[msgOrdinal]; ok { + msgs[idx].ToolCalls = append( + msgs[idx].ToolCalls, tc, + ) + } + } + return rows.Err() +} + +// scanPGMessages scans message rows from PostgreSQL, +// converting TIMESTAMPTZ to string. +func scanPGMessages(rows interface { + Next() bool + Scan(dest ...any) error + Err() error +}, +) ([]db.Message, error) { + msgs := []db.Message{} + for rows.Next() { + var m db.Message + var ts *time.Time + if err := rows.Scan( + &m.SessionID, &m.Ordinal, &m.Role, + &m.Content, &ts, &m.HasThinking, + &m.HasToolUse, &m.ContentLength, + ); err != nil { + return nil, fmt.Errorf( + "scanning message: %w", err, + ) + } + if ts != nil { + m.Timestamp = FormatISO8601(*ts) + } + msgs = append(msgs, m) + } + return msgs, rows.Err() +} diff --git a/internal/postgres/messages_pgtest_test.go b/internal/postgres/messages_pgtest_test.go new file mode 100644 index 00000000..dba51b17 --- /dev/null +++ b/internal/postgres/messages_pgtest_test.go @@ -0,0 +1,38 @@ +//go:build pgtest + +package postgres + +import ( + "context" + "testing" + + "github.com/wesm/agentsview/internal/db" +) + +func TestStoreSearchILIKE(t *testing.T) { + pgURL := testPGURL(t) + ensureStoreSchema(t, pgURL) + + store, err := NewStore(pgURL, testSchema, true) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + defer store.Close() + + ctx := context.Background() + page, err := store.Search(ctx, db.SearchFilter{ + Query: "hello", + Limit: 10, + }) + if err != nil { + t.Fatalf("Search: %v", err) + } + if len(page.Results) == 0 { + t.Error("expected at least 1 search result") + } + for _, r := range page.Results { + if r.SessionID != "store-test-001" { + t.Errorf("unexpected session %q", r.SessionID) + } + } +} diff --git a/internal/postgres/push.go b/internal/postgres/push.go new file mode 100644 index 00000000..80882ce0 --- /dev/null +++ b/internal/postgres/push.go @@ -0,0 +1,967 @@ +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "log" + "maps" + "sort" + "strings" + "time" + + "github.com/wesm/agentsview/internal/db" +) + +const lastPushBoundaryStateKey = "last_push_boundary_state" + +// syncStateStore abstracts sync state read/write operations on the +// local database. Used by push boundary state helpers. +type syncStateStore interface { + GetSyncState(key string) (string, error) + SetSyncState(key, value string) error +} + +type pushBoundaryState struct { + Cutoff string `json:"cutoff"` + Fingerprints map[string]string `json:"fingerprints"` +} + +// PushResult summarizes a push sync operation. +type PushResult struct { + SessionsPushed int + MessagesPushed int + Errors int + Duration time.Duration +} + +// Push syncs local sessions and messages to PostgreSQL. +// Only sessions modified since the last push are processed. +// When full is true, the per-message content heuristic is +// bypassed and every candidate session's messages are +// re-pushed unconditionally. +// +// Known limitation: sessions that are permanently deleted +// from SQLite (via prune) are not propagated as deletions +// to PG because the local rows no longer exist at push time. +// Sessions soft-deleted with deleted_at are synced correctly. +// Use a direct PG DELETE to remove permanently pruned +// sessions from PG if needed. +func (s *Sync) Push( + ctx context.Context, full bool, +) (PushResult, error) { + start := time.Now() + var result PushResult + + if err := s.normalizeSyncTimestamps(ctx); err != nil { + return result, err + } + + lastPush, err := s.local.GetSyncState("last_push_at") + if err != nil { + return result, fmt.Errorf( + "reading last_push_at: %w", err, + ) + } + if full { + lastPush = "" + // Caller requested a full push — the PG schema + // may have been dropped since schemaDone was set. + // Clear the memo so EnsureSchema re-runs. + s.schemaMu.Lock() + s.schemaDone = false + s.schemaMu.Unlock() + if err := s.normalizeSyncTimestamps( + ctx, + ); err != nil { + return result, err + } + } + + // Coherence check: if the local watermark says we've + // pushed before but PG has zero sessions for this + // machine, the PG side was reset (schema dropped, DB + // recreated, etc.). Force a full push so all sessions + // are re-synced. + if lastPush != "" { + pgCount, cErr := s.pgSessionCount(ctx) + if cErr != nil { + return result, cErr + } + if pgCount == 0 { + log.Printf( + "pgsync: local watermark set but PG has "+ + "0 sessions for machine %q; "+ + "forcing full push", + s.machine, + ) + lastPush = "" + full = true + s.schemaMu.Lock() + s.schemaDone = false + s.schemaMu.Unlock() + if err := s.normalizeSyncTimestamps( + ctx, + ); err != nil { + return result, err + } + } + } + + cutoff := time.Now().UTC().Format(LocalSyncTimestampLayout) + + allSessions, err := s.local.ListSessionsModifiedBetween( + ctx, lastPush, cutoff, + ) + if err != nil { + return result, fmt.Errorf( + "listing modified sessions: %w", err, + ) + } + + sessionByID := make( + map[string]db.Session, len(allSessions), + ) + for _, sess := range allSessions { + sessionByID[sess.ID] = sess + } + + var priorFingerprints map[string]string + var boundaryState map[string]string + var boundaryOK bool + if !full { + var bErr error + priorFingerprints, boundaryState, boundaryOK, bErr = readBoundaryAndFingerprints( + s.local, lastPush, + ) + if bErr != nil { + return result, bErr + } + } + + if lastPush != "" { + ok := boundaryOK + windowStart, err := PreviousLocalSyncTimestamp( + lastPush, + ) + if err != nil { + return result, fmt.Errorf( + "computing push boundary window before %s: %w", + lastPush, err, + ) + } + boundarySessions, err := s.local.ListSessionsModifiedBetween( + ctx, windowStart, lastPush, + ) + if err != nil { + return result, fmt.Errorf( + "listing push boundary sessions: %w", err, + ) + } + + for _, sess := range boundarySessions { + marker := localSessionSyncMarker(sess) + if marker != lastPush { + continue + } + if ok { + fp := sessionPushFingerprint(sess) + if boundaryState[sess.ID] == fp { + continue + } + } + if _, exists := sessionByID[sess.ID]; exists { + continue + } + sessionByID[sess.ID] = sess + } + } + + if len(priorFingerprints) > 0 { + for id, sess := range sessionByID { + fp := sessionPushFingerprint(sess) + if priorFingerprints[id] == fp { + delete(sessionByID, id) + } + } + } + + var sessions []db.Session + for _, sess := range sessionByID { + sessions = append(sessions, sess) + } + sort.Slice(sessions, func(i, j int) bool { + return sessions[i].ID < sessions[j].ID + }) + + if len(sessions) == 0 { + if err := finalizePushState( + s.local, cutoff, sessions, nil, + ); err != nil { + return result, err + } + result.Duration = time.Since(start) + return result, nil + } + + var pushed []db.Session + const batchSize = 50 + for i := 0; i < len(sessions); i += batchSize { + end := min(i+batchSize, len(sessions)) + batch := sessions[i:end] + + batchResult, err := s.pushBatch( + ctx, batch, full, &pushed, + ) + if err != nil { + return result, err + } + if batchResult.ok { + result.SessionsPushed += batchResult.sessions + result.MessagesPushed += batchResult.messages + continue + } + // Batch failed — retry each session individually + // so one bad session doesn't block the rest. + for _, sess := range batch { + sr, retryErr := s.pushBatch( + ctx, []db.Session{sess}, + full, &pushed, + ) + if retryErr != nil { + return result, retryErr + } + if sr.ok { + result.SessionsPushed += sr.sessions + result.MessagesPushed += sr.messages + } else { + result.Errors++ + } + } + } + + // When all sessions succeeded, advance the watermark to + // cutoff. When some failed, keep the watermark at lastPush + // so the failed sessions (plus any already-pushed ones) are + // re-evaluated next time. Already-pushed sessions are + // fingerprint-matched and skipped cheaply. + finalizeCutoff := cutoff + var mergedFingerprints map[string]string + if result.Errors > 0 { + finalizeCutoff = lastPush + mergedFingerprints = priorFingerprints + } + if err := finalizePushState( + s.local, finalizeCutoff, pushed, + mergedFingerprints, + ); err != nil { + return result, err + } + + result.Duration = time.Since(start) + return result, nil +} + +// pgSessionCount returns the number of sessions in PG for +// this machine. Used to detect schema resets. +func (s *Sync) pgSessionCount( + ctx context.Context, +) (int, error) { + var count int + err := s.pg.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sessions WHERE machine = $1", + s.machine, + ).Scan(&count) + if err != nil { + if isUndefinedTable(err) { + return 0, nil + } + return 0, fmt.Errorf( + "counting pg sessions: %w", err, + ) + } + return count, nil +} + +type batchResult struct { + ok bool + sessions int + messages int +} + +// pushBatch pushes a slice of sessions within a single +// transaction. On success it appends to pushed and returns +// ok=true with session/message counts. On a session-level +// error it rolls back and returns ok=false so the caller +// can retry individually. Fatal errors (BeginTx failure) +// return a non-nil error. +func (s *Sync) pushBatch( + ctx context.Context, + batch []db.Session, + full bool, + pushed *[]db.Session, +) (batchResult, error) { + tx, err := s.pg.BeginTx(ctx, nil) + if err != nil { + return batchResult{}, fmt.Errorf( + "begin pg tx: %w", err, + ) + } + + n := 0 + msgs := 0 + for _, sess := range batch { + if err := s.pushSession( + ctx, tx, sess, + ); err != nil { + log.Printf( + "pgsync: session %s: %v", + sess.ID, err, + ) + _ = tx.Rollback() + *pushed = (*pushed)[:len(*pushed)-n] + return batchResult{}, nil + } + + msgCount, err := s.pushMessages( + ctx, tx, sess.ID, full, + ) + if err != nil { + log.Printf( + "pgsync: session %s: %v", + sess.ID, err, + ) + _ = tx.Rollback() + *pushed = (*pushed)[:len(*pushed)-n] + return batchResult{}, nil + } + + // Bump updated_at when messages were rewritten + // but pushSession was a metadata no-op (its + // WHERE clause skips unchanged rows). + if msgCount > 0 { + if _, err := tx.ExecContext(ctx, ` + UPDATE sessions + SET updated_at = NOW() + WHERE id = $1`, + sess.ID, + ); err != nil { + log.Printf( + "pgsync: bumping updated_at %s: %v", + sess.ID, err, + ) + _ = tx.Rollback() + *pushed = (*pushed)[:len(*pushed)-n] + return batchResult{}, nil + } + } + + *pushed = append(*pushed, sess) + n++ + msgs += msgCount + } + + if err := tx.Commit(); err != nil { + log.Printf( + "pgsync: batch commit failed: %v", err, + ) + *pushed = (*pushed)[:len(*pushed)-n] + return batchResult{}, nil + } + return batchResult{ok: true, sessions: n, messages: msgs}, nil +} + +func finalizePushState( + local syncStateStore, + cutoff string, + sessions []db.Session, + priorFingerprints map[string]string, +) error { + if err := local.SetSyncState( + "last_push_at", cutoff, + ); err != nil { + return fmt.Errorf("updating last_push_at: %w", err) + } + return writePushBoundaryState( + local, cutoff, sessions, priorFingerprints, + ) +} + +func readBoundaryAndFingerprints( + local syncStateStore, + cutoff string, +) ( + fingerprints map[string]string, + boundary map[string]string, + boundaryOK bool, + err error, +) { + raw, err := local.GetSyncState( + lastPushBoundaryStateKey, + ) + if err != nil { + return nil, nil, false, fmt.Errorf( + "reading %s: %w", + lastPushBoundaryStateKey, err, + ) + } + if raw == "" { + return nil, nil, false, nil + } + var state pushBoundaryState + if err := json.Unmarshal( + []byte(raw), &state, + ); err != nil { + return nil, nil, false, nil + } + fingerprints = state.Fingerprints + if cutoff != "" && + state.Cutoff == cutoff && + state.Fingerprints != nil { + boundary = state.Fingerprints + boundaryOK = true + } + return fingerprints, boundary, boundaryOK, nil +} + +func writePushBoundaryState( + local syncStateStore, + cutoff string, + sessions []db.Session, + priorFingerprints map[string]string, +) error { + state := pushBoundaryState{ + Cutoff: cutoff, + Fingerprints: make( + map[string]string, + len(priorFingerprints)+len(sessions), + ), + } + maps.Copy(state.Fingerprints, priorFingerprints) + for _, sess := range sessions { + state.Fingerprints[sess.ID] = sessionPushFingerprint(sess) + } + data, err := json.Marshal(state) + if err != nil { + return fmt.Errorf( + "encoding %s: %w", + lastPushBoundaryStateKey, err, + ) + } + if err := local.SetSyncState( + lastPushBoundaryStateKey, string(data), + ); err != nil { + return fmt.Errorf( + "writing %s: %w", + lastPushBoundaryStateKey, err, + ) + } + return nil +} + +func localSessionSyncMarker(sess db.Session) string { + marker, err := NormalizeLocalSyncTimestamp(sess.CreatedAt) + if err != nil || marker == "" { + if err != nil { + log.Printf( + "pgsync: normalizing CreatedAt %q for "+ + "session %s: %v (skipping non-RFC3339 "+ + "value)", + sess.CreatedAt, sess.ID, err, + ) + } + marker = "" + } + for _, value := range []*string{ + sess.LocalModifiedAt, + sess.EndedAt, + sess.StartedAt, + } { + if value == nil { + continue + } + normalized, err := NormalizeLocalSyncTimestamp(*value) + if err != nil { + continue + } + if normalized > marker { + marker = normalized + } + } + if sess.FileMtime != nil { + fileMtime := time.Unix( + 0, *sess.FileMtime, + ).UTC().Format(LocalSyncTimestampLayout) + if fileMtime > marker { + marker = fileMtime + } + } + if marker == "" { + log.Printf( + "pgsync: session %s: all timestamps failed "+ + "normalization, falling back to raw "+ + "CreatedAt %q", + sess.ID, sess.CreatedAt, + ) + marker = sess.CreatedAt + } + return marker +} + +func sessionPushFingerprint(sess db.Session) string { + fields := []string{ + sess.ID, + sess.Project, + sess.Machine, + sess.Agent, + stringValue(sess.FirstMessage), + stringValue(sess.DisplayName), + stringValue(sess.StartedAt), + stringValue(sess.EndedAt), + stringValue(sess.DeletedAt), + fmt.Sprintf("%d", sess.MessageCount), + fmt.Sprintf("%d", sess.UserMessageCount), + stringValue(sess.ParentSessionID), + sess.RelationshipType, + stringValue(sess.FileHash), + int64Value(sess.FileMtime), + stringValue(sess.LocalModifiedAt), + sess.CreatedAt, + } + var b strings.Builder + for _, f := range fields { + fmt.Fprintf(&b, "%d:%s", len(f), f) + } + return b.String() +} + +func stringValue(value *string) string { + if value == nil { + return "" + } + return *value +} + +func int64Value(value *int64) string { + if value == nil { + return "" + } + return fmt.Sprintf("%d", *value) +} + +// nilStr converts a nil or empty *string to SQL NULL. +// Sanitizes before checking emptiness so strings like "\x00" +// that reduce to "" are correctly returned as NULL. +func nilStr(s *string) any { + if s == nil { + return nil + } + v := sanitizePG(*s) + if v == "" { + return nil + } + return v +} + +// nilStrTS converts a nil or empty *string timestamp to a +// *time.Time for PG TIMESTAMPTZ columns. +func nilStrTS(s *string) any { + if s == nil || *s == "" { + return nil + } + t, ok := ParseSQLiteTimestamp(*s) + if !ok { + return nil + } + return t +} + +// pushSession upserts a single session into PG. +// File-level metadata (file_hash, file_path, file_size, +// file_mtime) is intentionally not synced to PG -- it is +// local-only and used solely by the sync engine to detect +// re-parsed sessions. +func (s *Sync) pushSession( + ctx context.Context, tx *sql.Tx, sess db.Session, +) error { + createdAt, _ := ParseSQLiteTimestamp(sess.CreatedAt) + _, err := tx.ExecContext(ctx, ` + INSERT INTO sessions ( + id, machine, project, agent, + first_message, display_name, + created_at, started_at, ended_at, deleted_at, + message_count, user_message_count, + parent_session_id, relationship_type, + updated_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, + $7, $8, $9, $10, + $11, $12, $13, $14, NOW() + ) + ON CONFLICT (id) DO UPDATE SET + machine = EXCLUDED.machine, + project = EXCLUDED.project, + agent = EXCLUDED.agent, + first_message = EXCLUDED.first_message, + display_name = EXCLUDED.display_name, + created_at = EXCLUDED.created_at, + started_at = EXCLUDED.started_at, + ended_at = EXCLUDED.ended_at, + deleted_at = EXCLUDED.deleted_at, + message_count = EXCLUDED.message_count, + user_message_count = EXCLUDED.user_message_count, + parent_session_id = EXCLUDED.parent_session_id, + relationship_type = EXCLUDED.relationship_type, + updated_at = NOW() + WHERE sessions.machine IS DISTINCT FROM EXCLUDED.machine + OR sessions.project IS DISTINCT FROM EXCLUDED.project + OR sessions.agent IS DISTINCT FROM EXCLUDED.agent + OR sessions.first_message IS DISTINCT FROM EXCLUDED.first_message + OR sessions.display_name IS DISTINCT FROM EXCLUDED.display_name + OR sessions.created_at IS DISTINCT FROM EXCLUDED.created_at + OR sessions.started_at IS DISTINCT FROM EXCLUDED.started_at + OR sessions.ended_at IS DISTINCT FROM EXCLUDED.ended_at + OR sessions.deleted_at IS DISTINCT FROM EXCLUDED.deleted_at + OR sessions.message_count IS DISTINCT FROM EXCLUDED.message_count + OR sessions.user_message_count IS DISTINCT FROM EXCLUDED.user_message_count + OR sessions.parent_session_id IS DISTINCT FROM EXCLUDED.parent_session_id + OR sessions.relationship_type IS DISTINCT FROM EXCLUDED.relationship_type`, + sess.ID, s.machine, + sanitizePG(sess.Project), + sess.Agent, + nilStr(sess.FirstMessage), + nilStr(sess.DisplayName), + createdAt, + nilStrTS(sess.StartedAt), + nilStrTS(sess.EndedAt), + nilStrTS(sess.DeletedAt), + sess.MessageCount, sess.UserMessageCount, + nilStr(sess.ParentSessionID), + sess.RelationshipType, + ) + return err +} + +// pushMessages replaces a session's messages and tool calls +// in PG. It skips the replacement when the PG message count +// already matches the local count, avoiding redundant work +// for metadata-only changes. +func (s *Sync) pushMessages( + ctx context.Context, + tx *sql.Tx, + sessionID string, + full bool, +) (int, error) { + localCount, err := s.local.MessageCount(sessionID) + if err != nil { + return 0, fmt.Errorf( + "counting local messages: %w", err, + ) + } + if localCount == 0 { + if _, err := tx.ExecContext(ctx, + `DELETE FROM tool_calls WHERE session_id = $1`, + sessionID, + ); err != nil { + return 0, fmt.Errorf( + "deleting stale pg tool_calls: %w", err, + ) + } + if _, err := tx.ExecContext(ctx, + `DELETE FROM messages WHERE session_id = $1`, + sessionID, + ); err != nil { + return 0, fmt.Errorf( + "deleting stale pg messages: %w", err, + ) + } + return 0, nil + } + + var pgCount int + var pgContentSum, pgContentMax, pgContentMin int64 + var pgToolCallCount int + var pgTCContentSum int64 + if err := tx.QueryRowContext(ctx, + `SELECT COUNT(*), + COALESCE(SUM(content_length), 0), + COALESCE(MAX(content_length), 0), + COALESCE(MIN(content_length), 0) + FROM messages + WHERE session_id = $1`, + sessionID, + ).Scan( + &pgCount, &pgContentSum, + &pgContentMax, &pgContentMin, + ); err != nil { + return 0, fmt.Errorf( + "counting pg messages: %w", err, + ) + } + if err := tx.QueryRowContext(ctx, + `SELECT COUNT(*), + COALESCE(SUM(result_content_length), 0) + FROM tool_calls + WHERE session_id = $1`, + sessionID, + ).Scan(&pgToolCallCount, &pgTCContentSum); err != nil { + return 0, fmt.Errorf( + "counting pg tool_calls: %w", err, + ) + } + + if !full && pgCount == localCount && pgCount > 0 { + localSum, localMax, localMin, err := s.local.MessageContentFingerprint(sessionID) + if err != nil { + return 0, fmt.Errorf( + "computing local content fingerprint: %w", + err, + ) + } + localTCCount, err := s.local.ToolCallCount(sessionID) + if err != nil { + return 0, fmt.Errorf( + "counting local tool_calls: %w", err, + ) + } + localTCSum, err := s.local.ToolCallContentFingerprint(sessionID) + if err != nil { + return 0, fmt.Errorf( + "computing local tool_call content "+ + "fingerprint: %w", err, + ) + } + if localSum == pgContentSum && + localMax == pgContentMax && + localMin == pgContentMin && + localTCCount == pgToolCallCount && + localTCSum == pgTCContentSum { + return 0, nil + } + } + + if _, err := tx.ExecContext(ctx, ` + DELETE FROM tool_calls + WHERE session_id = $1 + `, sessionID); err != nil { + return 0, fmt.Errorf( + "deleting pg tool_calls: %w", err, + ) + } + if _, err := tx.ExecContext(ctx, ` + DELETE FROM messages + WHERE session_id = $1 + `, sessionID); err != nil { + return 0, fmt.Errorf( + "deleting pg messages: %w", err, + ) + } + + count := 0 + startOrdinal := 0 + for { + msgs, err := s.local.GetMessages( + ctx, sessionID, startOrdinal, + db.MaxMessageLimit, true, + ) + if err != nil { + return count, fmt.Errorf( + "reading local messages: %w", err, + ) + } + if len(msgs) == 0 { + break + } + + nextOrdinal := msgs[len(msgs)-1].Ordinal + 1 + if nextOrdinal <= startOrdinal { + return count, fmt.Errorf( + "pushMessages %s: ordinal did not "+ + "advance (start=%d, last=%d)", + sessionID, startOrdinal, + msgs[len(msgs)-1].Ordinal, + ) + } + + if err := bulkInsertMessages( + ctx, tx, sessionID, msgs, + ); err != nil { + return count, err + } + if err := bulkInsertToolCalls( + ctx, tx, sessionID, msgs, + ); err != nil { + return count, err + } + count += len(msgs) + startOrdinal = nextOrdinal + } + + return count, nil +} + +const msgInsertBatch = 100 + +// bulkInsertMessages inserts messages using multi-row VALUES. +func bulkInsertMessages( + ctx context.Context, tx *sql.Tx, + sessionID string, msgs []db.Message, +) error { + for i := 0; i < len(msgs); i += msgInsertBatch { + end := min(i+msgInsertBatch, len(msgs)) + batch := msgs[i:end] + + var b strings.Builder + b.WriteString(`INSERT INTO messages ( + session_id, ordinal, role, content, + timestamp, has_thinking, has_tool_use, + content_length) VALUES `) + args := make([]any, 0, len(batch)*8) + for j, m := range batch { + if j > 0 { + b.WriteByte(',') + } + p := j*8 + 1 + fmt.Fprintf(&b, + "($%d,$%d,$%d,$%d,$%d,$%d,$%d,$%d)", + p, p+1, p+2, p+3, + p+4, p+5, p+6, p+7, + ) + var ts any + if m.Timestamp != "" { + if t, ok := ParseSQLiteTimestamp( + m.Timestamp, + ); ok { + ts = t + } + } + args = append(args, + sessionID, m.Ordinal, m.Role, + sanitizePG(m.Content), ts, + m.HasThinking, + m.HasToolUse, m.ContentLength, + ) + } + if _, err := tx.ExecContext( + ctx, b.String(), args..., + ); err != nil { + return fmt.Errorf( + "bulk inserting messages: %w", err, + ) + } + } + return nil +} + +// bulkInsertToolCalls inserts tool calls using multi-row VALUES. +func bulkInsertToolCalls( + ctx context.Context, tx *sql.Tx, + sessionID string, msgs []db.Message, +) error { + // Collect all tool calls from messages. + type tcRow struct { + ordinal int + index int + tc db.ToolCall + } + var rows []tcRow + for _, m := range msgs { + for i, tc := range m.ToolCalls { + rows = append(rows, tcRow{m.Ordinal, i, tc}) + } + } + if len(rows) == 0 { + return nil + } + + const tcBatch = 50 + for i := 0; i < len(rows); i += tcBatch { + end := min(i+tcBatch, len(rows)) + batch := rows[i:end] + + var b strings.Builder + b.WriteString(`INSERT INTO tool_calls ( + session_id, tool_name, category, + call_index, tool_use_id, input_json, + skill_name, result_content_length, + result_content, subagent_session_id, + message_ordinal) VALUES `) + args := make([]any, 0, len(batch)*11) + for j, r := range batch { + if j > 0 { + b.WriteByte(',') + } + p := j*11 + 1 + fmt.Fprintf(&b, + "($%d,$%d,$%d,$%d,$%d,$%d,"+ + "$%d,$%d,$%d,$%d,$%d)", + p, p+1, p+2, p+3, p+4, p+5, + p+6, p+7, p+8, p+9, p+10, + ) + args = append(args, + sessionID, + sanitizePG(r.tc.ToolName), + sanitizePG(r.tc.Category), + r.index, + sanitizePG(r.tc.ToolUseID), + nilIfEmpty(r.tc.InputJSON), + nilIfEmpty(r.tc.SkillName), + nilIfZero(r.tc.ResultContentLength), + nilIfEmpty(r.tc.ResultContent), + nilIfEmpty(r.tc.SubagentSessionID), + r.ordinal, + ) + } + if _, err := tx.ExecContext( + ctx, b.String(), args..., + ); err != nil { + return fmt.Errorf( + "bulk inserting tool_calls: %w", err, + ) + } + } + return nil +} + +// normalizeSyncTimestamps ensures schema exists and normalizes +// local sync state timestamps. +func (s *Sync) normalizeSyncTimestamps( + ctx context.Context, +) error { + s.schemaMu.Lock() + defer s.schemaMu.Unlock() + if !s.schemaDone { + if err := EnsureSchema( + ctx, s.pg, s.schema, + ); err != nil { + return err + } + s.schemaDone = true + } + return NormalizeLocalSyncStateTimestamps(s.local) +} + +// sanitizePG strips null bytes and replaces invalid UTF-8 +// sequences so text can be safely inserted into PostgreSQL, +// which enforces strict UTF-8 encoding. +func sanitizePG(s string) string { + s = strings.ReplaceAll(s, "\x00", "") + s = strings.ToValidUTF8(s, "") + return s +} + +func nilIfEmpty(s string) any { + s = sanitizePG(s) + if s == "" { + return nil + } + return s +} + +func nilIfZero(n int) any { + if n == 0 { + return nil + } + return n +} diff --git a/internal/postgres/push_test.go b/internal/postgres/push_test.go new file mode 100644 index 00000000..b8773cad --- /dev/null +++ b/internal/postgres/push_test.go @@ -0,0 +1,421 @@ +package postgres + +import ( + "encoding/json" + "testing" + + "github.com/wesm/agentsview/internal/db" +) + +type syncStateReaderStub struct { + value string + err error +} + +func (s syncStateReaderStub) GetSyncState( + key string, +) (string, error) { + return s.value, s.err +} + +func (s syncStateReaderStub) SetSyncState( + string, string, +) error { + return nil +} + +type syncStateStoreStub struct { + values map[string]string +} + +func (s *syncStateStoreStub) GetSyncState( + key string, +) (string, error) { + return s.values[key], nil +} + +func (s *syncStateStoreStub) SetSyncState( + key, value string, +) error { + if s.values == nil { + s.values = make(map[string]string) + } + s.values[key] = value + return nil +} + +func TestReadPushBoundaryStateValidity(t *testing.T) { + const cutoff = "2026-03-11T12:34:56.123Z" + + tests := []struct { + name string + raw string + wantValid bool + wantLen int + }{ + { + name: "missing state", + raw: "", + wantValid: false, + wantLen: 0, + }, + { + name: "bare map without cutoff", + raw: `{"sess-001":"fingerprint"}`, + wantValid: false, + wantLen: 0, + }, + { + name: "malformed payload", + raw: `{`, + wantValid: false, + wantLen: 0, + }, + { + name: "stale cutoff", + raw: `{"cutoff":"2026-03-11T12:34:56.122Z","fingerprints":{"sess-001":"fingerprint"}}`, + wantValid: false, + wantLen: 0, + }, + { + name: "matching cutoff", + raw: `{"cutoff":"2026-03-11T12:34:56.123Z","fingerprints":{"sess-001":"fingerprint"}}`, + wantValid: true, + wantLen: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, got, valid, err := readBoundaryAndFingerprints( + syncStateReaderStub{value: tc.raw}, + cutoff, + ) + if err != nil { + t.Fatalf( + "readBoundaryAndFingerprints: %v", err, + ) + } + if valid != tc.wantValid { + t.Fatalf( + "valid = %v, want %v", + valid, tc.wantValid, + ) + } + if len(got) != tc.wantLen { + t.Fatalf( + "len(state) = %d, want %d", + len(got), tc.wantLen, + ) + } + }) + } +} + +func TestLocalSessionSyncMarkerNormalizesSecondPrecisionTimestamps(t *testing.T) { + startedAt := "2026-03-11T12:34:56Z" + endedAt := "2026-03-11T12:34:56.123Z" + + got := localSessionSyncMarker(db.Session{ + CreatedAt: "2026-03-11T12:34:55Z", + StartedAt: &startedAt, + EndedAt: &endedAt, + }) + + if got != endedAt { + t.Fatalf( + "localSessionSyncMarker = %q, want %q", + got, endedAt, + ) + } +} + +func TestSessionPushFingerprintDiffers(t *testing.T) { + base := db.Session{ + ID: "sess-001", + Project: "proj", + Machine: "laptop", + Agent: "claude", + MessageCount: 5, + UserMessageCount: 2, + CreatedAt: "2026-03-11T12:00:00Z", + } + + fp1 := sessionPushFingerprint(base) + + tests := []struct { + name string + modify func(s db.Session) db.Session + }{ + { + name: "message count change", + modify: func(s db.Session) db.Session { + s.MessageCount = 6 + return s + }, + }, + { + name: "display name change", + modify: func(s db.Session) db.Session { + name := "new name" + s.DisplayName = &name + return s + }, + }, + { + name: "ended at change", + modify: func(s db.Session) db.Session { + ended := "2026-03-11T13:00:00Z" + s.EndedAt = &ended + return s + }, + }, + { + name: "file hash change", + modify: func(s db.Session) db.Session { + hash := "abc123" + s.FileHash = &hash + return s + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + modified := tc.modify(base) + fp2 := sessionPushFingerprint(modified) + if fp1 == fp2 { + t.Fatalf( + "fingerprint should differ after %s", + tc.name, + ) + } + }) + } + + if fp1 != sessionPushFingerprint(base) { + t.Fatal( + "identical sessions should produce " + + "identical fingerprints", + ) + } +} + +func TestSessionPushFingerprintNoFieldCollisions( + t *testing.T, +) { + s1 := db.Session{ + ID: "ab", + Project: "cd", + CreatedAt: "2026-03-11T12:00:00Z", + } + s2 := db.Session{ + ID: "a", + Project: "bcd", + CreatedAt: "2026-03-11T12:00:00Z", + } + if sessionPushFingerprint(s1) == sessionPushFingerprint(s2) { + t.Fatal( + "length-prefixed fingerprints should not collide", + ) + } +} + +func TestFinalizePushStatePersistsEmptyBoundary( + t *testing.T, +) { + const cutoff = "2026-03-11T12:34:56.123Z" + + store := &syncStateStoreStub{} + if err := finalizePushState( + store, cutoff, nil, nil, + ); err != nil { + t.Fatalf("finalizePushState: %v", err) + } + if got := store.values["last_push_at"]; got != cutoff { + t.Fatalf( + "last_push_at = %q, want %q", got, cutoff, + ) + } + + raw := store.values[lastPushBoundaryStateKey] + if raw == "" { + t.Fatal( + "last_push_boundary_state should be written", + ) + } + + var state pushBoundaryState + if err := json.Unmarshal( + []byte(raw), &state, + ); err != nil { + t.Fatalf("unmarshal boundary state: %v", err) + } + if state.Cutoff != cutoff { + t.Fatalf( + "boundary cutoff = %q, want %q", + state.Cutoff, cutoff, + ) + } + if len(state.Fingerprints) != 0 { + t.Fatalf( + "boundary fingerprints = %v, want empty", + state.Fingerprints, + ) + } +} + +func TestFinalizePushStateMergesPriorFingerprints( + t *testing.T, +) { + const cutoff = "2026-03-11T12:34:56.123Z" + + priorFingerprints := map[string]string{ + "sess-001": "fp-001", + } + + cycle2Sessions := []db.Session{ + { + ID: "sess-002", + CreatedAt: "2026-03-11T12:00:00Z", + MessageCount: 3, + }, + } + + store := &syncStateStoreStub{} + if err := finalizePushState( + store, cutoff, cycle2Sessions, + priorFingerprints, + ); err != nil { + t.Fatalf("finalizePushState: %v", err) + } + + raw := store.values[lastPushBoundaryStateKey] + if raw == "" { + t.Fatal( + "last_push_boundary_state should be written", + ) + } + + var state pushBoundaryState + if err := json.Unmarshal( + []byte(raw), &state, + ); err != nil { + t.Fatalf("unmarshal boundary state: %v", err) + } + + if len(state.Fingerprints) != 2 { + t.Fatalf( + "len(fingerprints) = %d, want 2", + len(state.Fingerprints), + ) + } + if state.Fingerprints["sess-001"] != "fp-001" { + t.Fatalf( + "sess-001 fingerprint = %q, want %q", + state.Fingerprints["sess-001"], "fp-001", + ) + } + if _, ok := state.Fingerprints["sess-002"]; !ok { + t.Fatal("sess-002 fingerprint should be present") + } +} + +func TestSanitizePG(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "clean string", + input: "hello world", + want: "hello world", + }, + { + name: "null bytes stripped", + input: "hello\x00world", + want: "helloworld", + }, + { + name: "multiple null bytes", + input: "\x00a\x00b\x00", + want: "ab", + }, + { + name: "truncated 3-byte sequence", + input: "hello\xe2world", + want: "helloworld", + }, + { + name: "truncated 2 of 3 bytes", + input: "hello\xe2\x80world", + want: "helloworld", + }, + { + name: "valid multibyte preserved", + // U+2026 HORIZONTAL ELLIPSIS = e2 80 a6 + input: "hello\xe2\x80\xa6world", + want: "hello\xe2\x80\xa6world", + }, + { + name: "null and invalid combined", + input: "a\x00b\xe2c", + want: "abc", + }, + { + name: "empty string", + input: "", + want: "", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := sanitizePG(tc.input) + if got != tc.want { + t.Errorf( + "sanitizePG(%q) = %q, want %q", + tc.input, got, tc.want, + ) + } + }) + } +} + +func TestNilIfEmptySanitizes(t *testing.T) { + got := nilIfEmpty("hello\x00world") + if got != "helloworld" { + t.Errorf( + "nilIfEmpty with null byte = %q, want %q", + got, "helloworld", + ) + } + + if nilIfEmpty("") != nil { + t.Error("nilIfEmpty(\"\") should be nil") + } + + // A string that reduces to empty after sanitization + // should return nil, not "". + if nilIfEmpty("\x00") != nil { + t.Error("nilIfEmpty(\"\\x00\") should be nil") + } +} + +func TestNilStrSanitizes(t *testing.T) { + s := "hello\xe2world" + got := nilStr(&s) + if got != "helloworld" { + t.Errorf( + "nilStr with invalid UTF-8 = %q, want %q", + got, "helloworld", + ) + } + + // A *string that reduces to empty after sanitization + // should return nil. + nul := "\x00" + if nilStr(&nul) != nil { + t.Error("nilStr(\"\\x00\") should be nil") + } +} diff --git a/internal/postgres/schema.go b/internal/postgres/schema.go new file mode 100644 index 00000000..8b58f37a --- /dev/null +++ b/internal/postgres/schema.go @@ -0,0 +1,206 @@ +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/jackc/pgx/v5/pgconn" +) + +// SchemaVersion is incremented when the PG schema changes in a +// way that requires migration logic. EnsureSchema writes it to +// sync_metadata so future versions can detect what they're +// working with. +const SchemaVersion = 1 + +// coreDDL creates the tables and indexes. It uses unqualified +// names because Open() sets search_path to the target schema. +const coreDDL = ` +CREATE TABLE IF NOT EXISTS sync_metadata ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + machine TEXT NOT NULL, + project TEXT NOT NULL, + agent TEXT NOT NULL, + first_message TEXT, + display_name TEXT, + created_at TIMESTAMPTZ, + started_at TIMESTAMPTZ, + ended_at TIMESTAMPTZ, + deleted_at TIMESTAMPTZ, + message_count INT NOT NULL DEFAULT 0, + user_message_count INT NOT NULL DEFAULT 0, + parent_session_id TEXT, + relationship_type TEXT NOT NULL DEFAULT '', + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS messages ( + session_id TEXT NOT NULL, + ordinal INT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + timestamp TIMESTAMPTZ, + has_thinking BOOLEAN NOT NULL DEFAULT FALSE, + has_tool_use BOOLEAN NOT NULL DEFAULT FALSE, + content_length INT NOT NULL DEFAULT 0, + PRIMARY KEY (session_id, ordinal), + FOREIGN KEY (session_id) + REFERENCES sessions(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS tool_calls ( + id BIGSERIAL PRIMARY KEY, + session_id TEXT NOT NULL, + tool_name TEXT NOT NULL, + category TEXT NOT NULL, + call_index INT NOT NULL DEFAULT 0, + tool_use_id TEXT NOT NULL DEFAULT '', + input_json TEXT, + skill_name TEXT, + result_content_length INT, + result_content TEXT, + subagent_session_id TEXT, + message_ordinal INT NOT NULL, + FOREIGN KEY (session_id) + REFERENCES sessions(id) ON DELETE CASCADE +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_tool_calls_dedup + ON tool_calls (session_id, message_ordinal, call_index); + +CREATE INDEX IF NOT EXISTS idx_tool_calls_session + ON tool_calls (session_id); +` + +// EnsureSchema creates the schema (if needed), then runs +// idempotent CREATE TABLE / ALTER TABLE statements. The schema +// parameter is the unquoted schema name (e.g. "agentsview"). +// +// After CREATE SCHEMA, all table DDL uses unqualified names +// because Open() sets search_path to the target schema. +func EnsureSchema( + ctx context.Context, db *sql.DB, schema string, +) error { + quoted, err := quoteIdentifier(schema) + if err != nil { + return fmt.Errorf("invalid schema name: %w", err) + } + if _, err := db.ExecContext(ctx, + "CREATE SCHEMA IF NOT EXISTS "+quoted, + ); err != nil { + return fmt.Errorf("creating pg schema: %w", err) + } + if _, err := db.ExecContext(ctx, coreDDL); err != nil { + return fmt.Errorf("creating pg tables: %w", err) + } + + // Idempotent column additions for forward compatibility. + alters := []struct { + stmt string + desc string + }{ + { + `ALTER TABLE sessions + ADD COLUMN IF NOT EXISTS deleted_at TIMESTAMPTZ`, + "adding sessions.deleted_at", + }, + { + `ALTER TABLE sessions + ADD COLUMN IF NOT EXISTS created_at TIMESTAMPTZ`, + "adding sessions.created_at", + }, + { + `ALTER TABLE tool_calls + ADD COLUMN IF NOT EXISTS call_index + INT NOT NULL DEFAULT 0`, + "adding tool_calls.call_index", + }, + } + for _, a := range alters { + if _, err := db.ExecContext(ctx, a.stmt); err != nil { + return fmt.Errorf("%s: %w", a.desc, err) + } + } + + // Record schema version for future migration detection. + if _, err := db.ExecContext(ctx, + `INSERT INTO sync_metadata (key, value) + VALUES ('schema_version', $1) + ON CONFLICT (key) DO UPDATE + SET value = EXCLUDED.value + WHERE sync_metadata.value::int < EXCLUDED.value::int`, + fmt.Sprintf("%d", SchemaVersion), + ); err != nil { + return fmt.Errorf("setting schema version: %w", err) + } + return nil +} + +// GetSchemaVersion reads the schema version from sync_metadata. +// Returns 0 if the key is missing (pre-versioned schema). +func GetSchemaVersion( + ctx context.Context, db *sql.DB, +) (int, error) { + var v int + err := db.QueryRowContext(ctx, + `SELECT value::int FROM sync_metadata + WHERE key = 'schema_version'`, + ).Scan(&v) + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf( + "reading schema version: %w", err, + ) + } + return v, nil +} + +// CheckSchemaCompat verifies that the PG schema has all columns +// required by query paths. This is a read-only probe that works +// against any PG role. Returns nil if compatible, or an error +// describing what is missing. +func CheckSchemaCompat( + ctx context.Context, db *sql.DB, +) error { + rows, err := db.QueryContext(ctx, + `SELECT id, created_at, deleted_at, updated_at + FROM sessions LIMIT 0`) + if err != nil { + return fmt.Errorf( + "sessions table missing required columns: %w", + err, + ) + } + rows.Close() + + rows, err = db.QueryContext(ctx, + `SELECT call_index FROM tool_calls LIMIT 0`) + if err != nil { + return fmt.Errorf( + "tool_calls table missing required columns: %w", + err, + ) + } + rows.Close() + return nil +} + +// IsReadOnlyError returns true when the error indicates a PG +// read-only or insufficient-privilege condition (SQLSTATE 25006 +// or 42501). Uses pgconn.PgError for reliable SQLSTATE matching. +func IsReadOnlyError(err error) bool { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + return pgErr.Code == "25006" || pgErr.Code == "42501" + } + return false +} diff --git a/internal/postgres/sessions.go b/internal/postgres/sessions.go new file mode 100644 index 00000000..c02ae1ad --- /dev/null +++ b/internal/postgres/sessions.go @@ -0,0 +1,615 @@ +package postgres + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "database/sql" + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + "github.com/wesm/agentsview/internal/db" +) + +// Store wraps a PostgreSQL connection for read-only session +// queries. +type Store struct { + pg *sql.DB + cursorMu sync.RWMutex + cursorSecret []byte +} + +// pgSessionCols is the column list for standard PG session +// queries. PG has no file_path, file_size, file_mtime, +// file_hash, or local_modified_at columns. +const pgSessionCols = `id, project, machine, agent, + first_message, display_name, created_at, started_at, + ended_at, message_count, user_message_count, + parent_session_id, relationship_type, deleted_at` + +// paramBuilder generates numbered PostgreSQL placeholders. +type paramBuilder struct { + n int + args []any +} + +func (pb *paramBuilder) add(v any) string { + pb.n++ + pb.args = append(pb.args, v) + return fmt.Sprintf("$%d", pb.n) +} + +// scanPGSession scans a row with pgSessionCols into a +// db.Session, converting TIMESTAMPTZ columns to string. +func scanPGSession( + rs interface{ Scan(...any) error }, +) (db.Session, error) { + var s db.Session + var createdAt *time.Time + var startedAt, endedAt, deletedAt *time.Time + err := rs.Scan( + &s.ID, &s.Project, &s.Machine, &s.Agent, + &s.FirstMessage, &s.DisplayName, + &createdAt, &startedAt, &endedAt, + &s.MessageCount, &s.UserMessageCount, + &s.ParentSessionID, &s.RelationshipType, + &deletedAt, + ) + if err != nil { + return s, err + } + if createdAt != nil { + s.CreatedAt = FormatISO8601(*createdAt) + } + if startedAt != nil { + str := FormatISO8601(*startedAt) + s.StartedAt = &str + } + if endedAt != nil { + str := FormatISO8601(*endedAt) + s.EndedAt = &str + } + if deletedAt != nil { + str := FormatISO8601(*deletedAt) + s.DeletedAt = &str + } + return s, nil +} + +// scanPGSessionRows iterates rows and scans each. +func scanPGSessionRows( + rows *sql.Rows, +) ([]db.Session, error) { + sessions := []db.Session{} + for rows.Next() { + s, err := scanPGSession(rows) + if err != nil { + return nil, fmt.Errorf( + "scanning session: %w", err, + ) + } + sessions = append(sessions, s) + } + return sessions, rows.Err() +} + +// pgRootSessionFilter is the base WHERE clause for root +// sessions. +const pgRootSessionFilter = `message_count > 0 + AND relationship_type NOT IN ('subagent', 'fork') + AND deleted_at IS NULL` + +// buildPGSessionFilter returns a WHERE clause with $N +// placeholders and the corresponding args. +func buildPGSessionFilter( + f db.SessionFilter, +) (string, []any) { + pb := ¶mBuilder{} + basePreds := []string{ + "message_count > 0", + "deleted_at IS NULL", + } + if !f.IncludeChildren { + basePreds = append(basePreds, + "relationship_type NOT IN ('subagent', 'fork')") + } + + var filterPreds []string + + if f.Project != "" { + filterPreds = append(filterPreds, + "project = "+pb.add(f.Project)) + } + if f.ExcludeProject != "" { + filterPreds = append(filterPreds, + "project != "+pb.add(f.ExcludeProject)) + } + if f.Machine != "" { + filterPreds = append(filterPreds, + "machine = "+pb.add(f.Machine)) + } + if f.Agent != "" { + agents := strings.Split(f.Agent, ",") + if len(agents) == 1 { + filterPreds = append(filterPreds, + "agent = "+pb.add(agents[0])) + } else { + placeholders := make([]string, len(agents)) + for i, a := range agents { + placeholders[i] = pb.add(a) + } + filterPreds = append(filterPreds, + "agent IN ("+ + strings.Join(placeholders, ",")+ + ")", + ) + } + } + if f.Date != "" { + filterPreds = append(filterPreds, + "DATE(COALESCE(started_at, created_at) AT TIME ZONE 'UTC') = "+ + pb.add(f.Date)+"::date") + } + if f.DateFrom != "" { + filterPreds = append(filterPreds, + "DATE(COALESCE(started_at, created_at) AT TIME ZONE 'UTC') >= "+ + pb.add(f.DateFrom)+"::date") + } + if f.DateTo != "" { + filterPreds = append(filterPreds, + "DATE(COALESCE(started_at, created_at) AT TIME ZONE 'UTC') <= "+ + pb.add(f.DateTo)+"::date") + } + if f.ActiveSince != "" { + filterPreds = append(filterPreds, + "COALESCE(ended_at, started_at, created_at) >= "+ + pb.add(f.ActiveSince)+"::timestamptz") + } + if f.MinMessages > 0 { + filterPreds = append(filterPreds, + "message_count >= "+pb.add(f.MinMessages)) + } + if f.MaxMessages > 0 { + filterPreds = append(filterPreds, + "message_count <= "+pb.add(f.MaxMessages)) + } + if f.MinUserMessages > 0 { + filterPreds = append(filterPreds, + "user_message_count >= "+ + pb.add(f.MinUserMessages)) + } + + oneShotPred := "" + if f.ExcludeOneShot { + if f.IncludeChildren { + oneShotPred = "user_message_count > 1" + } else { + filterPreds = append(filterPreds, + "user_message_count > 1") + } + } + + hasFilters := len(filterPreds) > 0 || oneShotPred != "" + if !f.IncludeChildren || !hasFilters { + allPreds := append(basePreds, filterPreds...) + return strings.Join(allPreds, " AND "), pb.args + } + + baseWhere := strings.Join(basePreds, " AND ") + + rootMatchParts := append([]string{}, filterPreds...) + if oneShotPred != "" { + rootMatchParts = append(rootMatchParts, oneShotPred) + } + rootMatch := strings.Join(rootMatchParts, " AND ") + + subqWhere := "message_count > 0 AND deleted_at IS NULL" + if rootMatch != "" { + subqWhere += " AND " + rootMatch + } + + where := baseWhere + " AND (" + rootMatch + + " OR parent_session_id IN" + + " (SELECT id FROM sessions WHERE " + + subqWhere + "))" + + return where, pb.args +} + +// EncodeCursor returns a base64-encoded, HMAC-signed cursor. +func (s *Store) EncodeCursor( + endedAt, id string, total ...int, +) string { + t := 0 + if len(total) > 0 { + t = total[0] + } + c := db.SessionCursor{EndedAt: endedAt, ID: id, Total: t} + data, _ := json.Marshal(c) + + s.cursorMu.RLock() + secret := make([]byte, len(s.cursorSecret)) + copy(secret, s.cursorSecret) + s.cursorMu.RUnlock() + + mac := hmac.New(sha256.New, secret) + mac.Write(data) + sig := mac.Sum(nil) + + return base64.RawURLEncoding.EncodeToString(data) + "." + + base64.RawURLEncoding.EncodeToString(sig) +} + +// DecodeCursor parses a base64-encoded cursor string. +func (s *Store) DecodeCursor( + raw string, +) (db.SessionCursor, error) { + parts := strings.Split(raw, ".") + if len(parts) == 1 { + data, err := base64.RawURLEncoding.DecodeString( + parts[0], + ) + if err != nil { + return db.SessionCursor{}, + fmt.Errorf("%w: %v", + db.ErrInvalidCursor, err) + } + var c db.SessionCursor + if err := json.Unmarshal(data, &c); err != nil { + return db.SessionCursor{}, + fmt.Errorf("%w: %v", + db.ErrInvalidCursor, err) + } + c.Total = 0 + return c, nil + } else if len(parts) != 2 { + return db.SessionCursor{}, + fmt.Errorf("%w: invalid format", + db.ErrInvalidCursor) + } + + payload := parts[0] + sigStr := parts[1] + + data, err := base64.RawURLEncoding.DecodeString(payload) + if err != nil { + return db.SessionCursor{}, + fmt.Errorf("%w: invalid payload: %v", + db.ErrInvalidCursor, err) + } + + sig, err := base64.RawURLEncoding.DecodeString(sigStr) + if err != nil { + return db.SessionCursor{}, + fmt.Errorf( + "%w: invalid signature encoding: %v", + db.ErrInvalidCursor, err) + } + + s.cursorMu.RLock() + secret := make([]byte, len(s.cursorSecret)) + copy(secret, s.cursorSecret) + s.cursorMu.RUnlock() + + mac := hmac.New(sha256.New, secret) + mac.Write(data) + expectedSig := mac.Sum(nil) + + if !hmac.Equal(sig, expectedSig) { + return db.SessionCursor{}, + fmt.Errorf("%w: signature mismatch", + db.ErrInvalidCursor) + } + + var c db.SessionCursor + if err := json.Unmarshal(data, &c); err != nil { + return db.SessionCursor{}, + fmt.Errorf("%w: invalid json: %v", + db.ErrInvalidCursor, err) + } + return c, nil +} + +// ListSessions returns a cursor-paginated list of sessions. +func (s *Store) ListSessions( + ctx context.Context, f db.SessionFilter, +) (db.SessionPage, error) { + if f.Limit <= 0 || f.Limit > db.MaxSessionLimit { + f.Limit = db.DefaultSessionLimit + } + + where, args := buildPGSessionFilter(f) + + var total int + var cur db.SessionCursor + if f.Cursor != "" { + var err error + cur, err = s.DecodeCursor(f.Cursor) + if err != nil { + return db.SessionPage{}, err + } + total = cur.Total + } + + if total <= 0 { + countQ := "SELECT COUNT(*) FROM sessions WHERE " + + where + if err := s.pg.QueryRowContext( + ctx, countQ, args..., + ).Scan(&total); err != nil { + return db.SessionPage{}, + fmt.Errorf("counting sessions: %w", err) + } + } + + cursorPB := ¶mBuilder{ + n: len(args), + args: append([]any{}, args...), + } + cursorWhere := where + if f.Cursor != "" { + eaParam := cursorPB.add(cur.EndedAt) + idParam := cursorPB.add(cur.ID) + cursorWhere += ` AND ( + COALESCE(ended_at, started_at, created_at), + id + ) < (` + eaParam + `::timestamptz, ` + + idParam + `)` + } + + limitParam := cursorPB.add(f.Limit + 1) + query := "SELECT " + pgSessionCols + + " FROM sessions WHERE " + cursorWhere + ` + ORDER BY COALESCE( + ended_at, started_at, created_at + ) DESC, id DESC + LIMIT ` + limitParam + + rows, err := s.pg.QueryContext( + ctx, query, cursorPB.args..., + ) + if err != nil { + return db.SessionPage{}, + fmt.Errorf("querying sessions: %w", err) + } + defer rows.Close() + + sessions, err := scanPGSessionRows(rows) + if err != nil { + return db.SessionPage{}, err + } + + page := db.SessionPage{ + Sessions: sessions, Total: total, + } + if len(sessions) > f.Limit { + page.Sessions = sessions[:f.Limit] + last := page.Sessions[f.Limit-1] + ea := last.CreatedAt + if last.StartedAt != nil && *last.StartedAt != "" { + ea = *last.StartedAt + } + if last.EndedAt != nil && *last.EndedAt != "" { + ea = *last.EndedAt + } + page.NextCursor = s.EncodeCursor(ea, last.ID, total) + } + + return page, nil +} + +// GetSession returns a single session by ID, excluding +// soft-deleted sessions. +func (s *Store) GetSession( + ctx context.Context, id string, +) (*db.Session, error) { + row := s.pg.QueryRowContext( + ctx, + "SELECT "+pgSessionCols+ + " FROM sessions WHERE id = $1"+ + " AND deleted_at IS NULL", + id, + ) + sess, err := scanPGSession(row) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf( + "getting session %s: %w", id, err, + ) + } + return &sess, nil +} + +// GetSessionFull returns a single session by ID including +// soft-deleted sessions. +func (s *Store) GetSessionFull( + ctx context.Context, id string, +) (*db.Session, error) { + row := s.pg.QueryRowContext( + ctx, + "SELECT "+pgSessionCols+ + " FROM sessions WHERE id = $1", + id, + ) + sess, err := scanPGSession(row) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf( + "getting session full %s: %w", id, err, + ) + } + return &sess, nil +} + +// GetChildSessions returns sessions whose +// parent_session_id matches the given parentID. +func (s *Store) GetChildSessions( + ctx context.Context, parentID string, +) ([]db.Session, error) { + query := "SELECT " + pgSessionCols + + " FROM sessions" + + " WHERE parent_session_id = $1" + + " AND deleted_at IS NULL" + + " ORDER BY COALESCE(started_at, created_at) ASC" + rows, err := s.pg.QueryContext(ctx, query, parentID) + if err != nil { + return nil, fmt.Errorf( + "querying child sessions for %s: %w", + parentID, err, + ) + } + defer rows.Close() + + return scanPGSessionRows(rows) +} + +// GetStats returns database statistics, counting only root +// sessions with messages. +func (s *Store) GetStats( + ctx context.Context, excludeOneShot bool, +) (db.Stats, error) { + filter := pgRootSessionFilter + if excludeOneShot { + filter += " AND user_message_count > 1" + } + query := fmt.Sprintf(` + SELECT + (SELECT COUNT(*) FROM sessions + WHERE %s), + (SELECT COALESCE(SUM(message_count), 0) + FROM sessions WHERE %s), + (SELECT COUNT(DISTINCT project) FROM sessions + WHERE %s), + (SELECT COUNT(DISTINCT machine) FROM sessions + WHERE %s), + (SELECT MIN(COALESCE(started_at, created_at)) + FROM sessions + WHERE %s)`, + filter, filter, filter, filter, filter) + + var st db.Stats + var earliest *time.Time + err := s.pg.QueryRowContext(ctx, query).Scan( + &st.SessionCount, + &st.MessageCount, + &st.ProjectCount, + &st.MachineCount, + &earliest, + ) + if err != nil { + return db.Stats{}, + fmt.Errorf("fetching stats: %w", err) + } + if earliest != nil { + str := FormatISO8601(*earliest) + st.EarliestSession = &str + } + return st, nil +} + +// GetProjects returns project names with session counts. +func (s *Store) GetProjects( + ctx context.Context, excludeOneShot bool, +) ([]db.ProjectInfo, error) { + q := `SELECT project, COUNT(*) as session_count + FROM sessions + WHERE message_count > 0 + AND relationship_type NOT IN ('subagent', 'fork') + AND deleted_at IS NULL` + if excludeOneShot { + q += " AND user_message_count > 1" + } + q += " GROUP BY project ORDER BY project" + rows, err := s.pg.QueryContext(ctx, q) + if err != nil { + return nil, fmt.Errorf( + "querying projects: %w", err, + ) + } + defer rows.Close() + + projects := []db.ProjectInfo{} + for rows.Next() { + var pi db.ProjectInfo + if err := rows.Scan( + &pi.Name, &pi.SessionCount, + ); err != nil { + return nil, fmt.Errorf( + "scanning project: %w", err, + ) + } + projects = append(projects, pi) + } + return projects, rows.Err() +} + +// GetAgents returns distinct agent names with session counts. +func (s *Store) GetAgents( + ctx context.Context, excludeOneShot bool, +) ([]db.AgentInfo, error) { + q := `SELECT agent, COUNT(*) as session_count + FROM sessions + WHERE message_count > 0 AND agent <> '' + AND deleted_at IS NULL + AND relationship_type NOT IN ('subagent', 'fork')` + if excludeOneShot { + q += " AND user_message_count > 1" + } + q += " GROUP BY agent ORDER BY agent" + rows, err := s.pg.QueryContext(ctx, q) + if err != nil { + return nil, fmt.Errorf( + "querying agents: %w", err, + ) + } + defer rows.Close() + + agents := []db.AgentInfo{} + for rows.Next() { + var a db.AgentInfo + if err := rows.Scan( + &a.Name, &a.SessionCount, + ); err != nil { + return nil, fmt.Errorf( + "scanning agent: %w", err, + ) + } + agents = append(agents, a) + } + return agents, rows.Err() +} + +// GetMachines returns distinct machine names. +func (s *Store) GetMachines( + ctx context.Context, excludeOneShot bool, +) ([]string, error) { + q := `SELECT DISTINCT machine FROM sessions + WHERE deleted_at IS NULL` + if excludeOneShot { + q += " AND user_message_count > 1" + } + q += " ORDER BY machine" + rows, err := s.pg.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + + machines := []string{} + for rows.Next() { + var m string + if err := rows.Scan(&m); err != nil { + return nil, err + } + machines = append(machines, m) + } + return machines, rows.Err() +} diff --git a/internal/postgres/store.go b/internal/postgres/store.go new file mode 100644 index 00000000..f198b161 --- /dev/null +++ b/internal/postgres/store.go @@ -0,0 +1,187 @@ +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/wesm/agentsview/internal/db" +) + +// Compile-time check: *Store satisfies db.Store. +var _ db.Store = (*Store)(nil) + +// NewStore opens a PostgreSQL connection using the shared Open() +// helper and returns a read-only Store. +// When allowInsecure is true, non-loopback connections without +// TLS produce a warning instead of failing. +func NewStore( + pgURL, schema string, allowInsecure bool, +) (*Store, error) { + pg, err := Open(pgURL, schema, allowInsecure) + if err != nil { + return nil, err + } + return &Store{pg: pg}, nil +} + +// DB returns the underlying *sql.DB for operations that need +// direct access (e.g. schema compatibility checks). +func (s *Store) DB() *sql.DB { return s.pg } + +// Close closes the underlying database connection. +func (s *Store) Close() error { + return s.pg.Close() +} + +// SetCursorSecret sets the HMAC key used for cursor signing. +func (s *Store) SetCursorSecret(secret []byte) { + s.cursorMu.Lock() + defer s.cursorMu.Unlock() + s.cursorSecret = append([]byte(nil), secret...) +} + +// ReadOnly returns true; this is a read-only data source. +func (s *Store) ReadOnly() bool { return true } + +// GetSessionVersion returns the message count and a hash of +// updated_at for SSE change detection. +func (s *Store) GetSessionVersion( + id string, +) (int, int64, bool) { + var count int + var updatedAt time.Time + err := s.pg.QueryRow( + `SELECT message_count, COALESCE(updated_at, created_at) + FROM sessions WHERE id = $1`, + id, + ).Scan(&count, &updatedAt) + if err != nil { + return 0, 0, false + } + formatted := FormatISO8601(updatedAt) + var h int64 + for _, c := range formatted { + h = h*31 + int64(c) + } + return count, h, true +} + +// ------------------------------------------------------------ +// Write stubs (all return db.ErrReadOnly) +// ------------------------------------------------------------ + +// StarSession is not supported in read-only mode. +func (s *Store) StarSession(_ string) (bool, error) { + return false, db.ErrReadOnly +} + +// UnstarSession is not supported in read-only mode. +func (s *Store) UnstarSession(_ string) error { + return db.ErrReadOnly +} + +// ListStarredSessionIDs returns an empty slice. +func (s *Store) ListStarredSessionIDs( + _ context.Context, +) ([]string, error) { + return []string{}, nil +} + +// BulkStarSessions is not supported in read-only mode. +func (s *Store) BulkStarSessions(_ []string) error { + return db.ErrReadOnly +} + +// PinMessage is not supported in read-only mode. +func (s *Store) PinMessage( + _ string, _ int64, _ *string, +) (int64, error) { + return 0, db.ErrReadOnly +} + +// UnpinMessage is not supported in read-only mode. +func (s *Store) UnpinMessage(_ string, _ int64) error { + return db.ErrReadOnly +} + +// ListPinnedMessages returns an empty slice. +func (s *Store) ListPinnedMessages( + _ context.Context, _ string, +) ([]db.PinnedMessage, error) { + return []db.PinnedMessage{}, nil +} + +// InsertInsight is not supported in read-only mode. +func (s *Store) InsertInsight( + _ db.Insight, +) (int64, error) { + return 0, db.ErrReadOnly +} + +// DeleteInsight is not supported in read-only mode. +func (s *Store) DeleteInsight(_ int64) error { + return db.ErrReadOnly +} + +// ListInsights returns an empty slice. +func (s *Store) ListInsights( + _ context.Context, _ db.InsightFilter, +) ([]db.Insight, error) { + return []db.Insight{}, nil +} + +// GetInsight returns nil. +func (s *Store) GetInsight( + _ context.Context, _ int64, +) (*db.Insight, error) { + return nil, nil +} + +// RenameSession is not supported in read-only mode. +func (s *Store) RenameSession( + _ string, _ *string, +) error { + return db.ErrReadOnly +} + +// SoftDeleteSession is not supported in read-only mode. +func (s *Store) SoftDeleteSession(_ string) error { + return db.ErrReadOnly +} + +// RestoreSession is not supported in read-only mode. +func (s *Store) RestoreSession(_ string) (int64, error) { + return 0, db.ErrReadOnly +} + +// DeleteSessionIfTrashed is not supported in read-only mode. +func (s *Store) DeleteSessionIfTrashed( + _ string, +) (int64, error) { + return 0, db.ErrReadOnly +} + +// ListTrashedSessions returns an empty slice. +func (s *Store) ListTrashedSessions( + _ context.Context, +) ([]db.Session, error) { + return []db.Session{}, nil +} + +// EmptyTrash is not supported in read-only mode. +func (s *Store) EmptyTrash() (int, error) { + return 0, db.ErrReadOnly +} + +// UpsertSession is not supported in read-only mode. +func (s *Store) UpsertSession(_ db.Session) error { + return db.ErrReadOnly +} + +// ReplaceSessionMessages is not supported in read-only mode. +func (s *Store) ReplaceSessionMessages( + _ string, _ []db.Message, +) error { + return db.ErrReadOnly +} diff --git a/internal/postgres/store_test.go b/internal/postgres/store_test.go new file mode 100644 index 00000000..6e5b8acd --- /dev/null +++ b/internal/postgres/store_test.go @@ -0,0 +1,319 @@ +//go:build pgtest + +package postgres + +import ( + "context" + "testing" + + "github.com/wesm/agentsview/internal/db" +) + +const testSchema = "agentsview_store_test" + +// ensureStoreSchema creates the test schema and seed data. +func ensureStoreSchema(t *testing.T, pgURL string) { + t.Helper() + pg, err := Open(pgURL, testSchema, true) + if err != nil { + t.Fatalf("connecting to pg: %v", err) + } + defer pg.Close() + + _, err = pg.Exec(` + DROP SCHEMA IF EXISTS ` + testSchema + ` CASCADE; + `) + if err != nil { + t.Fatalf("dropping schema: %v", err) + } + + ctx := context.Background() + if err := EnsureSchema(ctx, pg, testSchema); err != nil { + t.Fatalf("creating schema: %v", err) + } + + _, err = pg.Exec(` + INSERT INTO sessions + (id, machine, project, agent, first_message, + started_at, ended_at, message_count, + user_message_count) + VALUES + ('store-test-001', 'test-machine', + 'test-project', 'claude-code', + 'hello world', + '2026-03-12T10:00:00Z'::timestamptz, + '2026-03-12T10:30:00Z'::timestamptz, + 2, 1) + `) + if err != nil { + t.Fatalf("inserting test session: %v", err) + } + _, err = pg.Exec(` + INSERT INTO messages + (session_id, ordinal, role, content, + timestamp, content_length) + VALUES + ('store-test-001', 0, 'user', + 'hello world', + '2026-03-12T10:00:00Z'::timestamptz, 11), + ('store-test-001', 1, 'assistant', + 'hi there', + '2026-03-12T10:00:01Z'::timestamptz, 8) + `) + if err != nil { + t.Fatalf("inserting test messages: %v", err) + } +} + +func TestNewStore(t *testing.T) { + pgURL := testPGURL(t) + ensureStoreSchema(t, pgURL) + + store, err := NewStore(pgURL, testSchema, true) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + defer store.Close() + + if !store.ReadOnly() { + t.Error("ReadOnly() = false, want true") + } + if !store.HasFTS() { + t.Error("HasFTS() = false, want true") + } +} + +func TestStoreListSessions(t *testing.T) { + pgURL := testPGURL(t) + ensureStoreSchema(t, pgURL) + + store, err := NewStore(pgURL, testSchema, true) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + defer store.Close() + + ctx := context.Background() + page, err := store.ListSessions( + ctx, db.SessionFilter{Limit: 10}, + ) + if err != nil { + t.Fatalf("ListSessions: %v", err) + } + if page.Total == 0 { + t.Error("expected at least 1 session") + } + t.Logf("sessions: %d, total: %d", + len(page.Sessions), page.Total) +} + +func TestStoreGetSession(t *testing.T) { + pgURL := testPGURL(t) + ensureStoreSchema(t, pgURL) + + store, err := NewStore(pgURL, testSchema, true) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + defer store.Close() + + ctx := context.Background() + sess, err := store.GetSession(ctx, "store-test-001") + if err != nil { + t.Fatalf("GetSession: %v", err) + } + if sess == nil { + t.Fatal("expected session, got nil") + } + if sess.Project != "test-project" { + t.Errorf("project = %q, want %q", + sess.Project, "test-project") + } +} + +func TestStoreGetMessages(t *testing.T) { + pgURL := testPGURL(t) + ensureStoreSchema(t, pgURL) + + store, err := NewStore(pgURL, testSchema, true) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + defer store.Close() + + ctx := context.Background() + msgs, err := store.GetMessages( + ctx, "store-test-001", 0, 100, true, + ) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(msgs) != 2 { + t.Errorf("got %d messages, want 2", len(msgs)) + } +} + +func TestStoreGetStats(t *testing.T) { + pgURL := testPGURL(t) + ensureStoreSchema(t, pgURL) + + store, err := NewStore(pgURL, testSchema, true) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + defer store.Close() + + ctx := context.Background() + stats, err := store.GetStats(ctx, false) + if err != nil { + t.Fatalf("GetStats: %v", err) + } + if stats.SessionCount == 0 { + t.Error("expected at least 1 session in stats") + } + t.Logf("stats: %+v", stats) +} + +func TestStoreSearch(t *testing.T) { + pgURL := testPGURL(t) + ensureStoreSchema(t, pgURL) + + store, err := NewStore(pgURL, testSchema, true) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + defer store.Close() + + ctx := context.Background() + page, err := store.Search(ctx, db.SearchFilter{ + Query: "hello", + Limit: 5, + }) + if err != nil { + t.Fatalf("Search: %v", err) + } + if len(page.Results) == 0 { + t.Error("expected at least 1 search result") + } + t.Logf("search results: %d", len(page.Results)) +} + +func TestStoreGetMinimap(t *testing.T) { + pgURL := testPGURL(t) + ensureStoreSchema(t, pgURL) + + store, err := NewStore(pgURL, testSchema, true) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + defer store.Close() + + ctx := context.Background() + entries, err := store.GetMinimap( + ctx, "store-test-001", + ) + if err != nil { + t.Fatalf("GetMinimap: %v", err) + } + if len(entries) != 2 { + t.Errorf("got %d entries, want 2", len(entries)) + } +} + +func TestStoreAnalyticsSummary(t *testing.T) { + pgURL := testPGURL(t) + ensureStoreSchema(t, pgURL) + + store, err := NewStore(pgURL, testSchema, true) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + defer store.Close() + + ctx := context.Background() + summary, err := store.GetAnalyticsSummary( + ctx, db.AnalyticsFilter{ + From: "2026-01-01", + To: "2026-12-31", + }, + ) + if err != nil { + t.Fatalf("GetAnalyticsSummary: %v", err) + } + if summary.TotalSessions == 0 { + t.Error("expected at least 1 session in summary") + } + t.Logf("summary: %+v", summary) +} + +func TestStoreWriteMethodsReturnReadOnly(t *testing.T) { + pgURL := testPGURL(t) + + store, err := NewStore(pgURL, testSchema, true) + if err != nil { + t.Fatalf("NewStore: %v", err) + } + defer store.Close() + + tests := []struct { + name string + fn func() error + }{ + {"StarSession", func() error { + _, err := store.StarSession("x") + return err + }}, + {"UnstarSession", func() error { + return store.UnstarSession("x") + }}, + {"BulkStarSessions", func() error { + return store.BulkStarSessions([]string{"x"}) + }}, + {"PinMessage", func() error { + _, err := store.PinMessage("x", 1, nil) + return err + }}, + {"UnpinMessage", func() error { + return store.UnpinMessage("x", 1) + }}, + {"InsertInsight", func() error { + _, err := store.InsertInsight(db.Insight{}) + return err + }}, + {"DeleteInsight", func() error { + return store.DeleteInsight(1) + }}, + {"RenameSession", func() error { + return store.RenameSession("x", nil) + }}, + {"SoftDeleteSession", func() error { + return store.SoftDeleteSession("x") + }}, + {"RestoreSession", func() error { + _, err := store.RestoreSession("x") + return err + }}, + {"DeleteSessionIfTrashed", func() error { + _, err := store.DeleteSessionIfTrashed("x") + return err + }}, + {"EmptyTrash", func() error { + _, err := store.EmptyTrash() + return err + }}, + {"UpsertSession", func() error { + return store.UpsertSession(db.Session{}) + }}, + {"ReplaceSessionMessages", func() error { + return store.ReplaceSessionMessages("x", nil) + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.fn() + if err != db.ErrReadOnly { + t.Errorf("got %v, want ErrReadOnly", err) + } + }) + } +} diff --git a/internal/postgres/store_unit_test.go b/internal/postgres/store_unit_test.go new file mode 100644 index 00000000..9ac23cb6 --- /dev/null +++ b/internal/postgres/store_unit_test.go @@ -0,0 +1,44 @@ +package postgres + +import "testing" + +func TestStripFTSQuotes(t *testing.T) { + tests := []struct { + input string + want string + }{ + {`"hello world"`, "hello world"}, + {`hello`, "hello"}, + {`"single`, `"single`}, + {`""`, ""}, + {`"a"`, "a"}, + {`already unquoted`, "already unquoted"}, + } + for _, tt := range tests { + got := stripFTSQuotes(tt.input) + if got != tt.want { + t.Errorf("stripFTSQuotes(%q) = %q, want %q", + tt.input, got, tt.want) + } + } +} + +func TestEscapeLike(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"hello", "hello"}, + {"100%", `100\%`}, + {"under_score", `under\_score`}, + {`back\slash`, `back\\slash`}, + {`%_\`, `\%\_\\`}, + } + for _, tt := range tests { + got := escapeLike(tt.input) + if got != tt.want { + t.Errorf("escapeLike(%q) = %q, want %q", + tt.input, got, tt.want) + } + } +} diff --git a/internal/postgres/sync.go b/internal/postgres/sync.go new file mode 100644 index 00000000..bc9588b1 --- /dev/null +++ b/internal/postgres/sync.go @@ -0,0 +1,171 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "log" + "strings" + "sync" + + "github.com/wesm/agentsview/internal/db" +) + +// isUndefinedTable returns true when the error indicates the +// queried relation does not exist (PG SQLSTATE 42P01). We match +// only the SQLSTATE code to avoid false positives from other +// "does not exist" errors (missing columns, functions, etc.). +func isUndefinedTable(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "42P01") +} + +// Sync manages push-only sync from local SQLite to a remote +// PostgreSQL database. +type Sync struct { + pg *sql.DB + local *db.DB + machine string + schema string + + closeOnce sync.Once + closeErr error + + schemaMu sync.Mutex + schemaDone bool +} + +// New creates a Sync instance and verifies the PG connection. +// The machine name must not be "local", which is reserved as the +// SQLite sentinel for sessions that originated on this machine. +// When allowInsecure is true, non-loopback connections without TLS +// produce a warning instead of failing. +func New( + pgURL, schema string, local *db.DB, + machine string, allowInsecure bool, +) (*Sync, error) { + if pgURL == "" { + return nil, fmt.Errorf("postgres URL is required") + } + if machine == "" { + return nil, fmt.Errorf( + "machine name must not be empty", + ) + } + if machine == "local" { + return nil, fmt.Errorf( + "machine name %q is reserved; "+ + "choose a different pg.machine_name", + machine, + ) + } + if local == nil { + return nil, fmt.Errorf("local db is required") + } + + pg, err := Open(pgURL, schema, allowInsecure) + if err != nil { + return nil, err + } + + return &Sync{ + pg: pg, + local: local, + machine: machine, + schema: schema, + }, nil +} + +// Close closes the PostgreSQL connection pool. +// Callers must ensure no Push operations are in-flight +// before calling Close; otherwise those operations will fail +// with connection errors. +func (s *Sync) Close() error { + s.closeOnce.Do(func() { + s.closeErr = s.pg.Close() + }) + return s.closeErr +} + +// EnsureSchema creates the schema and tables in PG if they +// don't already exist. It also marks the schema as initialized +// so subsequent Push calls skip redundant checks. +func (s *Sync) EnsureSchema(ctx context.Context) error { + s.schemaMu.Lock() + defer s.schemaMu.Unlock() + if s.schemaDone { + return nil + } + if err := EnsureSchema(ctx, s.pg, s.schema); err != nil { + return err + } + s.schemaDone = true + return nil +} + +// Status returns sync status information. +// Sync state reads (last_push_at) are non-fatal because these +// are informational watermarks stored in SQLite. PG query +// failures are fatal because they indicate a connectivity +// problem that the caller needs to know about. +func (s *Sync) Status( + ctx context.Context, +) (SyncStatus, error) { + lastPush, err := s.local.GetSyncState("last_push_at") + if err != nil { + log.Printf( + "warning: reading last_push_at: %v", err, + ) + lastPush = "" + } + + var pgSessions int + err = s.pg.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sessions", + ).Scan(&pgSessions) + if err != nil { + if isUndefinedTable(err) { + return SyncStatus{ + Machine: s.machine, + LastPushAt: lastPush, + }, nil + } + return SyncStatus{}, fmt.Errorf( + "counting pg sessions: %w", err, + ) + } + + var pgMessages int + err = s.pg.QueryRowContext(ctx, + "SELECT COUNT(*) FROM messages", + ).Scan(&pgMessages) + if err != nil { + if isUndefinedTable(err) { + return SyncStatus{ + Machine: s.machine, + LastPushAt: lastPush, + PGSessions: pgSessions, + }, nil + } + return SyncStatus{}, fmt.Errorf( + "counting pg messages: %w", err, + ) + } + + return SyncStatus{ + Machine: s.machine, + LastPushAt: lastPush, + PGSessions: pgSessions, + PGMessages: pgMessages, + }, nil +} + +// SyncStatus holds summary information about the sync state. +type SyncStatus struct { + Machine string `json:"machine"` + LastPushAt string `json:"last_push_at"` + PGSessions int `json:"pg_sessions"` + PGMessages int `json:"pg_messages"` +} diff --git a/internal/postgres/sync_test.go b/internal/postgres/sync_test.go new file mode 100644 index 00000000..137113e5 --- /dev/null +++ b/internal/postgres/sync_test.go @@ -0,0 +1,1049 @@ +//go:build pgtest + +package postgres + +import ( + "context" + "database/sql" + "fmt" + "regexp" + "testing" + "time" + + "github.com/wesm/agentsview/internal/db" +) + +func testDB(t *testing.T) *db.DB { + t.Helper() + d, err := db.Open(t.TempDir() + "/test.db") + if err != nil { + t.Fatalf("opening test db: %v", err) + } + t.Cleanup(func() { d.Close() }) + return d +} + +func cleanPGSchema(t *testing.T, pgURL string) { + t.Helper() + pg, err := sql.Open("pgx", pgURL) + if err != nil { + t.Fatalf("connecting to pg: %v", err) + } + defer pg.Close() + _, _ = pg.Exec( + "DROP SCHEMA IF EXISTS agentsview CASCADE", + ) +} + +func TestEnsureSchemaIdempotent(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("first EnsureSchema: %v", err) + } + + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("second EnsureSchema: %v", err) + } +} + +func TestPushSingleSession(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + started := "2026-03-11T12:00:00Z" + firstMsg := "hello world" + sess := db.Session{ + ID: "sess-001", + Project: "test-project", + Machine: "local", + Agent: "claude", + FirstMessage: &firstMsg, + StartedAt: &started, + MessageCount: 1, + } + if err := local.UpsertSession(sess); err != nil { + t.Fatalf("upsert session: %v", err) + } + if err := local.InsertMessages([]db.Message{ + { + SessionID: "sess-001", + Ordinal: 0, + Role: "user", + Content: firstMsg, + }, + }); err != nil { + t.Fatalf("insert messages: %v", err) + } + + result, err := ps.Push(ctx, false) + if err != nil { + t.Fatalf("push: %v", err) + } + if result.SessionsPushed != 1 { + t.Errorf( + "sessions pushed = %d, want 1", + result.SessionsPushed, + ) + } + if result.MessagesPushed != 1 { + t.Errorf( + "messages pushed = %d, want 1", + result.MessagesPushed, + ) + } + + var pgProject, pgMachine string + err = ps.pg.QueryRowContext(ctx, + "SELECT project, machine FROM sessions WHERE id = $1", + "sess-001", + ).Scan(&pgProject, &pgMachine) + if err != nil { + t.Fatalf("querying pg session: %v", err) + } + if pgProject != "test-project" { + t.Errorf( + "pg project = %q, want %q", + pgProject, "test-project", + ) + } + if pgMachine != "test-machine" { + t.Errorf( + "pg machine = %q, want %q", + pgMachine, "test-machine", + ) + } + + var pgMsgContent string + err = ps.pg.QueryRowContext(ctx, + "SELECT content FROM messages WHERE session_id = $1 AND ordinal = 0", + "sess-001", + ).Scan(&pgMsgContent) + if err != nil { + t.Fatalf("querying pg message: %v", err) + } + if pgMsgContent != firstMsg { + t.Errorf( + "pg message content = %q, want %q", + pgMsgContent, firstMsg, + ) + } +} + +func TestPushIdempotent(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + started := "2026-03-11T12:00:00Z" + sess := db.Session{ + ID: "sess-002", + Project: "test-project", + Machine: "local", + Agent: "claude", + StartedAt: &started, + MessageCount: 0, + } + if err := local.UpsertSession(sess); err != nil { + t.Fatalf("upsert session: %v", err) + } + + result1, err := ps.Push(ctx, false) + if err != nil { + t.Fatalf("first push: %v", err) + } + if result1.SessionsPushed != 1 { + t.Errorf( + "first push sessions = %d, want 1", + result1.SessionsPushed, + ) + } + + result2, err := ps.Push(ctx, false) + if err != nil { + t.Fatalf("second push: %v", err) + } + if result2.SessionsPushed != 0 { + t.Errorf( + "second push sessions = %d, want 0", + result2.SessionsPushed, + ) + } +} + +func TestPushWithToolCalls(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + started := "2026-03-11T12:00:00Z" + sess := db.Session{ + ID: "sess-tc-001", + Project: "test-project", + Machine: "local", + Agent: "claude", + StartedAt: &started, + MessageCount: 1, + } + if err := local.UpsertSession(sess); err != nil { + t.Fatalf("upsert session: %v", err) + } + if err := local.InsertMessages([]db.Message{ + { + SessionID: "sess-tc-001", + Ordinal: 0, + Role: "assistant", + Content: "tool use response", + HasToolUse: true, + ToolCalls: []db.ToolCall{ + { + ToolName: "Read", + Category: "Read", + ToolUseID: "toolu_001", + ResultContentLength: 42, + ResultContent: "file content here", + }, + }, + }, + }); err != nil { + t.Fatalf("insert messages: %v", err) + } + + result, err := ps.Push(ctx, false) + if err != nil { + t.Fatalf("push: %v", err) + } + if result.MessagesPushed != 1 { + t.Errorf( + "messages pushed = %d, want 1", + result.MessagesPushed, + ) + } + + var toolName string + var resultLen int + err = ps.pg.QueryRowContext(ctx, + "SELECT tool_name, result_content_length FROM tool_calls WHERE session_id = $1", + "sess-tc-001", + ).Scan(&toolName, &resultLen) + if err != nil { + t.Fatalf("querying pg tool_call: %v", err) + } + if toolName != "Read" { + t.Errorf( + "tool_name = %q, want %q", toolName, "Read", + ) + } + if resultLen != 42 { + t.Errorf( + "result_content_length = %d, want 42", + resultLen, + ) + } +} + +func TestStatus(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + status, err := ps.Status(ctx) + if err != nil { + t.Fatalf("status: %v", err) + } + if status.Machine != "test-machine" { + t.Errorf( + "machine = %q, want %q", + status.Machine, "test-machine", + ) + } + if status.PGSessions != 0 { + t.Errorf( + "pg sessions = %d, want 0", + status.PGSessions, + ) + } +} + +func TestStatusMissingSchema(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + status, err := ps.Status(ctx) + if err != nil { + t.Fatalf("status on missing schema: %v", err) + } + if status.PGSessions != 0 { + t.Errorf( + "pg sessions = %d, want 0", + status.PGSessions, + ) + } + if status.PGMessages != 0 { + t.Errorf( + "pg messages = %d, want 0", + status.PGMessages, + ) + } + if status.Machine != "test-machine" { + t.Errorf( + "machine = %q, want %q", + status.Machine, "test-machine", + ) + } +} + +func TestNewRejectsMachineLocal(t *testing.T) { + pgURL := testPGURL(t) + local := testDB(t) + _, err := New( + pgURL, "agentsview", local, "local", true, + ) + if err == nil { + t.Fatal("expected error for machine=local") + } +} + +func TestNewRejectsEmptyMachine(t *testing.T) { + pgURL := testPGURL(t) + local := testDB(t) + _, err := New( + pgURL, "agentsview", local, "", true, + ) + if err == nil { + t.Fatal("expected error for empty machine") + } +} + +func TestNewRejectsEmptyURL(t *testing.T) { + local := testDB(t) + _, err := New( + "", "agentsview", local, "test", true, + ) + if err == nil { + t.Fatal("expected error for empty URL") + } +} + +func TestPushUpdatedAtFormat(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + started := "2026-03-11T12:00:00Z" + sess := db.Session{ + ID: "sess-ts-001", + Project: "test-project", + Machine: "local", + Agent: "claude", + StartedAt: &started, + } + if err := local.UpsertSession(sess); err != nil { + t.Fatalf("upsert session: %v", err) + } + + if _, err := ps.Push(ctx, false); err != nil { + t.Fatalf("push: %v", err) + } + + var updatedAt time.Time + err = ps.pg.QueryRowContext(ctx, + "SELECT updated_at FROM sessions WHERE id = $1", + "sess-ts-001", + ).Scan(&updatedAt) + if err != nil { + t.Fatalf("querying updated_at: %v", err) + } + + formatted := updatedAt.UTC().Format( + "2006-01-02T15:04:05.000000Z", + ) + pattern := regexp.MustCompile( + `^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{6}Z$`, + ) + if !pattern.MatchString(formatted) { + t.Errorf( + "updated_at = %q, want ISO-8601 "+ + "microsecond format", formatted, + ) + } +} + +func TestPushBumpsUpdatedAtOnMessageRewrite( + t *testing.T, +) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "machine-a", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + started := time.Now().UTC().Format(time.RFC3339) + sess := db.Session{ + ID: "sess-bump-001", + Project: "test", + Machine: "local", + Agent: "test-agent", + StartedAt: &started, + MessageCount: 1, + } + if err := local.UpsertSession(sess); err != nil { + t.Fatalf("upsert session: %v", err) + } + msg := db.Message{ + SessionID: "sess-bump-001", + Ordinal: 0, + Role: "user", + Content: "hello", + ContentLength: 5, + } + if err := local.ReplaceSessionMessages( + "sess-bump-001", []db.Message{msg}, + ); err != nil { + t.Fatalf("replace messages: %v", err) + } + + if _, err := ps.Push(ctx, false); err != nil { + t.Fatalf("initial push: %v", err) + } + + var updatedAt1 time.Time + if err := ps.pg.QueryRowContext(ctx, + "SELECT updated_at FROM sessions WHERE id = $1", + "sess-bump-001", + ).Scan(&updatedAt1); err != nil { + t.Fatalf("querying updated_at: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + result, err := ps.Push(ctx, true) + if err != nil { + t.Fatalf("full push: %v", err) + } + if result.MessagesPushed == 0 { + t.Fatal( + "expected messages to be pushed on full push", + ) + } + + var updatedAt2 time.Time + if err := ps.pg.QueryRowContext(ctx, + "SELECT updated_at FROM sessions WHERE id = $1", + "sess-bump-001", + ).Scan(&updatedAt2); err != nil { + t.Fatalf( + "querying updated_at after full push: %v", + err, + ) + } + + if !updatedAt2.After(updatedAt1) { + t.Errorf( + "updated_at not bumped: before=%v, after=%v", + updatedAt1, updatedAt2, + ) + } +} + +func TestPushFullBypassesHeuristic(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + started := "2026-03-11T12:00:00Z" + sess := db.Session{ + ID: "sess-full-001", + Project: "test-project", + Machine: "local", + Agent: "claude", + StartedAt: &started, + MessageCount: 1, + } + if err := local.UpsertSession(sess); err != nil { + t.Fatalf("upsert session: %v", err) + } + if err := local.InsertMessages([]db.Message{ + { + SessionID: "sess-full-001", + Ordinal: 0, + Role: "user", + Content: "test", + }, + }); err != nil { + t.Fatalf("insert messages: %v", err) + } + + if _, err := ps.Push(ctx, false); err != nil { + t.Fatalf("first push: %v", err) + } + + if err := local.SetSyncState( + "last_push_at", "", + ); err != nil { + t.Fatalf("resetting watermark: %v", err) + } + + result, err := ps.Push(ctx, true) + if err != nil { + t.Fatalf("full push: %v", err) + } + if result.SessionsPushed != 1 { + t.Errorf( + "full push sessions = %d, want 1", + result.SessionsPushed, + ) + } + if result.MessagesPushed != 1 { + t.Errorf( + "full push messages = %d, want 1", + result.MessagesPushed, + ) + } +} + +func TestPushDetectsSchemaReset(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + // Push a session so the watermark advances. + started := "2026-03-11T12:00:00Z" + sess := db.Session{ + ID: "sess-reset-001", + Project: "test-project", + Machine: "local", + Agent: "claude", + StartedAt: &started, + MessageCount: 1, + } + if err := local.UpsertSession(sess); err != nil { + t.Fatalf("upsert session: %v", err) + } + if err := local.InsertMessages([]db.Message{{ + SessionID: "sess-reset-001", + Ordinal: 0, + Role: "user", + Content: "hello", + ContentLength: 5, + }}); err != nil { + t.Fatalf("insert message: %v", err) + } + + r1, err := ps.Push(ctx, false) + if err != nil { + t.Fatalf("initial push: %v", err) + } + if r1.SessionsPushed != 1 { + t.Fatalf( + "initial push sessions = %d, want 1", + r1.SessionsPushed, + ) + } + + // Simulate a PG schema reset — don't manually recreate; + // let Push detect and handle it via the coherence check. + cleanPGSchema(t, pgURL) + + // An incremental push should detect the mismatch + // (local watermark set, PG has 0 sessions), recreate + // the schema, and automatically force a full push. + r2, err := ps.Push(ctx, false) + if err != nil { + t.Fatalf("post-reset push: %v", err) + } + if r2.SessionsPushed != 1 { + t.Errorf( + "post-reset push sessions = %d, want 1 "+ + "(should auto-detect schema reset)", + r2.SessionsPushed, + ) + } + if r2.MessagesPushed != 1 { + t.Errorf( + "post-reset push messages = %d, want 1", + r2.MessagesPushed, + ) + } +} + +func TestPushFullAfterSchemaDropRecreatesSchema( + t *testing.T, +) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + ctx := context.Background() + + sess := db.Session{ + ID: "sess-full-drop", + Project: "proj", + Machine: "test-machine", + Agent: "claude", + CreatedAt: "2026-03-11T12:00:00.000Z", + } + if err := local.UpsertSession(sess); err != nil { + t.Fatalf("upsert session: %v", err) + } + + r1, err := ps.Push(ctx, false) + if err != nil { + t.Fatalf("initial push: %v", err) + } + if r1.SessionsPushed != 1 { + t.Fatalf( + "initial push sessions = %d, want 1", + r1.SessionsPushed, + ) + } + + // Drop the schema without clearing local state. + cleanPGSchema(t, pgURL) + + // A full push should recreate the schema even though + // schemaDone is memoized from the first push. + r2, err := ps.Push(ctx, true) + if err != nil { + t.Fatalf("full push after drop: %v", err) + } + if r2.SessionsPushed != 1 { + t.Errorf( + "full push sessions = %d, want 1", + r2.SessionsPushed, + ) + } +} + +func TestPushBatchesMultipleSessions(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + // Create 75 sessions to exercise two batches (50 + 25). + const totalSessions = 75 + for i := range totalSessions { + id := fmt.Sprintf("batch-sess-%03d", i) + started := "2026-03-11T12:00:00Z" + sess := db.Session{ + ID: id, + Project: "batch-project", + Machine: "local", + Agent: "claude", + StartedAt: &started, + MessageCount: 2, + } + if err := local.UpsertSession(sess); err != nil { + t.Fatalf("upsert session %d: %v", i, err) + } + if err := local.InsertMessages([]db.Message{ + { + SessionID: id, + Ordinal: 0, + Role: "user", + Content: fmt.Sprintf("msg %d", i), + ContentLength: 5, + }, + { + SessionID: id, + Ordinal: 1, + Role: "assistant", + Content: fmt.Sprintf("reply %d", i), + ContentLength: 7, + }, + }); err != nil { + t.Fatalf("insert messages %d: %v", i, err) + } + } + + result, err := ps.Push(ctx, false) + if err != nil { + t.Fatalf("push: %v", err) + } + if result.SessionsPushed != totalSessions { + t.Errorf( + "sessions pushed = %d, want %d", + result.SessionsPushed, totalSessions, + ) + } + if result.MessagesPushed != totalSessions*2 { + t.Errorf( + "messages pushed = %d, want %d", + result.MessagesPushed, totalSessions*2, + ) + } + if result.Errors != 0 { + t.Errorf("errors = %d, want 0", result.Errors) + } + + // Verify PG state. + var pgSessions, pgMessages int + if err := ps.pg.QueryRowContext(ctx, + "SELECT COUNT(*) FROM sessions", + ).Scan(&pgSessions); err != nil { + t.Fatalf("counting pg sessions: %v", err) + } + if err := ps.pg.QueryRowContext(ctx, + "SELECT COUNT(*) FROM messages", + ).Scan(&pgMessages); err != nil { + t.Fatalf("counting pg messages: %v", err) + } + if pgSessions != totalSessions { + t.Errorf( + "pg sessions = %d, want %d", + pgSessions, totalSessions, + ) + } + if pgMessages != totalSessions*2 { + t.Errorf( + "pg messages = %d, want %d", + pgMessages, totalSessions*2, + ) + } +} + +func TestPushBulkInsertManyMessages(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + // Create a session with 250 messages to exercise + // multi-row VALUES batching (100 per batch). + const msgCount = 250 + started := "2026-03-11T12:00:00Z" + sess := db.Session{ + ID: "bulk-msg-sess", + Project: "test-project", + Machine: "local", + Agent: "claude", + StartedAt: &started, + MessageCount: msgCount, + } + if err := local.UpsertSession(sess); err != nil { + t.Fatalf("upsert session: %v", err) + } + msgs := make([]db.Message, msgCount) + for i := range msgs { + role := "user" + if i%2 == 1 { + role = "assistant" + } + msgs[i] = db.Message{ + SessionID: "bulk-msg-sess", + Ordinal: i, + Role: role, + Content: fmt.Sprintf("message %d", i), + ContentLength: len(fmt.Sprintf("message %d", i)), + } + // Add a tool call on every 10th assistant message. + if role == "assistant" && i%10 == 1 { + msgs[i].HasToolUse = true + msgs[i].ToolCalls = []db.ToolCall{{ + ToolName: "Read", + Category: "Read", + ToolUseID: fmt.Sprintf("toolu_%d", i), + ResultContentLength: 10, + ResultContent: "some result", + }} + } + } + if err := local.InsertMessages(msgs); err != nil { + t.Fatalf("insert messages: %v", err) + } + + result, err := ps.Push(ctx, false) + if err != nil { + t.Fatalf("push: %v", err) + } + if result.SessionsPushed != 1 { + t.Errorf( + "sessions pushed = %d, want 1", + result.SessionsPushed, + ) + } + if result.MessagesPushed != msgCount { + t.Errorf( + "messages pushed = %d, want %d", + result.MessagesPushed, msgCount, + ) + } + + // Verify all messages landed in PG. + var pgMsgCount int + if err := ps.pg.QueryRowContext(ctx, + "SELECT COUNT(*) FROM messages WHERE session_id = $1", + "bulk-msg-sess", + ).Scan(&pgMsgCount); err != nil { + t.Fatalf("counting pg messages: %v", err) + } + if pgMsgCount != msgCount { + t.Errorf( + "pg messages = %d, want %d", + pgMsgCount, msgCount, + ) + } + + // Verify tool calls landed. + var pgTCCount int + if err := ps.pg.QueryRowContext(ctx, + "SELECT COUNT(*) FROM tool_calls WHERE session_id = $1", + "bulk-msg-sess", + ).Scan(&pgTCCount); err != nil { + t.Fatalf("counting pg tool_calls: %v", err) + } + // Every 10th assistant message (ordinals 1, 11, 21, ...). + expectedTC := 0 + for i := range msgCount { + if i%2 == 1 && i%10 == 1 { + expectedTC++ + } + } + if pgTCCount != expectedTC { + t.Errorf( + "pg tool_calls = %d, want %d", + pgTCCount, expectedTC, + ) + } +} + +func TestPushSimplePK(t *testing.T) { + pgURL := testPGURL(t) + cleanPGSchema(t, pgURL) + t.Cleanup(func() { cleanPGSchema(t, pgURL) }) + + local := testDB(t) + ps, err := New( + pgURL, "agentsview", local, + "test-machine", true, + ) + if err != nil { + t.Fatalf("creating sync: %v", err) + } + defer ps.Close() + + ctx := context.Background() + if err := ps.EnsureSchema(ctx); err != nil { + t.Fatalf("ensure schema: %v", err) + } + + var constraintDef string + err = ps.pg.QueryRowContext(ctx, ` + SELECT pg_get_constraintdef(c.oid) + FROM pg_constraint c + JOIN pg_namespace n ON n.oid = c.connamespace + WHERE n.nspname = 'agentsview' + AND c.conrelid = 'agentsview.sessions'::regclass + AND c.contype = 'p' + `).Scan(&constraintDef) + if err != nil { + t.Fatalf("querying sessions PK: %v", err) + } + if constraintDef != "PRIMARY KEY (id)" { + t.Errorf( + "sessions PK = %q, want PRIMARY KEY (id)", + constraintDef, + ) + } + + err = ps.pg.QueryRowContext(ctx, ` + SELECT pg_get_constraintdef(c.oid) + FROM pg_constraint c + JOIN pg_namespace n ON n.oid = c.connamespace + WHERE n.nspname = 'agentsview' + AND c.conrelid = 'agentsview.messages'::regclass + AND c.contype = 'p' + `).Scan(&constraintDef) + if err != nil { + t.Fatalf("querying messages PK: %v", err) + } + if constraintDef != "PRIMARY KEY (session_id, ordinal)" { + t.Errorf( + "messages PK = %q, "+ + "want PRIMARY KEY (session_id, ordinal)", + constraintDef, + ) + } +} diff --git a/internal/postgres/sync_unit_test.go b/internal/postgres/sync_unit_test.go new file mode 100644 index 00000000..4faa8b36 --- /dev/null +++ b/internal/postgres/sync_unit_test.go @@ -0,0 +1,52 @@ +package postgres + +import ( + "errors" + "testing" +) + +func TestIsUndefinedTable(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + { + "unrelated error", + errors.New("connection refused"), + false, + }, + { + "generic does not exist", + errors.New( + `column "foo" does not exist`, + ), + false, + }, + { + "SQLSTATE 42P01", + errors.New( + `ERROR: relation "sessions" ` + + `does not exist (SQLSTATE 42P01)`, + ), + true, + }, + { + "bare SQLSTATE", + errors.New("42P01"), + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isUndefinedTable(tt.err) + if got != tt.want { + t.Errorf( + "isUndefinedTable(%v) = %v, want %v", + tt.err, got, tt.want, + ) + } + }) + } +} diff --git a/internal/postgres/time.go b/internal/postgres/time.go new file mode 100644 index 00000000..19b16a54 --- /dev/null +++ b/internal/postgres/time.go @@ -0,0 +1,132 @@ +package postgres + +import ( + "fmt" + "time" +) + +// Common timestamp formats found in SQLite data. +var sqliteFormats = []string{ + time.RFC3339Nano, + "2006-01-02T15:04:05.000Z", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", +} + +// ParseSQLiteTimestamp parses an ISO-8601 text timestamp from +// SQLite into a time.Time. Returns zero time and false for +// empty strings or unparseable values. +func ParseSQLiteTimestamp(s string) (time.Time, bool) { + if s == "" { + return time.Time{}, false + } + for _, f := range sqliteFormats { + if t, err := time.Parse(f, s); err == nil { + return t, true + } + } + return time.Time{}, false +} + +// FormatISO8601 formats a time.Time to ISO-8601 UTC string +// for JSON API responses. +func FormatISO8601(t time.Time) string { + return t.UTC().Format(time.RFC3339Nano) +} + +// syncTimestampLayout uses microsecond precision to match +// PostgreSQL's timestamp resolution. +const syncTimestampLayout = "2006-01-02T15:04:05.000000Z" + +// LocalSyncTimestampLayout uses millisecond precision to match +// SQLite's datetime resolution. +const LocalSyncTimestampLayout = "2006-01-02T15:04:05.000Z" + +// FormatSyncTimestamp formats a time as a microsecond-precision +// UTC ISO-8601 string for PG sync watermarks. +func FormatSyncTimestamp(t time.Time) string { + return t.UTC().Format(syncTimestampLayout) +} + +// NormalizeSyncTimestamp parses a RFC3339Nano timestamp and +// re-formats it to microsecond precision for PG sync. +func NormalizeSyncTimestamp(value string) (string, error) { + if value == "" { + return "", nil + } + ts, err := time.Parse(time.RFC3339Nano, value) + if err != nil { + return "", err + } + return FormatSyncTimestamp(ts), nil +} + +// NormalizeLocalSyncTimestamp parses a RFC3339Nano timestamp and +// re-formats it to millisecond precision for SQLite sync state. +func NormalizeLocalSyncTimestamp( + value string, +) (string, error) { + if value == "" { + return "", nil + } + ts, err := time.Parse(time.RFC3339Nano, value) + if err != nil { + return "", err + } + return ts.UTC().Format(LocalSyncTimestampLayout), nil +} + +// PreviousLocalSyncTimestamp returns the timestamp 1ms before +// the given value, formatted at millisecond precision. This +// creates a non-overlapping boundary for incremental sync +// queries against SQLite. +func PreviousLocalSyncTimestamp( + value string, +) (string, error) { + if value == "" { + return "", nil + } + ts, err := time.Parse(time.RFC3339Nano, value) + if err != nil { + return "", err + } + prev := ts.Add(-time.Millisecond) + return prev.UTC().Format(LocalSyncTimestampLayout), nil +} + +// SyncStateStore is the interface needed for normalizing local +// sync timestamps stored in SQLite. +type SyncStateStore interface { + GetSyncState(key string) (string, error) + SetSyncState(key, value string) error +} + +// NormalizeLocalSyncStateTimestamps normalizes the last_push_at +// watermark in the local SQLite sync state to millisecond +// precision. +func NormalizeLocalSyncStateTimestamps( + local SyncStateStore, +) error { + value, err := local.GetSyncState("last_push_at") + if err != nil { + return fmt.Errorf("reading last_push_at: %w", err) + } + if value == "" { + return nil + } + normalized, err := NormalizeLocalSyncTimestamp(value) + if err != nil { + return fmt.Errorf( + "normalizing last_push_at: %w", err, + ) + } + if normalized == value { + return nil + } + if err := local.SetSyncState( + "last_push_at", normalized, + ); err != nil { + return fmt.Errorf("writing last_push_at: %w", err) + } + return nil +} diff --git a/internal/postgres/time_test.go b/internal/postgres/time_test.go new file mode 100644 index 00000000..bcd34e61 --- /dev/null +++ b/internal/postgres/time_test.go @@ -0,0 +1,223 @@ +package postgres + +import ( + "testing" + "time" + + "github.com/wesm/agentsview/internal/db" +) + +func TestParseSQLiteTimestamp(t *testing.T) { + tests := []struct { + name string + input string + wantOK bool + wantUTC string + }{ + { + "RFC3339Nano", + "2026-03-11T12:34:56.123456789Z", + true, + "2026-03-11T12:34:56.123456789Z", + }, + { + "millisecond", + "2026-03-11T12:34:56.000Z", + true, + "2026-03-11T12:34:56Z", + }, + { + "second only", + "2026-03-11T12:34:56Z", + true, + "2026-03-11T12:34:56Z", + }, + { + "space separated", + "2026-03-11 12:34:56", + true, + "2026-03-11T12:34:56Z", + }, + { + "empty string", + "", + false, + "", + }, + { + "garbage", + "not-a-timestamp", + false, + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := ParseSQLiteTimestamp(tt.input) + if ok != tt.wantOK { + t.Fatalf( + "ParseSQLiteTimestamp(%q) ok = %v, "+ + "want %v", + tt.input, ok, tt.wantOK, + ) + } + if !ok { + return + } + gotStr := got.UTC().Format(time.RFC3339Nano) + if gotStr != tt.wantUTC { + t.Errorf( + "ParseSQLiteTimestamp(%q) = %q, "+ + "want %q", + tt.input, gotStr, tt.wantUTC, + ) + } + }) + } +} + +func TestFormatISO8601(t *testing.T) { + ts := time.Date( + 2026, 3, 11, 12, 34, 56, 123456789, + time.UTC, + ) + got := FormatISO8601(ts) + want := "2026-03-11T12:34:56.123456789Z" + if got != want { + t.Errorf("FormatISO8601() = %q, want %q", got, want) + } +} + +func TestFormatISO8601NonUTC(t *testing.T) { + loc := time.FixedZone("EST", -5*3600) + ts := time.Date(2026, 3, 11, 7, 34, 56, 0, loc) + got := FormatISO8601(ts) + want := "2026-03-11T12:34:56Z" + if got != want { + t.Errorf( + "FormatISO8601() = %q, want %q (should be UTC)", + got, want, + ) + } +} + +func TestNormalizeSyncTimestamp(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + "second precision", + "2026-03-11T12:34:56Z", + "2026-03-11T12:34:56.000000Z", + }, + { + "nanosecond precision", + "2026-03-11T12:34:56.123456789Z", + "2026-03-11T12:34:56.123456Z", + }, + { + "empty", + "", + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NormalizeSyncTimestamp(tt.input) + if err != nil { + t.Fatalf("error = %v", err) + } + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestNormalizeLocalSyncTimestamp(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + "second precision", + "2026-03-11T12:34:56Z", + "2026-03-11T12:34:56.000Z", + }, + { + "microsecond precision", + "2026-03-11T12:34:56.123456Z", + "2026-03-11T12:34:56.123Z", + }, + { + "empty", + "", + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NormalizeLocalSyncTimestamp(tt.input) + if err != nil { + t.Fatalf("error = %v", err) + } + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestPreviousLocalSyncTimestamp(t *testing.T) { + got, err := PreviousLocalSyncTimestamp( + "2026-03-11T12:34:56.124Z", + ) + if err != nil { + t.Fatalf("error = %v", err) + } + want := "2026-03-11T12:34:56.123Z" + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestPreviousLocalSyncTimestampEmpty(t *testing.T) { + got, err := PreviousLocalSyncTimestamp("") + if err != nil { + t.Fatalf("error = %v", err) + } + if got != "" { + t.Errorf("got %q, want empty", got) + } +} + +func TestNormalizeLocalSyncStateTimestamps(t *testing.T) { + local, err := db.Open(t.TempDir() + "/test.db") + if err != nil { + t.Fatalf("opening test db: %v", err) + } + defer local.Close() + + if err := local.SetSyncState( + "last_push_at", + "2026-03-11T12:34:56.123456789Z", + ); err != nil { + t.Fatalf("SetSyncState: %v", err) + } + + if err := NormalizeLocalSyncStateTimestamps(local); err != nil { + t.Fatalf("NormalizeLocalSyncStateTimestamps: %v", err) + } + + got, err := local.GetSyncState("last_push_at") + if err != nil { + t.Fatalf("GetSyncState: %v", err) + } + want := "2026-03-11T12:34:56.123Z" + if got != want { + t.Errorf("last_push_at = %q, want %q", got, want) + } +} diff --git a/internal/server/basepath_test.go b/internal/server/basepath_test.go new file mode 100644 index 00000000..270e6c63 --- /dev/null +++ b/internal/server/basepath_test.go @@ -0,0 +1,128 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestBasePath_StripsPrefixForAPI(t *testing.T) { + s := testServer(t, 0, WithBasePath("/app")) + + req := httptest.NewRequest("GET", "/app/api/v1/sessions", nil) + req.Host = "127.0.0.1:0" + req.RemoteAddr = "127.0.0.1:12345" + w := httptest.NewRecorder() + s.Handler().ServeHTTP(w, req) + + // 200 or 503 (timeout) both confirm the route was matched + // and prefix was stripped. 404 or 403 would indicate a + // base-path routing failure. + if w.Code == http.StatusNotFound || + w.Code == http.StatusForbidden { + t.Fatalf("GET /app/api/v1/sessions = %d, want route match; body: %s", + w.Code, w.Body.String()) + } +} + +func TestBasePath_RedirectsBarePrefix(t *testing.T) { + s := testServer(t, 0, WithBasePath("/app")) + + req := httptest.NewRequest("GET", "/app", nil) + w := httptest.NewRecorder() + s.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusMovedPermanently { + t.Fatalf("GET /app = %d, want 301", w.Code) + } + loc := w.Header().Get("Location") + if loc != "/app/" { + t.Fatalf("Location = %q, want /app/", loc) + } +} + +func TestBasePath_InjectsBaseHrefIntoHTML(t *testing.T) { + s := testServer(t, 0, WithBasePath("/viewer")) + + req := httptest.NewRequest("GET", "/viewer/", nil) + w := httptest.NewRecorder() + s.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("GET /viewer/ = %d, want 200", w.Code) + } + body := w.Body.String() + if !strings.Contains(body, ``) { + t.Error("missing tag in response") + } +} + +func TestBasePath_RewritesAssetPaths(t *testing.T) { + s := testServer(t, 0, WithBasePath("/viewer")) + + req := httptest.NewRequest("GET", "/viewer/", nil) + w := httptest.NewRecorder() + s.Handler().ServeHTTP(w, req) + + body := w.Body.String() + + // Asset paths should be prefixed. + if strings.Contains(body, `src="/assets/`) { + t.Error("found unprefixed src=\"/assets/ in HTML") + } + if strings.Contains(body, `href="/assets/`) { + t.Error("found unprefixed href=\"/assets/ in HTML") + } + if strings.Contains(body, `href="/favicon`) { + t.Error("found unprefixed href=\"/favicon in HTML") + } + + // External URLs must NOT be prefixed. + if strings.Contains(body, `href="/viewer/https://`) { + t.Error("external URL was incorrectly prefixed") + } +} + +func TestBasePath_SPAFallbackServesIndex(t *testing.T) { + s := testServer(t, 0, WithBasePath("/app")) + + // A non-existent path should fall back to index.html + // with the base tag injected. + req := httptest.NewRequest("GET", "/app/some/route", nil) + w := httptest.NewRecorder() + s.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("GET /app/some/route = %d, want 200", w.Code) + } + if !strings.Contains(w.Body.String(), ``) { + t.Error("SPA fallback missing tag") + } +} + +func TestBasePath_RejectsSiblingPath(t *testing.T) { + s := testServer(t, 0, WithBasePath("/app")) + + // /appfoo should NOT be handled — only /app or /app/... + req := httptest.NewRequest("GET", "/appfoo/bar", nil) + req.Host = "127.0.0.1:0" + req.RemoteAddr = "127.0.0.1:12345" + w := httptest.NewRecorder() + s.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Fatalf( + "GET /appfoo/bar = %d, want 404", w.Code, + ) + } +} + +func TestBasePath_TrailingSlashNormalized(t *testing.T) { + s := testServer(t, 0, WithBasePath("/app/")) + + // WithBasePath trims trailing slash. + if s.basePath != "/app" { + t.Fatalf("basePath = %q, want /app", s.basePath) + } +} diff --git a/internal/server/events.go b/internal/server/events.go index d86d55fb..6ace7b8c 100644 --- a/internal/server/events.go +++ b/internal/server/events.go @@ -57,6 +57,14 @@ func (s *Server) sessionMonitor( sessionID, ) + if s.engine == nil { + // PG read mode: poll GetSessionVersion only, + // no file watching or fallback sync. + s.pollDBOnly(ctx, ch, sessionID, + lastCount, lastDBMtime) + return + } + // Track file mtime for fallback sync. sourcePath := s.engine.FindSourceFile(sessionID) var lastFileMtime int64 @@ -94,6 +102,35 @@ func (s *Server) sessionMonitor( return ch } +// pollDBOnly polls GetSessionVersion on a timer and signals ch +// when changes are detected. Used in PG-read mode where there is +// no sync engine or file watcher. +func (s *Server) pollDBOnly( + ctx context.Context, ch chan<- struct{}, + sessionID string, lastCount int, lastDBMtime int64, +) { + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + count, dbMtime, ok := s.db.GetSessionVersion(sessionID) + if ok && (count != lastCount || dbMtime != lastDBMtime) { + lastCount = count + lastDBMtime = dbMtime + select { + case ch <- struct{}{}: + case <-ctx.Done(): + return + } + } + } + } +} + // checkDBForChanges polls the database for a session's // message_count and file_mtime. If either changed, it // returns true. As a fallback, it monitors source file @@ -215,6 +252,11 @@ func (s *Server) handleWatchSession( func (s *Server) handleTriggerSync( w http.ResponseWriter, r *http.Request, ) { + if s.engine == nil { + writeError(w, http.StatusNotImplemented, + "not available in remote mode") + return + } stream, err := NewSSEStream(w) if err != nil { // Non-streaming fallback @@ -232,6 +274,11 @@ func (s *Server) handleTriggerSync( func (s *Server) handleTriggerResync( w http.ResponseWriter, r *http.Request, ) { + if s.engine == nil { + writeError(w, http.StatusNotImplemented, + "not available in remote mode") + return + } stream, err := NewSSEStream(w) if err != nil { stats := s.engine.ResyncAll(r.Context(), nil) @@ -248,6 +295,13 @@ func (s *Server) handleTriggerResync( func (s *Server) handleSyncStatus( w http.ResponseWriter, r *http.Request, ) { + if s.engine == nil { + writeJSON(w, http.StatusOK, map[string]any{ + "last_sync": "", + "stats": nil, + }) + return + } lastSync := s.engine.LastSync() stats := s.engine.LastSyncStats() diff --git a/internal/server/export.go b/internal/server/export.go index d7a5f8c3..b79e7a08 100644 --- a/internal/server/export.go +++ b/internal/server/export.go @@ -668,5 +668,11 @@ func truncateStr(s string, max int) string { if len(s) <= max { return s } - return s[:max] + "..." + // Truncate at a valid rune boundary to avoid producing + // invalid UTF-8. + r := []rune(s) + if len(r) <= max { + return s + } + return string(r[:max]) + "..." } diff --git a/internal/server/insights.go b/internal/server/insights.go index 60a3fce2..e99b59ae 100644 --- a/internal/server/insights.go +++ b/internal/server/insights.go @@ -112,6 +112,9 @@ func (s *Server) handleDeleteInsight( } if err := s.db.DeleteInsight(id); err != nil { + if handleReadOnly(w, err) { + return + } writeError( w, http.StatusInternalServerError, err.Error(), ) @@ -137,6 +140,12 @@ func insightGenerateClientMessage(agent string) string { func (s *Server) handleGenerateInsight( w http.ResponseWriter, r *http.Request, ) { + if s.db.ReadOnly() { + writeError(w, http.StatusNotImplemented, + "insight generation is not available in read-only mode") + return + } + var req generateInsightRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, diff --git a/internal/server/openers.go b/internal/server/openers.go index 36bf0810..7e06dded 100644 --- a/internal/server/openers.go +++ b/internal/server/openers.go @@ -155,6 +155,11 @@ func (s *Server) handleListOpeners( func (s *Server) handleGetSessionDir( w http.ResponseWriter, r *http.Request, ) { + if s.db.ReadOnly() { + writeError(w, http.StatusNotImplemented, + "not available in remote mode") + return + } sessionID := r.PathValue("id") session, err := s.db.GetSessionFull(r.Context(), sessionID) if err != nil { @@ -181,6 +186,11 @@ type openRequest struct { func (s *Server) handleOpenSession( w http.ResponseWriter, r *http.Request, ) { + if s.db.ReadOnly() { + writeError(w, http.StatusNotImplemented, + "not available in remote mode") + return + } sessionID := r.PathValue("id") session, err := s.db.GetSessionFull(r.Context(), sessionID) if err != nil { diff --git a/internal/server/pins.go b/internal/server/pins.go index ce30f097..4ebf2199 100644 --- a/internal/server/pins.go +++ b/internal/server/pins.go @@ -34,6 +34,9 @@ func (s *Server) handlePinMessage( id, err := s.db.PinMessage(sessionID, messageID, req.Note) if err != nil { + if handleReadOnly(w, err) { + return + } log.Printf("pin message: %v", err) writeError(w, http.StatusInternalServerError, "internal error") return @@ -59,6 +62,9 @@ func (s *Server) handleUnpinMessage( } if err := s.db.UnpinMessage(sessionID, messageID); err != nil { + if handleReadOnly(w, err) { + return + } log.Printf("unpin message: %v", err) writeError(w, http.StatusInternalServerError, "internal error") return diff --git a/internal/server/response.go b/internal/server/response.go index 0aba87a5..9d6ce49a 100644 --- a/internal/server/response.go +++ b/internal/server/response.go @@ -6,6 +6,8 @@ import ( "errors" "log" "net/http" + + "github.com/wesm/agentsview/internal/db" ) // writeJSON writes v as JSON with the given HTTP status code. @@ -24,6 +26,17 @@ func writeError(w http.ResponseWriter, status int, msg string) { writeJSON(w, status, map[string]string{"error": msg}) } +// handleReadOnly checks for db.ErrReadOnly and writes a 501. +// Returns true if the error was handled. +func handleReadOnly(w http.ResponseWriter, err error) bool { + if errors.Is(err, db.ErrReadOnly) { + writeError(w, http.StatusNotImplemented, + "not available in remote mode") + return true + } + return false +} + // handleContextError checks for context.Canceled and // context.DeadlineExceeded. On cancellation it returns true // silently (client disconnected). On deadline exceeded it diff --git a/internal/server/resume.go b/internal/server/resume.go index 7f41f759..c68c721a 100644 --- a/internal/server/resume.go +++ b/internal/server/resume.go @@ -147,6 +147,14 @@ func (s *Server) handleResumeSession( return } + // Block actual launches in read-only mode. command_only + // requests above are safe and remain available. + if s.db.ReadOnly() { + writeError(w, http.StatusNotImplemented, + "session launch not available in remote mode") + return + } + // If the caller specified a terminal opener, use it directly. if req.OpenerID != "" { openers := detectOpeners() diff --git a/internal/server/server.go b/internal/server/server.go index ec6c4bd7..dce1eeb3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "io" "io/fs" "log" "net" @@ -25,13 +26,14 @@ type VersionInfo struct { Version string `json:"version"` Commit string `json:"commit"` BuildDate string `json:"build_date"` + ReadOnly bool `json:"read_only,omitempty"` } // Server is the HTTP server that serves the SPA and REST API. type Server struct { mu gosync.RWMutex cfg config.Config - db *db.DB + db db.Store engine *sync.Engine mux *http.ServeMux httpSrv *http.Server @@ -56,11 +58,17 @@ type Server struct { // updates. Defaults to update.CheckForUpdate; tests // can override it via WithUpdateChecker. updateCheckFn UpdateCheckFunc + + // basePath is a URL prefix for reverse-proxy deployments + // (e.g. "/agentsview"). When set, all routes are served + // under this prefix and a tag is injected + // into the SPA's index.html. + basePath string } // New creates a new Server. func New( - cfg config.Config, database *db.DB, engine *sync.Engine, + cfg config.Config, database db.Store, engine *sync.Engine, opts ...Option, ) *Server { dist, err := web.Assets() @@ -111,6 +119,16 @@ func WithUpdateChecker(f UpdateCheckFunc) Option { return func(s *Server) { s.updateCheckFn = f } } +// WithBasePath sets a URL prefix for reverse-proxy deployments. +// The path must start with "/" and not end with "/" (e.g. +// "/agentsview"). When set, the server strips this prefix from +// incoming requests and injects a tag into the SPA. +func WithBasePath(path string) Option { + return func(s *Server) { + s.basePath = strings.TrimRight(path, "/") + } +} + // WithGenerateFunc overrides the insight generation function, // allowing tests to substitute a stub. Nil is ignored. func WithGenerateFunc(f insight.GenerateFunc) Option { @@ -246,15 +264,65 @@ func (s *Server) handleSPA(w http.ResponseWriter, r *http.Request) { f, err := s.spaFS.Open(path) if err == nil { f.Close() + // For index.html with a base path, inject . + if s.basePath != "" && path == "index.html" { + s.serveIndexWithBase(w, r) + return + } s.spaHandler.ServeHTTP(w, r) return } // SPA fallback: serve index.html for all routes + if s.basePath != "" { + s.serveIndexWithBase(w, r) + return + } r.URL.Path = "/" s.spaHandler.ServeHTTP(w, r) } +// serveIndexWithBase reads the embedded index.html, injects a +// tag, and rewrites root-relative asset paths so +// everything resolves correctly behind a reverse proxy subpath. +func (s *Server) serveIndexWithBase( + w http.ResponseWriter, _ *http.Request, +) { + f, err := s.spaFS.Open("index.html") + if err != nil { + http.Error(w, "index.html not found", + http.StatusInternalServerError) + return + } + defer f.Close() + data, err := io.ReadAll(f) + if err != nil { + http.Error(w, "reading index.html", + http.StatusInternalServerError) + return + } + html := string(data) + + // Rewrite root-relative asset paths (href="/...", src="/...") + // to include the base path prefix so the browser fetches + // assets through the reverse proxy. + bp := s.basePath + html = strings.ReplaceAll(html, `href="/`, `href="`+bp+`/`) + html = strings.ReplaceAll(html, `src="/`, `src="`+bp+`/`) + + // Inject AFTER rewriting paths so it doesn't + // get double-prefixed by the replacement above. + baseTag := fmt.Sprintf( + ``, bp, + ) + html = strings.Replace( + html, "", "\n "+baseTag, 1, + ) + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write([]byte(html)) +} + // SetPort updates the listen port (for testing). func (s *Server) SetPort(port int) { s.mu.Lock() @@ -282,14 +350,15 @@ func (s *Server) Handler() http.Handler { s.cfg.Host, s.cfg.Port, s.cfg.PublicOrigins, ) allowedHosts := buildAllowedHosts( - s.cfg.Host, s.cfg.Port, s.cfg.PublicURL, + s.cfg.Host, s.cfg.Port, + s.cfg.PublicURL, s.cfg.PublicOrigins, ) bindAll := isBindAll(s.cfg.Host) bindAllIPs := map[string]bool(nil) if bindAll { bindAllIPs = localInterfaceIPs() } - return s.authMiddleware( + h := s.authMiddleware( hostCheckMiddleware( allowedHosts, bindAll, s.cfg.Port, bindAllIPs, corsMiddleware( @@ -297,6 +366,30 @@ func (s *Server) Handler() http.Handler { ), ), ) + if s.basePath != "" { + inner := h + prefix := s.basePath + h = http.HandlerFunc(func( + w http.ResponseWriter, r *http.Request, + ) { + p := r.URL.Path + // Redirect /basepath to /basepath/ for the SPA. + if p == prefix { + http.Redirect(w, r, + prefix+"/", http.StatusMovedPermanently) + return + } + // Only match full path-segment prefixes to + // prevent /basepathFOO from being handled. + if !strings.HasPrefix(p, prefix+"/") { + http.NotFound(w, r) + return + } + http.StripPrefix(prefix, inner). + ServeHTTP(w, r) + }) + } + return h } // buildAllowedHosts returns the set of Host header values that @@ -304,7 +397,10 @@ func (s *Server) Handler() http.Handler { // rebinding attacks where an attacker's domain resolves to // 127.0.0.1 — the browser sends the attacker's domain as the // Host header, which we reject. -func buildAllowedHosts(host string, port int, publicURL string) map[string]bool { +func buildAllowedHosts( + host string, port int, + publicURL string, publicOrigins []string, +) map[string]bool { hosts := make(map[string]bool) add := func(h string) { hosts[net.JoinHostPort(h, strconv.Itoa(port))] = true @@ -335,6 +431,9 @@ func buildAllowedHosts(host string, port int, publicURL string) map[string]bool if publicURL != "" { addHostHeadersFromOrigin(hosts, publicURL) } + for _, origin := range publicOrigins { + addHostHeadersFromOrigin(hosts, origin) + } return hosts } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 63489f39..360ea978 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1462,17 +1462,43 @@ func TestHostHeaderAllowsConfiguredPublicOriginHost(t *testing.T) { assertStatus(t, w, http.StatusOK) } -func TestHostHeaderPublicOriginsDoNotExpandTrustedHosts(t *testing.T) { +func TestHostHeaderPublicOriginsExpandTrustedHosts(t *testing.T) { te := setup(t, withPublicOrigins("http://viewer.example.test:8004")) req := httptest.NewRequest(http.MethodGet, "/api/v1/stats", nil) req.Host = "viewer.example.test:8004" - // Use loopback RemoteAddr so authMiddleware passes through and - // the 403 comes from hostCheckMiddleware, not auth. req.RemoteAddr = "127.0.0.1:1234" w := httptest.NewRecorder() te.srv.Handler().ServeHTTP(w, req) - assertStatus(t, w, http.StatusForbidden) + // public_origins should expand the host allowlist so + // reverse proxies forwarding the origin's Host are allowed. + assertStatus(t, w, http.StatusOK) +} + +func TestHostHeaderHTTPSPublicOriginExpandsTrustedHosts( + t *testing.T, +) { + te := setup(t, withPublicOrigins( + "https://viewer.example.test", + )) + + // Browsers omit :443 for HTTPS, so test the bare hostname + // that a reverse proxy would forward. + for _, host := range []string{ + "viewer.example.test", + "viewer.example.test:443", + } { + t.Run(host, func(t *testing.T) { + req := httptest.NewRequest( + http.MethodGet, "/api/v1/stats", nil, + ) + req.Host = host + req.RemoteAddr = "127.0.0.1:1234" + w := httptest.NewRecorder() + te.srv.Handler().ServeHTTP(w, req) + assertStatus(t, w, http.StatusOK) + }) + } } func TestCORSAllowsConfiguredHTTPSPublicOrigin(t *testing.T) { diff --git a/internal/server/session_mgmt.go b/internal/server/session_mgmt.go index 41c1dfaf..df80d8c3 100644 --- a/internal/server/session_mgmt.go +++ b/internal/server/session_mgmt.go @@ -43,6 +43,9 @@ func (s *Server) handleRenameSession( } if err := s.db.RenameSession(id, req.DisplayName); err != nil { + if handleReadOnly(w, err) { + return + } log.Printf("rename session: %v", err) writeError(w, http.StatusInternalServerError, "internal error") return @@ -89,6 +92,9 @@ func (s *Server) handleDeleteSession( } if err := s.db.SoftDeleteSession(id); err != nil { + if handleReadOnly(w, err) { + return + } log.Printf("soft delete session: %v", err) writeError(w, http.StatusInternalServerError, "internal error") return @@ -106,6 +112,9 @@ func (s *Server) handleRestoreSession( n, err := s.db.RestoreSession(id) if err != nil { + if handleReadOnly(w, err) { + return + } log.Printf("restore session: %v", err) writeError(w, http.StatusInternalServerError, "internal error") return @@ -132,6 +141,9 @@ func (s *Server) handlePermanentDeleteSession( // performing the delete. n, err := s.db.DeleteSessionIfTrashed(id) if err != nil { + if handleReadOnly(w, err) { + return + } log.Printf("permanent delete session: %v", err) writeError(w, http.StatusInternalServerError, "internal error") return @@ -169,6 +181,9 @@ func (s *Server) handleEmptyTrash( ) { count, err := s.db.EmptyTrash() if err != nil { + if handleReadOnly(w, err) { + return + } log.Printf("empty trash: %v", err) writeError(w, http.StatusInternalServerError, "internal error") return diff --git a/internal/server/settings.go b/internal/server/settings.go index df424f2d..43d9de22 100644 --- a/internal/server/settings.go +++ b/internal/server/settings.go @@ -79,6 +79,12 @@ type settingsUpdateRequest struct { func (s *Server) handleUpdateSettings( w http.ResponseWriter, r *http.Request, ) { + if s.db.ReadOnly() { + writeError(w, http.StatusNotImplemented, + "settings cannot be modified in read-only mode") + return + } + var req settingsUpdateRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "invalid JSON") diff --git a/internal/server/starred.go b/internal/server/starred.go index 7eeeeba7..2622928b 100644 --- a/internal/server/starred.go +++ b/internal/server/starred.go @@ -19,6 +19,9 @@ func (s *Server) handleStarSession( // the TOCTOU race of a separate GetSession + INSERT. ok, err := s.db.StarSession(id) if err != nil { + if handleReadOnly(w, err) { + return + } log.Printf("star session: %v", err) writeError(w, http.StatusInternalServerError, "internal error") return @@ -40,6 +43,9 @@ func (s *Server) handleUnstarSession( } if err := s.db.UnstarSession(id); err != nil { + if handleReadOnly(w, err) { + return + } log.Printf("unstar session: %v", err) writeError(w, http.StatusInternalServerError, "internal error") return @@ -82,6 +88,9 @@ func (s *Server) handleBulkStar( return } if err := s.db.BulkStarSessions(body.SessionIDs); err != nil { + if handleReadOnly(w, err) { + return + } log.Printf("bulk star: %v", err) writeError(w, http.StatusInternalServerError, "internal error") return diff --git a/internal/server/statefile_test.go b/internal/server/statefile_test.go index f3f9af09..43f93abe 100644 --- a/internal/server/statefile_test.go +++ b/internal/server/statefile_test.go @@ -218,6 +218,7 @@ func TestFindRunningServer_LiveProcess(t *testing.T) { result := FindRunningServer(dir) if result == nil { t.Fatal("expected running server, got nil") + return } if result.Port != port { t.Errorf("port = %d, want %d", result.Port, port) @@ -260,6 +261,7 @@ func TestFindRunningServer_BindAll(t *testing.T) { t.Fatal( "expected running server for 0.0.0.0 host, got nil", ) + return } if result.Port != port { t.Errorf("port = %d, want %d", result.Port, port) diff --git a/internal/server/upload.go b/internal/server/upload.go index 5f95c227..7ed5f322 100644 --- a/internal/server/upload.go +++ b/internal/server/upload.go @@ -164,6 +164,12 @@ func (s *Server) saveSessionToDB( func (s *Server) handleUploadSession( w http.ResponseWriter, r *http.Request, ) { + if s.db.ReadOnly() { + writeError(w, http.StatusNotImplemented, + "uploads are not available in read-only mode") + return + } + req, errMsg := parseUploadRequest(r) if errMsg != "" { writeError(w, http.StatusBadRequest, errMsg) @@ -199,6 +205,9 @@ func (s *Server) handleUploadSession( for _, pr := range results { if err := s.saveSessionToDB(pr.Session, pr.Messages); err != nil { + if handleReadOnly(w, err) { + return + } log.Printf("Error saving session to DB: %v", err) writeError(w, http.StatusInternalServerError, "failed to save session to database")